mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-30 23:05:51 +00:00
feat(agent): emit structured _tool_events progress metadata
Extend the existing on_progress callback to carry structured tool-event payloads alongside the plain-text hint, so channels can render rich tool execution state (start/finish/error, arguments, results, file attachments) rather than only the pre-formatted hint string. Changes ------- - AgentLoop._tool_event_start_payload() — builds a version-1 start payload from a ToolCallRequest - AgentLoop._tool_event_result_extras() — extracts files/embeds from a tool result dict - AgentLoop._tool_event_finish_payloads() — maps tool_calls + tool_results + tool_events from AgentHookContext into finish payloads - _LoopHook.before_execute_tools() — passes tool_events=[...] to on_progress together with the existing tool_hint flag - _LoopHook.after_iteration() — emits a second on_progress call with the finish payloads once tool results are available - _bus_progress() — forwards tool_events as _tool_events in OutboundMessage metadata so channel implementations can read them - on_progress type widened to Callable[..., Awaitable[None]] on all public entry points; _cli_progress updated to accept and ignore tool_events The contract is additive: callers that only accept (content, *, tool_hint) continue to work unchanged. Callers that also accept tool_events receive the structured data. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
185a8fd34d
commit
c23d719780
@ -103,13 +103,18 @@ class _LoopHook(AgentHook):
|
||||
if thought:
|
||||
await self._on_progress(thought)
|
||||
tool_hint = self._loop._strip_think(self._loop._tool_hint(context.tool_calls))
|
||||
await self._on_progress(tool_hint, tool_hint=True)
|
||||
tool_events = [self._loop._tool_event_start_payload(tc) for tc in context.tool_calls]
|
||||
await self._on_progress(tool_hint, tool_hint=True, tool_events=tool_events)
|
||||
for tc in context.tool_calls:
|
||||
args_str = json.dumps(tc.arguments, ensure_ascii=False)
|
||||
logger.info("Tool call: {}({})", tc.name, args_str[:200])
|
||||
self._loop._set_tool_context(self._channel, self._chat_id, self._message_id)
|
||||
|
||||
async def after_iteration(self, context: AgentHookContext) -> None:
|
||||
if self._on_progress and context.tool_calls and context.tool_events:
|
||||
tool_events = self._loop._tool_event_finish_payloads(context)
|
||||
if tool_events:
|
||||
await self._on_progress("", tool_events=tool_events)
|
||||
u = context.usage or {}
|
||||
logger.debug(
|
||||
"LLM usage: prompt={} completion={} cached={}",
|
||||
@ -375,6 +380,58 @@ class AgentLoop:
|
||||
sub_cancelled = await self.subagents.cancel_by_session(key)
|
||||
return cancelled + sub_cancelled
|
||||
|
||||
@staticmethod
|
||||
def _tool_event_start_payload(tool_call: Any) -> dict[str, Any]:
|
||||
return {
|
||||
"version": 1,
|
||||
"phase": "start",
|
||||
"call_id": str(getattr(tool_call, "id", "") or ""),
|
||||
"name": getattr(tool_call, "name", ""),
|
||||
"arguments": getattr(tool_call, "arguments", {}) or {},
|
||||
"result": None,
|
||||
"error": None,
|
||||
"files": [],
|
||||
"embeds": [],
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _tool_event_result_extras(result: Any) -> tuple[list[Any], list[Any]]:
|
||||
if not isinstance(result, dict):
|
||||
return [], []
|
||||
files = result.get("files") if isinstance(result.get("files"), list) else []
|
||||
embeds = result.get("embeds") if isinstance(result.get("embeds"), list) else []
|
||||
return files, embeds
|
||||
|
||||
@classmethod
|
||||
def _tool_event_finish_payloads(cls, context: AgentHookContext) -> list[dict[str, Any]]:
|
||||
payloads: list[dict[str, Any]] = []
|
||||
count = min(len(context.tool_calls), len(context.tool_results), len(context.tool_events))
|
||||
for idx in range(count):
|
||||
tool_call = context.tool_calls[idx]
|
||||
result = context.tool_results[idx]
|
||||
event = context.tool_events[idx] if isinstance(context.tool_events[idx], dict) else {}
|
||||
status = event.get("status")
|
||||
phase = "end" if status == "ok" else "error"
|
||||
files, embeds = cls._tool_event_result_extras(result)
|
||||
payload = {
|
||||
"version": 1,
|
||||
"phase": phase,
|
||||
"call_id": str(getattr(tool_call, "id", "") or ""),
|
||||
"name": getattr(tool_call, "name", ""),
|
||||
"arguments": getattr(tool_call, "arguments", {}) or {},
|
||||
"result": result if phase == "end" else None,
|
||||
"error": None,
|
||||
"files": files,
|
||||
"embeds": embeds,
|
||||
}
|
||||
if phase == "error":
|
||||
if isinstance(result, str) and result.strip():
|
||||
payload["error"] = result.strip()
|
||||
else:
|
||||
payload["error"] = str(event.get("detail") or "Tool execution failed")
|
||||
payloads.append(payload)
|
||||
return payloads
|
||||
|
||||
def _effective_session_key(self, msg: InboundMessage) -> str:
|
||||
"""Return the session key used for task routing and mid-turn injections."""
|
||||
if self._unified_session and not msg.session_key_override:
|
||||
@ -726,7 +783,7 @@ class AgentLoop:
|
||||
self,
|
||||
msg: InboundMessage,
|
||||
session_key: str | None = None,
|
||||
on_progress: Callable[[str], Awaitable[None]] | None = None,
|
||||
on_progress: Callable[..., Awaitable[None]] | None = None,
|
||||
on_stream: Callable[[str], Awaitable[None]] | None = None,
|
||||
on_stream_end: Callable[..., Awaitable[None]] | None = None,
|
||||
pending_queue: asyncio.Queue | None = None,
|
||||
@ -833,10 +890,17 @@ class AgentLoop:
|
||||
chat_id=msg.chat_id,
|
||||
)
|
||||
|
||||
async def _bus_progress(content: str, *, tool_hint: bool = False) -> None:
|
||||
async def _bus_progress(
|
||||
content: str,
|
||||
*,
|
||||
tool_hint: bool = False,
|
||||
tool_events: list[dict[str, Any]] | None = None,
|
||||
) -> None:
|
||||
meta = dict(msg.metadata or {})
|
||||
meta["_progress"] = True
|
||||
meta["_tool_hint"] = tool_hint
|
||||
if tool_events:
|
||||
meta["_tool_events"] = tool_events
|
||||
await self.bus.publish_outbound(
|
||||
OutboundMessage(
|
||||
channel=msg.channel,
|
||||
@ -1137,7 +1201,7 @@ class AgentLoop:
|
||||
channel: str = "cli",
|
||||
chat_id: str = "direct",
|
||||
media: list[str] | None = None,
|
||||
on_progress: Callable[[str], Awaitable[None]] | None = None,
|
||||
on_progress: Callable[..., Awaitable[None]] | None = None,
|
||||
on_stream: Callable[[str], Awaitable[None]] | None = None,
|
||||
on_stream_end: Callable[..., Awaitable[None]] | None = None,
|
||||
) -> OutboundMessage | None:
|
||||
|
||||
@ -1028,7 +1028,7 @@ def agent(
|
||||
# Shared reference for progress callbacks
|
||||
_thinking: ThinkingSpinner | None = None
|
||||
|
||||
async def _cli_progress(content: str, *, tool_hint: bool = False) -> None:
|
||||
async def _cli_progress(content: str, *, tool_hint: bool = False, **_kwargs: Any) -> None:
|
||||
ch = agent_loop.channels_config
|
||||
if ch and tool_hint and not ch.send_tool_hints:
|
||||
return
|
||||
|
||||
125
tests/agent/test_loop_progress.py
Normal file
125
tests/agent/test_loop_progress.py
Normal file
@ -0,0 +1,125 @@
|
||||
"""Tests for structured tool-event progress metadata emitted by AgentLoop."""
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||
|
||||
|
||||
def _make_loop(tmp_path: Path) -> AgentLoop:
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
return AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model")
|
||||
|
||||
|
||||
class TestToolEventProgress:
|
||||
"""_run_agent_loop emits structured tool_events via on_progress."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_and_finish_events_emitted(self, tmp_path: Path) -> None:
|
||||
loop = _make_loop(tmp_path)
|
||||
tool_call = ToolCallRequest(id="call1", name="custom_tool", arguments={"path": "foo.txt"})
|
||||
calls = iter([
|
||||
LLMResponse(content="Visible", tool_calls=[tool_call]),
|
||||
LLMResponse(content="Done", tool_calls=[]),
|
||||
])
|
||||
loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
loop.tools.prepare_call = MagicMock(return_value=(None, {"path": "foo.txt"}, None))
|
||||
loop.tools.execute = AsyncMock(return_value="ok")
|
||||
|
||||
progress: list[tuple[str, bool, list[dict] | None]] = []
|
||||
|
||||
async def on_progress(
|
||||
content: str,
|
||||
*,
|
||||
tool_hint: bool = False,
|
||||
tool_events: list[dict] | None = None,
|
||||
) -> None:
|
||||
progress.append((content, tool_hint, tool_events))
|
||||
|
||||
final_content, _, _, _, _ = await loop._run_agent_loop([], on_progress=on_progress)
|
||||
|
||||
assert final_content == "Done"
|
||||
assert progress == [
|
||||
("Visible", False, None),
|
||||
(
|
||||
'custom_tool("foo.txt")',
|
||||
True,
|
||||
[{
|
||||
"version": 1,
|
||||
"phase": "start",
|
||||
"call_id": "call1",
|
||||
"name": "custom_tool",
|
||||
"arguments": {"path": "foo.txt"},
|
||||
"result": None,
|
||||
"error": None,
|
||||
"files": [],
|
||||
"embeds": [],
|
||||
}],
|
||||
),
|
||||
(
|
||||
"",
|
||||
False,
|
||||
[{
|
||||
"version": 1,
|
||||
"phase": "end",
|
||||
"call_id": "call1",
|
||||
"name": "custom_tool",
|
||||
"arguments": {"path": "foo.txt"},
|
||||
"result": "ok",
|
||||
"error": None,
|
||||
"files": [],
|
||||
"embeds": [],
|
||||
}],
|
||||
),
|
||||
]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bus_progress_forwards_tool_events_to_outbound_metadata(self, tmp_path: Path) -> None:
|
||||
"""When run() handles a bus message, _tool_events lands in OutboundMessage metadata."""
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model")
|
||||
|
||||
tool_call = ToolCallRequest(id="tc1", name="exec", arguments={"command": "ls"})
|
||||
calls = iter([
|
||||
LLMResponse(content="", tool_calls=[tool_call]),
|
||||
LLMResponse(content="Done", tool_calls=[]),
|
||||
])
|
||||
loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
loop.tools.prepare_call = MagicMock(return_value=(None, {"command": "ls"}, None))
|
||||
loop.tools.execute = AsyncMock(return_value="file.txt")
|
||||
|
||||
msg = InboundMessage(channel="telegram", chat_id="chat1", content="run ls")
|
||||
await loop.run(msg)
|
||||
|
||||
# Drain all outbound messages and find the one carrying _tool_events
|
||||
outbound = []
|
||||
while bus.outbound_size() > 0:
|
||||
outbound.append(await bus.consume_outbound())
|
||||
|
||||
tool_event_msgs = [m for m in outbound if m.metadata and m.metadata.get("_tool_events")]
|
||||
assert tool_event_msgs, "expected at least one outbound message with _tool_events"
|
||||
|
||||
start_msgs = [m for m in tool_event_msgs if m.metadata["_tool_events"][0]["phase"] == "start"]
|
||||
finish_msgs = [m for m in tool_event_msgs if m.metadata["_tool_events"][0]["phase"] in ("end", "error")]
|
||||
assert start_msgs, "expected a start-phase tool event"
|
||||
assert finish_msgs, "expected a finish-phase tool event"
|
||||
|
||||
start = start_msgs[0].metadata["_tool_events"][0]
|
||||
assert start["name"] == "exec"
|
||||
assert start["call_id"] == "tc1"
|
||||
assert start["result"] is None
|
||||
|
||||
finish = finish_msgs[0].metadata["_tool_events"][0]
|
||||
assert finish["phase"] == "end"
|
||||
assert finish["result"] == "file.txt"
|
||||
@ -152,7 +152,6 @@ class TestMessageToolSuppressLogic:
|
||||
('read foo.txt', True),
|
||||
]
|
||||
|
||||
|
||||
class TestMessageToolTurnTracking:
|
||||
|
||||
def test_sent_in_turn_tracks_same_target(self) -> None:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user