mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-14 06:43:53 +00:00
fix: isolate run-level hook snapshots
This commit is contained in:
parent
8933da1ec5
commit
39454534d4
@ -6,6 +6,7 @@ import asyncio
|
|||||||
import inspect
|
import inspect
|
||||||
import os
|
import os
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
|
from copy import deepcopy
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable
|
from typing import Any, Callable
|
||||||
@ -272,32 +273,32 @@ class AgentRunner:
|
|||||||
async def run(self, spec: AgentRunSpec) -> AgentRunResult:
|
async def run(self, spec: AgentRunSpec) -> AgentRunResult:
|
||||||
hook = spec.hook or AgentHook()
|
hook = spec.hook or AgentHook()
|
||||||
messages = list(spec.initial_messages)
|
messages = list(spec.initial_messages)
|
||||||
context = AgentRunHookContext(messages=list(messages))
|
context = AgentRunHookContext(messages=deepcopy(messages))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await hook.before_run(context)
|
await hook.before_run(context)
|
||||||
result = await self._run_core(spec, hook, messages)
|
result = await self._run_core(spec, hook, messages)
|
||||||
except asyncio.CancelledError as exc:
|
except asyncio.CancelledError as exc:
|
||||||
context.messages = list(messages)
|
context.messages = deepcopy(messages)
|
||||||
context.stop_reason = "cancelled"
|
context.stop_reason = "cancelled"
|
||||||
context.error = None
|
context.error = None
|
||||||
context.exception = exc
|
context.exception = exc
|
||||||
raise
|
raise
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
context.messages = list(messages)
|
context.messages = deepcopy(messages)
|
||||||
context.stop_reason = "error"
|
context.stop_reason = "error"
|
||||||
context.error = f"Error: {type(exc).__name__}: {exc}"
|
context.error = f"Error: {type(exc).__name__}: {exc}"
|
||||||
context.exception = exc
|
context.exception = exc
|
||||||
await hook.on_error(context)
|
await hook.on_error(context)
|
||||||
raise
|
raise
|
||||||
else:
|
else:
|
||||||
context.messages = list(result.messages)
|
context.messages = deepcopy(result.messages)
|
||||||
context.final_content = result.final_content
|
context.final_content = result.final_content
|
||||||
context.tools_used = list(result.tools_used)
|
context.tools_used = list(result.tools_used)
|
||||||
context.usage = dict(result.usage)
|
context.usage = dict(result.usage)
|
||||||
context.stop_reason = result.stop_reason
|
context.stop_reason = result.stop_reason
|
||||||
context.error = result.error
|
context.error = result.error
|
||||||
context.tool_events = list(result.tool_events)
|
context.tool_events = deepcopy(result.tool_events)
|
||||||
context.had_injections = result.had_injections
|
context.had_injections = result.had_injections
|
||||||
context.exception = None
|
context.exception = None
|
||||||
if context.error is not None:
|
if context.error is not None:
|
||||||
@ -305,8 +306,17 @@ class AgentRunner:
|
|||||||
await hook.after_run(context)
|
await hook.after_run(context)
|
||||||
return result
|
return result
|
||||||
finally:
|
finally:
|
||||||
context.messages = list(messages)
|
context.messages = deepcopy(messages)
|
||||||
await hook.on_finally(context)
|
if context.exception is None:
|
||||||
|
await hook.on_finally(context)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
await hook.on_finally(context)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"AgentHook.on_finally error after {}",
|
||||||
|
context.stop_reason or "run exception",
|
||||||
|
)
|
||||||
|
|
||||||
async def _run_core(
|
async def _run_core(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -239,6 +239,60 @@ async def test_runner_calls_run_level_hooks_on_success():
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_run_level_context_is_detached_snapshot():
|
||||||
|
from nanobot.agent.hook import AgentHook, AgentRunHookContext
|
||||||
|
from nanobot.agent.runner import AgentRunner, AgentRunSpec
|
||||||
|
|
||||||
|
provider = MagicMock(spec=LLMProvider)
|
||||||
|
call_count = {"n": 0}
|
||||||
|
request_messages: list[list[dict]] = []
|
||||||
|
|
||||||
|
async def chat_with_retry(**kwargs):
|
||||||
|
call_count["n"] += 1
|
||||||
|
request_messages.append([dict(msg) for msg in kwargs["messages"]])
|
||||||
|
if call_count["n"] == 1:
|
||||||
|
return LLMResponse(
|
||||||
|
content="thinking",
|
||||||
|
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
|
||||||
|
)
|
||||||
|
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||||
|
|
||||||
|
provider.chat_with_retry = chat_with_retry
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
tools.execute = AsyncMock(return_value="tool result")
|
||||||
|
|
||||||
|
class MutatingRunHook(AgentHook):
|
||||||
|
async def before_run(self, context: AgentRunHookContext) -> None:
|
||||||
|
context.messages[0]["content"] = "mutated-before"
|
||||||
|
|
||||||
|
async def after_run(self, context: AgentRunHookContext) -> None:
|
||||||
|
context.messages[0]["content"] = "mutated-after"
|
||||||
|
context.tool_events[0]["status"] = "mutated"
|
||||||
|
context.tools_used.append("mutated")
|
||||||
|
|
||||||
|
async def on_finally(self, context: AgentRunHookContext) -> None:
|
||||||
|
context.messages[0]["content"] = "mutated-finally"
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
result = await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=[{"role": "user", "content": "hi"}],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=2,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
hook=MutatingRunHook(),
|
||||||
|
))
|
||||||
|
|
||||||
|
assert request_messages[0][0]["content"] == "hi"
|
||||||
|
assert result.messages[0]["content"] == "hi"
|
||||||
|
assert result.tools_used == ["list_dir"]
|
||||||
|
assert result.tool_events == [
|
||||||
|
{"name": "list_dir", "status": "ok", "detail": "tool result"}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_runner_calls_on_error_for_model_error_result():
|
async def test_runner_calls_on_error_for_model_error_result():
|
||||||
from nanobot.agent.hook import AgentHook, AgentRunHookContext
|
from nanobot.agent.hook import AgentHook, AgentRunHookContext
|
||||||
@ -338,6 +392,36 @@ async def test_runner_calls_on_error_and_finally_for_unhandled_exception():
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_preserves_original_exception_when_finally_hook_fails():
|
||||||
|
from nanobot.agent.hook import AgentHook, AgentRunHookContext
|
||||||
|
from nanobot.agent.runner import AgentRunner, AgentRunSpec
|
||||||
|
|
||||||
|
provider = MagicMock(spec=LLMProvider)
|
||||||
|
|
||||||
|
async def chat_with_retry(**kwargs):
|
||||||
|
raise RuntimeError("provider exploded")
|
||||||
|
|
||||||
|
provider.chat_with_retry = chat_with_retry
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
|
||||||
|
class BadFinallyHook(AgentHook):
|
||||||
|
async def on_finally(self, context: AgentRunHookContext) -> None:
|
||||||
|
raise RuntimeError("finally exploded")
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
with pytest.raises(RuntimeError, match="provider exploded"):
|
||||||
|
await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=[{"role": "user", "content": "hi"}],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=1,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
hook=BadFinallyHook(),
|
||||||
|
))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_runner_does_not_report_cancellation_as_error():
|
async def test_runner_does_not_report_cancellation_as_error():
|
||||||
import asyncio
|
import asyncio
|
||||||
@ -388,3 +472,35 @@ async def test_runner_does_not_report_cancellation_as_error():
|
|||||||
("before_run", None),
|
("before_run", None),
|
||||||
("on_finally", "cancelled", None, "CancelledError"),
|
("on_finally", "cancelled", None, "CancelledError"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_preserves_cancellation_when_finally_hook_fails():
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from nanobot.agent.hook import AgentHook, AgentRunHookContext
|
||||||
|
from nanobot.agent.runner import AgentRunner, AgentRunSpec
|
||||||
|
|
||||||
|
provider = MagicMock(spec=LLMProvider)
|
||||||
|
|
||||||
|
async def chat_with_retry(**kwargs):
|
||||||
|
raise asyncio.CancelledError()
|
||||||
|
|
||||||
|
provider.chat_with_retry = chat_with_retry
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
|
||||||
|
class BadFinallyHook(AgentHook):
|
||||||
|
async def on_finally(self, context: AgentRunHookContext) -> None:
|
||||||
|
raise RuntimeError("finally exploded")
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
with pytest.raises(asyncio.CancelledError):
|
||||||
|
await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=[{"role": "user", "content": "hi"}],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=1,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
hook=BadFinallyHook(),
|
||||||
|
))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user