fix: harden run-level hook lifecycle

maintainer edit: keep cancellation out of on_error so shutdown paths do not look like run failures, and let the SDK capture hook use the authoritative after_run snapshot.
This commit is contained in:
chengyongru 2026-06-03 21:11:20 +08:00 committed by Xubin Ren
parent 2ea226055e
commit 8933da1ec5
4 changed files with 93 additions and 2 deletions

View File

@ -166,7 +166,9 @@ class SDKCaptureHook(AgentHook):
The runner mutates ``context.messages`` in place across iterations, so the
snapshot is refreshed on every ``after_iteration`` call; the last call
reflects the end-of-turn state the SDK caller cares about.
reflects the end-of-turn state the SDK caller cares about. The run-level
snapshot is authoritative when available and covers paths without a final
per-iteration callback.
"""
def __init__(self) -> None:
@ -178,3 +180,7 @@ class SDKCaptureHook(AgentHook):
for call in context.tool_calls:
self.tools_used.append(call.name)
self.messages = list(context.messages)
async def after_run(self, context: AgentRunHookContext) -> None:
self.tools_used = list(context.tools_used)
self.messages = list(context.messages)

View File

@ -277,7 +277,13 @@ class AgentRunner:
try:
await hook.before_run(context)
result = await self._run_core(spec, hook, messages)
except BaseException as exc:
except asyncio.CancelledError as exc:
context.messages = list(messages)
context.stop_reason = "cancelled"
context.error = None
context.exception = exc
raise
except Exception as exc:
context.messages = list(messages)
context.stop_reason = "error"
context.error = f"Error: {type(exc).__name__}: {exc}"

View File

@ -336,3 +336,55 @@ async def test_runner_calls_on_error_and_finally_for_unhandled_exception():
("on_error", "error", "Error: RuntimeError: provider exploded", "RuntimeError"),
("on_finally", "error"),
]
@pytest.mark.asyncio
async def test_runner_does_not_report_cancellation_as_error():
import asyncio
from nanobot.agent.hook import AgentHook, AgentRunHookContext
from nanobot.agent.runner import AgentRunner, AgentRunSpec
provider = MagicMock(spec=LLMProvider)
events: list[tuple] = []
async def chat_with_retry(**kwargs):
raise asyncio.CancelledError()
provider.chat_with_retry = chat_with_retry
tools = MagicMock()
tools.get_definitions.return_value = []
class CancellationHook(AgentHook):
async def before_run(self, context: AgentRunHookContext) -> None:
events.append(("before_run", context.stop_reason))
async def on_error(self, context: AgentRunHookContext) -> None:
events.append(("on_error", context.stop_reason, context.error))
async def after_run(self, context: AgentRunHookContext) -> None:
events.append(("after_run", context.stop_reason))
async def on_finally(self, context: AgentRunHookContext) -> None:
events.append((
"on_finally",
context.stop_reason,
context.error,
type(context.exception).__name__ if context.exception else None,
))
runner = AgentRunner(provider)
with pytest.raises(asyncio.CancelledError):
await runner.run(AgentRunSpec(
initial_messages=[{"role": "user", "content": "hi"}],
tools=tools,
model="test-model",
max_iterations=1,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
hook=CancellationHook(),
))
assert events == [
("before_run", None),
("on_finally", "cancelled", None, "CancelledError"),
]

View File

@ -299,3 +299,30 @@ async def test_run_restores_extra_hooks_even_on_populated_iterations(tmp_path):
bot._loop.process_direct = fake_process_direct
await bot.run("hello")
assert bot._loop._extra_hooks == [sentinel_hook]
@pytest.mark.asyncio
async def test_sdk_capture_prefers_run_level_snapshot():
from nanobot.agent.hook import AgentHookContext, AgentRunHookContext, SDKCaptureHook
from nanobot.providers.base import ToolCallRequest
hook = SDKCaptureHook()
iter_messages = [{"role": "user", "content": "work"}]
iter_context = AgentHookContext(iteration=0, messages=iter_messages)
iter_context.tool_calls = [
ToolCallRequest(id="call_1", name="read_file", arguments={}),
ToolCallRequest(id="call_2", name="grep", arguments={}),
]
await hook.after_iteration(iter_context)
final_messages = [
{"role": "user", "content": "work"},
{"role": "assistant", "content": "done"},
]
await hook.after_run(AgentRunHookContext(
messages=final_messages,
tools_used=["read_file"],
))
assert hook.tools_used == ["read_file"]
assert hook.messages == final_messages