feat: harden agent runtime for long-running tasks

This commit is contained in:
Xubin Ren 2026-04-01 19:12:49 +00:00
parent 63d646f731
commit fbedf7ad77
25 changed files with 1348 additions and 185 deletions

View File

@ -110,6 +110,20 @@ IMPORTANT: To send files (images, documents, audio, video) to the user, you MUST
lines += [f"Channel: {channel}", f"Chat ID: {chat_id}"] lines += [f"Channel: {channel}", f"Chat ID: {chat_id}"]
return ContextBuilder._RUNTIME_CONTEXT_TAG + "\n" + "\n".join(lines) return ContextBuilder._RUNTIME_CONTEXT_TAG + "\n" + "\n".join(lines)
@staticmethod
def _merge_message_content(left: Any, right: Any) -> str | list[dict[str, Any]]:
if isinstance(left, str) and isinstance(right, str):
return f"{left}\n\n{right}" if left else right
def _to_blocks(value: Any) -> list[dict[str, Any]]:
if isinstance(value, list):
return [item if isinstance(item, dict) else {"type": "text", "text": str(item)} for item in value]
if value is None:
return []
return [{"type": "text", "text": str(value)}]
return _to_blocks(left) + _to_blocks(right)
def _load_bootstrap_files(self) -> str: def _load_bootstrap_files(self) -> str:
"""Load all bootstrap files from workspace.""" """Load all bootstrap files from workspace."""
parts = [] parts = []
@ -142,12 +156,17 @@ IMPORTANT: To send files (images, documents, audio, video) to the user, you MUST
merged = f"{runtime_ctx}\n\n{user_content}" merged = f"{runtime_ctx}\n\n{user_content}"
else: else:
merged = [{"type": "text", "text": runtime_ctx}] + user_content merged = [{"type": "text", "text": runtime_ctx}] + user_content
messages = [
return [
{"role": "system", "content": self.build_system_prompt(skill_names)}, {"role": "system", "content": self.build_system_prompt(skill_names)},
*history, *history,
{"role": current_role, "content": merged},
] ]
if messages[-1].get("role") == current_role:
last = dict(messages[-1])
last["content"] = self._merge_message_content(last.get("content"), merged)
messages[-1] = last
return messages
messages.append({"role": current_role, "content": merged})
return messages
def _build_user_content(self, text: str, media: list[str] | None) -> str | list[dict[str, Any]]: def _build_user_content(self, text: str, media: list[str] | None) -> str | list[dict[str, Any]]:
"""Build user message content with optional base64-encoded images.""" """Build user message content with optional base64-encoded images."""

View File

@ -29,8 +29,10 @@ from nanobot.agent.tools.web import WebFetchTool, WebSearchTool
from nanobot.bus.events import InboundMessage, OutboundMessage from nanobot.bus.events import InboundMessage, OutboundMessage
from nanobot.command import CommandContext, CommandRouter, register_builtin_commands from nanobot.command import CommandContext, CommandRouter, register_builtin_commands
from nanobot.bus.queue import MessageBus from nanobot.bus.queue import MessageBus
from nanobot.config.schema import AgentDefaults
from nanobot.providers.base import LLMProvider from nanobot.providers.base import LLMProvider
from nanobot.session.manager import Session, SessionManager from nanobot.session.manager import Session, SessionManager
from nanobot.utils.helpers import image_placeholder_text, truncate_text
if TYPE_CHECKING: if TYPE_CHECKING:
from nanobot.config.schema import ChannelsConfig, ExecToolConfig, WebSearchConfig from nanobot.config.schema import ChannelsConfig, ExecToolConfig, WebSearchConfig
@ -38,11 +40,7 @@ if TYPE_CHECKING:
class _LoopHook(AgentHook): class _LoopHook(AgentHook):
"""Core lifecycle hook for the main agent loop. """Core hook for the main loop."""
Handles streaming delta relay, progress reporting, tool-call logging,
and think-tag stripping for the built-in agent path.
"""
def __init__( def __init__(
self, self,
@ -102,11 +100,7 @@ class _LoopHook(AgentHook):
class _LoopHookChain(AgentHook): class _LoopHookChain(AgentHook):
"""Run the core loop hook first, then best-effort extra hooks. """Run the core hook before extra hooks."""
This preserves the historical failure behavior of ``_LoopHook`` while still
letting user-supplied hooks opt into ``CompositeHook`` isolation.
"""
__slots__ = ("_primary", "_extras") __slots__ = ("_primary", "_extras")
@ -154,7 +148,7 @@ class AgentLoop:
5. Sends responses back 5. Sends responses back
""" """
_TOOL_RESULT_MAX_CHARS = 16_000 _RUNTIME_CHECKPOINT_KEY = "runtime_checkpoint"
def __init__( def __init__(
self, self,
@ -162,8 +156,11 @@ class AgentLoop:
provider: LLMProvider, provider: LLMProvider,
workspace: Path, workspace: Path,
model: str | None = None, model: str | None = None,
max_iterations: int = 40, max_iterations: int | None = None,
context_window_tokens: int = 65_536, context_window_tokens: int | None = None,
context_block_limit: int | None = None,
max_tool_result_chars: int | None = None,
provider_retry_mode: str = "standard",
web_search_config: WebSearchConfig | None = None, web_search_config: WebSearchConfig | None = None,
web_proxy: str | None = None, web_proxy: str | None = None,
exec_config: ExecToolConfig | None = None, exec_config: ExecToolConfig | None = None,
@ -177,13 +174,27 @@ class AgentLoop:
): ):
from nanobot.config.schema import ExecToolConfig, WebSearchConfig from nanobot.config.schema import ExecToolConfig, WebSearchConfig
defaults = AgentDefaults()
self.bus = bus self.bus = bus
self.channels_config = channels_config self.channels_config = channels_config
self.provider = provider self.provider = provider
self.workspace = workspace self.workspace = workspace
self.model = model or provider.get_default_model() self.model = model or provider.get_default_model()
self.max_iterations = max_iterations self.max_iterations = (
self.context_window_tokens = context_window_tokens max_iterations if max_iterations is not None else defaults.max_tool_iterations
)
self.context_window_tokens = (
context_window_tokens
if context_window_tokens is not None
else defaults.context_window_tokens
)
self.context_block_limit = context_block_limit
self.max_tool_result_chars = (
max_tool_result_chars
if max_tool_result_chars is not None
else defaults.max_tool_result_chars
)
self.provider_retry_mode = provider_retry_mode
self.web_search_config = web_search_config or WebSearchConfig() self.web_search_config = web_search_config or WebSearchConfig()
self.web_proxy = web_proxy self.web_proxy = web_proxy
self.exec_config = exec_config or ExecToolConfig() self.exec_config = exec_config or ExecToolConfig()
@ -202,6 +213,7 @@ class AgentLoop:
workspace=workspace, workspace=workspace,
bus=bus, bus=bus,
model=self.model, model=self.model,
max_tool_result_chars=self.max_tool_result_chars,
web_search_config=self.web_search_config, web_search_config=self.web_search_config,
web_proxy=web_proxy, web_proxy=web_proxy,
exec_config=self.exec_config, exec_config=self.exec_config,
@ -313,6 +325,7 @@ class AgentLoop:
on_stream: Callable[[str], Awaitable[None]] | None = None, on_stream: Callable[[str], Awaitable[None]] | None = None,
on_stream_end: Callable[..., Awaitable[None]] | None = None, on_stream_end: Callable[..., Awaitable[None]] | None = None,
*, *,
session: Session | None = None,
channel: str = "cli", channel: str = "cli",
chat_id: str = "direct", chat_id: str = "direct",
message_id: str | None = None, message_id: str | None = None,
@ -339,14 +352,27 @@ class AgentLoop:
else loop_hook else loop_hook
) )
async def _checkpoint(payload: dict[str, Any]) -> None:
if session is None:
return
self._set_runtime_checkpoint(session, payload)
result = await self.runner.run(AgentRunSpec( result = await self.runner.run(AgentRunSpec(
initial_messages=initial_messages, initial_messages=initial_messages,
tools=self.tools, tools=self.tools,
model=self.model, model=self.model,
max_iterations=self.max_iterations, max_iterations=self.max_iterations,
max_tool_result_chars=self.max_tool_result_chars,
hook=hook, hook=hook,
error_message="Sorry, I encountered an error calling the AI model.", error_message="Sorry, I encountered an error calling the AI model.",
concurrent_tools=True, concurrent_tools=True,
workspace=self.workspace,
session_key=session.key if session else None,
context_window_tokens=self.context_window_tokens,
context_block_limit=self.context_block_limit,
provider_retry_mode=self.provider_retry_mode,
progress_callback=on_progress,
checkpoint_callback=_checkpoint,
)) ))
self._last_usage = result.usage self._last_usage = result.usage
if result.stop_reason == "max_iterations": if result.stop_reason == "max_iterations":
@ -484,6 +510,8 @@ class AgentLoop:
logger.info("Processing system message from {}", msg.sender_id) logger.info("Processing system message from {}", msg.sender_id)
key = f"{channel}:{chat_id}" key = f"{channel}:{chat_id}"
session = self.sessions.get_or_create(key) session = self.sessions.get_or_create(key)
if self._restore_runtime_checkpoint(session):
self.sessions.save(session)
await self.memory_consolidator.maybe_consolidate_by_tokens(session) await self.memory_consolidator.maybe_consolidate_by_tokens(session)
self._set_tool_context(channel, chat_id, msg.metadata.get("message_id")) self._set_tool_context(channel, chat_id, msg.metadata.get("message_id"))
history = session.get_history(max_messages=0) history = session.get_history(max_messages=0)
@ -494,10 +522,11 @@ class AgentLoop:
current_role=current_role, current_role=current_role,
) )
final_content, _, all_msgs = await self._run_agent_loop( final_content, _, all_msgs = await self._run_agent_loop(
messages, channel=channel, chat_id=chat_id, messages, session=session, channel=channel, chat_id=chat_id,
message_id=msg.metadata.get("message_id"), message_id=msg.metadata.get("message_id"),
) )
self._save_turn(session, all_msgs, 1 + len(history)) self._save_turn(session, all_msgs, 1 + len(history))
self._clear_runtime_checkpoint(session)
self.sessions.save(session) self.sessions.save(session)
self._schedule_background(self.memory_consolidator.maybe_consolidate_by_tokens(session)) self._schedule_background(self.memory_consolidator.maybe_consolidate_by_tokens(session))
return OutboundMessage(channel=channel, chat_id=chat_id, return OutboundMessage(channel=channel, chat_id=chat_id,
@ -508,6 +537,8 @@ class AgentLoop:
key = session_key or msg.session_key key = session_key or msg.session_key
session = self.sessions.get_or_create(key) session = self.sessions.get_or_create(key)
if self._restore_runtime_checkpoint(session):
self.sessions.save(session)
# Slash commands # Slash commands
raw = msg.content.strip() raw = msg.content.strip()
@ -543,6 +574,7 @@ class AgentLoop:
on_progress=on_progress or _bus_progress, on_progress=on_progress or _bus_progress,
on_stream=on_stream, on_stream=on_stream,
on_stream_end=on_stream_end, on_stream_end=on_stream_end,
session=session,
channel=msg.channel, chat_id=msg.chat_id, channel=msg.channel, chat_id=msg.chat_id,
message_id=msg.metadata.get("message_id"), message_id=msg.metadata.get("message_id"),
) )
@ -551,6 +583,7 @@ class AgentLoop:
final_content = "I've completed processing but have no response to give." final_content = "I've completed processing but have no response to give."
self._save_turn(session, all_msgs, 1 + len(history)) self._save_turn(session, all_msgs, 1 + len(history))
self._clear_runtime_checkpoint(session)
self.sessions.save(session) self.sessions.save(session)
self._schedule_background(self.memory_consolidator.maybe_consolidate_by_tokens(session)) self._schedule_background(self.memory_consolidator.maybe_consolidate_by_tokens(session))
@ -568,12 +601,6 @@ class AgentLoop:
metadata=meta, metadata=meta,
) )
@staticmethod
def _image_placeholder(block: dict[str, Any]) -> dict[str, str]:
"""Convert an inline image block into a compact text placeholder."""
path = (block.get("_meta") or {}).get("path", "")
return {"type": "text", "text": f"[image: {path}]" if path else "[image]"}
def _sanitize_persisted_blocks( def _sanitize_persisted_blocks(
self, self,
content: list[dict[str, Any]], content: list[dict[str, Any]],
@ -600,13 +627,14 @@ class AgentLoop:
block.get("type") == "image_url" block.get("type") == "image_url"
and block.get("image_url", {}).get("url", "").startswith("data:image/") and block.get("image_url", {}).get("url", "").startswith("data:image/")
): ):
filtered.append(self._image_placeholder(block)) path = (block.get("_meta") or {}).get("path", "")
filtered.append({"type": "text", "text": image_placeholder_text(path)})
continue continue
if block.get("type") == "text" and isinstance(block.get("text"), str): if block.get("type") == "text" and isinstance(block.get("text"), str):
text = block["text"] text = block["text"]
if truncate_text and len(text) > self._TOOL_RESULT_MAX_CHARS: if truncate_text and len(text) > self.max_tool_result_chars:
text = text[:self._TOOL_RESULT_MAX_CHARS] + "\n... (truncated)" text = truncate_text(text, self.max_tool_result_chars)
filtered.append({**block, "text": text}) filtered.append({**block, "text": text})
continue continue
@ -623,8 +651,8 @@ class AgentLoop:
if role == "assistant" and not content and not entry.get("tool_calls"): if role == "assistant" and not content and not entry.get("tool_calls"):
continue # skip empty assistant messages — they poison session context continue # skip empty assistant messages — they poison session context
if role == "tool": if role == "tool":
if isinstance(content, str) and len(content) > self._TOOL_RESULT_MAX_CHARS: if isinstance(content, str) and len(content) > self.max_tool_result_chars:
entry["content"] = content[:self._TOOL_RESULT_MAX_CHARS] + "\n... (truncated)" entry["content"] = truncate_text(content, self.max_tool_result_chars)
elif isinstance(content, list): elif isinstance(content, list):
filtered = self._sanitize_persisted_blocks(content, truncate_text=True) filtered = self._sanitize_persisted_blocks(content, truncate_text=True)
if not filtered: if not filtered:
@ -647,6 +675,78 @@ class AgentLoop:
session.messages.append(entry) session.messages.append(entry)
session.updated_at = datetime.now() session.updated_at = datetime.now()
def _set_runtime_checkpoint(self, session: Session, payload: dict[str, Any]) -> None:
"""Persist the latest in-flight turn state into session metadata."""
session.metadata[self._RUNTIME_CHECKPOINT_KEY] = payload
self.sessions.save(session)
def _clear_runtime_checkpoint(self, session: Session) -> None:
if self._RUNTIME_CHECKPOINT_KEY in session.metadata:
session.metadata.pop(self._RUNTIME_CHECKPOINT_KEY, None)
@staticmethod
def _checkpoint_message_key(message: dict[str, Any]) -> tuple[Any, ...]:
return (
message.get("role"),
message.get("content"),
message.get("tool_call_id"),
message.get("name"),
message.get("tool_calls"),
message.get("reasoning_content"),
message.get("thinking_blocks"),
)
def _restore_runtime_checkpoint(self, session: Session) -> bool:
"""Materialize an unfinished turn into session history before a new request."""
from datetime import datetime
checkpoint = session.metadata.get(self._RUNTIME_CHECKPOINT_KEY)
if not isinstance(checkpoint, dict):
return False
assistant_message = checkpoint.get("assistant_message")
completed_tool_results = checkpoint.get("completed_tool_results") or []
pending_tool_calls = checkpoint.get("pending_tool_calls") or []
restored_messages: list[dict[str, Any]] = []
if isinstance(assistant_message, dict):
restored = dict(assistant_message)
restored.setdefault("timestamp", datetime.now().isoformat())
restored_messages.append(restored)
for message in completed_tool_results:
if isinstance(message, dict):
restored = dict(message)
restored.setdefault("timestamp", datetime.now().isoformat())
restored_messages.append(restored)
for tool_call in pending_tool_calls:
if not isinstance(tool_call, dict):
continue
tool_id = tool_call.get("id")
name = ((tool_call.get("function") or {}).get("name")) or "tool"
restored_messages.append({
"role": "tool",
"tool_call_id": tool_id,
"name": name,
"content": "Error: Task interrupted before this tool finished.",
"timestamp": datetime.now().isoformat(),
})
overlap = 0
max_overlap = min(len(session.messages), len(restored_messages))
for size in range(max_overlap, 0, -1):
existing = session.messages[-size:]
restored = restored_messages[:size]
if all(
self._checkpoint_message_key(left) == self._checkpoint_message_key(right)
for left, right in zip(existing, restored)
):
overlap = size
break
session.messages.extend(restored_messages[overlap:])
self._clear_runtime_checkpoint(session)
return True
async def process_direct( async def process_direct(
self, self,
content: str, content: str,

View File

@ -4,20 +4,29 @@ from __future__ import annotations
import asyncio import asyncio
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path
from typing import Any from typing import Any
from loguru import logger
from nanobot.agent.hook import AgentHook, AgentHookContext from nanobot.agent.hook import AgentHook, AgentHookContext
from nanobot.agent.tools.registry import ToolRegistry from nanobot.agent.tools.registry import ToolRegistry
from nanobot.providers.base import LLMProvider, ToolCallRequest from nanobot.providers.base import LLMProvider, ToolCallRequest
from nanobot.utils.helpers import build_assistant_message from nanobot.utils.helpers import (
build_assistant_message,
estimate_message_tokens,
estimate_prompt_tokens_chain,
find_legal_message_start,
maybe_persist_tool_result,
truncate_text,
)
_DEFAULT_MAX_ITERATIONS_MESSAGE = ( _DEFAULT_MAX_ITERATIONS_MESSAGE = (
"I reached the maximum number of tool call iterations ({max_iterations}) " "I reached the maximum number of tool call iterations ({max_iterations}) "
"without completing the task. You can try breaking the task into smaller steps." "without completing the task. You can try breaking the task into smaller steps."
) )
_DEFAULT_ERROR_MESSAGE = "Sorry, I encountered an error calling the AI model." _DEFAULT_ERROR_MESSAGE = "Sorry, I encountered an error calling the AI model."
_SNIP_SAFETY_BUFFER = 1024
@dataclass(slots=True) @dataclass(slots=True)
class AgentRunSpec: class AgentRunSpec:
"""Configuration for a single agent execution.""" """Configuration for a single agent execution."""
@ -26,6 +35,7 @@ class AgentRunSpec:
tools: ToolRegistry tools: ToolRegistry
model: str model: str
max_iterations: int max_iterations: int
max_tool_result_chars: int
temperature: float | None = None temperature: float | None = None
max_tokens: int | None = None max_tokens: int | None = None
reasoning_effort: str | None = None reasoning_effort: str | None = None
@ -34,6 +44,13 @@ class AgentRunSpec:
max_iterations_message: str | None = None max_iterations_message: str | None = None
concurrent_tools: bool = False concurrent_tools: bool = False
fail_on_tool_error: bool = False fail_on_tool_error: bool = False
workspace: Path | None = None
session_key: str | None = None
context_window_tokens: int | None = None
context_block_limit: int | None = None
provider_retry_mode: str = "standard"
progress_callback: Any | None = None
checkpoint_callback: Any | None = None
@dataclass(slots=True) @dataclass(slots=True)
@ -66,12 +83,25 @@ class AgentRunner:
tool_events: list[dict[str, str]] = [] tool_events: list[dict[str, str]] = []
for iteration in range(spec.max_iterations): for iteration in range(spec.max_iterations):
try:
messages = self._apply_tool_result_budget(spec, messages)
messages_for_model = self._snip_history(spec, messages)
except Exception as exc:
logger.warning(
"Context governance failed on turn {} for {}: {}; using raw messages",
iteration,
spec.session_key or "default",
exc,
)
messages_for_model = messages
context = AgentHookContext(iteration=iteration, messages=messages) context = AgentHookContext(iteration=iteration, messages=messages)
await hook.before_iteration(context) await hook.before_iteration(context)
kwargs: dict[str, Any] = { kwargs: dict[str, Any] = {
"messages": messages, "messages": messages_for_model,
"tools": spec.tools.get_definitions(), "tools": spec.tools.get_definitions(),
"model": spec.model, "model": spec.model,
"retry_mode": spec.provider_retry_mode,
"on_retry_wait": spec.progress_callback,
} }
if spec.temperature is not None: if spec.temperature is not None:
kwargs["temperature"] = spec.temperature kwargs["temperature"] = spec.temperature
@ -104,13 +134,25 @@ class AgentRunner:
if hook.wants_streaming(): if hook.wants_streaming():
await hook.on_stream_end(context, resuming=True) await hook.on_stream_end(context, resuming=True)
messages.append(build_assistant_message( assistant_message = build_assistant_message(
response.content or "", response.content or "",
tool_calls=[tc.to_openai_tool_call() for tc in response.tool_calls], tool_calls=[tc.to_openai_tool_call() for tc in response.tool_calls],
reasoning_content=response.reasoning_content, reasoning_content=response.reasoning_content,
thinking_blocks=response.thinking_blocks, thinking_blocks=response.thinking_blocks,
)) )
messages.append(assistant_message)
tools_used.extend(tc.name for tc in response.tool_calls) tools_used.extend(tc.name for tc in response.tool_calls)
await self._emit_checkpoint(
spec,
{
"phase": "awaiting_tools",
"iteration": iteration,
"model": spec.model,
"assistant_message": assistant_message,
"completed_tool_results": [],
"pending_tool_calls": [tc.to_openai_tool_call() for tc in response.tool_calls],
},
)
await hook.before_execute_tools(context) await hook.before_execute_tools(context)
@ -125,13 +167,31 @@ class AgentRunner:
context.stop_reason = stop_reason context.stop_reason = stop_reason
await hook.after_iteration(context) await hook.after_iteration(context)
break break
completed_tool_results: list[dict[str, Any]] = []
for tool_call, result in zip(response.tool_calls, results): for tool_call, result in zip(response.tool_calls, results):
messages.append({ tool_message = {
"role": "tool", "role": "tool",
"tool_call_id": tool_call.id, "tool_call_id": tool_call.id,
"name": tool_call.name, "name": tool_call.name,
"content": result, "content": self._normalize_tool_result(
}) spec,
tool_call.id,
result,
),
}
messages.append(tool_message)
completed_tool_results.append(tool_message)
await self._emit_checkpoint(
spec,
{
"phase": "tools_completed",
"iteration": iteration,
"model": spec.model,
"assistant_message": assistant_message,
"completed_tool_results": completed_tool_results,
"pending_tool_calls": [],
},
)
await hook.after_iteration(context) await hook.after_iteration(context)
continue continue
@ -143,6 +203,7 @@ class AgentRunner:
final_content = clean or spec.error_message or _DEFAULT_ERROR_MESSAGE final_content = clean or spec.error_message or _DEFAULT_ERROR_MESSAGE
stop_reason = "error" stop_reason = "error"
error = final_content error = final_content
self._append_final_message(messages, final_content)
context.final_content = final_content context.final_content = final_content
context.error = error context.error = error
context.stop_reason = stop_reason context.stop_reason = stop_reason
@ -154,6 +215,17 @@ class AgentRunner:
reasoning_content=response.reasoning_content, reasoning_content=response.reasoning_content,
thinking_blocks=response.thinking_blocks, thinking_blocks=response.thinking_blocks,
)) ))
await self._emit_checkpoint(
spec,
{
"phase": "final_response",
"iteration": iteration,
"model": spec.model,
"assistant_message": messages[-1],
"completed_tool_results": [],
"pending_tool_calls": [],
},
)
final_content = clean final_content = clean
context.final_content = final_content context.final_content = final_content
context.stop_reason = stop_reason context.stop_reason = stop_reason
@ -163,6 +235,7 @@ class AgentRunner:
stop_reason = "max_iterations" stop_reason = "max_iterations"
template = spec.max_iterations_message or _DEFAULT_MAX_ITERATIONS_MESSAGE template = spec.max_iterations_message or _DEFAULT_MAX_ITERATIONS_MESSAGE
final_content = template.format(max_iterations=spec.max_iterations) final_content = template.format(max_iterations=spec.max_iterations)
self._append_final_message(messages, final_content)
return AgentRunResult( return AgentRunResult(
final_content=final_content, final_content=final_content,
@ -179,16 +252,17 @@ class AgentRunner:
spec: AgentRunSpec, spec: AgentRunSpec,
tool_calls: list[ToolCallRequest], tool_calls: list[ToolCallRequest],
) -> tuple[list[Any], list[dict[str, str]], BaseException | None]: ) -> tuple[list[Any], list[dict[str, str]], BaseException | None]:
if spec.concurrent_tools: batches = self._partition_tool_batches(spec, tool_calls)
tool_results = await asyncio.gather(*( tool_results: list[tuple[Any, dict[str, str], BaseException | None]] = []
self._run_tool(spec, tool_call) for batch in batches:
for tool_call in tool_calls if spec.concurrent_tools and len(batch) > 1:
)) tool_results.extend(await asyncio.gather(*(
else: self._run_tool(spec, tool_call)
tool_results = [ for tool_call in batch
await self._run_tool(spec, tool_call) )))
for tool_call in tool_calls else:
] for tool_call in batch:
tool_results.append(await self._run_tool(spec, tool_call))
results: list[Any] = [] results: list[Any] = []
events: list[dict[str, str]] = [] events: list[dict[str, str]] = []
@ -205,8 +279,28 @@ class AgentRunner:
spec: AgentRunSpec, spec: AgentRunSpec,
tool_call: ToolCallRequest, tool_call: ToolCallRequest,
) -> tuple[Any, dict[str, str], BaseException | None]: ) -> tuple[Any, dict[str, str], BaseException | None]:
_HINT = "\n\n[Analyze the error above and try a different approach.]"
prepare_call = getattr(spec.tools, "prepare_call", None)
tool, params, prep_error = None, tool_call.arguments, None
if callable(prepare_call):
try:
prepared = prepare_call(tool_call.name, tool_call.arguments)
if isinstance(prepared, tuple) and len(prepared) == 3:
tool, params, prep_error = prepared
except Exception:
pass
if prep_error:
event = {
"name": tool_call.name,
"status": "error",
"detail": prep_error.split(": ", 1)[-1][:120],
}
return prep_error + _HINT, event, RuntimeError(prep_error) if spec.fail_on_tool_error else None
try: try:
result = await spec.tools.execute(tool_call.name, tool_call.arguments) if tool is not None:
result = await tool.execute(**params)
else:
result = await spec.tools.execute(tool_call.name, params)
except asyncio.CancelledError: except asyncio.CancelledError:
raise raise
except BaseException as exc: except BaseException as exc:
@ -219,14 +313,175 @@ class AgentRunner:
return f"Error: {type(exc).__name__}: {exc}", event, exc return f"Error: {type(exc).__name__}: {exc}", event, exc
return f"Error: {type(exc).__name__}: {exc}", event, None return f"Error: {type(exc).__name__}: {exc}", event, None
if isinstance(result, str) and result.startswith("Error"):
event = {
"name": tool_call.name,
"status": "error",
"detail": result.replace("\n", " ").strip()[:120],
}
if spec.fail_on_tool_error:
return result + _HINT, event, RuntimeError(result)
return result + _HINT, event, None
detail = "" if result is None else str(result) detail = "" if result is None else str(result)
detail = detail.replace("\n", " ").strip() detail = detail.replace("\n", " ").strip()
if not detail: if not detail:
detail = "(empty)" detail = "(empty)"
elif len(detail) > 120: elif len(detail) > 120:
detail = detail[:120] + "..." detail = detail[:120] + "..."
return result, { return result, {"name": tool_call.name, "status": "ok", "detail": detail}, None
"name": tool_call.name,
"status": "error" if isinstance(result, str) and result.startswith("Error") else "ok", async def _emit_checkpoint(
"detail": detail, self,
}, None spec: AgentRunSpec,
payload: dict[str, Any],
) -> None:
callback = spec.checkpoint_callback
if callback is not None:
await callback(payload)
@staticmethod
def _append_final_message(messages: list[dict[str, Any]], content: str | None) -> None:
if not content:
return
if (
messages
and messages[-1].get("role") == "assistant"
and not messages[-1].get("tool_calls")
):
if messages[-1].get("content") == content:
return
messages[-1] = build_assistant_message(content)
return
messages.append(build_assistant_message(content))
def _normalize_tool_result(
self,
spec: AgentRunSpec,
tool_call_id: str,
result: Any,
) -> Any:
try:
content = maybe_persist_tool_result(
spec.workspace,
spec.session_key,
tool_call_id,
result,
max_chars=spec.max_tool_result_chars,
)
except Exception as exc:
logger.warning(
"Tool result persist failed for {} in {}: {}; using raw result",
tool_call_id,
spec.session_key or "default",
exc,
)
content = result
if isinstance(content, str) and len(content) > spec.max_tool_result_chars:
return truncate_text(content, spec.max_tool_result_chars)
return content
def _apply_tool_result_budget(
self,
spec: AgentRunSpec,
messages: list[dict[str, Any]],
) -> list[dict[str, Any]]:
updated = messages
for idx, message in enumerate(messages):
if message.get("role") != "tool":
continue
normalized = self._normalize_tool_result(
spec,
str(message.get("tool_call_id") or f"tool_{idx}"),
message.get("content"),
)
if normalized != message.get("content"):
if updated is messages:
updated = [dict(m) for m in messages]
updated[idx]["content"] = normalized
return updated
def _snip_history(
self,
spec: AgentRunSpec,
messages: list[dict[str, Any]],
) -> list[dict[str, Any]]:
if not messages or not spec.context_window_tokens:
return messages
provider_max_tokens = getattr(getattr(self.provider, "generation", None), "max_tokens", 4096)
max_output = spec.max_tokens if isinstance(spec.max_tokens, int) else (
provider_max_tokens if isinstance(provider_max_tokens, int) else 4096
)
budget = spec.context_block_limit or (
spec.context_window_tokens - max_output - _SNIP_SAFETY_BUFFER
)
if budget <= 0:
return messages
estimate, _ = estimate_prompt_tokens_chain(
self.provider,
spec.model,
messages,
spec.tools.get_definitions(),
)
if estimate <= budget:
return messages
system_messages = [dict(msg) for msg in messages if msg.get("role") == "system"]
non_system = [dict(msg) for msg in messages if msg.get("role") != "system"]
if not non_system:
return messages
system_tokens = sum(estimate_message_tokens(msg) for msg in system_messages)
remaining_budget = max(128, budget - system_tokens)
kept: list[dict[str, Any]] = []
kept_tokens = 0
for message in reversed(non_system):
msg_tokens = estimate_message_tokens(message)
if kept and kept_tokens + msg_tokens > remaining_budget:
break
kept.append(message)
kept_tokens += msg_tokens
kept.reverse()
if kept:
for i, message in enumerate(kept):
if message.get("role") == "user":
kept = kept[i:]
break
start = find_legal_message_start(kept)
if start:
kept = kept[start:]
if not kept:
kept = non_system[-min(len(non_system), 4) :]
start = find_legal_message_start(kept)
if start:
kept = kept[start:]
return system_messages + kept
def _partition_tool_batches(
self,
spec: AgentRunSpec,
tool_calls: list[ToolCallRequest],
) -> list[list[ToolCallRequest]]:
if not spec.concurrent_tools:
return [[tool_call] for tool_call in tool_calls]
batches: list[list[ToolCallRequest]] = []
current: list[ToolCallRequest] = []
for tool_call in tool_calls:
get_tool = getattr(spec.tools, "get", None)
tool = get_tool(tool_call.name) if callable(get_tool) else None
can_batch = bool(tool and tool.concurrency_safe)
if can_batch:
current.append(tool_call)
continue
if current:
batches.append(current)
current = []
batches.append([tool_call])
if current:
batches.append(current)
return batches

View File

@ -44,6 +44,7 @@ class SubagentManager:
provider: LLMProvider, provider: LLMProvider,
workspace: Path, workspace: Path,
bus: MessageBus, bus: MessageBus,
max_tool_result_chars: int,
model: str | None = None, model: str | None = None,
web_search_config: "WebSearchConfig | None" = None, web_search_config: "WebSearchConfig | None" = None,
web_proxy: str | None = None, web_proxy: str | None = None,
@ -56,6 +57,7 @@ class SubagentManager:
self.workspace = workspace self.workspace = workspace
self.bus = bus self.bus = bus
self.model = model or provider.get_default_model() self.model = model or provider.get_default_model()
self.max_tool_result_chars = max_tool_result_chars
self.web_search_config = web_search_config or WebSearchConfig() self.web_search_config = web_search_config or WebSearchConfig()
self.web_proxy = web_proxy self.web_proxy = web_proxy
self.exec_config = exec_config or ExecToolConfig() self.exec_config = exec_config or ExecToolConfig()
@ -136,6 +138,7 @@ class SubagentManager:
tools=tools, tools=tools,
model=self.model, model=self.model,
max_iterations=15, max_iterations=15,
max_tool_result_chars=self.max_tool_result_chars,
hook=_SubagentHook(task_id), hook=_SubagentHook(task_id),
max_iterations_message="Task completed but no final response was generated.", max_iterations_message="Task completed but no final response was generated.",
error_message=None, error_message=None,

View File

@ -53,6 +53,21 @@ class Tool(ABC):
"""JSON Schema for tool parameters.""" """JSON Schema for tool parameters."""
pass pass
@property
def read_only(self) -> bool:
"""Whether this tool is side-effect free and safe to parallelize."""
return False
@property
def concurrency_safe(self) -> bool:
"""Whether this tool can run alongside other concurrency-safe tools."""
return self.read_only and not self.exclusive
@property
def exclusive(self) -> bool:
"""Whether this tool should run alone even if concurrency is enabled."""
return False
@abstractmethod @abstractmethod
async def execute(self, **kwargs: Any) -> Any: async def execute(self, **kwargs: Any) -> Any:
""" """

View File

@ -73,6 +73,10 @@ class ReadFileTool(_FsTool):
"Use offset and limit to paginate through large files." "Use offset and limit to paginate through large files."
) )
@property
def read_only(self) -> bool:
return True
@property @property
def parameters(self) -> dict[str, Any]: def parameters(self) -> dict[str, Any]:
return { return {
@ -344,6 +348,10 @@ class ListDirTool(_FsTool):
"Common noise directories (.git, node_modules, __pycache__, etc.) are auto-ignored." "Common noise directories (.git, node_modules, __pycache__, etc.) are auto-ignored."
) )
@property
def read_only(self) -> bool:
return True
@property @property
def parameters(self) -> dict[str, Any]: def parameters(self) -> dict[str, Any]:
return { return {

View File

@ -35,22 +35,35 @@ class ToolRegistry:
"""Get all tool definitions in OpenAI format.""" """Get all tool definitions in OpenAI format."""
return [tool.to_schema() for tool in self._tools.values()] return [tool.to_schema() for tool in self._tools.values()]
def prepare_call(
self,
name: str,
params: dict[str, Any],
) -> tuple[Tool | None, dict[str, Any], str | None]:
"""Resolve, cast, and validate one tool call."""
tool = self._tools.get(name)
if not tool:
return None, params, (
f"Error: Tool '{name}' not found. Available: {', '.join(self.tool_names)}"
)
cast_params = tool.cast_params(params)
errors = tool.validate_params(cast_params)
if errors:
return tool, cast_params, (
f"Error: Invalid parameters for tool '{name}': " + "; ".join(errors)
)
return tool, cast_params, None
async def execute(self, name: str, params: dict[str, Any]) -> Any: async def execute(self, name: str, params: dict[str, Any]) -> Any:
"""Execute a tool by name with given parameters.""" """Execute a tool by name with given parameters."""
_HINT = "\n\n[Analyze the error above and try a different approach.]" _HINT = "\n\n[Analyze the error above and try a different approach.]"
tool, params, error = self.prepare_call(name, params)
tool = self._tools.get(name) if error:
if not tool: return error + _HINT
return f"Error: Tool '{name}' not found. Available: {', '.join(self.tool_names)}"
try: try:
# Attempt to cast parameters to match schema types assert tool is not None # guarded by prepare_call()
params = tool.cast_params(params)
# Validate parameters
errors = tool.validate_params(params)
if errors:
return f"Error: Invalid parameters for tool '{name}': " + "; ".join(errors) + _HINT
result = await tool.execute(**params) result = await tool.execute(**params)
if isinstance(result, str) and result.startswith("Error"): if isinstance(result, str) and result.startswith("Error"):
return result + _HINT return result + _HINT

View File

@ -52,6 +52,10 @@ class ExecTool(Tool):
def description(self) -> str: def description(self) -> str:
return "Execute a shell command and return its output. Use with caution." return "Execute a shell command and return its output. Use with caution."
@property
def exclusive(self) -> bool:
return True
@property @property
def parameters(self) -> dict[str, Any]: def parameters(self) -> dict[str, Any]:
return { return {

View File

@ -92,6 +92,10 @@ class WebSearchTool(Tool):
self.config = config if config is not None else WebSearchConfig() self.config = config if config is not None else WebSearchConfig()
self.proxy = proxy self.proxy = proxy
@property
def read_only(self) -> bool:
return True
async def execute(self, query: str, count: int | None = None, **kwargs: Any) -> str: async def execute(self, query: str, count: int | None = None, **kwargs: Any) -> str:
provider = self.config.provider.strip().lower() or "brave" provider = self.config.provider.strip().lower() or "brave"
n = min(max(count or self.config.max_results, 1), 10) n = min(max(count or self.config.max_results, 1), 10)
@ -234,6 +238,10 @@ class WebFetchTool(Tool):
self.max_chars = max_chars self.max_chars = max_chars
self.proxy = proxy self.proxy = proxy
@property
def read_only(self) -> bool:
return True
async def execute(self, url: str, extractMode: str = "markdown", maxChars: int | None = None, **kwargs: Any) -> Any: async def execute(self, url: str, extractMode: str = "markdown", maxChars: int | None = None, **kwargs: Any) -> Any:
max_chars = maxChars or self.max_chars max_chars = maxChars or self.max_chars
is_valid, error_msg = _validate_url_safe(url) is_valid, error_msg = _validate_url_safe(url)

View File

@ -539,6 +539,9 @@ def serve(
model=runtime_config.agents.defaults.model, model=runtime_config.agents.defaults.model,
max_iterations=runtime_config.agents.defaults.max_tool_iterations, max_iterations=runtime_config.agents.defaults.max_tool_iterations,
context_window_tokens=runtime_config.agents.defaults.context_window_tokens, context_window_tokens=runtime_config.agents.defaults.context_window_tokens,
context_block_limit=runtime_config.agents.defaults.context_block_limit,
max_tool_result_chars=runtime_config.agents.defaults.max_tool_result_chars,
provider_retry_mode=runtime_config.agents.defaults.provider_retry_mode,
web_search_config=runtime_config.tools.web.search, web_search_config=runtime_config.tools.web.search,
web_proxy=runtime_config.tools.web.proxy or None, web_proxy=runtime_config.tools.web.proxy or None,
exec_config=runtime_config.tools.exec, exec_config=runtime_config.tools.exec,
@ -626,6 +629,9 @@ def gateway(
model=config.agents.defaults.model, model=config.agents.defaults.model,
max_iterations=config.agents.defaults.max_tool_iterations, max_iterations=config.agents.defaults.max_tool_iterations,
context_window_tokens=config.agents.defaults.context_window_tokens, context_window_tokens=config.agents.defaults.context_window_tokens,
context_block_limit=config.agents.defaults.context_block_limit,
max_tool_result_chars=config.agents.defaults.max_tool_result_chars,
provider_retry_mode=config.agents.defaults.provider_retry_mode,
web_search_config=config.tools.web.search, web_search_config=config.tools.web.search,
web_proxy=config.tools.web.proxy or None, web_proxy=config.tools.web.proxy or None,
exec_config=config.tools.exec, exec_config=config.tools.exec,
@ -832,6 +838,9 @@ def agent(
model=config.agents.defaults.model, model=config.agents.defaults.model,
max_iterations=config.agents.defaults.max_tool_iterations, max_iterations=config.agents.defaults.max_tool_iterations,
context_window_tokens=config.agents.defaults.context_window_tokens, context_window_tokens=config.agents.defaults.context_window_tokens,
context_block_limit=config.agents.defaults.context_block_limit,
max_tool_result_chars=config.agents.defaults.max_tool_result_chars,
provider_retry_mode=config.agents.defaults.provider_retry_mode,
web_search_config=config.tools.web.search, web_search_config=config.tools.web.search,
web_proxy=config.tools.web.proxy or None, web_proxy=config.tools.web.proxy or None,
exec_config=config.tools.exec, exec_config=config.tools.exec,

View File

@ -38,8 +38,11 @@ class AgentDefaults(Base):
) )
max_tokens: int = 8192 max_tokens: int = 8192
context_window_tokens: int = 65_536 context_window_tokens: int = 65_536
context_block_limit: int | None = None
temperature: float = 0.1 temperature: float = 0.1
max_tool_iterations: int = 40 max_tool_iterations: int = 200
max_tool_result_chars: int = 16_000
provider_retry_mode: Literal["standard", "persistent"] = "standard"
reasoning_effort: str | None = None # low / medium / high - enables LLM thinking mode reasoning_effort: str | None = None # low / medium / high - enables LLM thinking mode
timezone: str = "UTC" # IANA timezone, e.g. "Asia/Shanghai", "America/New_York" timezone: str = "UTC" # IANA timezone, e.g. "Asia/Shanghai", "America/New_York"

View File

@ -73,6 +73,9 @@ class Nanobot:
model=defaults.model, model=defaults.model,
max_iterations=defaults.max_tool_iterations, max_iterations=defaults.max_tool_iterations,
context_window_tokens=defaults.context_window_tokens, context_window_tokens=defaults.context_window_tokens,
context_block_limit=defaults.context_block_limit,
max_tool_result_chars=defaults.max_tool_result_chars,
provider_retry_mode=defaults.provider_retry_mode,
web_search_config=config.tools.web.search, web_search_config=config.tools.web.search,
web_proxy=config.tools.web.proxy or None, web_proxy=config.tools.web.proxy or None,
exec_config=config.tools.exec, exec_config=config.tools.exec,

View File

@ -2,6 +2,8 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import os
import re import re
import secrets import secrets
import string import string
@ -427,13 +429,33 @@ class AnthropicProvider(LLMProvider):
messages, tools, model, max_tokens, temperature, messages, tools, model, max_tokens, temperature,
reasoning_effort, tool_choice, reasoning_effort, tool_choice,
) )
idle_timeout_s = int(os.environ.get("NANOBOT_STREAM_IDLE_TIMEOUT_S", "90"))
try: try:
async with self._client.messages.stream(**kwargs) as stream: async with self._client.messages.stream(**kwargs) as stream:
if on_content_delta: if on_content_delta:
async for text in stream.text_stream: stream_iter = stream.text_stream.__aiter__()
while True:
try:
text = await asyncio.wait_for(
stream_iter.__anext__(),
timeout=idle_timeout_s,
)
except StopAsyncIteration:
break
await on_content_delta(text) await on_content_delta(text)
response = await stream.get_final_message() response = await asyncio.wait_for(
stream.get_final_message(),
timeout=idle_timeout_s,
)
return self._parse_response(response) return self._parse_response(response)
except asyncio.TimeoutError:
return LLMResponse(
content=(
f"Error calling LLM: stream stalled for more than "
f"{idle_timeout_s} seconds"
),
finish_reason="error",
)
except Exception as e: except Exception as e:
return LLMResponse(content=f"Error calling LLM: {e}", finish_reason="error") return LLMResponse(content=f"Error calling LLM: {e}", finish_reason="error")

View File

@ -2,6 +2,7 @@
import asyncio import asyncio
import json import json
import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field from dataclasses import dataclass, field
@ -9,6 +10,8 @@ from typing import Any
from loguru import logger from loguru import logger
from nanobot.utils.helpers import image_placeholder_text
@dataclass @dataclass
class ToolCallRequest: class ToolCallRequest:
@ -57,13 +60,7 @@ class LLMResponse:
@dataclass(frozen=True) @dataclass(frozen=True)
class GenerationSettings: class GenerationSettings:
"""Default generation parameters for LLM calls. """Default generation settings."""
Stored on the provider so every call site inherits the same defaults
without having to pass temperature / max_tokens / reasoning_effort
through every layer. Individual call sites can still override by
passing explicit keyword arguments to chat() / chat_with_retry().
"""
temperature: float = 0.7 temperature: float = 0.7
max_tokens: int = 4096 max_tokens: int = 4096
@ -71,14 +68,11 @@ class GenerationSettings:
class LLMProvider(ABC): class LLMProvider(ABC):
""" """Base class for LLM providers."""
Abstract base class for LLM providers.
Implementations should handle the specifics of each provider's API
while maintaining a consistent interface.
"""
_CHAT_RETRY_DELAYS = (1, 2, 4) _CHAT_RETRY_DELAYS = (1, 2, 4)
_PERSISTENT_MAX_DELAY = 60
_RETRY_HEARTBEAT_CHUNK = 30
_TRANSIENT_ERROR_MARKERS = ( _TRANSIENT_ERROR_MARKERS = (
"429", "429",
"rate limit", "rate limit",
@ -208,7 +202,7 @@ class LLMProvider(ABC):
for b in content: for b in content:
if isinstance(b, dict) and b.get("type") == "image_url": if isinstance(b, dict) and b.get("type") == "image_url":
path = (b.get("_meta") or {}).get("path", "") path = (b.get("_meta") or {}).get("path", "")
placeholder = f"[image: {path}]" if path else "[image omitted]" placeholder = image_placeholder_text(path, empty="[image omitted]")
new_content.append({"type": "text", "text": placeholder}) new_content.append({"type": "text", "text": placeholder})
found = True found = True
else: else:
@ -273,6 +267,8 @@ class LLMProvider(ABC):
reasoning_effort: object = _SENTINEL, reasoning_effort: object = _SENTINEL,
tool_choice: str | dict[str, Any] | None = None, tool_choice: str | dict[str, Any] | None = None,
on_content_delta: Callable[[str], Awaitable[None]] | None = None, on_content_delta: Callable[[str], Awaitable[None]] | None = None,
retry_mode: str = "standard",
on_retry_wait: Callable[[str], Awaitable[None]] | None = None,
) -> LLMResponse: ) -> LLMResponse:
"""Call chat_stream() with retry on transient provider failures.""" """Call chat_stream() with retry on transient provider failures."""
if max_tokens is self._SENTINEL: if max_tokens is self._SENTINEL:
@ -288,28 +284,13 @@ class LLMProvider(ABC):
reasoning_effort=reasoning_effort, tool_choice=tool_choice, reasoning_effort=reasoning_effort, tool_choice=tool_choice,
on_content_delta=on_content_delta, on_content_delta=on_content_delta,
) )
return await self._run_with_retry(
for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1): self._safe_chat_stream,
response = await self._safe_chat_stream(**kw) kw,
messages,
if response.finish_reason != "error": retry_mode=retry_mode,
return response on_retry_wait=on_retry_wait,
)
if not self._is_transient_error(response.content):
stripped = self._strip_image_content(messages)
if stripped is not None:
logger.warning("Non-transient LLM error with image content, retrying without images")
return await self._safe_chat_stream(**{**kw, "messages": stripped})
return response
logger.warning(
"LLM transient error (attempt {}/{}), retrying in {}s: {}",
attempt, len(self._CHAT_RETRY_DELAYS), delay,
(response.content or "")[:120].lower(),
)
await asyncio.sleep(delay)
return await self._safe_chat_stream(**kw)
async def chat_with_retry( async def chat_with_retry(
self, self,
@ -320,6 +301,8 @@ class LLMProvider(ABC):
temperature: object = _SENTINEL, temperature: object = _SENTINEL,
reasoning_effort: object = _SENTINEL, reasoning_effort: object = _SENTINEL,
tool_choice: str | dict[str, Any] | None = None, tool_choice: str | dict[str, Any] | None = None,
retry_mode: str = "standard",
on_retry_wait: Callable[[str], Awaitable[None]] | None = None,
) -> LLMResponse: ) -> LLMResponse:
"""Call chat() with retry on transient provider failures. """Call chat() with retry on transient provider failures.
@ -339,28 +322,102 @@ class LLMProvider(ABC):
max_tokens=max_tokens, temperature=temperature, max_tokens=max_tokens, temperature=temperature,
reasoning_effort=reasoning_effort, tool_choice=tool_choice, reasoning_effort=reasoning_effort, tool_choice=tool_choice,
) )
return await self._run_with_retry(
self._safe_chat,
kw,
messages,
retry_mode=retry_mode,
on_retry_wait=on_retry_wait,
)
for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1): @classmethod
response = await self._safe_chat(**kw) def _extract_retry_after(cls, content: str | None) -> float | None:
text = (content or "").lower()
match = re.search(r"retry after\s+(\d+(?:\.\d+)?)\s*(ms|milliseconds|s|sec|secs|seconds|m|min|minutes)?", text)
if not match:
return None
value = float(match.group(1))
unit = (match.group(2) or "s").lower()
if unit in {"ms", "milliseconds"}:
return max(0.1, value / 1000.0)
if unit in {"m", "min", "minutes"}:
return value * 60.0
return value
async def _sleep_with_heartbeat(
self,
delay: float,
*,
attempt: int,
persistent: bool,
on_retry_wait: Callable[[str], Awaitable[None]] | None = None,
) -> None:
remaining = max(0.0, delay)
while remaining > 0:
if on_retry_wait:
kind = "persistent retry" if persistent else "retry"
await on_retry_wait(
f"Model request failed, {kind} in {max(1, int(round(remaining)))}s "
f"(attempt {attempt})."
)
chunk = min(remaining, self._RETRY_HEARTBEAT_CHUNK)
await asyncio.sleep(chunk)
remaining -= chunk
async def _run_with_retry(
self,
call: Callable[..., Awaitable[LLMResponse]],
kw: dict[str, Any],
original_messages: list[dict[str, Any]],
*,
retry_mode: str,
on_retry_wait: Callable[[str], Awaitable[None]] | None,
) -> LLMResponse:
attempt = 0
delays = list(self._CHAT_RETRY_DELAYS)
persistent = retry_mode == "persistent"
last_response: LLMResponse | None = None
while True:
attempt += 1
response = await call(**kw)
if response.finish_reason != "error": if response.finish_reason != "error":
return response return response
last_response = response
if not self._is_transient_error(response.content): if not self._is_transient_error(response.content):
stripped = self._strip_image_content(messages) stripped = self._strip_image_content(original_messages)
if stripped is not None: if stripped is not None and stripped != kw["messages"]:
logger.warning("Non-transient LLM error with image content, retrying without images") logger.warning(
return await self._safe_chat(**{**kw, "messages": stripped}) "Non-transient LLM error with image content, retrying without images"
)
retry_kw = dict(kw)
retry_kw["messages"] = stripped
return await call(**retry_kw)
return response return response
if not persistent and attempt > len(delays):
break
base_delay = delays[min(attempt - 1, len(delays) - 1)]
delay = self._extract_retry_after(response.content) or base_delay
if persistent:
delay = min(delay, self._PERSISTENT_MAX_DELAY)
logger.warning( logger.warning(
"LLM transient error (attempt {}/{}), retrying in {}s: {}", "LLM transient error (attempt {}{}), retrying in {}s: {}",
attempt, len(self._CHAT_RETRY_DELAYS), delay, attempt,
"+" if persistent and attempt > len(delays) else f"/{len(delays)}",
int(round(delay)),
(response.content or "")[:120].lower(), (response.content or "")[:120].lower(),
) )
await asyncio.sleep(delay) await self._sleep_with_heartbeat(
delay,
attempt=attempt,
persistent=persistent,
on_retry_wait=on_retry_wait,
)
return await self._safe_chat(**kw) return last_response if last_response is not None else await call(**kw)
@abstractmethod @abstractmethod
def get_default_model(self) -> str: def get_default_model(self) -> str:

View File

@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import hashlib import hashlib
import os import os
import secrets import secrets
@ -20,7 +21,6 @@ if TYPE_CHECKING:
_ALLOWED_MSG_KEYS = frozenset({ _ALLOWED_MSG_KEYS = frozenset({
"role", "content", "tool_calls", "tool_call_id", "name", "role", "content", "tool_calls", "tool_call_id", "name",
"reasoning_content", "extra_content",
}) })
_ALNUM = string.ascii_letters + string.digits _ALNUM = string.ascii_letters + string.digits
@ -572,16 +572,33 @@ class OpenAICompatProvider(LLMProvider):
) )
kwargs["stream"] = True kwargs["stream"] = True
kwargs["stream_options"] = {"include_usage": True} kwargs["stream_options"] = {"include_usage": True}
idle_timeout_s = int(os.environ.get("NANOBOT_STREAM_IDLE_TIMEOUT_S", "90"))
try: try:
stream = await self._client.chat.completions.create(**kwargs) stream = await self._client.chat.completions.create(**kwargs)
chunks: list[Any] = [] chunks: list[Any] = []
async for chunk in stream: stream_iter = stream.__aiter__()
while True:
try:
chunk = await asyncio.wait_for(
stream_iter.__anext__(),
timeout=idle_timeout_s,
)
except StopAsyncIteration:
break
chunks.append(chunk) chunks.append(chunk)
if on_content_delta and chunk.choices: if on_content_delta and chunk.choices:
text = getattr(chunk.choices[0].delta, "content", None) text = getattr(chunk.choices[0].delta, "content", None)
if text: if text:
await on_content_delta(text) await on_content_delta(text)
return self._parse_chunks(chunks) return self._parse_chunks(chunks)
except asyncio.TimeoutError:
return LLMResponse(
content=(
f"Error calling LLM: stream stalled for more than "
f"{idle_timeout_s} seconds"
),
finish_reason="error",
)
except Exception as e: except Exception as e:
return self._handle_error(e) return self._handle_error(e)

View File

@ -10,20 +10,12 @@ from typing import Any
from loguru import logger from loguru import logger
from nanobot.config.paths import get_legacy_sessions_dir from nanobot.config.paths import get_legacy_sessions_dir
from nanobot.utils.helpers import ensure_dir, safe_filename from nanobot.utils.helpers import ensure_dir, find_legal_message_start, safe_filename
@dataclass @dataclass
class Session: class Session:
""" """A conversation session."""
A conversation session.
Stores messages in JSONL format for easy reading and persistence.
Important: Messages are append-only for LLM cache efficiency.
The consolidation process writes summaries to MEMORY.md/HISTORY.md
but does NOT modify the messages list or get_history() output.
"""
key: str # channel:chat_id key: str # channel:chat_id
messages: list[dict[str, Any]] = field(default_factory=list) messages: list[dict[str, Any]] = field(default_factory=list)
@ -43,43 +35,19 @@ class Session:
self.messages.append(msg) self.messages.append(msg)
self.updated_at = datetime.now() self.updated_at = datetime.now()
@staticmethod
def _find_legal_start(messages: list[dict[str, Any]]) -> int:
"""Find first index where every tool result has a matching assistant tool_call."""
declared: set[str] = set()
start = 0
for i, msg in enumerate(messages):
role = msg.get("role")
if role == "assistant":
for tc in msg.get("tool_calls") or []:
if isinstance(tc, dict) and tc.get("id"):
declared.add(str(tc["id"]))
elif role == "tool":
tid = msg.get("tool_call_id")
if tid and str(tid) not in declared:
start = i + 1
declared.clear()
for prev in messages[start:i + 1]:
if prev.get("role") == "assistant":
for tc in prev.get("tool_calls") or []:
if isinstance(tc, dict) and tc.get("id"):
declared.add(str(tc["id"]))
return start
def get_history(self, max_messages: int = 500) -> list[dict[str, Any]]: def get_history(self, max_messages: int = 500) -> list[dict[str, Any]]:
"""Return unconsolidated messages for LLM input, aligned to a legal tool-call boundary.""" """Return unconsolidated messages for LLM input, aligned to a legal tool-call boundary."""
unconsolidated = self.messages[self.last_consolidated:] unconsolidated = self.messages[self.last_consolidated:]
sliced = unconsolidated[-max_messages:] sliced = unconsolidated[-max_messages:]
# Drop leading non-user messages to avoid starting mid-turn when possible. # Avoid starting mid-turn when possible.
for i, message in enumerate(sliced): for i, message in enumerate(sliced):
if message.get("role") == "user": if message.get("role") == "user":
sliced = sliced[i:] sliced = sliced[i:]
break break
# Some providers reject orphan tool results if the matching assistant # Drop orphan tool results at the front.
# tool_calls message fell outside the fixed-size history window. start = find_legal_message_start(sliced)
start = self._find_legal_start(sliced)
if start: if start:
sliced = sliced[start:] sliced = sliced[start:]
@ -115,7 +83,7 @@ class Session:
retained = self.messages[start_idx:] retained = self.messages[start_idx:]
# Mirror get_history(): avoid persisting orphan tool results at the front. # Mirror get_history(): avoid persisting orphan tool results at the front.
start = self._find_legal_start(retained) start = find_legal_message_start(retained)
if start: if start:
retained = retained[start:] retained = retained[start:]

View File

@ -3,7 +3,9 @@
import base64 import base64
import json import json
import re import re
import shutil
import time import time
import uuid
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
@ -56,11 +58,7 @@ def timestamp() -> str:
def current_time_str(timezone: str | None = None) -> str: def current_time_str(timezone: str | None = None) -> str:
"""Human-readable current time with weekday and UTC offset. """Return the current time string."""
When *timezone* is a valid IANA name (e.g. ``"Asia/Shanghai"``), the time
is converted to that zone. Otherwise falls back to the host local time.
"""
from zoneinfo import ZoneInfo from zoneinfo import ZoneInfo
try: try:
@ -76,12 +74,164 @@ def current_time_str(timezone: str | None = None) -> str:
_UNSAFE_CHARS = re.compile(r'[<>:"/\\|?*]') _UNSAFE_CHARS = re.compile(r'[<>:"/\\|?*]')
_TOOL_RESULT_PREVIEW_CHARS = 1200
_TOOL_RESULTS_DIR = ".nanobot/tool-results"
_TOOL_RESULT_RETENTION_SECS = 7 * 24 * 60 * 60
_TOOL_RESULT_MAX_BUCKETS = 32
def safe_filename(name: str) -> str: def safe_filename(name: str) -> str:
"""Replace unsafe path characters with underscores.""" """Replace unsafe path characters with underscores."""
return _UNSAFE_CHARS.sub("_", name).strip() return _UNSAFE_CHARS.sub("_", name).strip()
def image_placeholder_text(path: str | None, *, empty: str = "[image]") -> str:
"""Build an image placeholder string."""
return f"[image: {path}]" if path else empty
def truncate_text(text: str, max_chars: int) -> str:
"""Truncate text with a stable suffix."""
if max_chars <= 0 or len(text) <= max_chars:
return text
return text[:max_chars] + "\n... (truncated)"
def find_legal_message_start(messages: list[dict[str, Any]]) -> int:
"""Find the first index whose tool results have matching assistant calls."""
declared: set[str] = set()
start = 0
for i, msg in enumerate(messages):
role = msg.get("role")
if role == "assistant":
for tc in msg.get("tool_calls") or []:
if isinstance(tc, dict) and tc.get("id"):
declared.add(str(tc["id"]))
elif role == "tool":
tid = msg.get("tool_call_id")
if tid and str(tid) not in declared:
start = i + 1
declared.clear()
for prev in messages[start : i + 1]:
if prev.get("role") == "assistant":
for tc in prev.get("tool_calls") or []:
if isinstance(tc, dict) and tc.get("id"):
declared.add(str(tc["id"]))
return start
def _stringify_text_blocks(content: list[dict[str, Any]]) -> str | None:
parts: list[str] = []
for block in content:
if not isinstance(block, dict):
return None
if block.get("type") != "text":
return None
text = block.get("text")
if not isinstance(text, str):
return None
parts.append(text)
return "\n".join(parts)
def _render_tool_result_reference(
filepath: Path,
*,
original_size: int,
preview: str,
truncated_preview: bool,
) -> str:
result = (
f"[tool output persisted]\n"
f"Full output saved to: {filepath}\n"
f"Original size: {original_size} chars\n"
f"Preview:\n{preview}"
)
if truncated_preview:
result += "\n...\n(Read the saved file if you need the full output.)"
return result
def _bucket_mtime(path: Path) -> float:
try:
return path.stat().st_mtime
except OSError:
return 0.0
def _cleanup_tool_result_buckets(root: Path, current_bucket: Path) -> None:
siblings = [path for path in root.iterdir() if path.is_dir() and path != current_bucket]
cutoff = time.time() - _TOOL_RESULT_RETENTION_SECS
for path in siblings:
if _bucket_mtime(path) < cutoff:
shutil.rmtree(path, ignore_errors=True)
keep = max(_TOOL_RESULT_MAX_BUCKETS - 1, 0)
siblings = [path for path in siblings if path.exists()]
if len(siblings) <= keep:
return
siblings.sort(key=_bucket_mtime, reverse=True)
for path in siblings[keep:]:
shutil.rmtree(path, ignore_errors=True)
def _write_text_atomic(path: Path, content: str) -> None:
tmp = path.with_name(f".{path.name}.{uuid.uuid4().hex}.tmp")
try:
tmp.write_text(content, encoding="utf-8")
tmp.replace(path)
finally:
if tmp.exists():
tmp.unlink(missing_ok=True)
def maybe_persist_tool_result(
workspace: Path | None,
session_key: str | None,
tool_call_id: str,
content: Any,
*,
max_chars: int,
) -> Any:
"""Persist oversized tool output and replace it with a stable reference string."""
if workspace is None or max_chars <= 0:
return content
text_payload: str | None = None
suffix = "txt"
if isinstance(content, str):
text_payload = content
elif isinstance(content, list):
text_payload = _stringify_text_blocks(content)
if text_payload is None:
return content
suffix = "json"
else:
return content
if len(text_payload) <= max_chars:
return content
root = ensure_dir(workspace / _TOOL_RESULTS_DIR)
bucket = ensure_dir(root / safe_filename(session_key or "default"))
try:
_cleanup_tool_result_buckets(root, bucket)
except Exception:
pass
path = bucket / f"{safe_filename(tool_call_id)}.{suffix}"
if not path.exists():
if suffix == "json" and isinstance(content, list):
_write_text_atomic(path, json.dumps(content, ensure_ascii=False, indent=2))
else:
_write_text_atomic(path, text_payload)
preview = text_payload[:_TOOL_RESULT_PREVIEW_CHARS]
return _render_tool_result_reference(
path,
original_size=len(text_payload),
preview=preview,
truncated_preview=len(text_payload) > _TOOL_RESULT_PREVIEW_CHARS,
)
def split_message(content: str, max_len: int = 2000) -> list[str]: def split_message(content: str, max_len: int = 2000) -> list[str]:
""" """
Split content into chunks within max_len, preferring line breaks. Split content into chunks within max_len, preferring line breaks.

View File

@ -71,3 +71,19 @@ def test_runtime_context_is_separate_untrusted_user_message(tmp_path) -> None:
assert "Channel: cli" in user_content assert "Channel: cli" in user_content
assert "Chat ID: direct" in user_content assert "Chat ID: direct" in user_content
assert "Return exactly: OK" in user_content assert "Return exactly: OK" in user_content
def test_subagent_result_does_not_create_consecutive_assistant_messages(tmp_path) -> None:
workspace = _make_workspace(tmp_path)
builder = ContextBuilder(workspace)
messages = builder.build_messages(
history=[{"role": "assistant", "content": "previous result"}],
current_message="subagent result",
channel="cli",
chat_id="direct",
current_role="assistant",
)
for left, right in zip(messages, messages[1:]):
assert not (left.get("role") == right.get("role") == "assistant")

View File

@ -5,7 +5,9 @@ from nanobot.session.manager import Session
def _mk_loop() -> AgentLoop: def _mk_loop() -> AgentLoop:
loop = AgentLoop.__new__(AgentLoop) loop = AgentLoop.__new__(AgentLoop)
loop._TOOL_RESULT_MAX_CHARS = AgentLoop._TOOL_RESULT_MAX_CHARS from nanobot.config.schema import AgentDefaults
loop.max_tool_result_chars = AgentDefaults().max_tool_result_chars
return loop return loop
@ -72,3 +74,129 @@ def test_save_turn_keeps_tool_results_under_16k() -> None:
) )
assert session.messages[0]["content"] == content assert session.messages[0]["content"] == content
def test_restore_runtime_checkpoint_rehydrates_completed_and_pending_tools() -> None:
loop = _mk_loop()
session = Session(
key="test:checkpoint",
metadata={
AgentLoop._RUNTIME_CHECKPOINT_KEY: {
"assistant_message": {
"role": "assistant",
"content": "working",
"tool_calls": [
{
"id": "call_done",
"type": "function",
"function": {"name": "read_file", "arguments": "{}"},
},
{
"id": "call_pending",
"type": "function",
"function": {"name": "exec", "arguments": "{}"},
},
],
},
"completed_tool_results": [
{
"role": "tool",
"tool_call_id": "call_done",
"name": "read_file",
"content": "ok",
}
],
"pending_tool_calls": [
{
"id": "call_pending",
"type": "function",
"function": {"name": "exec", "arguments": "{}"},
}
],
}
},
)
restored = loop._restore_runtime_checkpoint(session)
assert restored is True
assert session.metadata.get(AgentLoop._RUNTIME_CHECKPOINT_KEY) is None
assert session.messages[0]["role"] == "assistant"
assert session.messages[1]["tool_call_id"] == "call_done"
assert session.messages[2]["tool_call_id"] == "call_pending"
assert "interrupted before this tool finished" in session.messages[2]["content"].lower()
def test_restore_runtime_checkpoint_dedupes_overlapping_tail() -> None:
loop = _mk_loop()
session = Session(
key="test:checkpoint-overlap",
messages=[
{
"role": "assistant",
"content": "working",
"tool_calls": [
{
"id": "call_done",
"type": "function",
"function": {"name": "read_file", "arguments": "{}"},
},
{
"id": "call_pending",
"type": "function",
"function": {"name": "exec", "arguments": "{}"},
},
],
},
{
"role": "tool",
"tool_call_id": "call_done",
"name": "read_file",
"content": "ok",
},
],
metadata={
AgentLoop._RUNTIME_CHECKPOINT_KEY: {
"assistant_message": {
"role": "assistant",
"content": "working",
"tool_calls": [
{
"id": "call_done",
"type": "function",
"function": {"name": "read_file", "arguments": "{}"},
},
{
"id": "call_pending",
"type": "function",
"function": {"name": "exec", "arguments": "{}"},
},
],
},
"completed_tool_results": [
{
"role": "tool",
"tool_call_id": "call_done",
"name": "read_file",
"content": "ok",
}
],
"pending_tool_calls": [
{
"id": "call_pending",
"type": "function",
"function": {"name": "exec", "arguments": "{}"},
}
],
}
},
)
restored = loop._restore_runtime_checkpoint(session)
assert restored is True
assert session.metadata.get(AgentLoop._RUNTIME_CHECKPOINT_KEY) is None
assert len(session.messages) == 3
assert session.messages[0]["role"] == "assistant"
assert session.messages[1]["tool_call_id"] == "call_done"
assert session.messages[2]["tool_call_id"] == "call_pending"

View File

@ -2,12 +2,20 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import os
import time
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
from nanobot.config.schema import AgentDefaults
from nanobot.agent.tools.base import Tool
from nanobot.agent.tools.registry import ToolRegistry
from nanobot.providers.base import LLMResponse, ToolCallRequest from nanobot.providers.base import LLMResponse, ToolCallRequest
_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars
def _make_loop(tmp_path): def _make_loop(tmp_path):
from nanobot.agent.loop import AgentLoop from nanobot.agent.loop import AgentLoop
@ -60,6 +68,7 @@ async def test_runner_preserves_reasoning_fields_and_tool_results():
tools=tools, tools=tools,
model="test-model", model="test-model",
max_iterations=3, max_iterations=3,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
)) ))
assert result.final_content == "done" assert result.final_content == "done"
@ -135,6 +144,7 @@ async def test_runner_calls_hooks_in_order():
tools=tools, tools=tools,
model="test-model", model="test-model",
max_iterations=3, max_iterations=3,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
hook=RecordingHook(), hook=RecordingHook(),
)) ))
@ -191,6 +201,7 @@ async def test_runner_streaming_hook_receives_deltas_and_end_signal():
tools=tools, tools=tools,
model="test-model", model="test-model",
max_iterations=1, max_iterations=1,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
hook=StreamingHook(), hook=StreamingHook(),
)) ))
@ -219,6 +230,7 @@ async def test_runner_returns_max_iterations_fallback():
tools=tools, tools=tools,
model="test-model", model="test-model",
max_iterations=2, max_iterations=2,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
)) ))
assert result.stop_reason == "max_iterations" assert result.stop_reason == "max_iterations"
@ -226,7 +238,8 @@ async def test_runner_returns_max_iterations_fallback():
"I reached the maximum number of tool call iterations (2) " "I reached the maximum number of tool call iterations (2) "
"without completing the task. You can try breaking the task into smaller steps." "without completing the task. You can try breaking the task into smaller steps."
) )
assert result.messages[-1]["role"] == "assistant"
assert result.messages[-1]["content"] == result.final_content
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_runner_returns_structured_tool_error(): async def test_runner_returns_structured_tool_error():
@ -248,6 +261,7 @@ async def test_runner_returns_structured_tool_error():
tools=tools, tools=tools,
model="test-model", model="test-model",
max_iterations=2, max_iterations=2,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
fail_on_tool_error=True, fail_on_tool_error=True,
)) ))
@ -258,6 +272,232 @@ async def test_runner_returns_structured_tool_error():
] ]
@pytest.mark.asyncio
async def test_runner_persists_large_tool_results_for_follow_up_calls(tmp_path):
from nanobot.agent.runner import AgentRunSpec, AgentRunner
provider = MagicMock()
captured_second_call: list[dict] = []
call_count = {"n": 0}
async def chat_with_retry(*, messages, **kwargs):
call_count["n"] += 1
if call_count["n"] == 1:
return LLMResponse(
content="working",
tool_calls=[ToolCallRequest(id="call_big", name="list_dir", arguments={"path": "."})],
usage={"prompt_tokens": 5, "completion_tokens": 3},
)
captured_second_call[:] = messages
return LLMResponse(content="done", tool_calls=[], usage={})
provider.chat_with_retry = chat_with_retry
tools = MagicMock()
tools.get_definitions.return_value = []
tools.execute = AsyncMock(return_value="x" * 20_000)
runner = AgentRunner(provider)
result = await runner.run(AgentRunSpec(
initial_messages=[{"role": "user", "content": "do task"}],
tools=tools,
model="test-model",
max_iterations=2,
workspace=tmp_path,
session_key="test:runner",
max_tool_result_chars=2048,
))
assert result.final_content == "done"
tool_message = next(msg for msg in captured_second_call if msg.get("role") == "tool")
assert "[tool output persisted]" in tool_message["content"]
assert "tool-results" in tool_message["content"]
assert (tmp_path / ".nanobot" / "tool-results" / "test_runner" / "call_big.txt").exists()
def test_persist_tool_result_prunes_old_session_buckets(tmp_path):
from nanobot.utils.helpers import maybe_persist_tool_result
root = tmp_path / ".nanobot" / "tool-results"
old_bucket = root / "old_session"
recent_bucket = root / "recent_session"
old_bucket.mkdir(parents=True)
recent_bucket.mkdir(parents=True)
(old_bucket / "old.txt").write_text("old", encoding="utf-8")
(recent_bucket / "recent.txt").write_text("recent", encoding="utf-8")
stale = time.time() - (8 * 24 * 60 * 60)
os.utime(old_bucket, (stale, stale))
os.utime(old_bucket / "old.txt", (stale, stale))
persisted = maybe_persist_tool_result(
tmp_path,
"current:session",
"call_big",
"x" * 5000,
max_chars=64,
)
assert "[tool output persisted]" in persisted
assert not old_bucket.exists()
assert recent_bucket.exists()
assert (root / "current_session" / "call_big.txt").exists()
def test_persist_tool_result_leaves_no_temp_files(tmp_path):
from nanobot.utils.helpers import maybe_persist_tool_result
root = tmp_path / ".nanobot" / "tool-results"
maybe_persist_tool_result(
tmp_path,
"current:session",
"call_big",
"x" * 5000,
max_chars=64,
)
assert (root / "current_session" / "call_big.txt").exists()
assert list((root / "current_session").glob("*.tmp")) == []
@pytest.mark.asyncio
async def test_runner_uses_raw_messages_when_context_governance_fails():
from nanobot.agent.runner import AgentRunSpec, AgentRunner
provider = MagicMock()
captured_messages: list[dict] = []
async def chat_with_retry(*, messages, **kwargs):
captured_messages[:] = messages
return LLMResponse(content="done", tool_calls=[], usage={})
provider.chat_with_retry = chat_with_retry
tools = MagicMock()
tools.get_definitions.return_value = []
initial_messages = [
{"role": "system", "content": "system"},
{"role": "user", "content": "hello"},
]
runner = AgentRunner(provider)
runner._snip_history = MagicMock(side_effect=RuntimeError("boom")) # type: ignore[method-assign]
result = await runner.run(AgentRunSpec(
initial_messages=initial_messages,
tools=tools,
model="test-model",
max_iterations=1,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
))
assert result.final_content == "done"
assert captured_messages == initial_messages
@pytest.mark.asyncio
async def test_runner_keeps_going_when_tool_result_persistence_fails():
from nanobot.agent.runner import AgentRunSpec, AgentRunner
provider = MagicMock()
captured_second_call: list[dict] = []
call_count = {"n": 0}
async def chat_with_retry(*, messages, **kwargs):
call_count["n"] += 1
if call_count["n"] == 1:
return LLMResponse(
content="working",
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
usage={"prompt_tokens": 5, "completion_tokens": 3},
)
captured_second_call[:] = messages
return LLMResponse(content="done", tool_calls=[], usage={})
provider.chat_with_retry = chat_with_retry
tools = MagicMock()
tools.get_definitions.return_value = []
tools.execute = AsyncMock(return_value="tool result")
runner = AgentRunner(provider)
with patch("nanobot.agent.runner.maybe_persist_tool_result", side_effect=RuntimeError("disk full")):
result = await runner.run(AgentRunSpec(
initial_messages=[{"role": "user", "content": "do task"}],
tools=tools,
model="test-model",
max_iterations=2,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
))
assert result.final_content == "done"
tool_message = next(msg for msg in captured_second_call if msg.get("role") == "tool")
assert tool_message["content"] == "tool result"
class _DelayTool(Tool):
def __init__(self, name: str, *, delay: float, read_only: bool, shared_events: list[str]):
self._name = name
self._delay = delay
self._read_only = read_only
self._shared_events = shared_events
@property
def name(self) -> str:
return self._name
@property
def description(self) -> str:
return self._name
@property
def parameters(self) -> dict:
return {"type": "object", "properties": {}, "required": []}
@property
def read_only(self) -> bool:
return self._read_only
async def execute(self, **kwargs):
self._shared_events.append(f"start:{self._name}")
await asyncio.sleep(self._delay)
self._shared_events.append(f"end:{self._name}")
return self._name
@pytest.mark.asyncio
async def test_runner_batches_read_only_tools_before_exclusive_work():
from nanobot.agent.runner import AgentRunSpec, AgentRunner
tools = ToolRegistry()
shared_events: list[str] = []
read_a = _DelayTool("read_a", delay=0.05, read_only=True, shared_events=shared_events)
read_b = _DelayTool("read_b", delay=0.05, read_only=True, shared_events=shared_events)
write_a = _DelayTool("write_a", delay=0.01, read_only=False, shared_events=shared_events)
tools.register(read_a)
tools.register(read_b)
tools.register(write_a)
runner = AgentRunner(MagicMock())
await runner._execute_tools(
AgentRunSpec(
initial_messages=[],
tools=tools,
model="test-model",
max_iterations=1,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
concurrent_tools=True,
),
[
ToolCallRequest(id="ro1", name="read_a", arguments={}),
ToolCallRequest(id="ro2", name="read_b", arguments={}),
ToolCallRequest(id="rw1", name="write_a", arguments={}),
],
)
assert shared_events[0:2] == ["start:read_a", "start:read_b"]
assert "end:read_a" in shared_events and "end:read_b" in shared_events
assert shared_events.index("end:read_a") < shared_events.index("start:write_a")
assert shared_events.index("end:read_b") < shared_events.index("start:write_a")
assert shared_events[-2:] == ["start:write_a", "end:write_a"]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_loop_max_iterations_message_stays_stable(tmp_path): async def test_loop_max_iterations_message_stays_stable(tmp_path):
loop = _make_loop(tmp_path) loop = _make_loop(tmp_path)
@ -317,15 +557,20 @@ async def test_subagent_max_iterations_announces_existing_fallback(tmp_path, mon
provider.get_default_model.return_value = "test-model" provider.get_default_model.return_value = "test-model"
provider.chat_with_retry = AsyncMock(return_value=LLMResponse( provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
content="working", content="working",
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})], tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
)) ))
mgr = SubagentManager(provider=provider, workspace=tmp_path, bus=bus) mgr = SubagentManager(
provider=provider,
workspace=tmp_path,
bus=bus,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
)
mgr._announce_result = AsyncMock() mgr._announce_result = AsyncMock()
async def fake_execute(self, name, arguments): async def fake_execute(self, **kwargs):
return "tool result" return "tool result"
monkeypatch.setattr("nanobot.agent.tools.registry.ToolRegistry.execute", fake_execute) monkeypatch.setattr("nanobot.agent.tools.filesystem.ListDirTool.execute", fake_execute)
await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"}) await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"})

View File

@ -8,6 +8,10 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
from nanobot.config.schema import AgentDefaults
_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars
def _make_loop(*, exec_config=None): def _make_loop(*, exec_config=None):
"""Create a minimal AgentLoop with mocked dependencies.""" """Create a minimal AgentLoop with mocked dependencies."""
@ -186,7 +190,12 @@ class TestSubagentCancellation:
bus = MessageBus() bus = MessageBus()
provider = MagicMock() provider = MagicMock()
provider.get_default_model.return_value = "test-model" provider.get_default_model.return_value = "test-model"
mgr = SubagentManager(provider=provider, workspace=MagicMock(), bus=bus) mgr = SubagentManager(
provider=provider,
workspace=MagicMock(),
bus=bus,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
)
cancelled = asyncio.Event() cancelled = asyncio.Event()
@ -214,7 +223,12 @@ class TestSubagentCancellation:
bus = MessageBus() bus = MessageBus()
provider = MagicMock() provider = MagicMock()
provider.get_default_model.return_value = "test-model" provider.get_default_model.return_value = "test-model"
mgr = SubagentManager(provider=provider, workspace=MagicMock(), bus=bus) mgr = SubagentManager(
provider=provider,
workspace=MagicMock(),
bus=bus,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
)
assert await mgr.cancel_by_session("nonexistent") == 0 assert await mgr.cancel_by_session("nonexistent") == 0
@pytest.mark.asyncio @pytest.mark.asyncio
@ -236,19 +250,24 @@ class TestSubagentCancellation:
if call_count["n"] == 1: if call_count["n"] == 1:
return LLMResponse( return LLMResponse(
content="thinking", content="thinking",
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})], tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
reasoning_content="hidden reasoning", reasoning_content="hidden reasoning",
thinking_blocks=[{"type": "thinking", "thinking": "step"}], thinking_blocks=[{"type": "thinking", "thinking": "step"}],
) )
captured_second_call[:] = messages captured_second_call[:] = messages
return LLMResponse(content="done", tool_calls=[]) return LLMResponse(content="done", tool_calls=[])
provider.chat_with_retry = scripted_chat_with_retry provider.chat_with_retry = scripted_chat_with_retry
mgr = SubagentManager(provider=provider, workspace=tmp_path, bus=bus) mgr = SubagentManager(
provider=provider,
workspace=tmp_path,
bus=bus,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
)
async def fake_execute(self, name, arguments): async def fake_execute(self, **kwargs):
return "tool result" return "tool result"
monkeypatch.setattr("nanobot.agent.tools.registry.ToolRegistry.execute", fake_execute) monkeypatch.setattr("nanobot.agent.tools.filesystem.ListDirTool.execute", fake_execute)
await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"}) await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"})
@ -273,6 +292,7 @@ class TestSubagentCancellation:
provider=provider, provider=provider,
workspace=tmp_path, workspace=tmp_path,
bus=bus, bus=bus,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
exec_config=ExecToolConfig(enable=False), exec_config=ExecToolConfig(enable=False),
) )
mgr._announce_result = AsyncMock() mgr._announce_result = AsyncMock()
@ -304,20 +324,25 @@ class TestSubagentCancellation:
provider.get_default_model.return_value = "test-model" provider.get_default_model.return_value = "test-model"
provider.chat_with_retry = AsyncMock(return_value=LLMResponse( provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
content="thinking", content="thinking",
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})], tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
)) ))
mgr = SubagentManager(provider=provider, workspace=tmp_path, bus=bus) mgr = SubagentManager(
provider=provider,
workspace=tmp_path,
bus=bus,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
)
mgr._announce_result = AsyncMock() mgr._announce_result = AsyncMock()
calls = {"n": 0} calls = {"n": 0}
async def fake_execute(self, name, arguments): async def fake_execute(self, **kwargs):
calls["n"] += 1 calls["n"] += 1
if calls["n"] == 1: if calls["n"] == 1:
return "first result" return "first result"
raise RuntimeError("boom") raise RuntimeError("boom")
monkeypatch.setattr("nanobot.agent.tools.registry.ToolRegistry.execute", fake_execute) monkeypatch.setattr("nanobot.agent.tools.filesystem.ListDirTool.execute", fake_execute)
await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"}) await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"})
@ -340,15 +365,20 @@ class TestSubagentCancellation:
provider.get_default_model.return_value = "test-model" provider.get_default_model.return_value = "test-model"
provider.chat_with_retry = AsyncMock(return_value=LLMResponse( provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
content="thinking", content="thinking",
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})], tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
)) ))
mgr = SubagentManager(provider=provider, workspace=tmp_path, bus=bus) mgr = SubagentManager(
provider=provider,
workspace=tmp_path,
bus=bus,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
)
mgr._announce_result = AsyncMock() mgr._announce_result = AsyncMock()
started = asyncio.Event() started = asyncio.Event()
cancelled = asyncio.Event() cancelled = asyncio.Event()
async def fake_execute(self, name, arguments): async def fake_execute(self, **kwargs):
started.set() started.set()
try: try:
await asyncio.sleep(60) await asyncio.sleep(60)
@ -356,7 +386,7 @@ class TestSubagentCancellation:
cancelled.set() cancelled.set()
raise raise
monkeypatch.setattr("nanobot.agent.tools.registry.ToolRegistry.execute", fake_execute) monkeypatch.setattr("nanobot.agent.tools.filesystem.ListDirTool.execute", fake_execute)
task = asyncio.create_task( task = asyncio.create_task(
mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"}) mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"})
@ -364,7 +394,7 @@ class TestSubagentCancellation:
mgr._running_tasks["sub-1"] = task mgr._running_tasks["sub-1"] = task
mgr._session_tasks["test:c1"] = {"sub-1"} mgr._session_tasks["test:c1"] = {"sub-1"}
await started.wait() await asyncio.wait_for(started.wait(), timeout=1.0)
count = await mgr.cancel_by_session("test:c1") count = await mgr.cancel_by_session("test:c1")

View File

@ -594,7 +594,7 @@ async def test_send_stops_typing_after_send() -> None:
typing_channel.typing_enter_hook = slow_typing typing_channel.typing_enter_hook = slow_typing
await channel._start_typing(typing_channel) await channel._start_typing(typing_channel)
await start.wait() await asyncio.wait_for(start.wait(), timeout=1.0)
await channel.send(OutboundMessage(channel="discord", chat_id="123", content="hello")) await channel.send(OutboundMessage(channel="discord", chat_id="123", content="hello"))
release.set() release.set()
@ -614,7 +614,7 @@ async def test_send_stops_typing_after_send() -> None:
typing_channel.typing_enter_hook = slow_typing_progress typing_channel.typing_enter_hook = slow_typing_progress
await channel._start_typing(typing_channel) await channel._start_typing(typing_channel)
await start.wait() await asyncio.wait_for(start.wait(), timeout=1.0)
await channel.send( await channel.send(
OutboundMessage( OutboundMessage(
@ -665,7 +665,7 @@ async def test_start_typing_uses_typing_context_when_trigger_typing_missing() ->
typing_channel = _NoTriggerChannel(channel_id=123) typing_channel = _NoTriggerChannel(channel_id=123)
await channel._start_typing(typing_channel) # type: ignore[arg-type] await channel._start_typing(typing_channel) # type: ignore[arg-type]
await entered.wait() await asyncio.wait_for(entered.wait(), timeout=1.0)
assert "123" in channel._typing_tasks assert "123" in channel._typing_tasks

View File

@ -8,6 +8,7 @@ Validates that:
from __future__ import annotations from __future__ import annotations
import asyncio
from types import SimpleNamespace from types import SimpleNamespace
from unittest.mock import AsyncMock, patch from unittest.mock import AsyncMock, patch
@ -53,6 +54,15 @@ def _fake_tool_call_response() -> SimpleNamespace:
return SimpleNamespace(choices=[choice], usage=usage) return SimpleNamespace(choices=[choice], usage=usage)
class _StalledStream:
def __aiter__(self):
return self
async def __anext__(self):
await asyncio.sleep(3600)
raise StopAsyncIteration
def test_openrouter_spec_is_gateway() -> None: def test_openrouter_spec_is_gateway() -> None:
spec = find_by_name("openrouter") spec = find_by_name("openrouter")
assert spec is not None assert spec is not None
@ -214,3 +224,54 @@ def test_openai_model_passthrough() -> None:
spec=spec, spec=spec,
) )
assert provider.get_default_model() == "gpt-4o" assert provider.get_default_model() == "gpt-4o"
def test_openai_compat_strips_message_level_reasoning_fields() -> None:
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
provider = OpenAICompatProvider()
sanitized = provider._sanitize_messages([
{
"role": "assistant",
"content": "done",
"reasoning_content": "hidden",
"extra_content": {"debug": True},
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {"name": "fn", "arguments": "{}"},
"extra_content": {"google": {"thought_signature": "sig"}},
}
],
}
])
assert "reasoning_content" not in sanitized[0]
assert "extra_content" not in sanitized[0]
assert sanitized[0]["tool_calls"][0]["extra_content"] == {"google": {"thought_signature": "sig"}}
@pytest.mark.asyncio
async def test_openai_compat_stream_watchdog_returns_error_on_stall(monkeypatch) -> None:
monkeypatch.setenv("NANOBOT_STREAM_IDLE_TIMEOUT_S", "0")
mock_create = AsyncMock(return_value=_StalledStream())
spec = find_by_name("openai")
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
client_instance = MockClient.return_value
client_instance.chat.completions.create = mock_create
provider = OpenAICompatProvider(
api_key="sk-test-key",
default_model="gpt-4o",
spec=spec,
)
result = await provider.chat_stream(
messages=[{"role": "user", "content": "hello"}],
model="gpt-4o",
)
assert result.finish_reason == "error"
assert result.content is not None
assert "stream stalled" in result.content

View File

@ -211,3 +211,32 @@ async def test_image_fallback_without_meta_uses_default_placeholder() -> None:
content = msg.get("content") content = msg.get("content")
if isinstance(content, list): if isinstance(content, list):
assert any("[image omitted]" in (b.get("text") or "") for b in content) assert any("[image omitted]" in (b.get("text") or "") for b in content)
@pytest.mark.asyncio
async def test_chat_with_retry_uses_retry_after_and_emits_wait_progress(monkeypatch) -> None:
provider = ScriptedProvider([
LLMResponse(content="429 rate limit, retry after 7s", finish_reason="error"),
LLMResponse(content="ok"),
])
delays: list[float] = []
progress: list[str] = []
async def _fake_sleep(delay: float) -> None:
delays.append(delay)
async def _progress(msg: str) -> None:
progress.append(msg)
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
response = await provider.chat_with_retry(
messages=[{"role": "user", "content": "hello"}],
on_retry_wait=_progress,
)
assert response.content == "ok"
assert delays == [7.0]
assert progress and "7s" in progress[0]

View File

@ -196,7 +196,7 @@ async def test_execute_re_raises_external_cancellation() -> None:
wrapper = _make_wrapper(SimpleNamespace(call_tool=call_tool), timeout=10) wrapper = _make_wrapper(SimpleNamespace(call_tool=call_tool), timeout=10)
task = asyncio.create_task(wrapper.execute()) task = asyncio.create_task(wrapper.execute())
await started.wait() await asyncio.wait_for(started.wait(), timeout=1.0)
task.cancel() task.cancel()