mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-02 09:22:36 +00:00
refactor: extract shared agent runner and preserve subagent progress on failure
This commit is contained in:
parent
33abe915e7
commit
e7d371ec1e
@ -15,6 +15,7 @@ from loguru import logger
|
||||
|
||||
from nanobot.agent.context import ContextBuilder
|
||||
from nanobot.agent.memory import MemoryConsolidator
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
from nanobot.agent.subagent import SubagentManager
|
||||
from nanobot.agent.tools.cron import CronTool
|
||||
from nanobot.agent.skills import BUILTIN_SKILLS_DIR
|
||||
@ -87,6 +88,7 @@ class AgentLoop:
|
||||
self.context = ContextBuilder(workspace, timezone=timezone)
|
||||
self.sessions = session_manager or SessionManager(workspace)
|
||||
self.tools = ToolRegistry()
|
||||
self.runner = AgentRunner(provider)
|
||||
self.subagents = SubagentManager(
|
||||
provider=provider,
|
||||
workspace=workspace,
|
||||
@ -214,11 +216,6 @@ class AgentLoop:
|
||||
``resuming=True`` means tool calls follow (spinner should restart);
|
||||
``resuming=False`` means this is the final response.
|
||||
"""
|
||||
messages = initial_messages
|
||||
iteration = 0
|
||||
final_content = None
|
||||
tools_used: list[str] = []
|
||||
|
||||
# Wrap on_stream with stateful think-tag filter so downstream
|
||||
# consumers (CLI, channels) never see <think> blocks.
|
||||
_raw_stream = on_stream
|
||||
@ -234,104 +231,47 @@ class AgentLoop:
|
||||
if incremental and _raw_stream:
|
||||
await _raw_stream(incremental)
|
||||
|
||||
while iteration < self.max_iterations:
|
||||
iteration += 1
|
||||
async def _wrapped_stream_end(*, resuming: bool = False) -> None:
|
||||
nonlocal _stream_buf
|
||||
if on_stream_end:
|
||||
await on_stream_end(resuming=resuming)
|
||||
_stream_buf = ""
|
||||
|
||||
tool_defs = self.tools.get_definitions()
|
||||
async def _handle_tool_calls(response) -> None:
|
||||
if not on_progress:
|
||||
return
|
||||
if not on_stream:
|
||||
thought = self._strip_think(response.content)
|
||||
if thought:
|
||||
await on_progress(thought)
|
||||
tool_hint = self._strip_think(self._tool_hint(response.tool_calls))
|
||||
await on_progress(tool_hint, tool_hint=True)
|
||||
|
||||
if on_stream:
|
||||
response = await self.provider.chat_stream_with_retry(
|
||||
messages=messages,
|
||||
tools=tool_defs,
|
||||
model=self.model,
|
||||
on_content_delta=_filtered_stream,
|
||||
)
|
||||
else:
|
||||
response = await self.provider.chat_with_retry(
|
||||
messages=messages,
|
||||
tools=tool_defs,
|
||||
model=self.model,
|
||||
)
|
||||
async def _prepare_tools(tool_calls) -> None:
|
||||
for tc in tool_calls:
|
||||
args_str = json.dumps(tc.arguments, ensure_ascii=False)
|
||||
logger.info("Tool call: {}({})", tc.name, args_str[:200])
|
||||
self._set_tool_context(channel, chat_id, message_id)
|
||||
|
||||
usage = response.usage or {}
|
||||
self._last_usage = {
|
||||
"prompt_tokens": int(usage.get("prompt_tokens", 0) or 0),
|
||||
"completion_tokens": int(usage.get("completion_tokens", 0) or 0),
|
||||
}
|
||||
|
||||
if response.has_tool_calls:
|
||||
if on_stream and on_stream_end:
|
||||
await on_stream_end(resuming=True)
|
||||
_stream_buf = ""
|
||||
|
||||
if on_progress:
|
||||
if not on_stream:
|
||||
thought = self._strip_think(response.content)
|
||||
if thought:
|
||||
await on_progress(thought)
|
||||
tool_hint = self._tool_hint(response.tool_calls)
|
||||
tool_hint = self._strip_think(tool_hint)
|
||||
await on_progress(tool_hint, tool_hint=True)
|
||||
|
||||
tool_call_dicts = [
|
||||
tc.to_openai_tool_call()
|
||||
for tc in response.tool_calls
|
||||
]
|
||||
messages = self.context.add_assistant_message(
|
||||
messages, response.content, tool_call_dicts,
|
||||
reasoning_content=response.reasoning_content,
|
||||
thinking_blocks=response.thinking_blocks,
|
||||
)
|
||||
|
||||
for tc in response.tool_calls:
|
||||
tools_used.append(tc.name)
|
||||
args_str = json.dumps(tc.arguments, ensure_ascii=False)
|
||||
logger.info("Tool call: {}({})", tc.name, args_str[:200])
|
||||
|
||||
# Re-bind tool context right before execution so that
|
||||
# concurrent sessions don't clobber each other's routing.
|
||||
self._set_tool_context(channel, chat_id, message_id)
|
||||
|
||||
# Execute all tool calls concurrently — the LLM batches
|
||||
# independent calls in a single response on purpose.
|
||||
# return_exceptions=True ensures all results are collected
|
||||
# even if one tool is cancelled or raises BaseException.
|
||||
results = await asyncio.gather(*(
|
||||
self.tools.execute(tc.name, tc.arguments)
|
||||
for tc in response.tool_calls
|
||||
), return_exceptions=True)
|
||||
|
||||
for tool_call, result in zip(response.tool_calls, results):
|
||||
if isinstance(result, BaseException):
|
||||
result = f"Error: {type(result).__name__}: {result}"
|
||||
messages = self.context.add_tool_result(
|
||||
messages, tool_call.id, tool_call.name, result
|
||||
)
|
||||
else:
|
||||
if on_stream and on_stream_end:
|
||||
await on_stream_end(resuming=False)
|
||||
_stream_buf = ""
|
||||
|
||||
clean = self._strip_think(response.content)
|
||||
if response.finish_reason == "error":
|
||||
logger.error("LLM returned error: {}", (clean or "")[:200])
|
||||
final_content = clean or "Sorry, I encountered an error calling the AI model."
|
||||
break
|
||||
messages = self.context.add_assistant_message(
|
||||
messages, clean, reasoning_content=response.reasoning_content,
|
||||
thinking_blocks=response.thinking_blocks,
|
||||
)
|
||||
final_content = clean
|
||||
break
|
||||
|
||||
if final_content is None and iteration >= self.max_iterations:
|
||||
result = await self.runner.run(AgentRunSpec(
|
||||
initial_messages=initial_messages,
|
||||
tools=self.tools,
|
||||
model=self.model,
|
||||
max_iterations=self.max_iterations,
|
||||
on_stream=_filtered_stream if on_stream else None,
|
||||
on_stream_end=_wrapped_stream_end if on_stream else None,
|
||||
on_tool_calls=_handle_tool_calls,
|
||||
before_execute_tools=_prepare_tools,
|
||||
finalize_content=self._strip_think,
|
||||
error_message="Sorry, I encountered an error calling the AI model.",
|
||||
concurrent_tools=True,
|
||||
))
|
||||
self._last_usage = result.usage
|
||||
if result.stop_reason == "max_iterations":
|
||||
logger.warning("Max iterations ({}) reached", self.max_iterations)
|
||||
final_content = (
|
||||
f"I reached the maximum number of tool call iterations ({self.max_iterations}) "
|
||||
"without completing the task. You can try breaking the task into smaller steps."
|
||||
)
|
||||
|
||||
return final_content, tools_used, messages
|
||||
elif result.stop_reason == "error":
|
||||
logger.error("LLM returned error: {}", (result.final_content or "")[:200])
|
||||
return result.final_content, result.tools_used, result.messages
|
||||
|
||||
async def run(self) -> None:
|
||||
"""Run the agent loop, dispatching messages as tasks to stay responsive to /stop."""
|
||||
|
||||
221
nanobot/agent/runner.py
Normal file
221
nanobot/agent/runner.py
Normal file
@ -0,0 +1,221 @@
|
||||
"""Shared execution loop for tool-using agents."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
from nanobot.utils.helpers import build_assistant_message
|
||||
|
||||
_DEFAULT_MAX_ITERATIONS_MESSAGE = (
|
||||
"I reached the maximum number of tool call iterations ({max_iterations}) "
|
||||
"without completing the task. You can try breaking the task into smaller steps."
|
||||
)
|
||||
_DEFAULT_ERROR_MESSAGE = "Sorry, I encountered an error calling the AI model."
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class AgentRunSpec:
|
||||
"""Configuration for a single agent execution."""
|
||||
|
||||
initial_messages: list[dict[str, Any]]
|
||||
tools: ToolRegistry
|
||||
model: str
|
||||
max_iterations: int
|
||||
temperature: float | None = None
|
||||
max_tokens: int | None = None
|
||||
reasoning_effort: str | None = None
|
||||
on_stream: Callable[[str], Awaitable[None]] | None = None
|
||||
on_stream_end: Callable[..., Awaitable[None]] | None = None
|
||||
on_tool_calls: Callable[[LLMResponse], Awaitable[None] | None] | None = None
|
||||
before_execute_tools: Callable[[list[ToolCallRequest]], Awaitable[None] | None] | None = None
|
||||
finalize_content: Callable[[str | None], str | None] | None = None
|
||||
error_message: str | None = _DEFAULT_ERROR_MESSAGE
|
||||
max_iterations_message: str | None = None
|
||||
concurrent_tools: bool = False
|
||||
fail_on_tool_error: bool = False
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class AgentRunResult:
|
||||
"""Outcome of a shared agent execution."""
|
||||
|
||||
final_content: str | None
|
||||
messages: list[dict[str, Any]]
|
||||
tools_used: list[str] = field(default_factory=list)
|
||||
usage: dict[str, int] = field(default_factory=dict)
|
||||
stop_reason: str = "completed"
|
||||
error: str | None = None
|
||||
tool_events: list[dict[str, str]] = field(default_factory=list)
|
||||
|
||||
|
||||
class AgentRunner:
|
||||
"""Run a tool-capable LLM loop without product-layer concerns."""
|
||||
|
||||
def __init__(self, provider: LLMProvider):
|
||||
self.provider = provider
|
||||
|
||||
async def run(self, spec: AgentRunSpec) -> AgentRunResult:
|
||||
messages = list(spec.initial_messages)
|
||||
final_content: str | None = None
|
||||
tools_used: list[str] = []
|
||||
usage = {"prompt_tokens": 0, "completion_tokens": 0}
|
||||
error: str | None = None
|
||||
stop_reason = "completed"
|
||||
tool_events: list[dict[str, str]] = []
|
||||
|
||||
for _ in range(spec.max_iterations):
|
||||
kwargs: dict[str, Any] = {
|
||||
"messages": messages,
|
||||
"tools": spec.tools.get_definitions(),
|
||||
"model": spec.model,
|
||||
}
|
||||
if spec.temperature is not None:
|
||||
kwargs["temperature"] = spec.temperature
|
||||
if spec.max_tokens is not None:
|
||||
kwargs["max_tokens"] = spec.max_tokens
|
||||
if spec.reasoning_effort is not None:
|
||||
kwargs["reasoning_effort"] = spec.reasoning_effort
|
||||
|
||||
if spec.on_stream:
|
||||
response = await self.provider.chat_stream_with_retry(
|
||||
**kwargs,
|
||||
on_content_delta=spec.on_stream,
|
||||
)
|
||||
else:
|
||||
response = await self.provider.chat_with_retry(**kwargs)
|
||||
|
||||
raw_usage = response.usage or {}
|
||||
usage = {
|
||||
"prompt_tokens": int(raw_usage.get("prompt_tokens", 0) or 0),
|
||||
"completion_tokens": int(raw_usage.get("completion_tokens", 0) or 0),
|
||||
}
|
||||
|
||||
if response.has_tool_calls:
|
||||
if spec.on_stream_end:
|
||||
await spec.on_stream_end(resuming=True)
|
||||
if spec.on_tool_calls:
|
||||
maybe = spec.on_tool_calls(response)
|
||||
if maybe is not None:
|
||||
await maybe
|
||||
|
||||
messages.append(build_assistant_message(
|
||||
response.content or "",
|
||||
tool_calls=[tc.to_openai_tool_call() for tc in response.tool_calls],
|
||||
reasoning_content=response.reasoning_content,
|
||||
thinking_blocks=response.thinking_blocks,
|
||||
))
|
||||
tools_used.extend(tc.name for tc in response.tool_calls)
|
||||
|
||||
if spec.before_execute_tools:
|
||||
maybe = spec.before_execute_tools(response.tool_calls)
|
||||
if maybe is not None:
|
||||
await maybe
|
||||
|
||||
results, new_events, fatal_error = await self._execute_tools(spec, response.tool_calls)
|
||||
tool_events.extend(new_events)
|
||||
if fatal_error is not None:
|
||||
error = f"Error: {type(fatal_error).__name__}: {fatal_error}"
|
||||
stop_reason = "tool_error"
|
||||
break
|
||||
for tool_call, result in zip(response.tool_calls, results):
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"name": tool_call.name,
|
||||
"content": result,
|
||||
})
|
||||
continue
|
||||
|
||||
if spec.on_stream_end:
|
||||
await spec.on_stream_end(resuming=False)
|
||||
|
||||
clean = spec.finalize_content(response.content) if spec.finalize_content else response.content
|
||||
if response.finish_reason == "error":
|
||||
final_content = clean or spec.error_message or _DEFAULT_ERROR_MESSAGE
|
||||
stop_reason = "error"
|
||||
error = final_content
|
||||
break
|
||||
|
||||
messages.append(build_assistant_message(
|
||||
clean,
|
||||
reasoning_content=response.reasoning_content,
|
||||
thinking_blocks=response.thinking_blocks,
|
||||
))
|
||||
final_content = clean
|
||||
break
|
||||
else:
|
||||
stop_reason = "max_iterations"
|
||||
template = spec.max_iterations_message or _DEFAULT_MAX_ITERATIONS_MESSAGE
|
||||
final_content = template.format(max_iterations=spec.max_iterations)
|
||||
|
||||
return AgentRunResult(
|
||||
final_content=final_content,
|
||||
messages=messages,
|
||||
tools_used=tools_used,
|
||||
usage=usage,
|
||||
stop_reason=stop_reason,
|
||||
error=error,
|
||||
tool_events=tool_events,
|
||||
)
|
||||
|
||||
async def _execute_tools(
|
||||
self,
|
||||
spec: AgentRunSpec,
|
||||
tool_calls: list[ToolCallRequest],
|
||||
) -> tuple[list[Any], list[dict[str, str]], BaseException | None]:
|
||||
if spec.concurrent_tools:
|
||||
tool_results = await asyncio.gather(*(
|
||||
self._run_tool(spec, tool_call)
|
||||
for tool_call in tool_calls
|
||||
))
|
||||
else:
|
||||
tool_results = [
|
||||
await self._run_tool(spec, tool_call)
|
||||
for tool_call in tool_calls
|
||||
]
|
||||
|
||||
results: list[Any] = []
|
||||
events: list[dict[str, str]] = []
|
||||
fatal_error: BaseException | None = None
|
||||
for result, event, error in tool_results:
|
||||
results.append(result)
|
||||
events.append(event)
|
||||
if error is not None and fatal_error is None:
|
||||
fatal_error = error
|
||||
return results, events, fatal_error
|
||||
|
||||
async def _run_tool(
|
||||
self,
|
||||
spec: AgentRunSpec,
|
||||
tool_call: ToolCallRequest,
|
||||
) -> tuple[Any, dict[str, str], BaseException | None]:
|
||||
try:
|
||||
result = await spec.tools.execute(tool_call.name, tool_call.arguments)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except BaseException as exc:
|
||||
event = {
|
||||
"name": tool_call.name,
|
||||
"status": "error",
|
||||
"detail": str(exc),
|
||||
}
|
||||
if spec.fail_on_tool_error:
|
||||
return f"Error: {type(exc).__name__}: {exc}", event, exc
|
||||
return f"Error: {type(exc).__name__}: {exc}", event, None
|
||||
|
||||
detail = "" if result is None else str(result)
|
||||
detail = detail.replace("\n", " ").strip()
|
||||
if not detail:
|
||||
detail = "(empty)"
|
||||
elif len(detail) > 120:
|
||||
detail = detail[:120] + "..."
|
||||
return result, {
|
||||
"name": tool_call.name,
|
||||
"status": "error" if isinstance(result, str) and result.startswith("Error") else "ok",
|
||||
"detail": detail,
|
||||
}, None
|
||||
@ -8,6 +8,7 @@ from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
from nanobot.agent.skills import BUILTIN_SKILLS_DIR
|
||||
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
@ -17,7 +18,6 @@ from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.config.schema import ExecToolConfig
|
||||
from nanobot.providers.base import LLMProvider
|
||||
from nanobot.utils.helpers import build_assistant_message
|
||||
|
||||
|
||||
class SubagentManager:
|
||||
@ -44,6 +44,7 @@ class SubagentManager:
|
||||
self.web_proxy = web_proxy
|
||||
self.exec_config = exec_config or ExecToolConfig()
|
||||
self.restrict_to_workspace = restrict_to_workspace
|
||||
self.runner = AgentRunner(provider)
|
||||
self._running_tasks: dict[str, asyncio.Task[None]] = {}
|
||||
self._session_tasks: dict[str, set[str]] = {} # session_key -> {task_id, ...}
|
||||
|
||||
@ -112,50 +113,42 @@ class SubagentManager:
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": task},
|
||||
]
|
||||
async def _log_tool_calls(tool_calls) -> None:
|
||||
for tool_call in tool_calls:
|
||||
args_str = json.dumps(tool_call.arguments, ensure_ascii=False)
|
||||
logger.debug("Subagent [{}] executing: {} with arguments: {}", task_id, tool_call.name, args_str)
|
||||
|
||||
# Run agent loop (limited iterations)
|
||||
max_iterations = 15
|
||||
iteration = 0
|
||||
final_result: str | None = None
|
||||
|
||||
while iteration < max_iterations:
|
||||
iteration += 1
|
||||
|
||||
response = await self.provider.chat_with_retry(
|
||||
messages=messages,
|
||||
tools=tools.get_definitions(),
|
||||
model=self.model,
|
||||
result = await self.runner.run(AgentRunSpec(
|
||||
initial_messages=messages,
|
||||
tools=tools,
|
||||
model=self.model,
|
||||
max_iterations=15,
|
||||
before_execute_tools=_log_tool_calls,
|
||||
max_iterations_message="Task completed but no final response was generated.",
|
||||
error_message=None,
|
||||
fail_on_tool_error=True,
|
||||
))
|
||||
if result.stop_reason == "tool_error":
|
||||
await self._announce_result(
|
||||
task_id,
|
||||
label,
|
||||
task,
|
||||
self._format_partial_progress(result),
|
||||
origin,
|
||||
"error",
|
||||
)
|
||||
|
||||
if response.has_tool_calls:
|
||||
tool_call_dicts = [
|
||||
tc.to_openai_tool_call()
|
||||
for tc in response.tool_calls
|
||||
]
|
||||
messages.append(build_assistant_message(
|
||||
response.content or "",
|
||||
tool_calls=tool_call_dicts,
|
||||
reasoning_content=response.reasoning_content,
|
||||
thinking_blocks=response.thinking_blocks,
|
||||
))
|
||||
|
||||
# Execute tools
|
||||
for tool_call in response.tool_calls:
|
||||
args_str = json.dumps(tool_call.arguments, ensure_ascii=False)
|
||||
logger.debug("Subagent [{}] executing: {} with arguments: {}", task_id, tool_call.name, args_str)
|
||||
result = await tools.execute(tool_call.name, tool_call.arguments)
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"name": tool_call.name,
|
||||
"content": result,
|
||||
})
|
||||
else:
|
||||
final_result = response.content
|
||||
break
|
||||
|
||||
if final_result is None:
|
||||
final_result = "Task completed but no final response was generated."
|
||||
return
|
||||
if result.stop_reason == "error":
|
||||
await self._announce_result(
|
||||
task_id,
|
||||
label,
|
||||
task,
|
||||
result.error or "Error: subagent execution failed.",
|
||||
origin,
|
||||
"error",
|
||||
)
|
||||
return
|
||||
final_result = result.final_content or "Task completed but no final response was generated."
|
||||
|
||||
logger.info("Subagent [{}] completed successfully", task_id)
|
||||
await self._announce_result(task_id, label, task, final_result, origin, "ok")
|
||||
@ -196,6 +189,27 @@ Summarize this naturally for the user. Keep it brief (1-2 sentences). Do not men
|
||||
|
||||
await self.bus.publish_inbound(msg)
|
||||
logger.debug("Subagent [{}] announced result to {}:{}", task_id, origin['channel'], origin['chat_id'])
|
||||
|
||||
@staticmethod
|
||||
def _format_partial_progress(result) -> str:
|
||||
completed = [e for e in result.tool_events if e["status"] == "ok"]
|
||||
failure = next((e for e in reversed(result.tool_events) if e["status"] == "error"), None)
|
||||
lines: list[str] = []
|
||||
if completed:
|
||||
lines.append("Completed steps:")
|
||||
for event in completed[-3:]:
|
||||
lines.append(f"- {event['name']}: {event['detail']}")
|
||||
if failure:
|
||||
if lines:
|
||||
lines.append("")
|
||||
lines.append("Failure:")
|
||||
lines.append(f"- {failure['name']}: {failure['detail']}")
|
||||
if result.error and not failure:
|
||||
if lines:
|
||||
lines.append("")
|
||||
lines.append("Failure:")
|
||||
lines.append(f"- {result.error}")
|
||||
return "\n".join(lines) or (result.error or "Error: subagent execution failed.")
|
||||
|
||||
def _build_subagent_prompt(self) -> str:
|
||||
"""Build a focused system prompt for the subagent."""
|
||||
|
||||
186
tests/agent/test_runner.py
Normal file
186
tests/agent/test_runner.py
Normal file
@ -0,0 +1,186 @@
|
||||
"""Tests for the shared agent runner and its integration contracts."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||
|
||||
|
||||
def _make_loop(tmp_path):
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.queue import MessageBus
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
|
||||
with patch("nanobot.agent.loop.ContextBuilder"), \
|
||||
patch("nanobot.agent.loop.SessionManager"), \
|
||||
patch("nanobot.agent.loop.SubagentManager") as MockSubMgr:
|
||||
MockSubMgr.return_value.cancel_by_session = AsyncMock(return_value=0)
|
||||
loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path)
|
||||
return loop
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_preserves_reasoning_fields_and_tool_results():
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
captured_second_call: list[dict] = []
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
return LLMResponse(
|
||||
content="thinking",
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
|
||||
reasoning_content="hidden reasoning",
|
||||
thinking_blocks=[{"type": "thinking", "thinking": "step"}],
|
||||
usage={"prompt_tokens": 5, "completion_tokens": 3},
|
||||
)
|
||||
captured_second_call[:] = messages
|
||||
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
tools.execute = AsyncMock(return_value="tool result")
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[
|
||||
{"role": "system", "content": "system"},
|
||||
{"role": "user", "content": "do task"},
|
||||
],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=3,
|
||||
))
|
||||
|
||||
assert result.final_content == "done"
|
||||
assert result.tools_used == ["list_dir"]
|
||||
assert result.tool_events == [
|
||||
{"name": "list_dir", "status": "ok", "detail": "tool result"}
|
||||
]
|
||||
|
||||
assistant_messages = [
|
||||
msg for msg in captured_second_call
|
||||
if msg.get("role") == "assistant" and msg.get("tool_calls")
|
||||
]
|
||||
assert len(assistant_messages) == 1
|
||||
assert assistant_messages[0]["reasoning_content"] == "hidden reasoning"
|
||||
assert assistant_messages[0]["thinking_blocks"] == [{"type": "thinking", "thinking": "step"}]
|
||||
assert any(
|
||||
msg.get("role") == "tool" and msg.get("content") == "tool result"
|
||||
for msg in captured_second_call
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_returns_max_iterations_fallback():
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
|
||||
content="still working",
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
|
||||
))
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
tools.execute = AsyncMock(return_value="tool result")
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=2,
|
||||
))
|
||||
|
||||
assert result.stop_reason == "max_iterations"
|
||||
assert result.final_content == (
|
||||
"I reached the maximum number of tool call iterations (2) "
|
||||
"without completing the task. You can try breaking the task into smaller steps."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_returns_structured_tool_error():
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
|
||||
content="working",
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})],
|
||||
))
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
tools.execute = AsyncMock(side_effect=RuntimeError("boom"))
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=2,
|
||||
fail_on_tool_error=True,
|
||||
))
|
||||
|
||||
assert result.stop_reason == "tool_error"
|
||||
assert result.error == "Error: RuntimeError: boom"
|
||||
assert result.tool_events == [
|
||||
{"name": "list_dir", "status": "error", "detail": "boom"}
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_loop_max_iterations_message_stays_stable(tmp_path):
|
||||
loop = _make_loop(tmp_path)
|
||||
loop.provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
|
||||
content="working",
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})],
|
||||
))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
loop.tools.execute = AsyncMock(return_value="ok")
|
||||
loop.max_iterations = 2
|
||||
|
||||
final_content, _, _ = await loop._run_agent_loop([])
|
||||
|
||||
assert final_content == (
|
||||
"I reached the maximum number of tool call iterations (2) "
|
||||
"without completing the task. You can try breaking the task into smaller steps."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subagent_max_iterations_announces_existing_fallback(tmp_path, monkeypatch):
|
||||
from nanobot.agent.subagent import SubagentManager
|
||||
from nanobot.bus.queue import MessageBus
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
|
||||
content="working",
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})],
|
||||
))
|
||||
mgr = SubagentManager(provider=provider, workspace=tmp_path, bus=bus)
|
||||
mgr._announce_result = AsyncMock()
|
||||
|
||||
async def fake_execute(self, name, arguments):
|
||||
return "tool result"
|
||||
|
||||
monkeypatch.setattr("nanobot.agent.tools.registry.ToolRegistry.execute", fake_execute)
|
||||
|
||||
await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"})
|
||||
|
||||
mgr._announce_result.assert_awaited_once()
|
||||
args = mgr._announce_result.await_args.args
|
||||
assert args[3] == "Task completed but no final response was generated."
|
||||
assert args[5] == "ok"
|
||||
@ -221,3 +221,83 @@ class TestSubagentCancellation:
|
||||
assert len(assistant_messages) == 1
|
||||
assert assistant_messages[0]["reasoning_content"] == "hidden reasoning"
|
||||
assert assistant_messages[0]["thinking_blocks"] == [{"type": "thinking", "thinking": "step"}]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_subagent_announces_error_when_tool_execution_fails(self, monkeypatch, tmp_path):
|
||||
from nanobot.agent.subagent import SubagentManager
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
|
||||
content="thinking",
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})],
|
||||
))
|
||||
mgr = SubagentManager(provider=provider, workspace=tmp_path, bus=bus)
|
||||
mgr._announce_result = AsyncMock()
|
||||
|
||||
calls = {"n": 0}
|
||||
|
||||
async def fake_execute(self, name, arguments):
|
||||
calls["n"] += 1
|
||||
if calls["n"] == 1:
|
||||
return "first result"
|
||||
raise RuntimeError("boom")
|
||||
|
||||
monkeypatch.setattr("nanobot.agent.tools.registry.ToolRegistry.execute", fake_execute)
|
||||
|
||||
await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"})
|
||||
|
||||
mgr._announce_result.assert_awaited_once()
|
||||
args = mgr._announce_result.await_args.args
|
||||
assert "Completed steps:" in args[3]
|
||||
assert "- list_dir: first result" in args[3]
|
||||
assert "Failure:" in args[3]
|
||||
assert "- list_dir: boom" in args[3]
|
||||
assert args[5] == "error"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_by_session_cancels_running_subagent_tool(self, monkeypatch, tmp_path):
|
||||
from nanobot.agent.subagent import SubagentManager
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
|
||||
content="thinking",
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})],
|
||||
))
|
||||
mgr = SubagentManager(provider=provider, workspace=tmp_path, bus=bus)
|
||||
mgr._announce_result = AsyncMock()
|
||||
|
||||
started = asyncio.Event()
|
||||
cancelled = asyncio.Event()
|
||||
|
||||
async def fake_execute(self, name, arguments):
|
||||
started.set()
|
||||
try:
|
||||
await asyncio.sleep(60)
|
||||
except asyncio.CancelledError:
|
||||
cancelled.set()
|
||||
raise
|
||||
|
||||
monkeypatch.setattr("nanobot.agent.tools.registry.ToolRegistry.execute", fake_execute)
|
||||
|
||||
task = asyncio.create_task(
|
||||
mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"})
|
||||
)
|
||||
mgr._running_tasks["sub-1"] = task
|
||||
mgr._session_tasks["test:c1"] = {"sub-1"}
|
||||
|
||||
await started.wait()
|
||||
|
||||
count = await mgr.cancel_by_session("test:c1")
|
||||
|
||||
assert count == 1
|
||||
assert cancelled.is_set()
|
||||
assert task.cancelled()
|
||||
mgr._announce_result.assert_not_awaited()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user