diff --git a/nanobot/agent/hook.py b/nanobot/agent/hook.py new file mode 100644 index 000000000..368c46aa2 --- /dev/null +++ b/nanobot/agent/hook.py @@ -0,0 +1,49 @@ +"""Shared lifecycle hook primitives for agent runs.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + +from nanobot.providers.base import LLMResponse, ToolCallRequest + + +@dataclass(slots=True) +class AgentHookContext: + """Mutable per-iteration state exposed to runner hooks.""" + + iteration: int + messages: list[dict[str, Any]] + response: LLMResponse | None = None + usage: dict[str, int] = field(default_factory=dict) + tool_calls: list[ToolCallRequest] = field(default_factory=list) + tool_results: list[Any] = field(default_factory=list) + tool_events: list[dict[str, str]] = field(default_factory=list) + final_content: str | None = None + stop_reason: str | None = None + error: str | None = None + + +class AgentHook: + """Minimal lifecycle surface for shared runner customization.""" + + def wants_streaming(self) -> bool: + return False + + async def before_iteration(self, context: AgentHookContext) -> None: + pass + + async def on_stream(self, context: AgentHookContext, delta: str) -> None: + pass + + async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None: + pass + + async def before_execute_tools(self, context: AgentHookContext) -> None: + pass + + async def after_iteration(self, context: AgentHookContext) -> None: + pass + + def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None: + return content diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index 2a3109a38..63ee92ca5 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -14,6 +14,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable from loguru import logger from nanobot.agent.context import ContextBuilder +from nanobot.agent.hook import AgentHook, AgentHookContext from nanobot.agent.memory import MemoryConsolidator from nanobot.agent.runner import AgentRunSpec, AgentRunner from nanobot.agent.subagent import SubagentManager @@ -216,53 +217,52 @@ class AgentLoop: ``resuming=True`` means tool calls follow (spinner should restart); ``resuming=False`` means this is the final response. """ - # Wrap on_stream with stateful think-tag filter so downstream - # consumers (CLI, channels) never see blocks. - _raw_stream = on_stream - _stream_buf = "" + loop_self = self - async def _filtered_stream(delta: str) -> None: - nonlocal _stream_buf - from nanobot.utils.helpers import strip_think - prev_clean = strip_think(_stream_buf) - _stream_buf += delta - new_clean = strip_think(_stream_buf) - incremental = new_clean[len(prev_clean):] - if incremental and _raw_stream: - await _raw_stream(incremental) + class _LoopHook(AgentHook): + def __init__(self) -> None: + self._stream_buf = "" - async def _wrapped_stream_end(*, resuming: bool = False) -> None: - nonlocal _stream_buf - if on_stream_end: - await on_stream_end(resuming=resuming) - _stream_buf = "" + def wants_streaming(self) -> bool: + return on_stream is not None - async def _handle_tool_calls(response) -> None: - if not on_progress: - return - if not on_stream: - thought = self._strip_think(response.content) - if thought: - await on_progress(thought) - tool_hint = self._strip_think(self._tool_hint(response.tool_calls)) - await on_progress(tool_hint, tool_hint=True) + async def on_stream(self, context: AgentHookContext, delta: str) -> None: + from nanobot.utils.helpers import strip_think - async def _prepare_tools(tool_calls) -> None: - for tc in tool_calls: - args_str = json.dumps(tc.arguments, ensure_ascii=False) - logger.info("Tool call: {}({})", tc.name, args_str[:200]) - self._set_tool_context(channel, chat_id, message_id) + prev_clean = strip_think(self._stream_buf) + self._stream_buf += delta + new_clean = strip_think(self._stream_buf) + incremental = new_clean[len(prev_clean):] + if incremental and on_stream: + await on_stream(incremental) + + async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None: + if on_stream_end: + await on_stream_end(resuming=resuming) + self._stream_buf = "" + + async def before_execute_tools(self, context: AgentHookContext) -> None: + if on_progress: + if not on_stream: + thought = loop_self._strip_think(context.response.content if context.response else None) + if thought: + await on_progress(thought) + tool_hint = loop_self._strip_think(loop_self._tool_hint(context.tool_calls)) + await on_progress(tool_hint, tool_hint=True) + for tc in context.tool_calls: + args_str = json.dumps(tc.arguments, ensure_ascii=False) + logger.info("Tool call: {}({})", tc.name, args_str[:200]) + loop_self._set_tool_context(channel, chat_id, message_id) + + def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None: + return loop_self._strip_think(content) result = await self.runner.run(AgentRunSpec( initial_messages=initial_messages, tools=self.tools, model=self.model, max_iterations=self.max_iterations, - on_stream=_filtered_stream if on_stream else None, - on_stream_end=_wrapped_stream_end if on_stream else None, - on_tool_calls=_handle_tool_calls, - before_execute_tools=_prepare_tools, - finalize_content=self._strip_think, + hook=_LoopHook(), error_message="Sorry, I encountered an error calling the AI model.", concurrent_tools=True, )) diff --git a/nanobot/agent/runner.py b/nanobot/agent/runner.py index 1827bab66..d6242a6b4 100644 --- a/nanobot/agent/runner.py +++ b/nanobot/agent/runner.py @@ -3,12 +3,12 @@ from __future__ import annotations import asyncio -from collections.abc import Awaitable, Callable from dataclasses import dataclass, field from typing import Any +from nanobot.agent.hook import AgentHook, AgentHookContext from nanobot.agent.tools.registry import ToolRegistry -from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest +from nanobot.providers.base import LLMProvider, ToolCallRequest from nanobot.utils.helpers import build_assistant_message _DEFAULT_MAX_ITERATIONS_MESSAGE = ( @@ -29,11 +29,7 @@ class AgentRunSpec: temperature: float | None = None max_tokens: int | None = None reasoning_effort: str | None = None - on_stream: Callable[[str], Awaitable[None]] | None = None - on_stream_end: Callable[..., Awaitable[None]] | None = None - on_tool_calls: Callable[[LLMResponse], Awaitable[None] | None] | None = None - before_execute_tools: Callable[[list[ToolCallRequest]], Awaitable[None] | None] | None = None - finalize_content: Callable[[str | None], str | None] | None = None + hook: AgentHook | None = None error_message: str | None = _DEFAULT_ERROR_MESSAGE max_iterations_message: str | None = None concurrent_tools: bool = False @@ -60,6 +56,7 @@ class AgentRunner: self.provider = provider async def run(self, spec: AgentRunSpec) -> AgentRunResult: + hook = spec.hook or AgentHook() messages = list(spec.initial_messages) final_content: str | None = None tools_used: list[str] = [] @@ -68,7 +65,9 @@ class AgentRunner: stop_reason = "completed" tool_events: list[dict[str, str]] = [] - for _ in range(spec.max_iterations): + for iteration in range(spec.max_iterations): + context = AgentHookContext(iteration=iteration, messages=messages) + await hook.before_iteration(context) kwargs: dict[str, Any] = { "messages": messages, "tools": spec.tools.get_definitions(), @@ -81,10 +80,13 @@ class AgentRunner: if spec.reasoning_effort is not None: kwargs["reasoning_effort"] = spec.reasoning_effort - if spec.on_stream: + if hook.wants_streaming(): + async def _stream(delta: str) -> None: + await hook.on_stream(context, delta) + response = await self.provider.chat_stream_with_retry( **kwargs, - on_content_delta=spec.on_stream, + on_content_delta=_stream, ) else: response = await self.provider.chat_with_retry(**kwargs) @@ -94,14 +96,13 @@ class AgentRunner: "prompt_tokens": int(raw_usage.get("prompt_tokens", 0) or 0), "completion_tokens": int(raw_usage.get("completion_tokens", 0) or 0), } + context.response = response + context.usage = usage + context.tool_calls = list(response.tool_calls) if response.has_tool_calls: - if spec.on_stream_end: - await spec.on_stream_end(resuming=True) - if spec.on_tool_calls: - maybe = spec.on_tool_calls(response) - if maybe is not None: - await maybe + if hook.wants_streaming(): + await hook.on_stream_end(context, resuming=True) messages.append(build_assistant_message( response.content or "", @@ -111,16 +112,18 @@ class AgentRunner: )) tools_used.extend(tc.name for tc in response.tool_calls) - if spec.before_execute_tools: - maybe = spec.before_execute_tools(response.tool_calls) - if maybe is not None: - await maybe + await hook.before_execute_tools(context) results, new_events, fatal_error = await self._execute_tools(spec, response.tool_calls) tool_events.extend(new_events) + context.tool_results = list(results) + context.tool_events = list(new_events) if fatal_error is not None: error = f"Error: {type(fatal_error).__name__}: {fatal_error}" stop_reason = "tool_error" + context.error = error + context.stop_reason = stop_reason + await hook.after_iteration(context) break for tool_call, result in zip(response.tool_calls, results): messages.append({ @@ -129,16 +132,21 @@ class AgentRunner: "name": tool_call.name, "content": result, }) + await hook.after_iteration(context) continue - if spec.on_stream_end: - await spec.on_stream_end(resuming=False) + if hook.wants_streaming(): + await hook.on_stream_end(context, resuming=False) - clean = spec.finalize_content(response.content) if spec.finalize_content else response.content + clean = hook.finalize_content(context, response.content) if response.finish_reason == "error": final_content = clean or spec.error_message or _DEFAULT_ERROR_MESSAGE stop_reason = "error" error = final_content + context.final_content = final_content + context.error = error + context.stop_reason = stop_reason + await hook.after_iteration(context) break messages.append(build_assistant_message( @@ -147,6 +155,9 @@ class AgentRunner: thinking_blocks=response.thinking_blocks, )) final_content = clean + context.final_content = final_content + context.stop_reason = stop_reason + await hook.after_iteration(context) break else: stop_reason = "max_iterations" diff --git a/nanobot/agent/subagent.py b/nanobot/agent/subagent.py index 4d112b834..5266fc8b1 100644 --- a/nanobot/agent/subagent.py +++ b/nanobot/agent/subagent.py @@ -8,6 +8,7 @@ from typing import Any from loguru import logger +from nanobot.agent.hook import AgentHook, AgentHookContext from nanobot.agent.runner import AgentRunSpec, AgentRunner from nanobot.agent.skills import BUILTIN_SKILLS_DIR from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool @@ -113,17 +114,19 @@ class SubagentManager: {"role": "system", "content": system_prompt}, {"role": "user", "content": task}, ] - async def _log_tool_calls(tool_calls) -> None: - for tool_call in tool_calls: - args_str = json.dumps(tool_call.arguments, ensure_ascii=False) - logger.debug("Subagent [{}] executing: {} with arguments: {}", task_id, tool_call.name, args_str) + + class _SubagentHook(AgentHook): + async def before_execute_tools(self, context: AgentHookContext) -> None: + for tool_call in context.tool_calls: + args_str = json.dumps(tool_call.arguments, ensure_ascii=False) + logger.debug("Subagent [{}] executing: {} with arguments: {}", task_id, tool_call.name, args_str) result = await self.runner.run(AgentRunSpec( initial_messages=messages, tools=tools, model=self.model, max_iterations=15, - before_execute_tools=_log_tool_calls, + hook=_SubagentHook(), max_iterations_message="Task completed but no final response was generated.", error_message=None, fail_on_tool_error=True, diff --git a/tests/agent/test_runner.py b/tests/agent/test_runner.py index b534c03c6..86b0ba710 100644 --- a/tests/agent/test_runner.py +++ b/tests/agent/test_runner.py @@ -81,6 +81,125 @@ async def test_runner_preserves_reasoning_fields_and_tool_results(): ) +@pytest.mark.asyncio +async def test_runner_calls_hooks_in_order(): + from nanobot.agent.hook import AgentHook, AgentHookContext + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock() + call_count = {"n": 0} + events: list[tuple] = [] + + async def chat_with_retry(**kwargs): + call_count["n"] += 1 + if call_count["n"] == 1: + return LLMResponse( + content="thinking", + tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})], + ) + return LLMResponse(content="done", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + tools.execute = AsyncMock(return_value="tool result") + + class RecordingHook(AgentHook): + async def before_iteration(self, context: AgentHookContext) -> None: + events.append(("before_iteration", context.iteration)) + + async def before_execute_tools(self, context: AgentHookContext) -> None: + events.append(( + "before_execute_tools", + context.iteration, + [tc.name for tc in context.tool_calls], + )) + + async def after_iteration(self, context: AgentHookContext) -> None: + events.append(( + "after_iteration", + context.iteration, + context.final_content, + list(context.tool_results), + list(context.tool_events), + context.stop_reason, + )) + + def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None: + events.append(("finalize_content", context.iteration, content)) + return content.upper() if content else content + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[], + tools=tools, + model="test-model", + max_iterations=3, + hook=RecordingHook(), + )) + + assert result.final_content == "DONE" + assert events == [ + ("before_iteration", 0), + ("before_execute_tools", 0, ["list_dir"]), + ( + "after_iteration", + 0, + None, + ["tool result"], + [{"name": "list_dir", "status": "ok", "detail": "tool result"}], + None, + ), + ("before_iteration", 1), + ("finalize_content", 1, "done"), + ("after_iteration", 1, "DONE", [], [], "completed"), + ] + + +@pytest.mark.asyncio +async def test_runner_streaming_hook_receives_deltas_and_end_signal(): + from nanobot.agent.hook import AgentHook, AgentHookContext + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock() + streamed: list[str] = [] + endings: list[bool] = [] + + async def chat_stream_with_retry(*, on_content_delta, **kwargs): + await on_content_delta("he") + await on_content_delta("llo") + return LLMResponse(content="hello", tool_calls=[], usage={}) + + provider.chat_stream_with_retry = chat_stream_with_retry + provider.chat_with_retry = AsyncMock() + tools = MagicMock() + tools.get_definitions.return_value = [] + + class StreamingHook(AgentHook): + def wants_streaming(self) -> bool: + return True + + async def on_stream(self, context: AgentHookContext, delta: str) -> None: + streamed.append(delta) + + async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None: + endings.append(resuming) + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[], + tools=tools, + model="test-model", + max_iterations=1, + hook=StreamingHook(), + )) + + assert result.final_content == "hello" + assert streamed == ["he", "llo"] + assert endings == [False] + provider.chat_with_retry.assert_not_awaited() + + @pytest.mark.asyncio async def test_runner_returns_max_iterations_fallback(): from nanobot.agent.runner import AgentRunSpec, AgentRunner @@ -158,6 +277,36 @@ async def test_loop_max_iterations_message_stays_stable(tmp_path): ) +@pytest.mark.asyncio +async def test_loop_stream_filter_handles_think_only_prefix_without_crashing(tmp_path): + loop = _make_loop(tmp_path) + deltas: list[str] = [] + endings: list[bool] = [] + + async def chat_stream_with_retry(*, on_content_delta, **kwargs): + await on_content_delta("hidden") + await on_content_delta("Hello") + return LLMResponse(content="hiddenHello", tool_calls=[], usage={}) + + loop.provider.chat_stream_with_retry = chat_stream_with_retry + + async def on_stream(delta: str) -> None: + deltas.append(delta) + + async def on_stream_end(*, resuming: bool = False) -> None: + endings.append(resuming) + + final_content, _, _ = await loop._run_agent_loop( + [], + on_stream=on_stream, + on_stream_end=on_stream_end, + ) + + assert final_content == "Hello" + assert deltas == ["Hello"] + assert endings == [False] + + @pytest.mark.asyncio async def test_subagent_max_iterations_announces_existing_fallback(tmp_path, monkeypatch): from nanobot.agent.subagent import SubagentManager