mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-13 22:34:06 +00:00
fix: harden run-level hook lifecycle
maintainer edit: keep cancellation out of on_error so shutdown paths do not look like run failures, and let the SDK capture hook use the authoritative after_run snapshot.
This commit is contained in:
parent
2ea226055e
commit
8933da1ec5
@ -166,7 +166,9 @@ class SDKCaptureHook(AgentHook):
|
|||||||
|
|
||||||
The runner mutates ``context.messages`` in place across iterations, so the
|
The runner mutates ``context.messages`` in place across iterations, so the
|
||||||
snapshot is refreshed on every ``after_iteration`` call; the last call
|
snapshot is refreshed on every ``after_iteration`` call; the last call
|
||||||
reflects the end-of-turn state the SDK caller cares about.
|
reflects the end-of-turn state the SDK caller cares about. The run-level
|
||||||
|
snapshot is authoritative when available and covers paths without a final
|
||||||
|
per-iteration callback.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
@ -178,3 +180,7 @@ class SDKCaptureHook(AgentHook):
|
|||||||
for call in context.tool_calls:
|
for call in context.tool_calls:
|
||||||
self.tools_used.append(call.name)
|
self.tools_used.append(call.name)
|
||||||
self.messages = list(context.messages)
|
self.messages = list(context.messages)
|
||||||
|
|
||||||
|
async def after_run(self, context: AgentRunHookContext) -> None:
|
||||||
|
self.tools_used = list(context.tools_used)
|
||||||
|
self.messages = list(context.messages)
|
||||||
|
|||||||
@ -277,7 +277,13 @@ class AgentRunner:
|
|||||||
try:
|
try:
|
||||||
await hook.before_run(context)
|
await hook.before_run(context)
|
||||||
result = await self._run_core(spec, hook, messages)
|
result = await self._run_core(spec, hook, messages)
|
||||||
except BaseException as exc:
|
except asyncio.CancelledError as exc:
|
||||||
|
context.messages = list(messages)
|
||||||
|
context.stop_reason = "cancelled"
|
||||||
|
context.error = None
|
||||||
|
context.exception = exc
|
||||||
|
raise
|
||||||
|
except Exception as exc:
|
||||||
context.messages = list(messages)
|
context.messages = list(messages)
|
||||||
context.stop_reason = "error"
|
context.stop_reason = "error"
|
||||||
context.error = f"Error: {type(exc).__name__}: {exc}"
|
context.error = f"Error: {type(exc).__name__}: {exc}"
|
||||||
|
|||||||
@ -336,3 +336,55 @@ async def test_runner_calls_on_error_and_finally_for_unhandled_exception():
|
|||||||
("on_error", "error", "Error: RuntimeError: provider exploded", "RuntimeError"),
|
("on_error", "error", "Error: RuntimeError: provider exploded", "RuntimeError"),
|
||||||
("on_finally", "error"),
|
("on_finally", "error"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_does_not_report_cancellation_as_error():
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from nanobot.agent.hook import AgentHook, AgentRunHookContext
|
||||||
|
from nanobot.agent.runner import AgentRunner, AgentRunSpec
|
||||||
|
|
||||||
|
provider = MagicMock(spec=LLMProvider)
|
||||||
|
events: list[tuple] = []
|
||||||
|
|
||||||
|
async def chat_with_retry(**kwargs):
|
||||||
|
raise asyncio.CancelledError()
|
||||||
|
|
||||||
|
provider.chat_with_retry = chat_with_retry
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
|
||||||
|
class CancellationHook(AgentHook):
|
||||||
|
async def before_run(self, context: AgentRunHookContext) -> None:
|
||||||
|
events.append(("before_run", context.stop_reason))
|
||||||
|
|
||||||
|
async def on_error(self, context: AgentRunHookContext) -> None:
|
||||||
|
events.append(("on_error", context.stop_reason, context.error))
|
||||||
|
|
||||||
|
async def after_run(self, context: AgentRunHookContext) -> None:
|
||||||
|
events.append(("after_run", context.stop_reason))
|
||||||
|
|
||||||
|
async def on_finally(self, context: AgentRunHookContext) -> None:
|
||||||
|
events.append((
|
||||||
|
"on_finally",
|
||||||
|
context.stop_reason,
|
||||||
|
context.error,
|
||||||
|
type(context.exception).__name__ if context.exception else None,
|
||||||
|
))
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
with pytest.raises(asyncio.CancelledError):
|
||||||
|
await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=[{"role": "user", "content": "hi"}],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=1,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
hook=CancellationHook(),
|
||||||
|
))
|
||||||
|
|
||||||
|
assert events == [
|
||||||
|
("before_run", None),
|
||||||
|
("on_finally", "cancelled", None, "CancelledError"),
|
||||||
|
]
|
||||||
|
|||||||
@ -299,3 +299,30 @@ async def test_run_restores_extra_hooks_even_on_populated_iterations(tmp_path):
|
|||||||
bot._loop.process_direct = fake_process_direct
|
bot._loop.process_direct = fake_process_direct
|
||||||
await bot.run("hello")
|
await bot.run("hello")
|
||||||
assert bot._loop._extra_hooks == [sentinel_hook]
|
assert bot._loop._extra_hooks == [sentinel_hook]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sdk_capture_prefers_run_level_snapshot():
|
||||||
|
from nanobot.agent.hook import AgentHookContext, AgentRunHookContext, SDKCaptureHook
|
||||||
|
from nanobot.providers.base import ToolCallRequest
|
||||||
|
|
||||||
|
hook = SDKCaptureHook()
|
||||||
|
iter_messages = [{"role": "user", "content": "work"}]
|
||||||
|
iter_context = AgentHookContext(iteration=0, messages=iter_messages)
|
||||||
|
iter_context.tool_calls = [
|
||||||
|
ToolCallRequest(id="call_1", name="read_file", arguments={}),
|
||||||
|
ToolCallRequest(id="call_2", name="grep", arguments={}),
|
||||||
|
]
|
||||||
|
await hook.after_iteration(iter_context)
|
||||||
|
|
||||||
|
final_messages = [
|
||||||
|
{"role": "user", "content": "work"},
|
||||||
|
{"role": "assistant", "content": "done"},
|
||||||
|
]
|
||||||
|
await hook.after_run(AgentRunHookContext(
|
||||||
|
messages=final_messages,
|
||||||
|
tools_used=["read_file"],
|
||||||
|
))
|
||||||
|
|
||||||
|
assert hook.tools_used == ["read_file"]
|
||||||
|
assert hook.messages == final_messages
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user