nanobot/tests/agent/test_runner_hooks.py
chengyongru 8933da1ec5 fix: harden run-level hook lifecycle
maintainer edit: keep cancellation out of on_error so shutdown paths do not look like run failures, and let the SDK capture hook use the authoritative after_run snapshot.
2026-06-05 01:09:45 +08:00

391 lines
13 KiB
Python

"""Tests for AgentRunner hook lifecycle: ordering, streaming deltas,
cached-token propagation, and hook context."""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock
import pytest
from nanobot.config.schema import AgentDefaults
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars
@pytest.mark.asyncio
async def test_runner_calls_hooks_in_order():
from nanobot.agent.hook import AgentHook, AgentHookContext
from nanobot.agent.runner import AgentRunner, AgentRunSpec
provider = MagicMock(spec=LLMProvider)
call_count = {"n": 0}
events: list[tuple] = []
async def chat_with_retry(**kwargs):
call_count["n"] += 1
if call_count["n"] == 1:
return LLMResponse(
content="thinking",
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
)
return LLMResponse(content="done", tool_calls=[], usage={})
provider.chat_with_retry = chat_with_retry
tools = MagicMock()
tools.get_definitions.return_value = []
tools.execute = AsyncMock(return_value="tool result")
class RecordingHook(AgentHook):
async def before_iteration(self, context: AgentHookContext) -> None:
events.append(("before_iteration", context.iteration))
async def before_execute_tools(self, context: AgentHookContext) -> None:
events.append((
"before_execute_tools",
context.iteration,
[tc.name for tc in context.tool_calls],
))
async def after_iteration(self, context: AgentHookContext) -> None:
events.append((
"after_iteration",
context.iteration,
context.final_content,
list(context.tool_results),
list(context.tool_events),
context.stop_reason,
))
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
events.append(("finalize_content", context.iteration, content))
return content.upper() if content else content
runner = AgentRunner(provider)
result = await runner.run(AgentRunSpec(
initial_messages=[],
tools=tools,
model="test-model",
max_iterations=3,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
hook=RecordingHook(),
))
assert result.final_content == "DONE"
assert events == [
("before_iteration", 0),
("before_execute_tools", 0, ["list_dir"]),
(
"after_iteration",
0,
None,
["tool result"],
[{"name": "list_dir", "status": "ok", "detail": "tool result"}],
None,
),
("before_iteration", 1),
("finalize_content", 1, "done"),
("after_iteration", 1, "DONE", [], [], "completed"),
]
@pytest.mark.asyncio
async def test_runner_streaming_hook_receives_deltas_and_end_signal():
from nanobot.agent.hook import AgentHook, AgentHookContext
from nanobot.agent.runner import AgentRunner, AgentRunSpec
provider = MagicMock(spec=LLMProvider)
streamed: list[str] = []
endings: list[bool] = []
async def chat_stream_with_retry(*, on_content_delta, **kwargs):
await on_content_delta("he")
await on_content_delta("llo")
return LLMResponse(content="hello", tool_calls=[], usage={})
provider.chat_stream_with_retry = chat_stream_with_retry
provider.chat_with_retry = AsyncMock()
tools = MagicMock()
tools.get_definitions.return_value = []
class StreamingHook(AgentHook):
def wants_streaming(self) -> bool:
return True
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
streamed.append(delta)
async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None:
endings.append(resuming)
runner = AgentRunner(provider)
result = await runner.run(AgentRunSpec(
initial_messages=[],
tools=tools,
model="test-model",
max_iterations=1,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
hook=StreamingHook(),
))
assert result.final_content == "hello"
assert streamed == ["he", "llo"]
assert endings == [False]
provider.chat_with_retry.assert_not_awaited()
@pytest.mark.asyncio
async def test_runner_passes_cached_tokens_to_hook_context():
"""Hook context.usage should contain cached_tokens."""
from nanobot.agent.hook import AgentHook, AgentHookContext
from nanobot.agent.runner import AgentRunner, AgentRunSpec
provider = MagicMock(spec=LLMProvider)
captured_usage: list[dict] = []
class UsageHook(AgentHook):
async def after_iteration(self, context: AgentHookContext) -> None:
captured_usage.append(dict(context.usage))
async def chat_with_retry(**kwargs):
return LLMResponse(
content="done",
tool_calls=[],
usage={"prompt_tokens": 200, "completion_tokens": 20, "cached_tokens": 150},
)
provider.chat_with_retry = chat_with_retry
tools = MagicMock()
tools.get_definitions.return_value = []
runner = AgentRunner(provider)
await runner.run(AgentRunSpec(
initial_messages=[],
tools=tools,
model="test-model",
max_iterations=1,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
hook=UsageHook(),
))
assert len(captured_usage) == 1
assert captured_usage[0]["cached_tokens"] == 150
@pytest.mark.asyncio
async def test_runner_calls_run_level_hooks_on_success():
from nanobot.agent.hook import AgentHook, AgentRunHookContext
from nanobot.agent.runner import AgentRunner, AgentRunSpec
provider = MagicMock(spec=LLMProvider)
events: list[tuple] = []
async def chat_with_retry(**kwargs):
events.append(("request_messages", list(kwargs["messages"])))
return LLMResponse(
content="done",
tool_calls=[],
usage={"prompt_tokens": 3, "completion_tokens": 2},
)
provider.chat_with_retry = chat_with_retry
tools = MagicMock()
tools.get_definitions.return_value = []
class RunHook(AgentHook):
async def before_run(self, context: AgentRunHookContext) -> None:
events.append(("before_run", list(context.messages), context.stop_reason))
context.messages.append({"role": "user", "content": "hook-only"})
async def after_run(self, context: AgentRunHookContext) -> None:
events.append((
"after_run",
context.final_content,
context.stop_reason,
context.error,
dict(context.usage),
[msg["role"] for msg in context.messages],
))
async def on_error(self, context: AgentRunHookContext) -> None:
events.append(("on_error", context.error))
async def on_finally(self, context: AgentRunHookContext) -> None:
events.append(("on_finally", context.stop_reason, context.exception))
runner = AgentRunner(provider)
result = await runner.run(AgentRunSpec(
initial_messages=[{"role": "user", "content": "hi"}],
tools=tools,
model="test-model",
max_iterations=1,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
hook=RunHook(),
))
assert result.final_content == "done"
assert events == [
("before_run", [{"role": "user", "content": "hi"}], None),
("request_messages", [{"role": "user", "content": "hi"}]),
(
"after_run",
"done",
"completed",
None,
{"prompt_tokens": 3, "completion_tokens": 2},
["user", "assistant"],
),
("on_finally", "completed", None),
]
@pytest.mark.asyncio
async def test_runner_calls_on_error_for_model_error_result():
from nanobot.agent.hook import AgentHook, AgentRunHookContext
from nanobot.agent.runner import AgentRunner, AgentRunSpec
provider = MagicMock(spec=LLMProvider)
events: list[tuple] = []
async def chat_with_retry(**kwargs):
return LLMResponse(content="model failed", finish_reason="error", tool_calls=[])
provider.chat_with_retry = chat_with_retry
tools = MagicMock()
tools.get_definitions.return_value = []
class ErrorHook(AgentHook):
async def before_run(self, context: AgentRunHookContext) -> None:
events.append(("before_run", context.stop_reason))
async def on_error(self, context: AgentRunHookContext) -> None:
events.append(("on_error", context.stop_reason, context.error, context.exception))
async def after_run(self, context: AgentRunHookContext) -> None:
events.append(("after_run", context.stop_reason, context.error))
async def on_finally(self, context: AgentRunHookContext) -> None:
events.append(("on_finally", context.stop_reason, context.error))
runner = AgentRunner(provider)
result = await runner.run(AgentRunSpec(
initial_messages=[],
tools=tools,
model="test-model",
max_iterations=1,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
hook=ErrorHook(),
))
assert result.stop_reason == "error"
assert result.error == "model failed"
assert events == [
("before_run", None),
("on_error", "error", "model failed", None),
("after_run", "error", "model failed"),
("on_finally", "error", "model failed"),
]
@pytest.mark.asyncio
async def test_runner_calls_on_error_and_finally_for_unhandled_exception():
from nanobot.agent.hook import AgentHook, AgentRunHookContext
from nanobot.agent.runner import AgentRunner, AgentRunSpec
provider = MagicMock(spec=LLMProvider)
events: list[tuple] = []
async def chat_with_retry(**kwargs):
raise RuntimeError("provider exploded")
provider.chat_with_retry = chat_with_retry
tools = MagicMock()
tools.get_definitions.return_value = []
class ExceptionHook(AgentHook):
async def before_run(self, context: AgentRunHookContext) -> None:
events.append(("before_run", list(context.messages)))
async def on_error(self, context: AgentRunHookContext) -> None:
events.append((
"on_error",
context.stop_reason,
context.error,
type(context.exception).__name__ if context.exception else None,
))
async def after_run(self, context: AgentRunHookContext) -> None:
events.append(("after_run", context.stop_reason))
async def on_finally(self, context: AgentRunHookContext) -> None:
events.append(("on_finally", context.stop_reason))
runner = AgentRunner(provider)
with pytest.raises(RuntimeError, match="provider exploded"):
await runner.run(AgentRunSpec(
initial_messages=[{"role": "user", "content": "hi"}],
tools=tools,
model="test-model",
max_iterations=1,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
hook=ExceptionHook(),
))
assert events == [
("before_run", [{"role": "user", "content": "hi"}]),
("on_error", "error", "Error: RuntimeError: provider exploded", "RuntimeError"),
("on_finally", "error"),
]
@pytest.mark.asyncio
async def test_runner_does_not_report_cancellation_as_error():
import asyncio
from nanobot.agent.hook import AgentHook, AgentRunHookContext
from nanobot.agent.runner import AgentRunner, AgentRunSpec
provider = MagicMock(spec=LLMProvider)
events: list[tuple] = []
async def chat_with_retry(**kwargs):
raise asyncio.CancelledError()
provider.chat_with_retry = chat_with_retry
tools = MagicMock()
tools.get_definitions.return_value = []
class CancellationHook(AgentHook):
async def before_run(self, context: AgentRunHookContext) -> None:
events.append(("before_run", context.stop_reason))
async def on_error(self, context: AgentRunHookContext) -> None:
events.append(("on_error", context.stop_reason, context.error))
async def after_run(self, context: AgentRunHookContext) -> None:
events.append(("after_run", context.stop_reason))
async def on_finally(self, context: AgentRunHookContext) -> None:
events.append((
"on_finally",
context.stop_reason,
context.error,
type(context.exception).__name__ if context.exception else None,
))
runner = AgentRunner(provider)
with pytest.raises(asyncio.CancelledError):
await runner.run(AgentRunSpec(
initial_messages=[{"role": "user", "content": "hi"}],
tools=tools,
model="test-model",
max_iterations=1,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
hook=CancellationHook(),
))
assert events == [
("before_run", None),
("on_finally", "cancelled", None, "CancelledError"),
]