mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-13 14:23:58 +00:00
fix: harden run-level hook lifecycle
maintainer edit: keep cancellation out of on_error so shutdown paths do not look like run failures, and let the SDK capture hook use the authoritative after_run snapshot.
This commit is contained in:
parent
2ea226055e
commit
8933da1ec5
@ -166,7 +166,9 @@ class SDKCaptureHook(AgentHook):
|
||||
|
||||
The runner mutates ``context.messages`` in place across iterations, so the
|
||||
snapshot is refreshed on every ``after_iteration`` call; the last call
|
||||
reflects the end-of-turn state the SDK caller cares about.
|
||||
reflects the end-of-turn state the SDK caller cares about. The run-level
|
||||
snapshot is authoritative when available and covers paths without a final
|
||||
per-iteration callback.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
@ -178,3 +180,7 @@ class SDKCaptureHook(AgentHook):
|
||||
for call in context.tool_calls:
|
||||
self.tools_used.append(call.name)
|
||||
self.messages = list(context.messages)
|
||||
|
||||
async def after_run(self, context: AgentRunHookContext) -> None:
|
||||
self.tools_used = list(context.tools_used)
|
||||
self.messages = list(context.messages)
|
||||
|
||||
@ -277,7 +277,13 @@ class AgentRunner:
|
||||
try:
|
||||
await hook.before_run(context)
|
||||
result = await self._run_core(spec, hook, messages)
|
||||
except BaseException as exc:
|
||||
except asyncio.CancelledError as exc:
|
||||
context.messages = list(messages)
|
||||
context.stop_reason = "cancelled"
|
||||
context.error = None
|
||||
context.exception = exc
|
||||
raise
|
||||
except Exception as exc:
|
||||
context.messages = list(messages)
|
||||
context.stop_reason = "error"
|
||||
context.error = f"Error: {type(exc).__name__}: {exc}"
|
||||
|
||||
@ -336,3 +336,55 @@ async def test_runner_calls_on_error_and_finally_for_unhandled_exception():
|
||||
("on_error", "error", "Error: RuntimeError: provider exploded", "RuntimeError"),
|
||||
("on_finally", "error"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_does_not_report_cancellation_as_error():
|
||||
import asyncio
|
||||
|
||||
from nanobot.agent.hook import AgentHook, AgentRunHookContext
|
||||
from nanobot.agent.runner import AgentRunner, AgentRunSpec
|
||||
|
||||
provider = MagicMock(spec=LLMProvider)
|
||||
events: list[tuple] = []
|
||||
|
||||
async def chat_with_retry(**kwargs):
|
||||
raise asyncio.CancelledError()
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
class CancellationHook(AgentHook):
|
||||
async def before_run(self, context: AgentRunHookContext) -> None:
|
||||
events.append(("before_run", context.stop_reason))
|
||||
|
||||
async def on_error(self, context: AgentRunHookContext) -> None:
|
||||
events.append(("on_error", context.stop_reason, context.error))
|
||||
|
||||
async def after_run(self, context: AgentRunHookContext) -> None:
|
||||
events.append(("after_run", context.stop_reason))
|
||||
|
||||
async def on_finally(self, context: AgentRunHookContext) -> None:
|
||||
events.append((
|
||||
"on_finally",
|
||||
context.stop_reason,
|
||||
context.error,
|
||||
type(context.exception).__name__ if context.exception else None,
|
||||
))
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "hi"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
hook=CancellationHook(),
|
||||
))
|
||||
|
||||
assert events == [
|
||||
("before_run", None),
|
||||
("on_finally", "cancelled", None, "CancelledError"),
|
||||
]
|
||||
|
||||
@ -299,3 +299,30 @@ async def test_run_restores_extra_hooks_even_on_populated_iterations(tmp_path):
|
||||
bot._loop.process_direct = fake_process_direct
|
||||
await bot.run("hello")
|
||||
assert bot._loop._extra_hooks == [sentinel_hook]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sdk_capture_prefers_run_level_snapshot():
|
||||
from nanobot.agent.hook import AgentHookContext, AgentRunHookContext, SDKCaptureHook
|
||||
from nanobot.providers.base import ToolCallRequest
|
||||
|
||||
hook = SDKCaptureHook()
|
||||
iter_messages = [{"role": "user", "content": "work"}]
|
||||
iter_context = AgentHookContext(iteration=0, messages=iter_messages)
|
||||
iter_context.tool_calls = [
|
||||
ToolCallRequest(id="call_1", name="read_file", arguments={}),
|
||||
ToolCallRequest(id="call_2", name="grep", arguments={}),
|
||||
]
|
||||
await hook.after_iteration(iter_context)
|
||||
|
||||
final_messages = [
|
||||
{"role": "user", "content": "work"},
|
||||
{"role": "assistant", "content": "done"},
|
||||
]
|
||||
await hook.after_run(AgentRunHookContext(
|
||||
messages=final_messages,
|
||||
tools_used=["read_file"],
|
||||
))
|
||||
|
||||
assert hook.tools_used == ["read_file"]
|
||||
assert hook.messages == final_messages
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user