diff --git a/nanobot/agent/hook.py b/nanobot/agent/hook.py index 1e299dd15..ed2c95498 100644 --- a/nanobot/agent/hook.py +++ b/nanobot/agent/hook.py @@ -166,7 +166,9 @@ class SDKCaptureHook(AgentHook): The runner mutates ``context.messages`` in place across iterations, so the snapshot is refreshed on every ``after_iteration`` call; the last call - reflects the end-of-turn state the SDK caller cares about. + reflects the end-of-turn state the SDK caller cares about. The run-level + snapshot is authoritative when available and covers paths without a final + per-iteration callback. """ def __init__(self) -> None: @@ -178,3 +180,7 @@ class SDKCaptureHook(AgentHook): for call in context.tool_calls: self.tools_used.append(call.name) self.messages = list(context.messages) + + async def after_run(self, context: AgentRunHookContext) -> None: + self.tools_used = list(context.tools_used) + self.messages = list(context.messages) diff --git a/nanobot/agent/runner.py b/nanobot/agent/runner.py index f74f916b3..ff320c4fa 100644 --- a/nanobot/agent/runner.py +++ b/nanobot/agent/runner.py @@ -277,7 +277,13 @@ class AgentRunner: try: await hook.before_run(context) result = await self._run_core(spec, hook, messages) - except BaseException as exc: + except asyncio.CancelledError as exc: + context.messages = list(messages) + context.stop_reason = "cancelled" + context.error = None + context.exception = exc + raise + except Exception as exc: context.messages = list(messages) context.stop_reason = "error" context.error = f"Error: {type(exc).__name__}: {exc}" diff --git a/tests/agent/test_runner_hooks.py b/tests/agent/test_runner_hooks.py index 0ccd718cf..cc278d480 100644 --- a/tests/agent/test_runner_hooks.py +++ b/tests/agent/test_runner_hooks.py @@ -336,3 +336,55 @@ async def test_runner_calls_on_error_and_finally_for_unhandled_exception(): ("on_error", "error", "Error: RuntimeError: provider exploded", "RuntimeError"), ("on_finally", "error"), ] + + +@pytest.mark.asyncio +async def test_runner_does_not_report_cancellation_as_error(): + import asyncio + + from nanobot.agent.hook import AgentHook, AgentRunHookContext + from nanobot.agent.runner import AgentRunner, AgentRunSpec + + provider = MagicMock(spec=LLMProvider) + events: list[tuple] = [] + + async def chat_with_retry(**kwargs): + raise asyncio.CancelledError() + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + + class CancellationHook(AgentHook): + async def before_run(self, context: AgentRunHookContext) -> None: + events.append(("before_run", context.stop_reason)) + + async def on_error(self, context: AgentRunHookContext) -> None: + events.append(("on_error", context.stop_reason, context.error)) + + async def after_run(self, context: AgentRunHookContext) -> None: + events.append(("after_run", context.stop_reason)) + + async def on_finally(self, context: AgentRunHookContext) -> None: + events.append(( + "on_finally", + context.stop_reason, + context.error, + type(context.exception).__name__ if context.exception else None, + )) + + runner = AgentRunner(provider) + with pytest.raises(asyncio.CancelledError): + await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "hi"}], + tools=tools, + model="test-model", + max_iterations=1, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + hook=CancellationHook(), + )) + + assert events == [ + ("before_run", None), + ("on_finally", "cancelled", None, "CancelledError"), + ] diff --git a/tests/test_nanobot_facade.py b/tests/test_nanobot_facade.py index c2ef35f9f..24f4681db 100644 --- a/tests/test_nanobot_facade.py +++ b/tests/test_nanobot_facade.py @@ -299,3 +299,30 @@ async def test_run_restores_extra_hooks_even_on_populated_iterations(tmp_path): bot._loop.process_direct = fake_process_direct await bot.run("hello") assert bot._loop._extra_hooks == [sentinel_hook] + + +@pytest.mark.asyncio +async def test_sdk_capture_prefers_run_level_snapshot(): + from nanobot.agent.hook import AgentHookContext, AgentRunHookContext, SDKCaptureHook + from nanobot.providers.base import ToolCallRequest + + hook = SDKCaptureHook() + iter_messages = [{"role": "user", "content": "work"}] + iter_context = AgentHookContext(iteration=0, messages=iter_messages) + iter_context.tool_calls = [ + ToolCallRequest(id="call_1", name="read_file", arguments={}), + ToolCallRequest(id="call_2", name="grep", arguments={}), + ] + await hook.after_iteration(iter_context) + + final_messages = [ + {"role": "user", "content": "work"}, + {"role": "assistant", "content": "done"}, + ] + await hook.after_run(AgentRunHookContext( + messages=final_messages, + tools_used=["read_file"], + )) + + assert hook.tools_used == ["read_file"] + assert hook.messages == final_messages