mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-02 17:32:39 +00:00
feat(agent): add CompositeHook for composable lifecycle hooks
Introduce a CompositeHook that fans out lifecycle callbacks to an ordered list of AgentHook instances with per-hook error isolation. Extract the nested _LoopHook and _SubagentHook to module scope as public LoopHook / SubagentHook so downstream users can subclass or compose them. Add `hooks` parameter to AgentLoop.__init__ for registering custom hooks at construction time. Closes #2603
This commit is contained in:
parent
1814272583
commit
f08de72f18
@ -1,8 +1,21 @@
|
||||
"""Agent core module."""
|
||||
|
||||
from nanobot.agent.context import ContextBuilder
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook
|
||||
from nanobot.agent.loop import AgentLoop, LoopHook
|
||||
from nanobot.agent.memory import MemoryStore
|
||||
from nanobot.agent.skills import SkillsLoader
|
||||
from nanobot.agent.subagent import SubagentHook, SubagentManager
|
||||
|
||||
__all__ = ["AgentLoop", "ContextBuilder", "MemoryStore", "SkillsLoader"]
|
||||
__all__ = [
|
||||
"AgentHook",
|
||||
"AgentHookContext",
|
||||
"AgentLoop",
|
||||
"CompositeHook",
|
||||
"ContextBuilder",
|
||||
"LoopHook",
|
||||
"MemoryStore",
|
||||
"SkillsLoader",
|
||||
"SubagentHook",
|
||||
"SubagentManager",
|
||||
]
|
||||
|
||||
@ -5,6 +5,8 @@ from __future__ import annotations
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||
|
||||
|
||||
@ -47,3 +49,60 @@ class AgentHook:
|
||||
|
||||
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
|
||||
return content
|
||||
|
||||
|
||||
class CompositeHook(AgentHook):
|
||||
"""Fan-out hook that delegates to an ordered list of hooks.
|
||||
|
||||
Error isolation: async methods catch and log per-hook exceptions
|
||||
so a faulty custom hook cannot crash the agent loop.
|
||||
``finalize_content`` is a pipeline (no isolation — bugs should surface).
|
||||
"""
|
||||
|
||||
__slots__ = ("_hooks",)
|
||||
|
||||
def __init__(self, hooks: list[AgentHook]) -> None:
|
||||
self._hooks = list(hooks)
|
||||
|
||||
def wants_streaming(self) -> bool:
|
||||
return any(h.wants_streaming() for h in self._hooks)
|
||||
|
||||
async def before_iteration(self, context: AgentHookContext) -> None:
|
||||
for h in self._hooks:
|
||||
try:
|
||||
await h.before_iteration(context)
|
||||
except Exception:
|
||||
logger.exception("AgentHook.before_iteration error in {}", type(h).__name__)
|
||||
|
||||
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
|
||||
for h in self._hooks:
|
||||
try:
|
||||
await h.on_stream(context, delta)
|
||||
except Exception:
|
||||
logger.exception("AgentHook.on_stream error in {}", type(h).__name__)
|
||||
|
||||
async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None:
|
||||
for h in self._hooks:
|
||||
try:
|
||||
await h.on_stream_end(context, resuming=resuming)
|
||||
except Exception:
|
||||
logger.exception("AgentHook.on_stream_end error in {}", type(h).__name__)
|
||||
|
||||
async def before_execute_tools(self, context: AgentHookContext) -> None:
|
||||
for h in self._hooks:
|
||||
try:
|
||||
await h.before_execute_tools(context)
|
||||
except Exception:
|
||||
logger.exception("AgentHook.before_execute_tools error in {}", type(h).__name__)
|
||||
|
||||
async def after_iteration(self, context: AgentHookContext) -> None:
|
||||
for h in self._hooks:
|
||||
try:
|
||||
await h.after_iteration(context)
|
||||
except Exception:
|
||||
logger.exception("AgentHook.after_iteration error in {}", type(h).__name__)
|
||||
|
||||
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
|
||||
for h in self._hooks:
|
||||
content = h.finalize_content(context, content)
|
||||
return content
|
||||
|
||||
@ -14,7 +14,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.agent.context import ContextBuilder
|
||||
from nanobot.agent.hook import AgentHook, AgentHookContext
|
||||
from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook
|
||||
from nanobot.agent.memory import MemoryConsolidator
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
from nanobot.agent.subagent import SubagentManager
|
||||
@ -37,6 +37,71 @@ if TYPE_CHECKING:
|
||||
from nanobot.cron.service import CronService
|
||||
|
||||
|
||||
class LoopHook(AgentHook):
|
||||
"""Core lifecycle hook for the main agent loop.
|
||||
|
||||
Handles streaming delta relay, progress reporting, tool-call logging,
|
||||
and think-tag stripping. Public so downstream users can subclass or
|
||||
compose it via :class:`CompositeHook`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_loop: AgentLoop,
|
||||
on_progress: Callable[..., Awaitable[None]] | None = None,
|
||||
on_stream: Callable[[str], Awaitable[None]] | None = None,
|
||||
on_stream_end: Callable[..., Awaitable[None]] | None = None,
|
||||
*,
|
||||
channel: str = "cli",
|
||||
chat_id: str = "direct",
|
||||
message_id: str | None = None,
|
||||
) -> None:
|
||||
self._loop = agent_loop
|
||||
self._on_progress = on_progress
|
||||
self._on_stream = on_stream
|
||||
self._on_stream_end = on_stream_end
|
||||
self._channel = channel
|
||||
self._chat_id = chat_id
|
||||
self._message_id = message_id
|
||||
self._stream_buf = ""
|
||||
|
||||
def wants_streaming(self) -> bool:
|
||||
return self._on_stream is not None
|
||||
|
||||
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
|
||||
from nanobot.utils.helpers import strip_think
|
||||
|
||||
prev_clean = strip_think(self._stream_buf)
|
||||
self._stream_buf += delta
|
||||
new_clean = strip_think(self._stream_buf)
|
||||
incremental = new_clean[len(prev_clean):]
|
||||
if incremental and self._on_stream:
|
||||
await self._on_stream(incremental)
|
||||
|
||||
async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None:
|
||||
if self._on_stream_end:
|
||||
await self._on_stream_end(resuming=resuming)
|
||||
self._stream_buf = ""
|
||||
|
||||
async def before_execute_tools(self, context: AgentHookContext) -> None:
|
||||
if self._on_progress:
|
||||
if not self._on_stream:
|
||||
thought = self._loop._strip_think(
|
||||
context.response.content if context.response else None
|
||||
)
|
||||
if thought:
|
||||
await self._on_progress(thought)
|
||||
tool_hint = self._loop._strip_think(self._loop._tool_hint(context.tool_calls))
|
||||
await self._on_progress(tool_hint, tool_hint=True)
|
||||
for tc in context.tool_calls:
|
||||
args_str = json.dumps(tc.arguments, ensure_ascii=False)
|
||||
logger.info("Tool call: {}({})", tc.name, args_str[:200])
|
||||
self._loop._set_tool_context(self._channel, self._chat_id, self._message_id)
|
||||
|
||||
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
|
||||
return self._loop._strip_think(content)
|
||||
|
||||
|
||||
class AgentLoop:
|
||||
"""
|
||||
The agent loop is the core processing engine.
|
||||
@ -68,6 +133,7 @@ class AgentLoop:
|
||||
mcp_servers: dict | None = None,
|
||||
channels_config: ChannelsConfig | None = None,
|
||||
timezone: str | None = None,
|
||||
hooks: list[AgentHook] | None = None,
|
||||
):
|
||||
from nanobot.config.schema import ExecToolConfig, WebSearchConfig
|
||||
|
||||
@ -85,6 +151,7 @@ class AgentLoop:
|
||||
self.restrict_to_workspace = restrict_to_workspace
|
||||
self._start_time = time.time()
|
||||
self._last_usage: dict[str, int] = {}
|
||||
self._extra_hooks: list[AgentHook] = hooks or []
|
||||
|
||||
self.context = ContextBuilder(workspace, timezone=timezone)
|
||||
self.sessions = session_manager or SessionManager(workspace)
|
||||
@ -217,52 +284,27 @@ class AgentLoop:
|
||||
``resuming=True`` means tool calls follow (spinner should restart);
|
||||
``resuming=False`` means this is the final response.
|
||||
"""
|
||||
loop_self = self
|
||||
|
||||
class _LoopHook(AgentHook):
|
||||
def __init__(self) -> None:
|
||||
self._stream_buf = ""
|
||||
|
||||
def wants_streaming(self) -> bool:
|
||||
return on_stream is not None
|
||||
|
||||
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
|
||||
from nanobot.utils.helpers import strip_think
|
||||
|
||||
prev_clean = strip_think(self._stream_buf)
|
||||
self._stream_buf += delta
|
||||
new_clean = strip_think(self._stream_buf)
|
||||
incremental = new_clean[len(prev_clean):]
|
||||
if incremental and on_stream:
|
||||
await on_stream(incremental)
|
||||
|
||||
async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None:
|
||||
if on_stream_end:
|
||||
await on_stream_end(resuming=resuming)
|
||||
self._stream_buf = ""
|
||||
|
||||
async def before_execute_tools(self, context: AgentHookContext) -> None:
|
||||
if on_progress:
|
||||
if not on_stream:
|
||||
thought = loop_self._strip_think(context.response.content if context.response else None)
|
||||
if thought:
|
||||
await on_progress(thought)
|
||||
tool_hint = loop_self._strip_think(loop_self._tool_hint(context.tool_calls))
|
||||
await on_progress(tool_hint, tool_hint=True)
|
||||
for tc in context.tool_calls:
|
||||
args_str = json.dumps(tc.arguments, ensure_ascii=False)
|
||||
logger.info("Tool call: {}({})", tc.name, args_str[:200])
|
||||
loop_self._set_tool_context(channel, chat_id, message_id)
|
||||
|
||||
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
|
||||
return loop_self._strip_think(content)
|
||||
loop_hook = LoopHook(
|
||||
self,
|
||||
on_progress=on_progress,
|
||||
on_stream=on_stream,
|
||||
on_stream_end=on_stream_end,
|
||||
channel=channel,
|
||||
chat_id=chat_id,
|
||||
message_id=message_id,
|
||||
)
|
||||
hook: AgentHook = (
|
||||
CompositeHook([loop_hook, *self._extra_hooks])
|
||||
if self._extra_hooks
|
||||
else loop_hook
|
||||
)
|
||||
|
||||
result = await self.runner.run(AgentRunSpec(
|
||||
initial_messages=initial_messages,
|
||||
tools=self.tools,
|
||||
model=self.model,
|
||||
max_iterations=self.max_iterations,
|
||||
hook=_LoopHook(),
|
||||
hook=hook,
|
||||
error_message="Sorry, I encountered an error calling the AI model.",
|
||||
concurrent_tools=True,
|
||||
))
|
||||
|
||||
@ -21,6 +21,24 @@ from nanobot.config.schema import ExecToolConfig
|
||||
from nanobot.providers.base import LLMProvider
|
||||
|
||||
|
||||
class SubagentHook(AgentHook):
|
||||
"""Logging-only hook for subagent execution.
|
||||
|
||||
Public so downstream users can subclass or compose via :class:`CompositeHook`.
|
||||
"""
|
||||
|
||||
def __init__(self, task_id: str) -> None:
|
||||
self._task_id = task_id
|
||||
|
||||
async def before_execute_tools(self, context: AgentHookContext) -> None:
|
||||
for tool_call in context.tool_calls:
|
||||
args_str = json.dumps(tool_call.arguments, ensure_ascii=False)
|
||||
logger.debug(
|
||||
"Subagent [{}] executing: {} with arguments: {}",
|
||||
self._task_id, tool_call.name, args_str,
|
||||
)
|
||||
|
||||
|
||||
class SubagentManager:
|
||||
"""Manages background subagent execution."""
|
||||
|
||||
@ -108,25 +126,19 @@ class SubagentManager:
|
||||
))
|
||||
tools.register(WebSearchTool(config=self.web_search_config, proxy=self.web_proxy))
|
||||
tools.register(WebFetchTool(proxy=self.web_proxy))
|
||||
|
||||
|
||||
system_prompt = self._build_subagent_prompt()
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": task},
|
||||
]
|
||||
|
||||
class _SubagentHook(AgentHook):
|
||||
async def before_execute_tools(self, context: AgentHookContext) -> None:
|
||||
for tool_call in context.tool_calls:
|
||||
args_str = json.dumps(tool_call.arguments, ensure_ascii=False)
|
||||
logger.debug("Subagent [{}] executing: {} with arguments: {}", task_id, tool_call.name, args_str)
|
||||
|
||||
result = await self.runner.run(AgentRunSpec(
|
||||
initial_messages=messages,
|
||||
tools=tools,
|
||||
model=self.model,
|
||||
max_iterations=15,
|
||||
hook=_SubagentHook(),
|
||||
hook=SubagentHook(task_id),
|
||||
max_iterations_message="Task completed but no final response was generated.",
|
||||
error_message=None,
|
||||
fail_on_tool_error=True,
|
||||
@ -213,7 +225,7 @@ Summarize this naturally for the user. Keep it brief (1-2 sentences). Do not men
|
||||
lines.append("Failure:")
|
||||
lines.append(f"- {result.error}")
|
||||
return "\n".join(lines) or (result.error or "Error: subagent execution failed.")
|
||||
|
||||
|
||||
def _build_subagent_prompt(self) -> str:
|
||||
"""Build a focused system prompt for the subagent."""
|
||||
from nanobot.agent.context import ContextBuilder
|
||||
|
||||
330
tests/agent/test_hook_composite.py
Normal file
330
tests/agent/test_hook_composite.py
Normal file
@ -0,0 +1,330 @@
|
||||
"""Tests for CompositeHook fan-out, error isolation, and integration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook
|
||||
|
||||
|
||||
def _ctx() -> AgentHookContext:
|
||||
return AgentHookContext(iteration=0, messages=[])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fan-out: every hook is called in order
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_composite_fans_out_before_iteration():
|
||||
calls: list[str] = []
|
||||
|
||||
class H(AgentHook):
|
||||
async def before_iteration(self, context: AgentHookContext) -> None:
|
||||
calls.append(f"A:{context.iteration}")
|
||||
|
||||
class H2(AgentHook):
|
||||
async def before_iteration(self, context: AgentHookContext) -> None:
|
||||
calls.append(f"B:{context.iteration}")
|
||||
|
||||
hook = CompositeHook([H(), H2()])
|
||||
ctx = _ctx()
|
||||
await hook.before_iteration(ctx)
|
||||
assert calls == ["A:0", "B:0"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_composite_fans_out_all_async_methods():
|
||||
"""Verify all async methods fan out to every hook."""
|
||||
events: list[str] = []
|
||||
|
||||
class RecordingHook(AgentHook):
|
||||
async def before_iteration(self, context: AgentHookContext) -> None:
|
||||
events.append("before_iteration")
|
||||
|
||||
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
|
||||
events.append(f"on_stream:{delta}")
|
||||
|
||||
async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None:
|
||||
events.append(f"on_stream_end:{resuming}")
|
||||
|
||||
async def before_execute_tools(self, context: AgentHookContext) -> None:
|
||||
events.append("before_execute_tools")
|
||||
|
||||
async def after_iteration(self, context: AgentHookContext) -> None:
|
||||
events.append("after_iteration")
|
||||
|
||||
hook = CompositeHook([RecordingHook(), RecordingHook()])
|
||||
ctx = _ctx()
|
||||
|
||||
await hook.before_iteration(ctx)
|
||||
await hook.on_stream(ctx, "hi")
|
||||
await hook.on_stream_end(ctx, resuming=True)
|
||||
await hook.before_execute_tools(ctx)
|
||||
await hook.after_iteration(ctx)
|
||||
|
||||
assert events == [
|
||||
"before_iteration", "before_iteration",
|
||||
"on_stream:hi", "on_stream:hi",
|
||||
"on_stream_end:True", "on_stream_end:True",
|
||||
"before_execute_tools", "before_execute_tools",
|
||||
"after_iteration", "after_iteration",
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Error isolation: one hook raises, others still run
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_composite_error_isolation_before_iteration():
|
||||
calls: list[str] = []
|
||||
|
||||
class Bad(AgentHook):
|
||||
async def before_iteration(self, context: AgentHookContext) -> None:
|
||||
raise RuntimeError("boom")
|
||||
|
||||
class Good(AgentHook):
|
||||
async def before_iteration(self, context: AgentHookContext) -> None:
|
||||
calls.append("good")
|
||||
|
||||
hook = CompositeHook([Bad(), Good()])
|
||||
await hook.before_iteration(_ctx())
|
||||
assert calls == ["good"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_composite_error_isolation_on_stream():
|
||||
calls: list[str] = []
|
||||
|
||||
class Bad(AgentHook):
|
||||
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
|
||||
raise RuntimeError("stream-boom")
|
||||
|
||||
class Good(AgentHook):
|
||||
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
|
||||
calls.append(delta)
|
||||
|
||||
hook = CompositeHook([Bad(), Good()])
|
||||
await hook.on_stream(_ctx(), "delta")
|
||||
assert calls == ["delta"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_composite_error_isolation_all_async():
|
||||
"""Error isolation for on_stream_end, before_execute_tools, after_iteration."""
|
||||
calls: list[str] = []
|
||||
|
||||
class Bad(AgentHook):
|
||||
async def on_stream_end(self, context, *, resuming):
|
||||
raise RuntimeError("err")
|
||||
async def before_execute_tools(self, context):
|
||||
raise RuntimeError("err")
|
||||
async def after_iteration(self, context):
|
||||
raise RuntimeError("err")
|
||||
|
||||
class Good(AgentHook):
|
||||
async def on_stream_end(self, context, *, resuming):
|
||||
calls.append("on_stream_end")
|
||||
async def before_execute_tools(self, context):
|
||||
calls.append("before_execute_tools")
|
||||
async def after_iteration(self, context):
|
||||
calls.append("after_iteration")
|
||||
|
||||
hook = CompositeHook([Bad(), Good()])
|
||||
ctx = _ctx()
|
||||
await hook.on_stream_end(ctx, resuming=False)
|
||||
await hook.before_execute_tools(ctx)
|
||||
await hook.after_iteration(ctx)
|
||||
assert calls == ["on_stream_end", "before_execute_tools", "after_iteration"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# finalize_content: pipeline semantics (no error isolation)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_composite_finalize_content_pipeline():
|
||||
class Upper(AgentHook):
|
||||
def finalize_content(self, context, content):
|
||||
return content.upper() if content else content
|
||||
|
||||
class Suffix(AgentHook):
|
||||
def finalize_content(self, context, content):
|
||||
return (content + "!") if content else content
|
||||
|
||||
hook = CompositeHook([Upper(), Suffix()])
|
||||
result = hook.finalize_content(_ctx(), "hello")
|
||||
assert result == "HELLO!"
|
||||
|
||||
|
||||
def test_composite_finalize_content_none_passthrough():
|
||||
hook = CompositeHook([AgentHook()])
|
||||
assert hook.finalize_content(_ctx(), None) is None
|
||||
|
||||
|
||||
def test_composite_finalize_content_ordering():
|
||||
"""First hook transforms first, result feeds second hook."""
|
||||
steps: list[str] = []
|
||||
|
||||
class H1(AgentHook):
|
||||
def finalize_content(self, context, content):
|
||||
steps.append(f"H1:{content}")
|
||||
return content.upper()
|
||||
|
||||
class H2(AgentHook):
|
||||
def finalize_content(self, context, content):
|
||||
steps.append(f"H2:{content}")
|
||||
return content + "!"
|
||||
|
||||
hook = CompositeHook([H1(), H2()])
|
||||
result = hook.finalize_content(_ctx(), "hi")
|
||||
assert result == "HI!"
|
||||
assert steps == ["H1:hi", "H2:HI"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# wants_streaming: any-semantics
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_composite_wants_streaming_any_true():
|
||||
class No(AgentHook):
|
||||
def wants_streaming(self):
|
||||
return False
|
||||
|
||||
class Yes(AgentHook):
|
||||
def wants_streaming(self):
|
||||
return True
|
||||
|
||||
hook = CompositeHook([No(), Yes(), No()])
|
||||
assert hook.wants_streaming() is True
|
||||
|
||||
|
||||
def test_composite_wants_streaming_all_false():
|
||||
hook = CompositeHook([AgentHook(), AgentHook()])
|
||||
assert hook.wants_streaming() is False
|
||||
|
||||
|
||||
def test_composite_wants_streaming_empty():
|
||||
hook = CompositeHook([])
|
||||
assert hook.wants_streaming() is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Empty hooks list: behaves like no-op AgentHook
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_composite_empty_hooks_no_ops():
|
||||
hook = CompositeHook([])
|
||||
ctx = _ctx()
|
||||
await hook.before_iteration(ctx)
|
||||
await hook.on_stream(ctx, "delta")
|
||||
await hook.on_stream_end(ctx, resuming=False)
|
||||
await hook.before_execute_tools(ctx)
|
||||
await hook.after_iteration(ctx)
|
||||
assert hook.finalize_content(ctx, "test") == "test"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration: AgentLoop with extra hooks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_loop(tmp_path, hooks=None):
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.queue import MessageBus
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
provider.generation.max_tokens = 4096
|
||||
|
||||
with patch("nanobot.agent.loop.ContextBuilder"), \
|
||||
patch("nanobot.agent.loop.SessionManager"), \
|
||||
patch("nanobot.agent.loop.SubagentManager") as mock_sub_mgr, \
|
||||
patch("nanobot.agent.loop.MemoryConsolidator"):
|
||||
mock_sub_mgr.return_value.cancel_by_session = AsyncMock(return_value=0)
|
||||
loop = AgentLoop(
|
||||
bus=bus, provider=provider, workspace=tmp_path, hooks=hooks,
|
||||
)
|
||||
return loop
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_loop_extra_hook_receives_calls(tmp_path):
|
||||
"""Extra hook passed to AgentLoop is called alongside core LoopHook."""
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
events: list[str] = []
|
||||
|
||||
class TrackingHook(AgentHook):
|
||||
async def before_iteration(self, context):
|
||||
events.append(f"before_iter:{context.iteration}")
|
||||
|
||||
async def after_iteration(self, context):
|
||||
events.append(f"after_iter:{context.iteration}")
|
||||
|
||||
loop = _make_loop(tmp_path, hooks=[TrackingHook()])
|
||||
loop.provider.chat_with_retry = AsyncMock(
|
||||
return_value=LLMResponse(content="done", tool_calls=[], usage={})
|
||||
)
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
content, tools_used, messages = await loop._run_agent_loop(
|
||||
[{"role": "user", "content": "hi"}]
|
||||
)
|
||||
|
||||
assert content == "done"
|
||||
assert "before_iter:0" in events
|
||||
assert "after_iter:0" in events
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_loop_extra_hook_error_isolation(tmp_path):
|
||||
"""A faulty extra hook does not crash the agent loop."""
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
class BadHook(AgentHook):
|
||||
async def before_iteration(self, context):
|
||||
raise RuntimeError("I am broken")
|
||||
|
||||
loop = _make_loop(tmp_path, hooks=[BadHook()])
|
||||
loop.provider.chat_with_retry = AsyncMock(
|
||||
return_value=LLMResponse(content="still works", tool_calls=[], usage={})
|
||||
)
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
content, _, _ = await loop._run_agent_loop(
|
||||
[{"role": "user", "content": "hi"}]
|
||||
)
|
||||
|
||||
assert content == "still works"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_loop_no_hooks_backward_compat(tmp_path):
|
||||
"""Without hooks param, behavior is identical to before."""
|
||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||
|
||||
loop = _make_loop(tmp_path)
|
||||
loop.provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
|
||||
content="working",
|
||||
tool_calls=[ToolCallRequest(id="c1", name="list_dir", arguments={"path": "."})],
|
||||
))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
loop.tools.execute = AsyncMock(return_value="ok")
|
||||
loop.max_iterations = 2
|
||||
|
||||
content, tools_used, _ = await loop._run_agent_loop([])
|
||||
assert content == (
|
||||
"I reached the maximum number of tool call iterations (2) "
|
||||
"without completing the task. You can try breaking the task into smaller steps."
|
||||
)
|
||||
assert tools_used == ["list_dir", "list_dir"]
|
||||
Loading…
x
Reference in New Issue
Block a user