mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-20 00:22:31 +00:00
Refine file edit progress gating
This commit is contained in:
parent
de8761f25a
commit
eb3aed359f
@ -32,7 +32,10 @@ from nanobot.utils.helpers import (
|
|||||||
strip_think,
|
strip_think,
|
||||||
truncate_text,
|
truncate_text,
|
||||||
)
|
)
|
||||||
from nanobot.utils.progress_events import invoke_file_edit_progress
|
from nanobot.utils.progress_events import (
|
||||||
|
invoke_file_edit_progress,
|
||||||
|
on_progress_accepts_file_edit_events,
|
||||||
|
)
|
||||||
from nanobot.utils.prompt_templates import render_template
|
from nanobot.utils.prompt_templates import render_template
|
||||||
from nanobot.utils.runtime import (
|
from nanobot.utils.runtime import (
|
||||||
EMPTY_FINAL_RESPONSE_MESSAGE,
|
EMPTY_FINAL_RESPONSE_MESSAGE,
|
||||||
@ -820,16 +823,25 @@ class AgentRunner:
|
|||||||
return prep_error + hint, event, (
|
return prep_error + hint, event, (
|
||||||
RuntimeError(prep_error) if spec.fail_on_tool_error else None
|
RuntimeError(prep_error) if spec.fail_on_tool_error else None
|
||||||
)
|
)
|
||||||
file_edit_tracker = prepare_file_edit_tracker(
|
emit_file_edit_events = (
|
||||||
call_id=tool_call.id,
|
spec.progress_callback is not None
|
||||||
tool_name=tool_call.name,
|
and on_progress_accepts_file_edit_events(spec.progress_callback)
|
||||||
tool=tool,
|
|
||||||
workspace=spec.workspace,
|
|
||||||
params=params if isinstance(params, dict) else None,
|
|
||||||
)
|
)
|
||||||
if file_edit_tracker is not None and spec.progress_callback is not None:
|
progress_callback = spec.progress_callback if emit_file_edit_events else None
|
||||||
|
file_edit_tracker = (
|
||||||
|
prepare_file_edit_tracker(
|
||||||
|
call_id=tool_call.id,
|
||||||
|
tool_name=tool_call.name,
|
||||||
|
tool=tool,
|
||||||
|
workspace=spec.workspace,
|
||||||
|
params=params if isinstance(params, dict) else None,
|
||||||
|
)
|
||||||
|
if progress_callback is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
if file_edit_tracker is not None and progress_callback is not None:
|
||||||
await invoke_file_edit_progress(
|
await invoke_file_edit_progress(
|
||||||
spec.progress_callback,
|
progress_callback,
|
||||||
[build_file_edit_start_event(
|
[build_file_edit_start_event(
|
||||||
file_edit_tracker,
|
file_edit_tracker,
|
||||||
params if isinstance(params, dict) else None,
|
params if isinstance(params, dict) else None,
|
||||||
@ -843,9 +855,9 @@ class AgentRunner:
|
|||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
raise
|
raise
|
||||||
except BaseException as exc:
|
except BaseException as exc:
|
||||||
if file_edit_tracker is not None and spec.progress_callback is not None:
|
if file_edit_tracker is not None and progress_callback is not None:
|
||||||
await invoke_file_edit_progress(
|
await invoke_file_edit_progress(
|
||||||
spec.progress_callback,
|
progress_callback,
|
||||||
[build_file_edit_error_event(file_edit_tracker, str(exc))],
|
[build_file_edit_error_event(file_edit_tracker, str(exc))],
|
||||||
)
|
)
|
||||||
event = {
|
event = {
|
||||||
@ -869,9 +881,9 @@ class AgentRunner:
|
|||||||
return payload, event, None
|
return payload, event, None
|
||||||
|
|
||||||
if isinstance(result, str) and result.startswith("Error"):
|
if isinstance(result, str) and result.startswith("Error"):
|
||||||
if file_edit_tracker is not None and spec.progress_callback is not None:
|
if file_edit_tracker is not None and progress_callback is not None:
|
||||||
await invoke_file_edit_progress(
|
await invoke_file_edit_progress(
|
||||||
spec.progress_callback,
|
progress_callback,
|
||||||
[build_file_edit_error_event(file_edit_tracker, result)],
|
[build_file_edit_error_event(file_edit_tracker, result)],
|
||||||
)
|
)
|
||||||
event = {
|
event = {
|
||||||
@ -892,9 +904,9 @@ class AgentRunner:
|
|||||||
return result + hint, event, RuntimeError(result)
|
return result + hint, event, RuntimeError(result)
|
||||||
return result + hint, event, None
|
return result + hint, event, None
|
||||||
|
|
||||||
if file_edit_tracker is not None and spec.progress_callback is not None:
|
if file_edit_tracker is not None and progress_callback is not None:
|
||||||
await invoke_file_edit_progress(
|
await invoke_file_edit_progress(
|
||||||
spec.progress_callback,
|
progress_callback,
|
||||||
[build_file_edit_end_event(file_edit_tracker)],
|
[build_file_edit_end_event(file_edit_tracker)],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -200,7 +200,7 @@ def build_bus_progress_callback(
|
|||||||
) -> Callable[..., Awaitable[None]]:
|
) -> Callable[..., Awaitable[None]]:
|
||||||
"""Return the bus progress callback for agent runtime events."""
|
"""Return the bus progress callback for agent runtime events."""
|
||||||
|
|
||||||
async def _bus_progress(
|
async def _publish_progress(
|
||||||
content: str,
|
content: str,
|
||||||
*,
|
*,
|
||||||
tool_hint: bool = False,
|
tool_hint: bool = False,
|
||||||
@ -209,8 +209,6 @@ def build_bus_progress_callback(
|
|||||||
reasoning: bool = False,
|
reasoning: bool = False,
|
||||||
reasoning_end: bool = False,
|
reasoning_end: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
if file_edit_events and msg.channel != "websocket":
|
|
||||||
return
|
|
||||||
meta = dict(msg.metadata or {})
|
meta = dict(msg.metadata or {})
|
||||||
meta["_progress"] = True
|
meta["_progress"] = True
|
||||||
meta["_tool_hint"] = tool_hint
|
meta["_tool_hint"] = tool_hint
|
||||||
@ -231,6 +229,43 @@ def build_bus_progress_callback(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if msg.channel == "websocket":
|
||||||
|
async def _websocket_progress(
|
||||||
|
content: str,
|
||||||
|
*,
|
||||||
|
tool_hint: bool = False,
|
||||||
|
tool_events: list[dict[str, Any]] | None = None,
|
||||||
|
file_edit_events: list[dict[str, Any]] | None = None,
|
||||||
|
reasoning: bool = False,
|
||||||
|
reasoning_end: bool = False,
|
||||||
|
) -> None:
|
||||||
|
await _publish_progress(
|
||||||
|
content,
|
||||||
|
tool_hint=tool_hint,
|
||||||
|
tool_events=tool_events,
|
||||||
|
file_edit_events=file_edit_events,
|
||||||
|
reasoning=reasoning,
|
||||||
|
reasoning_end=reasoning_end,
|
||||||
|
)
|
||||||
|
|
||||||
|
return _websocket_progress
|
||||||
|
|
||||||
|
async def _bus_progress(
|
||||||
|
content: str,
|
||||||
|
*,
|
||||||
|
tool_hint: bool = False,
|
||||||
|
tool_events: list[dict[str, Any]] | None = None,
|
||||||
|
reasoning: bool = False,
|
||||||
|
reasoning_end: bool = False,
|
||||||
|
) -> None:
|
||||||
|
await _publish_progress(
|
||||||
|
content,
|
||||||
|
tool_hint=tool_hint,
|
||||||
|
tool_events=tool_events,
|
||||||
|
reasoning=reasoning,
|
||||||
|
reasoning_end=reasoning_end,
|
||||||
|
)
|
||||||
|
|
||||||
return _bus_progress
|
return _bus_progress
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -6,10 +6,15 @@ from unittest.mock import AsyncMock, MagicMock
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
import nanobot.agent.runner as runner_module
|
||||||
from nanobot.agent.loop import AgentLoop
|
from nanobot.agent.loop import AgentLoop
|
||||||
from nanobot.bus.events import InboundMessage
|
from nanobot.bus.events import InboundMessage
|
||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||||
|
from nanobot.utils.progress_events import (
|
||||||
|
invoke_file_edit_progress,
|
||||||
|
on_progress_accepts_file_edit_events,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _make_loop(tmp_path: Path) -> AgentLoop:
|
def _make_loop(tmp_path: Path) -> AgentLoop:
|
||||||
@ -138,6 +143,52 @@ class TestToolEventProgress:
|
|||||||
assert file_events[1]["approximate"] is False
|
assert file_events[1]["approximate"] is False
|
||||||
assert (file_events[1]["added"], file_events[1]["deleted"]) == (2, 1)
|
assert (file_events[1]["added"], file_events[1]["deleted"]) == (2, 1)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_file_edit_snapshot_skipped_when_progress_callback_cannot_emit_file_edits(
|
||||||
|
self,
|
||||||
|
tmp_path: Path,
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
loop = _make_loop(tmp_path)
|
||||||
|
target = tmp_path / "foo.txt"
|
||||||
|
target.write_text("old\n", encoding="utf-8")
|
||||||
|
tool_call = ToolCallRequest(
|
||||||
|
id="call-write",
|
||||||
|
name="write_file",
|
||||||
|
arguments={"path": "foo.txt", "content": "new\n"},
|
||||||
|
)
|
||||||
|
calls = iter([
|
||||||
|
LLMResponse(content="", tool_calls=[tool_call]),
|
||||||
|
LLMResponse(content="Done", tool_calls=[]),
|
||||||
|
])
|
||||||
|
loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls))
|
||||||
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
loop.tools.prepare_call = MagicMock(
|
||||||
|
return_value=(None, {"path": "foo.txt", "content": "new\n"}, None),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def execute(name: str, params: dict) -> str:
|
||||||
|
target.write_text(params["content"], encoding="utf-8")
|
||||||
|
return "ok"
|
||||||
|
|
||||||
|
loop.tools.execute = AsyncMock(side_effect=execute)
|
||||||
|
prepare_tracker = MagicMock(side_effect=AssertionError("unexpected file snapshot"))
|
||||||
|
monkeypatch.setattr(runner_module, "prepare_file_edit_tracker", prepare_tracker)
|
||||||
|
|
||||||
|
async def on_progress(
|
||||||
|
content: str,
|
||||||
|
*,
|
||||||
|
tool_hint: bool = False,
|
||||||
|
tool_events: list[dict] | None = None,
|
||||||
|
) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
final_content, _, _, _, _ = await loop._run_agent_loop([], on_progress=on_progress)
|
||||||
|
|
||||||
|
assert final_content == "Done"
|
||||||
|
assert target.read_text(encoding="utf-8") == "new\n"
|
||||||
|
prepare_tracker.assert_not_called()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_exec_does_not_emit_file_edit_progress(self, tmp_path: Path) -> None:
|
async def test_exec_does_not_emit_file_edit_progress(self, tmp_path: Path) -> None:
|
||||||
loop = _make_loop(tmp_path)
|
loop = _make_loop(tmp_path)
|
||||||
@ -243,6 +294,7 @@ class TestToolEventProgress:
|
|||||||
chat_id="chat1",
|
chat_id="chat1",
|
||||||
content="edit",
|
content="edit",
|
||||||
))
|
))
|
||||||
|
assert on_progress_accepts_file_edit_events(websocket_progress) is True
|
||||||
await websocket_progress("", file_edit_events=edit_events)
|
await websocket_progress("", file_edit_events=edit_events)
|
||||||
outbound = await bus.consume_outbound()
|
outbound = await bus.consume_outbound()
|
||||||
assert outbound.metadata["_file_edit_events"] == edit_events
|
assert outbound.metadata["_file_edit_events"] == edit_events
|
||||||
@ -253,7 +305,8 @@ class TestToolEventProgress:
|
|||||||
chat_id="chat2",
|
chat_id="chat2",
|
||||||
content="edit",
|
content="edit",
|
||||||
))
|
))
|
||||||
await telegram_progress("", file_edit_events=edit_events)
|
assert on_progress_accepts_file_edit_events(telegram_progress) is False
|
||||||
|
await invoke_file_edit_progress(telegram_progress, edit_events)
|
||||||
assert bus.outbound_size == 0
|
assert bus.outbound_size == 0
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user