fix: isolate run-level hook snapshots

This commit is contained in:
chengyongru 2026-06-04 10:16:29 +08:00 committed by Xubin Ren
parent 8933da1ec5
commit 39454534d4
2 changed files with 133 additions and 7 deletions

View File

@ -6,6 +6,7 @@ import asyncio
import inspect import inspect
import os import os
from contextlib import suppress from contextlib import suppress
from copy import deepcopy
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Any, Callable from typing import Any, Callable
@ -272,32 +273,32 @@ class AgentRunner:
async def run(self, spec: AgentRunSpec) -> AgentRunResult: async def run(self, spec: AgentRunSpec) -> AgentRunResult:
hook = spec.hook or AgentHook() hook = spec.hook or AgentHook()
messages = list(spec.initial_messages) messages = list(spec.initial_messages)
context = AgentRunHookContext(messages=list(messages)) context = AgentRunHookContext(messages=deepcopy(messages))
try: try:
await hook.before_run(context) await hook.before_run(context)
result = await self._run_core(spec, hook, messages) result = await self._run_core(spec, hook, messages)
except asyncio.CancelledError as exc: except asyncio.CancelledError as exc:
context.messages = list(messages) context.messages = deepcopy(messages)
context.stop_reason = "cancelled" context.stop_reason = "cancelled"
context.error = None context.error = None
context.exception = exc context.exception = exc
raise raise
except Exception as exc: except Exception as exc:
context.messages = list(messages) context.messages = deepcopy(messages)
context.stop_reason = "error" context.stop_reason = "error"
context.error = f"Error: {type(exc).__name__}: {exc}" context.error = f"Error: {type(exc).__name__}: {exc}"
context.exception = exc context.exception = exc
await hook.on_error(context) await hook.on_error(context)
raise raise
else: else:
context.messages = list(result.messages) context.messages = deepcopy(result.messages)
context.final_content = result.final_content context.final_content = result.final_content
context.tools_used = list(result.tools_used) context.tools_used = list(result.tools_used)
context.usage = dict(result.usage) context.usage = dict(result.usage)
context.stop_reason = result.stop_reason context.stop_reason = result.stop_reason
context.error = result.error context.error = result.error
context.tool_events = list(result.tool_events) context.tool_events = deepcopy(result.tool_events)
context.had_injections = result.had_injections context.had_injections = result.had_injections
context.exception = None context.exception = None
if context.error is not None: if context.error is not None:
@ -305,8 +306,17 @@ class AgentRunner:
await hook.after_run(context) await hook.after_run(context)
return result return result
finally: finally:
context.messages = list(messages) context.messages = deepcopy(messages)
await hook.on_finally(context) if context.exception is None:
await hook.on_finally(context)
else:
try:
await hook.on_finally(context)
except Exception:
logger.exception(
"AgentHook.on_finally error after {}",
context.stop_reason or "run exception",
)
async def _run_core( async def _run_core(
self, self,

View File

@ -239,6 +239,60 @@ async def test_runner_calls_run_level_hooks_on_success():
] ]
@pytest.mark.asyncio
async def test_runner_run_level_context_is_detached_snapshot():
from nanobot.agent.hook import AgentHook, AgentRunHookContext
from nanobot.agent.runner import AgentRunner, AgentRunSpec
provider = MagicMock(spec=LLMProvider)
call_count = {"n": 0}
request_messages: list[list[dict]] = []
async def chat_with_retry(**kwargs):
call_count["n"] += 1
request_messages.append([dict(msg) for msg in kwargs["messages"]])
if call_count["n"] == 1:
return LLMResponse(
content="thinking",
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
)
return LLMResponse(content="done", tool_calls=[], usage={})
provider.chat_with_retry = chat_with_retry
tools = MagicMock()
tools.get_definitions.return_value = []
tools.execute = AsyncMock(return_value="tool result")
class MutatingRunHook(AgentHook):
async def before_run(self, context: AgentRunHookContext) -> None:
context.messages[0]["content"] = "mutated-before"
async def after_run(self, context: AgentRunHookContext) -> None:
context.messages[0]["content"] = "mutated-after"
context.tool_events[0]["status"] = "mutated"
context.tools_used.append("mutated")
async def on_finally(self, context: AgentRunHookContext) -> None:
context.messages[0]["content"] = "mutated-finally"
runner = AgentRunner(provider)
result = await runner.run(AgentRunSpec(
initial_messages=[{"role": "user", "content": "hi"}],
tools=tools,
model="test-model",
max_iterations=2,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
hook=MutatingRunHook(),
))
assert request_messages[0][0]["content"] == "hi"
assert result.messages[0]["content"] == "hi"
assert result.tools_used == ["list_dir"]
assert result.tool_events == [
{"name": "list_dir", "status": "ok", "detail": "tool result"}
]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_runner_calls_on_error_for_model_error_result(): async def test_runner_calls_on_error_for_model_error_result():
from nanobot.agent.hook import AgentHook, AgentRunHookContext from nanobot.agent.hook import AgentHook, AgentRunHookContext
@ -338,6 +392,36 @@ async def test_runner_calls_on_error_and_finally_for_unhandled_exception():
] ]
@pytest.mark.asyncio
async def test_runner_preserves_original_exception_when_finally_hook_fails():
from nanobot.agent.hook import AgentHook, AgentRunHookContext
from nanobot.agent.runner import AgentRunner, AgentRunSpec
provider = MagicMock(spec=LLMProvider)
async def chat_with_retry(**kwargs):
raise RuntimeError("provider exploded")
provider.chat_with_retry = chat_with_retry
tools = MagicMock()
tools.get_definitions.return_value = []
class BadFinallyHook(AgentHook):
async def on_finally(self, context: AgentRunHookContext) -> None:
raise RuntimeError("finally exploded")
runner = AgentRunner(provider)
with pytest.raises(RuntimeError, match="provider exploded"):
await runner.run(AgentRunSpec(
initial_messages=[{"role": "user", "content": "hi"}],
tools=tools,
model="test-model",
max_iterations=1,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
hook=BadFinallyHook(),
))
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_runner_does_not_report_cancellation_as_error(): async def test_runner_does_not_report_cancellation_as_error():
import asyncio import asyncio
@ -388,3 +472,35 @@ async def test_runner_does_not_report_cancellation_as_error():
("before_run", None), ("before_run", None),
("on_finally", "cancelled", None, "CancelledError"), ("on_finally", "cancelled", None, "CancelledError"),
] ]
@pytest.mark.asyncio
async def test_runner_preserves_cancellation_when_finally_hook_fails():
import asyncio
from nanobot.agent.hook import AgentHook, AgentRunHookContext
from nanobot.agent.runner import AgentRunner, AgentRunSpec
provider = MagicMock(spec=LLMProvider)
async def chat_with_retry(**kwargs):
raise asyncio.CancelledError()
provider.chat_with_retry = chat_with_retry
tools = MagicMock()
tools.get_definitions.return_value = []
class BadFinallyHook(AgentHook):
async def on_finally(self, context: AgentRunHookContext) -> None:
raise RuntimeError("finally exploded")
runner = AgentRunner(provider)
with pytest.raises(asyncio.CancelledError):
await runner.run(AgentRunSpec(
initial_messages=[{"role": "user", "content": "hi"}],
tools=tools,
model="test-model",
max_iterations=1,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
hook=BadFinallyHook(),
))