mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-14 14:54:06 +00:00
feat: add run-level agent hook lifecycle
This commit is contained in:
parent
c77ca16d91
commit
2ea226055e
@ -1,7 +1,7 @@
|
|||||||
"""Agent core module."""
|
"""Agent core module."""
|
||||||
|
|
||||||
from nanobot.agent.context import ContextBuilder
|
from nanobot.agent.context import ContextBuilder
|
||||||
from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook
|
from nanobot.agent.hook import AgentHook, AgentHookContext, AgentRunHookContext, CompositeHook
|
||||||
from nanobot.agent.loop import AgentLoop
|
from nanobot.agent.loop import AgentLoop
|
||||||
from nanobot.agent.memory import MemoryStore
|
from nanobot.agent.memory import MemoryStore
|
||||||
from nanobot.agent.skills import SkillsLoader
|
from nanobot.agent.skills import SkillsLoader
|
||||||
@ -10,6 +10,7 @@ from nanobot.agent.subagent import SubagentManager
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"AgentHook",
|
"AgentHook",
|
||||||
"AgentHookContext",
|
"AgentHookContext",
|
||||||
|
"AgentRunHookContext",
|
||||||
"AgentLoop",
|
"AgentLoop",
|
||||||
"CompositeHook",
|
"CompositeHook",
|
||||||
"ContextBuilder",
|
"ContextBuilder",
|
||||||
|
|||||||
@ -28,6 +28,21 @@ class AgentHookContext:
|
|||||||
error: str | None = None
|
error: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class AgentRunHookContext:
|
||||||
|
"""Run-level state snapshot exposed to runner hooks."""
|
||||||
|
|
||||||
|
messages: list[dict[str, Any]]
|
||||||
|
final_content: str | None = None
|
||||||
|
tools_used: list[str] = field(default_factory=list)
|
||||||
|
usage: dict[str, int] = field(default_factory=dict)
|
||||||
|
stop_reason: str | None = None
|
||||||
|
error: str | None = None
|
||||||
|
tool_events: list[dict[str, str]] = field(default_factory=list)
|
||||||
|
had_injections: bool = False
|
||||||
|
exception: BaseException | None = None
|
||||||
|
|
||||||
|
|
||||||
class AgentHook:
|
class AgentHook:
|
||||||
"""Minimal lifecycle surface for shared runner customization."""
|
"""Minimal lifecycle surface for shared runner customization."""
|
||||||
|
|
||||||
@ -37,6 +52,18 @@ class AgentHook:
|
|||||||
def wants_streaming(self) -> bool:
|
def wants_streaming(self) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
async def before_run(self, context: AgentRunHookContext) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def after_run(self, context: AgentRunHookContext) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def on_error(self, context: AgentRunHookContext) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def on_finally(self, context: AgentRunHookContext) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
async def before_iteration(self, context: AgentHookContext) -> None:
|
async def before_iteration(self, context: AgentHookContext) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -98,6 +125,18 @@ class CompositeHook(AgentHook):
|
|||||||
async def before_iteration(self, context: AgentHookContext) -> None:
|
async def before_iteration(self, context: AgentHookContext) -> None:
|
||||||
await self._for_each_hook_safe("before_iteration", context)
|
await self._for_each_hook_safe("before_iteration", context)
|
||||||
|
|
||||||
|
async def before_run(self, context: AgentRunHookContext) -> None:
|
||||||
|
await self._for_each_hook_safe("before_run", context)
|
||||||
|
|
||||||
|
async def after_run(self, context: AgentRunHookContext) -> None:
|
||||||
|
await self._for_each_hook_safe("after_run", context)
|
||||||
|
|
||||||
|
async def on_error(self, context: AgentRunHookContext) -> None:
|
||||||
|
await self._for_each_hook_safe("on_error", context)
|
||||||
|
|
||||||
|
async def on_finally(self, context: AgentRunHookContext) -> None:
|
||||||
|
await self._for_each_hook_safe("on_finally", context)
|
||||||
|
|
||||||
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
|
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
|
||||||
await self._for_each_hook_safe("on_stream", context, delta)
|
await self._for_each_hook_safe("on_stream", context, delta)
|
||||||
|
|
||||||
|
|||||||
@ -12,7 +12,7 @@ from typing import Any, Callable
|
|||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from nanobot.agent.hook import AgentHook, AgentHookContext
|
from nanobot.agent.hook import AgentHook, AgentHookContext, AgentRunHookContext
|
||||||
from nanobot.agent.tools.registry import ToolRegistry
|
from nanobot.agent.tools.registry import ToolRegistry
|
||||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||||
from nanobot.utils.file_edit_events import (
|
from nanobot.utils.file_edit_events import (
|
||||||
@ -272,6 +272,42 @@ class AgentRunner:
|
|||||||
async def run(self, spec: AgentRunSpec) -> AgentRunResult:
|
async def run(self, spec: AgentRunSpec) -> AgentRunResult:
|
||||||
hook = spec.hook or AgentHook()
|
hook = spec.hook or AgentHook()
|
||||||
messages = list(spec.initial_messages)
|
messages = list(spec.initial_messages)
|
||||||
|
context = AgentRunHookContext(messages=list(messages))
|
||||||
|
|
||||||
|
try:
|
||||||
|
await hook.before_run(context)
|
||||||
|
result = await self._run_core(spec, hook, messages)
|
||||||
|
except BaseException as exc:
|
||||||
|
context.messages = list(messages)
|
||||||
|
context.stop_reason = "error"
|
||||||
|
context.error = f"Error: {type(exc).__name__}: {exc}"
|
||||||
|
context.exception = exc
|
||||||
|
await hook.on_error(context)
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
context.messages = list(result.messages)
|
||||||
|
context.final_content = result.final_content
|
||||||
|
context.tools_used = list(result.tools_used)
|
||||||
|
context.usage = dict(result.usage)
|
||||||
|
context.stop_reason = result.stop_reason
|
||||||
|
context.error = result.error
|
||||||
|
context.tool_events = list(result.tool_events)
|
||||||
|
context.had_injections = result.had_injections
|
||||||
|
context.exception = None
|
||||||
|
if context.error is not None:
|
||||||
|
await hook.on_error(context)
|
||||||
|
await hook.after_run(context)
|
||||||
|
return result
|
||||||
|
finally:
|
||||||
|
context.messages = list(messages)
|
||||||
|
await hook.on_finally(context)
|
||||||
|
|
||||||
|
async def _run_core(
|
||||||
|
self,
|
||||||
|
spec: AgentRunSpec,
|
||||||
|
hook: AgentHook,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
) -> AgentRunResult:
|
||||||
final_content: str | None = None
|
final_content: str | None = None
|
||||||
tools_used: list[str] = []
|
tools_used: list[str] = []
|
||||||
usage: dict[str, int] = {"prompt_tokens": 0, "completion_tokens": 0}
|
usage: dict[str, int] = {"prompt_tokens": 0, "completion_tokens": 0}
|
||||||
|
|||||||
@ -6,13 +6,17 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook
|
from nanobot.agent.hook import AgentHook, AgentHookContext, AgentRunHookContext, CompositeHook
|
||||||
|
|
||||||
|
|
||||||
def _ctx() -> AgentHookContext:
|
def _ctx() -> AgentHookContext:
|
||||||
return AgentHookContext(iteration=0, messages=[])
|
return AgentHookContext(iteration=0, messages=[])
|
||||||
|
|
||||||
|
|
||||||
|
def _run_ctx() -> AgentRunHookContext:
|
||||||
|
return AgentRunHookContext(messages=[])
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Base AgentHook emit_reasoning: no-op
|
# Base AgentHook emit_reasoning: no-op
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@ -54,6 +58,9 @@ async def test_composite_fans_out_all_async_methods():
|
|||||||
events: list[str] = []
|
events: list[str] = []
|
||||||
|
|
||||||
class RecordingHook(AgentHook):
|
class RecordingHook(AgentHook):
|
||||||
|
async def before_run(self, context: AgentRunHookContext) -> None:
|
||||||
|
events.append("before_run")
|
||||||
|
|
||||||
async def before_iteration(self, context: AgentHookContext) -> None:
|
async def before_iteration(self, context: AgentHookContext) -> None:
|
||||||
events.append("before_iteration")
|
events.append("before_iteration")
|
||||||
|
|
||||||
@ -72,23 +79,41 @@ async def test_composite_fans_out_all_async_methods():
|
|||||||
async def after_iteration(self, context: AgentHookContext) -> None:
|
async def after_iteration(self, context: AgentHookContext) -> None:
|
||||||
events.append("after_iteration")
|
events.append("after_iteration")
|
||||||
|
|
||||||
|
async def after_run(self, context: AgentRunHookContext) -> None:
|
||||||
|
events.append("after_run")
|
||||||
|
|
||||||
|
async def on_error(self, context: AgentRunHookContext) -> None:
|
||||||
|
events.append("on_error")
|
||||||
|
|
||||||
|
async def on_finally(self, context: AgentRunHookContext) -> None:
|
||||||
|
events.append("on_finally")
|
||||||
|
|
||||||
hook = CompositeHook([RecordingHook(), RecordingHook()])
|
hook = CompositeHook([RecordingHook(), RecordingHook()])
|
||||||
ctx = _ctx()
|
ctx = _ctx()
|
||||||
|
run_ctx = _run_ctx()
|
||||||
|
|
||||||
|
await hook.before_run(run_ctx)
|
||||||
await hook.before_iteration(ctx)
|
await hook.before_iteration(ctx)
|
||||||
await hook.emit_reasoning("thinking...")
|
await hook.emit_reasoning("thinking...")
|
||||||
await hook.on_stream(ctx, "hi")
|
await hook.on_stream(ctx, "hi")
|
||||||
await hook.on_stream_end(ctx, resuming=True)
|
await hook.on_stream_end(ctx, resuming=True)
|
||||||
await hook.before_execute_tools(ctx)
|
await hook.before_execute_tools(ctx)
|
||||||
await hook.after_iteration(ctx)
|
await hook.after_iteration(ctx)
|
||||||
|
await hook.after_run(run_ctx)
|
||||||
|
await hook.on_error(run_ctx)
|
||||||
|
await hook.on_finally(run_ctx)
|
||||||
|
|
||||||
assert events == [
|
assert events == [
|
||||||
|
"before_run", "before_run",
|
||||||
"before_iteration", "before_iteration",
|
"before_iteration", "before_iteration",
|
||||||
"emit_reasoning:thinking...", "emit_reasoning:thinking...",
|
"emit_reasoning:thinking...", "emit_reasoning:thinking...",
|
||||||
"on_stream:hi", "on_stream:hi",
|
"on_stream:hi", "on_stream:hi",
|
||||||
"on_stream_end:True", "on_stream_end:True",
|
"on_stream_end:True", "on_stream_end:True",
|
||||||
"before_execute_tools", "before_execute_tools",
|
"before_execute_tools", "before_execute_tools",
|
||||||
"after_iteration", "after_iteration",
|
"after_iteration", "after_iteration",
|
||||||
|
"after_run", "after_run",
|
||||||
|
"on_error", "on_error",
|
||||||
|
"on_finally", "on_finally",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -137,6 +162,8 @@ async def test_composite_error_isolation_all_async():
|
|||||||
calls: list[str] = []
|
calls: list[str] = []
|
||||||
|
|
||||||
class Bad(AgentHook):
|
class Bad(AgentHook):
|
||||||
|
async def before_run(self, context):
|
||||||
|
raise RuntimeError("err")
|
||||||
async def emit_reasoning(self, reasoning_content):
|
async def emit_reasoning(self, reasoning_content):
|
||||||
raise RuntimeError("err")
|
raise RuntimeError("err")
|
||||||
async def on_stream_end(self, context, *, resuming):
|
async def on_stream_end(self, context, *, resuming):
|
||||||
@ -145,8 +172,16 @@ async def test_composite_error_isolation_all_async():
|
|||||||
raise RuntimeError("err")
|
raise RuntimeError("err")
|
||||||
async def after_iteration(self, context):
|
async def after_iteration(self, context):
|
||||||
raise RuntimeError("err")
|
raise RuntimeError("err")
|
||||||
|
async def after_run(self, context):
|
||||||
|
raise RuntimeError("err")
|
||||||
|
async def on_error(self, context):
|
||||||
|
raise RuntimeError("err")
|
||||||
|
async def on_finally(self, context):
|
||||||
|
raise RuntimeError("err")
|
||||||
|
|
||||||
class Good(AgentHook):
|
class Good(AgentHook):
|
||||||
|
async def before_run(self, context):
|
||||||
|
calls.append("before_run")
|
||||||
async def emit_reasoning(self, reasoning_content):
|
async def emit_reasoning(self, reasoning_content):
|
||||||
calls.append("emit_reasoning")
|
calls.append("emit_reasoning")
|
||||||
async def on_stream_end(self, context, *, resuming):
|
async def on_stream_end(self, context, *, resuming):
|
||||||
@ -155,14 +190,34 @@ async def test_composite_error_isolation_all_async():
|
|||||||
calls.append("before_execute_tools")
|
calls.append("before_execute_tools")
|
||||||
async def after_iteration(self, context):
|
async def after_iteration(self, context):
|
||||||
calls.append("after_iteration")
|
calls.append("after_iteration")
|
||||||
|
async def after_run(self, context):
|
||||||
|
calls.append("after_run")
|
||||||
|
async def on_error(self, context):
|
||||||
|
calls.append("on_error")
|
||||||
|
async def on_finally(self, context):
|
||||||
|
calls.append("on_finally")
|
||||||
|
|
||||||
hook = CompositeHook([Bad(), Good()])
|
hook = CompositeHook([Bad(), Good()])
|
||||||
ctx = _ctx()
|
ctx = _ctx()
|
||||||
|
run_ctx = _run_ctx()
|
||||||
|
await hook.before_run(run_ctx)
|
||||||
await hook.emit_reasoning("test")
|
await hook.emit_reasoning("test")
|
||||||
await hook.on_stream_end(ctx, resuming=False)
|
await hook.on_stream_end(ctx, resuming=False)
|
||||||
await hook.before_execute_tools(ctx)
|
await hook.before_execute_tools(ctx)
|
||||||
await hook.after_iteration(ctx)
|
await hook.after_iteration(ctx)
|
||||||
assert calls == ["emit_reasoning", "on_stream_end", "before_execute_tools", "after_iteration"]
|
await hook.after_run(run_ctx)
|
||||||
|
await hook.on_error(run_ctx)
|
||||||
|
await hook.on_finally(run_ctx)
|
||||||
|
assert calls == [
|
||||||
|
"before_run",
|
||||||
|
"emit_reasoning",
|
||||||
|
"on_stream_end",
|
||||||
|
"before_execute_tools",
|
||||||
|
"after_iteration",
|
||||||
|
"after_run",
|
||||||
|
"on_error",
|
||||||
|
"on_finally",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@ -246,11 +301,16 @@ def test_composite_wants_streaming_empty():
|
|||||||
async def test_composite_empty_hooks_no_ops():
|
async def test_composite_empty_hooks_no_ops():
|
||||||
hook = CompositeHook([])
|
hook = CompositeHook([])
|
||||||
ctx = _ctx()
|
ctx = _ctx()
|
||||||
|
run_ctx = _run_ctx()
|
||||||
|
await hook.before_run(run_ctx)
|
||||||
await hook.before_iteration(ctx)
|
await hook.before_iteration(ctx)
|
||||||
await hook.on_stream(ctx, "delta")
|
await hook.on_stream(ctx, "delta")
|
||||||
await hook.on_stream_end(ctx, resuming=False)
|
await hook.on_stream_end(ctx, resuming=False)
|
||||||
await hook.before_execute_tools(ctx)
|
await hook.before_execute_tools(ctx)
|
||||||
await hook.after_iteration(ctx)
|
await hook.after_iteration(ctx)
|
||||||
|
await hook.after_run(run_ctx)
|
||||||
|
await hook.on_error(run_ctx)
|
||||||
|
await hook.on_finally(run_ctx)
|
||||||
assert hook.finalize_content(ctx, "test") == "test"
|
assert hook.finalize_content(ctx, "test") == "test"
|
||||||
|
|
||||||
|
|
||||||
@ -316,12 +376,18 @@ async def test_agent_loop_extra_hook_receives_calls(tmp_path):
|
|||||||
events: list[str] = []
|
events: list[str] = []
|
||||||
|
|
||||||
class TrackingHook(AgentHook):
|
class TrackingHook(AgentHook):
|
||||||
|
async def before_run(self, context):
|
||||||
|
events.append("before_run")
|
||||||
|
|
||||||
async def before_iteration(self, context):
|
async def before_iteration(self, context):
|
||||||
events.append(f"before_iter:{context.iteration}")
|
events.append(f"before_iter:{context.iteration}")
|
||||||
|
|
||||||
async def after_iteration(self, context):
|
async def after_iteration(self, context):
|
||||||
events.append(f"after_iter:{context.iteration}")
|
events.append(f"after_iter:{context.iteration}")
|
||||||
|
|
||||||
|
async def after_run(self, context):
|
||||||
|
events.append(f"after_run:{context.stop_reason}")
|
||||||
|
|
||||||
loop = _make_loop(tmp_path, hooks=[TrackingHook()])
|
loop = _make_loop(tmp_path, hooks=[TrackingHook()])
|
||||||
loop.provider.chat_with_retry = AsyncMock(
|
loop.provider.chat_with_retry = AsyncMock(
|
||||||
return_value=LLMResponse(content="done", tool_calls=[], usage={})
|
return_value=LLMResponse(content="done", tool_calls=[], usage={})
|
||||||
@ -333,8 +399,10 @@ async def test_agent_loop_extra_hook_receives_calls(tmp_path):
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert content == "done"
|
assert content == "done"
|
||||||
|
assert "before_run" in events
|
||||||
assert "before_iter:0" in events
|
assert "before_iter:0" in events
|
||||||
assert "after_iter:0" in events
|
assert "after_iter:0" in events
|
||||||
|
assert "after_run:completed" in events
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
@ -16,7 +16,7 @@ _MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_runner_calls_hooks_in_order():
|
async def test_runner_calls_hooks_in_order():
|
||||||
from nanobot.agent.hook import AgentHook, AgentHookContext
|
from nanobot.agent.hook import AgentHook, AgentHookContext
|
||||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
from nanobot.agent.runner import AgentRunner, AgentRunSpec
|
||||||
|
|
||||||
provider = MagicMock(spec=LLMProvider)
|
provider = MagicMock(spec=LLMProvider)
|
||||||
call_count = {"n": 0}
|
call_count = {"n": 0}
|
||||||
@ -92,7 +92,7 @@ async def test_runner_calls_hooks_in_order():
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_runner_streaming_hook_receives_deltas_and_end_signal():
|
async def test_runner_streaming_hook_receives_deltas_and_end_signal():
|
||||||
from nanobot.agent.hook import AgentHook, AgentHookContext
|
from nanobot.agent.hook import AgentHook, AgentHookContext
|
||||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
from nanobot.agent.runner import AgentRunner, AgentRunSpec
|
||||||
|
|
||||||
provider = MagicMock(spec=LLMProvider)
|
provider = MagicMock(spec=LLMProvider)
|
||||||
streamed: list[str] = []
|
streamed: list[str] = []
|
||||||
@ -138,7 +138,7 @@ async def test_runner_streaming_hook_receives_deltas_and_end_signal():
|
|||||||
async def test_runner_passes_cached_tokens_to_hook_context():
|
async def test_runner_passes_cached_tokens_to_hook_context():
|
||||||
"""Hook context.usage should contain cached_tokens."""
|
"""Hook context.usage should contain cached_tokens."""
|
||||||
from nanobot.agent.hook import AgentHook, AgentHookContext
|
from nanobot.agent.hook import AgentHook, AgentHookContext
|
||||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
from nanobot.agent.runner import AgentRunner, AgentRunSpec
|
||||||
|
|
||||||
provider = MagicMock(spec=LLMProvider)
|
provider = MagicMock(spec=LLMProvider)
|
||||||
captured_usage: list[dict] = []
|
captured_usage: list[dict] = []
|
||||||
@ -170,3 +170,169 @@ async def test_runner_passes_cached_tokens_to_hook_context():
|
|||||||
|
|
||||||
assert len(captured_usage) == 1
|
assert len(captured_usage) == 1
|
||||||
assert captured_usage[0]["cached_tokens"] == 150
|
assert captured_usage[0]["cached_tokens"] == 150
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_calls_run_level_hooks_on_success():
|
||||||
|
from nanobot.agent.hook import AgentHook, AgentRunHookContext
|
||||||
|
from nanobot.agent.runner import AgentRunner, AgentRunSpec
|
||||||
|
|
||||||
|
provider = MagicMock(spec=LLMProvider)
|
||||||
|
events: list[tuple] = []
|
||||||
|
|
||||||
|
async def chat_with_retry(**kwargs):
|
||||||
|
events.append(("request_messages", list(kwargs["messages"])))
|
||||||
|
return LLMResponse(
|
||||||
|
content="done",
|
||||||
|
tool_calls=[],
|
||||||
|
usage={"prompt_tokens": 3, "completion_tokens": 2},
|
||||||
|
)
|
||||||
|
|
||||||
|
provider.chat_with_retry = chat_with_retry
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
|
||||||
|
class RunHook(AgentHook):
|
||||||
|
async def before_run(self, context: AgentRunHookContext) -> None:
|
||||||
|
events.append(("before_run", list(context.messages), context.stop_reason))
|
||||||
|
context.messages.append({"role": "user", "content": "hook-only"})
|
||||||
|
|
||||||
|
async def after_run(self, context: AgentRunHookContext) -> None:
|
||||||
|
events.append((
|
||||||
|
"after_run",
|
||||||
|
context.final_content,
|
||||||
|
context.stop_reason,
|
||||||
|
context.error,
|
||||||
|
dict(context.usage),
|
||||||
|
[msg["role"] for msg in context.messages],
|
||||||
|
))
|
||||||
|
|
||||||
|
async def on_error(self, context: AgentRunHookContext) -> None:
|
||||||
|
events.append(("on_error", context.error))
|
||||||
|
|
||||||
|
async def on_finally(self, context: AgentRunHookContext) -> None:
|
||||||
|
events.append(("on_finally", context.stop_reason, context.exception))
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
result = await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=[{"role": "user", "content": "hi"}],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=1,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
hook=RunHook(),
|
||||||
|
))
|
||||||
|
|
||||||
|
assert result.final_content == "done"
|
||||||
|
assert events == [
|
||||||
|
("before_run", [{"role": "user", "content": "hi"}], None),
|
||||||
|
("request_messages", [{"role": "user", "content": "hi"}]),
|
||||||
|
(
|
||||||
|
"after_run",
|
||||||
|
"done",
|
||||||
|
"completed",
|
||||||
|
None,
|
||||||
|
{"prompt_tokens": 3, "completion_tokens": 2},
|
||||||
|
["user", "assistant"],
|
||||||
|
),
|
||||||
|
("on_finally", "completed", None),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_calls_on_error_for_model_error_result():
|
||||||
|
from nanobot.agent.hook import AgentHook, AgentRunHookContext
|
||||||
|
from nanobot.agent.runner import AgentRunner, AgentRunSpec
|
||||||
|
|
||||||
|
provider = MagicMock(spec=LLMProvider)
|
||||||
|
events: list[tuple] = []
|
||||||
|
|
||||||
|
async def chat_with_retry(**kwargs):
|
||||||
|
return LLMResponse(content="model failed", finish_reason="error", tool_calls=[])
|
||||||
|
|
||||||
|
provider.chat_with_retry = chat_with_retry
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
|
||||||
|
class ErrorHook(AgentHook):
|
||||||
|
async def before_run(self, context: AgentRunHookContext) -> None:
|
||||||
|
events.append(("before_run", context.stop_reason))
|
||||||
|
|
||||||
|
async def on_error(self, context: AgentRunHookContext) -> None:
|
||||||
|
events.append(("on_error", context.stop_reason, context.error, context.exception))
|
||||||
|
|
||||||
|
async def after_run(self, context: AgentRunHookContext) -> None:
|
||||||
|
events.append(("after_run", context.stop_reason, context.error))
|
||||||
|
|
||||||
|
async def on_finally(self, context: AgentRunHookContext) -> None:
|
||||||
|
events.append(("on_finally", context.stop_reason, context.error))
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
result = await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=[],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=1,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
hook=ErrorHook(),
|
||||||
|
))
|
||||||
|
|
||||||
|
assert result.stop_reason == "error"
|
||||||
|
assert result.error == "model failed"
|
||||||
|
assert events == [
|
||||||
|
("before_run", None),
|
||||||
|
("on_error", "error", "model failed", None),
|
||||||
|
("after_run", "error", "model failed"),
|
||||||
|
("on_finally", "error", "model failed"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_calls_on_error_and_finally_for_unhandled_exception():
|
||||||
|
from nanobot.agent.hook import AgentHook, AgentRunHookContext
|
||||||
|
from nanobot.agent.runner import AgentRunner, AgentRunSpec
|
||||||
|
|
||||||
|
provider = MagicMock(spec=LLMProvider)
|
||||||
|
events: list[tuple] = []
|
||||||
|
|
||||||
|
async def chat_with_retry(**kwargs):
|
||||||
|
raise RuntimeError("provider exploded")
|
||||||
|
|
||||||
|
provider.chat_with_retry = chat_with_retry
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
|
||||||
|
class ExceptionHook(AgentHook):
|
||||||
|
async def before_run(self, context: AgentRunHookContext) -> None:
|
||||||
|
events.append(("before_run", list(context.messages)))
|
||||||
|
|
||||||
|
async def on_error(self, context: AgentRunHookContext) -> None:
|
||||||
|
events.append((
|
||||||
|
"on_error",
|
||||||
|
context.stop_reason,
|
||||||
|
context.error,
|
||||||
|
type(context.exception).__name__ if context.exception else None,
|
||||||
|
))
|
||||||
|
|
||||||
|
async def after_run(self, context: AgentRunHookContext) -> None:
|
||||||
|
events.append(("after_run", context.stop_reason))
|
||||||
|
|
||||||
|
async def on_finally(self, context: AgentRunHookContext) -> None:
|
||||||
|
events.append(("on_finally", context.stop_reason))
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
with pytest.raises(RuntimeError, match="provider exploded"):
|
||||||
|
await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=[{"role": "user", "content": "hi"}],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=1,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
hook=ExceptionHook(),
|
||||||
|
))
|
||||||
|
|
||||||
|
assert events == [
|
||||||
|
("before_run", [{"role": "user", "content": "hi"}]),
|
||||||
|
("on_error", "error", "Error: RuntimeError: provider exploded", "RuntimeError"),
|
||||||
|
("on_finally", "error"),
|
||||||
|
]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user