mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-08 04:03:38 +00:00
235 lines
8.9 KiB
Python
235 lines
8.9 KiB
Python
"""Shared execution loop for tool-using agents."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from dataclasses import dataclass, field
|
|
from typing import Any
|
|
|
|
from nanobot.agent.hook import AgentHook, AgentHookContext
|
|
from nanobot.agent.tools.registry import ToolRegistry
|
|
from nanobot.providers.base import LLMProvider, ToolCallRequest
|
|
from nanobot.utils.helpers import build_assistant_message
|
|
|
|
_DEFAULT_MAX_ITERATIONS_MESSAGE = (
|
|
"I reached the maximum number of tool call iterations ({max_iterations}) "
|
|
"without completing the task. You can try breaking the task into smaller steps."
|
|
)
|
|
_DEFAULT_ERROR_MESSAGE = "Sorry, I encountered an error calling the AI model."
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class AgentRunSpec:
|
|
"""Configuration for a single agent execution."""
|
|
|
|
initial_messages: list[dict[str, Any]]
|
|
tools: ToolRegistry
|
|
model: str
|
|
max_iterations: int
|
|
temperature: float | None = None
|
|
max_tokens: int | None = None
|
|
reasoning_effort: str | None = None
|
|
hook: AgentHook | None = None
|
|
error_message: str | None = _DEFAULT_ERROR_MESSAGE
|
|
max_iterations_message: str | None = None
|
|
concurrent_tools: bool = False
|
|
fail_on_tool_error: bool = False
|
|
|
|
|
|
@dataclass(slots=True)
|
|
class AgentRunResult:
|
|
"""Outcome of a shared agent execution."""
|
|
|
|
final_content: str | None
|
|
messages: list[dict[str, Any]]
|
|
tools_used: list[str] = field(default_factory=list)
|
|
usage: dict[str, int] = field(default_factory=dict)
|
|
stop_reason: str = "completed"
|
|
error: str | None = None
|
|
tool_events: list[dict[str, str]] = field(default_factory=list)
|
|
|
|
|
|
class AgentRunner:
|
|
"""Run a tool-capable LLM loop without product-layer concerns."""
|
|
|
|
def __init__(self, provider: LLMProvider):
|
|
self.provider = provider
|
|
|
|
async def run(self, spec: AgentRunSpec) -> AgentRunResult:
|
|
hook = spec.hook or AgentHook()
|
|
messages = list(spec.initial_messages)
|
|
final_content: str | None = None
|
|
tools_used: list[str] = []
|
|
usage: dict[str, int] = {}
|
|
error: str | None = None
|
|
stop_reason = "completed"
|
|
tool_events: list[dict[str, str]] = []
|
|
|
|
for iteration in range(spec.max_iterations):
|
|
context = AgentHookContext(iteration=iteration, messages=messages)
|
|
await hook.before_iteration(context)
|
|
kwargs: dict[str, Any] = {
|
|
"messages": messages,
|
|
"tools": spec.tools.get_definitions(),
|
|
"model": spec.model,
|
|
}
|
|
if spec.temperature is not None:
|
|
kwargs["temperature"] = spec.temperature
|
|
if spec.max_tokens is not None:
|
|
kwargs["max_tokens"] = spec.max_tokens
|
|
if spec.reasoning_effort is not None:
|
|
kwargs["reasoning_effort"] = spec.reasoning_effort
|
|
|
|
if hook.wants_streaming():
|
|
async def _stream(delta: str) -> None:
|
|
await hook.on_stream(context, delta)
|
|
|
|
response = await self.provider.chat_stream_with_retry(
|
|
**kwargs,
|
|
on_content_delta=_stream,
|
|
)
|
|
else:
|
|
response = await self.provider.chat_with_retry(**kwargs)
|
|
|
|
raw_usage = response.usage or {}
|
|
context.response = response
|
|
context.usage = raw_usage
|
|
context.tool_calls = list(response.tool_calls)
|
|
# Accumulate standard fields into result usage.
|
|
usage["prompt_tokens"] = usage.get("prompt_tokens", 0) + int(raw_usage.get("prompt_tokens", 0) or 0)
|
|
usage["completion_tokens"] = usage.get("completion_tokens", 0) + int(raw_usage.get("completion_tokens", 0) or 0)
|
|
cached = raw_usage.get("cached_tokens")
|
|
if cached:
|
|
usage["cached_tokens"] = usage.get("cached_tokens", 0) + int(cached)
|
|
|
|
if response.has_tool_calls:
|
|
if hook.wants_streaming():
|
|
await hook.on_stream_end(context, resuming=True)
|
|
|
|
messages.append(build_assistant_message(
|
|
response.content or "",
|
|
tool_calls=[tc.to_openai_tool_call() for tc in response.tool_calls],
|
|
reasoning_content=response.reasoning_content,
|
|
thinking_blocks=response.thinking_blocks,
|
|
))
|
|
tools_used.extend(tc.name for tc in response.tool_calls)
|
|
|
|
await hook.before_execute_tools(context)
|
|
|
|
results, new_events, fatal_error = await self._execute_tools(spec, response.tool_calls)
|
|
tool_events.extend(new_events)
|
|
context.tool_results = list(results)
|
|
context.tool_events = list(new_events)
|
|
if fatal_error is not None:
|
|
error = f"Error: {type(fatal_error).__name__}: {fatal_error}"
|
|
stop_reason = "tool_error"
|
|
context.error = error
|
|
context.stop_reason = stop_reason
|
|
await hook.after_iteration(context)
|
|
break
|
|
for tool_call, result in zip(response.tool_calls, results):
|
|
messages.append({
|
|
"role": "tool",
|
|
"tool_call_id": tool_call.id,
|
|
"name": tool_call.name,
|
|
"content": result,
|
|
})
|
|
await hook.after_iteration(context)
|
|
continue
|
|
|
|
if hook.wants_streaming():
|
|
await hook.on_stream_end(context, resuming=False)
|
|
|
|
clean = hook.finalize_content(context, response.content)
|
|
if response.finish_reason == "error":
|
|
final_content = clean or spec.error_message or _DEFAULT_ERROR_MESSAGE
|
|
stop_reason = "error"
|
|
error = final_content
|
|
context.final_content = final_content
|
|
context.error = error
|
|
context.stop_reason = stop_reason
|
|
await hook.after_iteration(context)
|
|
break
|
|
|
|
messages.append(build_assistant_message(
|
|
clean,
|
|
reasoning_content=response.reasoning_content,
|
|
thinking_blocks=response.thinking_blocks,
|
|
))
|
|
final_content = clean
|
|
context.final_content = final_content
|
|
context.stop_reason = stop_reason
|
|
await hook.after_iteration(context)
|
|
break
|
|
else:
|
|
stop_reason = "max_iterations"
|
|
template = spec.max_iterations_message or _DEFAULT_MAX_ITERATIONS_MESSAGE
|
|
final_content = template.format(max_iterations=spec.max_iterations)
|
|
|
|
return AgentRunResult(
|
|
final_content=final_content,
|
|
messages=messages,
|
|
tools_used=tools_used,
|
|
usage=usage,
|
|
stop_reason=stop_reason,
|
|
error=error,
|
|
tool_events=tool_events,
|
|
)
|
|
|
|
async def _execute_tools(
|
|
self,
|
|
spec: AgentRunSpec,
|
|
tool_calls: list[ToolCallRequest],
|
|
) -> tuple[list[Any], list[dict[str, str]], BaseException | None]:
|
|
if spec.concurrent_tools:
|
|
tool_results = await asyncio.gather(*(
|
|
self._run_tool(spec, tool_call)
|
|
for tool_call in tool_calls
|
|
))
|
|
else:
|
|
tool_results = [
|
|
await self._run_tool(spec, tool_call)
|
|
for tool_call in tool_calls
|
|
]
|
|
|
|
results: list[Any] = []
|
|
events: list[dict[str, str]] = []
|
|
fatal_error: BaseException | None = None
|
|
for result, event, error in tool_results:
|
|
results.append(result)
|
|
events.append(event)
|
|
if error is not None and fatal_error is None:
|
|
fatal_error = error
|
|
return results, events, fatal_error
|
|
|
|
async def _run_tool(
|
|
self,
|
|
spec: AgentRunSpec,
|
|
tool_call: ToolCallRequest,
|
|
) -> tuple[Any, dict[str, str], BaseException | None]:
|
|
try:
|
|
result = await spec.tools.execute(tool_call.name, tool_call.arguments)
|
|
except asyncio.CancelledError:
|
|
raise
|
|
except BaseException as exc:
|
|
event = {
|
|
"name": tool_call.name,
|
|
"status": "error",
|
|
"detail": str(exc),
|
|
}
|
|
if spec.fail_on_tool_error:
|
|
return f"Error: {type(exc).__name__}: {exc}", event, exc
|
|
return f"Error: {type(exc).__name__}: {exc}", event, None
|
|
|
|
detail = "" if result is None else str(result)
|
|
detail = detail.replace("\n", " ").strip()
|
|
if not detail:
|
|
detail = "(empty)"
|
|
elif len(detail) > 120:
|
|
detail = detail[:120] + "..."
|
|
return result, {
|
|
"name": tool_call.name,
|
|
"status": "error" if isinstance(result, str) and result.startswith("Error") else "ok",
|
|
"detail": detail,
|
|
}, None
|