mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-20 18:09:56 +00:00
feat: harden agent runtime for long-running tasks
This commit is contained in:
parent
63d646f731
commit
fbedf7ad77
@ -110,6 +110,20 @@ IMPORTANT: To send files (images, documents, audio, video) to the user, you MUST
|
|||||||
lines += [f"Channel: {channel}", f"Chat ID: {chat_id}"]
|
lines += [f"Channel: {channel}", f"Chat ID: {chat_id}"]
|
||||||
return ContextBuilder._RUNTIME_CONTEXT_TAG + "\n" + "\n".join(lines)
|
return ContextBuilder._RUNTIME_CONTEXT_TAG + "\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _merge_message_content(left: Any, right: Any) -> str | list[dict[str, Any]]:
|
||||||
|
if isinstance(left, str) and isinstance(right, str):
|
||||||
|
return f"{left}\n\n{right}" if left else right
|
||||||
|
|
||||||
|
def _to_blocks(value: Any) -> list[dict[str, Any]]:
|
||||||
|
if isinstance(value, list):
|
||||||
|
return [item if isinstance(item, dict) else {"type": "text", "text": str(item)} for item in value]
|
||||||
|
if value is None:
|
||||||
|
return []
|
||||||
|
return [{"type": "text", "text": str(value)}]
|
||||||
|
|
||||||
|
return _to_blocks(left) + _to_blocks(right)
|
||||||
|
|
||||||
def _load_bootstrap_files(self) -> str:
|
def _load_bootstrap_files(self) -> str:
|
||||||
"""Load all bootstrap files from workspace."""
|
"""Load all bootstrap files from workspace."""
|
||||||
parts = []
|
parts = []
|
||||||
@ -142,12 +156,17 @@ IMPORTANT: To send files (images, documents, audio, video) to the user, you MUST
|
|||||||
merged = f"{runtime_ctx}\n\n{user_content}"
|
merged = f"{runtime_ctx}\n\n{user_content}"
|
||||||
else:
|
else:
|
||||||
merged = [{"type": "text", "text": runtime_ctx}] + user_content
|
merged = [{"type": "text", "text": runtime_ctx}] + user_content
|
||||||
|
messages = [
|
||||||
return [
|
|
||||||
{"role": "system", "content": self.build_system_prompt(skill_names)},
|
{"role": "system", "content": self.build_system_prompt(skill_names)},
|
||||||
*history,
|
*history,
|
||||||
{"role": current_role, "content": merged},
|
|
||||||
]
|
]
|
||||||
|
if messages[-1].get("role") == current_role:
|
||||||
|
last = dict(messages[-1])
|
||||||
|
last["content"] = self._merge_message_content(last.get("content"), merged)
|
||||||
|
messages[-1] = last
|
||||||
|
return messages
|
||||||
|
messages.append({"role": current_role, "content": merged})
|
||||||
|
return messages
|
||||||
|
|
||||||
def _build_user_content(self, text: str, media: list[str] | None) -> str | list[dict[str, Any]]:
|
def _build_user_content(self, text: str, media: list[str] | None) -> str | list[dict[str, Any]]:
|
||||||
"""Build user message content with optional base64-encoded images."""
|
"""Build user message content with optional base64-encoded images."""
|
||||||
|
|||||||
@ -29,8 +29,10 @@ from nanobot.agent.tools.web import WebFetchTool, WebSearchTool
|
|||||||
from nanobot.bus.events import InboundMessage, OutboundMessage
|
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||||
from nanobot.command import CommandContext, CommandRouter, register_builtin_commands
|
from nanobot.command import CommandContext, CommandRouter, register_builtin_commands
|
||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.config.schema import AgentDefaults
|
||||||
from nanobot.providers.base import LLMProvider
|
from nanobot.providers.base import LLMProvider
|
||||||
from nanobot.session.manager import Session, SessionManager
|
from nanobot.session.manager import Session, SessionManager
|
||||||
|
from nanobot.utils.helpers import image_placeholder_text, truncate_text
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from nanobot.config.schema import ChannelsConfig, ExecToolConfig, WebSearchConfig
|
from nanobot.config.schema import ChannelsConfig, ExecToolConfig, WebSearchConfig
|
||||||
@ -38,11 +40,7 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
|
|
||||||
class _LoopHook(AgentHook):
|
class _LoopHook(AgentHook):
|
||||||
"""Core lifecycle hook for the main agent loop.
|
"""Core hook for the main loop."""
|
||||||
|
|
||||||
Handles streaming delta relay, progress reporting, tool-call logging,
|
|
||||||
and think-tag stripping for the built-in agent path.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -102,11 +100,7 @@ class _LoopHook(AgentHook):
|
|||||||
|
|
||||||
|
|
||||||
class _LoopHookChain(AgentHook):
|
class _LoopHookChain(AgentHook):
|
||||||
"""Run the core loop hook first, then best-effort extra hooks.
|
"""Run the core hook before extra hooks."""
|
||||||
|
|
||||||
This preserves the historical failure behavior of ``_LoopHook`` while still
|
|
||||||
letting user-supplied hooks opt into ``CompositeHook`` isolation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
__slots__ = ("_primary", "_extras")
|
__slots__ = ("_primary", "_extras")
|
||||||
|
|
||||||
@ -154,7 +148,7 @@ class AgentLoop:
|
|||||||
5. Sends responses back
|
5. Sends responses back
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_TOOL_RESULT_MAX_CHARS = 16_000
|
_RUNTIME_CHECKPOINT_KEY = "runtime_checkpoint"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -162,8 +156,11 @@ class AgentLoop:
|
|||||||
provider: LLMProvider,
|
provider: LLMProvider,
|
||||||
workspace: Path,
|
workspace: Path,
|
||||||
model: str | None = None,
|
model: str | None = None,
|
||||||
max_iterations: int = 40,
|
max_iterations: int | None = None,
|
||||||
context_window_tokens: int = 65_536,
|
context_window_tokens: int | None = None,
|
||||||
|
context_block_limit: int | None = None,
|
||||||
|
max_tool_result_chars: int | None = None,
|
||||||
|
provider_retry_mode: str = "standard",
|
||||||
web_search_config: WebSearchConfig | None = None,
|
web_search_config: WebSearchConfig | None = None,
|
||||||
web_proxy: str | None = None,
|
web_proxy: str | None = None,
|
||||||
exec_config: ExecToolConfig | None = None,
|
exec_config: ExecToolConfig | None = None,
|
||||||
@ -177,13 +174,27 @@ class AgentLoop:
|
|||||||
):
|
):
|
||||||
from nanobot.config.schema import ExecToolConfig, WebSearchConfig
|
from nanobot.config.schema import ExecToolConfig, WebSearchConfig
|
||||||
|
|
||||||
|
defaults = AgentDefaults()
|
||||||
self.bus = bus
|
self.bus = bus
|
||||||
self.channels_config = channels_config
|
self.channels_config = channels_config
|
||||||
self.provider = provider
|
self.provider = provider
|
||||||
self.workspace = workspace
|
self.workspace = workspace
|
||||||
self.model = model or provider.get_default_model()
|
self.model = model or provider.get_default_model()
|
||||||
self.max_iterations = max_iterations
|
self.max_iterations = (
|
||||||
self.context_window_tokens = context_window_tokens
|
max_iterations if max_iterations is not None else defaults.max_tool_iterations
|
||||||
|
)
|
||||||
|
self.context_window_tokens = (
|
||||||
|
context_window_tokens
|
||||||
|
if context_window_tokens is not None
|
||||||
|
else defaults.context_window_tokens
|
||||||
|
)
|
||||||
|
self.context_block_limit = context_block_limit
|
||||||
|
self.max_tool_result_chars = (
|
||||||
|
max_tool_result_chars
|
||||||
|
if max_tool_result_chars is not None
|
||||||
|
else defaults.max_tool_result_chars
|
||||||
|
)
|
||||||
|
self.provider_retry_mode = provider_retry_mode
|
||||||
self.web_search_config = web_search_config or WebSearchConfig()
|
self.web_search_config = web_search_config or WebSearchConfig()
|
||||||
self.web_proxy = web_proxy
|
self.web_proxy = web_proxy
|
||||||
self.exec_config = exec_config or ExecToolConfig()
|
self.exec_config = exec_config or ExecToolConfig()
|
||||||
@ -202,6 +213,7 @@ class AgentLoop:
|
|||||||
workspace=workspace,
|
workspace=workspace,
|
||||||
bus=bus,
|
bus=bus,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
|
max_tool_result_chars=self.max_tool_result_chars,
|
||||||
web_search_config=self.web_search_config,
|
web_search_config=self.web_search_config,
|
||||||
web_proxy=web_proxy,
|
web_proxy=web_proxy,
|
||||||
exec_config=self.exec_config,
|
exec_config=self.exec_config,
|
||||||
@ -313,6 +325,7 @@ class AgentLoop:
|
|||||||
on_stream: Callable[[str], Awaitable[None]] | None = None,
|
on_stream: Callable[[str], Awaitable[None]] | None = None,
|
||||||
on_stream_end: Callable[..., Awaitable[None]] | None = None,
|
on_stream_end: Callable[..., Awaitable[None]] | None = None,
|
||||||
*,
|
*,
|
||||||
|
session: Session | None = None,
|
||||||
channel: str = "cli",
|
channel: str = "cli",
|
||||||
chat_id: str = "direct",
|
chat_id: str = "direct",
|
||||||
message_id: str | None = None,
|
message_id: str | None = None,
|
||||||
@ -339,14 +352,27 @@ class AgentLoop:
|
|||||||
else loop_hook
|
else loop_hook
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _checkpoint(payload: dict[str, Any]) -> None:
|
||||||
|
if session is None:
|
||||||
|
return
|
||||||
|
self._set_runtime_checkpoint(session, payload)
|
||||||
|
|
||||||
result = await self.runner.run(AgentRunSpec(
|
result = await self.runner.run(AgentRunSpec(
|
||||||
initial_messages=initial_messages,
|
initial_messages=initial_messages,
|
||||||
tools=self.tools,
|
tools=self.tools,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
max_iterations=self.max_iterations,
|
max_iterations=self.max_iterations,
|
||||||
|
max_tool_result_chars=self.max_tool_result_chars,
|
||||||
hook=hook,
|
hook=hook,
|
||||||
error_message="Sorry, I encountered an error calling the AI model.",
|
error_message="Sorry, I encountered an error calling the AI model.",
|
||||||
concurrent_tools=True,
|
concurrent_tools=True,
|
||||||
|
workspace=self.workspace,
|
||||||
|
session_key=session.key if session else None,
|
||||||
|
context_window_tokens=self.context_window_tokens,
|
||||||
|
context_block_limit=self.context_block_limit,
|
||||||
|
provider_retry_mode=self.provider_retry_mode,
|
||||||
|
progress_callback=on_progress,
|
||||||
|
checkpoint_callback=_checkpoint,
|
||||||
))
|
))
|
||||||
self._last_usage = result.usage
|
self._last_usage = result.usage
|
||||||
if result.stop_reason == "max_iterations":
|
if result.stop_reason == "max_iterations":
|
||||||
@ -484,6 +510,8 @@ class AgentLoop:
|
|||||||
logger.info("Processing system message from {}", msg.sender_id)
|
logger.info("Processing system message from {}", msg.sender_id)
|
||||||
key = f"{channel}:{chat_id}"
|
key = f"{channel}:{chat_id}"
|
||||||
session = self.sessions.get_or_create(key)
|
session = self.sessions.get_or_create(key)
|
||||||
|
if self._restore_runtime_checkpoint(session):
|
||||||
|
self.sessions.save(session)
|
||||||
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
|
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
|
||||||
self._set_tool_context(channel, chat_id, msg.metadata.get("message_id"))
|
self._set_tool_context(channel, chat_id, msg.metadata.get("message_id"))
|
||||||
history = session.get_history(max_messages=0)
|
history = session.get_history(max_messages=0)
|
||||||
@ -494,10 +522,11 @@ class AgentLoop:
|
|||||||
current_role=current_role,
|
current_role=current_role,
|
||||||
)
|
)
|
||||||
final_content, _, all_msgs = await self._run_agent_loop(
|
final_content, _, all_msgs = await self._run_agent_loop(
|
||||||
messages, channel=channel, chat_id=chat_id,
|
messages, session=session, channel=channel, chat_id=chat_id,
|
||||||
message_id=msg.metadata.get("message_id"),
|
message_id=msg.metadata.get("message_id"),
|
||||||
)
|
)
|
||||||
self._save_turn(session, all_msgs, 1 + len(history))
|
self._save_turn(session, all_msgs, 1 + len(history))
|
||||||
|
self._clear_runtime_checkpoint(session)
|
||||||
self.sessions.save(session)
|
self.sessions.save(session)
|
||||||
self._schedule_background(self.memory_consolidator.maybe_consolidate_by_tokens(session))
|
self._schedule_background(self.memory_consolidator.maybe_consolidate_by_tokens(session))
|
||||||
return OutboundMessage(channel=channel, chat_id=chat_id,
|
return OutboundMessage(channel=channel, chat_id=chat_id,
|
||||||
@ -508,6 +537,8 @@ class AgentLoop:
|
|||||||
|
|
||||||
key = session_key or msg.session_key
|
key = session_key or msg.session_key
|
||||||
session = self.sessions.get_or_create(key)
|
session = self.sessions.get_or_create(key)
|
||||||
|
if self._restore_runtime_checkpoint(session):
|
||||||
|
self.sessions.save(session)
|
||||||
|
|
||||||
# Slash commands
|
# Slash commands
|
||||||
raw = msg.content.strip()
|
raw = msg.content.strip()
|
||||||
@ -543,6 +574,7 @@ class AgentLoop:
|
|||||||
on_progress=on_progress or _bus_progress,
|
on_progress=on_progress or _bus_progress,
|
||||||
on_stream=on_stream,
|
on_stream=on_stream,
|
||||||
on_stream_end=on_stream_end,
|
on_stream_end=on_stream_end,
|
||||||
|
session=session,
|
||||||
channel=msg.channel, chat_id=msg.chat_id,
|
channel=msg.channel, chat_id=msg.chat_id,
|
||||||
message_id=msg.metadata.get("message_id"),
|
message_id=msg.metadata.get("message_id"),
|
||||||
)
|
)
|
||||||
@ -551,6 +583,7 @@ class AgentLoop:
|
|||||||
final_content = "I've completed processing but have no response to give."
|
final_content = "I've completed processing but have no response to give."
|
||||||
|
|
||||||
self._save_turn(session, all_msgs, 1 + len(history))
|
self._save_turn(session, all_msgs, 1 + len(history))
|
||||||
|
self._clear_runtime_checkpoint(session)
|
||||||
self.sessions.save(session)
|
self.sessions.save(session)
|
||||||
self._schedule_background(self.memory_consolidator.maybe_consolidate_by_tokens(session))
|
self._schedule_background(self.memory_consolidator.maybe_consolidate_by_tokens(session))
|
||||||
|
|
||||||
@ -568,12 +601,6 @@ class AgentLoop:
|
|||||||
metadata=meta,
|
metadata=meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _image_placeholder(block: dict[str, Any]) -> dict[str, str]:
|
|
||||||
"""Convert an inline image block into a compact text placeholder."""
|
|
||||||
path = (block.get("_meta") or {}).get("path", "")
|
|
||||||
return {"type": "text", "text": f"[image: {path}]" if path else "[image]"}
|
|
||||||
|
|
||||||
def _sanitize_persisted_blocks(
|
def _sanitize_persisted_blocks(
|
||||||
self,
|
self,
|
||||||
content: list[dict[str, Any]],
|
content: list[dict[str, Any]],
|
||||||
@ -600,13 +627,14 @@ class AgentLoop:
|
|||||||
block.get("type") == "image_url"
|
block.get("type") == "image_url"
|
||||||
and block.get("image_url", {}).get("url", "").startswith("data:image/")
|
and block.get("image_url", {}).get("url", "").startswith("data:image/")
|
||||||
):
|
):
|
||||||
filtered.append(self._image_placeholder(block))
|
path = (block.get("_meta") or {}).get("path", "")
|
||||||
|
filtered.append({"type": "text", "text": image_placeholder_text(path)})
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if block.get("type") == "text" and isinstance(block.get("text"), str):
|
if block.get("type") == "text" and isinstance(block.get("text"), str):
|
||||||
text = block["text"]
|
text = block["text"]
|
||||||
if truncate_text and len(text) > self._TOOL_RESULT_MAX_CHARS:
|
if truncate_text and len(text) > self.max_tool_result_chars:
|
||||||
text = text[:self._TOOL_RESULT_MAX_CHARS] + "\n... (truncated)"
|
text = truncate_text(text, self.max_tool_result_chars)
|
||||||
filtered.append({**block, "text": text})
|
filtered.append({**block, "text": text})
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -623,8 +651,8 @@ class AgentLoop:
|
|||||||
if role == "assistant" and not content and not entry.get("tool_calls"):
|
if role == "assistant" and not content and not entry.get("tool_calls"):
|
||||||
continue # skip empty assistant messages — they poison session context
|
continue # skip empty assistant messages — they poison session context
|
||||||
if role == "tool":
|
if role == "tool":
|
||||||
if isinstance(content, str) and len(content) > self._TOOL_RESULT_MAX_CHARS:
|
if isinstance(content, str) and len(content) > self.max_tool_result_chars:
|
||||||
entry["content"] = content[:self._TOOL_RESULT_MAX_CHARS] + "\n... (truncated)"
|
entry["content"] = truncate_text(content, self.max_tool_result_chars)
|
||||||
elif isinstance(content, list):
|
elif isinstance(content, list):
|
||||||
filtered = self._sanitize_persisted_blocks(content, truncate_text=True)
|
filtered = self._sanitize_persisted_blocks(content, truncate_text=True)
|
||||||
if not filtered:
|
if not filtered:
|
||||||
@ -647,6 +675,78 @@ class AgentLoop:
|
|||||||
session.messages.append(entry)
|
session.messages.append(entry)
|
||||||
session.updated_at = datetime.now()
|
session.updated_at = datetime.now()
|
||||||
|
|
||||||
|
def _set_runtime_checkpoint(self, session: Session, payload: dict[str, Any]) -> None:
|
||||||
|
"""Persist the latest in-flight turn state into session metadata."""
|
||||||
|
session.metadata[self._RUNTIME_CHECKPOINT_KEY] = payload
|
||||||
|
self.sessions.save(session)
|
||||||
|
|
||||||
|
def _clear_runtime_checkpoint(self, session: Session) -> None:
|
||||||
|
if self._RUNTIME_CHECKPOINT_KEY in session.metadata:
|
||||||
|
session.metadata.pop(self._RUNTIME_CHECKPOINT_KEY, None)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _checkpoint_message_key(message: dict[str, Any]) -> tuple[Any, ...]:
|
||||||
|
return (
|
||||||
|
message.get("role"),
|
||||||
|
message.get("content"),
|
||||||
|
message.get("tool_call_id"),
|
||||||
|
message.get("name"),
|
||||||
|
message.get("tool_calls"),
|
||||||
|
message.get("reasoning_content"),
|
||||||
|
message.get("thinking_blocks"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _restore_runtime_checkpoint(self, session: Session) -> bool:
|
||||||
|
"""Materialize an unfinished turn into session history before a new request."""
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
checkpoint = session.metadata.get(self._RUNTIME_CHECKPOINT_KEY)
|
||||||
|
if not isinstance(checkpoint, dict):
|
||||||
|
return False
|
||||||
|
|
||||||
|
assistant_message = checkpoint.get("assistant_message")
|
||||||
|
completed_tool_results = checkpoint.get("completed_tool_results") or []
|
||||||
|
pending_tool_calls = checkpoint.get("pending_tool_calls") or []
|
||||||
|
|
||||||
|
restored_messages: list[dict[str, Any]] = []
|
||||||
|
if isinstance(assistant_message, dict):
|
||||||
|
restored = dict(assistant_message)
|
||||||
|
restored.setdefault("timestamp", datetime.now().isoformat())
|
||||||
|
restored_messages.append(restored)
|
||||||
|
for message in completed_tool_results:
|
||||||
|
if isinstance(message, dict):
|
||||||
|
restored = dict(message)
|
||||||
|
restored.setdefault("timestamp", datetime.now().isoformat())
|
||||||
|
restored_messages.append(restored)
|
||||||
|
for tool_call in pending_tool_calls:
|
||||||
|
if not isinstance(tool_call, dict):
|
||||||
|
continue
|
||||||
|
tool_id = tool_call.get("id")
|
||||||
|
name = ((tool_call.get("function") or {}).get("name")) or "tool"
|
||||||
|
restored_messages.append({
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": tool_id,
|
||||||
|
"name": name,
|
||||||
|
"content": "Error: Task interrupted before this tool finished.",
|
||||||
|
"timestamp": datetime.now().isoformat(),
|
||||||
|
})
|
||||||
|
|
||||||
|
overlap = 0
|
||||||
|
max_overlap = min(len(session.messages), len(restored_messages))
|
||||||
|
for size in range(max_overlap, 0, -1):
|
||||||
|
existing = session.messages[-size:]
|
||||||
|
restored = restored_messages[:size]
|
||||||
|
if all(
|
||||||
|
self._checkpoint_message_key(left) == self._checkpoint_message_key(right)
|
||||||
|
for left, right in zip(existing, restored)
|
||||||
|
):
|
||||||
|
overlap = size
|
||||||
|
break
|
||||||
|
session.messages.extend(restored_messages[overlap:])
|
||||||
|
|
||||||
|
self._clear_runtime_checkpoint(session)
|
||||||
|
return True
|
||||||
|
|
||||||
async def process_direct(
|
async def process_direct(
|
||||||
self,
|
self,
|
||||||
content: str,
|
content: str,
|
||||||
|
|||||||
@ -4,20 +4,29 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
from nanobot.agent.hook import AgentHook, AgentHookContext
|
from nanobot.agent.hook import AgentHook, AgentHookContext
|
||||||
from nanobot.agent.tools.registry import ToolRegistry
|
from nanobot.agent.tools.registry import ToolRegistry
|
||||||
from nanobot.providers.base import LLMProvider, ToolCallRequest
|
from nanobot.providers.base import LLMProvider, ToolCallRequest
|
||||||
from nanobot.utils.helpers import build_assistant_message
|
from nanobot.utils.helpers import (
|
||||||
|
build_assistant_message,
|
||||||
|
estimate_message_tokens,
|
||||||
|
estimate_prompt_tokens_chain,
|
||||||
|
find_legal_message_start,
|
||||||
|
maybe_persist_tool_result,
|
||||||
|
truncate_text,
|
||||||
|
)
|
||||||
|
|
||||||
_DEFAULT_MAX_ITERATIONS_MESSAGE = (
|
_DEFAULT_MAX_ITERATIONS_MESSAGE = (
|
||||||
"I reached the maximum number of tool call iterations ({max_iterations}) "
|
"I reached the maximum number of tool call iterations ({max_iterations}) "
|
||||||
"without completing the task. You can try breaking the task into smaller steps."
|
"without completing the task. You can try breaking the task into smaller steps."
|
||||||
)
|
)
|
||||||
_DEFAULT_ERROR_MESSAGE = "Sorry, I encountered an error calling the AI model."
|
_DEFAULT_ERROR_MESSAGE = "Sorry, I encountered an error calling the AI model."
|
||||||
|
_SNIP_SAFETY_BUFFER = 1024
|
||||||
|
|
||||||
@dataclass(slots=True)
|
@dataclass(slots=True)
|
||||||
class AgentRunSpec:
|
class AgentRunSpec:
|
||||||
"""Configuration for a single agent execution."""
|
"""Configuration for a single agent execution."""
|
||||||
@ -26,6 +35,7 @@ class AgentRunSpec:
|
|||||||
tools: ToolRegistry
|
tools: ToolRegistry
|
||||||
model: str
|
model: str
|
||||||
max_iterations: int
|
max_iterations: int
|
||||||
|
max_tool_result_chars: int
|
||||||
temperature: float | None = None
|
temperature: float | None = None
|
||||||
max_tokens: int | None = None
|
max_tokens: int | None = None
|
||||||
reasoning_effort: str | None = None
|
reasoning_effort: str | None = None
|
||||||
@ -34,6 +44,13 @@ class AgentRunSpec:
|
|||||||
max_iterations_message: str | None = None
|
max_iterations_message: str | None = None
|
||||||
concurrent_tools: bool = False
|
concurrent_tools: bool = False
|
||||||
fail_on_tool_error: bool = False
|
fail_on_tool_error: bool = False
|
||||||
|
workspace: Path | None = None
|
||||||
|
session_key: str | None = None
|
||||||
|
context_window_tokens: int | None = None
|
||||||
|
context_block_limit: int | None = None
|
||||||
|
provider_retry_mode: str = "standard"
|
||||||
|
progress_callback: Any | None = None
|
||||||
|
checkpoint_callback: Any | None = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass(slots=True)
|
@dataclass(slots=True)
|
||||||
@ -66,12 +83,25 @@ class AgentRunner:
|
|||||||
tool_events: list[dict[str, str]] = []
|
tool_events: list[dict[str, str]] = []
|
||||||
|
|
||||||
for iteration in range(spec.max_iterations):
|
for iteration in range(spec.max_iterations):
|
||||||
|
try:
|
||||||
|
messages = self._apply_tool_result_budget(spec, messages)
|
||||||
|
messages_for_model = self._snip_history(spec, messages)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"Context governance failed on turn {} for {}: {}; using raw messages",
|
||||||
|
iteration,
|
||||||
|
spec.session_key or "default",
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
messages_for_model = messages
|
||||||
context = AgentHookContext(iteration=iteration, messages=messages)
|
context = AgentHookContext(iteration=iteration, messages=messages)
|
||||||
await hook.before_iteration(context)
|
await hook.before_iteration(context)
|
||||||
kwargs: dict[str, Any] = {
|
kwargs: dict[str, Any] = {
|
||||||
"messages": messages,
|
"messages": messages_for_model,
|
||||||
"tools": spec.tools.get_definitions(),
|
"tools": spec.tools.get_definitions(),
|
||||||
"model": spec.model,
|
"model": spec.model,
|
||||||
|
"retry_mode": spec.provider_retry_mode,
|
||||||
|
"on_retry_wait": spec.progress_callback,
|
||||||
}
|
}
|
||||||
if spec.temperature is not None:
|
if spec.temperature is not None:
|
||||||
kwargs["temperature"] = spec.temperature
|
kwargs["temperature"] = spec.temperature
|
||||||
@ -104,13 +134,25 @@ class AgentRunner:
|
|||||||
if hook.wants_streaming():
|
if hook.wants_streaming():
|
||||||
await hook.on_stream_end(context, resuming=True)
|
await hook.on_stream_end(context, resuming=True)
|
||||||
|
|
||||||
messages.append(build_assistant_message(
|
assistant_message = build_assistant_message(
|
||||||
response.content or "",
|
response.content or "",
|
||||||
tool_calls=[tc.to_openai_tool_call() for tc in response.tool_calls],
|
tool_calls=[tc.to_openai_tool_call() for tc in response.tool_calls],
|
||||||
reasoning_content=response.reasoning_content,
|
reasoning_content=response.reasoning_content,
|
||||||
thinking_blocks=response.thinking_blocks,
|
thinking_blocks=response.thinking_blocks,
|
||||||
))
|
)
|
||||||
|
messages.append(assistant_message)
|
||||||
tools_used.extend(tc.name for tc in response.tool_calls)
|
tools_used.extend(tc.name for tc in response.tool_calls)
|
||||||
|
await self._emit_checkpoint(
|
||||||
|
spec,
|
||||||
|
{
|
||||||
|
"phase": "awaiting_tools",
|
||||||
|
"iteration": iteration,
|
||||||
|
"model": spec.model,
|
||||||
|
"assistant_message": assistant_message,
|
||||||
|
"completed_tool_results": [],
|
||||||
|
"pending_tool_calls": [tc.to_openai_tool_call() for tc in response.tool_calls],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
await hook.before_execute_tools(context)
|
await hook.before_execute_tools(context)
|
||||||
|
|
||||||
@ -125,13 +167,31 @@ class AgentRunner:
|
|||||||
context.stop_reason = stop_reason
|
context.stop_reason = stop_reason
|
||||||
await hook.after_iteration(context)
|
await hook.after_iteration(context)
|
||||||
break
|
break
|
||||||
|
completed_tool_results: list[dict[str, Any]] = []
|
||||||
for tool_call, result in zip(response.tool_calls, results):
|
for tool_call, result in zip(response.tool_calls, results):
|
||||||
messages.append({
|
tool_message = {
|
||||||
"role": "tool",
|
"role": "tool",
|
||||||
"tool_call_id": tool_call.id,
|
"tool_call_id": tool_call.id,
|
||||||
"name": tool_call.name,
|
"name": tool_call.name,
|
||||||
"content": result,
|
"content": self._normalize_tool_result(
|
||||||
})
|
spec,
|
||||||
|
tool_call.id,
|
||||||
|
result,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
messages.append(tool_message)
|
||||||
|
completed_tool_results.append(tool_message)
|
||||||
|
await self._emit_checkpoint(
|
||||||
|
spec,
|
||||||
|
{
|
||||||
|
"phase": "tools_completed",
|
||||||
|
"iteration": iteration,
|
||||||
|
"model": spec.model,
|
||||||
|
"assistant_message": assistant_message,
|
||||||
|
"completed_tool_results": completed_tool_results,
|
||||||
|
"pending_tool_calls": [],
|
||||||
|
},
|
||||||
|
)
|
||||||
await hook.after_iteration(context)
|
await hook.after_iteration(context)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -143,6 +203,7 @@ class AgentRunner:
|
|||||||
final_content = clean or spec.error_message or _DEFAULT_ERROR_MESSAGE
|
final_content = clean or spec.error_message or _DEFAULT_ERROR_MESSAGE
|
||||||
stop_reason = "error"
|
stop_reason = "error"
|
||||||
error = final_content
|
error = final_content
|
||||||
|
self._append_final_message(messages, final_content)
|
||||||
context.final_content = final_content
|
context.final_content = final_content
|
||||||
context.error = error
|
context.error = error
|
||||||
context.stop_reason = stop_reason
|
context.stop_reason = stop_reason
|
||||||
@ -154,6 +215,17 @@ class AgentRunner:
|
|||||||
reasoning_content=response.reasoning_content,
|
reasoning_content=response.reasoning_content,
|
||||||
thinking_blocks=response.thinking_blocks,
|
thinking_blocks=response.thinking_blocks,
|
||||||
))
|
))
|
||||||
|
await self._emit_checkpoint(
|
||||||
|
spec,
|
||||||
|
{
|
||||||
|
"phase": "final_response",
|
||||||
|
"iteration": iteration,
|
||||||
|
"model": spec.model,
|
||||||
|
"assistant_message": messages[-1],
|
||||||
|
"completed_tool_results": [],
|
||||||
|
"pending_tool_calls": [],
|
||||||
|
},
|
||||||
|
)
|
||||||
final_content = clean
|
final_content = clean
|
||||||
context.final_content = final_content
|
context.final_content = final_content
|
||||||
context.stop_reason = stop_reason
|
context.stop_reason = stop_reason
|
||||||
@ -163,6 +235,7 @@ class AgentRunner:
|
|||||||
stop_reason = "max_iterations"
|
stop_reason = "max_iterations"
|
||||||
template = spec.max_iterations_message or _DEFAULT_MAX_ITERATIONS_MESSAGE
|
template = spec.max_iterations_message or _DEFAULT_MAX_ITERATIONS_MESSAGE
|
||||||
final_content = template.format(max_iterations=spec.max_iterations)
|
final_content = template.format(max_iterations=spec.max_iterations)
|
||||||
|
self._append_final_message(messages, final_content)
|
||||||
|
|
||||||
return AgentRunResult(
|
return AgentRunResult(
|
||||||
final_content=final_content,
|
final_content=final_content,
|
||||||
@ -179,16 +252,17 @@ class AgentRunner:
|
|||||||
spec: AgentRunSpec,
|
spec: AgentRunSpec,
|
||||||
tool_calls: list[ToolCallRequest],
|
tool_calls: list[ToolCallRequest],
|
||||||
) -> tuple[list[Any], list[dict[str, str]], BaseException | None]:
|
) -> tuple[list[Any], list[dict[str, str]], BaseException | None]:
|
||||||
if spec.concurrent_tools:
|
batches = self._partition_tool_batches(spec, tool_calls)
|
||||||
tool_results = await asyncio.gather(*(
|
tool_results: list[tuple[Any, dict[str, str], BaseException | None]] = []
|
||||||
self._run_tool(spec, tool_call)
|
for batch in batches:
|
||||||
for tool_call in tool_calls
|
if spec.concurrent_tools and len(batch) > 1:
|
||||||
))
|
tool_results.extend(await asyncio.gather(*(
|
||||||
else:
|
self._run_tool(spec, tool_call)
|
||||||
tool_results = [
|
for tool_call in batch
|
||||||
await self._run_tool(spec, tool_call)
|
)))
|
||||||
for tool_call in tool_calls
|
else:
|
||||||
]
|
for tool_call in batch:
|
||||||
|
tool_results.append(await self._run_tool(spec, tool_call))
|
||||||
|
|
||||||
results: list[Any] = []
|
results: list[Any] = []
|
||||||
events: list[dict[str, str]] = []
|
events: list[dict[str, str]] = []
|
||||||
@ -205,8 +279,28 @@ class AgentRunner:
|
|||||||
spec: AgentRunSpec,
|
spec: AgentRunSpec,
|
||||||
tool_call: ToolCallRequest,
|
tool_call: ToolCallRequest,
|
||||||
) -> tuple[Any, dict[str, str], BaseException | None]:
|
) -> tuple[Any, dict[str, str], BaseException | None]:
|
||||||
|
_HINT = "\n\n[Analyze the error above and try a different approach.]"
|
||||||
|
prepare_call = getattr(spec.tools, "prepare_call", None)
|
||||||
|
tool, params, prep_error = None, tool_call.arguments, None
|
||||||
|
if callable(prepare_call):
|
||||||
|
try:
|
||||||
|
prepared = prepare_call(tool_call.name, tool_call.arguments)
|
||||||
|
if isinstance(prepared, tuple) and len(prepared) == 3:
|
||||||
|
tool, params, prep_error = prepared
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
if prep_error:
|
||||||
|
event = {
|
||||||
|
"name": tool_call.name,
|
||||||
|
"status": "error",
|
||||||
|
"detail": prep_error.split(": ", 1)[-1][:120],
|
||||||
|
}
|
||||||
|
return prep_error + _HINT, event, RuntimeError(prep_error) if spec.fail_on_tool_error else None
|
||||||
try:
|
try:
|
||||||
result = await spec.tools.execute(tool_call.name, tool_call.arguments)
|
if tool is not None:
|
||||||
|
result = await tool.execute(**params)
|
||||||
|
else:
|
||||||
|
result = await spec.tools.execute(tool_call.name, params)
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
raise
|
raise
|
||||||
except BaseException as exc:
|
except BaseException as exc:
|
||||||
@ -219,14 +313,175 @@ class AgentRunner:
|
|||||||
return f"Error: {type(exc).__name__}: {exc}", event, exc
|
return f"Error: {type(exc).__name__}: {exc}", event, exc
|
||||||
return f"Error: {type(exc).__name__}: {exc}", event, None
|
return f"Error: {type(exc).__name__}: {exc}", event, None
|
||||||
|
|
||||||
|
if isinstance(result, str) and result.startswith("Error"):
|
||||||
|
event = {
|
||||||
|
"name": tool_call.name,
|
||||||
|
"status": "error",
|
||||||
|
"detail": result.replace("\n", " ").strip()[:120],
|
||||||
|
}
|
||||||
|
if spec.fail_on_tool_error:
|
||||||
|
return result + _HINT, event, RuntimeError(result)
|
||||||
|
return result + _HINT, event, None
|
||||||
|
|
||||||
detail = "" if result is None else str(result)
|
detail = "" if result is None else str(result)
|
||||||
detail = detail.replace("\n", " ").strip()
|
detail = detail.replace("\n", " ").strip()
|
||||||
if not detail:
|
if not detail:
|
||||||
detail = "(empty)"
|
detail = "(empty)"
|
||||||
elif len(detail) > 120:
|
elif len(detail) > 120:
|
||||||
detail = detail[:120] + "..."
|
detail = detail[:120] + "..."
|
||||||
return result, {
|
return result, {"name": tool_call.name, "status": "ok", "detail": detail}, None
|
||||||
"name": tool_call.name,
|
|
||||||
"status": "error" if isinstance(result, str) and result.startswith("Error") else "ok",
|
async def _emit_checkpoint(
|
||||||
"detail": detail,
|
self,
|
||||||
}, None
|
spec: AgentRunSpec,
|
||||||
|
payload: dict[str, Any],
|
||||||
|
) -> None:
|
||||||
|
callback = spec.checkpoint_callback
|
||||||
|
if callback is not None:
|
||||||
|
await callback(payload)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _append_final_message(messages: list[dict[str, Any]], content: str | None) -> None:
|
||||||
|
if not content:
|
||||||
|
return
|
||||||
|
if (
|
||||||
|
messages
|
||||||
|
and messages[-1].get("role") == "assistant"
|
||||||
|
and not messages[-1].get("tool_calls")
|
||||||
|
):
|
||||||
|
if messages[-1].get("content") == content:
|
||||||
|
return
|
||||||
|
messages[-1] = build_assistant_message(content)
|
||||||
|
return
|
||||||
|
messages.append(build_assistant_message(content))
|
||||||
|
|
||||||
|
def _normalize_tool_result(
|
||||||
|
self,
|
||||||
|
spec: AgentRunSpec,
|
||||||
|
tool_call_id: str,
|
||||||
|
result: Any,
|
||||||
|
) -> Any:
|
||||||
|
try:
|
||||||
|
content = maybe_persist_tool_result(
|
||||||
|
spec.workspace,
|
||||||
|
spec.session_key,
|
||||||
|
tool_call_id,
|
||||||
|
result,
|
||||||
|
max_chars=spec.max_tool_result_chars,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"Tool result persist failed for {} in {}: {}; using raw result",
|
||||||
|
tool_call_id,
|
||||||
|
spec.session_key or "default",
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
content = result
|
||||||
|
if isinstance(content, str) and len(content) > spec.max_tool_result_chars:
|
||||||
|
return truncate_text(content, spec.max_tool_result_chars)
|
||||||
|
return content
|
||||||
|
|
||||||
|
def _apply_tool_result_budget(
|
||||||
|
self,
|
||||||
|
spec: AgentRunSpec,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
updated = messages
|
||||||
|
for idx, message in enumerate(messages):
|
||||||
|
if message.get("role") != "tool":
|
||||||
|
continue
|
||||||
|
normalized = self._normalize_tool_result(
|
||||||
|
spec,
|
||||||
|
str(message.get("tool_call_id") or f"tool_{idx}"),
|
||||||
|
message.get("content"),
|
||||||
|
)
|
||||||
|
if normalized != message.get("content"):
|
||||||
|
if updated is messages:
|
||||||
|
updated = [dict(m) for m in messages]
|
||||||
|
updated[idx]["content"] = normalized
|
||||||
|
return updated
|
||||||
|
|
||||||
|
def _snip_history(
|
||||||
|
self,
|
||||||
|
spec: AgentRunSpec,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
if not messages or not spec.context_window_tokens:
|
||||||
|
return messages
|
||||||
|
|
||||||
|
provider_max_tokens = getattr(getattr(self.provider, "generation", None), "max_tokens", 4096)
|
||||||
|
max_output = spec.max_tokens if isinstance(spec.max_tokens, int) else (
|
||||||
|
provider_max_tokens if isinstance(provider_max_tokens, int) else 4096
|
||||||
|
)
|
||||||
|
budget = spec.context_block_limit or (
|
||||||
|
spec.context_window_tokens - max_output - _SNIP_SAFETY_BUFFER
|
||||||
|
)
|
||||||
|
if budget <= 0:
|
||||||
|
return messages
|
||||||
|
|
||||||
|
estimate, _ = estimate_prompt_tokens_chain(
|
||||||
|
self.provider,
|
||||||
|
spec.model,
|
||||||
|
messages,
|
||||||
|
spec.tools.get_definitions(),
|
||||||
|
)
|
||||||
|
if estimate <= budget:
|
||||||
|
return messages
|
||||||
|
|
||||||
|
system_messages = [dict(msg) for msg in messages if msg.get("role") == "system"]
|
||||||
|
non_system = [dict(msg) for msg in messages if msg.get("role") != "system"]
|
||||||
|
if not non_system:
|
||||||
|
return messages
|
||||||
|
|
||||||
|
system_tokens = sum(estimate_message_tokens(msg) for msg in system_messages)
|
||||||
|
remaining_budget = max(128, budget - system_tokens)
|
||||||
|
kept: list[dict[str, Any]] = []
|
||||||
|
kept_tokens = 0
|
||||||
|
for message in reversed(non_system):
|
||||||
|
msg_tokens = estimate_message_tokens(message)
|
||||||
|
if kept and kept_tokens + msg_tokens > remaining_budget:
|
||||||
|
break
|
||||||
|
kept.append(message)
|
||||||
|
kept_tokens += msg_tokens
|
||||||
|
kept.reverse()
|
||||||
|
|
||||||
|
if kept:
|
||||||
|
for i, message in enumerate(kept):
|
||||||
|
if message.get("role") == "user":
|
||||||
|
kept = kept[i:]
|
||||||
|
break
|
||||||
|
start = find_legal_message_start(kept)
|
||||||
|
if start:
|
||||||
|
kept = kept[start:]
|
||||||
|
if not kept:
|
||||||
|
kept = non_system[-min(len(non_system), 4) :]
|
||||||
|
start = find_legal_message_start(kept)
|
||||||
|
if start:
|
||||||
|
kept = kept[start:]
|
||||||
|
return system_messages + kept
|
||||||
|
|
||||||
|
def _partition_tool_batches(
|
||||||
|
self,
|
||||||
|
spec: AgentRunSpec,
|
||||||
|
tool_calls: list[ToolCallRequest],
|
||||||
|
) -> list[list[ToolCallRequest]]:
|
||||||
|
if not spec.concurrent_tools:
|
||||||
|
return [[tool_call] for tool_call in tool_calls]
|
||||||
|
|
||||||
|
batches: list[list[ToolCallRequest]] = []
|
||||||
|
current: list[ToolCallRequest] = []
|
||||||
|
for tool_call in tool_calls:
|
||||||
|
get_tool = getattr(spec.tools, "get", None)
|
||||||
|
tool = get_tool(tool_call.name) if callable(get_tool) else None
|
||||||
|
can_batch = bool(tool and tool.concurrency_safe)
|
||||||
|
if can_batch:
|
||||||
|
current.append(tool_call)
|
||||||
|
continue
|
||||||
|
if current:
|
||||||
|
batches.append(current)
|
||||||
|
current = []
|
||||||
|
batches.append([tool_call])
|
||||||
|
if current:
|
||||||
|
batches.append(current)
|
||||||
|
return batches
|
||||||
|
|
||||||
|
|||||||
@ -44,6 +44,7 @@ class SubagentManager:
|
|||||||
provider: LLMProvider,
|
provider: LLMProvider,
|
||||||
workspace: Path,
|
workspace: Path,
|
||||||
bus: MessageBus,
|
bus: MessageBus,
|
||||||
|
max_tool_result_chars: int,
|
||||||
model: str | None = None,
|
model: str | None = None,
|
||||||
web_search_config: "WebSearchConfig | None" = None,
|
web_search_config: "WebSearchConfig | None" = None,
|
||||||
web_proxy: str | None = None,
|
web_proxy: str | None = None,
|
||||||
@ -56,6 +57,7 @@ class SubagentManager:
|
|||||||
self.workspace = workspace
|
self.workspace = workspace
|
||||||
self.bus = bus
|
self.bus = bus
|
||||||
self.model = model or provider.get_default_model()
|
self.model = model or provider.get_default_model()
|
||||||
|
self.max_tool_result_chars = max_tool_result_chars
|
||||||
self.web_search_config = web_search_config or WebSearchConfig()
|
self.web_search_config = web_search_config or WebSearchConfig()
|
||||||
self.web_proxy = web_proxy
|
self.web_proxy = web_proxy
|
||||||
self.exec_config = exec_config or ExecToolConfig()
|
self.exec_config = exec_config or ExecToolConfig()
|
||||||
@ -136,6 +138,7 @@ class SubagentManager:
|
|||||||
tools=tools,
|
tools=tools,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
max_iterations=15,
|
max_iterations=15,
|
||||||
|
max_tool_result_chars=self.max_tool_result_chars,
|
||||||
hook=_SubagentHook(task_id),
|
hook=_SubagentHook(task_id),
|
||||||
max_iterations_message="Task completed but no final response was generated.",
|
max_iterations_message="Task completed but no final response was generated.",
|
||||||
error_message=None,
|
error_message=None,
|
||||||
|
|||||||
@ -53,6 +53,21 @@ class Tool(ABC):
|
|||||||
"""JSON Schema for tool parameters."""
|
"""JSON Schema for tool parameters."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def read_only(self) -> bool:
|
||||||
|
"""Whether this tool is side-effect free and safe to parallelize."""
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def concurrency_safe(self) -> bool:
|
||||||
|
"""Whether this tool can run alongside other concurrency-safe tools."""
|
||||||
|
return self.read_only and not self.exclusive
|
||||||
|
|
||||||
|
@property
|
||||||
|
def exclusive(self) -> bool:
|
||||||
|
"""Whether this tool should run alone even if concurrency is enabled."""
|
||||||
|
return False
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def execute(self, **kwargs: Any) -> Any:
|
async def execute(self, **kwargs: Any) -> Any:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -73,6 +73,10 @@ class ReadFileTool(_FsTool):
|
|||||||
"Use offset and limit to paginate through large files."
|
"Use offset and limit to paginate through large files."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def read_only(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def parameters(self) -> dict[str, Any]:
|
def parameters(self) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
@ -344,6 +348,10 @@ class ListDirTool(_FsTool):
|
|||||||
"Common noise directories (.git, node_modules, __pycache__, etc.) are auto-ignored."
|
"Common noise directories (.git, node_modules, __pycache__, etc.) are auto-ignored."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def read_only(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def parameters(self) -> dict[str, Any]:
|
def parameters(self) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
|
|||||||
@ -35,22 +35,35 @@ class ToolRegistry:
|
|||||||
"""Get all tool definitions in OpenAI format."""
|
"""Get all tool definitions in OpenAI format."""
|
||||||
return [tool.to_schema() for tool in self._tools.values()]
|
return [tool.to_schema() for tool in self._tools.values()]
|
||||||
|
|
||||||
|
def prepare_call(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
params: dict[str, Any],
|
||||||
|
) -> tuple[Tool | None, dict[str, Any], str | None]:
|
||||||
|
"""Resolve, cast, and validate one tool call."""
|
||||||
|
tool = self._tools.get(name)
|
||||||
|
if not tool:
|
||||||
|
return None, params, (
|
||||||
|
f"Error: Tool '{name}' not found. Available: {', '.join(self.tool_names)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
cast_params = tool.cast_params(params)
|
||||||
|
errors = tool.validate_params(cast_params)
|
||||||
|
if errors:
|
||||||
|
return tool, cast_params, (
|
||||||
|
f"Error: Invalid parameters for tool '{name}': " + "; ".join(errors)
|
||||||
|
)
|
||||||
|
return tool, cast_params, None
|
||||||
|
|
||||||
async def execute(self, name: str, params: dict[str, Any]) -> Any:
|
async def execute(self, name: str, params: dict[str, Any]) -> Any:
|
||||||
"""Execute a tool by name with given parameters."""
|
"""Execute a tool by name with given parameters."""
|
||||||
_HINT = "\n\n[Analyze the error above and try a different approach.]"
|
_HINT = "\n\n[Analyze the error above and try a different approach.]"
|
||||||
|
tool, params, error = self.prepare_call(name, params)
|
||||||
tool = self._tools.get(name)
|
if error:
|
||||||
if not tool:
|
return error + _HINT
|
||||||
return f"Error: Tool '{name}' not found. Available: {', '.join(self.tool_names)}"
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Attempt to cast parameters to match schema types
|
assert tool is not None # guarded by prepare_call()
|
||||||
params = tool.cast_params(params)
|
|
||||||
|
|
||||||
# Validate parameters
|
|
||||||
errors = tool.validate_params(params)
|
|
||||||
if errors:
|
|
||||||
return f"Error: Invalid parameters for tool '{name}': " + "; ".join(errors) + _HINT
|
|
||||||
result = await tool.execute(**params)
|
result = await tool.execute(**params)
|
||||||
if isinstance(result, str) and result.startswith("Error"):
|
if isinstance(result, str) and result.startswith("Error"):
|
||||||
return result + _HINT
|
return result + _HINT
|
||||||
|
|||||||
@ -52,6 +52,10 @@ class ExecTool(Tool):
|
|||||||
def description(self) -> str:
|
def description(self) -> str:
|
||||||
return "Execute a shell command and return its output. Use with caution."
|
return "Execute a shell command and return its output. Use with caution."
|
||||||
|
|
||||||
|
@property
|
||||||
|
def exclusive(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def parameters(self) -> dict[str, Any]:
|
def parameters(self) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
|
|||||||
@ -92,6 +92,10 @@ class WebSearchTool(Tool):
|
|||||||
self.config = config if config is not None else WebSearchConfig()
|
self.config = config if config is not None else WebSearchConfig()
|
||||||
self.proxy = proxy
|
self.proxy = proxy
|
||||||
|
|
||||||
|
@property
|
||||||
|
def read_only(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
async def execute(self, query: str, count: int | None = None, **kwargs: Any) -> str:
|
async def execute(self, query: str, count: int | None = None, **kwargs: Any) -> str:
|
||||||
provider = self.config.provider.strip().lower() or "brave"
|
provider = self.config.provider.strip().lower() or "brave"
|
||||||
n = min(max(count or self.config.max_results, 1), 10)
|
n = min(max(count or self.config.max_results, 1), 10)
|
||||||
@ -234,6 +238,10 @@ class WebFetchTool(Tool):
|
|||||||
self.max_chars = max_chars
|
self.max_chars = max_chars
|
||||||
self.proxy = proxy
|
self.proxy = proxy
|
||||||
|
|
||||||
|
@property
|
||||||
|
def read_only(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
async def execute(self, url: str, extractMode: str = "markdown", maxChars: int | None = None, **kwargs: Any) -> Any:
|
async def execute(self, url: str, extractMode: str = "markdown", maxChars: int | None = None, **kwargs: Any) -> Any:
|
||||||
max_chars = maxChars or self.max_chars
|
max_chars = maxChars or self.max_chars
|
||||||
is_valid, error_msg = _validate_url_safe(url)
|
is_valid, error_msg = _validate_url_safe(url)
|
||||||
|
|||||||
@ -539,6 +539,9 @@ def serve(
|
|||||||
model=runtime_config.agents.defaults.model,
|
model=runtime_config.agents.defaults.model,
|
||||||
max_iterations=runtime_config.agents.defaults.max_tool_iterations,
|
max_iterations=runtime_config.agents.defaults.max_tool_iterations,
|
||||||
context_window_tokens=runtime_config.agents.defaults.context_window_tokens,
|
context_window_tokens=runtime_config.agents.defaults.context_window_tokens,
|
||||||
|
context_block_limit=runtime_config.agents.defaults.context_block_limit,
|
||||||
|
max_tool_result_chars=runtime_config.agents.defaults.max_tool_result_chars,
|
||||||
|
provider_retry_mode=runtime_config.agents.defaults.provider_retry_mode,
|
||||||
web_search_config=runtime_config.tools.web.search,
|
web_search_config=runtime_config.tools.web.search,
|
||||||
web_proxy=runtime_config.tools.web.proxy or None,
|
web_proxy=runtime_config.tools.web.proxy or None,
|
||||||
exec_config=runtime_config.tools.exec,
|
exec_config=runtime_config.tools.exec,
|
||||||
@ -626,6 +629,9 @@ def gateway(
|
|||||||
model=config.agents.defaults.model,
|
model=config.agents.defaults.model,
|
||||||
max_iterations=config.agents.defaults.max_tool_iterations,
|
max_iterations=config.agents.defaults.max_tool_iterations,
|
||||||
context_window_tokens=config.agents.defaults.context_window_tokens,
|
context_window_tokens=config.agents.defaults.context_window_tokens,
|
||||||
|
context_block_limit=config.agents.defaults.context_block_limit,
|
||||||
|
max_tool_result_chars=config.agents.defaults.max_tool_result_chars,
|
||||||
|
provider_retry_mode=config.agents.defaults.provider_retry_mode,
|
||||||
web_search_config=config.tools.web.search,
|
web_search_config=config.tools.web.search,
|
||||||
web_proxy=config.tools.web.proxy or None,
|
web_proxy=config.tools.web.proxy or None,
|
||||||
exec_config=config.tools.exec,
|
exec_config=config.tools.exec,
|
||||||
@ -832,6 +838,9 @@ def agent(
|
|||||||
model=config.agents.defaults.model,
|
model=config.agents.defaults.model,
|
||||||
max_iterations=config.agents.defaults.max_tool_iterations,
|
max_iterations=config.agents.defaults.max_tool_iterations,
|
||||||
context_window_tokens=config.agents.defaults.context_window_tokens,
|
context_window_tokens=config.agents.defaults.context_window_tokens,
|
||||||
|
context_block_limit=config.agents.defaults.context_block_limit,
|
||||||
|
max_tool_result_chars=config.agents.defaults.max_tool_result_chars,
|
||||||
|
provider_retry_mode=config.agents.defaults.provider_retry_mode,
|
||||||
web_search_config=config.tools.web.search,
|
web_search_config=config.tools.web.search,
|
||||||
web_proxy=config.tools.web.proxy or None,
|
web_proxy=config.tools.web.proxy or None,
|
||||||
exec_config=config.tools.exec,
|
exec_config=config.tools.exec,
|
||||||
|
|||||||
@ -38,8 +38,11 @@ class AgentDefaults(Base):
|
|||||||
)
|
)
|
||||||
max_tokens: int = 8192
|
max_tokens: int = 8192
|
||||||
context_window_tokens: int = 65_536
|
context_window_tokens: int = 65_536
|
||||||
|
context_block_limit: int | None = None
|
||||||
temperature: float = 0.1
|
temperature: float = 0.1
|
||||||
max_tool_iterations: int = 40
|
max_tool_iterations: int = 200
|
||||||
|
max_tool_result_chars: int = 16_000
|
||||||
|
provider_retry_mode: Literal["standard", "persistent"] = "standard"
|
||||||
reasoning_effort: str | None = None # low / medium / high - enables LLM thinking mode
|
reasoning_effort: str | None = None # low / medium / high - enables LLM thinking mode
|
||||||
timezone: str = "UTC" # IANA timezone, e.g. "Asia/Shanghai", "America/New_York"
|
timezone: str = "UTC" # IANA timezone, e.g. "Asia/Shanghai", "America/New_York"
|
||||||
|
|
||||||
|
|||||||
@ -73,6 +73,9 @@ class Nanobot:
|
|||||||
model=defaults.model,
|
model=defaults.model,
|
||||||
max_iterations=defaults.max_tool_iterations,
|
max_iterations=defaults.max_tool_iterations,
|
||||||
context_window_tokens=defaults.context_window_tokens,
|
context_window_tokens=defaults.context_window_tokens,
|
||||||
|
context_block_limit=defaults.context_block_limit,
|
||||||
|
max_tool_result_chars=defaults.max_tool_result_chars,
|
||||||
|
provider_retry_mode=defaults.provider_retry_mode,
|
||||||
web_search_config=config.tools.web.search,
|
web_search_config=config.tools.web.search,
|
||||||
web_proxy=config.tools.web.proxy or None,
|
web_proxy=config.tools.web.proxy or None,
|
||||||
exec_config=config.tools.exec,
|
exec_config=config.tools.exec,
|
||||||
|
|||||||
@ -2,6 +2,8 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
import re
|
import re
|
||||||
import secrets
|
import secrets
|
||||||
import string
|
import string
|
||||||
@ -427,13 +429,33 @@ class AnthropicProvider(LLMProvider):
|
|||||||
messages, tools, model, max_tokens, temperature,
|
messages, tools, model, max_tokens, temperature,
|
||||||
reasoning_effort, tool_choice,
|
reasoning_effort, tool_choice,
|
||||||
)
|
)
|
||||||
|
idle_timeout_s = int(os.environ.get("NANOBOT_STREAM_IDLE_TIMEOUT_S", "90"))
|
||||||
try:
|
try:
|
||||||
async with self._client.messages.stream(**kwargs) as stream:
|
async with self._client.messages.stream(**kwargs) as stream:
|
||||||
if on_content_delta:
|
if on_content_delta:
|
||||||
async for text in stream.text_stream:
|
stream_iter = stream.text_stream.__aiter__()
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
text = await asyncio.wait_for(
|
||||||
|
stream_iter.__anext__(),
|
||||||
|
timeout=idle_timeout_s,
|
||||||
|
)
|
||||||
|
except StopAsyncIteration:
|
||||||
|
break
|
||||||
await on_content_delta(text)
|
await on_content_delta(text)
|
||||||
response = await stream.get_final_message()
|
response = await asyncio.wait_for(
|
||||||
|
stream.get_final_message(),
|
||||||
|
timeout=idle_timeout_s,
|
||||||
|
)
|
||||||
return self._parse_response(response)
|
return self._parse_response(response)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
return LLMResponse(
|
||||||
|
content=(
|
||||||
|
f"Error calling LLM: stream stalled for more than "
|
||||||
|
f"{idle_timeout_s} seconds"
|
||||||
|
),
|
||||||
|
finish_reason="error",
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return LLMResponse(content=f"Error calling LLM: {e}", finish_reason="error")
|
return LLMResponse(content=f"Error calling LLM: {e}", finish_reason="error")
|
||||||
|
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
import re
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
@ -9,6 +10,8 @@ from typing import Any
|
|||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
from nanobot.utils.helpers import image_placeholder_text
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ToolCallRequest:
|
class ToolCallRequest:
|
||||||
@ -57,13 +60,7 @@ class LLMResponse:
|
|||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class GenerationSettings:
|
class GenerationSettings:
|
||||||
"""Default generation parameters for LLM calls.
|
"""Default generation settings."""
|
||||||
|
|
||||||
Stored on the provider so every call site inherits the same defaults
|
|
||||||
without having to pass temperature / max_tokens / reasoning_effort
|
|
||||||
through every layer. Individual call sites can still override by
|
|
||||||
passing explicit keyword arguments to chat() / chat_with_retry().
|
|
||||||
"""
|
|
||||||
|
|
||||||
temperature: float = 0.7
|
temperature: float = 0.7
|
||||||
max_tokens: int = 4096
|
max_tokens: int = 4096
|
||||||
@ -71,14 +68,11 @@ class GenerationSettings:
|
|||||||
|
|
||||||
|
|
||||||
class LLMProvider(ABC):
|
class LLMProvider(ABC):
|
||||||
"""
|
"""Base class for LLM providers."""
|
||||||
Abstract base class for LLM providers.
|
|
||||||
|
|
||||||
Implementations should handle the specifics of each provider's API
|
|
||||||
while maintaining a consistent interface.
|
|
||||||
"""
|
|
||||||
|
|
||||||
_CHAT_RETRY_DELAYS = (1, 2, 4)
|
_CHAT_RETRY_DELAYS = (1, 2, 4)
|
||||||
|
_PERSISTENT_MAX_DELAY = 60
|
||||||
|
_RETRY_HEARTBEAT_CHUNK = 30
|
||||||
_TRANSIENT_ERROR_MARKERS = (
|
_TRANSIENT_ERROR_MARKERS = (
|
||||||
"429",
|
"429",
|
||||||
"rate limit",
|
"rate limit",
|
||||||
@ -208,7 +202,7 @@ class LLMProvider(ABC):
|
|||||||
for b in content:
|
for b in content:
|
||||||
if isinstance(b, dict) and b.get("type") == "image_url":
|
if isinstance(b, dict) and b.get("type") == "image_url":
|
||||||
path = (b.get("_meta") or {}).get("path", "")
|
path = (b.get("_meta") or {}).get("path", "")
|
||||||
placeholder = f"[image: {path}]" if path else "[image omitted]"
|
placeholder = image_placeholder_text(path, empty="[image omitted]")
|
||||||
new_content.append({"type": "text", "text": placeholder})
|
new_content.append({"type": "text", "text": placeholder})
|
||||||
found = True
|
found = True
|
||||||
else:
|
else:
|
||||||
@ -273,6 +267,8 @@ class LLMProvider(ABC):
|
|||||||
reasoning_effort: object = _SENTINEL,
|
reasoning_effort: object = _SENTINEL,
|
||||||
tool_choice: str | dict[str, Any] | None = None,
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||||
|
retry_mode: str = "standard",
|
||||||
|
on_retry_wait: Callable[[str], Awaitable[None]] | None = None,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
"""Call chat_stream() with retry on transient provider failures."""
|
"""Call chat_stream() with retry on transient provider failures."""
|
||||||
if max_tokens is self._SENTINEL:
|
if max_tokens is self._SENTINEL:
|
||||||
@ -288,28 +284,13 @@ class LLMProvider(ABC):
|
|||||||
reasoning_effort=reasoning_effort, tool_choice=tool_choice,
|
reasoning_effort=reasoning_effort, tool_choice=tool_choice,
|
||||||
on_content_delta=on_content_delta,
|
on_content_delta=on_content_delta,
|
||||||
)
|
)
|
||||||
|
return await self._run_with_retry(
|
||||||
for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1):
|
self._safe_chat_stream,
|
||||||
response = await self._safe_chat_stream(**kw)
|
kw,
|
||||||
|
messages,
|
||||||
if response.finish_reason != "error":
|
retry_mode=retry_mode,
|
||||||
return response
|
on_retry_wait=on_retry_wait,
|
||||||
|
)
|
||||||
if not self._is_transient_error(response.content):
|
|
||||||
stripped = self._strip_image_content(messages)
|
|
||||||
if stripped is not None:
|
|
||||||
logger.warning("Non-transient LLM error with image content, retrying without images")
|
|
||||||
return await self._safe_chat_stream(**{**kw, "messages": stripped})
|
|
||||||
return response
|
|
||||||
|
|
||||||
logger.warning(
|
|
||||||
"LLM transient error (attempt {}/{}), retrying in {}s: {}",
|
|
||||||
attempt, len(self._CHAT_RETRY_DELAYS), delay,
|
|
||||||
(response.content or "")[:120].lower(),
|
|
||||||
)
|
|
||||||
await asyncio.sleep(delay)
|
|
||||||
|
|
||||||
return await self._safe_chat_stream(**kw)
|
|
||||||
|
|
||||||
async def chat_with_retry(
|
async def chat_with_retry(
|
||||||
self,
|
self,
|
||||||
@ -320,6 +301,8 @@ class LLMProvider(ABC):
|
|||||||
temperature: object = _SENTINEL,
|
temperature: object = _SENTINEL,
|
||||||
reasoning_effort: object = _SENTINEL,
|
reasoning_effort: object = _SENTINEL,
|
||||||
tool_choice: str | dict[str, Any] | None = None,
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
|
retry_mode: str = "standard",
|
||||||
|
on_retry_wait: Callable[[str], Awaitable[None]] | None = None,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
"""Call chat() with retry on transient provider failures.
|
"""Call chat() with retry on transient provider failures.
|
||||||
|
|
||||||
@ -339,28 +322,102 @@ class LLMProvider(ABC):
|
|||||||
max_tokens=max_tokens, temperature=temperature,
|
max_tokens=max_tokens, temperature=temperature,
|
||||||
reasoning_effort=reasoning_effort, tool_choice=tool_choice,
|
reasoning_effort=reasoning_effort, tool_choice=tool_choice,
|
||||||
)
|
)
|
||||||
|
return await self._run_with_retry(
|
||||||
|
self._safe_chat,
|
||||||
|
kw,
|
||||||
|
messages,
|
||||||
|
retry_mode=retry_mode,
|
||||||
|
on_retry_wait=on_retry_wait,
|
||||||
|
)
|
||||||
|
|
||||||
for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1):
|
@classmethod
|
||||||
response = await self._safe_chat(**kw)
|
def _extract_retry_after(cls, content: str | None) -> float | None:
|
||||||
|
text = (content or "").lower()
|
||||||
|
match = re.search(r"retry after\s+(\d+(?:\.\d+)?)\s*(ms|milliseconds|s|sec|secs|seconds|m|min|minutes)?", text)
|
||||||
|
if not match:
|
||||||
|
return None
|
||||||
|
value = float(match.group(1))
|
||||||
|
unit = (match.group(2) or "s").lower()
|
||||||
|
if unit in {"ms", "milliseconds"}:
|
||||||
|
return max(0.1, value / 1000.0)
|
||||||
|
if unit in {"m", "min", "minutes"}:
|
||||||
|
return value * 60.0
|
||||||
|
return value
|
||||||
|
|
||||||
|
async def _sleep_with_heartbeat(
|
||||||
|
self,
|
||||||
|
delay: float,
|
||||||
|
*,
|
||||||
|
attempt: int,
|
||||||
|
persistent: bool,
|
||||||
|
on_retry_wait: Callable[[str], Awaitable[None]] | None = None,
|
||||||
|
) -> None:
|
||||||
|
remaining = max(0.0, delay)
|
||||||
|
while remaining > 0:
|
||||||
|
if on_retry_wait:
|
||||||
|
kind = "persistent retry" if persistent else "retry"
|
||||||
|
await on_retry_wait(
|
||||||
|
f"Model request failed, {kind} in {max(1, int(round(remaining)))}s "
|
||||||
|
f"(attempt {attempt})."
|
||||||
|
)
|
||||||
|
chunk = min(remaining, self._RETRY_HEARTBEAT_CHUNK)
|
||||||
|
await asyncio.sleep(chunk)
|
||||||
|
remaining -= chunk
|
||||||
|
|
||||||
|
async def _run_with_retry(
|
||||||
|
self,
|
||||||
|
call: Callable[..., Awaitable[LLMResponse]],
|
||||||
|
kw: dict[str, Any],
|
||||||
|
original_messages: list[dict[str, Any]],
|
||||||
|
*,
|
||||||
|
retry_mode: str,
|
||||||
|
on_retry_wait: Callable[[str], Awaitable[None]] | None,
|
||||||
|
) -> LLMResponse:
|
||||||
|
attempt = 0
|
||||||
|
delays = list(self._CHAT_RETRY_DELAYS)
|
||||||
|
persistent = retry_mode == "persistent"
|
||||||
|
last_response: LLMResponse | None = None
|
||||||
|
while True:
|
||||||
|
attempt += 1
|
||||||
|
response = await call(**kw)
|
||||||
if response.finish_reason != "error":
|
if response.finish_reason != "error":
|
||||||
return response
|
return response
|
||||||
|
last_response = response
|
||||||
|
|
||||||
if not self._is_transient_error(response.content):
|
if not self._is_transient_error(response.content):
|
||||||
stripped = self._strip_image_content(messages)
|
stripped = self._strip_image_content(original_messages)
|
||||||
if stripped is not None:
|
if stripped is not None and stripped != kw["messages"]:
|
||||||
logger.warning("Non-transient LLM error with image content, retrying without images")
|
logger.warning(
|
||||||
return await self._safe_chat(**{**kw, "messages": stripped})
|
"Non-transient LLM error with image content, retrying without images"
|
||||||
|
)
|
||||||
|
retry_kw = dict(kw)
|
||||||
|
retry_kw["messages"] = stripped
|
||||||
|
return await call(**retry_kw)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
if not persistent and attempt > len(delays):
|
||||||
|
break
|
||||||
|
|
||||||
|
base_delay = delays[min(attempt - 1, len(delays) - 1)]
|
||||||
|
delay = self._extract_retry_after(response.content) or base_delay
|
||||||
|
if persistent:
|
||||||
|
delay = min(delay, self._PERSISTENT_MAX_DELAY)
|
||||||
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"LLM transient error (attempt {}/{}), retrying in {}s: {}",
|
"LLM transient error (attempt {}{}), retrying in {}s: {}",
|
||||||
attempt, len(self._CHAT_RETRY_DELAYS), delay,
|
attempt,
|
||||||
|
"+" if persistent and attempt > len(delays) else f"/{len(delays)}",
|
||||||
|
int(round(delay)),
|
||||||
(response.content or "")[:120].lower(),
|
(response.content or "")[:120].lower(),
|
||||||
)
|
)
|
||||||
await asyncio.sleep(delay)
|
await self._sleep_with_heartbeat(
|
||||||
|
delay,
|
||||||
|
attempt=attempt,
|
||||||
|
persistent=persistent,
|
||||||
|
on_retry_wait=on_retry_wait,
|
||||||
|
)
|
||||||
|
|
||||||
return await self._safe_chat(**kw)
|
return last_response if last_response is not None else await call(**kw)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_default_model(self) -> str:
|
def get_default_model(self) -> str:
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import hashlib
|
import hashlib
|
||||||
import os
|
import os
|
||||||
import secrets
|
import secrets
|
||||||
@ -20,7 +21,6 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
_ALLOWED_MSG_KEYS = frozenset({
|
_ALLOWED_MSG_KEYS = frozenset({
|
||||||
"role", "content", "tool_calls", "tool_call_id", "name",
|
"role", "content", "tool_calls", "tool_call_id", "name",
|
||||||
"reasoning_content", "extra_content",
|
|
||||||
})
|
})
|
||||||
_ALNUM = string.ascii_letters + string.digits
|
_ALNUM = string.ascii_letters + string.digits
|
||||||
|
|
||||||
@ -572,16 +572,33 @@ class OpenAICompatProvider(LLMProvider):
|
|||||||
)
|
)
|
||||||
kwargs["stream"] = True
|
kwargs["stream"] = True
|
||||||
kwargs["stream_options"] = {"include_usage": True}
|
kwargs["stream_options"] = {"include_usage": True}
|
||||||
|
idle_timeout_s = int(os.environ.get("NANOBOT_STREAM_IDLE_TIMEOUT_S", "90"))
|
||||||
try:
|
try:
|
||||||
stream = await self._client.chat.completions.create(**kwargs)
|
stream = await self._client.chat.completions.create(**kwargs)
|
||||||
chunks: list[Any] = []
|
chunks: list[Any] = []
|
||||||
async for chunk in stream:
|
stream_iter = stream.__aiter__()
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
chunk = await asyncio.wait_for(
|
||||||
|
stream_iter.__anext__(),
|
||||||
|
timeout=idle_timeout_s,
|
||||||
|
)
|
||||||
|
except StopAsyncIteration:
|
||||||
|
break
|
||||||
chunks.append(chunk)
|
chunks.append(chunk)
|
||||||
if on_content_delta and chunk.choices:
|
if on_content_delta and chunk.choices:
|
||||||
text = getattr(chunk.choices[0].delta, "content", None)
|
text = getattr(chunk.choices[0].delta, "content", None)
|
||||||
if text:
|
if text:
|
||||||
await on_content_delta(text)
|
await on_content_delta(text)
|
||||||
return self._parse_chunks(chunks)
|
return self._parse_chunks(chunks)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
return LLMResponse(
|
||||||
|
content=(
|
||||||
|
f"Error calling LLM: stream stalled for more than "
|
||||||
|
f"{idle_timeout_s} seconds"
|
||||||
|
),
|
||||||
|
finish_reason="error",
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return self._handle_error(e)
|
return self._handle_error(e)
|
||||||
|
|
||||||
|
|||||||
@ -10,20 +10,12 @@ from typing import Any
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from nanobot.config.paths import get_legacy_sessions_dir
|
from nanobot.config.paths import get_legacy_sessions_dir
|
||||||
from nanobot.utils.helpers import ensure_dir, safe_filename
|
from nanobot.utils.helpers import ensure_dir, find_legal_message_start, safe_filename
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Session:
|
class Session:
|
||||||
"""
|
"""A conversation session."""
|
||||||
A conversation session.
|
|
||||||
|
|
||||||
Stores messages in JSONL format for easy reading and persistence.
|
|
||||||
|
|
||||||
Important: Messages are append-only for LLM cache efficiency.
|
|
||||||
The consolidation process writes summaries to MEMORY.md/HISTORY.md
|
|
||||||
but does NOT modify the messages list or get_history() output.
|
|
||||||
"""
|
|
||||||
|
|
||||||
key: str # channel:chat_id
|
key: str # channel:chat_id
|
||||||
messages: list[dict[str, Any]] = field(default_factory=list)
|
messages: list[dict[str, Any]] = field(default_factory=list)
|
||||||
@ -43,43 +35,19 @@ class Session:
|
|||||||
self.messages.append(msg)
|
self.messages.append(msg)
|
||||||
self.updated_at = datetime.now()
|
self.updated_at = datetime.now()
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _find_legal_start(messages: list[dict[str, Any]]) -> int:
|
|
||||||
"""Find first index where every tool result has a matching assistant tool_call."""
|
|
||||||
declared: set[str] = set()
|
|
||||||
start = 0
|
|
||||||
for i, msg in enumerate(messages):
|
|
||||||
role = msg.get("role")
|
|
||||||
if role == "assistant":
|
|
||||||
for tc in msg.get("tool_calls") or []:
|
|
||||||
if isinstance(tc, dict) and tc.get("id"):
|
|
||||||
declared.add(str(tc["id"]))
|
|
||||||
elif role == "tool":
|
|
||||||
tid = msg.get("tool_call_id")
|
|
||||||
if tid and str(tid) not in declared:
|
|
||||||
start = i + 1
|
|
||||||
declared.clear()
|
|
||||||
for prev in messages[start:i + 1]:
|
|
||||||
if prev.get("role") == "assistant":
|
|
||||||
for tc in prev.get("tool_calls") or []:
|
|
||||||
if isinstance(tc, dict) and tc.get("id"):
|
|
||||||
declared.add(str(tc["id"]))
|
|
||||||
return start
|
|
||||||
|
|
||||||
def get_history(self, max_messages: int = 500) -> list[dict[str, Any]]:
|
def get_history(self, max_messages: int = 500) -> list[dict[str, Any]]:
|
||||||
"""Return unconsolidated messages for LLM input, aligned to a legal tool-call boundary."""
|
"""Return unconsolidated messages for LLM input, aligned to a legal tool-call boundary."""
|
||||||
unconsolidated = self.messages[self.last_consolidated:]
|
unconsolidated = self.messages[self.last_consolidated:]
|
||||||
sliced = unconsolidated[-max_messages:]
|
sliced = unconsolidated[-max_messages:]
|
||||||
|
|
||||||
# Drop leading non-user messages to avoid starting mid-turn when possible.
|
# Avoid starting mid-turn when possible.
|
||||||
for i, message in enumerate(sliced):
|
for i, message in enumerate(sliced):
|
||||||
if message.get("role") == "user":
|
if message.get("role") == "user":
|
||||||
sliced = sliced[i:]
|
sliced = sliced[i:]
|
||||||
break
|
break
|
||||||
|
|
||||||
# Some providers reject orphan tool results if the matching assistant
|
# Drop orphan tool results at the front.
|
||||||
# tool_calls message fell outside the fixed-size history window.
|
start = find_legal_message_start(sliced)
|
||||||
start = self._find_legal_start(sliced)
|
|
||||||
if start:
|
if start:
|
||||||
sliced = sliced[start:]
|
sliced = sliced[start:]
|
||||||
|
|
||||||
@ -115,7 +83,7 @@ class Session:
|
|||||||
retained = self.messages[start_idx:]
|
retained = self.messages[start_idx:]
|
||||||
|
|
||||||
# Mirror get_history(): avoid persisting orphan tool results at the front.
|
# Mirror get_history(): avoid persisting orphan tool results at the front.
|
||||||
start = self._find_legal_start(retained)
|
start = find_legal_message_start(retained)
|
||||||
if start:
|
if start:
|
||||||
retained = retained[start:]
|
retained = retained[start:]
|
||||||
|
|
||||||
|
|||||||
@ -3,7 +3,9 @@
|
|||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
|
import shutil
|
||||||
import time
|
import time
|
||||||
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@ -56,11 +58,7 @@ def timestamp() -> str:
|
|||||||
|
|
||||||
|
|
||||||
def current_time_str(timezone: str | None = None) -> str:
|
def current_time_str(timezone: str | None = None) -> str:
|
||||||
"""Human-readable current time with weekday and UTC offset.
|
"""Return the current time string."""
|
||||||
|
|
||||||
When *timezone* is a valid IANA name (e.g. ``"Asia/Shanghai"``), the time
|
|
||||||
is converted to that zone. Otherwise falls back to the host local time.
|
|
||||||
"""
|
|
||||||
from zoneinfo import ZoneInfo
|
from zoneinfo import ZoneInfo
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -76,12 +74,164 @@ def current_time_str(timezone: str | None = None) -> str:
|
|||||||
|
|
||||||
|
|
||||||
_UNSAFE_CHARS = re.compile(r'[<>:"/\\|?*]')
|
_UNSAFE_CHARS = re.compile(r'[<>:"/\\|?*]')
|
||||||
|
_TOOL_RESULT_PREVIEW_CHARS = 1200
|
||||||
|
_TOOL_RESULTS_DIR = ".nanobot/tool-results"
|
||||||
|
_TOOL_RESULT_RETENTION_SECS = 7 * 24 * 60 * 60
|
||||||
|
_TOOL_RESULT_MAX_BUCKETS = 32
|
||||||
|
|
||||||
def safe_filename(name: str) -> str:
|
def safe_filename(name: str) -> str:
|
||||||
"""Replace unsafe path characters with underscores."""
|
"""Replace unsafe path characters with underscores."""
|
||||||
return _UNSAFE_CHARS.sub("_", name).strip()
|
return _UNSAFE_CHARS.sub("_", name).strip()
|
||||||
|
|
||||||
|
|
||||||
|
def image_placeholder_text(path: str | None, *, empty: str = "[image]") -> str:
|
||||||
|
"""Build an image placeholder string."""
|
||||||
|
return f"[image: {path}]" if path else empty
|
||||||
|
|
||||||
|
|
||||||
|
def truncate_text(text: str, max_chars: int) -> str:
|
||||||
|
"""Truncate text with a stable suffix."""
|
||||||
|
if max_chars <= 0 or len(text) <= max_chars:
|
||||||
|
return text
|
||||||
|
return text[:max_chars] + "\n... (truncated)"
|
||||||
|
|
||||||
|
|
||||||
|
def find_legal_message_start(messages: list[dict[str, Any]]) -> int:
|
||||||
|
"""Find the first index whose tool results have matching assistant calls."""
|
||||||
|
declared: set[str] = set()
|
||||||
|
start = 0
|
||||||
|
for i, msg in enumerate(messages):
|
||||||
|
role = msg.get("role")
|
||||||
|
if role == "assistant":
|
||||||
|
for tc in msg.get("tool_calls") or []:
|
||||||
|
if isinstance(tc, dict) and tc.get("id"):
|
||||||
|
declared.add(str(tc["id"]))
|
||||||
|
elif role == "tool":
|
||||||
|
tid = msg.get("tool_call_id")
|
||||||
|
if tid and str(tid) not in declared:
|
||||||
|
start = i + 1
|
||||||
|
declared.clear()
|
||||||
|
for prev in messages[start : i + 1]:
|
||||||
|
if prev.get("role") == "assistant":
|
||||||
|
for tc in prev.get("tool_calls") or []:
|
||||||
|
if isinstance(tc, dict) and tc.get("id"):
|
||||||
|
declared.add(str(tc["id"]))
|
||||||
|
return start
|
||||||
|
|
||||||
|
|
||||||
|
def _stringify_text_blocks(content: list[dict[str, Any]]) -> str | None:
|
||||||
|
parts: list[str] = []
|
||||||
|
for block in content:
|
||||||
|
if not isinstance(block, dict):
|
||||||
|
return None
|
||||||
|
if block.get("type") != "text":
|
||||||
|
return None
|
||||||
|
text = block.get("text")
|
||||||
|
if not isinstance(text, str):
|
||||||
|
return None
|
||||||
|
parts.append(text)
|
||||||
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
def _render_tool_result_reference(
|
||||||
|
filepath: Path,
|
||||||
|
*,
|
||||||
|
original_size: int,
|
||||||
|
preview: str,
|
||||||
|
truncated_preview: bool,
|
||||||
|
) -> str:
|
||||||
|
result = (
|
||||||
|
f"[tool output persisted]\n"
|
||||||
|
f"Full output saved to: {filepath}\n"
|
||||||
|
f"Original size: {original_size} chars\n"
|
||||||
|
f"Preview:\n{preview}"
|
||||||
|
)
|
||||||
|
if truncated_preview:
|
||||||
|
result += "\n...\n(Read the saved file if you need the full output.)"
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _bucket_mtime(path: Path) -> float:
|
||||||
|
try:
|
||||||
|
return path.stat().st_mtime
|
||||||
|
except OSError:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def _cleanup_tool_result_buckets(root: Path, current_bucket: Path) -> None:
|
||||||
|
siblings = [path for path in root.iterdir() if path.is_dir() and path != current_bucket]
|
||||||
|
cutoff = time.time() - _TOOL_RESULT_RETENTION_SECS
|
||||||
|
for path in siblings:
|
||||||
|
if _bucket_mtime(path) < cutoff:
|
||||||
|
shutil.rmtree(path, ignore_errors=True)
|
||||||
|
keep = max(_TOOL_RESULT_MAX_BUCKETS - 1, 0)
|
||||||
|
siblings = [path for path in siblings if path.exists()]
|
||||||
|
if len(siblings) <= keep:
|
||||||
|
return
|
||||||
|
siblings.sort(key=_bucket_mtime, reverse=True)
|
||||||
|
for path in siblings[keep:]:
|
||||||
|
shutil.rmtree(path, ignore_errors=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _write_text_atomic(path: Path, content: str) -> None:
|
||||||
|
tmp = path.with_name(f".{path.name}.{uuid.uuid4().hex}.tmp")
|
||||||
|
try:
|
||||||
|
tmp.write_text(content, encoding="utf-8")
|
||||||
|
tmp.replace(path)
|
||||||
|
finally:
|
||||||
|
if tmp.exists():
|
||||||
|
tmp.unlink(missing_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_persist_tool_result(
|
||||||
|
workspace: Path | None,
|
||||||
|
session_key: str | None,
|
||||||
|
tool_call_id: str,
|
||||||
|
content: Any,
|
||||||
|
*,
|
||||||
|
max_chars: int,
|
||||||
|
) -> Any:
|
||||||
|
"""Persist oversized tool output and replace it with a stable reference string."""
|
||||||
|
if workspace is None or max_chars <= 0:
|
||||||
|
return content
|
||||||
|
|
||||||
|
text_payload: str | None = None
|
||||||
|
suffix = "txt"
|
||||||
|
if isinstance(content, str):
|
||||||
|
text_payload = content
|
||||||
|
elif isinstance(content, list):
|
||||||
|
text_payload = _stringify_text_blocks(content)
|
||||||
|
if text_payload is None:
|
||||||
|
return content
|
||||||
|
suffix = "json"
|
||||||
|
else:
|
||||||
|
return content
|
||||||
|
|
||||||
|
if len(text_payload) <= max_chars:
|
||||||
|
return content
|
||||||
|
|
||||||
|
root = ensure_dir(workspace / _TOOL_RESULTS_DIR)
|
||||||
|
bucket = ensure_dir(root / safe_filename(session_key or "default"))
|
||||||
|
try:
|
||||||
|
_cleanup_tool_result_buckets(root, bucket)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
path = bucket / f"{safe_filename(tool_call_id)}.{suffix}"
|
||||||
|
if not path.exists():
|
||||||
|
if suffix == "json" and isinstance(content, list):
|
||||||
|
_write_text_atomic(path, json.dumps(content, ensure_ascii=False, indent=2))
|
||||||
|
else:
|
||||||
|
_write_text_atomic(path, text_payload)
|
||||||
|
|
||||||
|
preview = text_payload[:_TOOL_RESULT_PREVIEW_CHARS]
|
||||||
|
return _render_tool_result_reference(
|
||||||
|
path,
|
||||||
|
original_size=len(text_payload),
|
||||||
|
preview=preview,
|
||||||
|
truncated_preview=len(text_payload) > _TOOL_RESULT_PREVIEW_CHARS,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def split_message(content: str, max_len: int = 2000) -> list[str]:
|
def split_message(content: str, max_len: int = 2000) -> list[str]:
|
||||||
"""
|
"""
|
||||||
Split content into chunks within max_len, preferring line breaks.
|
Split content into chunks within max_len, preferring line breaks.
|
||||||
|
|||||||
@ -71,3 +71,19 @@ def test_runtime_context_is_separate_untrusted_user_message(tmp_path) -> None:
|
|||||||
assert "Channel: cli" in user_content
|
assert "Channel: cli" in user_content
|
||||||
assert "Chat ID: direct" in user_content
|
assert "Chat ID: direct" in user_content
|
||||||
assert "Return exactly: OK" in user_content
|
assert "Return exactly: OK" in user_content
|
||||||
|
|
||||||
|
|
||||||
|
def test_subagent_result_does_not_create_consecutive_assistant_messages(tmp_path) -> None:
|
||||||
|
workspace = _make_workspace(tmp_path)
|
||||||
|
builder = ContextBuilder(workspace)
|
||||||
|
|
||||||
|
messages = builder.build_messages(
|
||||||
|
history=[{"role": "assistant", "content": "previous result"}],
|
||||||
|
current_message="subagent result",
|
||||||
|
channel="cli",
|
||||||
|
chat_id="direct",
|
||||||
|
current_role="assistant",
|
||||||
|
)
|
||||||
|
|
||||||
|
for left, right in zip(messages, messages[1:]):
|
||||||
|
assert not (left.get("role") == right.get("role") == "assistant")
|
||||||
|
|||||||
@ -5,7 +5,9 @@ from nanobot.session.manager import Session
|
|||||||
|
|
||||||
def _mk_loop() -> AgentLoop:
|
def _mk_loop() -> AgentLoop:
|
||||||
loop = AgentLoop.__new__(AgentLoop)
|
loop = AgentLoop.__new__(AgentLoop)
|
||||||
loop._TOOL_RESULT_MAX_CHARS = AgentLoop._TOOL_RESULT_MAX_CHARS
|
from nanobot.config.schema import AgentDefaults
|
||||||
|
|
||||||
|
loop.max_tool_result_chars = AgentDefaults().max_tool_result_chars
|
||||||
return loop
|
return loop
|
||||||
|
|
||||||
|
|
||||||
@ -72,3 +74,129 @@ def test_save_turn_keeps_tool_results_under_16k() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert session.messages[0]["content"] == content
|
assert session.messages[0]["content"] == content
|
||||||
|
|
||||||
|
|
||||||
|
def test_restore_runtime_checkpoint_rehydrates_completed_and_pending_tools() -> None:
|
||||||
|
loop = _mk_loop()
|
||||||
|
session = Session(
|
||||||
|
key="test:checkpoint",
|
||||||
|
metadata={
|
||||||
|
AgentLoop._RUNTIME_CHECKPOINT_KEY: {
|
||||||
|
"assistant_message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "working",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "call_done",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "read_file", "arguments": "{}"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "call_pending",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "exec", "arguments": "{}"},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"completed_tool_results": [
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": "call_done",
|
||||||
|
"name": "read_file",
|
||||||
|
"content": "ok",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"pending_tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "call_pending",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "exec", "arguments": "{}"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
restored = loop._restore_runtime_checkpoint(session)
|
||||||
|
|
||||||
|
assert restored is True
|
||||||
|
assert session.metadata.get(AgentLoop._RUNTIME_CHECKPOINT_KEY) is None
|
||||||
|
assert session.messages[0]["role"] == "assistant"
|
||||||
|
assert session.messages[1]["tool_call_id"] == "call_done"
|
||||||
|
assert session.messages[2]["tool_call_id"] == "call_pending"
|
||||||
|
assert "interrupted before this tool finished" in session.messages[2]["content"].lower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_restore_runtime_checkpoint_dedupes_overlapping_tail() -> None:
|
||||||
|
loop = _mk_loop()
|
||||||
|
session = Session(
|
||||||
|
key="test:checkpoint-overlap",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "working",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "call_done",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "read_file", "arguments": "{}"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "call_pending",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "exec", "arguments": "{}"},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": "call_done",
|
||||||
|
"name": "read_file",
|
||||||
|
"content": "ok",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
metadata={
|
||||||
|
AgentLoop._RUNTIME_CHECKPOINT_KEY: {
|
||||||
|
"assistant_message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "working",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "call_done",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "read_file", "arguments": "{}"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "call_pending",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "exec", "arguments": "{}"},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"completed_tool_results": [
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": "call_done",
|
||||||
|
"name": "read_file",
|
||||||
|
"content": "ok",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"pending_tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "call_pending",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "exec", "arguments": "{}"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
restored = loop._restore_runtime_checkpoint(session)
|
||||||
|
|
||||||
|
assert restored is True
|
||||||
|
assert session.metadata.get(AgentLoop._RUNTIME_CHECKPOINT_KEY) is None
|
||||||
|
assert len(session.messages) == 3
|
||||||
|
assert session.messages[0]["role"] == "assistant"
|
||||||
|
assert session.messages[1]["tool_call_id"] == "call_done"
|
||||||
|
assert session.messages[2]["tool_call_id"] == "call_pending"
|
||||||
|
|||||||
@ -2,12 +2,20 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import time
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.config.schema import AgentDefaults
|
||||||
|
from nanobot.agent.tools.base import Tool
|
||||||
|
from nanobot.agent.tools.registry import ToolRegistry
|
||||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||||
|
|
||||||
|
_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars
|
||||||
|
|
||||||
|
|
||||||
def _make_loop(tmp_path):
|
def _make_loop(tmp_path):
|
||||||
from nanobot.agent.loop import AgentLoop
|
from nanobot.agent.loop import AgentLoop
|
||||||
@ -60,6 +68,7 @@ async def test_runner_preserves_reasoning_fields_and_tool_results():
|
|||||||
tools=tools,
|
tools=tools,
|
||||||
model="test-model",
|
model="test-model",
|
||||||
max_iterations=3,
|
max_iterations=3,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
))
|
))
|
||||||
|
|
||||||
assert result.final_content == "done"
|
assert result.final_content == "done"
|
||||||
@ -135,6 +144,7 @@ async def test_runner_calls_hooks_in_order():
|
|||||||
tools=tools,
|
tools=tools,
|
||||||
model="test-model",
|
model="test-model",
|
||||||
max_iterations=3,
|
max_iterations=3,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
hook=RecordingHook(),
|
hook=RecordingHook(),
|
||||||
))
|
))
|
||||||
|
|
||||||
@ -191,6 +201,7 @@ async def test_runner_streaming_hook_receives_deltas_and_end_signal():
|
|||||||
tools=tools,
|
tools=tools,
|
||||||
model="test-model",
|
model="test-model",
|
||||||
max_iterations=1,
|
max_iterations=1,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
hook=StreamingHook(),
|
hook=StreamingHook(),
|
||||||
))
|
))
|
||||||
|
|
||||||
@ -219,6 +230,7 @@ async def test_runner_returns_max_iterations_fallback():
|
|||||||
tools=tools,
|
tools=tools,
|
||||||
model="test-model",
|
model="test-model",
|
||||||
max_iterations=2,
|
max_iterations=2,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
))
|
))
|
||||||
|
|
||||||
assert result.stop_reason == "max_iterations"
|
assert result.stop_reason == "max_iterations"
|
||||||
@ -226,7 +238,8 @@ async def test_runner_returns_max_iterations_fallback():
|
|||||||
"I reached the maximum number of tool call iterations (2) "
|
"I reached the maximum number of tool call iterations (2) "
|
||||||
"without completing the task. You can try breaking the task into smaller steps."
|
"without completing the task. You can try breaking the task into smaller steps."
|
||||||
)
|
)
|
||||||
|
assert result.messages[-1]["role"] == "assistant"
|
||||||
|
assert result.messages[-1]["content"] == result.final_content
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_runner_returns_structured_tool_error():
|
async def test_runner_returns_structured_tool_error():
|
||||||
@ -248,6 +261,7 @@ async def test_runner_returns_structured_tool_error():
|
|||||||
tools=tools,
|
tools=tools,
|
||||||
model="test-model",
|
model="test-model",
|
||||||
max_iterations=2,
|
max_iterations=2,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
fail_on_tool_error=True,
|
fail_on_tool_error=True,
|
||||||
))
|
))
|
||||||
|
|
||||||
@ -258,6 +272,232 @@ async def test_runner_returns_structured_tool_error():
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_persists_large_tool_results_for_follow_up_calls(tmp_path):
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
|
|
||||||
|
provider = MagicMock()
|
||||||
|
captured_second_call: list[dict] = []
|
||||||
|
call_count = {"n": 0}
|
||||||
|
|
||||||
|
async def chat_with_retry(*, messages, **kwargs):
|
||||||
|
call_count["n"] += 1
|
||||||
|
if call_count["n"] == 1:
|
||||||
|
return LLMResponse(
|
||||||
|
content="working",
|
||||||
|
tool_calls=[ToolCallRequest(id="call_big", name="list_dir", arguments={"path": "."})],
|
||||||
|
usage={"prompt_tokens": 5, "completion_tokens": 3},
|
||||||
|
)
|
||||||
|
captured_second_call[:] = messages
|
||||||
|
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||||
|
|
||||||
|
provider.chat_with_retry = chat_with_retry
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
tools.execute = AsyncMock(return_value="x" * 20_000)
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
result = await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=[{"role": "user", "content": "do task"}],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=2,
|
||||||
|
workspace=tmp_path,
|
||||||
|
session_key="test:runner",
|
||||||
|
max_tool_result_chars=2048,
|
||||||
|
))
|
||||||
|
|
||||||
|
assert result.final_content == "done"
|
||||||
|
tool_message = next(msg for msg in captured_second_call if msg.get("role") == "tool")
|
||||||
|
assert "[tool output persisted]" in tool_message["content"]
|
||||||
|
assert "tool-results" in tool_message["content"]
|
||||||
|
assert (tmp_path / ".nanobot" / "tool-results" / "test_runner" / "call_big.txt").exists()
|
||||||
|
|
||||||
|
|
||||||
|
def test_persist_tool_result_prunes_old_session_buckets(tmp_path):
|
||||||
|
from nanobot.utils.helpers import maybe_persist_tool_result
|
||||||
|
|
||||||
|
root = tmp_path / ".nanobot" / "tool-results"
|
||||||
|
old_bucket = root / "old_session"
|
||||||
|
recent_bucket = root / "recent_session"
|
||||||
|
old_bucket.mkdir(parents=True)
|
||||||
|
recent_bucket.mkdir(parents=True)
|
||||||
|
(old_bucket / "old.txt").write_text("old", encoding="utf-8")
|
||||||
|
(recent_bucket / "recent.txt").write_text("recent", encoding="utf-8")
|
||||||
|
|
||||||
|
stale = time.time() - (8 * 24 * 60 * 60)
|
||||||
|
os.utime(old_bucket, (stale, stale))
|
||||||
|
os.utime(old_bucket / "old.txt", (stale, stale))
|
||||||
|
|
||||||
|
persisted = maybe_persist_tool_result(
|
||||||
|
tmp_path,
|
||||||
|
"current:session",
|
||||||
|
"call_big",
|
||||||
|
"x" * 5000,
|
||||||
|
max_chars=64,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "[tool output persisted]" in persisted
|
||||||
|
assert not old_bucket.exists()
|
||||||
|
assert recent_bucket.exists()
|
||||||
|
assert (root / "current_session" / "call_big.txt").exists()
|
||||||
|
|
||||||
|
|
||||||
|
def test_persist_tool_result_leaves_no_temp_files(tmp_path):
|
||||||
|
from nanobot.utils.helpers import maybe_persist_tool_result
|
||||||
|
|
||||||
|
root = tmp_path / ".nanobot" / "tool-results"
|
||||||
|
maybe_persist_tool_result(
|
||||||
|
tmp_path,
|
||||||
|
"current:session",
|
||||||
|
"call_big",
|
||||||
|
"x" * 5000,
|
||||||
|
max_chars=64,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (root / "current_session" / "call_big.txt").exists()
|
||||||
|
assert list((root / "current_session").glob("*.tmp")) == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_uses_raw_messages_when_context_governance_fails():
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
|
|
||||||
|
provider = MagicMock()
|
||||||
|
captured_messages: list[dict] = []
|
||||||
|
|
||||||
|
async def chat_with_retry(*, messages, **kwargs):
|
||||||
|
captured_messages[:] = messages
|
||||||
|
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||||
|
|
||||||
|
provider.chat_with_retry = chat_with_retry
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
initial_messages = [
|
||||||
|
{"role": "system", "content": "system"},
|
||||||
|
{"role": "user", "content": "hello"},
|
||||||
|
]
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
runner._snip_history = MagicMock(side_effect=RuntimeError("boom")) # type: ignore[method-assign]
|
||||||
|
result = await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=initial_messages,
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=1,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
))
|
||||||
|
|
||||||
|
assert result.final_content == "done"
|
||||||
|
assert captured_messages == initial_messages
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_keeps_going_when_tool_result_persistence_fails():
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
|
|
||||||
|
provider = MagicMock()
|
||||||
|
captured_second_call: list[dict] = []
|
||||||
|
call_count = {"n": 0}
|
||||||
|
|
||||||
|
async def chat_with_retry(*, messages, **kwargs):
|
||||||
|
call_count["n"] += 1
|
||||||
|
if call_count["n"] == 1:
|
||||||
|
return LLMResponse(
|
||||||
|
content="working",
|
||||||
|
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
|
||||||
|
usage={"prompt_tokens": 5, "completion_tokens": 3},
|
||||||
|
)
|
||||||
|
captured_second_call[:] = messages
|
||||||
|
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||||
|
|
||||||
|
provider.chat_with_retry = chat_with_retry
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
tools.execute = AsyncMock(return_value="tool result")
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
with patch("nanobot.agent.runner.maybe_persist_tool_result", side_effect=RuntimeError("disk full")):
|
||||||
|
result = await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=[{"role": "user", "content": "do task"}],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=2,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
))
|
||||||
|
|
||||||
|
assert result.final_content == "done"
|
||||||
|
tool_message = next(msg for msg in captured_second_call if msg.get("role") == "tool")
|
||||||
|
assert tool_message["content"] == "tool result"
|
||||||
|
|
||||||
|
|
||||||
|
class _DelayTool(Tool):
|
||||||
|
def __init__(self, name: str, *, delay: float, read_only: bool, shared_events: list[str]):
|
||||||
|
self._name = name
|
||||||
|
self._delay = delay
|
||||||
|
self._read_only = read_only
|
||||||
|
self._shared_events = shared_events
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return self._name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return self._name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parameters(self) -> dict:
|
||||||
|
return {"type": "object", "properties": {}, "required": []}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def read_only(self) -> bool:
|
||||||
|
return self._read_only
|
||||||
|
|
||||||
|
async def execute(self, **kwargs):
|
||||||
|
self._shared_events.append(f"start:{self._name}")
|
||||||
|
await asyncio.sleep(self._delay)
|
||||||
|
self._shared_events.append(f"end:{self._name}")
|
||||||
|
return self._name
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_batches_read_only_tools_before_exclusive_work():
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
|
|
||||||
|
tools = ToolRegistry()
|
||||||
|
shared_events: list[str] = []
|
||||||
|
read_a = _DelayTool("read_a", delay=0.05, read_only=True, shared_events=shared_events)
|
||||||
|
read_b = _DelayTool("read_b", delay=0.05, read_only=True, shared_events=shared_events)
|
||||||
|
write_a = _DelayTool("write_a", delay=0.01, read_only=False, shared_events=shared_events)
|
||||||
|
tools.register(read_a)
|
||||||
|
tools.register(read_b)
|
||||||
|
tools.register(write_a)
|
||||||
|
|
||||||
|
runner = AgentRunner(MagicMock())
|
||||||
|
await runner._execute_tools(
|
||||||
|
AgentRunSpec(
|
||||||
|
initial_messages=[],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=1,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
concurrent_tools=True,
|
||||||
|
),
|
||||||
|
[
|
||||||
|
ToolCallRequest(id="ro1", name="read_a", arguments={}),
|
||||||
|
ToolCallRequest(id="ro2", name="read_b", arguments={}),
|
||||||
|
ToolCallRequest(id="rw1", name="write_a", arguments={}),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert shared_events[0:2] == ["start:read_a", "start:read_b"]
|
||||||
|
assert "end:read_a" in shared_events and "end:read_b" in shared_events
|
||||||
|
assert shared_events.index("end:read_a") < shared_events.index("start:write_a")
|
||||||
|
assert shared_events.index("end:read_b") < shared_events.index("start:write_a")
|
||||||
|
assert shared_events[-2:] == ["start:write_a", "end:write_a"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_loop_max_iterations_message_stays_stable(tmp_path):
|
async def test_loop_max_iterations_message_stays_stable(tmp_path):
|
||||||
loop = _make_loop(tmp_path)
|
loop = _make_loop(tmp_path)
|
||||||
@ -317,15 +557,20 @@ async def test_subagent_max_iterations_announces_existing_fallback(tmp_path, mon
|
|||||||
provider.get_default_model.return_value = "test-model"
|
provider.get_default_model.return_value = "test-model"
|
||||||
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
|
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
|
||||||
content="working",
|
content="working",
|
||||||
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})],
|
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
|
||||||
))
|
))
|
||||||
mgr = SubagentManager(provider=provider, workspace=tmp_path, bus=bus)
|
mgr = SubagentManager(
|
||||||
|
provider=provider,
|
||||||
|
workspace=tmp_path,
|
||||||
|
bus=bus,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
)
|
||||||
mgr._announce_result = AsyncMock()
|
mgr._announce_result = AsyncMock()
|
||||||
|
|
||||||
async def fake_execute(self, name, arguments):
|
async def fake_execute(self, **kwargs):
|
||||||
return "tool result"
|
return "tool result"
|
||||||
|
|
||||||
monkeypatch.setattr("nanobot.agent.tools.registry.ToolRegistry.execute", fake_execute)
|
monkeypatch.setattr("nanobot.agent.tools.filesystem.ListDirTool.execute", fake_execute)
|
||||||
|
|
||||||
await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"})
|
await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"})
|
||||||
|
|
||||||
|
|||||||
@ -8,6 +8,10 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.config.schema import AgentDefaults
|
||||||
|
|
||||||
|
_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars
|
||||||
|
|
||||||
|
|
||||||
def _make_loop(*, exec_config=None):
|
def _make_loop(*, exec_config=None):
|
||||||
"""Create a minimal AgentLoop with mocked dependencies."""
|
"""Create a minimal AgentLoop with mocked dependencies."""
|
||||||
@ -186,7 +190,12 @@ class TestSubagentCancellation:
|
|||||||
bus = MessageBus()
|
bus = MessageBus()
|
||||||
provider = MagicMock()
|
provider = MagicMock()
|
||||||
provider.get_default_model.return_value = "test-model"
|
provider.get_default_model.return_value = "test-model"
|
||||||
mgr = SubagentManager(provider=provider, workspace=MagicMock(), bus=bus)
|
mgr = SubagentManager(
|
||||||
|
provider=provider,
|
||||||
|
workspace=MagicMock(),
|
||||||
|
bus=bus,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
)
|
||||||
|
|
||||||
cancelled = asyncio.Event()
|
cancelled = asyncio.Event()
|
||||||
|
|
||||||
@ -214,7 +223,12 @@ class TestSubagentCancellation:
|
|||||||
bus = MessageBus()
|
bus = MessageBus()
|
||||||
provider = MagicMock()
|
provider = MagicMock()
|
||||||
provider.get_default_model.return_value = "test-model"
|
provider.get_default_model.return_value = "test-model"
|
||||||
mgr = SubagentManager(provider=provider, workspace=MagicMock(), bus=bus)
|
mgr = SubagentManager(
|
||||||
|
provider=provider,
|
||||||
|
workspace=MagicMock(),
|
||||||
|
bus=bus,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
)
|
||||||
assert await mgr.cancel_by_session("nonexistent") == 0
|
assert await mgr.cancel_by_session("nonexistent") == 0
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -236,19 +250,24 @@ class TestSubagentCancellation:
|
|||||||
if call_count["n"] == 1:
|
if call_count["n"] == 1:
|
||||||
return LLMResponse(
|
return LLMResponse(
|
||||||
content="thinking",
|
content="thinking",
|
||||||
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})],
|
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
|
||||||
reasoning_content="hidden reasoning",
|
reasoning_content="hidden reasoning",
|
||||||
thinking_blocks=[{"type": "thinking", "thinking": "step"}],
|
thinking_blocks=[{"type": "thinking", "thinking": "step"}],
|
||||||
)
|
)
|
||||||
captured_second_call[:] = messages
|
captured_second_call[:] = messages
|
||||||
return LLMResponse(content="done", tool_calls=[])
|
return LLMResponse(content="done", tool_calls=[])
|
||||||
provider.chat_with_retry = scripted_chat_with_retry
|
provider.chat_with_retry = scripted_chat_with_retry
|
||||||
mgr = SubagentManager(provider=provider, workspace=tmp_path, bus=bus)
|
mgr = SubagentManager(
|
||||||
|
provider=provider,
|
||||||
|
workspace=tmp_path,
|
||||||
|
bus=bus,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
)
|
||||||
|
|
||||||
async def fake_execute(self, name, arguments):
|
async def fake_execute(self, **kwargs):
|
||||||
return "tool result"
|
return "tool result"
|
||||||
|
|
||||||
monkeypatch.setattr("nanobot.agent.tools.registry.ToolRegistry.execute", fake_execute)
|
monkeypatch.setattr("nanobot.agent.tools.filesystem.ListDirTool.execute", fake_execute)
|
||||||
|
|
||||||
await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"})
|
await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"})
|
||||||
|
|
||||||
@ -273,6 +292,7 @@ class TestSubagentCancellation:
|
|||||||
provider=provider,
|
provider=provider,
|
||||||
workspace=tmp_path,
|
workspace=tmp_path,
|
||||||
bus=bus,
|
bus=bus,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
exec_config=ExecToolConfig(enable=False),
|
exec_config=ExecToolConfig(enable=False),
|
||||||
)
|
)
|
||||||
mgr._announce_result = AsyncMock()
|
mgr._announce_result = AsyncMock()
|
||||||
@ -304,20 +324,25 @@ class TestSubagentCancellation:
|
|||||||
provider.get_default_model.return_value = "test-model"
|
provider.get_default_model.return_value = "test-model"
|
||||||
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
|
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
|
||||||
content="thinking",
|
content="thinking",
|
||||||
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})],
|
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
|
||||||
))
|
))
|
||||||
mgr = SubagentManager(provider=provider, workspace=tmp_path, bus=bus)
|
mgr = SubagentManager(
|
||||||
|
provider=provider,
|
||||||
|
workspace=tmp_path,
|
||||||
|
bus=bus,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
)
|
||||||
mgr._announce_result = AsyncMock()
|
mgr._announce_result = AsyncMock()
|
||||||
|
|
||||||
calls = {"n": 0}
|
calls = {"n": 0}
|
||||||
|
|
||||||
async def fake_execute(self, name, arguments):
|
async def fake_execute(self, **kwargs):
|
||||||
calls["n"] += 1
|
calls["n"] += 1
|
||||||
if calls["n"] == 1:
|
if calls["n"] == 1:
|
||||||
return "first result"
|
return "first result"
|
||||||
raise RuntimeError("boom")
|
raise RuntimeError("boom")
|
||||||
|
|
||||||
monkeypatch.setattr("nanobot.agent.tools.registry.ToolRegistry.execute", fake_execute)
|
monkeypatch.setattr("nanobot.agent.tools.filesystem.ListDirTool.execute", fake_execute)
|
||||||
|
|
||||||
await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"})
|
await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"})
|
||||||
|
|
||||||
@ -340,15 +365,20 @@ class TestSubagentCancellation:
|
|||||||
provider.get_default_model.return_value = "test-model"
|
provider.get_default_model.return_value = "test-model"
|
||||||
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
|
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
|
||||||
content="thinking",
|
content="thinking",
|
||||||
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})],
|
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
|
||||||
))
|
))
|
||||||
mgr = SubagentManager(provider=provider, workspace=tmp_path, bus=bus)
|
mgr = SubagentManager(
|
||||||
|
provider=provider,
|
||||||
|
workspace=tmp_path,
|
||||||
|
bus=bus,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
)
|
||||||
mgr._announce_result = AsyncMock()
|
mgr._announce_result = AsyncMock()
|
||||||
|
|
||||||
started = asyncio.Event()
|
started = asyncio.Event()
|
||||||
cancelled = asyncio.Event()
|
cancelled = asyncio.Event()
|
||||||
|
|
||||||
async def fake_execute(self, name, arguments):
|
async def fake_execute(self, **kwargs):
|
||||||
started.set()
|
started.set()
|
||||||
try:
|
try:
|
||||||
await asyncio.sleep(60)
|
await asyncio.sleep(60)
|
||||||
@ -356,7 +386,7 @@ class TestSubagentCancellation:
|
|||||||
cancelled.set()
|
cancelled.set()
|
||||||
raise
|
raise
|
||||||
|
|
||||||
monkeypatch.setattr("nanobot.agent.tools.registry.ToolRegistry.execute", fake_execute)
|
monkeypatch.setattr("nanobot.agent.tools.filesystem.ListDirTool.execute", fake_execute)
|
||||||
|
|
||||||
task = asyncio.create_task(
|
task = asyncio.create_task(
|
||||||
mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"})
|
mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"})
|
||||||
@ -364,7 +394,7 @@ class TestSubagentCancellation:
|
|||||||
mgr._running_tasks["sub-1"] = task
|
mgr._running_tasks["sub-1"] = task
|
||||||
mgr._session_tasks["test:c1"] = {"sub-1"}
|
mgr._session_tasks["test:c1"] = {"sub-1"}
|
||||||
|
|
||||||
await started.wait()
|
await asyncio.wait_for(started.wait(), timeout=1.0)
|
||||||
|
|
||||||
count = await mgr.cancel_by_session("test:c1")
|
count = await mgr.cancel_by_session("test:c1")
|
||||||
|
|
||||||
|
|||||||
@ -594,7 +594,7 @@ async def test_send_stops_typing_after_send() -> None:
|
|||||||
typing_channel.typing_enter_hook = slow_typing
|
typing_channel.typing_enter_hook = slow_typing
|
||||||
|
|
||||||
await channel._start_typing(typing_channel)
|
await channel._start_typing(typing_channel)
|
||||||
await start.wait()
|
await asyncio.wait_for(start.wait(), timeout=1.0)
|
||||||
|
|
||||||
await channel.send(OutboundMessage(channel="discord", chat_id="123", content="hello"))
|
await channel.send(OutboundMessage(channel="discord", chat_id="123", content="hello"))
|
||||||
release.set()
|
release.set()
|
||||||
@ -614,7 +614,7 @@ async def test_send_stops_typing_after_send() -> None:
|
|||||||
typing_channel.typing_enter_hook = slow_typing_progress
|
typing_channel.typing_enter_hook = slow_typing_progress
|
||||||
|
|
||||||
await channel._start_typing(typing_channel)
|
await channel._start_typing(typing_channel)
|
||||||
await start.wait()
|
await asyncio.wait_for(start.wait(), timeout=1.0)
|
||||||
|
|
||||||
await channel.send(
|
await channel.send(
|
||||||
OutboundMessage(
|
OutboundMessage(
|
||||||
@ -665,7 +665,7 @@ async def test_start_typing_uses_typing_context_when_trigger_typing_missing() ->
|
|||||||
|
|
||||||
typing_channel = _NoTriggerChannel(channel_id=123)
|
typing_channel = _NoTriggerChannel(channel_id=123)
|
||||||
await channel._start_typing(typing_channel) # type: ignore[arg-type]
|
await channel._start_typing(typing_channel) # type: ignore[arg-type]
|
||||||
await entered.wait()
|
await asyncio.wait_for(entered.wait(), timeout=1.0)
|
||||||
|
|
||||||
assert "123" in channel._typing_tasks
|
assert "123" in channel._typing_tasks
|
||||||
|
|
||||||
|
|||||||
@ -8,6 +8,7 @@ Validates that:
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from unittest.mock import AsyncMock, patch
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
@ -53,6 +54,15 @@ def _fake_tool_call_response() -> SimpleNamespace:
|
|||||||
return SimpleNamespace(choices=[choice], usage=usage)
|
return SimpleNamespace(choices=[choice], usage=usage)
|
||||||
|
|
||||||
|
|
||||||
|
class _StalledStream:
|
||||||
|
def __aiter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __anext__(self):
|
||||||
|
await asyncio.sleep(3600)
|
||||||
|
raise StopAsyncIteration
|
||||||
|
|
||||||
|
|
||||||
def test_openrouter_spec_is_gateway() -> None:
|
def test_openrouter_spec_is_gateway() -> None:
|
||||||
spec = find_by_name("openrouter")
|
spec = find_by_name("openrouter")
|
||||||
assert spec is not None
|
assert spec is not None
|
||||||
@ -214,3 +224,54 @@ def test_openai_model_passthrough() -> None:
|
|||||||
spec=spec,
|
spec=spec,
|
||||||
)
|
)
|
||||||
assert provider.get_default_model() == "gpt-4o"
|
assert provider.get_default_model() == "gpt-4o"
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_compat_strips_message_level_reasoning_fields() -> None:
|
||||||
|
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||||
|
provider = OpenAICompatProvider()
|
||||||
|
|
||||||
|
sanitized = provider._sanitize_messages([
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "done",
|
||||||
|
"reasoning_content": "hidden",
|
||||||
|
"extra_content": {"debug": True},
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "call_1",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "fn", "arguments": "{}"},
|
||||||
|
"extra_content": {"google": {"thought_signature": "sig"}},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
])
|
||||||
|
|
||||||
|
assert "reasoning_content" not in sanitized[0]
|
||||||
|
assert "extra_content" not in sanitized[0]
|
||||||
|
assert sanitized[0]["tool_calls"][0]["extra_content"] == {"google": {"thought_signature": "sig"}}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_openai_compat_stream_watchdog_returns_error_on_stall(monkeypatch) -> None:
|
||||||
|
monkeypatch.setenv("NANOBOT_STREAM_IDLE_TIMEOUT_S", "0")
|
||||||
|
mock_create = AsyncMock(return_value=_StalledStream())
|
||||||
|
spec = find_by_name("openai")
|
||||||
|
|
||||||
|
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
|
||||||
|
client_instance = MockClient.return_value
|
||||||
|
client_instance.chat.completions.create = mock_create
|
||||||
|
|
||||||
|
provider = OpenAICompatProvider(
|
||||||
|
api_key="sk-test-key",
|
||||||
|
default_model="gpt-4o",
|
||||||
|
spec=spec,
|
||||||
|
)
|
||||||
|
result = await provider.chat_stream(
|
||||||
|
messages=[{"role": "user", "content": "hello"}],
|
||||||
|
model="gpt-4o",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.finish_reason == "error"
|
||||||
|
assert result.content is not None
|
||||||
|
assert "stream stalled" in result.content
|
||||||
|
|||||||
@ -211,3 +211,32 @@ async def test_image_fallback_without_meta_uses_default_placeholder() -> None:
|
|||||||
content = msg.get("content")
|
content = msg.get("content")
|
||||||
if isinstance(content, list):
|
if isinstance(content, list):
|
||||||
assert any("[image omitted]" in (b.get("text") or "") for b in content)
|
assert any("[image omitted]" in (b.get("text") or "") for b in content)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_with_retry_uses_retry_after_and_emits_wait_progress(monkeypatch) -> None:
|
||||||
|
provider = ScriptedProvider([
|
||||||
|
LLMResponse(content="429 rate limit, retry after 7s", finish_reason="error"),
|
||||||
|
LLMResponse(content="ok"),
|
||||||
|
])
|
||||||
|
delays: list[float] = []
|
||||||
|
progress: list[str] = []
|
||||||
|
|
||||||
|
async def _fake_sleep(delay: float) -> None:
|
||||||
|
delays.append(delay)
|
||||||
|
|
||||||
|
async def _progress(msg: str) -> None:
|
||||||
|
progress.append(msg)
|
||||||
|
|
||||||
|
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
|
||||||
|
|
||||||
|
response = await provider.chat_with_retry(
|
||||||
|
messages=[{"role": "user", "content": "hello"}],
|
||||||
|
on_retry_wait=_progress,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.content == "ok"
|
||||||
|
assert delays == [7.0]
|
||||||
|
assert progress and "7s" in progress[0]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -196,7 +196,7 @@ async def test_execute_re_raises_external_cancellation() -> None:
|
|||||||
|
|
||||||
wrapper = _make_wrapper(SimpleNamespace(call_tool=call_tool), timeout=10)
|
wrapper = _make_wrapper(SimpleNamespace(call_tool=call_tool), timeout=10)
|
||||||
task = asyncio.create_task(wrapper.execute())
|
task = asyncio.create_task(wrapper.execute())
|
||||||
await started.wait()
|
await asyncio.wait_for(started.wait(), timeout=1.0)
|
||||||
|
|
||||||
task.cancel()
|
task.cancel()
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user