refactor: unify agent runner lifecycle hooks

This commit is contained in:
Xubin Ren 2026-03-26 19:39:57 +00:00 committed by Xubin Ren
parent e7d371ec1e
commit 5bf0f6fe7d
5 changed files with 277 additions and 65 deletions

49
nanobot/agent/hook.py Normal file
View File

@ -0,0 +1,49 @@
"""Shared lifecycle hook primitives for agent runs."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
from nanobot.providers.base import LLMResponse, ToolCallRequest
@dataclass(slots=True)
class AgentHookContext:
"""Mutable per-iteration state exposed to runner hooks."""
iteration: int
messages: list[dict[str, Any]]
response: LLMResponse | None = None
usage: dict[str, int] = field(default_factory=dict)
tool_calls: list[ToolCallRequest] = field(default_factory=list)
tool_results: list[Any] = field(default_factory=list)
tool_events: list[dict[str, str]] = field(default_factory=list)
final_content: str | None = None
stop_reason: str | None = None
error: str | None = None
class AgentHook:
"""Minimal lifecycle surface for shared runner customization."""
def wants_streaming(self) -> bool:
return False
async def before_iteration(self, context: AgentHookContext) -> None:
pass
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
pass
async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None:
pass
async def before_execute_tools(self, context: AgentHookContext) -> None:
pass
async def after_iteration(self, context: AgentHookContext) -> None:
pass
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
return content

View File

@ -14,6 +14,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable
from loguru import logger
from nanobot.agent.context import ContextBuilder
from nanobot.agent.hook import AgentHook, AgentHookContext
from nanobot.agent.memory import MemoryConsolidator
from nanobot.agent.runner import AgentRunSpec, AgentRunner
from nanobot.agent.subagent import SubagentManager
@ -216,53 +217,52 @@ class AgentLoop:
``resuming=True`` means tool calls follow (spinner should restart);
``resuming=False`` means this is the final response.
"""
# Wrap on_stream with stateful think-tag filter so downstream
# consumers (CLI, channels) never see <think> blocks.
_raw_stream = on_stream
_stream_buf = ""
loop_self = self
async def _filtered_stream(delta: str) -> None:
nonlocal _stream_buf
from nanobot.utils.helpers import strip_think
prev_clean = strip_think(_stream_buf)
_stream_buf += delta
new_clean = strip_think(_stream_buf)
incremental = new_clean[len(prev_clean):]
if incremental and _raw_stream:
await _raw_stream(incremental)
class _LoopHook(AgentHook):
def __init__(self) -> None:
self._stream_buf = ""
async def _wrapped_stream_end(*, resuming: bool = False) -> None:
nonlocal _stream_buf
if on_stream_end:
await on_stream_end(resuming=resuming)
_stream_buf = ""
def wants_streaming(self) -> bool:
return on_stream is not None
async def _handle_tool_calls(response) -> None:
if not on_progress:
return
if not on_stream:
thought = self._strip_think(response.content)
if thought:
await on_progress(thought)
tool_hint = self._strip_think(self._tool_hint(response.tool_calls))
await on_progress(tool_hint, tool_hint=True)
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
from nanobot.utils.helpers import strip_think
async def _prepare_tools(tool_calls) -> None:
for tc in tool_calls:
args_str = json.dumps(tc.arguments, ensure_ascii=False)
logger.info("Tool call: {}({})", tc.name, args_str[:200])
self._set_tool_context(channel, chat_id, message_id)
prev_clean = strip_think(self._stream_buf)
self._stream_buf += delta
new_clean = strip_think(self._stream_buf)
incremental = new_clean[len(prev_clean):]
if incremental and on_stream:
await on_stream(incremental)
async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None:
if on_stream_end:
await on_stream_end(resuming=resuming)
self._stream_buf = ""
async def before_execute_tools(self, context: AgentHookContext) -> None:
if on_progress:
if not on_stream:
thought = loop_self._strip_think(context.response.content if context.response else None)
if thought:
await on_progress(thought)
tool_hint = loop_self._strip_think(loop_self._tool_hint(context.tool_calls))
await on_progress(tool_hint, tool_hint=True)
for tc in context.tool_calls:
args_str = json.dumps(tc.arguments, ensure_ascii=False)
logger.info("Tool call: {}({})", tc.name, args_str[:200])
loop_self._set_tool_context(channel, chat_id, message_id)
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
return loop_self._strip_think(content)
result = await self.runner.run(AgentRunSpec(
initial_messages=initial_messages,
tools=self.tools,
model=self.model,
max_iterations=self.max_iterations,
on_stream=_filtered_stream if on_stream else None,
on_stream_end=_wrapped_stream_end if on_stream else None,
on_tool_calls=_handle_tool_calls,
before_execute_tools=_prepare_tools,
finalize_content=self._strip_think,
hook=_LoopHook(),
error_message="Sorry, I encountered an error calling the AI model.",
concurrent_tools=True,
))

View File

@ -3,12 +3,12 @@
from __future__ import annotations
import asyncio
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field
from typing import Any
from nanobot.agent.hook import AgentHook, AgentHookContext
from nanobot.agent.tools.registry import ToolRegistry
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
from nanobot.providers.base import LLMProvider, ToolCallRequest
from nanobot.utils.helpers import build_assistant_message
_DEFAULT_MAX_ITERATIONS_MESSAGE = (
@ -29,11 +29,7 @@ class AgentRunSpec:
temperature: float | None = None
max_tokens: int | None = None
reasoning_effort: str | None = None
on_stream: Callable[[str], Awaitable[None]] | None = None
on_stream_end: Callable[..., Awaitable[None]] | None = None
on_tool_calls: Callable[[LLMResponse], Awaitable[None] | None] | None = None
before_execute_tools: Callable[[list[ToolCallRequest]], Awaitable[None] | None] | None = None
finalize_content: Callable[[str | None], str | None] | None = None
hook: AgentHook | None = None
error_message: str | None = _DEFAULT_ERROR_MESSAGE
max_iterations_message: str | None = None
concurrent_tools: bool = False
@ -60,6 +56,7 @@ class AgentRunner:
self.provider = provider
async def run(self, spec: AgentRunSpec) -> AgentRunResult:
hook = spec.hook or AgentHook()
messages = list(spec.initial_messages)
final_content: str | None = None
tools_used: list[str] = []
@ -68,7 +65,9 @@ class AgentRunner:
stop_reason = "completed"
tool_events: list[dict[str, str]] = []
for _ in range(spec.max_iterations):
for iteration in range(spec.max_iterations):
context = AgentHookContext(iteration=iteration, messages=messages)
await hook.before_iteration(context)
kwargs: dict[str, Any] = {
"messages": messages,
"tools": spec.tools.get_definitions(),
@ -81,10 +80,13 @@ class AgentRunner:
if spec.reasoning_effort is not None:
kwargs["reasoning_effort"] = spec.reasoning_effort
if spec.on_stream:
if hook.wants_streaming():
async def _stream(delta: str) -> None:
await hook.on_stream(context, delta)
response = await self.provider.chat_stream_with_retry(
**kwargs,
on_content_delta=spec.on_stream,
on_content_delta=_stream,
)
else:
response = await self.provider.chat_with_retry(**kwargs)
@ -94,14 +96,13 @@ class AgentRunner:
"prompt_tokens": int(raw_usage.get("prompt_tokens", 0) or 0),
"completion_tokens": int(raw_usage.get("completion_tokens", 0) or 0),
}
context.response = response
context.usage = usage
context.tool_calls = list(response.tool_calls)
if response.has_tool_calls:
if spec.on_stream_end:
await spec.on_stream_end(resuming=True)
if spec.on_tool_calls:
maybe = spec.on_tool_calls(response)
if maybe is not None:
await maybe
if hook.wants_streaming():
await hook.on_stream_end(context, resuming=True)
messages.append(build_assistant_message(
response.content or "",
@ -111,16 +112,18 @@ class AgentRunner:
))
tools_used.extend(tc.name for tc in response.tool_calls)
if spec.before_execute_tools:
maybe = spec.before_execute_tools(response.tool_calls)
if maybe is not None:
await maybe
await hook.before_execute_tools(context)
results, new_events, fatal_error = await self._execute_tools(spec, response.tool_calls)
tool_events.extend(new_events)
context.tool_results = list(results)
context.tool_events = list(new_events)
if fatal_error is not None:
error = f"Error: {type(fatal_error).__name__}: {fatal_error}"
stop_reason = "tool_error"
context.error = error
context.stop_reason = stop_reason
await hook.after_iteration(context)
break
for tool_call, result in zip(response.tool_calls, results):
messages.append({
@ -129,16 +132,21 @@ class AgentRunner:
"name": tool_call.name,
"content": result,
})
await hook.after_iteration(context)
continue
if spec.on_stream_end:
await spec.on_stream_end(resuming=False)
if hook.wants_streaming():
await hook.on_stream_end(context, resuming=False)
clean = spec.finalize_content(response.content) if spec.finalize_content else response.content
clean = hook.finalize_content(context, response.content)
if response.finish_reason == "error":
final_content = clean or spec.error_message or _DEFAULT_ERROR_MESSAGE
stop_reason = "error"
error = final_content
context.final_content = final_content
context.error = error
context.stop_reason = stop_reason
await hook.after_iteration(context)
break
messages.append(build_assistant_message(
@ -147,6 +155,9 @@ class AgentRunner:
thinking_blocks=response.thinking_blocks,
))
final_content = clean
context.final_content = final_content
context.stop_reason = stop_reason
await hook.after_iteration(context)
break
else:
stop_reason = "max_iterations"

View File

@ -8,6 +8,7 @@ from typing import Any
from loguru import logger
from nanobot.agent.hook import AgentHook, AgentHookContext
from nanobot.agent.runner import AgentRunSpec, AgentRunner
from nanobot.agent.skills import BUILTIN_SKILLS_DIR
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
@ -113,17 +114,19 @@ class SubagentManager:
{"role": "system", "content": system_prompt},
{"role": "user", "content": task},
]
async def _log_tool_calls(tool_calls) -> None:
for tool_call in tool_calls:
args_str = json.dumps(tool_call.arguments, ensure_ascii=False)
logger.debug("Subagent [{}] executing: {} with arguments: {}", task_id, tool_call.name, args_str)
class _SubagentHook(AgentHook):
async def before_execute_tools(self, context: AgentHookContext) -> None:
for tool_call in context.tool_calls:
args_str = json.dumps(tool_call.arguments, ensure_ascii=False)
logger.debug("Subagent [{}] executing: {} with arguments: {}", task_id, tool_call.name, args_str)
result = await self.runner.run(AgentRunSpec(
initial_messages=messages,
tools=tools,
model=self.model,
max_iterations=15,
before_execute_tools=_log_tool_calls,
hook=_SubagentHook(),
max_iterations_message="Task completed but no final response was generated.",
error_message=None,
fail_on_tool_error=True,

View File

@ -81,6 +81,125 @@ async def test_runner_preserves_reasoning_fields_and_tool_results():
)
@pytest.mark.asyncio
async def test_runner_calls_hooks_in_order():
from nanobot.agent.hook import AgentHook, AgentHookContext
from nanobot.agent.runner import AgentRunSpec, AgentRunner
provider = MagicMock()
call_count = {"n": 0}
events: list[tuple] = []
async def chat_with_retry(**kwargs):
call_count["n"] += 1
if call_count["n"] == 1:
return LLMResponse(
content="thinking",
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
)
return LLMResponse(content="done", tool_calls=[], usage={})
provider.chat_with_retry = chat_with_retry
tools = MagicMock()
tools.get_definitions.return_value = []
tools.execute = AsyncMock(return_value="tool result")
class RecordingHook(AgentHook):
async def before_iteration(self, context: AgentHookContext) -> None:
events.append(("before_iteration", context.iteration))
async def before_execute_tools(self, context: AgentHookContext) -> None:
events.append((
"before_execute_tools",
context.iteration,
[tc.name for tc in context.tool_calls],
))
async def after_iteration(self, context: AgentHookContext) -> None:
events.append((
"after_iteration",
context.iteration,
context.final_content,
list(context.tool_results),
list(context.tool_events),
context.stop_reason,
))
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
events.append(("finalize_content", context.iteration, content))
return content.upper() if content else content
runner = AgentRunner(provider)
result = await runner.run(AgentRunSpec(
initial_messages=[],
tools=tools,
model="test-model",
max_iterations=3,
hook=RecordingHook(),
))
assert result.final_content == "DONE"
assert events == [
("before_iteration", 0),
("before_execute_tools", 0, ["list_dir"]),
(
"after_iteration",
0,
None,
["tool result"],
[{"name": "list_dir", "status": "ok", "detail": "tool result"}],
None,
),
("before_iteration", 1),
("finalize_content", 1, "done"),
("after_iteration", 1, "DONE", [], [], "completed"),
]
@pytest.mark.asyncio
async def test_runner_streaming_hook_receives_deltas_and_end_signal():
from nanobot.agent.hook import AgentHook, AgentHookContext
from nanobot.agent.runner import AgentRunSpec, AgentRunner
provider = MagicMock()
streamed: list[str] = []
endings: list[bool] = []
async def chat_stream_with_retry(*, on_content_delta, **kwargs):
await on_content_delta("he")
await on_content_delta("llo")
return LLMResponse(content="hello", tool_calls=[], usage={})
provider.chat_stream_with_retry = chat_stream_with_retry
provider.chat_with_retry = AsyncMock()
tools = MagicMock()
tools.get_definitions.return_value = []
class StreamingHook(AgentHook):
def wants_streaming(self) -> bool:
return True
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
streamed.append(delta)
async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None:
endings.append(resuming)
runner = AgentRunner(provider)
result = await runner.run(AgentRunSpec(
initial_messages=[],
tools=tools,
model="test-model",
max_iterations=1,
hook=StreamingHook(),
))
assert result.final_content == "hello"
assert streamed == ["he", "llo"]
assert endings == [False]
provider.chat_with_retry.assert_not_awaited()
@pytest.mark.asyncio
async def test_runner_returns_max_iterations_fallback():
from nanobot.agent.runner import AgentRunSpec, AgentRunner
@ -158,6 +277,36 @@ async def test_loop_max_iterations_message_stays_stable(tmp_path):
)
@pytest.mark.asyncio
async def test_loop_stream_filter_handles_think_only_prefix_without_crashing(tmp_path):
loop = _make_loop(tmp_path)
deltas: list[str] = []
endings: list[bool] = []
async def chat_stream_with_retry(*, on_content_delta, **kwargs):
await on_content_delta("<think>hidden")
await on_content_delta("</think>Hello")
return LLMResponse(content="<think>hidden</think>Hello", tool_calls=[], usage={})
loop.provider.chat_stream_with_retry = chat_stream_with_retry
async def on_stream(delta: str) -> None:
deltas.append(delta)
async def on_stream_end(*, resuming: bool = False) -> None:
endings.append(resuming)
final_content, _, _ = await loop._run_agent_loop(
[],
on_stream=on_stream,
on_stream_end=on_stream_end,
)
assert final_content == "Hello"
assert deltas == ["Hello"]
assert endings == [False]
@pytest.mark.asyncio
async def test_subagent_max_iterations_announces_existing_fallback(tmp_path, monkeypatch):
from nanobot.agent.subagent import SubagentManager