mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-13 22:34:06 +00:00
feat: add run-level agent hook lifecycle
This commit is contained in:
parent
c77ca16d91
commit
2ea226055e
@ -1,7 +1,7 @@
|
||||
"""Agent core module."""
|
||||
|
||||
from nanobot.agent.context import ContextBuilder
|
||||
from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook
|
||||
from nanobot.agent.hook import AgentHook, AgentHookContext, AgentRunHookContext, CompositeHook
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.agent.memory import MemoryStore
|
||||
from nanobot.agent.skills import SkillsLoader
|
||||
@ -10,6 +10,7 @@ from nanobot.agent.subagent import SubagentManager
|
||||
__all__ = [
|
||||
"AgentHook",
|
||||
"AgentHookContext",
|
||||
"AgentRunHookContext",
|
||||
"AgentLoop",
|
||||
"CompositeHook",
|
||||
"ContextBuilder",
|
||||
|
||||
@ -28,6 +28,21 @@ class AgentHookContext:
|
||||
error: str | None = None
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class AgentRunHookContext:
|
||||
"""Run-level state snapshot exposed to runner hooks."""
|
||||
|
||||
messages: list[dict[str, Any]]
|
||||
final_content: str | None = None
|
||||
tools_used: list[str] = field(default_factory=list)
|
||||
usage: dict[str, int] = field(default_factory=dict)
|
||||
stop_reason: str | None = None
|
||||
error: str | None = None
|
||||
tool_events: list[dict[str, str]] = field(default_factory=list)
|
||||
had_injections: bool = False
|
||||
exception: BaseException | None = None
|
||||
|
||||
|
||||
class AgentHook:
|
||||
"""Minimal lifecycle surface for shared runner customization."""
|
||||
|
||||
@ -37,6 +52,18 @@ class AgentHook:
|
||||
def wants_streaming(self) -> bool:
|
||||
return False
|
||||
|
||||
async def before_run(self, context: AgentRunHookContext) -> None:
|
||||
pass
|
||||
|
||||
async def after_run(self, context: AgentRunHookContext) -> None:
|
||||
pass
|
||||
|
||||
async def on_error(self, context: AgentRunHookContext) -> None:
|
||||
pass
|
||||
|
||||
async def on_finally(self, context: AgentRunHookContext) -> None:
|
||||
pass
|
||||
|
||||
async def before_iteration(self, context: AgentHookContext) -> None:
|
||||
pass
|
||||
|
||||
@ -98,6 +125,18 @@ class CompositeHook(AgentHook):
|
||||
async def before_iteration(self, context: AgentHookContext) -> None:
|
||||
await self._for_each_hook_safe("before_iteration", context)
|
||||
|
||||
async def before_run(self, context: AgentRunHookContext) -> None:
|
||||
await self._for_each_hook_safe("before_run", context)
|
||||
|
||||
async def after_run(self, context: AgentRunHookContext) -> None:
|
||||
await self._for_each_hook_safe("after_run", context)
|
||||
|
||||
async def on_error(self, context: AgentRunHookContext) -> None:
|
||||
await self._for_each_hook_safe("on_error", context)
|
||||
|
||||
async def on_finally(self, context: AgentRunHookContext) -> None:
|
||||
await self._for_each_hook_safe("on_finally", context)
|
||||
|
||||
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
|
||||
await self._for_each_hook_safe("on_stream", context, delta)
|
||||
|
||||
|
||||
@ -12,7 +12,7 @@ from typing import Any, Callable
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.agent.hook import AgentHook, AgentHookContext
|
||||
from nanobot.agent.hook import AgentHook, AgentHookContext, AgentRunHookContext
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
from nanobot.utils.file_edit_events import (
|
||||
@ -272,6 +272,42 @@ class AgentRunner:
|
||||
async def run(self, spec: AgentRunSpec) -> AgentRunResult:
|
||||
hook = spec.hook or AgentHook()
|
||||
messages = list(spec.initial_messages)
|
||||
context = AgentRunHookContext(messages=list(messages))
|
||||
|
||||
try:
|
||||
await hook.before_run(context)
|
||||
result = await self._run_core(spec, hook, messages)
|
||||
except BaseException as exc:
|
||||
context.messages = list(messages)
|
||||
context.stop_reason = "error"
|
||||
context.error = f"Error: {type(exc).__name__}: {exc}"
|
||||
context.exception = exc
|
||||
await hook.on_error(context)
|
||||
raise
|
||||
else:
|
||||
context.messages = list(result.messages)
|
||||
context.final_content = result.final_content
|
||||
context.tools_used = list(result.tools_used)
|
||||
context.usage = dict(result.usage)
|
||||
context.stop_reason = result.stop_reason
|
||||
context.error = result.error
|
||||
context.tool_events = list(result.tool_events)
|
||||
context.had_injections = result.had_injections
|
||||
context.exception = None
|
||||
if context.error is not None:
|
||||
await hook.on_error(context)
|
||||
await hook.after_run(context)
|
||||
return result
|
||||
finally:
|
||||
context.messages = list(messages)
|
||||
await hook.on_finally(context)
|
||||
|
||||
async def _run_core(
|
||||
self,
|
||||
spec: AgentRunSpec,
|
||||
hook: AgentHook,
|
||||
messages: list[dict[str, Any]],
|
||||
) -> AgentRunResult:
|
||||
final_content: str | None = None
|
||||
tools_used: list[str] = []
|
||||
usage: dict[str, int] = {"prompt_tokens": 0, "completion_tokens": 0}
|
||||
|
||||
@ -6,13 +6,17 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook
|
||||
from nanobot.agent.hook import AgentHook, AgentHookContext, AgentRunHookContext, CompositeHook
|
||||
|
||||
|
||||
def _ctx() -> AgentHookContext:
|
||||
return AgentHookContext(iteration=0, messages=[])
|
||||
|
||||
|
||||
def _run_ctx() -> AgentRunHookContext:
|
||||
return AgentRunHookContext(messages=[])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Base AgentHook emit_reasoning: no-op
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -54,6 +58,9 @@ async def test_composite_fans_out_all_async_methods():
|
||||
events: list[str] = []
|
||||
|
||||
class RecordingHook(AgentHook):
|
||||
async def before_run(self, context: AgentRunHookContext) -> None:
|
||||
events.append("before_run")
|
||||
|
||||
async def before_iteration(self, context: AgentHookContext) -> None:
|
||||
events.append("before_iteration")
|
||||
|
||||
@ -72,23 +79,41 @@ async def test_composite_fans_out_all_async_methods():
|
||||
async def after_iteration(self, context: AgentHookContext) -> None:
|
||||
events.append("after_iteration")
|
||||
|
||||
async def after_run(self, context: AgentRunHookContext) -> None:
|
||||
events.append("after_run")
|
||||
|
||||
async def on_error(self, context: AgentRunHookContext) -> None:
|
||||
events.append("on_error")
|
||||
|
||||
async def on_finally(self, context: AgentRunHookContext) -> None:
|
||||
events.append("on_finally")
|
||||
|
||||
hook = CompositeHook([RecordingHook(), RecordingHook()])
|
||||
ctx = _ctx()
|
||||
run_ctx = _run_ctx()
|
||||
|
||||
await hook.before_run(run_ctx)
|
||||
await hook.before_iteration(ctx)
|
||||
await hook.emit_reasoning("thinking...")
|
||||
await hook.on_stream(ctx, "hi")
|
||||
await hook.on_stream_end(ctx, resuming=True)
|
||||
await hook.before_execute_tools(ctx)
|
||||
await hook.after_iteration(ctx)
|
||||
await hook.after_run(run_ctx)
|
||||
await hook.on_error(run_ctx)
|
||||
await hook.on_finally(run_ctx)
|
||||
|
||||
assert events == [
|
||||
"before_run", "before_run",
|
||||
"before_iteration", "before_iteration",
|
||||
"emit_reasoning:thinking...", "emit_reasoning:thinking...",
|
||||
"on_stream:hi", "on_stream:hi",
|
||||
"on_stream_end:True", "on_stream_end:True",
|
||||
"before_execute_tools", "before_execute_tools",
|
||||
"after_iteration", "after_iteration",
|
||||
"after_run", "after_run",
|
||||
"on_error", "on_error",
|
||||
"on_finally", "on_finally",
|
||||
]
|
||||
|
||||
|
||||
@ -137,6 +162,8 @@ async def test_composite_error_isolation_all_async():
|
||||
calls: list[str] = []
|
||||
|
||||
class Bad(AgentHook):
|
||||
async def before_run(self, context):
|
||||
raise RuntimeError("err")
|
||||
async def emit_reasoning(self, reasoning_content):
|
||||
raise RuntimeError("err")
|
||||
async def on_stream_end(self, context, *, resuming):
|
||||
@ -145,8 +172,16 @@ async def test_composite_error_isolation_all_async():
|
||||
raise RuntimeError("err")
|
||||
async def after_iteration(self, context):
|
||||
raise RuntimeError("err")
|
||||
async def after_run(self, context):
|
||||
raise RuntimeError("err")
|
||||
async def on_error(self, context):
|
||||
raise RuntimeError("err")
|
||||
async def on_finally(self, context):
|
||||
raise RuntimeError("err")
|
||||
|
||||
class Good(AgentHook):
|
||||
async def before_run(self, context):
|
||||
calls.append("before_run")
|
||||
async def emit_reasoning(self, reasoning_content):
|
||||
calls.append("emit_reasoning")
|
||||
async def on_stream_end(self, context, *, resuming):
|
||||
@ -155,14 +190,34 @@ async def test_composite_error_isolation_all_async():
|
||||
calls.append("before_execute_tools")
|
||||
async def after_iteration(self, context):
|
||||
calls.append("after_iteration")
|
||||
async def after_run(self, context):
|
||||
calls.append("after_run")
|
||||
async def on_error(self, context):
|
||||
calls.append("on_error")
|
||||
async def on_finally(self, context):
|
||||
calls.append("on_finally")
|
||||
|
||||
hook = CompositeHook([Bad(), Good()])
|
||||
ctx = _ctx()
|
||||
run_ctx = _run_ctx()
|
||||
await hook.before_run(run_ctx)
|
||||
await hook.emit_reasoning("test")
|
||||
await hook.on_stream_end(ctx, resuming=False)
|
||||
await hook.before_execute_tools(ctx)
|
||||
await hook.after_iteration(ctx)
|
||||
assert calls == ["emit_reasoning", "on_stream_end", "before_execute_tools", "after_iteration"]
|
||||
await hook.after_run(run_ctx)
|
||||
await hook.on_error(run_ctx)
|
||||
await hook.on_finally(run_ctx)
|
||||
assert calls == [
|
||||
"before_run",
|
||||
"emit_reasoning",
|
||||
"on_stream_end",
|
||||
"before_execute_tools",
|
||||
"after_iteration",
|
||||
"after_run",
|
||||
"on_error",
|
||||
"on_finally",
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -246,11 +301,16 @@ def test_composite_wants_streaming_empty():
|
||||
async def test_composite_empty_hooks_no_ops():
|
||||
hook = CompositeHook([])
|
||||
ctx = _ctx()
|
||||
run_ctx = _run_ctx()
|
||||
await hook.before_run(run_ctx)
|
||||
await hook.before_iteration(ctx)
|
||||
await hook.on_stream(ctx, "delta")
|
||||
await hook.on_stream_end(ctx, resuming=False)
|
||||
await hook.before_execute_tools(ctx)
|
||||
await hook.after_iteration(ctx)
|
||||
await hook.after_run(run_ctx)
|
||||
await hook.on_error(run_ctx)
|
||||
await hook.on_finally(run_ctx)
|
||||
assert hook.finalize_content(ctx, "test") == "test"
|
||||
|
||||
|
||||
@ -316,12 +376,18 @@ async def test_agent_loop_extra_hook_receives_calls(tmp_path):
|
||||
events: list[str] = []
|
||||
|
||||
class TrackingHook(AgentHook):
|
||||
async def before_run(self, context):
|
||||
events.append("before_run")
|
||||
|
||||
async def before_iteration(self, context):
|
||||
events.append(f"before_iter:{context.iteration}")
|
||||
|
||||
async def after_iteration(self, context):
|
||||
events.append(f"after_iter:{context.iteration}")
|
||||
|
||||
async def after_run(self, context):
|
||||
events.append(f"after_run:{context.stop_reason}")
|
||||
|
||||
loop = _make_loop(tmp_path, hooks=[TrackingHook()])
|
||||
loop.provider.chat_with_retry = AsyncMock(
|
||||
return_value=LLMResponse(content="done", tool_calls=[], usage={})
|
||||
@ -333,8 +399,10 @@ async def test_agent_loop_extra_hook_receives_calls(tmp_path):
|
||||
)
|
||||
|
||||
assert content == "done"
|
||||
assert "before_run" in events
|
||||
assert "before_iter:0" in events
|
||||
assert "after_iter:0" in events
|
||||
assert "after_run:completed" in events
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@ -16,7 +16,7 @@ _MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_calls_hooks_in_order():
|
||||
from nanobot.agent.hook import AgentHook, AgentHookContext
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
from nanobot.agent.runner import AgentRunner, AgentRunSpec
|
||||
|
||||
provider = MagicMock(spec=LLMProvider)
|
||||
call_count = {"n": 0}
|
||||
@ -92,7 +92,7 @@ async def test_runner_calls_hooks_in_order():
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_streaming_hook_receives_deltas_and_end_signal():
|
||||
from nanobot.agent.hook import AgentHook, AgentHookContext
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
from nanobot.agent.runner import AgentRunner, AgentRunSpec
|
||||
|
||||
provider = MagicMock(spec=LLMProvider)
|
||||
streamed: list[str] = []
|
||||
@ -138,7 +138,7 @@ async def test_runner_streaming_hook_receives_deltas_and_end_signal():
|
||||
async def test_runner_passes_cached_tokens_to_hook_context():
|
||||
"""Hook context.usage should contain cached_tokens."""
|
||||
from nanobot.agent.hook import AgentHook, AgentHookContext
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
from nanobot.agent.runner import AgentRunner, AgentRunSpec
|
||||
|
||||
provider = MagicMock(spec=LLMProvider)
|
||||
captured_usage: list[dict] = []
|
||||
@ -170,3 +170,169 @@ async def test_runner_passes_cached_tokens_to_hook_context():
|
||||
|
||||
assert len(captured_usage) == 1
|
||||
assert captured_usage[0]["cached_tokens"] == 150
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_calls_run_level_hooks_on_success():
|
||||
from nanobot.agent.hook import AgentHook, AgentRunHookContext
|
||||
from nanobot.agent.runner import AgentRunner, AgentRunSpec
|
||||
|
||||
provider = MagicMock(spec=LLMProvider)
|
||||
events: list[tuple] = []
|
||||
|
||||
async def chat_with_retry(**kwargs):
|
||||
events.append(("request_messages", list(kwargs["messages"])))
|
||||
return LLMResponse(
|
||||
content="done",
|
||||
tool_calls=[],
|
||||
usage={"prompt_tokens": 3, "completion_tokens": 2},
|
||||
)
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
class RunHook(AgentHook):
|
||||
async def before_run(self, context: AgentRunHookContext) -> None:
|
||||
events.append(("before_run", list(context.messages), context.stop_reason))
|
||||
context.messages.append({"role": "user", "content": "hook-only"})
|
||||
|
||||
async def after_run(self, context: AgentRunHookContext) -> None:
|
||||
events.append((
|
||||
"after_run",
|
||||
context.final_content,
|
||||
context.stop_reason,
|
||||
context.error,
|
||||
dict(context.usage),
|
||||
[msg["role"] for msg in context.messages],
|
||||
))
|
||||
|
||||
async def on_error(self, context: AgentRunHookContext) -> None:
|
||||
events.append(("on_error", context.error))
|
||||
|
||||
async def on_finally(self, context: AgentRunHookContext) -> None:
|
||||
events.append(("on_finally", context.stop_reason, context.exception))
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "hi"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
hook=RunHook(),
|
||||
))
|
||||
|
||||
assert result.final_content == "done"
|
||||
assert events == [
|
||||
("before_run", [{"role": "user", "content": "hi"}], None),
|
||||
("request_messages", [{"role": "user", "content": "hi"}]),
|
||||
(
|
||||
"after_run",
|
||||
"done",
|
||||
"completed",
|
||||
None,
|
||||
{"prompt_tokens": 3, "completion_tokens": 2},
|
||||
["user", "assistant"],
|
||||
),
|
||||
("on_finally", "completed", None),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_calls_on_error_for_model_error_result():
|
||||
from nanobot.agent.hook import AgentHook, AgentRunHookContext
|
||||
from nanobot.agent.runner import AgentRunner, AgentRunSpec
|
||||
|
||||
provider = MagicMock(spec=LLMProvider)
|
||||
events: list[tuple] = []
|
||||
|
||||
async def chat_with_retry(**kwargs):
|
||||
return LLMResponse(content="model failed", finish_reason="error", tool_calls=[])
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
class ErrorHook(AgentHook):
|
||||
async def before_run(self, context: AgentRunHookContext) -> None:
|
||||
events.append(("before_run", context.stop_reason))
|
||||
|
||||
async def on_error(self, context: AgentRunHookContext) -> None:
|
||||
events.append(("on_error", context.stop_reason, context.error, context.exception))
|
||||
|
||||
async def after_run(self, context: AgentRunHookContext) -> None:
|
||||
events.append(("after_run", context.stop_reason, context.error))
|
||||
|
||||
async def on_finally(self, context: AgentRunHookContext) -> None:
|
||||
events.append(("on_finally", context.stop_reason, context.error))
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
hook=ErrorHook(),
|
||||
))
|
||||
|
||||
assert result.stop_reason == "error"
|
||||
assert result.error == "model failed"
|
||||
assert events == [
|
||||
("before_run", None),
|
||||
("on_error", "error", "model failed", None),
|
||||
("after_run", "error", "model failed"),
|
||||
("on_finally", "error", "model failed"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_calls_on_error_and_finally_for_unhandled_exception():
|
||||
from nanobot.agent.hook import AgentHook, AgentRunHookContext
|
||||
from nanobot.agent.runner import AgentRunner, AgentRunSpec
|
||||
|
||||
provider = MagicMock(spec=LLMProvider)
|
||||
events: list[tuple] = []
|
||||
|
||||
async def chat_with_retry(**kwargs):
|
||||
raise RuntimeError("provider exploded")
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
class ExceptionHook(AgentHook):
|
||||
async def before_run(self, context: AgentRunHookContext) -> None:
|
||||
events.append(("before_run", list(context.messages)))
|
||||
|
||||
async def on_error(self, context: AgentRunHookContext) -> None:
|
||||
events.append((
|
||||
"on_error",
|
||||
context.stop_reason,
|
||||
context.error,
|
||||
type(context.exception).__name__ if context.exception else None,
|
||||
))
|
||||
|
||||
async def after_run(self, context: AgentRunHookContext) -> None:
|
||||
events.append(("after_run", context.stop_reason))
|
||||
|
||||
async def on_finally(self, context: AgentRunHookContext) -> None:
|
||||
events.append(("on_finally", context.stop_reason))
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
with pytest.raises(RuntimeError, match="provider exploded"):
|
||||
await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "hi"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
hook=ExceptionHook(),
|
||||
))
|
||||
|
||||
assert events == [
|
||||
("before_run", [{"role": "user", "content": "hi"}]),
|
||||
("on_error", "error", "Error: RuntimeError: provider exploded", "RuntimeError"),
|
||||
("on_finally", "error"),
|
||||
]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user