mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-02 09:22:36 +00:00
refactor: unify agent runner lifecycle hooks
This commit is contained in:
parent
e7d371ec1e
commit
5bf0f6fe7d
49
nanobot/agent/hook.py
Normal file
49
nanobot/agent/hook.py
Normal file
@ -0,0 +1,49 @@
|
||||
"""Shared lifecycle hook primitives for agent runs."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class AgentHookContext:
|
||||
"""Mutable per-iteration state exposed to runner hooks."""
|
||||
|
||||
iteration: int
|
||||
messages: list[dict[str, Any]]
|
||||
response: LLMResponse | None = None
|
||||
usage: dict[str, int] = field(default_factory=dict)
|
||||
tool_calls: list[ToolCallRequest] = field(default_factory=list)
|
||||
tool_results: list[Any] = field(default_factory=list)
|
||||
tool_events: list[dict[str, str]] = field(default_factory=list)
|
||||
final_content: str | None = None
|
||||
stop_reason: str | None = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class AgentHook:
|
||||
"""Minimal lifecycle surface for shared runner customization."""
|
||||
|
||||
def wants_streaming(self) -> bool:
|
||||
return False
|
||||
|
||||
async def before_iteration(self, context: AgentHookContext) -> None:
|
||||
pass
|
||||
|
||||
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
|
||||
pass
|
||||
|
||||
async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None:
|
||||
pass
|
||||
|
||||
async def before_execute_tools(self, context: AgentHookContext) -> None:
|
||||
pass
|
||||
|
||||
async def after_iteration(self, context: AgentHookContext) -> None:
|
||||
pass
|
||||
|
||||
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
|
||||
return content
|
||||
@ -14,6 +14,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.agent.context import ContextBuilder
|
||||
from nanobot.agent.hook import AgentHook, AgentHookContext
|
||||
from nanobot.agent.memory import MemoryConsolidator
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
from nanobot.agent.subagent import SubagentManager
|
||||
@ -216,53 +217,52 @@ class AgentLoop:
|
||||
``resuming=True`` means tool calls follow (spinner should restart);
|
||||
``resuming=False`` means this is the final response.
|
||||
"""
|
||||
# Wrap on_stream with stateful think-tag filter so downstream
|
||||
# consumers (CLI, channels) never see <think> blocks.
|
||||
_raw_stream = on_stream
|
||||
_stream_buf = ""
|
||||
loop_self = self
|
||||
|
||||
async def _filtered_stream(delta: str) -> None:
|
||||
nonlocal _stream_buf
|
||||
from nanobot.utils.helpers import strip_think
|
||||
prev_clean = strip_think(_stream_buf)
|
||||
_stream_buf += delta
|
||||
new_clean = strip_think(_stream_buf)
|
||||
incremental = new_clean[len(prev_clean):]
|
||||
if incremental and _raw_stream:
|
||||
await _raw_stream(incremental)
|
||||
class _LoopHook(AgentHook):
|
||||
def __init__(self) -> None:
|
||||
self._stream_buf = ""
|
||||
|
||||
async def _wrapped_stream_end(*, resuming: bool = False) -> None:
|
||||
nonlocal _stream_buf
|
||||
if on_stream_end:
|
||||
await on_stream_end(resuming=resuming)
|
||||
_stream_buf = ""
|
||||
def wants_streaming(self) -> bool:
|
||||
return on_stream is not None
|
||||
|
||||
async def _handle_tool_calls(response) -> None:
|
||||
if not on_progress:
|
||||
return
|
||||
if not on_stream:
|
||||
thought = self._strip_think(response.content)
|
||||
if thought:
|
||||
await on_progress(thought)
|
||||
tool_hint = self._strip_think(self._tool_hint(response.tool_calls))
|
||||
await on_progress(tool_hint, tool_hint=True)
|
||||
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
|
||||
from nanobot.utils.helpers import strip_think
|
||||
|
||||
async def _prepare_tools(tool_calls) -> None:
|
||||
for tc in tool_calls:
|
||||
args_str = json.dumps(tc.arguments, ensure_ascii=False)
|
||||
logger.info("Tool call: {}({})", tc.name, args_str[:200])
|
||||
self._set_tool_context(channel, chat_id, message_id)
|
||||
prev_clean = strip_think(self._stream_buf)
|
||||
self._stream_buf += delta
|
||||
new_clean = strip_think(self._stream_buf)
|
||||
incremental = new_clean[len(prev_clean):]
|
||||
if incremental and on_stream:
|
||||
await on_stream(incremental)
|
||||
|
||||
async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None:
|
||||
if on_stream_end:
|
||||
await on_stream_end(resuming=resuming)
|
||||
self._stream_buf = ""
|
||||
|
||||
async def before_execute_tools(self, context: AgentHookContext) -> None:
|
||||
if on_progress:
|
||||
if not on_stream:
|
||||
thought = loop_self._strip_think(context.response.content if context.response else None)
|
||||
if thought:
|
||||
await on_progress(thought)
|
||||
tool_hint = loop_self._strip_think(loop_self._tool_hint(context.tool_calls))
|
||||
await on_progress(tool_hint, tool_hint=True)
|
||||
for tc in context.tool_calls:
|
||||
args_str = json.dumps(tc.arguments, ensure_ascii=False)
|
||||
logger.info("Tool call: {}({})", tc.name, args_str[:200])
|
||||
loop_self._set_tool_context(channel, chat_id, message_id)
|
||||
|
||||
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
|
||||
return loop_self._strip_think(content)
|
||||
|
||||
result = await self.runner.run(AgentRunSpec(
|
||||
initial_messages=initial_messages,
|
||||
tools=self.tools,
|
||||
model=self.model,
|
||||
max_iterations=self.max_iterations,
|
||||
on_stream=_filtered_stream if on_stream else None,
|
||||
on_stream_end=_wrapped_stream_end if on_stream else None,
|
||||
on_tool_calls=_handle_tool_calls,
|
||||
before_execute_tools=_prepare_tools,
|
||||
finalize_content=self._strip_think,
|
||||
hook=_LoopHook(),
|
||||
error_message="Sorry, I encountered an error calling the AI model.",
|
||||
concurrent_tools=True,
|
||||
))
|
||||
|
||||
@ -3,12 +3,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.hook import AgentHook, AgentHookContext
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
from nanobot.providers.base import LLMProvider, ToolCallRequest
|
||||
from nanobot.utils.helpers import build_assistant_message
|
||||
|
||||
_DEFAULT_MAX_ITERATIONS_MESSAGE = (
|
||||
@ -29,11 +29,7 @@ class AgentRunSpec:
|
||||
temperature: float | None = None
|
||||
max_tokens: int | None = None
|
||||
reasoning_effort: str | None = None
|
||||
on_stream: Callable[[str], Awaitable[None]] | None = None
|
||||
on_stream_end: Callable[..., Awaitable[None]] | None = None
|
||||
on_tool_calls: Callable[[LLMResponse], Awaitable[None] | None] | None = None
|
||||
before_execute_tools: Callable[[list[ToolCallRequest]], Awaitable[None] | None] | None = None
|
||||
finalize_content: Callable[[str | None], str | None] | None = None
|
||||
hook: AgentHook | None = None
|
||||
error_message: str | None = _DEFAULT_ERROR_MESSAGE
|
||||
max_iterations_message: str | None = None
|
||||
concurrent_tools: bool = False
|
||||
@ -60,6 +56,7 @@ class AgentRunner:
|
||||
self.provider = provider
|
||||
|
||||
async def run(self, spec: AgentRunSpec) -> AgentRunResult:
|
||||
hook = spec.hook or AgentHook()
|
||||
messages = list(spec.initial_messages)
|
||||
final_content: str | None = None
|
||||
tools_used: list[str] = []
|
||||
@ -68,7 +65,9 @@ class AgentRunner:
|
||||
stop_reason = "completed"
|
||||
tool_events: list[dict[str, str]] = []
|
||||
|
||||
for _ in range(spec.max_iterations):
|
||||
for iteration in range(spec.max_iterations):
|
||||
context = AgentHookContext(iteration=iteration, messages=messages)
|
||||
await hook.before_iteration(context)
|
||||
kwargs: dict[str, Any] = {
|
||||
"messages": messages,
|
||||
"tools": spec.tools.get_definitions(),
|
||||
@ -81,10 +80,13 @@ class AgentRunner:
|
||||
if spec.reasoning_effort is not None:
|
||||
kwargs["reasoning_effort"] = spec.reasoning_effort
|
||||
|
||||
if spec.on_stream:
|
||||
if hook.wants_streaming():
|
||||
async def _stream(delta: str) -> None:
|
||||
await hook.on_stream(context, delta)
|
||||
|
||||
response = await self.provider.chat_stream_with_retry(
|
||||
**kwargs,
|
||||
on_content_delta=spec.on_stream,
|
||||
on_content_delta=_stream,
|
||||
)
|
||||
else:
|
||||
response = await self.provider.chat_with_retry(**kwargs)
|
||||
@ -94,14 +96,13 @@ class AgentRunner:
|
||||
"prompt_tokens": int(raw_usage.get("prompt_tokens", 0) or 0),
|
||||
"completion_tokens": int(raw_usage.get("completion_tokens", 0) or 0),
|
||||
}
|
||||
context.response = response
|
||||
context.usage = usage
|
||||
context.tool_calls = list(response.tool_calls)
|
||||
|
||||
if response.has_tool_calls:
|
||||
if spec.on_stream_end:
|
||||
await spec.on_stream_end(resuming=True)
|
||||
if spec.on_tool_calls:
|
||||
maybe = spec.on_tool_calls(response)
|
||||
if maybe is not None:
|
||||
await maybe
|
||||
if hook.wants_streaming():
|
||||
await hook.on_stream_end(context, resuming=True)
|
||||
|
||||
messages.append(build_assistant_message(
|
||||
response.content or "",
|
||||
@ -111,16 +112,18 @@ class AgentRunner:
|
||||
))
|
||||
tools_used.extend(tc.name for tc in response.tool_calls)
|
||||
|
||||
if spec.before_execute_tools:
|
||||
maybe = spec.before_execute_tools(response.tool_calls)
|
||||
if maybe is not None:
|
||||
await maybe
|
||||
await hook.before_execute_tools(context)
|
||||
|
||||
results, new_events, fatal_error = await self._execute_tools(spec, response.tool_calls)
|
||||
tool_events.extend(new_events)
|
||||
context.tool_results = list(results)
|
||||
context.tool_events = list(new_events)
|
||||
if fatal_error is not None:
|
||||
error = f"Error: {type(fatal_error).__name__}: {fatal_error}"
|
||||
stop_reason = "tool_error"
|
||||
context.error = error
|
||||
context.stop_reason = stop_reason
|
||||
await hook.after_iteration(context)
|
||||
break
|
||||
for tool_call, result in zip(response.tool_calls, results):
|
||||
messages.append({
|
||||
@ -129,16 +132,21 @@ class AgentRunner:
|
||||
"name": tool_call.name,
|
||||
"content": result,
|
||||
})
|
||||
await hook.after_iteration(context)
|
||||
continue
|
||||
|
||||
if spec.on_stream_end:
|
||||
await spec.on_stream_end(resuming=False)
|
||||
if hook.wants_streaming():
|
||||
await hook.on_stream_end(context, resuming=False)
|
||||
|
||||
clean = spec.finalize_content(response.content) if spec.finalize_content else response.content
|
||||
clean = hook.finalize_content(context, response.content)
|
||||
if response.finish_reason == "error":
|
||||
final_content = clean or spec.error_message or _DEFAULT_ERROR_MESSAGE
|
||||
stop_reason = "error"
|
||||
error = final_content
|
||||
context.final_content = final_content
|
||||
context.error = error
|
||||
context.stop_reason = stop_reason
|
||||
await hook.after_iteration(context)
|
||||
break
|
||||
|
||||
messages.append(build_assistant_message(
|
||||
@ -147,6 +155,9 @@ class AgentRunner:
|
||||
thinking_blocks=response.thinking_blocks,
|
||||
))
|
||||
final_content = clean
|
||||
context.final_content = final_content
|
||||
context.stop_reason = stop_reason
|
||||
await hook.after_iteration(context)
|
||||
break
|
||||
else:
|
||||
stop_reason = "max_iterations"
|
||||
|
||||
@ -8,6 +8,7 @@ from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.agent.hook import AgentHook, AgentHookContext
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
from nanobot.agent.skills import BUILTIN_SKILLS_DIR
|
||||
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
|
||||
@ -113,17 +114,19 @@ class SubagentManager:
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": task},
|
||||
]
|
||||
async def _log_tool_calls(tool_calls) -> None:
|
||||
for tool_call in tool_calls:
|
||||
args_str = json.dumps(tool_call.arguments, ensure_ascii=False)
|
||||
logger.debug("Subagent [{}] executing: {} with arguments: {}", task_id, tool_call.name, args_str)
|
||||
|
||||
class _SubagentHook(AgentHook):
|
||||
async def before_execute_tools(self, context: AgentHookContext) -> None:
|
||||
for tool_call in context.tool_calls:
|
||||
args_str = json.dumps(tool_call.arguments, ensure_ascii=False)
|
||||
logger.debug("Subagent [{}] executing: {} with arguments: {}", task_id, tool_call.name, args_str)
|
||||
|
||||
result = await self.runner.run(AgentRunSpec(
|
||||
initial_messages=messages,
|
||||
tools=tools,
|
||||
model=self.model,
|
||||
max_iterations=15,
|
||||
before_execute_tools=_log_tool_calls,
|
||||
hook=_SubagentHook(),
|
||||
max_iterations_message="Task completed but no final response was generated.",
|
||||
error_message=None,
|
||||
fail_on_tool_error=True,
|
||||
|
||||
@ -81,6 +81,125 @@ async def test_runner_preserves_reasoning_fields_and_tool_results():
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_calls_hooks_in_order():
|
||||
from nanobot.agent.hook import AgentHook, AgentHookContext
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
call_count = {"n": 0}
|
||||
events: list[tuple] = []
|
||||
|
||||
async def chat_with_retry(**kwargs):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
return LLMResponse(
|
||||
content="thinking",
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
|
||||
)
|
||||
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
tools.execute = AsyncMock(return_value="tool result")
|
||||
|
||||
class RecordingHook(AgentHook):
|
||||
async def before_iteration(self, context: AgentHookContext) -> None:
|
||||
events.append(("before_iteration", context.iteration))
|
||||
|
||||
async def before_execute_tools(self, context: AgentHookContext) -> None:
|
||||
events.append((
|
||||
"before_execute_tools",
|
||||
context.iteration,
|
||||
[tc.name for tc in context.tool_calls],
|
||||
))
|
||||
|
||||
async def after_iteration(self, context: AgentHookContext) -> None:
|
||||
events.append((
|
||||
"after_iteration",
|
||||
context.iteration,
|
||||
context.final_content,
|
||||
list(context.tool_results),
|
||||
list(context.tool_events),
|
||||
context.stop_reason,
|
||||
))
|
||||
|
||||
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
|
||||
events.append(("finalize_content", context.iteration, content))
|
||||
return content.upper() if content else content
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=3,
|
||||
hook=RecordingHook(),
|
||||
))
|
||||
|
||||
assert result.final_content == "DONE"
|
||||
assert events == [
|
||||
("before_iteration", 0),
|
||||
("before_execute_tools", 0, ["list_dir"]),
|
||||
(
|
||||
"after_iteration",
|
||||
0,
|
||||
None,
|
||||
["tool result"],
|
||||
[{"name": "list_dir", "status": "ok", "detail": "tool result"}],
|
||||
None,
|
||||
),
|
||||
("before_iteration", 1),
|
||||
("finalize_content", 1, "done"),
|
||||
("after_iteration", 1, "DONE", [], [], "completed"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_streaming_hook_receives_deltas_and_end_signal():
|
||||
from nanobot.agent.hook import AgentHook, AgentHookContext
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
streamed: list[str] = []
|
||||
endings: list[bool] = []
|
||||
|
||||
async def chat_stream_with_retry(*, on_content_delta, **kwargs):
|
||||
await on_content_delta("he")
|
||||
await on_content_delta("llo")
|
||||
return LLMResponse(content="hello", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_stream_with_retry = chat_stream_with_retry
|
||||
provider.chat_with_retry = AsyncMock()
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
class StreamingHook(AgentHook):
|
||||
def wants_streaming(self) -> bool:
|
||||
return True
|
||||
|
||||
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
|
||||
streamed.append(delta)
|
||||
|
||||
async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None:
|
||||
endings.append(resuming)
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
hook=StreamingHook(),
|
||||
))
|
||||
|
||||
assert result.final_content == "hello"
|
||||
assert streamed == ["he", "llo"]
|
||||
assert endings == [False]
|
||||
provider.chat_with_retry.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_returns_max_iterations_fallback():
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
@ -158,6 +277,36 @@ async def test_loop_max_iterations_message_stays_stable(tmp_path):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_loop_stream_filter_handles_think_only_prefix_without_crashing(tmp_path):
|
||||
loop = _make_loop(tmp_path)
|
||||
deltas: list[str] = []
|
||||
endings: list[bool] = []
|
||||
|
||||
async def chat_stream_with_retry(*, on_content_delta, **kwargs):
|
||||
await on_content_delta("<think>hidden")
|
||||
await on_content_delta("</think>Hello")
|
||||
return LLMResponse(content="<think>hidden</think>Hello", tool_calls=[], usage={})
|
||||
|
||||
loop.provider.chat_stream_with_retry = chat_stream_with_retry
|
||||
|
||||
async def on_stream(delta: str) -> None:
|
||||
deltas.append(delta)
|
||||
|
||||
async def on_stream_end(*, resuming: bool = False) -> None:
|
||||
endings.append(resuming)
|
||||
|
||||
final_content, _, _ = await loop._run_agent_loop(
|
||||
[],
|
||||
on_stream=on_stream,
|
||||
on_stream_end=on_stream_end,
|
||||
)
|
||||
|
||||
assert final_content == "Hello"
|
||||
assert deltas == ["Hello"]
|
||||
assert endings == [False]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subagent_max_iterations_announces_existing_fallback(tmp_path, monkeypatch):
|
||||
from nanobot.agent.subagent import SubagentManager
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user