diff --git a/nanobot/agent/runner.py b/nanobot/agent/runner.py index 64345822a..776885ecb 100644 --- a/nanobot/agent/runner.py +++ b/nanobot/agent/runner.py @@ -32,7 +32,10 @@ from nanobot.utils.helpers import ( strip_think, truncate_text, ) -from nanobot.utils.progress_events import invoke_file_edit_progress +from nanobot.utils.progress_events import ( + invoke_file_edit_progress, + on_progress_accepts_file_edit_events, +) from nanobot.utils.prompt_templates import render_template from nanobot.utils.runtime import ( EMPTY_FINAL_RESPONSE_MESSAGE, @@ -820,16 +823,25 @@ class AgentRunner: return prep_error + hint, event, ( RuntimeError(prep_error) if spec.fail_on_tool_error else None ) - file_edit_tracker = prepare_file_edit_tracker( - call_id=tool_call.id, - tool_name=tool_call.name, - tool=tool, - workspace=spec.workspace, - params=params if isinstance(params, dict) else None, + emit_file_edit_events = ( + spec.progress_callback is not None + and on_progress_accepts_file_edit_events(spec.progress_callback) ) - if file_edit_tracker is not None and spec.progress_callback is not None: + progress_callback = spec.progress_callback if emit_file_edit_events else None + file_edit_tracker = ( + prepare_file_edit_tracker( + call_id=tool_call.id, + tool_name=tool_call.name, + tool=tool, + workspace=spec.workspace, + params=params if isinstance(params, dict) else None, + ) + if progress_callback is not None + else None + ) + if file_edit_tracker is not None and progress_callback is not None: await invoke_file_edit_progress( - spec.progress_callback, + progress_callback, [build_file_edit_start_event( file_edit_tracker, params if isinstance(params, dict) else None, @@ -843,9 +855,9 @@ class AgentRunner: except asyncio.CancelledError: raise except BaseException as exc: - if file_edit_tracker is not None and spec.progress_callback is not None: + if file_edit_tracker is not None and progress_callback is not None: await invoke_file_edit_progress( - spec.progress_callback, + progress_callback, [build_file_edit_error_event(file_edit_tracker, str(exc))], ) event = { @@ -869,9 +881,9 @@ class AgentRunner: return payload, event, None if isinstance(result, str) and result.startswith("Error"): - if file_edit_tracker is not None and spec.progress_callback is not None: + if file_edit_tracker is not None and progress_callback is not None: await invoke_file_edit_progress( - spec.progress_callback, + progress_callback, [build_file_edit_error_event(file_edit_tracker, result)], ) event = { @@ -892,9 +904,9 @@ class AgentRunner: return result + hint, event, RuntimeError(result) return result + hint, event, None - if file_edit_tracker is not None and spec.progress_callback is not None: + if file_edit_tracker is not None and progress_callback is not None: await invoke_file_edit_progress( - spec.progress_callback, + progress_callback, [build_file_edit_end_event(file_edit_tracker)], ) diff --git a/nanobot/utils/webui_turn_helpers.py b/nanobot/utils/webui_turn_helpers.py index 9ef4612f9..6a3ac2ba0 100644 --- a/nanobot/utils/webui_turn_helpers.py +++ b/nanobot/utils/webui_turn_helpers.py @@ -200,7 +200,7 @@ def build_bus_progress_callback( ) -> Callable[..., Awaitable[None]]: """Return the bus progress callback for agent runtime events.""" - async def _bus_progress( + async def _publish_progress( content: str, *, tool_hint: bool = False, @@ -209,8 +209,6 @@ def build_bus_progress_callback( reasoning: bool = False, reasoning_end: bool = False, ) -> None: - if file_edit_events and msg.channel != "websocket": - return meta = dict(msg.metadata or {}) meta["_progress"] = True meta["_tool_hint"] = tool_hint @@ -231,6 +229,43 @@ def build_bus_progress_callback( ) ) + if msg.channel == "websocket": + async def _websocket_progress( + content: str, + *, + tool_hint: bool = False, + tool_events: list[dict[str, Any]] | None = None, + file_edit_events: list[dict[str, Any]] | None = None, + reasoning: bool = False, + reasoning_end: bool = False, + ) -> None: + await _publish_progress( + content, + tool_hint=tool_hint, + tool_events=tool_events, + file_edit_events=file_edit_events, + reasoning=reasoning, + reasoning_end=reasoning_end, + ) + + return _websocket_progress + + async def _bus_progress( + content: str, + *, + tool_hint: bool = False, + tool_events: list[dict[str, Any]] | None = None, + reasoning: bool = False, + reasoning_end: bool = False, + ) -> None: + await _publish_progress( + content, + tool_hint=tool_hint, + tool_events=tool_events, + reasoning=reasoning, + reasoning_end=reasoning_end, + ) + return _bus_progress diff --git a/tests/agent/test_loop_progress.py b/tests/agent/test_loop_progress.py index b1b33612f..43a691437 100644 --- a/tests/agent/test_loop_progress.py +++ b/tests/agent/test_loop_progress.py @@ -6,10 +6,15 @@ from unittest.mock import AsyncMock, MagicMock import pytest +import nanobot.agent.runner as runner_module from nanobot.agent.loop import AgentLoop from nanobot.bus.events import InboundMessage from nanobot.bus.queue import MessageBus from nanobot.providers.base import LLMResponse, ToolCallRequest +from nanobot.utils.progress_events import ( + invoke_file_edit_progress, + on_progress_accepts_file_edit_events, +) def _make_loop(tmp_path: Path) -> AgentLoop: @@ -138,6 +143,52 @@ class TestToolEventProgress: assert file_events[1]["approximate"] is False assert (file_events[1]["added"], file_events[1]["deleted"]) == (2, 1) + @pytest.mark.asyncio + async def test_file_edit_snapshot_skipped_when_progress_callback_cannot_emit_file_edits( + self, + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + loop = _make_loop(tmp_path) + target = tmp_path / "foo.txt" + target.write_text("old\n", encoding="utf-8") + tool_call = ToolCallRequest( + id="call-write", + name="write_file", + arguments={"path": "foo.txt", "content": "new\n"}, + ) + calls = iter([ + LLMResponse(content="", tool_calls=[tool_call]), + LLMResponse(content="Done", tool_calls=[]), + ]) + loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls)) + loop.tools.get_definitions = MagicMock(return_value=[]) + loop.tools.prepare_call = MagicMock( + return_value=(None, {"path": "foo.txt", "content": "new\n"}, None), + ) + + async def execute(name: str, params: dict) -> str: + target.write_text(params["content"], encoding="utf-8") + return "ok" + + loop.tools.execute = AsyncMock(side_effect=execute) + prepare_tracker = MagicMock(side_effect=AssertionError("unexpected file snapshot")) + monkeypatch.setattr(runner_module, "prepare_file_edit_tracker", prepare_tracker) + + async def on_progress( + content: str, + *, + tool_hint: bool = False, + tool_events: list[dict] | None = None, + ) -> None: + pass + + final_content, _, _, _, _ = await loop._run_agent_loop([], on_progress=on_progress) + + assert final_content == "Done" + assert target.read_text(encoding="utf-8") == "new\n" + prepare_tracker.assert_not_called() + @pytest.mark.asyncio async def test_exec_does_not_emit_file_edit_progress(self, tmp_path: Path) -> None: loop = _make_loop(tmp_path) @@ -243,6 +294,7 @@ class TestToolEventProgress: chat_id="chat1", content="edit", )) + assert on_progress_accepts_file_edit_events(websocket_progress) is True await websocket_progress("", file_edit_events=edit_events) outbound = await bus.consume_outbound() assert outbound.metadata["_file_edit_events"] == edit_events @@ -253,7 +305,8 @@ class TestToolEventProgress: chat_id="chat2", content="edit", )) - await telegram_progress("", file_edit_events=edit_events) + assert on_progress_accepts_file_edit_events(telegram_progress) is False + await invoke_file_edit_progress(telegram_progress, edit_events) assert bus.outbound_size == 0 @pytest.mark.asyncio