674 lines
29 KiB
Python

"""Agent loop: the core processing engine."""
from __future__ import annotations
import asyncio
import json
import re
import os
import time
from contextlib import AsyncExitStack, nullcontext
from pathlib import Path
from typing import TYPE_CHECKING, Any, Awaitable, Callable
from loguru import logger
from nanobot.agent.context import ContextBuilder
from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook
from nanobot.agent.memory import MemoryConsolidator
from nanobot.agent.runner import AgentRunSpec, AgentRunner
from nanobot.agent.subagent import SubagentManager
from nanobot.agent.tools.cron import CronTool
from nanobot.agent.skills import BUILTIN_SKILLS_DIR
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
from nanobot.agent.tools.message import MessageTool
from nanobot.agent.tools.registry import ToolRegistry
from nanobot.agent.tools.shell import ExecTool
from nanobot.agent.tools.spawn import SpawnTool
from nanobot.agent.tools.web import WebFetchTool, WebSearchTool
from nanobot.bus.events import InboundMessage, OutboundMessage
from nanobot.command import CommandContext, CommandRouter, register_builtin_commands
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMProvider
from nanobot.session.manager import Session, SessionManager
if TYPE_CHECKING:
from nanobot.config.schema import ChannelsConfig, ExecToolConfig, WebSearchConfig
from nanobot.cron.service import CronService
class _LoopHook(AgentHook):
"""Core lifecycle hook for the main agent loop.
Handles streaming delta relay, progress reporting, tool-call logging,
and think-tag stripping for the built-in agent path.
"""
def __init__(
self,
agent_loop: AgentLoop,
on_progress: Callable[..., Awaitable[None]] | None = None,
on_stream: Callable[[str], Awaitable[None]] | None = None,
on_stream_end: Callable[..., Awaitable[None]] | None = None,
*,
channel: str = "cli",
chat_id: str = "direct",
message_id: str | None = None,
) -> None:
self._loop = agent_loop
self._on_progress = on_progress
self._on_stream = on_stream
self._on_stream_end = on_stream_end
self._channel = channel
self._chat_id = chat_id
self._message_id = message_id
self._stream_buf = ""
def wants_streaming(self) -> bool:
return self._on_stream is not None
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
from nanobot.utils.helpers import strip_think
prev_clean = strip_think(self._stream_buf)
self._stream_buf += delta
new_clean = strip_think(self._stream_buf)
incremental = new_clean[len(prev_clean):]
if incremental and self._on_stream:
await self._on_stream(incremental)
async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None:
if self._on_stream_end:
await self._on_stream_end(resuming=resuming)
self._stream_buf = ""
async def before_execute_tools(self, context: AgentHookContext) -> None:
if self._on_progress:
if not self._on_stream:
thought = self._loop._strip_think(
context.response.content if context.response else None
)
if thought:
await self._on_progress(thought)
tool_hint = self._loop._strip_think(self._loop._tool_hint(context.tool_calls))
await self._on_progress(tool_hint, tool_hint=True)
for tc in context.tool_calls:
args_str = json.dumps(tc.arguments, ensure_ascii=False)
logger.info("Tool call: {}({})", tc.name, args_str[:200])
self._loop._set_tool_context(self._channel, self._chat_id, self._message_id)
async def after_iteration(self, context: AgentHookContext) -> None:
u = context.usage or {}
logger.debug(
"LLM usage: prompt={} completion={} cached={}",
u.get("prompt_tokens", 0),
u.get("completion_tokens", 0),
u.get("cached_tokens", 0),
)
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
return self._loop._strip_think(content)
class _LoopHookChain(AgentHook):
"""Run the core loop hook first, then best-effort extra hooks.
This preserves the historical failure behavior of ``_LoopHook`` while still
letting user-supplied hooks opt into ``CompositeHook`` isolation.
"""
__slots__ = ("_primary", "_extras")
def __init__(self, primary: AgentHook, extra_hooks: list[AgentHook]) -> None:
self._primary = primary
self._extras = CompositeHook(extra_hooks)
def wants_streaming(self) -> bool:
return self._primary.wants_streaming() or self._extras.wants_streaming()
async def before_iteration(self, context: AgentHookContext) -> None:
await self._primary.before_iteration(context)
await self._extras.before_iteration(context)
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
await self._primary.on_stream(context, delta)
await self._extras.on_stream(context, delta)
async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None:
await self._primary.on_stream_end(context, resuming=resuming)
await self._extras.on_stream_end(context, resuming=resuming)
async def before_execute_tools(self, context: AgentHookContext) -> None:
await self._primary.before_execute_tools(context)
await self._extras.before_execute_tools(context)
async def after_iteration(self, context: AgentHookContext) -> None:
await self._primary.after_iteration(context)
await self._extras.after_iteration(context)
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
content = self._primary.finalize_content(context, content)
return self._extras.finalize_content(context, content)
class AgentLoop:
"""
The agent loop is the core processing engine.
It:
1. Receives messages from the bus
2. Builds context with history, memory, skills
3. Calls the LLM
4. Executes tool calls
5. Sends responses back
"""
_TOOL_RESULT_MAX_CHARS = 16_000
def __init__(
self,
bus: MessageBus,
provider: LLMProvider,
workspace: Path,
model: str | None = None,
max_iterations: int = 40,
context_window_tokens: int = 65_536,
web_search_config: WebSearchConfig | None = None,
web_proxy: str | None = None,
exec_config: ExecToolConfig | None = None,
cron_service: CronService | None = None,
restrict_to_workspace: bool = False,
session_manager: SessionManager | None = None,
mcp_servers: dict | None = None,
channels_config: ChannelsConfig | None = None,
timezone: str | None = None,
hooks: list[AgentHook] | None = None,
):
from nanobot.config.schema import ExecToolConfig, WebSearchConfig
self.bus = bus
self.channels_config = channels_config
self.provider = provider
self.workspace = workspace
self.model = model or provider.get_default_model()
self.max_iterations = max_iterations
self.context_window_tokens = context_window_tokens
self.web_search_config = web_search_config or WebSearchConfig()
self.web_proxy = web_proxy
self.exec_config = exec_config or ExecToolConfig()
self.cron_service = cron_service
self.restrict_to_workspace = restrict_to_workspace
self._start_time = time.time()
self._last_usage: dict[str, int] = {}
self._extra_hooks: list[AgentHook] = hooks or []
self.context = ContextBuilder(workspace, timezone=timezone)
self.sessions = session_manager or SessionManager(workspace)
self.tools = ToolRegistry()
self.runner = AgentRunner(provider)
self.subagents = SubagentManager(
provider=provider,
workspace=workspace,
bus=bus,
model=self.model,
web_search_config=self.web_search_config,
web_proxy=web_proxy,
exec_config=self.exec_config,
restrict_to_workspace=restrict_to_workspace,
)
self._running = False
self._mcp_servers = mcp_servers or {}
self._mcp_stack: AsyncExitStack | None = None
self._mcp_connected = False
self._mcp_connecting = False
self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks
self._background_tasks: list[asyncio.Task] = []
self._session_locks: dict[str, asyncio.Lock] = {}
# NANOBOT_MAX_CONCURRENT_REQUESTS: <=0 means unlimited; default 3.
_max = int(os.environ.get("NANOBOT_MAX_CONCURRENT_REQUESTS", "3"))
self._concurrency_gate: asyncio.Semaphore | None = (
asyncio.Semaphore(_max) if _max > 0 else None
)
self.memory_consolidator = MemoryConsolidator(
workspace=workspace,
provider=provider,
model=self.model,
sessions=self.sessions,
context_window_tokens=context_window_tokens,
build_messages=self.context.build_messages,
get_tool_definitions=self.tools.get_definitions,
max_completion_tokens=provider.generation.max_tokens,
)
self._register_default_tools()
self.commands = CommandRouter()
register_builtin_commands(self.commands)
def _register_default_tools(self) -> None:
"""Register the default set of tools."""
allowed_dir = self.workspace if self.restrict_to_workspace else None
extra_read = [BUILTIN_SKILLS_DIR] if allowed_dir else None
self.tools.register(ReadFileTool(workspace=self.workspace, allowed_dir=allowed_dir, extra_allowed_dirs=extra_read))
for cls in (WriteFileTool, EditFileTool, ListDirTool):
self.tools.register(cls(workspace=self.workspace, allowed_dir=allowed_dir))
if self.exec_config.enable:
self.tools.register(ExecTool(
working_dir=str(self.workspace),
timeout=self.exec_config.timeout,
restrict_to_workspace=self.restrict_to_workspace,
path_append=self.exec_config.path_append,
))
self.tools.register(WebSearchTool(config=self.web_search_config, proxy=self.web_proxy))
self.tools.register(WebFetchTool(proxy=self.web_proxy))
self.tools.register(MessageTool(send_callback=self.bus.publish_outbound))
self.tools.register(SpawnTool(manager=self.subagents))
if self.cron_service:
self.tools.register(
CronTool(self.cron_service, default_timezone=self.context.timezone or "UTC")
)
async def _connect_mcp(self) -> None:
"""Connect to configured MCP servers (one-time, lazy)."""
if self._mcp_connected or self._mcp_connecting or not self._mcp_servers:
return
self._mcp_connecting = True
from nanobot.agent.tools.mcp import connect_mcp_servers
try:
self._mcp_stack = AsyncExitStack()
await self._mcp_stack.__aenter__()
await connect_mcp_servers(self._mcp_servers, self.tools, self._mcp_stack)
self._mcp_connected = True
except BaseException as e:
logger.error("Failed to connect MCP servers (will retry next message): {}", e)
if self._mcp_stack:
try:
await self._mcp_stack.aclose()
except Exception:
pass
self._mcp_stack = None
finally:
self._mcp_connecting = False
def _set_tool_context(self, channel: str, chat_id: str, message_id: str | None = None) -> None:
"""Update context for all tools that need routing info."""
for name in ("message", "spawn", "cron"):
if tool := self.tools.get(name):
if hasattr(tool, "set_context"):
tool.set_context(channel, chat_id, *([message_id] if name == "message" else []))
@staticmethod
def _strip_think(text: str | None) -> str | None:
"""Remove <think>…</think> blocks that some models embed in content."""
if not text:
return None
from nanobot.utils.helpers import strip_think
return strip_think(text) or None
@staticmethod
def _tool_hint(tool_calls: list) -> str:
"""Format tool calls as concise hint, e.g. 'web_search("query")'."""
def _fmt(tc):
args = (tc.arguments[0] if isinstance(tc.arguments, list) else tc.arguments) or {}
val = next(iter(args.values()), None) if isinstance(args, dict) else None
if not isinstance(val, str):
return tc.name
return f'{tc.name}("{val[:40]}")' if len(val) > 40 else f'{tc.name}("{val}")'
return ", ".join(_fmt(tc) for tc in tool_calls)
async def _run_agent_loop(
self,
initial_messages: list[dict],
on_progress: Callable[..., Awaitable[None]] | None = None,
on_stream: Callable[[str], Awaitable[None]] | None = None,
on_stream_end: Callable[..., Awaitable[None]] | None = None,
*,
channel: str = "cli",
chat_id: str = "direct",
message_id: str | None = None,
) -> tuple[str | None, list[str], list[dict]]:
"""Run the agent iteration loop.
*on_stream*: called with each content delta during streaming.
*on_stream_end(resuming)*: called when a streaming session finishes.
``resuming=True`` means tool calls follow (spinner should restart);
``resuming=False`` means this is the final response.
"""
loop_hook = _LoopHook(
self,
on_progress=on_progress,
on_stream=on_stream,
on_stream_end=on_stream_end,
channel=channel,
chat_id=chat_id,
message_id=message_id,
)
hook: AgentHook = (
_LoopHookChain(loop_hook, self._extra_hooks)
if self._extra_hooks
else loop_hook
)
result = await self.runner.run(AgentRunSpec(
initial_messages=initial_messages,
tools=self.tools,
model=self.model,
max_iterations=self.max_iterations,
hook=hook,
error_message="Sorry, I encountered an error calling the AI model.",
concurrent_tools=True,
))
self._last_usage = result.usage
if result.stop_reason == "max_iterations":
logger.warning("Max iterations ({}) reached", self.max_iterations)
elif result.stop_reason == "error":
logger.error("LLM returned error: {}", (result.final_content or "")[:200])
return result.final_content, result.tools_used, result.messages
async def run(self) -> None:
"""Run the agent loop, dispatching messages as tasks to stay responsive to /stop."""
self._running = True
await self._connect_mcp()
logger.info("Agent loop started")
while self._running:
try:
msg = await asyncio.wait_for(self.bus.consume_inbound(), timeout=1.0)
except asyncio.TimeoutError:
continue
except asyncio.CancelledError:
# Preserve real task cancellation so shutdown can complete cleanly.
# Only ignore non-task CancelledError signals that may leak from integrations.
if not self._running or asyncio.current_task().cancelling():
raise
continue
except Exception as e:
logger.warning("Error consuming inbound message: {}, continuing...", e)
continue
raw = msg.content.strip()
if self.commands.is_priority(raw):
ctx = CommandContext(msg=msg, session=None, key=msg.session_key, raw=raw, loop=self)
result = await self.commands.dispatch_priority(ctx)
if result:
await self.bus.publish_outbound(result)
continue
task = asyncio.create_task(self._dispatch(msg))
self._active_tasks.setdefault(msg.session_key, []).append(task)
task.add_done_callback(lambda t, k=msg.session_key: self._active_tasks.get(k, []) and self._active_tasks[k].remove(t) if t in self._active_tasks.get(k, []) else None)
async def _dispatch(self, msg: InboundMessage) -> None:
"""Process a message: per-session serial, cross-session concurrent."""
lock = self._session_locks.setdefault(msg.session_key, asyncio.Lock())
gate = self._concurrency_gate or nullcontext()
async with lock, gate:
try:
on_stream = on_stream_end = None
if msg.metadata.get("_wants_stream"):
# Split one answer into distinct stream segments.
stream_base_id = f"{msg.session_key}:{time.time_ns()}"
stream_segment = 0
def _current_stream_id() -> str:
return f"{stream_base_id}:{stream_segment}"
async def on_stream(delta: str) -> None:
meta = dict(msg.metadata or {})
meta["_stream_delta"] = True
meta["_stream_id"] = _current_stream_id()
await self.bus.publish_outbound(OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id,
content=delta, metadata=meta,
))
async def on_stream_end(*, resuming: bool = False) -> None:
nonlocal stream_segment
meta = dict(msg.metadata or {})
meta["_stream_end"] = True
meta["_resuming"] = resuming
meta["_stream_id"] = _current_stream_id()
await self.bus.publish_outbound(OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id,
content="", metadata=meta,
))
stream_segment += 1
response = await self._process_message(
msg, on_stream=on_stream, on_stream_end=on_stream_end,
)
if response is not None:
await self.bus.publish_outbound(response)
elif msg.channel == "cli":
await self.bus.publish_outbound(OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id,
content="", metadata=msg.metadata or {},
))
except asyncio.CancelledError:
logger.info("Task cancelled for session {}", msg.session_key)
raise
except Exception:
logger.exception("Error processing message for session {}", msg.session_key)
await self.bus.publish_outbound(OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id,
content="Sorry, I encountered an error.",
))
async def close_mcp(self) -> None:
"""Drain pending background archives, then close MCP connections."""
if self._background_tasks:
await asyncio.gather(*self._background_tasks, return_exceptions=True)
self._background_tasks.clear()
if self._mcp_stack:
try:
await self._mcp_stack.aclose()
except (RuntimeError, BaseExceptionGroup):
pass # MCP SDK cancel scope cleanup is noisy but harmless
self._mcp_stack = None
def _schedule_background(self, coro) -> None:
"""Schedule a coroutine as a tracked background task (drained on shutdown)."""
task = asyncio.create_task(coro)
self._background_tasks.append(task)
task.add_done_callback(self._background_tasks.remove)
def stop(self) -> None:
"""Stop the agent loop."""
self._running = False
logger.info("Agent loop stopping")
async def _process_message(
self,
msg: InboundMessage,
session_key: str | None = None,
on_progress: Callable[[str], Awaitable[None]] | None = None,
on_stream: Callable[[str], Awaitable[None]] | None = None,
on_stream_end: Callable[..., Awaitable[None]] | None = None,
) -> OutboundMessage | None:
"""Process a single inbound message and return the response."""
# System messages: parse origin from chat_id ("channel:chat_id")
if msg.channel == "system":
channel, chat_id = (msg.chat_id.split(":", 1) if ":" in msg.chat_id
else ("cli", msg.chat_id))
logger.info("Processing system message from {}", msg.sender_id)
key = f"{channel}:{chat_id}"
session = self.sessions.get_or_create(key)
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
self._set_tool_context(channel, chat_id, msg.metadata.get("message_id"))
history = session.get_history(max_messages=0)
current_role = "assistant" if msg.sender_id == "subagent" else "user"
messages = self.context.build_messages(
history=history,
current_message=msg.content, channel=channel, chat_id=chat_id,
current_role=current_role,
)
final_content, _, all_msgs = await self._run_agent_loop(
messages, channel=channel, chat_id=chat_id,
message_id=msg.metadata.get("message_id"),
)
self._save_turn(session, all_msgs, 1 + len(history))
self.sessions.save(session)
self._schedule_background(self.memory_consolidator.maybe_consolidate_by_tokens(session))
return OutboundMessage(channel=channel, chat_id=chat_id,
content=final_content or "Background task completed.")
preview = msg.content[:80] + "..." if len(msg.content) > 80 else msg.content
logger.info("Processing message from {}:{}: {}", msg.channel, msg.sender_id, preview)
key = session_key or msg.session_key
session = self.sessions.get_or_create(key)
# Slash commands
raw = msg.content.strip()
ctx = CommandContext(msg=msg, session=session, key=key, raw=raw, loop=self)
if result := await self.commands.dispatch(ctx):
return result
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
self._set_tool_context(msg.channel, msg.chat_id, msg.metadata.get("message_id"))
if message_tool := self.tools.get("message"):
if isinstance(message_tool, MessageTool):
message_tool.start_turn()
history = session.get_history(max_messages=0)
initial_messages = self.context.build_messages(
history=history,
current_message=msg.content,
media=msg.media if msg.media else None,
channel=msg.channel, chat_id=msg.chat_id,
)
async def _bus_progress(content: str, *, tool_hint: bool = False) -> None:
meta = dict(msg.metadata or {})
meta["_progress"] = True
meta["_tool_hint"] = tool_hint
await self.bus.publish_outbound(OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id, content=content, metadata=meta,
))
final_content, _, all_msgs = await self._run_agent_loop(
initial_messages,
on_progress=on_progress or _bus_progress,
on_stream=on_stream,
on_stream_end=on_stream_end,
channel=msg.channel, chat_id=msg.chat_id,
message_id=msg.metadata.get("message_id"),
)
if final_content is None:
final_content = "I've completed processing but have no response to give."
self._save_turn(session, all_msgs, 1 + len(history))
self.sessions.save(session)
self._schedule_background(self.memory_consolidator.maybe_consolidate_by_tokens(session))
if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn:
return None
preview = final_content[:120] + "..." if len(final_content) > 120 else final_content
logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview)
meta = dict(msg.metadata or {})
if on_stream is not None:
meta["_streamed"] = True
return OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id, content=final_content,
metadata=meta,
)
@staticmethod
def _image_placeholder(block: dict[str, Any]) -> dict[str, str]:
"""Convert an inline image block into a compact text placeholder."""
path = (block.get("_meta") or {}).get("path", "")
return {"type": "text", "text": f"[image: {path}]" if path else "[image]"}
def _sanitize_persisted_blocks(
self,
content: list[dict[str, Any]],
*,
truncate_text: bool = False,
drop_runtime: bool = False,
) -> list[dict[str, Any]]:
"""Strip volatile multimodal payloads before writing session history."""
filtered: list[dict[str, Any]] = []
for block in content:
if not isinstance(block, dict):
filtered.append(block)
continue
if (
drop_runtime
and block.get("type") == "text"
and isinstance(block.get("text"), str)
and block["text"].startswith(ContextBuilder._RUNTIME_CONTEXT_TAG)
):
continue
if (
block.get("type") == "image_url"
and block.get("image_url", {}).get("url", "").startswith("data:image/")
):
filtered.append(self._image_placeholder(block))
continue
if block.get("type") == "text" and isinstance(block.get("text"), str):
text = block["text"]
if truncate_text and len(text) > self._TOOL_RESULT_MAX_CHARS:
text = text[:self._TOOL_RESULT_MAX_CHARS] + "\n... (truncated)"
filtered.append({**block, "text": text})
continue
filtered.append(block)
return filtered
def _save_turn(self, session: Session, messages: list[dict], skip: int) -> None:
"""Save new-turn messages into session, truncating large tool results."""
from datetime import datetime
for m in messages[skip:]:
entry = dict(m)
role, content = entry.get("role"), entry.get("content")
if role == "assistant" and not content and not entry.get("tool_calls"):
continue # skip empty assistant messages — they poison session context
if role == "tool":
if isinstance(content, str) and len(content) > self._TOOL_RESULT_MAX_CHARS:
entry["content"] = content[:self._TOOL_RESULT_MAX_CHARS] + "\n... (truncated)"
elif isinstance(content, list):
filtered = self._sanitize_persisted_blocks(content, truncate_text=True)
if not filtered:
continue
entry["content"] = filtered
elif role == "user":
if isinstance(content, str) and content.startswith(ContextBuilder._RUNTIME_CONTEXT_TAG):
# Strip the runtime-context prefix, keep only the user text.
parts = content.split("\n\n", 1)
if len(parts) > 1 and parts[1].strip():
entry["content"] = parts[1]
else:
continue
if isinstance(content, list):
filtered = self._sanitize_persisted_blocks(content, drop_runtime=True)
if not filtered:
continue
entry["content"] = filtered
entry.setdefault("timestamp", datetime.now().isoformat())
session.messages.append(entry)
session.updated_at = datetime.now()
async def process_direct(
self,
content: str,
session_key: str = "cli:direct",
channel: str = "cli",
chat_id: str = "direct",
on_progress: Callable[[str], Awaitable[None]] | None = None,
on_stream: Callable[[str], Awaitable[None]] | None = None,
on_stream_end: Callable[..., Awaitable[None]] | None = None,
) -> OutboundMessage | None:
"""Process a message directly and return the outbound payload."""
await self._connect_mcp()
msg = InboundMessage(channel=channel, sender_id="user", chat_id=chat_id, content=content)
return await self._process_message(
msg, session_key=session_key, on_progress=on_progress,
on_stream=on_stream, on_stream_end=on_stream_end,
)