feat: add run-level agent hook lifecycle

This commit is contained in:
chengyongru 2026-06-03 18:45:13 +08:00 committed by Xubin Ren
parent c77ca16d91
commit 2ea226055e
5 changed files with 317 additions and 7 deletions

View File

@ -1,7 +1,7 @@
"""Agent core module."""
from nanobot.agent.context import ContextBuilder
from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook
from nanobot.agent.hook import AgentHook, AgentHookContext, AgentRunHookContext, CompositeHook
from nanobot.agent.loop import AgentLoop
from nanobot.agent.memory import MemoryStore
from nanobot.agent.skills import SkillsLoader
@ -10,6 +10,7 @@ from nanobot.agent.subagent import SubagentManager
__all__ = [
"AgentHook",
"AgentHookContext",
"AgentRunHookContext",
"AgentLoop",
"CompositeHook",
"ContextBuilder",

View File

@ -28,6 +28,21 @@ class AgentHookContext:
error: str | None = None
@dataclass(slots=True)
class AgentRunHookContext:
"""Run-level state snapshot exposed to runner hooks."""
messages: list[dict[str, Any]]
final_content: str | None = None
tools_used: list[str] = field(default_factory=list)
usage: dict[str, int] = field(default_factory=dict)
stop_reason: str | None = None
error: str | None = None
tool_events: list[dict[str, str]] = field(default_factory=list)
had_injections: bool = False
exception: BaseException | None = None
class AgentHook:
"""Minimal lifecycle surface for shared runner customization."""
@ -37,6 +52,18 @@ class AgentHook:
def wants_streaming(self) -> bool:
return False
async def before_run(self, context: AgentRunHookContext) -> None:
pass
async def after_run(self, context: AgentRunHookContext) -> None:
pass
async def on_error(self, context: AgentRunHookContext) -> None:
pass
async def on_finally(self, context: AgentRunHookContext) -> None:
pass
async def before_iteration(self, context: AgentHookContext) -> None:
pass
@ -98,6 +125,18 @@ class CompositeHook(AgentHook):
async def before_iteration(self, context: AgentHookContext) -> None:
await self._for_each_hook_safe("before_iteration", context)
async def before_run(self, context: AgentRunHookContext) -> None:
await self._for_each_hook_safe("before_run", context)
async def after_run(self, context: AgentRunHookContext) -> None:
await self._for_each_hook_safe("after_run", context)
async def on_error(self, context: AgentRunHookContext) -> None:
await self._for_each_hook_safe("on_error", context)
async def on_finally(self, context: AgentRunHookContext) -> None:
await self._for_each_hook_safe("on_finally", context)
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
await self._for_each_hook_safe("on_stream", context, delta)

View File

@ -12,7 +12,7 @@ from typing import Any, Callable
from loguru import logger
from nanobot.agent.hook import AgentHook, AgentHookContext
from nanobot.agent.hook import AgentHook, AgentHookContext, AgentRunHookContext
from nanobot.agent.tools.registry import ToolRegistry
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
from nanobot.utils.file_edit_events import (
@ -272,6 +272,42 @@ class AgentRunner:
async def run(self, spec: AgentRunSpec) -> AgentRunResult:
hook = spec.hook or AgentHook()
messages = list(spec.initial_messages)
context = AgentRunHookContext(messages=list(messages))
try:
await hook.before_run(context)
result = await self._run_core(spec, hook, messages)
except BaseException as exc:
context.messages = list(messages)
context.stop_reason = "error"
context.error = f"Error: {type(exc).__name__}: {exc}"
context.exception = exc
await hook.on_error(context)
raise
else:
context.messages = list(result.messages)
context.final_content = result.final_content
context.tools_used = list(result.tools_used)
context.usage = dict(result.usage)
context.stop_reason = result.stop_reason
context.error = result.error
context.tool_events = list(result.tool_events)
context.had_injections = result.had_injections
context.exception = None
if context.error is not None:
await hook.on_error(context)
await hook.after_run(context)
return result
finally:
context.messages = list(messages)
await hook.on_finally(context)
async def _run_core(
self,
spec: AgentRunSpec,
hook: AgentHook,
messages: list[dict[str, Any]],
) -> AgentRunResult:
final_content: str | None = None
tools_used: list[str] = []
usage: dict[str, int] = {"prompt_tokens": 0, "completion_tokens": 0}

View File

@ -6,13 +6,17 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook
from nanobot.agent.hook import AgentHook, AgentHookContext, AgentRunHookContext, CompositeHook
def _ctx() -> AgentHookContext:
return AgentHookContext(iteration=0, messages=[])
def _run_ctx() -> AgentRunHookContext:
return AgentRunHookContext(messages=[])
# ---------------------------------------------------------------------------
# Base AgentHook emit_reasoning: no-op
# ---------------------------------------------------------------------------
@ -54,6 +58,9 @@ async def test_composite_fans_out_all_async_methods():
events: list[str] = []
class RecordingHook(AgentHook):
async def before_run(self, context: AgentRunHookContext) -> None:
events.append("before_run")
async def before_iteration(self, context: AgentHookContext) -> None:
events.append("before_iteration")
@ -72,23 +79,41 @@ async def test_composite_fans_out_all_async_methods():
async def after_iteration(self, context: AgentHookContext) -> None:
events.append("after_iteration")
async def after_run(self, context: AgentRunHookContext) -> None:
events.append("after_run")
async def on_error(self, context: AgentRunHookContext) -> None:
events.append("on_error")
async def on_finally(self, context: AgentRunHookContext) -> None:
events.append("on_finally")
hook = CompositeHook([RecordingHook(), RecordingHook()])
ctx = _ctx()
run_ctx = _run_ctx()
await hook.before_run(run_ctx)
await hook.before_iteration(ctx)
await hook.emit_reasoning("thinking...")
await hook.on_stream(ctx, "hi")
await hook.on_stream_end(ctx, resuming=True)
await hook.before_execute_tools(ctx)
await hook.after_iteration(ctx)
await hook.after_run(run_ctx)
await hook.on_error(run_ctx)
await hook.on_finally(run_ctx)
assert events == [
"before_run", "before_run",
"before_iteration", "before_iteration",
"emit_reasoning:thinking...", "emit_reasoning:thinking...",
"on_stream:hi", "on_stream:hi",
"on_stream_end:True", "on_stream_end:True",
"before_execute_tools", "before_execute_tools",
"after_iteration", "after_iteration",
"after_run", "after_run",
"on_error", "on_error",
"on_finally", "on_finally",
]
@ -137,6 +162,8 @@ async def test_composite_error_isolation_all_async():
calls: list[str] = []
class Bad(AgentHook):
async def before_run(self, context):
raise RuntimeError("err")
async def emit_reasoning(self, reasoning_content):
raise RuntimeError("err")
async def on_stream_end(self, context, *, resuming):
@ -145,8 +172,16 @@ async def test_composite_error_isolation_all_async():
raise RuntimeError("err")
async def after_iteration(self, context):
raise RuntimeError("err")
async def after_run(self, context):
raise RuntimeError("err")
async def on_error(self, context):
raise RuntimeError("err")
async def on_finally(self, context):
raise RuntimeError("err")
class Good(AgentHook):
async def before_run(self, context):
calls.append("before_run")
async def emit_reasoning(self, reasoning_content):
calls.append("emit_reasoning")
async def on_stream_end(self, context, *, resuming):
@ -155,14 +190,34 @@ async def test_composite_error_isolation_all_async():
calls.append("before_execute_tools")
async def after_iteration(self, context):
calls.append("after_iteration")
async def after_run(self, context):
calls.append("after_run")
async def on_error(self, context):
calls.append("on_error")
async def on_finally(self, context):
calls.append("on_finally")
hook = CompositeHook([Bad(), Good()])
ctx = _ctx()
run_ctx = _run_ctx()
await hook.before_run(run_ctx)
await hook.emit_reasoning("test")
await hook.on_stream_end(ctx, resuming=False)
await hook.before_execute_tools(ctx)
await hook.after_iteration(ctx)
assert calls == ["emit_reasoning", "on_stream_end", "before_execute_tools", "after_iteration"]
await hook.after_run(run_ctx)
await hook.on_error(run_ctx)
await hook.on_finally(run_ctx)
assert calls == [
"before_run",
"emit_reasoning",
"on_stream_end",
"before_execute_tools",
"after_iteration",
"after_run",
"on_error",
"on_finally",
]
# ---------------------------------------------------------------------------
@ -246,11 +301,16 @@ def test_composite_wants_streaming_empty():
async def test_composite_empty_hooks_no_ops():
hook = CompositeHook([])
ctx = _ctx()
run_ctx = _run_ctx()
await hook.before_run(run_ctx)
await hook.before_iteration(ctx)
await hook.on_stream(ctx, "delta")
await hook.on_stream_end(ctx, resuming=False)
await hook.before_execute_tools(ctx)
await hook.after_iteration(ctx)
await hook.after_run(run_ctx)
await hook.on_error(run_ctx)
await hook.on_finally(run_ctx)
assert hook.finalize_content(ctx, "test") == "test"
@ -316,12 +376,18 @@ async def test_agent_loop_extra_hook_receives_calls(tmp_path):
events: list[str] = []
class TrackingHook(AgentHook):
async def before_run(self, context):
events.append("before_run")
async def before_iteration(self, context):
events.append(f"before_iter:{context.iteration}")
async def after_iteration(self, context):
events.append(f"after_iter:{context.iteration}")
async def after_run(self, context):
events.append(f"after_run:{context.stop_reason}")
loop = _make_loop(tmp_path, hooks=[TrackingHook()])
loop.provider.chat_with_retry = AsyncMock(
return_value=LLMResponse(content="done", tool_calls=[], usage={})
@ -333,8 +399,10 @@ async def test_agent_loop_extra_hook_receives_calls(tmp_path):
)
assert content == "done"
assert "before_run" in events
assert "before_iter:0" in events
assert "after_iter:0" in events
assert "after_run:completed" in events
@pytest.mark.asyncio

View File

@ -16,7 +16,7 @@ _MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars
@pytest.mark.asyncio
async def test_runner_calls_hooks_in_order():
from nanobot.agent.hook import AgentHook, AgentHookContext
from nanobot.agent.runner import AgentRunSpec, AgentRunner
from nanobot.agent.runner import AgentRunner, AgentRunSpec
provider = MagicMock(spec=LLMProvider)
call_count = {"n": 0}
@ -92,7 +92,7 @@ async def test_runner_calls_hooks_in_order():
@pytest.mark.asyncio
async def test_runner_streaming_hook_receives_deltas_and_end_signal():
from nanobot.agent.hook import AgentHook, AgentHookContext
from nanobot.agent.runner import AgentRunSpec, AgentRunner
from nanobot.agent.runner import AgentRunner, AgentRunSpec
provider = MagicMock(spec=LLMProvider)
streamed: list[str] = []
@ -138,7 +138,7 @@ async def test_runner_streaming_hook_receives_deltas_and_end_signal():
async def test_runner_passes_cached_tokens_to_hook_context():
"""Hook context.usage should contain cached_tokens."""
from nanobot.agent.hook import AgentHook, AgentHookContext
from nanobot.agent.runner import AgentRunSpec, AgentRunner
from nanobot.agent.runner import AgentRunner, AgentRunSpec
provider = MagicMock(spec=LLMProvider)
captured_usage: list[dict] = []
@ -170,3 +170,169 @@ async def test_runner_passes_cached_tokens_to_hook_context():
assert len(captured_usage) == 1
assert captured_usage[0]["cached_tokens"] == 150
@pytest.mark.asyncio
async def test_runner_calls_run_level_hooks_on_success():
from nanobot.agent.hook import AgentHook, AgentRunHookContext
from nanobot.agent.runner import AgentRunner, AgentRunSpec
provider = MagicMock(spec=LLMProvider)
events: list[tuple] = []
async def chat_with_retry(**kwargs):
events.append(("request_messages", list(kwargs["messages"])))
return LLMResponse(
content="done",
tool_calls=[],
usage={"prompt_tokens": 3, "completion_tokens": 2},
)
provider.chat_with_retry = chat_with_retry
tools = MagicMock()
tools.get_definitions.return_value = []
class RunHook(AgentHook):
async def before_run(self, context: AgentRunHookContext) -> None:
events.append(("before_run", list(context.messages), context.stop_reason))
context.messages.append({"role": "user", "content": "hook-only"})
async def after_run(self, context: AgentRunHookContext) -> None:
events.append((
"after_run",
context.final_content,
context.stop_reason,
context.error,
dict(context.usage),
[msg["role"] for msg in context.messages],
))
async def on_error(self, context: AgentRunHookContext) -> None:
events.append(("on_error", context.error))
async def on_finally(self, context: AgentRunHookContext) -> None:
events.append(("on_finally", context.stop_reason, context.exception))
runner = AgentRunner(provider)
result = await runner.run(AgentRunSpec(
initial_messages=[{"role": "user", "content": "hi"}],
tools=tools,
model="test-model",
max_iterations=1,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
hook=RunHook(),
))
assert result.final_content == "done"
assert events == [
("before_run", [{"role": "user", "content": "hi"}], None),
("request_messages", [{"role": "user", "content": "hi"}]),
(
"after_run",
"done",
"completed",
None,
{"prompt_tokens": 3, "completion_tokens": 2},
["user", "assistant"],
),
("on_finally", "completed", None),
]
@pytest.mark.asyncio
async def test_runner_calls_on_error_for_model_error_result():
from nanobot.agent.hook import AgentHook, AgentRunHookContext
from nanobot.agent.runner import AgentRunner, AgentRunSpec
provider = MagicMock(spec=LLMProvider)
events: list[tuple] = []
async def chat_with_retry(**kwargs):
return LLMResponse(content="model failed", finish_reason="error", tool_calls=[])
provider.chat_with_retry = chat_with_retry
tools = MagicMock()
tools.get_definitions.return_value = []
class ErrorHook(AgentHook):
async def before_run(self, context: AgentRunHookContext) -> None:
events.append(("before_run", context.stop_reason))
async def on_error(self, context: AgentRunHookContext) -> None:
events.append(("on_error", context.stop_reason, context.error, context.exception))
async def after_run(self, context: AgentRunHookContext) -> None:
events.append(("after_run", context.stop_reason, context.error))
async def on_finally(self, context: AgentRunHookContext) -> None:
events.append(("on_finally", context.stop_reason, context.error))
runner = AgentRunner(provider)
result = await runner.run(AgentRunSpec(
initial_messages=[],
tools=tools,
model="test-model",
max_iterations=1,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
hook=ErrorHook(),
))
assert result.stop_reason == "error"
assert result.error == "model failed"
assert events == [
("before_run", None),
("on_error", "error", "model failed", None),
("after_run", "error", "model failed"),
("on_finally", "error", "model failed"),
]
@pytest.mark.asyncio
async def test_runner_calls_on_error_and_finally_for_unhandled_exception():
from nanobot.agent.hook import AgentHook, AgentRunHookContext
from nanobot.agent.runner import AgentRunner, AgentRunSpec
provider = MagicMock(spec=LLMProvider)
events: list[tuple] = []
async def chat_with_retry(**kwargs):
raise RuntimeError("provider exploded")
provider.chat_with_retry = chat_with_retry
tools = MagicMock()
tools.get_definitions.return_value = []
class ExceptionHook(AgentHook):
async def before_run(self, context: AgentRunHookContext) -> None:
events.append(("before_run", list(context.messages)))
async def on_error(self, context: AgentRunHookContext) -> None:
events.append((
"on_error",
context.stop_reason,
context.error,
type(context.exception).__name__ if context.exception else None,
))
async def after_run(self, context: AgentRunHookContext) -> None:
events.append(("after_run", context.stop_reason))
async def on_finally(self, context: AgentRunHookContext) -> None:
events.append(("on_finally", context.stop_reason))
runner = AgentRunner(provider)
with pytest.raises(RuntimeError, match="provider exploded"):
await runner.run(AgentRunSpec(
initial_messages=[{"role": "user", "content": "hi"}],
tools=tools,
model="test-model",
max_iterations=1,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
hook=ExceptionHook(),
))
assert events == [
("before_run", [{"role": "user", "content": "hi"}]),
("on_error", "error", "Error: RuntimeError: provider exploded", "RuntimeError"),
("on_finally", "error"),
]