refactor: unify agent runner lifecycle hooks

This commit is contained in:
Xubin Ren 2026-03-26 19:39:57 +00:00 committed by Xubin Ren
parent e7d371ec1e
commit 5bf0f6fe7d
5 changed files with 277 additions and 65 deletions

49
nanobot/agent/hook.py Normal file
View File

@ -0,0 +1,49 @@
"""Shared lifecycle hook primitives for agent runs."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
from nanobot.providers.base import LLMResponse, ToolCallRequest
@dataclass(slots=True)
class AgentHookContext:
"""Mutable per-iteration state exposed to runner hooks."""
iteration: int
messages: list[dict[str, Any]]
response: LLMResponse | None = None
usage: dict[str, int] = field(default_factory=dict)
tool_calls: list[ToolCallRequest] = field(default_factory=list)
tool_results: list[Any] = field(default_factory=list)
tool_events: list[dict[str, str]] = field(default_factory=list)
final_content: str | None = None
stop_reason: str | None = None
error: str | None = None
class AgentHook:
"""Minimal lifecycle surface for shared runner customization."""
def wants_streaming(self) -> bool:
return False
async def before_iteration(self, context: AgentHookContext) -> None:
pass
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
pass
async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None:
pass
async def before_execute_tools(self, context: AgentHookContext) -> None:
pass
async def after_iteration(self, context: AgentHookContext) -> None:
pass
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
return content

View File

@ -14,6 +14,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable
from loguru import logger from loguru import logger
from nanobot.agent.context import ContextBuilder from nanobot.agent.context import ContextBuilder
from nanobot.agent.hook import AgentHook, AgentHookContext
from nanobot.agent.memory import MemoryConsolidator from nanobot.agent.memory import MemoryConsolidator
from nanobot.agent.runner import AgentRunSpec, AgentRunner from nanobot.agent.runner import AgentRunSpec, AgentRunner
from nanobot.agent.subagent import SubagentManager from nanobot.agent.subagent import SubagentManager
@ -216,53 +217,52 @@ class AgentLoop:
``resuming=True`` means tool calls follow (spinner should restart); ``resuming=True`` means tool calls follow (spinner should restart);
``resuming=False`` means this is the final response. ``resuming=False`` means this is the final response.
""" """
# Wrap on_stream with stateful think-tag filter so downstream loop_self = self
# consumers (CLI, channels) never see <think> blocks.
_raw_stream = on_stream
_stream_buf = ""
async def _filtered_stream(delta: str) -> None: class _LoopHook(AgentHook):
nonlocal _stream_buf def __init__(self) -> None:
from nanobot.utils.helpers import strip_think self._stream_buf = ""
prev_clean = strip_think(_stream_buf)
_stream_buf += delta
new_clean = strip_think(_stream_buf)
incremental = new_clean[len(prev_clean):]
if incremental and _raw_stream:
await _raw_stream(incremental)
async def _wrapped_stream_end(*, resuming: bool = False) -> None: def wants_streaming(self) -> bool:
nonlocal _stream_buf return on_stream is not None
if on_stream_end:
await on_stream_end(resuming=resuming)
_stream_buf = ""
async def _handle_tool_calls(response) -> None: async def on_stream(self, context: AgentHookContext, delta: str) -> None:
if not on_progress: from nanobot.utils.helpers import strip_think
return
if not on_stream:
thought = self._strip_think(response.content)
if thought:
await on_progress(thought)
tool_hint = self._strip_think(self._tool_hint(response.tool_calls))
await on_progress(tool_hint, tool_hint=True)
async def _prepare_tools(tool_calls) -> None: prev_clean = strip_think(self._stream_buf)
for tc in tool_calls: self._stream_buf += delta
args_str = json.dumps(tc.arguments, ensure_ascii=False) new_clean = strip_think(self._stream_buf)
logger.info("Tool call: {}({})", tc.name, args_str[:200]) incremental = new_clean[len(prev_clean):]
self._set_tool_context(channel, chat_id, message_id) if incremental and on_stream:
await on_stream(incremental)
async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None:
if on_stream_end:
await on_stream_end(resuming=resuming)
self._stream_buf = ""
async def before_execute_tools(self, context: AgentHookContext) -> None:
if on_progress:
if not on_stream:
thought = loop_self._strip_think(context.response.content if context.response else None)
if thought:
await on_progress(thought)
tool_hint = loop_self._strip_think(loop_self._tool_hint(context.tool_calls))
await on_progress(tool_hint, tool_hint=True)
for tc in context.tool_calls:
args_str = json.dumps(tc.arguments, ensure_ascii=False)
logger.info("Tool call: {}({})", tc.name, args_str[:200])
loop_self._set_tool_context(channel, chat_id, message_id)
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
return loop_self._strip_think(content)
result = await self.runner.run(AgentRunSpec( result = await self.runner.run(AgentRunSpec(
initial_messages=initial_messages, initial_messages=initial_messages,
tools=self.tools, tools=self.tools,
model=self.model, model=self.model,
max_iterations=self.max_iterations, max_iterations=self.max_iterations,
on_stream=_filtered_stream if on_stream else None, hook=_LoopHook(),
on_stream_end=_wrapped_stream_end if on_stream else None,
on_tool_calls=_handle_tool_calls,
before_execute_tools=_prepare_tools,
finalize_content=self._strip_think,
error_message="Sorry, I encountered an error calling the AI model.", error_message="Sorry, I encountered an error calling the AI model.",
concurrent_tools=True, concurrent_tools=True,
)) ))

View File

@ -3,12 +3,12 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any from typing import Any
from nanobot.agent.hook import AgentHook, AgentHookContext
from nanobot.agent.tools.registry import ToolRegistry from nanobot.agent.tools.registry import ToolRegistry
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest from nanobot.providers.base import LLMProvider, ToolCallRequest
from nanobot.utils.helpers import build_assistant_message from nanobot.utils.helpers import build_assistant_message
_DEFAULT_MAX_ITERATIONS_MESSAGE = ( _DEFAULT_MAX_ITERATIONS_MESSAGE = (
@ -29,11 +29,7 @@ class AgentRunSpec:
temperature: float | None = None temperature: float | None = None
max_tokens: int | None = None max_tokens: int | None = None
reasoning_effort: str | None = None reasoning_effort: str | None = None
on_stream: Callable[[str], Awaitable[None]] | None = None hook: AgentHook | None = None
on_stream_end: Callable[..., Awaitable[None]] | None = None
on_tool_calls: Callable[[LLMResponse], Awaitable[None] | None] | None = None
before_execute_tools: Callable[[list[ToolCallRequest]], Awaitable[None] | None] | None = None
finalize_content: Callable[[str | None], str | None] | None = None
error_message: str | None = _DEFAULT_ERROR_MESSAGE error_message: str | None = _DEFAULT_ERROR_MESSAGE
max_iterations_message: str | None = None max_iterations_message: str | None = None
concurrent_tools: bool = False concurrent_tools: bool = False
@ -60,6 +56,7 @@ class AgentRunner:
self.provider = provider self.provider = provider
async def run(self, spec: AgentRunSpec) -> AgentRunResult: async def run(self, spec: AgentRunSpec) -> AgentRunResult:
hook = spec.hook or AgentHook()
messages = list(spec.initial_messages) messages = list(spec.initial_messages)
final_content: str | None = None final_content: str | None = None
tools_used: list[str] = [] tools_used: list[str] = []
@ -68,7 +65,9 @@ class AgentRunner:
stop_reason = "completed" stop_reason = "completed"
tool_events: list[dict[str, str]] = [] tool_events: list[dict[str, str]] = []
for _ in range(spec.max_iterations): for iteration in range(spec.max_iterations):
context = AgentHookContext(iteration=iteration, messages=messages)
await hook.before_iteration(context)
kwargs: dict[str, Any] = { kwargs: dict[str, Any] = {
"messages": messages, "messages": messages,
"tools": spec.tools.get_definitions(), "tools": spec.tools.get_definitions(),
@ -81,10 +80,13 @@ class AgentRunner:
if spec.reasoning_effort is not None: if spec.reasoning_effort is not None:
kwargs["reasoning_effort"] = spec.reasoning_effort kwargs["reasoning_effort"] = spec.reasoning_effort
if spec.on_stream: if hook.wants_streaming():
async def _stream(delta: str) -> None:
await hook.on_stream(context, delta)
response = await self.provider.chat_stream_with_retry( response = await self.provider.chat_stream_with_retry(
**kwargs, **kwargs,
on_content_delta=spec.on_stream, on_content_delta=_stream,
) )
else: else:
response = await self.provider.chat_with_retry(**kwargs) response = await self.provider.chat_with_retry(**kwargs)
@ -94,14 +96,13 @@ class AgentRunner:
"prompt_tokens": int(raw_usage.get("prompt_tokens", 0) or 0), "prompt_tokens": int(raw_usage.get("prompt_tokens", 0) or 0),
"completion_tokens": int(raw_usage.get("completion_tokens", 0) or 0), "completion_tokens": int(raw_usage.get("completion_tokens", 0) or 0),
} }
context.response = response
context.usage = usage
context.tool_calls = list(response.tool_calls)
if response.has_tool_calls: if response.has_tool_calls:
if spec.on_stream_end: if hook.wants_streaming():
await spec.on_stream_end(resuming=True) await hook.on_stream_end(context, resuming=True)
if spec.on_tool_calls:
maybe = spec.on_tool_calls(response)
if maybe is not None:
await maybe
messages.append(build_assistant_message( messages.append(build_assistant_message(
response.content or "", response.content or "",
@ -111,16 +112,18 @@ class AgentRunner:
)) ))
tools_used.extend(tc.name for tc in response.tool_calls) tools_used.extend(tc.name for tc in response.tool_calls)
if spec.before_execute_tools: await hook.before_execute_tools(context)
maybe = spec.before_execute_tools(response.tool_calls)
if maybe is not None:
await maybe
results, new_events, fatal_error = await self._execute_tools(spec, response.tool_calls) results, new_events, fatal_error = await self._execute_tools(spec, response.tool_calls)
tool_events.extend(new_events) tool_events.extend(new_events)
context.tool_results = list(results)
context.tool_events = list(new_events)
if fatal_error is not None: if fatal_error is not None:
error = f"Error: {type(fatal_error).__name__}: {fatal_error}" error = f"Error: {type(fatal_error).__name__}: {fatal_error}"
stop_reason = "tool_error" stop_reason = "tool_error"
context.error = error
context.stop_reason = stop_reason
await hook.after_iteration(context)
break break
for tool_call, result in zip(response.tool_calls, results): for tool_call, result in zip(response.tool_calls, results):
messages.append({ messages.append({
@ -129,16 +132,21 @@ class AgentRunner:
"name": tool_call.name, "name": tool_call.name,
"content": result, "content": result,
}) })
await hook.after_iteration(context)
continue continue
if spec.on_stream_end: if hook.wants_streaming():
await spec.on_stream_end(resuming=False) await hook.on_stream_end(context, resuming=False)
clean = spec.finalize_content(response.content) if spec.finalize_content else response.content clean = hook.finalize_content(context, response.content)
if response.finish_reason == "error": if response.finish_reason == "error":
final_content = clean or spec.error_message or _DEFAULT_ERROR_MESSAGE final_content = clean or spec.error_message or _DEFAULT_ERROR_MESSAGE
stop_reason = "error" stop_reason = "error"
error = final_content error = final_content
context.final_content = final_content
context.error = error
context.stop_reason = stop_reason
await hook.after_iteration(context)
break break
messages.append(build_assistant_message( messages.append(build_assistant_message(
@ -147,6 +155,9 @@ class AgentRunner:
thinking_blocks=response.thinking_blocks, thinking_blocks=response.thinking_blocks,
)) ))
final_content = clean final_content = clean
context.final_content = final_content
context.stop_reason = stop_reason
await hook.after_iteration(context)
break break
else: else:
stop_reason = "max_iterations" stop_reason = "max_iterations"

View File

@ -8,6 +8,7 @@ from typing import Any
from loguru import logger from loguru import logger
from nanobot.agent.hook import AgentHook, AgentHookContext
from nanobot.agent.runner import AgentRunSpec, AgentRunner from nanobot.agent.runner import AgentRunSpec, AgentRunner
from nanobot.agent.skills import BUILTIN_SKILLS_DIR from nanobot.agent.skills import BUILTIN_SKILLS_DIR
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
@ -113,17 +114,19 @@ class SubagentManager:
{"role": "system", "content": system_prompt}, {"role": "system", "content": system_prompt},
{"role": "user", "content": task}, {"role": "user", "content": task},
] ]
async def _log_tool_calls(tool_calls) -> None:
for tool_call in tool_calls: class _SubagentHook(AgentHook):
args_str = json.dumps(tool_call.arguments, ensure_ascii=False) async def before_execute_tools(self, context: AgentHookContext) -> None:
logger.debug("Subagent [{}] executing: {} with arguments: {}", task_id, tool_call.name, args_str) for tool_call in context.tool_calls:
args_str = json.dumps(tool_call.arguments, ensure_ascii=False)
logger.debug("Subagent [{}] executing: {} with arguments: {}", task_id, tool_call.name, args_str)
result = await self.runner.run(AgentRunSpec( result = await self.runner.run(AgentRunSpec(
initial_messages=messages, initial_messages=messages,
tools=tools, tools=tools,
model=self.model, model=self.model,
max_iterations=15, max_iterations=15,
before_execute_tools=_log_tool_calls, hook=_SubagentHook(),
max_iterations_message="Task completed but no final response was generated.", max_iterations_message="Task completed but no final response was generated.",
error_message=None, error_message=None,
fail_on_tool_error=True, fail_on_tool_error=True,

View File

@ -81,6 +81,125 @@ async def test_runner_preserves_reasoning_fields_and_tool_results():
) )
@pytest.mark.asyncio
async def test_runner_calls_hooks_in_order():
from nanobot.agent.hook import AgentHook, AgentHookContext
from nanobot.agent.runner import AgentRunSpec, AgentRunner
provider = MagicMock()
call_count = {"n": 0}
events: list[tuple] = []
async def chat_with_retry(**kwargs):
call_count["n"] += 1
if call_count["n"] == 1:
return LLMResponse(
content="thinking",
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
)
return LLMResponse(content="done", tool_calls=[], usage={})
provider.chat_with_retry = chat_with_retry
tools = MagicMock()
tools.get_definitions.return_value = []
tools.execute = AsyncMock(return_value="tool result")
class RecordingHook(AgentHook):
async def before_iteration(self, context: AgentHookContext) -> None:
events.append(("before_iteration", context.iteration))
async def before_execute_tools(self, context: AgentHookContext) -> None:
events.append((
"before_execute_tools",
context.iteration,
[tc.name for tc in context.tool_calls],
))
async def after_iteration(self, context: AgentHookContext) -> None:
events.append((
"after_iteration",
context.iteration,
context.final_content,
list(context.tool_results),
list(context.tool_events),
context.stop_reason,
))
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
events.append(("finalize_content", context.iteration, content))
return content.upper() if content else content
runner = AgentRunner(provider)
result = await runner.run(AgentRunSpec(
initial_messages=[],
tools=tools,
model="test-model",
max_iterations=3,
hook=RecordingHook(),
))
assert result.final_content == "DONE"
assert events == [
("before_iteration", 0),
("before_execute_tools", 0, ["list_dir"]),
(
"after_iteration",
0,
None,
["tool result"],
[{"name": "list_dir", "status": "ok", "detail": "tool result"}],
None,
),
("before_iteration", 1),
("finalize_content", 1, "done"),
("after_iteration", 1, "DONE", [], [], "completed"),
]
@pytest.mark.asyncio
async def test_runner_streaming_hook_receives_deltas_and_end_signal():
from nanobot.agent.hook import AgentHook, AgentHookContext
from nanobot.agent.runner import AgentRunSpec, AgentRunner
provider = MagicMock()
streamed: list[str] = []
endings: list[bool] = []
async def chat_stream_with_retry(*, on_content_delta, **kwargs):
await on_content_delta("he")
await on_content_delta("llo")
return LLMResponse(content="hello", tool_calls=[], usage={})
provider.chat_stream_with_retry = chat_stream_with_retry
provider.chat_with_retry = AsyncMock()
tools = MagicMock()
tools.get_definitions.return_value = []
class StreamingHook(AgentHook):
def wants_streaming(self) -> bool:
return True
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
streamed.append(delta)
async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None:
endings.append(resuming)
runner = AgentRunner(provider)
result = await runner.run(AgentRunSpec(
initial_messages=[],
tools=tools,
model="test-model",
max_iterations=1,
hook=StreamingHook(),
))
assert result.final_content == "hello"
assert streamed == ["he", "llo"]
assert endings == [False]
provider.chat_with_retry.assert_not_awaited()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_runner_returns_max_iterations_fallback(): async def test_runner_returns_max_iterations_fallback():
from nanobot.agent.runner import AgentRunSpec, AgentRunner from nanobot.agent.runner import AgentRunSpec, AgentRunner
@ -158,6 +277,36 @@ async def test_loop_max_iterations_message_stays_stable(tmp_path):
) )
@pytest.mark.asyncio
async def test_loop_stream_filter_handles_think_only_prefix_without_crashing(tmp_path):
loop = _make_loop(tmp_path)
deltas: list[str] = []
endings: list[bool] = []
async def chat_stream_with_retry(*, on_content_delta, **kwargs):
await on_content_delta("<think>hidden")
await on_content_delta("</think>Hello")
return LLMResponse(content="<think>hidden</think>Hello", tool_calls=[], usage={})
loop.provider.chat_stream_with_retry = chat_stream_with_retry
async def on_stream(delta: str) -> None:
deltas.append(delta)
async def on_stream_end(*, resuming: bool = False) -> None:
endings.append(resuming)
final_content, _, _ = await loop._run_agent_loop(
[],
on_stream=on_stream,
on_stream_end=on_stream_end,
)
assert final_content == "Hello"
assert deltas == ["Hello"]
assert endings == [False]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_subagent_max_iterations_announces_existing_fallback(tmp_path, monkeypatch): async def test_subagent_max_iterations_announces_existing_fallback(tmp_path, monkeypatch):
from nanobot.agent.subagent import SubagentManager from nanobot.agent.subagent import SubagentManager