mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-13 14:23:58 +00:00
fix: isolate run-level hook snapshots
This commit is contained in:
parent
8933da1ec5
commit
39454534d4
@ -6,6 +6,7 @@ import asyncio
|
||||
import inspect
|
||||
import os
|
||||
from contextlib import suppress
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable
|
||||
@ -272,32 +273,32 @@ class AgentRunner:
|
||||
async def run(self, spec: AgentRunSpec) -> AgentRunResult:
|
||||
hook = spec.hook or AgentHook()
|
||||
messages = list(spec.initial_messages)
|
||||
context = AgentRunHookContext(messages=list(messages))
|
||||
context = AgentRunHookContext(messages=deepcopy(messages))
|
||||
|
||||
try:
|
||||
await hook.before_run(context)
|
||||
result = await self._run_core(spec, hook, messages)
|
||||
except asyncio.CancelledError as exc:
|
||||
context.messages = list(messages)
|
||||
context.messages = deepcopy(messages)
|
||||
context.stop_reason = "cancelled"
|
||||
context.error = None
|
||||
context.exception = exc
|
||||
raise
|
||||
except Exception as exc:
|
||||
context.messages = list(messages)
|
||||
context.messages = deepcopy(messages)
|
||||
context.stop_reason = "error"
|
||||
context.error = f"Error: {type(exc).__name__}: {exc}"
|
||||
context.exception = exc
|
||||
await hook.on_error(context)
|
||||
raise
|
||||
else:
|
||||
context.messages = list(result.messages)
|
||||
context.messages = deepcopy(result.messages)
|
||||
context.final_content = result.final_content
|
||||
context.tools_used = list(result.tools_used)
|
||||
context.usage = dict(result.usage)
|
||||
context.stop_reason = result.stop_reason
|
||||
context.error = result.error
|
||||
context.tool_events = list(result.tool_events)
|
||||
context.tool_events = deepcopy(result.tool_events)
|
||||
context.had_injections = result.had_injections
|
||||
context.exception = None
|
||||
if context.error is not None:
|
||||
@ -305,8 +306,17 @@ class AgentRunner:
|
||||
await hook.after_run(context)
|
||||
return result
|
||||
finally:
|
||||
context.messages = list(messages)
|
||||
await hook.on_finally(context)
|
||||
context.messages = deepcopy(messages)
|
||||
if context.exception is None:
|
||||
await hook.on_finally(context)
|
||||
else:
|
||||
try:
|
||||
await hook.on_finally(context)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"AgentHook.on_finally error after {}",
|
||||
context.stop_reason or "run exception",
|
||||
)
|
||||
|
||||
async def _run_core(
|
||||
self,
|
||||
|
||||
@ -239,6 +239,60 @@ async def test_runner_calls_run_level_hooks_on_success():
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_run_level_context_is_detached_snapshot():
|
||||
from nanobot.agent.hook import AgentHook, AgentRunHookContext
|
||||
from nanobot.agent.runner import AgentRunner, AgentRunSpec
|
||||
|
||||
provider = MagicMock(spec=LLMProvider)
|
||||
call_count = {"n": 0}
|
||||
request_messages: list[list[dict]] = []
|
||||
|
||||
async def chat_with_retry(**kwargs):
|
||||
call_count["n"] += 1
|
||||
request_messages.append([dict(msg) for msg in kwargs["messages"]])
|
||||
if call_count["n"] == 1:
|
||||
return LLMResponse(
|
||||
content="thinking",
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
|
||||
)
|
||||
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
tools.execute = AsyncMock(return_value="tool result")
|
||||
|
||||
class MutatingRunHook(AgentHook):
|
||||
async def before_run(self, context: AgentRunHookContext) -> None:
|
||||
context.messages[0]["content"] = "mutated-before"
|
||||
|
||||
async def after_run(self, context: AgentRunHookContext) -> None:
|
||||
context.messages[0]["content"] = "mutated-after"
|
||||
context.tool_events[0]["status"] = "mutated"
|
||||
context.tools_used.append("mutated")
|
||||
|
||||
async def on_finally(self, context: AgentRunHookContext) -> None:
|
||||
context.messages[0]["content"] = "mutated-finally"
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "hi"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=2,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
hook=MutatingRunHook(),
|
||||
))
|
||||
|
||||
assert request_messages[0][0]["content"] == "hi"
|
||||
assert result.messages[0]["content"] == "hi"
|
||||
assert result.tools_used == ["list_dir"]
|
||||
assert result.tool_events == [
|
||||
{"name": "list_dir", "status": "ok", "detail": "tool result"}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_calls_on_error_for_model_error_result():
|
||||
from nanobot.agent.hook import AgentHook, AgentRunHookContext
|
||||
@ -338,6 +392,36 @@ async def test_runner_calls_on_error_and_finally_for_unhandled_exception():
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_preserves_original_exception_when_finally_hook_fails():
|
||||
from nanobot.agent.hook import AgentHook, AgentRunHookContext
|
||||
from nanobot.agent.runner import AgentRunner, AgentRunSpec
|
||||
|
||||
provider = MagicMock(spec=LLMProvider)
|
||||
|
||||
async def chat_with_retry(**kwargs):
|
||||
raise RuntimeError("provider exploded")
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
class BadFinallyHook(AgentHook):
|
||||
async def on_finally(self, context: AgentRunHookContext) -> None:
|
||||
raise RuntimeError("finally exploded")
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
with pytest.raises(RuntimeError, match="provider exploded"):
|
||||
await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "hi"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
hook=BadFinallyHook(),
|
||||
))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_does_not_report_cancellation_as_error():
|
||||
import asyncio
|
||||
@ -388,3 +472,35 @@ async def test_runner_does_not_report_cancellation_as_error():
|
||||
("before_run", None),
|
||||
("on_finally", "cancelled", None, "CancelledError"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_preserves_cancellation_when_finally_hook_fails():
|
||||
import asyncio
|
||||
|
||||
from nanobot.agent.hook import AgentHook, AgentRunHookContext
|
||||
from nanobot.agent.runner import AgentRunner, AgentRunSpec
|
||||
|
||||
provider = MagicMock(spec=LLMProvider)
|
||||
|
||||
async def chat_with_retry(**kwargs):
|
||||
raise asyncio.CancelledError()
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
class BadFinallyHook(AgentHook):
|
||||
async def on_finally(self, context: AgentRunHookContext) -> None:
|
||||
raise RuntimeError("finally exploded")
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "hi"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
hook=BadFinallyHook(),
|
||||
))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user