Refine file edit progress gating

This commit is contained in:
Xubin Ren 2026-05-18 01:59:55 +08:00
parent de8761f25a
commit eb3aed359f
3 changed files with 119 additions and 19 deletions

View File

@ -32,7 +32,10 @@ from nanobot.utils.helpers import (
strip_think, strip_think,
truncate_text, truncate_text,
) )
from nanobot.utils.progress_events import invoke_file_edit_progress from nanobot.utils.progress_events import (
invoke_file_edit_progress,
on_progress_accepts_file_edit_events,
)
from nanobot.utils.prompt_templates import render_template from nanobot.utils.prompt_templates import render_template
from nanobot.utils.runtime import ( from nanobot.utils.runtime import (
EMPTY_FINAL_RESPONSE_MESSAGE, EMPTY_FINAL_RESPONSE_MESSAGE,
@ -820,16 +823,25 @@ class AgentRunner:
return prep_error + hint, event, ( return prep_error + hint, event, (
RuntimeError(prep_error) if spec.fail_on_tool_error else None RuntimeError(prep_error) if spec.fail_on_tool_error else None
) )
file_edit_tracker = prepare_file_edit_tracker( emit_file_edit_events = (
call_id=tool_call.id, spec.progress_callback is not None
tool_name=tool_call.name, and on_progress_accepts_file_edit_events(spec.progress_callback)
tool=tool,
workspace=spec.workspace,
params=params if isinstance(params, dict) else None,
) )
if file_edit_tracker is not None and spec.progress_callback is not None: progress_callback = spec.progress_callback if emit_file_edit_events else None
file_edit_tracker = (
prepare_file_edit_tracker(
call_id=tool_call.id,
tool_name=tool_call.name,
tool=tool,
workspace=spec.workspace,
params=params if isinstance(params, dict) else None,
)
if progress_callback is not None
else None
)
if file_edit_tracker is not None and progress_callback is not None:
await invoke_file_edit_progress( await invoke_file_edit_progress(
spec.progress_callback, progress_callback,
[build_file_edit_start_event( [build_file_edit_start_event(
file_edit_tracker, file_edit_tracker,
params if isinstance(params, dict) else None, params if isinstance(params, dict) else None,
@ -843,9 +855,9 @@ class AgentRunner:
except asyncio.CancelledError: except asyncio.CancelledError:
raise raise
except BaseException as exc: except BaseException as exc:
if file_edit_tracker is not None and spec.progress_callback is not None: if file_edit_tracker is not None and progress_callback is not None:
await invoke_file_edit_progress( await invoke_file_edit_progress(
spec.progress_callback, progress_callback,
[build_file_edit_error_event(file_edit_tracker, str(exc))], [build_file_edit_error_event(file_edit_tracker, str(exc))],
) )
event = { event = {
@ -869,9 +881,9 @@ class AgentRunner:
return payload, event, None return payload, event, None
if isinstance(result, str) and result.startswith("Error"): if isinstance(result, str) and result.startswith("Error"):
if file_edit_tracker is not None and spec.progress_callback is not None: if file_edit_tracker is not None and progress_callback is not None:
await invoke_file_edit_progress( await invoke_file_edit_progress(
spec.progress_callback, progress_callback,
[build_file_edit_error_event(file_edit_tracker, result)], [build_file_edit_error_event(file_edit_tracker, result)],
) )
event = { event = {
@ -892,9 +904,9 @@ class AgentRunner:
return result + hint, event, RuntimeError(result) return result + hint, event, RuntimeError(result)
return result + hint, event, None return result + hint, event, None
if file_edit_tracker is not None and spec.progress_callback is not None: if file_edit_tracker is not None and progress_callback is not None:
await invoke_file_edit_progress( await invoke_file_edit_progress(
spec.progress_callback, progress_callback,
[build_file_edit_end_event(file_edit_tracker)], [build_file_edit_end_event(file_edit_tracker)],
) )

View File

@ -200,7 +200,7 @@ def build_bus_progress_callback(
) -> Callable[..., Awaitable[None]]: ) -> Callable[..., Awaitable[None]]:
"""Return the bus progress callback for agent runtime events.""" """Return the bus progress callback for agent runtime events."""
async def _bus_progress( async def _publish_progress(
content: str, content: str,
*, *,
tool_hint: bool = False, tool_hint: bool = False,
@ -209,8 +209,6 @@ def build_bus_progress_callback(
reasoning: bool = False, reasoning: bool = False,
reasoning_end: bool = False, reasoning_end: bool = False,
) -> None: ) -> None:
if file_edit_events and msg.channel != "websocket":
return
meta = dict(msg.metadata or {}) meta = dict(msg.metadata or {})
meta["_progress"] = True meta["_progress"] = True
meta["_tool_hint"] = tool_hint meta["_tool_hint"] = tool_hint
@ -231,6 +229,43 @@ def build_bus_progress_callback(
) )
) )
if msg.channel == "websocket":
async def _websocket_progress(
content: str,
*,
tool_hint: bool = False,
tool_events: list[dict[str, Any]] | None = None,
file_edit_events: list[dict[str, Any]] | None = None,
reasoning: bool = False,
reasoning_end: bool = False,
) -> None:
await _publish_progress(
content,
tool_hint=tool_hint,
tool_events=tool_events,
file_edit_events=file_edit_events,
reasoning=reasoning,
reasoning_end=reasoning_end,
)
return _websocket_progress
async def _bus_progress(
content: str,
*,
tool_hint: bool = False,
tool_events: list[dict[str, Any]] | None = None,
reasoning: bool = False,
reasoning_end: bool = False,
) -> None:
await _publish_progress(
content,
tool_hint=tool_hint,
tool_events=tool_events,
reasoning=reasoning,
reasoning_end=reasoning_end,
)
return _bus_progress return _bus_progress

View File

@ -6,10 +6,15 @@ from unittest.mock import AsyncMock, MagicMock
import pytest import pytest
import nanobot.agent.runner as runner_module
from nanobot.agent.loop import AgentLoop from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse, ToolCallRequest from nanobot.providers.base import LLMResponse, ToolCallRequest
from nanobot.utils.progress_events import (
invoke_file_edit_progress,
on_progress_accepts_file_edit_events,
)
def _make_loop(tmp_path: Path) -> AgentLoop: def _make_loop(tmp_path: Path) -> AgentLoop:
@ -138,6 +143,52 @@ class TestToolEventProgress:
assert file_events[1]["approximate"] is False assert file_events[1]["approximate"] is False
assert (file_events[1]["added"], file_events[1]["deleted"]) == (2, 1) assert (file_events[1]["added"], file_events[1]["deleted"]) == (2, 1)
@pytest.mark.asyncio
async def test_file_edit_snapshot_skipped_when_progress_callback_cannot_emit_file_edits(
self,
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
loop = _make_loop(tmp_path)
target = tmp_path / "foo.txt"
target.write_text("old\n", encoding="utf-8")
tool_call = ToolCallRequest(
id="call-write",
name="write_file",
arguments={"path": "foo.txt", "content": "new\n"},
)
calls = iter([
LLMResponse(content="", tool_calls=[tool_call]),
LLMResponse(content="Done", tool_calls=[]),
])
loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls))
loop.tools.get_definitions = MagicMock(return_value=[])
loop.tools.prepare_call = MagicMock(
return_value=(None, {"path": "foo.txt", "content": "new\n"}, None),
)
async def execute(name: str, params: dict) -> str:
target.write_text(params["content"], encoding="utf-8")
return "ok"
loop.tools.execute = AsyncMock(side_effect=execute)
prepare_tracker = MagicMock(side_effect=AssertionError("unexpected file snapshot"))
monkeypatch.setattr(runner_module, "prepare_file_edit_tracker", prepare_tracker)
async def on_progress(
content: str,
*,
tool_hint: bool = False,
tool_events: list[dict] | None = None,
) -> None:
pass
final_content, _, _, _, _ = await loop._run_agent_loop([], on_progress=on_progress)
assert final_content == "Done"
assert target.read_text(encoding="utf-8") == "new\n"
prepare_tracker.assert_not_called()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_exec_does_not_emit_file_edit_progress(self, tmp_path: Path) -> None: async def test_exec_does_not_emit_file_edit_progress(self, tmp_path: Path) -> None:
loop = _make_loop(tmp_path) loop = _make_loop(tmp_path)
@ -243,6 +294,7 @@ class TestToolEventProgress:
chat_id="chat1", chat_id="chat1",
content="edit", content="edit",
)) ))
assert on_progress_accepts_file_edit_events(websocket_progress) is True
await websocket_progress("", file_edit_events=edit_events) await websocket_progress("", file_edit_events=edit_events)
outbound = await bus.consume_outbound() outbound = await bus.consume_outbound()
assert outbound.metadata["_file_edit_events"] == edit_events assert outbound.metadata["_file_edit_events"] == edit_events
@ -253,7 +305,8 @@ class TestToolEventProgress:
chat_id="chat2", chat_id="chat2",
content="edit", content="edit",
)) ))
await telegram_progress("", file_edit_events=edit_events) assert on_progress_accepts_file_edit_events(telegram_progress) is False
await invoke_file_edit_progress(telegram_progress, edit_events)
assert bus.outbound_size == 0 assert bus.outbound_size == 0
@pytest.mark.asyncio @pytest.mark.asyncio