diff --git a/nanobot/agent/__init__.py b/nanobot/agent/__init__.py index 7d3ab2af4..47d453a8b 100644 --- a/nanobot/agent/__init__.py +++ b/nanobot/agent/__init__.py @@ -1,7 +1,7 @@ """Agent core module.""" from nanobot.agent.context import ContextBuilder -from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook +from nanobot.agent.hook import AgentHook, AgentHookContext, AgentRunHookContext, CompositeHook from nanobot.agent.loop import AgentLoop from nanobot.agent.memory import MemoryStore from nanobot.agent.skills import SkillsLoader @@ -10,6 +10,7 @@ from nanobot.agent.subagent import SubagentManager __all__ = [ "AgentHook", "AgentHookContext", + "AgentRunHookContext", "AgentLoop", "CompositeHook", "ContextBuilder", diff --git a/nanobot/agent/hook.py b/nanobot/agent/hook.py index 5b6fed445..1e299dd15 100644 --- a/nanobot/agent/hook.py +++ b/nanobot/agent/hook.py @@ -28,6 +28,21 @@ class AgentHookContext: error: str | None = None +@dataclass(slots=True) +class AgentRunHookContext: + """Run-level state snapshot exposed to runner hooks.""" + + messages: list[dict[str, Any]] + final_content: str | None = None + tools_used: list[str] = field(default_factory=list) + usage: dict[str, int] = field(default_factory=dict) + stop_reason: str | None = None + error: str | None = None + tool_events: list[dict[str, str]] = field(default_factory=list) + had_injections: bool = False + exception: BaseException | None = None + + class AgentHook: """Minimal lifecycle surface for shared runner customization.""" @@ -37,6 +52,18 @@ class AgentHook: def wants_streaming(self) -> bool: return False + async def before_run(self, context: AgentRunHookContext) -> None: + pass + + async def after_run(self, context: AgentRunHookContext) -> None: + pass + + async def on_error(self, context: AgentRunHookContext) -> None: + pass + + async def on_finally(self, context: AgentRunHookContext) -> None: + pass + async def before_iteration(self, context: AgentHookContext) -> None: pass @@ -98,6 +125,18 @@ class CompositeHook(AgentHook): async def before_iteration(self, context: AgentHookContext) -> None: await self._for_each_hook_safe("before_iteration", context) + async def before_run(self, context: AgentRunHookContext) -> None: + await self._for_each_hook_safe("before_run", context) + + async def after_run(self, context: AgentRunHookContext) -> None: + await self._for_each_hook_safe("after_run", context) + + async def on_error(self, context: AgentRunHookContext) -> None: + await self._for_each_hook_safe("on_error", context) + + async def on_finally(self, context: AgentRunHookContext) -> None: + await self._for_each_hook_safe("on_finally", context) + async def on_stream(self, context: AgentHookContext, delta: str) -> None: await self._for_each_hook_safe("on_stream", context, delta) diff --git a/nanobot/agent/runner.py b/nanobot/agent/runner.py index 83438dbef..f74f916b3 100644 --- a/nanobot/agent/runner.py +++ b/nanobot/agent/runner.py @@ -12,7 +12,7 @@ from typing import Any, Callable from loguru import logger -from nanobot.agent.hook import AgentHook, AgentHookContext +from nanobot.agent.hook import AgentHook, AgentHookContext, AgentRunHookContext from nanobot.agent.tools.registry import ToolRegistry from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest from nanobot.utils.file_edit_events import ( @@ -272,6 +272,42 @@ class AgentRunner: async def run(self, spec: AgentRunSpec) -> AgentRunResult: hook = spec.hook or AgentHook() messages = list(spec.initial_messages) + context = AgentRunHookContext(messages=list(messages)) + + try: + await hook.before_run(context) + result = await self._run_core(spec, hook, messages) + except BaseException as exc: + context.messages = list(messages) + context.stop_reason = "error" + context.error = f"Error: {type(exc).__name__}: {exc}" + context.exception = exc + await hook.on_error(context) + raise + else: + context.messages = list(result.messages) + context.final_content = result.final_content + context.tools_used = list(result.tools_used) + context.usage = dict(result.usage) + context.stop_reason = result.stop_reason + context.error = result.error + context.tool_events = list(result.tool_events) + context.had_injections = result.had_injections + context.exception = None + if context.error is not None: + await hook.on_error(context) + await hook.after_run(context) + return result + finally: + context.messages = list(messages) + await hook.on_finally(context) + + async def _run_core( + self, + spec: AgentRunSpec, + hook: AgentHook, + messages: list[dict[str, Any]], + ) -> AgentRunResult: final_content: str | None = None tools_used: list[str] = [] usage: dict[str, int] = {"prompt_tokens": 0, "completion_tokens": 0} diff --git a/tests/agent/test_hook_composite.py b/tests/agent/test_hook_composite.py index 7ee5dd6b5..58fdec6b2 100644 --- a/tests/agent/test_hook_composite.py +++ b/tests/agent/test_hook_composite.py @@ -6,13 +6,17 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest -from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook +from nanobot.agent.hook import AgentHook, AgentHookContext, AgentRunHookContext, CompositeHook def _ctx() -> AgentHookContext: return AgentHookContext(iteration=0, messages=[]) +def _run_ctx() -> AgentRunHookContext: + return AgentRunHookContext(messages=[]) + + # --------------------------------------------------------------------------- # Base AgentHook emit_reasoning: no-op # --------------------------------------------------------------------------- @@ -54,6 +58,9 @@ async def test_composite_fans_out_all_async_methods(): events: list[str] = [] class RecordingHook(AgentHook): + async def before_run(self, context: AgentRunHookContext) -> None: + events.append("before_run") + async def before_iteration(self, context: AgentHookContext) -> None: events.append("before_iteration") @@ -72,23 +79,41 @@ async def test_composite_fans_out_all_async_methods(): async def after_iteration(self, context: AgentHookContext) -> None: events.append("after_iteration") + async def after_run(self, context: AgentRunHookContext) -> None: + events.append("after_run") + + async def on_error(self, context: AgentRunHookContext) -> None: + events.append("on_error") + + async def on_finally(self, context: AgentRunHookContext) -> None: + events.append("on_finally") + hook = CompositeHook([RecordingHook(), RecordingHook()]) ctx = _ctx() + run_ctx = _run_ctx() + await hook.before_run(run_ctx) await hook.before_iteration(ctx) await hook.emit_reasoning("thinking...") await hook.on_stream(ctx, "hi") await hook.on_stream_end(ctx, resuming=True) await hook.before_execute_tools(ctx) await hook.after_iteration(ctx) + await hook.after_run(run_ctx) + await hook.on_error(run_ctx) + await hook.on_finally(run_ctx) assert events == [ + "before_run", "before_run", "before_iteration", "before_iteration", "emit_reasoning:thinking...", "emit_reasoning:thinking...", "on_stream:hi", "on_stream:hi", "on_stream_end:True", "on_stream_end:True", "before_execute_tools", "before_execute_tools", "after_iteration", "after_iteration", + "after_run", "after_run", + "on_error", "on_error", + "on_finally", "on_finally", ] @@ -137,6 +162,8 @@ async def test_composite_error_isolation_all_async(): calls: list[str] = [] class Bad(AgentHook): + async def before_run(self, context): + raise RuntimeError("err") async def emit_reasoning(self, reasoning_content): raise RuntimeError("err") async def on_stream_end(self, context, *, resuming): @@ -145,8 +172,16 @@ async def test_composite_error_isolation_all_async(): raise RuntimeError("err") async def after_iteration(self, context): raise RuntimeError("err") + async def after_run(self, context): + raise RuntimeError("err") + async def on_error(self, context): + raise RuntimeError("err") + async def on_finally(self, context): + raise RuntimeError("err") class Good(AgentHook): + async def before_run(self, context): + calls.append("before_run") async def emit_reasoning(self, reasoning_content): calls.append("emit_reasoning") async def on_stream_end(self, context, *, resuming): @@ -155,14 +190,34 @@ async def test_composite_error_isolation_all_async(): calls.append("before_execute_tools") async def after_iteration(self, context): calls.append("after_iteration") + async def after_run(self, context): + calls.append("after_run") + async def on_error(self, context): + calls.append("on_error") + async def on_finally(self, context): + calls.append("on_finally") hook = CompositeHook([Bad(), Good()]) ctx = _ctx() + run_ctx = _run_ctx() + await hook.before_run(run_ctx) await hook.emit_reasoning("test") await hook.on_stream_end(ctx, resuming=False) await hook.before_execute_tools(ctx) await hook.after_iteration(ctx) - assert calls == ["emit_reasoning", "on_stream_end", "before_execute_tools", "after_iteration"] + await hook.after_run(run_ctx) + await hook.on_error(run_ctx) + await hook.on_finally(run_ctx) + assert calls == [ + "before_run", + "emit_reasoning", + "on_stream_end", + "before_execute_tools", + "after_iteration", + "after_run", + "on_error", + "on_finally", + ] # --------------------------------------------------------------------------- @@ -246,11 +301,16 @@ def test_composite_wants_streaming_empty(): async def test_composite_empty_hooks_no_ops(): hook = CompositeHook([]) ctx = _ctx() + run_ctx = _run_ctx() + await hook.before_run(run_ctx) await hook.before_iteration(ctx) await hook.on_stream(ctx, "delta") await hook.on_stream_end(ctx, resuming=False) await hook.before_execute_tools(ctx) await hook.after_iteration(ctx) + await hook.after_run(run_ctx) + await hook.on_error(run_ctx) + await hook.on_finally(run_ctx) assert hook.finalize_content(ctx, "test") == "test" @@ -316,12 +376,18 @@ async def test_agent_loop_extra_hook_receives_calls(tmp_path): events: list[str] = [] class TrackingHook(AgentHook): + async def before_run(self, context): + events.append("before_run") + async def before_iteration(self, context): events.append(f"before_iter:{context.iteration}") async def after_iteration(self, context): events.append(f"after_iter:{context.iteration}") + async def after_run(self, context): + events.append(f"after_run:{context.stop_reason}") + loop = _make_loop(tmp_path, hooks=[TrackingHook()]) loop.provider.chat_with_retry = AsyncMock( return_value=LLMResponse(content="done", tool_calls=[], usage={}) @@ -333,8 +399,10 @@ async def test_agent_loop_extra_hook_receives_calls(tmp_path): ) assert content == "done" + assert "before_run" in events assert "before_iter:0" in events assert "after_iter:0" in events + assert "after_run:completed" in events @pytest.mark.asyncio diff --git a/tests/agent/test_runner_hooks.py b/tests/agent/test_runner_hooks.py index 7718eee20..0ccd718cf 100644 --- a/tests/agent/test_runner_hooks.py +++ b/tests/agent/test_runner_hooks.py @@ -16,7 +16,7 @@ _MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars @pytest.mark.asyncio async def test_runner_calls_hooks_in_order(): from nanobot.agent.hook import AgentHook, AgentHookContext - from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.agent.runner import AgentRunner, AgentRunSpec provider = MagicMock(spec=LLMProvider) call_count = {"n": 0} @@ -92,7 +92,7 @@ async def test_runner_calls_hooks_in_order(): @pytest.mark.asyncio async def test_runner_streaming_hook_receives_deltas_and_end_signal(): from nanobot.agent.hook import AgentHook, AgentHookContext - from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.agent.runner import AgentRunner, AgentRunSpec provider = MagicMock(spec=LLMProvider) streamed: list[str] = [] @@ -138,7 +138,7 @@ async def test_runner_streaming_hook_receives_deltas_and_end_signal(): async def test_runner_passes_cached_tokens_to_hook_context(): """Hook context.usage should contain cached_tokens.""" from nanobot.agent.hook import AgentHook, AgentHookContext - from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.agent.runner import AgentRunner, AgentRunSpec provider = MagicMock(spec=LLMProvider) captured_usage: list[dict] = [] @@ -170,3 +170,169 @@ async def test_runner_passes_cached_tokens_to_hook_context(): assert len(captured_usage) == 1 assert captured_usage[0]["cached_tokens"] == 150 + + +@pytest.mark.asyncio +async def test_runner_calls_run_level_hooks_on_success(): + from nanobot.agent.hook import AgentHook, AgentRunHookContext + from nanobot.agent.runner import AgentRunner, AgentRunSpec + + provider = MagicMock(spec=LLMProvider) + events: list[tuple] = [] + + async def chat_with_retry(**kwargs): + events.append(("request_messages", list(kwargs["messages"]))) + return LLMResponse( + content="done", + tool_calls=[], + usage={"prompt_tokens": 3, "completion_tokens": 2}, + ) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + + class RunHook(AgentHook): + async def before_run(self, context: AgentRunHookContext) -> None: + events.append(("before_run", list(context.messages), context.stop_reason)) + context.messages.append({"role": "user", "content": "hook-only"}) + + async def after_run(self, context: AgentRunHookContext) -> None: + events.append(( + "after_run", + context.final_content, + context.stop_reason, + context.error, + dict(context.usage), + [msg["role"] for msg in context.messages], + )) + + async def on_error(self, context: AgentRunHookContext) -> None: + events.append(("on_error", context.error)) + + async def on_finally(self, context: AgentRunHookContext) -> None: + events.append(("on_finally", context.stop_reason, context.exception)) + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "hi"}], + tools=tools, + model="test-model", + max_iterations=1, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + hook=RunHook(), + )) + + assert result.final_content == "done" + assert events == [ + ("before_run", [{"role": "user", "content": "hi"}], None), + ("request_messages", [{"role": "user", "content": "hi"}]), + ( + "after_run", + "done", + "completed", + None, + {"prompt_tokens": 3, "completion_tokens": 2}, + ["user", "assistant"], + ), + ("on_finally", "completed", None), + ] + + +@pytest.mark.asyncio +async def test_runner_calls_on_error_for_model_error_result(): + from nanobot.agent.hook import AgentHook, AgentRunHookContext + from nanobot.agent.runner import AgentRunner, AgentRunSpec + + provider = MagicMock(spec=LLMProvider) + events: list[tuple] = [] + + async def chat_with_retry(**kwargs): + return LLMResponse(content="model failed", finish_reason="error", tool_calls=[]) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + + class ErrorHook(AgentHook): + async def before_run(self, context: AgentRunHookContext) -> None: + events.append(("before_run", context.stop_reason)) + + async def on_error(self, context: AgentRunHookContext) -> None: + events.append(("on_error", context.stop_reason, context.error, context.exception)) + + async def after_run(self, context: AgentRunHookContext) -> None: + events.append(("after_run", context.stop_reason, context.error)) + + async def on_finally(self, context: AgentRunHookContext) -> None: + events.append(("on_finally", context.stop_reason, context.error)) + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[], + tools=tools, + model="test-model", + max_iterations=1, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + hook=ErrorHook(), + )) + + assert result.stop_reason == "error" + assert result.error == "model failed" + assert events == [ + ("before_run", None), + ("on_error", "error", "model failed", None), + ("after_run", "error", "model failed"), + ("on_finally", "error", "model failed"), + ] + + +@pytest.mark.asyncio +async def test_runner_calls_on_error_and_finally_for_unhandled_exception(): + from nanobot.agent.hook import AgentHook, AgentRunHookContext + from nanobot.agent.runner import AgentRunner, AgentRunSpec + + provider = MagicMock(spec=LLMProvider) + events: list[tuple] = [] + + async def chat_with_retry(**kwargs): + raise RuntimeError("provider exploded") + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + + class ExceptionHook(AgentHook): + async def before_run(self, context: AgentRunHookContext) -> None: + events.append(("before_run", list(context.messages))) + + async def on_error(self, context: AgentRunHookContext) -> None: + events.append(( + "on_error", + context.stop_reason, + context.error, + type(context.exception).__name__ if context.exception else None, + )) + + async def after_run(self, context: AgentRunHookContext) -> None: + events.append(("after_run", context.stop_reason)) + + async def on_finally(self, context: AgentRunHookContext) -> None: + events.append(("on_finally", context.stop_reason)) + + runner = AgentRunner(provider) + with pytest.raises(RuntimeError, match="provider exploded"): + await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "hi"}], + tools=tools, + model="test-model", + max_iterations=1, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + hook=ExceptionHook(), + )) + + assert events == [ + ("before_run", [{"role": "user", "content": "hi"}]), + ("on_error", "error", "Error: RuntimeError: provider exploded", "RuntimeError"), + ("on_finally", "error"), + ]