mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-19 08:02:30 +00:00
Refine file edit progress gating
This commit is contained in:
parent
de8761f25a
commit
eb3aed359f
@ -32,7 +32,10 @@ from nanobot.utils.helpers import (
|
||||
strip_think,
|
||||
truncate_text,
|
||||
)
|
||||
from nanobot.utils.progress_events import invoke_file_edit_progress
|
||||
from nanobot.utils.progress_events import (
|
||||
invoke_file_edit_progress,
|
||||
on_progress_accepts_file_edit_events,
|
||||
)
|
||||
from nanobot.utils.prompt_templates import render_template
|
||||
from nanobot.utils.runtime import (
|
||||
EMPTY_FINAL_RESPONSE_MESSAGE,
|
||||
@ -820,16 +823,25 @@ class AgentRunner:
|
||||
return prep_error + hint, event, (
|
||||
RuntimeError(prep_error) if spec.fail_on_tool_error else None
|
||||
)
|
||||
file_edit_tracker = prepare_file_edit_tracker(
|
||||
call_id=tool_call.id,
|
||||
tool_name=tool_call.name,
|
||||
tool=tool,
|
||||
workspace=spec.workspace,
|
||||
params=params if isinstance(params, dict) else None,
|
||||
emit_file_edit_events = (
|
||||
spec.progress_callback is not None
|
||||
and on_progress_accepts_file_edit_events(spec.progress_callback)
|
||||
)
|
||||
if file_edit_tracker is not None and spec.progress_callback is not None:
|
||||
progress_callback = spec.progress_callback if emit_file_edit_events else None
|
||||
file_edit_tracker = (
|
||||
prepare_file_edit_tracker(
|
||||
call_id=tool_call.id,
|
||||
tool_name=tool_call.name,
|
||||
tool=tool,
|
||||
workspace=spec.workspace,
|
||||
params=params if isinstance(params, dict) else None,
|
||||
)
|
||||
if progress_callback is not None
|
||||
else None
|
||||
)
|
||||
if file_edit_tracker is not None and progress_callback is not None:
|
||||
await invoke_file_edit_progress(
|
||||
spec.progress_callback,
|
||||
progress_callback,
|
||||
[build_file_edit_start_event(
|
||||
file_edit_tracker,
|
||||
params if isinstance(params, dict) else None,
|
||||
@ -843,9 +855,9 @@ class AgentRunner:
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except BaseException as exc:
|
||||
if file_edit_tracker is not None and spec.progress_callback is not None:
|
||||
if file_edit_tracker is not None and progress_callback is not None:
|
||||
await invoke_file_edit_progress(
|
||||
spec.progress_callback,
|
||||
progress_callback,
|
||||
[build_file_edit_error_event(file_edit_tracker, str(exc))],
|
||||
)
|
||||
event = {
|
||||
@ -869,9 +881,9 @@ class AgentRunner:
|
||||
return payload, event, None
|
||||
|
||||
if isinstance(result, str) and result.startswith("Error"):
|
||||
if file_edit_tracker is not None and spec.progress_callback is not None:
|
||||
if file_edit_tracker is not None and progress_callback is not None:
|
||||
await invoke_file_edit_progress(
|
||||
spec.progress_callback,
|
||||
progress_callback,
|
||||
[build_file_edit_error_event(file_edit_tracker, result)],
|
||||
)
|
||||
event = {
|
||||
@ -892,9 +904,9 @@ class AgentRunner:
|
||||
return result + hint, event, RuntimeError(result)
|
||||
return result + hint, event, None
|
||||
|
||||
if file_edit_tracker is not None and spec.progress_callback is not None:
|
||||
if file_edit_tracker is not None and progress_callback is not None:
|
||||
await invoke_file_edit_progress(
|
||||
spec.progress_callback,
|
||||
progress_callback,
|
||||
[build_file_edit_end_event(file_edit_tracker)],
|
||||
)
|
||||
|
||||
|
||||
@ -200,7 +200,7 @@ def build_bus_progress_callback(
|
||||
) -> Callable[..., Awaitable[None]]:
|
||||
"""Return the bus progress callback for agent runtime events."""
|
||||
|
||||
async def _bus_progress(
|
||||
async def _publish_progress(
|
||||
content: str,
|
||||
*,
|
||||
tool_hint: bool = False,
|
||||
@ -209,8 +209,6 @@ def build_bus_progress_callback(
|
||||
reasoning: bool = False,
|
||||
reasoning_end: bool = False,
|
||||
) -> None:
|
||||
if file_edit_events and msg.channel != "websocket":
|
||||
return
|
||||
meta = dict(msg.metadata or {})
|
||||
meta["_progress"] = True
|
||||
meta["_tool_hint"] = tool_hint
|
||||
@ -231,6 +229,43 @@ def build_bus_progress_callback(
|
||||
)
|
||||
)
|
||||
|
||||
if msg.channel == "websocket":
|
||||
async def _websocket_progress(
|
||||
content: str,
|
||||
*,
|
||||
tool_hint: bool = False,
|
||||
tool_events: list[dict[str, Any]] | None = None,
|
||||
file_edit_events: list[dict[str, Any]] | None = None,
|
||||
reasoning: bool = False,
|
||||
reasoning_end: bool = False,
|
||||
) -> None:
|
||||
await _publish_progress(
|
||||
content,
|
||||
tool_hint=tool_hint,
|
||||
tool_events=tool_events,
|
||||
file_edit_events=file_edit_events,
|
||||
reasoning=reasoning,
|
||||
reasoning_end=reasoning_end,
|
||||
)
|
||||
|
||||
return _websocket_progress
|
||||
|
||||
async def _bus_progress(
|
||||
content: str,
|
||||
*,
|
||||
tool_hint: bool = False,
|
||||
tool_events: list[dict[str, Any]] | None = None,
|
||||
reasoning: bool = False,
|
||||
reasoning_end: bool = False,
|
||||
) -> None:
|
||||
await _publish_progress(
|
||||
content,
|
||||
tool_hint=tool_hint,
|
||||
tool_events=tool_events,
|
||||
reasoning=reasoning,
|
||||
reasoning_end=reasoning_end,
|
||||
)
|
||||
|
||||
return _bus_progress
|
||||
|
||||
|
||||
|
||||
@ -6,10 +6,15 @@ from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
import nanobot.agent.runner as runner_module
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||
from nanobot.utils.progress_events import (
|
||||
invoke_file_edit_progress,
|
||||
on_progress_accepts_file_edit_events,
|
||||
)
|
||||
|
||||
|
||||
def _make_loop(tmp_path: Path) -> AgentLoop:
|
||||
@ -138,6 +143,52 @@ class TestToolEventProgress:
|
||||
assert file_events[1]["approximate"] is False
|
||||
assert (file_events[1]["added"], file_events[1]["deleted"]) == (2, 1)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_edit_snapshot_skipped_when_progress_callback_cannot_emit_file_edits(
|
||||
self,
|
||||
tmp_path: Path,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
loop = _make_loop(tmp_path)
|
||||
target = tmp_path / "foo.txt"
|
||||
target.write_text("old\n", encoding="utf-8")
|
||||
tool_call = ToolCallRequest(
|
||||
id="call-write",
|
||||
name="write_file",
|
||||
arguments={"path": "foo.txt", "content": "new\n"},
|
||||
)
|
||||
calls = iter([
|
||||
LLMResponse(content="", tool_calls=[tool_call]),
|
||||
LLMResponse(content="Done", tool_calls=[]),
|
||||
])
|
||||
loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
loop.tools.prepare_call = MagicMock(
|
||||
return_value=(None, {"path": "foo.txt", "content": "new\n"}, None),
|
||||
)
|
||||
|
||||
async def execute(name: str, params: dict) -> str:
|
||||
target.write_text(params["content"], encoding="utf-8")
|
||||
return "ok"
|
||||
|
||||
loop.tools.execute = AsyncMock(side_effect=execute)
|
||||
prepare_tracker = MagicMock(side_effect=AssertionError("unexpected file snapshot"))
|
||||
monkeypatch.setattr(runner_module, "prepare_file_edit_tracker", prepare_tracker)
|
||||
|
||||
async def on_progress(
|
||||
content: str,
|
||||
*,
|
||||
tool_hint: bool = False,
|
||||
tool_events: list[dict] | None = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
final_content, _, _, _, _ = await loop._run_agent_loop([], on_progress=on_progress)
|
||||
|
||||
assert final_content == "Done"
|
||||
assert target.read_text(encoding="utf-8") == "new\n"
|
||||
prepare_tracker.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_does_not_emit_file_edit_progress(self, tmp_path: Path) -> None:
|
||||
loop = _make_loop(tmp_path)
|
||||
@ -243,6 +294,7 @@ class TestToolEventProgress:
|
||||
chat_id="chat1",
|
||||
content="edit",
|
||||
))
|
||||
assert on_progress_accepts_file_edit_events(websocket_progress) is True
|
||||
await websocket_progress("", file_edit_events=edit_events)
|
||||
outbound = await bus.consume_outbound()
|
||||
assert outbound.metadata["_file_edit_events"] == edit_events
|
||||
@ -253,7 +305,8 @@ class TestToolEventProgress:
|
||||
chat_id="chat2",
|
||||
content="edit",
|
||||
))
|
||||
await telegram_progress("", file_edit_events=edit_events)
|
||||
assert on_progress_accepts_file_edit_events(telegram_progress) is False
|
||||
await invoke_file_edit_progress(telegram_progress, edit_events)
|
||||
assert bus.outbound_size == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user