Refine file edit progress gating

This commit is contained in:
Xubin Ren 2026-05-18 01:59:55 +08:00
parent de8761f25a
commit eb3aed359f
3 changed files with 119 additions and 19 deletions

View File

@ -32,7 +32,10 @@ from nanobot.utils.helpers import (
strip_think,
truncate_text,
)
from nanobot.utils.progress_events import invoke_file_edit_progress
from nanobot.utils.progress_events import (
invoke_file_edit_progress,
on_progress_accepts_file_edit_events,
)
from nanobot.utils.prompt_templates import render_template
from nanobot.utils.runtime import (
EMPTY_FINAL_RESPONSE_MESSAGE,
@ -820,16 +823,25 @@ class AgentRunner:
return prep_error + hint, event, (
RuntimeError(prep_error) if spec.fail_on_tool_error else None
)
file_edit_tracker = prepare_file_edit_tracker(
call_id=tool_call.id,
tool_name=tool_call.name,
tool=tool,
workspace=spec.workspace,
params=params if isinstance(params, dict) else None,
emit_file_edit_events = (
spec.progress_callback is not None
and on_progress_accepts_file_edit_events(spec.progress_callback)
)
if file_edit_tracker is not None and spec.progress_callback is not None:
progress_callback = spec.progress_callback if emit_file_edit_events else None
file_edit_tracker = (
prepare_file_edit_tracker(
call_id=tool_call.id,
tool_name=tool_call.name,
tool=tool,
workspace=spec.workspace,
params=params if isinstance(params, dict) else None,
)
if progress_callback is not None
else None
)
if file_edit_tracker is not None and progress_callback is not None:
await invoke_file_edit_progress(
spec.progress_callback,
progress_callback,
[build_file_edit_start_event(
file_edit_tracker,
params if isinstance(params, dict) else None,
@ -843,9 +855,9 @@ class AgentRunner:
except asyncio.CancelledError:
raise
except BaseException as exc:
if file_edit_tracker is not None and spec.progress_callback is not None:
if file_edit_tracker is not None and progress_callback is not None:
await invoke_file_edit_progress(
spec.progress_callback,
progress_callback,
[build_file_edit_error_event(file_edit_tracker, str(exc))],
)
event = {
@ -869,9 +881,9 @@ class AgentRunner:
return payload, event, None
if isinstance(result, str) and result.startswith("Error"):
if file_edit_tracker is not None and spec.progress_callback is not None:
if file_edit_tracker is not None and progress_callback is not None:
await invoke_file_edit_progress(
spec.progress_callback,
progress_callback,
[build_file_edit_error_event(file_edit_tracker, result)],
)
event = {
@ -892,9 +904,9 @@ class AgentRunner:
return result + hint, event, RuntimeError(result)
return result + hint, event, None
if file_edit_tracker is not None and spec.progress_callback is not None:
if file_edit_tracker is not None and progress_callback is not None:
await invoke_file_edit_progress(
spec.progress_callback,
progress_callback,
[build_file_edit_end_event(file_edit_tracker)],
)

View File

@ -200,7 +200,7 @@ def build_bus_progress_callback(
) -> Callable[..., Awaitable[None]]:
"""Return the bus progress callback for agent runtime events."""
async def _bus_progress(
async def _publish_progress(
content: str,
*,
tool_hint: bool = False,
@ -209,8 +209,6 @@ def build_bus_progress_callback(
reasoning: bool = False,
reasoning_end: bool = False,
) -> None:
if file_edit_events and msg.channel != "websocket":
return
meta = dict(msg.metadata or {})
meta["_progress"] = True
meta["_tool_hint"] = tool_hint
@ -231,6 +229,43 @@ def build_bus_progress_callback(
)
)
if msg.channel == "websocket":
async def _websocket_progress(
content: str,
*,
tool_hint: bool = False,
tool_events: list[dict[str, Any]] | None = None,
file_edit_events: list[dict[str, Any]] | None = None,
reasoning: bool = False,
reasoning_end: bool = False,
) -> None:
await _publish_progress(
content,
tool_hint=tool_hint,
tool_events=tool_events,
file_edit_events=file_edit_events,
reasoning=reasoning,
reasoning_end=reasoning_end,
)
return _websocket_progress
async def _bus_progress(
content: str,
*,
tool_hint: bool = False,
tool_events: list[dict[str, Any]] | None = None,
reasoning: bool = False,
reasoning_end: bool = False,
) -> None:
await _publish_progress(
content,
tool_hint=tool_hint,
tool_events=tool_events,
reasoning=reasoning,
reasoning_end=reasoning_end,
)
return _bus_progress

View File

@ -6,10 +6,15 @@ from unittest.mock import AsyncMock, MagicMock
import pytest
import nanobot.agent.runner as runner_module
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse, ToolCallRequest
from nanobot.utils.progress_events import (
invoke_file_edit_progress,
on_progress_accepts_file_edit_events,
)
def _make_loop(tmp_path: Path) -> AgentLoop:
@ -138,6 +143,52 @@ class TestToolEventProgress:
assert file_events[1]["approximate"] is False
assert (file_events[1]["added"], file_events[1]["deleted"]) == (2, 1)
@pytest.mark.asyncio
async def test_file_edit_snapshot_skipped_when_progress_callback_cannot_emit_file_edits(
self,
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
loop = _make_loop(tmp_path)
target = tmp_path / "foo.txt"
target.write_text("old\n", encoding="utf-8")
tool_call = ToolCallRequest(
id="call-write",
name="write_file",
arguments={"path": "foo.txt", "content": "new\n"},
)
calls = iter([
LLMResponse(content="", tool_calls=[tool_call]),
LLMResponse(content="Done", tool_calls=[]),
])
loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls))
loop.tools.get_definitions = MagicMock(return_value=[])
loop.tools.prepare_call = MagicMock(
return_value=(None, {"path": "foo.txt", "content": "new\n"}, None),
)
async def execute(name: str, params: dict) -> str:
target.write_text(params["content"], encoding="utf-8")
return "ok"
loop.tools.execute = AsyncMock(side_effect=execute)
prepare_tracker = MagicMock(side_effect=AssertionError("unexpected file snapshot"))
monkeypatch.setattr(runner_module, "prepare_file_edit_tracker", prepare_tracker)
async def on_progress(
content: str,
*,
tool_hint: bool = False,
tool_events: list[dict] | None = None,
) -> None:
pass
final_content, _, _, _, _ = await loop._run_agent_loop([], on_progress=on_progress)
assert final_content == "Done"
assert target.read_text(encoding="utf-8") == "new\n"
prepare_tracker.assert_not_called()
@pytest.mark.asyncio
async def test_exec_does_not_emit_file_edit_progress(self, tmp_path: Path) -> None:
loop = _make_loop(tmp_path)
@ -243,6 +294,7 @@ class TestToolEventProgress:
chat_id="chat1",
content="edit",
))
assert on_progress_accepts_file_edit_events(websocket_progress) is True
await websocket_progress("", file_edit_events=edit_events)
outbound = await bus.consume_outbound()
assert outbound.metadata["_file_edit_events"] == edit_events
@ -253,7 +305,8 @@ class TestToolEventProgress:
chat_id="chat2",
content="edit",
))
await telegram_progress("", file_edit_events=edit_events)
assert on_progress_accepts_file_edit_events(telegram_progress) is False
await invoke_file_edit_progress(telegram_progress, edit_events)
assert bus.outbound_size == 0
@pytest.mark.asyncio