mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-19 09:29:55 +00:00
refactor: extract shared agent runner and preserve subagent progress on failure
This commit is contained in:
parent
33abe915e7
commit
e7d371ec1e
@ -15,6 +15,7 @@ from loguru import logger
|
|||||||
|
|
||||||
from nanobot.agent.context import ContextBuilder
|
from nanobot.agent.context import ContextBuilder
|
||||||
from nanobot.agent.memory import MemoryConsolidator
|
from nanobot.agent.memory import MemoryConsolidator
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
from nanobot.agent.subagent import SubagentManager
|
from nanobot.agent.subagent import SubagentManager
|
||||||
from nanobot.agent.tools.cron import CronTool
|
from nanobot.agent.tools.cron import CronTool
|
||||||
from nanobot.agent.skills import BUILTIN_SKILLS_DIR
|
from nanobot.agent.skills import BUILTIN_SKILLS_DIR
|
||||||
@ -87,6 +88,7 @@ class AgentLoop:
|
|||||||
self.context = ContextBuilder(workspace, timezone=timezone)
|
self.context = ContextBuilder(workspace, timezone=timezone)
|
||||||
self.sessions = session_manager or SessionManager(workspace)
|
self.sessions = session_manager or SessionManager(workspace)
|
||||||
self.tools = ToolRegistry()
|
self.tools = ToolRegistry()
|
||||||
|
self.runner = AgentRunner(provider)
|
||||||
self.subagents = SubagentManager(
|
self.subagents = SubagentManager(
|
||||||
provider=provider,
|
provider=provider,
|
||||||
workspace=workspace,
|
workspace=workspace,
|
||||||
@ -214,11 +216,6 @@ class AgentLoop:
|
|||||||
``resuming=True`` means tool calls follow (spinner should restart);
|
``resuming=True`` means tool calls follow (spinner should restart);
|
||||||
``resuming=False`` means this is the final response.
|
``resuming=False`` means this is the final response.
|
||||||
"""
|
"""
|
||||||
messages = initial_messages
|
|
||||||
iteration = 0
|
|
||||||
final_content = None
|
|
||||||
tools_used: list[str] = []
|
|
||||||
|
|
||||||
# Wrap on_stream with stateful think-tag filter so downstream
|
# Wrap on_stream with stateful think-tag filter so downstream
|
||||||
# consumers (CLI, channels) never see <think> blocks.
|
# consumers (CLI, channels) never see <think> blocks.
|
||||||
_raw_stream = on_stream
|
_raw_stream = on_stream
|
||||||
@ -234,104 +231,47 @@ class AgentLoop:
|
|||||||
if incremental and _raw_stream:
|
if incremental and _raw_stream:
|
||||||
await _raw_stream(incremental)
|
await _raw_stream(incremental)
|
||||||
|
|
||||||
while iteration < self.max_iterations:
|
async def _wrapped_stream_end(*, resuming: bool = False) -> None:
|
||||||
iteration += 1
|
nonlocal _stream_buf
|
||||||
|
if on_stream_end:
|
||||||
|
await on_stream_end(resuming=resuming)
|
||||||
|
_stream_buf = ""
|
||||||
|
|
||||||
tool_defs = self.tools.get_definitions()
|
async def _handle_tool_calls(response) -> None:
|
||||||
|
if not on_progress:
|
||||||
|
return
|
||||||
|
if not on_stream:
|
||||||
|
thought = self._strip_think(response.content)
|
||||||
|
if thought:
|
||||||
|
await on_progress(thought)
|
||||||
|
tool_hint = self._strip_think(self._tool_hint(response.tool_calls))
|
||||||
|
await on_progress(tool_hint, tool_hint=True)
|
||||||
|
|
||||||
if on_stream:
|
async def _prepare_tools(tool_calls) -> None:
|
||||||
response = await self.provider.chat_stream_with_retry(
|
for tc in tool_calls:
|
||||||
messages=messages,
|
args_str = json.dumps(tc.arguments, ensure_ascii=False)
|
||||||
tools=tool_defs,
|
logger.info("Tool call: {}({})", tc.name, args_str[:200])
|
||||||
model=self.model,
|
self._set_tool_context(channel, chat_id, message_id)
|
||||||
on_content_delta=_filtered_stream,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
response = await self.provider.chat_with_retry(
|
|
||||||
messages=messages,
|
|
||||||
tools=tool_defs,
|
|
||||||
model=self.model,
|
|
||||||
)
|
|
||||||
|
|
||||||
usage = response.usage or {}
|
result = await self.runner.run(AgentRunSpec(
|
||||||
self._last_usage = {
|
initial_messages=initial_messages,
|
||||||
"prompt_tokens": int(usage.get("prompt_tokens", 0) or 0),
|
tools=self.tools,
|
||||||
"completion_tokens": int(usage.get("completion_tokens", 0) or 0),
|
model=self.model,
|
||||||
}
|
max_iterations=self.max_iterations,
|
||||||
|
on_stream=_filtered_stream if on_stream else None,
|
||||||
if response.has_tool_calls:
|
on_stream_end=_wrapped_stream_end if on_stream else None,
|
||||||
if on_stream and on_stream_end:
|
on_tool_calls=_handle_tool_calls,
|
||||||
await on_stream_end(resuming=True)
|
before_execute_tools=_prepare_tools,
|
||||||
_stream_buf = ""
|
finalize_content=self._strip_think,
|
||||||
|
error_message="Sorry, I encountered an error calling the AI model.",
|
||||||
if on_progress:
|
concurrent_tools=True,
|
||||||
if not on_stream:
|
))
|
||||||
thought = self._strip_think(response.content)
|
self._last_usage = result.usage
|
||||||
if thought:
|
if result.stop_reason == "max_iterations":
|
||||||
await on_progress(thought)
|
|
||||||
tool_hint = self._tool_hint(response.tool_calls)
|
|
||||||
tool_hint = self._strip_think(tool_hint)
|
|
||||||
await on_progress(tool_hint, tool_hint=True)
|
|
||||||
|
|
||||||
tool_call_dicts = [
|
|
||||||
tc.to_openai_tool_call()
|
|
||||||
for tc in response.tool_calls
|
|
||||||
]
|
|
||||||
messages = self.context.add_assistant_message(
|
|
||||||
messages, response.content, tool_call_dicts,
|
|
||||||
reasoning_content=response.reasoning_content,
|
|
||||||
thinking_blocks=response.thinking_blocks,
|
|
||||||
)
|
|
||||||
|
|
||||||
for tc in response.tool_calls:
|
|
||||||
tools_used.append(tc.name)
|
|
||||||
args_str = json.dumps(tc.arguments, ensure_ascii=False)
|
|
||||||
logger.info("Tool call: {}({})", tc.name, args_str[:200])
|
|
||||||
|
|
||||||
# Re-bind tool context right before execution so that
|
|
||||||
# concurrent sessions don't clobber each other's routing.
|
|
||||||
self._set_tool_context(channel, chat_id, message_id)
|
|
||||||
|
|
||||||
# Execute all tool calls concurrently — the LLM batches
|
|
||||||
# independent calls in a single response on purpose.
|
|
||||||
# return_exceptions=True ensures all results are collected
|
|
||||||
# even if one tool is cancelled or raises BaseException.
|
|
||||||
results = await asyncio.gather(*(
|
|
||||||
self.tools.execute(tc.name, tc.arguments)
|
|
||||||
for tc in response.tool_calls
|
|
||||||
), return_exceptions=True)
|
|
||||||
|
|
||||||
for tool_call, result in zip(response.tool_calls, results):
|
|
||||||
if isinstance(result, BaseException):
|
|
||||||
result = f"Error: {type(result).__name__}: {result}"
|
|
||||||
messages = self.context.add_tool_result(
|
|
||||||
messages, tool_call.id, tool_call.name, result
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if on_stream and on_stream_end:
|
|
||||||
await on_stream_end(resuming=False)
|
|
||||||
_stream_buf = ""
|
|
||||||
|
|
||||||
clean = self._strip_think(response.content)
|
|
||||||
if response.finish_reason == "error":
|
|
||||||
logger.error("LLM returned error: {}", (clean or "")[:200])
|
|
||||||
final_content = clean or "Sorry, I encountered an error calling the AI model."
|
|
||||||
break
|
|
||||||
messages = self.context.add_assistant_message(
|
|
||||||
messages, clean, reasoning_content=response.reasoning_content,
|
|
||||||
thinking_blocks=response.thinking_blocks,
|
|
||||||
)
|
|
||||||
final_content = clean
|
|
||||||
break
|
|
||||||
|
|
||||||
if final_content is None and iteration >= self.max_iterations:
|
|
||||||
logger.warning("Max iterations ({}) reached", self.max_iterations)
|
logger.warning("Max iterations ({}) reached", self.max_iterations)
|
||||||
final_content = (
|
elif result.stop_reason == "error":
|
||||||
f"I reached the maximum number of tool call iterations ({self.max_iterations}) "
|
logger.error("LLM returned error: {}", (result.final_content or "")[:200])
|
||||||
"without completing the task. You can try breaking the task into smaller steps."
|
return result.final_content, result.tools_used, result.messages
|
||||||
)
|
|
||||||
|
|
||||||
return final_content, tools_used, messages
|
|
||||||
|
|
||||||
async def run(self) -> None:
|
async def run(self) -> None:
|
||||||
"""Run the agent loop, dispatching messages as tasks to stay responsive to /stop."""
|
"""Run the agent loop, dispatching messages as tasks to stay responsive to /stop."""
|
||||||
|
|||||||
221
nanobot/agent/runner.py
Normal file
221
nanobot/agent/runner.py
Normal file
@ -0,0 +1,221 @@
|
|||||||
|
"""Shared execution loop for tool-using agents."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from nanobot.agent.tools.registry import ToolRegistry
|
||||||
|
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||||
|
from nanobot.utils.helpers import build_assistant_message
|
||||||
|
|
||||||
|
_DEFAULT_MAX_ITERATIONS_MESSAGE = (
|
||||||
|
"I reached the maximum number of tool call iterations ({max_iterations}) "
|
||||||
|
"without completing the task. You can try breaking the task into smaller steps."
|
||||||
|
)
|
||||||
|
_DEFAULT_ERROR_MESSAGE = "Sorry, I encountered an error calling the AI model."
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class AgentRunSpec:
|
||||||
|
"""Configuration for a single agent execution."""
|
||||||
|
|
||||||
|
initial_messages: list[dict[str, Any]]
|
||||||
|
tools: ToolRegistry
|
||||||
|
model: str
|
||||||
|
max_iterations: int
|
||||||
|
temperature: float | None = None
|
||||||
|
max_tokens: int | None = None
|
||||||
|
reasoning_effort: str | None = None
|
||||||
|
on_stream: Callable[[str], Awaitable[None]] | None = None
|
||||||
|
on_stream_end: Callable[..., Awaitable[None]] | None = None
|
||||||
|
on_tool_calls: Callable[[LLMResponse], Awaitable[None] | None] | None = None
|
||||||
|
before_execute_tools: Callable[[list[ToolCallRequest]], Awaitable[None] | None] | None = None
|
||||||
|
finalize_content: Callable[[str | None], str | None] | None = None
|
||||||
|
error_message: str | None = _DEFAULT_ERROR_MESSAGE
|
||||||
|
max_iterations_message: str | None = None
|
||||||
|
concurrent_tools: bool = False
|
||||||
|
fail_on_tool_error: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class AgentRunResult:
|
||||||
|
"""Outcome of a shared agent execution."""
|
||||||
|
|
||||||
|
final_content: str | None
|
||||||
|
messages: list[dict[str, Any]]
|
||||||
|
tools_used: list[str] = field(default_factory=list)
|
||||||
|
usage: dict[str, int] = field(default_factory=dict)
|
||||||
|
stop_reason: str = "completed"
|
||||||
|
error: str | None = None
|
||||||
|
tool_events: list[dict[str, str]] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentRunner:
|
||||||
|
"""Run a tool-capable LLM loop without product-layer concerns."""
|
||||||
|
|
||||||
|
def __init__(self, provider: LLMProvider):
|
||||||
|
self.provider = provider
|
||||||
|
|
||||||
|
async def run(self, spec: AgentRunSpec) -> AgentRunResult:
|
||||||
|
messages = list(spec.initial_messages)
|
||||||
|
final_content: str | None = None
|
||||||
|
tools_used: list[str] = []
|
||||||
|
usage = {"prompt_tokens": 0, "completion_tokens": 0}
|
||||||
|
error: str | None = None
|
||||||
|
stop_reason = "completed"
|
||||||
|
tool_events: list[dict[str, str]] = []
|
||||||
|
|
||||||
|
for _ in range(spec.max_iterations):
|
||||||
|
kwargs: dict[str, Any] = {
|
||||||
|
"messages": messages,
|
||||||
|
"tools": spec.tools.get_definitions(),
|
||||||
|
"model": spec.model,
|
||||||
|
}
|
||||||
|
if spec.temperature is not None:
|
||||||
|
kwargs["temperature"] = spec.temperature
|
||||||
|
if spec.max_tokens is not None:
|
||||||
|
kwargs["max_tokens"] = spec.max_tokens
|
||||||
|
if spec.reasoning_effort is not None:
|
||||||
|
kwargs["reasoning_effort"] = spec.reasoning_effort
|
||||||
|
|
||||||
|
if spec.on_stream:
|
||||||
|
response = await self.provider.chat_stream_with_retry(
|
||||||
|
**kwargs,
|
||||||
|
on_content_delta=spec.on_stream,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = await self.provider.chat_with_retry(**kwargs)
|
||||||
|
|
||||||
|
raw_usage = response.usage or {}
|
||||||
|
usage = {
|
||||||
|
"prompt_tokens": int(raw_usage.get("prompt_tokens", 0) or 0),
|
||||||
|
"completion_tokens": int(raw_usage.get("completion_tokens", 0) or 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
if response.has_tool_calls:
|
||||||
|
if spec.on_stream_end:
|
||||||
|
await spec.on_stream_end(resuming=True)
|
||||||
|
if spec.on_tool_calls:
|
||||||
|
maybe = spec.on_tool_calls(response)
|
||||||
|
if maybe is not None:
|
||||||
|
await maybe
|
||||||
|
|
||||||
|
messages.append(build_assistant_message(
|
||||||
|
response.content or "",
|
||||||
|
tool_calls=[tc.to_openai_tool_call() for tc in response.tool_calls],
|
||||||
|
reasoning_content=response.reasoning_content,
|
||||||
|
thinking_blocks=response.thinking_blocks,
|
||||||
|
))
|
||||||
|
tools_used.extend(tc.name for tc in response.tool_calls)
|
||||||
|
|
||||||
|
if spec.before_execute_tools:
|
||||||
|
maybe = spec.before_execute_tools(response.tool_calls)
|
||||||
|
if maybe is not None:
|
||||||
|
await maybe
|
||||||
|
|
||||||
|
results, new_events, fatal_error = await self._execute_tools(spec, response.tool_calls)
|
||||||
|
tool_events.extend(new_events)
|
||||||
|
if fatal_error is not None:
|
||||||
|
error = f"Error: {type(fatal_error).__name__}: {fatal_error}"
|
||||||
|
stop_reason = "tool_error"
|
||||||
|
break
|
||||||
|
for tool_call, result in zip(response.tool_calls, results):
|
||||||
|
messages.append({
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": tool_call.id,
|
||||||
|
"name": tool_call.name,
|
||||||
|
"content": result,
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
|
||||||
|
if spec.on_stream_end:
|
||||||
|
await spec.on_stream_end(resuming=False)
|
||||||
|
|
||||||
|
clean = spec.finalize_content(response.content) if spec.finalize_content else response.content
|
||||||
|
if response.finish_reason == "error":
|
||||||
|
final_content = clean or spec.error_message or _DEFAULT_ERROR_MESSAGE
|
||||||
|
stop_reason = "error"
|
||||||
|
error = final_content
|
||||||
|
break
|
||||||
|
|
||||||
|
messages.append(build_assistant_message(
|
||||||
|
clean,
|
||||||
|
reasoning_content=response.reasoning_content,
|
||||||
|
thinking_blocks=response.thinking_blocks,
|
||||||
|
))
|
||||||
|
final_content = clean
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
stop_reason = "max_iterations"
|
||||||
|
template = spec.max_iterations_message or _DEFAULT_MAX_ITERATIONS_MESSAGE
|
||||||
|
final_content = template.format(max_iterations=spec.max_iterations)
|
||||||
|
|
||||||
|
return AgentRunResult(
|
||||||
|
final_content=final_content,
|
||||||
|
messages=messages,
|
||||||
|
tools_used=tools_used,
|
||||||
|
usage=usage,
|
||||||
|
stop_reason=stop_reason,
|
||||||
|
error=error,
|
||||||
|
tool_events=tool_events,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _execute_tools(
|
||||||
|
self,
|
||||||
|
spec: AgentRunSpec,
|
||||||
|
tool_calls: list[ToolCallRequest],
|
||||||
|
) -> tuple[list[Any], list[dict[str, str]], BaseException | None]:
|
||||||
|
if spec.concurrent_tools:
|
||||||
|
tool_results = await asyncio.gather(*(
|
||||||
|
self._run_tool(spec, tool_call)
|
||||||
|
for tool_call in tool_calls
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
tool_results = [
|
||||||
|
await self._run_tool(spec, tool_call)
|
||||||
|
for tool_call in tool_calls
|
||||||
|
]
|
||||||
|
|
||||||
|
results: list[Any] = []
|
||||||
|
events: list[dict[str, str]] = []
|
||||||
|
fatal_error: BaseException | None = None
|
||||||
|
for result, event, error in tool_results:
|
||||||
|
results.append(result)
|
||||||
|
events.append(event)
|
||||||
|
if error is not None and fatal_error is None:
|
||||||
|
fatal_error = error
|
||||||
|
return results, events, fatal_error
|
||||||
|
|
||||||
|
async def _run_tool(
|
||||||
|
self,
|
||||||
|
spec: AgentRunSpec,
|
||||||
|
tool_call: ToolCallRequest,
|
||||||
|
) -> tuple[Any, dict[str, str], BaseException | None]:
|
||||||
|
try:
|
||||||
|
result = await spec.tools.execute(tool_call.name, tool_call.arguments)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
|
except BaseException as exc:
|
||||||
|
event = {
|
||||||
|
"name": tool_call.name,
|
||||||
|
"status": "error",
|
||||||
|
"detail": str(exc),
|
||||||
|
}
|
||||||
|
if spec.fail_on_tool_error:
|
||||||
|
return f"Error: {type(exc).__name__}: {exc}", event, exc
|
||||||
|
return f"Error: {type(exc).__name__}: {exc}", event, None
|
||||||
|
|
||||||
|
detail = "" if result is None else str(result)
|
||||||
|
detail = detail.replace("\n", " ").strip()
|
||||||
|
if not detail:
|
||||||
|
detail = "(empty)"
|
||||||
|
elif len(detail) > 120:
|
||||||
|
detail = detail[:120] + "..."
|
||||||
|
return result, {
|
||||||
|
"name": tool_call.name,
|
||||||
|
"status": "error" if isinstance(result, str) and result.startswith("Error") else "ok",
|
||||||
|
"detail": detail,
|
||||||
|
}, None
|
||||||
@ -8,6 +8,7 @@ from typing import Any
|
|||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
from nanobot.agent.skills import BUILTIN_SKILLS_DIR
|
from nanobot.agent.skills import BUILTIN_SKILLS_DIR
|
||||||
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
|
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
|
||||||
from nanobot.agent.tools.registry import ToolRegistry
|
from nanobot.agent.tools.registry import ToolRegistry
|
||||||
@ -17,7 +18,6 @@ from nanobot.bus.events import InboundMessage
|
|||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
from nanobot.config.schema import ExecToolConfig
|
from nanobot.config.schema import ExecToolConfig
|
||||||
from nanobot.providers.base import LLMProvider
|
from nanobot.providers.base import LLMProvider
|
||||||
from nanobot.utils.helpers import build_assistant_message
|
|
||||||
|
|
||||||
|
|
||||||
class SubagentManager:
|
class SubagentManager:
|
||||||
@ -44,6 +44,7 @@ class SubagentManager:
|
|||||||
self.web_proxy = web_proxy
|
self.web_proxy = web_proxy
|
||||||
self.exec_config = exec_config or ExecToolConfig()
|
self.exec_config = exec_config or ExecToolConfig()
|
||||||
self.restrict_to_workspace = restrict_to_workspace
|
self.restrict_to_workspace = restrict_to_workspace
|
||||||
|
self.runner = AgentRunner(provider)
|
||||||
self._running_tasks: dict[str, asyncio.Task[None]] = {}
|
self._running_tasks: dict[str, asyncio.Task[None]] = {}
|
||||||
self._session_tasks: dict[str, set[str]] = {} # session_key -> {task_id, ...}
|
self._session_tasks: dict[str, set[str]] = {} # session_key -> {task_id, ...}
|
||||||
|
|
||||||
@ -112,50 +113,42 @@ class SubagentManager:
|
|||||||
{"role": "system", "content": system_prompt},
|
{"role": "system", "content": system_prompt},
|
||||||
{"role": "user", "content": task},
|
{"role": "user", "content": task},
|
||||||
]
|
]
|
||||||
|
async def _log_tool_calls(tool_calls) -> None:
|
||||||
|
for tool_call in tool_calls:
|
||||||
|
args_str = json.dumps(tool_call.arguments, ensure_ascii=False)
|
||||||
|
logger.debug("Subagent [{}] executing: {} with arguments: {}", task_id, tool_call.name, args_str)
|
||||||
|
|
||||||
# Run agent loop (limited iterations)
|
result = await self.runner.run(AgentRunSpec(
|
||||||
max_iterations = 15
|
initial_messages=messages,
|
||||||
iteration = 0
|
tools=tools,
|
||||||
final_result: str | None = None
|
model=self.model,
|
||||||
|
max_iterations=15,
|
||||||
while iteration < max_iterations:
|
before_execute_tools=_log_tool_calls,
|
||||||
iteration += 1
|
max_iterations_message="Task completed but no final response was generated.",
|
||||||
|
error_message=None,
|
||||||
response = await self.provider.chat_with_retry(
|
fail_on_tool_error=True,
|
||||||
messages=messages,
|
))
|
||||||
tools=tools.get_definitions(),
|
if result.stop_reason == "tool_error":
|
||||||
model=self.model,
|
await self._announce_result(
|
||||||
|
task_id,
|
||||||
|
label,
|
||||||
|
task,
|
||||||
|
self._format_partial_progress(result),
|
||||||
|
origin,
|
||||||
|
"error",
|
||||||
)
|
)
|
||||||
|
return
|
||||||
if response.has_tool_calls:
|
if result.stop_reason == "error":
|
||||||
tool_call_dicts = [
|
await self._announce_result(
|
||||||
tc.to_openai_tool_call()
|
task_id,
|
||||||
for tc in response.tool_calls
|
label,
|
||||||
]
|
task,
|
||||||
messages.append(build_assistant_message(
|
result.error or "Error: subagent execution failed.",
|
||||||
response.content or "",
|
origin,
|
||||||
tool_calls=tool_call_dicts,
|
"error",
|
||||||
reasoning_content=response.reasoning_content,
|
)
|
||||||
thinking_blocks=response.thinking_blocks,
|
return
|
||||||
))
|
final_result = result.final_content or "Task completed but no final response was generated."
|
||||||
|
|
||||||
# Execute tools
|
|
||||||
for tool_call in response.tool_calls:
|
|
||||||
args_str = json.dumps(tool_call.arguments, ensure_ascii=False)
|
|
||||||
logger.debug("Subagent [{}] executing: {} with arguments: {}", task_id, tool_call.name, args_str)
|
|
||||||
result = await tools.execute(tool_call.name, tool_call.arguments)
|
|
||||||
messages.append({
|
|
||||||
"role": "tool",
|
|
||||||
"tool_call_id": tool_call.id,
|
|
||||||
"name": tool_call.name,
|
|
||||||
"content": result,
|
|
||||||
})
|
|
||||||
else:
|
|
||||||
final_result = response.content
|
|
||||||
break
|
|
||||||
|
|
||||||
if final_result is None:
|
|
||||||
final_result = "Task completed but no final response was generated."
|
|
||||||
|
|
||||||
logger.info("Subagent [{}] completed successfully", task_id)
|
logger.info("Subagent [{}] completed successfully", task_id)
|
||||||
await self._announce_result(task_id, label, task, final_result, origin, "ok")
|
await self._announce_result(task_id, label, task, final_result, origin, "ok")
|
||||||
@ -196,6 +189,27 @@ Summarize this naturally for the user. Keep it brief (1-2 sentences). Do not men
|
|||||||
|
|
||||||
await self.bus.publish_inbound(msg)
|
await self.bus.publish_inbound(msg)
|
||||||
logger.debug("Subagent [{}] announced result to {}:{}", task_id, origin['channel'], origin['chat_id'])
|
logger.debug("Subagent [{}] announced result to {}:{}", task_id, origin['channel'], origin['chat_id'])
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _format_partial_progress(result) -> str:
|
||||||
|
completed = [e for e in result.tool_events if e["status"] == "ok"]
|
||||||
|
failure = next((e for e in reversed(result.tool_events) if e["status"] == "error"), None)
|
||||||
|
lines: list[str] = []
|
||||||
|
if completed:
|
||||||
|
lines.append("Completed steps:")
|
||||||
|
for event in completed[-3:]:
|
||||||
|
lines.append(f"- {event['name']}: {event['detail']}")
|
||||||
|
if failure:
|
||||||
|
if lines:
|
||||||
|
lines.append("")
|
||||||
|
lines.append("Failure:")
|
||||||
|
lines.append(f"- {failure['name']}: {failure['detail']}")
|
||||||
|
if result.error and not failure:
|
||||||
|
if lines:
|
||||||
|
lines.append("")
|
||||||
|
lines.append("Failure:")
|
||||||
|
lines.append(f"- {result.error}")
|
||||||
|
return "\n".join(lines) or (result.error or "Error: subagent execution failed.")
|
||||||
|
|
||||||
def _build_subagent_prompt(self) -> str:
|
def _build_subagent_prompt(self) -> str:
|
||||||
"""Build a focused system prompt for the subagent."""
|
"""Build a focused system prompt for the subagent."""
|
||||||
|
|||||||
186
tests/agent/test_runner.py
Normal file
186
tests/agent/test_runner.py
Normal file
@ -0,0 +1,186 @@
|
|||||||
|
"""Tests for the shared agent runner and its integration contracts."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||||
|
|
||||||
|
|
||||||
|
def _make_loop(tmp_path):
|
||||||
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
|
||||||
|
bus = MessageBus()
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
|
||||||
|
with patch("nanobot.agent.loop.ContextBuilder"), \
|
||||||
|
patch("nanobot.agent.loop.SessionManager"), \
|
||||||
|
patch("nanobot.agent.loop.SubagentManager") as MockSubMgr:
|
||||||
|
MockSubMgr.return_value.cancel_by_session = AsyncMock(return_value=0)
|
||||||
|
loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path)
|
||||||
|
return loop
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_preserves_reasoning_fields_and_tool_results():
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
|
|
||||||
|
provider = MagicMock()
|
||||||
|
captured_second_call: list[dict] = []
|
||||||
|
call_count = {"n": 0}
|
||||||
|
|
||||||
|
async def chat_with_retry(*, messages, **kwargs):
|
||||||
|
call_count["n"] += 1
|
||||||
|
if call_count["n"] == 1:
|
||||||
|
return LLMResponse(
|
||||||
|
content="thinking",
|
||||||
|
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
|
||||||
|
reasoning_content="hidden reasoning",
|
||||||
|
thinking_blocks=[{"type": "thinking", "thinking": "step"}],
|
||||||
|
usage={"prompt_tokens": 5, "completion_tokens": 3},
|
||||||
|
)
|
||||||
|
captured_second_call[:] = messages
|
||||||
|
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||||
|
|
||||||
|
provider.chat_with_retry = chat_with_retry
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
tools.execute = AsyncMock(return_value="tool result")
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
result = await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=[
|
||||||
|
{"role": "system", "content": "system"},
|
||||||
|
{"role": "user", "content": "do task"},
|
||||||
|
],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=3,
|
||||||
|
))
|
||||||
|
|
||||||
|
assert result.final_content == "done"
|
||||||
|
assert result.tools_used == ["list_dir"]
|
||||||
|
assert result.tool_events == [
|
||||||
|
{"name": "list_dir", "status": "ok", "detail": "tool result"}
|
||||||
|
]
|
||||||
|
|
||||||
|
assistant_messages = [
|
||||||
|
msg for msg in captured_second_call
|
||||||
|
if msg.get("role") == "assistant" and msg.get("tool_calls")
|
||||||
|
]
|
||||||
|
assert len(assistant_messages) == 1
|
||||||
|
assert assistant_messages[0]["reasoning_content"] == "hidden reasoning"
|
||||||
|
assert assistant_messages[0]["thinking_blocks"] == [{"type": "thinking", "thinking": "step"}]
|
||||||
|
assert any(
|
||||||
|
msg.get("role") == "tool" and msg.get("content") == "tool result"
|
||||||
|
for msg in captured_second_call
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_returns_max_iterations_fallback():
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
|
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
|
||||||
|
content="still working",
|
||||||
|
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
|
||||||
|
))
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
tools.execute = AsyncMock(return_value="tool result")
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
result = await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=[],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=2,
|
||||||
|
))
|
||||||
|
|
||||||
|
assert result.stop_reason == "max_iterations"
|
||||||
|
assert result.final_content == (
|
||||||
|
"I reached the maximum number of tool call iterations (2) "
|
||||||
|
"without completing the task. You can try breaking the task into smaller steps."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_returns_structured_tool_error():
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
|
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
|
||||||
|
content="working",
|
||||||
|
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})],
|
||||||
|
))
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
tools.execute = AsyncMock(side_effect=RuntimeError("boom"))
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
|
||||||
|
result = await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=[],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=2,
|
||||||
|
fail_on_tool_error=True,
|
||||||
|
))
|
||||||
|
|
||||||
|
assert result.stop_reason == "tool_error"
|
||||||
|
assert result.error == "Error: RuntimeError: boom"
|
||||||
|
assert result.tool_events == [
|
||||||
|
{"name": "list_dir", "status": "error", "detail": "boom"}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_loop_max_iterations_message_stays_stable(tmp_path):
|
||||||
|
loop = _make_loop(tmp_path)
|
||||||
|
loop.provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
|
||||||
|
content="working",
|
||||||
|
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})],
|
||||||
|
))
|
||||||
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
loop.tools.execute = AsyncMock(return_value="ok")
|
||||||
|
loop.max_iterations = 2
|
||||||
|
|
||||||
|
final_content, _, _ = await loop._run_agent_loop([])
|
||||||
|
|
||||||
|
assert final_content == (
|
||||||
|
"I reached the maximum number of tool call iterations (2) "
|
||||||
|
"without completing the task. You can try breaking the task into smaller steps."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_subagent_max_iterations_announces_existing_fallback(tmp_path, monkeypatch):
|
||||||
|
from nanobot.agent.subagent import SubagentManager
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
|
||||||
|
bus = MessageBus()
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
|
||||||
|
content="working",
|
||||||
|
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})],
|
||||||
|
))
|
||||||
|
mgr = SubagentManager(provider=provider, workspace=tmp_path, bus=bus)
|
||||||
|
mgr._announce_result = AsyncMock()
|
||||||
|
|
||||||
|
async def fake_execute(self, name, arguments):
|
||||||
|
return "tool result"
|
||||||
|
|
||||||
|
monkeypatch.setattr("nanobot.agent.tools.registry.ToolRegistry.execute", fake_execute)
|
||||||
|
|
||||||
|
await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"})
|
||||||
|
|
||||||
|
mgr._announce_result.assert_awaited_once()
|
||||||
|
args = mgr._announce_result.await_args.args
|
||||||
|
assert args[3] == "Task completed but no final response was generated."
|
||||||
|
assert args[5] == "ok"
|
||||||
@ -221,3 +221,83 @@ class TestSubagentCancellation:
|
|||||||
assert len(assistant_messages) == 1
|
assert len(assistant_messages) == 1
|
||||||
assert assistant_messages[0]["reasoning_content"] == "hidden reasoning"
|
assert assistant_messages[0]["reasoning_content"] == "hidden reasoning"
|
||||||
assert assistant_messages[0]["thinking_blocks"] == [{"type": "thinking", "thinking": "step"}]
|
assert assistant_messages[0]["thinking_blocks"] == [{"type": "thinking", "thinking": "step"}]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_subagent_announces_error_when_tool_execution_fails(self, monkeypatch, tmp_path):
|
||||||
|
from nanobot.agent.subagent import SubagentManager
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||||
|
|
||||||
|
bus = MessageBus()
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
|
||||||
|
content="thinking",
|
||||||
|
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})],
|
||||||
|
))
|
||||||
|
mgr = SubagentManager(provider=provider, workspace=tmp_path, bus=bus)
|
||||||
|
mgr._announce_result = AsyncMock()
|
||||||
|
|
||||||
|
calls = {"n": 0}
|
||||||
|
|
||||||
|
async def fake_execute(self, name, arguments):
|
||||||
|
calls["n"] += 1
|
||||||
|
if calls["n"] == 1:
|
||||||
|
return "first result"
|
||||||
|
raise RuntimeError("boom")
|
||||||
|
|
||||||
|
monkeypatch.setattr("nanobot.agent.tools.registry.ToolRegistry.execute", fake_execute)
|
||||||
|
|
||||||
|
await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"})
|
||||||
|
|
||||||
|
mgr._announce_result.assert_awaited_once()
|
||||||
|
args = mgr._announce_result.await_args.args
|
||||||
|
assert "Completed steps:" in args[3]
|
||||||
|
assert "- list_dir: first result" in args[3]
|
||||||
|
assert "Failure:" in args[3]
|
||||||
|
assert "- list_dir: boom" in args[3]
|
||||||
|
assert args[5] == "error"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cancel_by_session_cancels_running_subagent_tool(self, monkeypatch, tmp_path):
|
||||||
|
from nanobot.agent.subagent import SubagentManager
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||||
|
|
||||||
|
bus = MessageBus()
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
|
||||||
|
content="thinking",
|
||||||
|
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})],
|
||||||
|
))
|
||||||
|
mgr = SubagentManager(provider=provider, workspace=tmp_path, bus=bus)
|
||||||
|
mgr._announce_result = AsyncMock()
|
||||||
|
|
||||||
|
started = asyncio.Event()
|
||||||
|
cancelled = asyncio.Event()
|
||||||
|
|
||||||
|
async def fake_execute(self, name, arguments):
|
||||||
|
started.set()
|
||||||
|
try:
|
||||||
|
await asyncio.sleep(60)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
cancelled.set()
|
||||||
|
raise
|
||||||
|
|
||||||
|
monkeypatch.setattr("nanobot.agent.tools.registry.ToolRegistry.execute", fake_execute)
|
||||||
|
|
||||||
|
task = asyncio.create_task(
|
||||||
|
mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"})
|
||||||
|
)
|
||||||
|
mgr._running_tasks["sub-1"] = task
|
||||||
|
mgr._session_tasks["test:c1"] = {"sub-1"}
|
||||||
|
|
||||||
|
await started.wait()
|
||||||
|
|
||||||
|
count = await mgr.cancel_by_session("test:c1")
|
||||||
|
|
||||||
|
assert count == 1
|
||||||
|
assert cancelled.is_set()
|
||||||
|
assert task.cancelled()
|
||||||
|
mgr._announce_result.assert_not_awaited()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user