feat: harden agent runtime for long-running tasks

This commit is contained in:
Xubin Ren 2026-04-01 19:12:49 +00:00
parent 63d646f731
commit fbedf7ad77
25 changed files with 1348 additions and 185 deletions

View File

@ -110,6 +110,20 @@ IMPORTANT: To send files (images, documents, audio, video) to the user, you MUST
lines += [f"Channel: {channel}", f"Chat ID: {chat_id}"]
return ContextBuilder._RUNTIME_CONTEXT_TAG + "\n" + "\n".join(lines)
@staticmethod
def _merge_message_content(left: Any, right: Any) -> str | list[dict[str, Any]]:
if isinstance(left, str) and isinstance(right, str):
return f"{left}\n\n{right}" if left else right
def _to_blocks(value: Any) -> list[dict[str, Any]]:
if isinstance(value, list):
return [item if isinstance(item, dict) else {"type": "text", "text": str(item)} for item in value]
if value is None:
return []
return [{"type": "text", "text": str(value)}]
return _to_blocks(left) + _to_blocks(right)
def _load_bootstrap_files(self) -> str:
"""Load all bootstrap files from workspace."""
parts = []
@ -142,12 +156,17 @@ IMPORTANT: To send files (images, documents, audio, video) to the user, you MUST
merged = f"{runtime_ctx}\n\n{user_content}"
else:
merged = [{"type": "text", "text": runtime_ctx}] + user_content
return [
messages = [
{"role": "system", "content": self.build_system_prompt(skill_names)},
*history,
{"role": current_role, "content": merged},
]
if messages[-1].get("role") == current_role:
last = dict(messages[-1])
last["content"] = self._merge_message_content(last.get("content"), merged)
messages[-1] = last
return messages
messages.append({"role": current_role, "content": merged})
return messages
def _build_user_content(self, text: str, media: list[str] | None) -> str | list[dict[str, Any]]:
"""Build user message content with optional base64-encoded images."""

View File

@ -29,8 +29,10 @@ from nanobot.agent.tools.web import WebFetchTool, WebSearchTool
from nanobot.bus.events import InboundMessage, OutboundMessage
from nanobot.command import CommandContext, CommandRouter, register_builtin_commands
from nanobot.bus.queue import MessageBus
from nanobot.config.schema import AgentDefaults
from nanobot.providers.base import LLMProvider
from nanobot.session.manager import Session, SessionManager
from nanobot.utils.helpers import image_placeholder_text, truncate_text
if TYPE_CHECKING:
from nanobot.config.schema import ChannelsConfig, ExecToolConfig, WebSearchConfig
@ -38,11 +40,7 @@ if TYPE_CHECKING:
class _LoopHook(AgentHook):
"""Core lifecycle hook for the main agent loop.
Handles streaming delta relay, progress reporting, tool-call logging,
and think-tag stripping for the built-in agent path.
"""
"""Core hook for the main loop."""
def __init__(
self,
@ -102,11 +100,7 @@ class _LoopHook(AgentHook):
class _LoopHookChain(AgentHook):
"""Run the core loop hook first, then best-effort extra hooks.
This preserves the historical failure behavior of ``_LoopHook`` while still
letting user-supplied hooks opt into ``CompositeHook`` isolation.
"""
"""Run the core hook before extra hooks."""
__slots__ = ("_primary", "_extras")
@ -154,7 +148,7 @@ class AgentLoop:
5. Sends responses back
"""
_TOOL_RESULT_MAX_CHARS = 16_000
_RUNTIME_CHECKPOINT_KEY = "runtime_checkpoint"
def __init__(
self,
@ -162,8 +156,11 @@ class AgentLoop:
provider: LLMProvider,
workspace: Path,
model: str | None = None,
max_iterations: int = 40,
context_window_tokens: int = 65_536,
max_iterations: int | None = None,
context_window_tokens: int | None = None,
context_block_limit: int | None = None,
max_tool_result_chars: int | None = None,
provider_retry_mode: str = "standard",
web_search_config: WebSearchConfig | None = None,
web_proxy: str | None = None,
exec_config: ExecToolConfig | None = None,
@ -177,13 +174,27 @@ class AgentLoop:
):
from nanobot.config.schema import ExecToolConfig, WebSearchConfig
defaults = AgentDefaults()
self.bus = bus
self.channels_config = channels_config
self.provider = provider
self.workspace = workspace
self.model = model or provider.get_default_model()
self.max_iterations = max_iterations
self.context_window_tokens = context_window_tokens
self.max_iterations = (
max_iterations if max_iterations is not None else defaults.max_tool_iterations
)
self.context_window_tokens = (
context_window_tokens
if context_window_tokens is not None
else defaults.context_window_tokens
)
self.context_block_limit = context_block_limit
self.max_tool_result_chars = (
max_tool_result_chars
if max_tool_result_chars is not None
else defaults.max_tool_result_chars
)
self.provider_retry_mode = provider_retry_mode
self.web_search_config = web_search_config or WebSearchConfig()
self.web_proxy = web_proxy
self.exec_config = exec_config or ExecToolConfig()
@ -202,6 +213,7 @@ class AgentLoop:
workspace=workspace,
bus=bus,
model=self.model,
max_tool_result_chars=self.max_tool_result_chars,
web_search_config=self.web_search_config,
web_proxy=web_proxy,
exec_config=self.exec_config,
@ -313,6 +325,7 @@ class AgentLoop:
on_stream: Callable[[str], Awaitable[None]] | None = None,
on_stream_end: Callable[..., Awaitable[None]] | None = None,
*,
session: Session | None = None,
channel: str = "cli",
chat_id: str = "direct",
message_id: str | None = None,
@ -339,14 +352,27 @@ class AgentLoop:
else loop_hook
)
async def _checkpoint(payload: dict[str, Any]) -> None:
if session is None:
return
self._set_runtime_checkpoint(session, payload)
result = await self.runner.run(AgentRunSpec(
initial_messages=initial_messages,
tools=self.tools,
model=self.model,
max_iterations=self.max_iterations,
max_tool_result_chars=self.max_tool_result_chars,
hook=hook,
error_message="Sorry, I encountered an error calling the AI model.",
concurrent_tools=True,
workspace=self.workspace,
session_key=session.key if session else None,
context_window_tokens=self.context_window_tokens,
context_block_limit=self.context_block_limit,
provider_retry_mode=self.provider_retry_mode,
progress_callback=on_progress,
checkpoint_callback=_checkpoint,
))
self._last_usage = result.usage
if result.stop_reason == "max_iterations":
@ -484,6 +510,8 @@ class AgentLoop:
logger.info("Processing system message from {}", msg.sender_id)
key = f"{channel}:{chat_id}"
session = self.sessions.get_or_create(key)
if self._restore_runtime_checkpoint(session):
self.sessions.save(session)
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
self._set_tool_context(channel, chat_id, msg.metadata.get("message_id"))
history = session.get_history(max_messages=0)
@ -494,10 +522,11 @@ class AgentLoop:
current_role=current_role,
)
final_content, _, all_msgs = await self._run_agent_loop(
messages, channel=channel, chat_id=chat_id,
messages, session=session, channel=channel, chat_id=chat_id,
message_id=msg.metadata.get("message_id"),
)
self._save_turn(session, all_msgs, 1 + len(history))
self._clear_runtime_checkpoint(session)
self.sessions.save(session)
self._schedule_background(self.memory_consolidator.maybe_consolidate_by_tokens(session))
return OutboundMessage(channel=channel, chat_id=chat_id,
@ -508,6 +537,8 @@ class AgentLoop:
key = session_key or msg.session_key
session = self.sessions.get_or_create(key)
if self._restore_runtime_checkpoint(session):
self.sessions.save(session)
# Slash commands
raw = msg.content.strip()
@ -543,6 +574,7 @@ class AgentLoop:
on_progress=on_progress or _bus_progress,
on_stream=on_stream,
on_stream_end=on_stream_end,
session=session,
channel=msg.channel, chat_id=msg.chat_id,
message_id=msg.metadata.get("message_id"),
)
@ -551,6 +583,7 @@ class AgentLoop:
final_content = "I've completed processing but have no response to give."
self._save_turn(session, all_msgs, 1 + len(history))
self._clear_runtime_checkpoint(session)
self.sessions.save(session)
self._schedule_background(self.memory_consolidator.maybe_consolidate_by_tokens(session))
@ -568,12 +601,6 @@ class AgentLoop:
metadata=meta,
)
@staticmethod
def _image_placeholder(block: dict[str, Any]) -> dict[str, str]:
"""Convert an inline image block into a compact text placeholder."""
path = (block.get("_meta") or {}).get("path", "")
return {"type": "text", "text": f"[image: {path}]" if path else "[image]"}
def _sanitize_persisted_blocks(
self,
content: list[dict[str, Any]],
@ -600,13 +627,14 @@ class AgentLoop:
block.get("type") == "image_url"
and block.get("image_url", {}).get("url", "").startswith("data:image/")
):
filtered.append(self._image_placeholder(block))
path = (block.get("_meta") or {}).get("path", "")
filtered.append({"type": "text", "text": image_placeholder_text(path)})
continue
if block.get("type") == "text" and isinstance(block.get("text"), str):
text = block["text"]
if truncate_text and len(text) > self._TOOL_RESULT_MAX_CHARS:
text = text[:self._TOOL_RESULT_MAX_CHARS] + "\n... (truncated)"
if truncate_text and len(text) > self.max_tool_result_chars:
text = truncate_text(text, self.max_tool_result_chars)
filtered.append({**block, "text": text})
continue
@ -623,8 +651,8 @@ class AgentLoop:
if role == "assistant" and not content and not entry.get("tool_calls"):
continue # skip empty assistant messages — they poison session context
if role == "tool":
if isinstance(content, str) and len(content) > self._TOOL_RESULT_MAX_CHARS:
entry["content"] = content[:self._TOOL_RESULT_MAX_CHARS] + "\n... (truncated)"
if isinstance(content, str) and len(content) > self.max_tool_result_chars:
entry["content"] = truncate_text(content, self.max_tool_result_chars)
elif isinstance(content, list):
filtered = self._sanitize_persisted_blocks(content, truncate_text=True)
if not filtered:
@ -647,6 +675,78 @@ class AgentLoop:
session.messages.append(entry)
session.updated_at = datetime.now()
def _set_runtime_checkpoint(self, session: Session, payload: dict[str, Any]) -> None:
"""Persist the latest in-flight turn state into session metadata."""
session.metadata[self._RUNTIME_CHECKPOINT_KEY] = payload
self.sessions.save(session)
def _clear_runtime_checkpoint(self, session: Session) -> None:
if self._RUNTIME_CHECKPOINT_KEY in session.metadata:
session.metadata.pop(self._RUNTIME_CHECKPOINT_KEY, None)
@staticmethod
def _checkpoint_message_key(message: dict[str, Any]) -> tuple[Any, ...]:
return (
message.get("role"),
message.get("content"),
message.get("tool_call_id"),
message.get("name"),
message.get("tool_calls"),
message.get("reasoning_content"),
message.get("thinking_blocks"),
)
def _restore_runtime_checkpoint(self, session: Session) -> bool:
"""Materialize an unfinished turn into session history before a new request."""
from datetime import datetime
checkpoint = session.metadata.get(self._RUNTIME_CHECKPOINT_KEY)
if not isinstance(checkpoint, dict):
return False
assistant_message = checkpoint.get("assistant_message")
completed_tool_results = checkpoint.get("completed_tool_results") or []
pending_tool_calls = checkpoint.get("pending_tool_calls") or []
restored_messages: list[dict[str, Any]] = []
if isinstance(assistant_message, dict):
restored = dict(assistant_message)
restored.setdefault("timestamp", datetime.now().isoformat())
restored_messages.append(restored)
for message in completed_tool_results:
if isinstance(message, dict):
restored = dict(message)
restored.setdefault("timestamp", datetime.now().isoformat())
restored_messages.append(restored)
for tool_call in pending_tool_calls:
if not isinstance(tool_call, dict):
continue
tool_id = tool_call.get("id")
name = ((tool_call.get("function") or {}).get("name")) or "tool"
restored_messages.append({
"role": "tool",
"tool_call_id": tool_id,
"name": name,
"content": "Error: Task interrupted before this tool finished.",
"timestamp": datetime.now().isoformat(),
})
overlap = 0
max_overlap = min(len(session.messages), len(restored_messages))
for size in range(max_overlap, 0, -1):
existing = session.messages[-size:]
restored = restored_messages[:size]
if all(
self._checkpoint_message_key(left) == self._checkpoint_message_key(right)
for left, right in zip(existing, restored)
):
overlap = size
break
session.messages.extend(restored_messages[overlap:])
self._clear_runtime_checkpoint(session)
return True
async def process_direct(
self,
content: str,

View File

@ -4,20 +4,29 @@ from __future__ import annotations
import asyncio
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from loguru import logger
from nanobot.agent.hook import AgentHook, AgentHookContext
from nanobot.agent.tools.registry import ToolRegistry
from nanobot.providers.base import LLMProvider, ToolCallRequest
from nanobot.utils.helpers import build_assistant_message
from nanobot.utils.helpers import (
build_assistant_message,
estimate_message_tokens,
estimate_prompt_tokens_chain,
find_legal_message_start,
maybe_persist_tool_result,
truncate_text,
)
_DEFAULT_MAX_ITERATIONS_MESSAGE = (
"I reached the maximum number of tool call iterations ({max_iterations}) "
"without completing the task. You can try breaking the task into smaller steps."
)
_DEFAULT_ERROR_MESSAGE = "Sorry, I encountered an error calling the AI model."
_SNIP_SAFETY_BUFFER = 1024
@dataclass(slots=True)
class AgentRunSpec:
"""Configuration for a single agent execution."""
@ -26,6 +35,7 @@ class AgentRunSpec:
tools: ToolRegistry
model: str
max_iterations: int
max_tool_result_chars: int
temperature: float | None = None
max_tokens: int | None = None
reasoning_effort: str | None = None
@ -34,6 +44,13 @@ class AgentRunSpec:
max_iterations_message: str | None = None
concurrent_tools: bool = False
fail_on_tool_error: bool = False
workspace: Path | None = None
session_key: str | None = None
context_window_tokens: int | None = None
context_block_limit: int | None = None
provider_retry_mode: str = "standard"
progress_callback: Any | None = None
checkpoint_callback: Any | None = None
@dataclass(slots=True)
@ -66,12 +83,25 @@ class AgentRunner:
tool_events: list[dict[str, str]] = []
for iteration in range(spec.max_iterations):
try:
messages = self._apply_tool_result_budget(spec, messages)
messages_for_model = self._snip_history(spec, messages)
except Exception as exc:
logger.warning(
"Context governance failed on turn {} for {}: {}; using raw messages",
iteration,
spec.session_key or "default",
exc,
)
messages_for_model = messages
context = AgentHookContext(iteration=iteration, messages=messages)
await hook.before_iteration(context)
kwargs: dict[str, Any] = {
"messages": messages,
"messages": messages_for_model,
"tools": spec.tools.get_definitions(),
"model": spec.model,
"retry_mode": spec.provider_retry_mode,
"on_retry_wait": spec.progress_callback,
}
if spec.temperature is not None:
kwargs["temperature"] = spec.temperature
@ -104,13 +134,25 @@ class AgentRunner:
if hook.wants_streaming():
await hook.on_stream_end(context, resuming=True)
messages.append(build_assistant_message(
assistant_message = build_assistant_message(
response.content or "",
tool_calls=[tc.to_openai_tool_call() for tc in response.tool_calls],
reasoning_content=response.reasoning_content,
thinking_blocks=response.thinking_blocks,
))
)
messages.append(assistant_message)
tools_used.extend(tc.name for tc in response.tool_calls)
await self._emit_checkpoint(
spec,
{
"phase": "awaiting_tools",
"iteration": iteration,
"model": spec.model,
"assistant_message": assistant_message,
"completed_tool_results": [],
"pending_tool_calls": [tc.to_openai_tool_call() for tc in response.tool_calls],
},
)
await hook.before_execute_tools(context)
@ -125,13 +167,31 @@ class AgentRunner:
context.stop_reason = stop_reason
await hook.after_iteration(context)
break
completed_tool_results: list[dict[str, Any]] = []
for tool_call, result in zip(response.tool_calls, results):
messages.append({
tool_message = {
"role": "tool",
"tool_call_id": tool_call.id,
"name": tool_call.name,
"content": result,
})
"content": self._normalize_tool_result(
spec,
tool_call.id,
result,
),
}
messages.append(tool_message)
completed_tool_results.append(tool_message)
await self._emit_checkpoint(
spec,
{
"phase": "tools_completed",
"iteration": iteration,
"model": spec.model,
"assistant_message": assistant_message,
"completed_tool_results": completed_tool_results,
"pending_tool_calls": [],
},
)
await hook.after_iteration(context)
continue
@ -143,6 +203,7 @@ class AgentRunner:
final_content = clean or spec.error_message or _DEFAULT_ERROR_MESSAGE
stop_reason = "error"
error = final_content
self._append_final_message(messages, final_content)
context.final_content = final_content
context.error = error
context.stop_reason = stop_reason
@ -154,6 +215,17 @@ class AgentRunner:
reasoning_content=response.reasoning_content,
thinking_blocks=response.thinking_blocks,
))
await self._emit_checkpoint(
spec,
{
"phase": "final_response",
"iteration": iteration,
"model": spec.model,
"assistant_message": messages[-1],
"completed_tool_results": [],
"pending_tool_calls": [],
},
)
final_content = clean
context.final_content = final_content
context.stop_reason = stop_reason
@ -163,6 +235,7 @@ class AgentRunner:
stop_reason = "max_iterations"
template = spec.max_iterations_message or _DEFAULT_MAX_ITERATIONS_MESSAGE
final_content = template.format(max_iterations=spec.max_iterations)
self._append_final_message(messages, final_content)
return AgentRunResult(
final_content=final_content,
@ -179,16 +252,17 @@ class AgentRunner:
spec: AgentRunSpec,
tool_calls: list[ToolCallRequest],
) -> tuple[list[Any], list[dict[str, str]], BaseException | None]:
if spec.concurrent_tools:
tool_results = await asyncio.gather(*(
self._run_tool(spec, tool_call)
for tool_call in tool_calls
))
else:
tool_results = [
await self._run_tool(spec, tool_call)
for tool_call in tool_calls
]
batches = self._partition_tool_batches(spec, tool_calls)
tool_results: list[tuple[Any, dict[str, str], BaseException | None]] = []
for batch in batches:
if spec.concurrent_tools and len(batch) > 1:
tool_results.extend(await asyncio.gather(*(
self._run_tool(spec, tool_call)
for tool_call in batch
)))
else:
for tool_call in batch:
tool_results.append(await self._run_tool(spec, tool_call))
results: list[Any] = []
events: list[dict[str, str]] = []
@ -205,8 +279,28 @@ class AgentRunner:
spec: AgentRunSpec,
tool_call: ToolCallRequest,
) -> tuple[Any, dict[str, str], BaseException | None]:
_HINT = "\n\n[Analyze the error above and try a different approach.]"
prepare_call = getattr(spec.tools, "prepare_call", None)
tool, params, prep_error = None, tool_call.arguments, None
if callable(prepare_call):
try:
prepared = prepare_call(tool_call.name, tool_call.arguments)
if isinstance(prepared, tuple) and len(prepared) == 3:
tool, params, prep_error = prepared
except Exception:
pass
if prep_error:
event = {
"name": tool_call.name,
"status": "error",
"detail": prep_error.split(": ", 1)[-1][:120],
}
return prep_error + _HINT, event, RuntimeError(prep_error) if spec.fail_on_tool_error else None
try:
result = await spec.tools.execute(tool_call.name, tool_call.arguments)
if tool is not None:
result = await tool.execute(**params)
else:
result = await spec.tools.execute(tool_call.name, params)
except asyncio.CancelledError:
raise
except BaseException as exc:
@ -219,14 +313,175 @@ class AgentRunner:
return f"Error: {type(exc).__name__}: {exc}", event, exc
return f"Error: {type(exc).__name__}: {exc}", event, None
if isinstance(result, str) and result.startswith("Error"):
event = {
"name": tool_call.name,
"status": "error",
"detail": result.replace("\n", " ").strip()[:120],
}
if spec.fail_on_tool_error:
return result + _HINT, event, RuntimeError(result)
return result + _HINT, event, None
detail = "" if result is None else str(result)
detail = detail.replace("\n", " ").strip()
if not detail:
detail = "(empty)"
elif len(detail) > 120:
detail = detail[:120] + "..."
return result, {
"name": tool_call.name,
"status": "error" if isinstance(result, str) and result.startswith("Error") else "ok",
"detail": detail,
}, None
return result, {"name": tool_call.name, "status": "ok", "detail": detail}, None
async def _emit_checkpoint(
self,
spec: AgentRunSpec,
payload: dict[str, Any],
) -> None:
callback = spec.checkpoint_callback
if callback is not None:
await callback(payload)
@staticmethod
def _append_final_message(messages: list[dict[str, Any]], content: str | None) -> None:
if not content:
return
if (
messages
and messages[-1].get("role") == "assistant"
and not messages[-1].get("tool_calls")
):
if messages[-1].get("content") == content:
return
messages[-1] = build_assistant_message(content)
return
messages.append(build_assistant_message(content))
def _normalize_tool_result(
self,
spec: AgentRunSpec,
tool_call_id: str,
result: Any,
) -> Any:
try:
content = maybe_persist_tool_result(
spec.workspace,
spec.session_key,
tool_call_id,
result,
max_chars=spec.max_tool_result_chars,
)
except Exception as exc:
logger.warning(
"Tool result persist failed for {} in {}: {}; using raw result",
tool_call_id,
spec.session_key or "default",
exc,
)
content = result
if isinstance(content, str) and len(content) > spec.max_tool_result_chars:
return truncate_text(content, spec.max_tool_result_chars)
return content
def _apply_tool_result_budget(
self,
spec: AgentRunSpec,
messages: list[dict[str, Any]],
) -> list[dict[str, Any]]:
updated = messages
for idx, message in enumerate(messages):
if message.get("role") != "tool":
continue
normalized = self._normalize_tool_result(
spec,
str(message.get("tool_call_id") or f"tool_{idx}"),
message.get("content"),
)
if normalized != message.get("content"):
if updated is messages:
updated = [dict(m) for m in messages]
updated[idx]["content"] = normalized
return updated
def _snip_history(
self,
spec: AgentRunSpec,
messages: list[dict[str, Any]],
) -> list[dict[str, Any]]:
if not messages or not spec.context_window_tokens:
return messages
provider_max_tokens = getattr(getattr(self.provider, "generation", None), "max_tokens", 4096)
max_output = spec.max_tokens if isinstance(spec.max_tokens, int) else (
provider_max_tokens if isinstance(provider_max_tokens, int) else 4096
)
budget = spec.context_block_limit or (
spec.context_window_tokens - max_output - _SNIP_SAFETY_BUFFER
)
if budget <= 0:
return messages
estimate, _ = estimate_prompt_tokens_chain(
self.provider,
spec.model,
messages,
spec.tools.get_definitions(),
)
if estimate <= budget:
return messages
system_messages = [dict(msg) for msg in messages if msg.get("role") == "system"]
non_system = [dict(msg) for msg in messages if msg.get("role") != "system"]
if not non_system:
return messages
system_tokens = sum(estimate_message_tokens(msg) for msg in system_messages)
remaining_budget = max(128, budget - system_tokens)
kept: list[dict[str, Any]] = []
kept_tokens = 0
for message in reversed(non_system):
msg_tokens = estimate_message_tokens(message)
if kept and kept_tokens + msg_tokens > remaining_budget:
break
kept.append(message)
kept_tokens += msg_tokens
kept.reverse()
if kept:
for i, message in enumerate(kept):
if message.get("role") == "user":
kept = kept[i:]
break
start = find_legal_message_start(kept)
if start:
kept = kept[start:]
if not kept:
kept = non_system[-min(len(non_system), 4) :]
start = find_legal_message_start(kept)
if start:
kept = kept[start:]
return system_messages + kept
def _partition_tool_batches(
self,
spec: AgentRunSpec,
tool_calls: list[ToolCallRequest],
) -> list[list[ToolCallRequest]]:
if not spec.concurrent_tools:
return [[tool_call] for tool_call in tool_calls]
batches: list[list[ToolCallRequest]] = []
current: list[ToolCallRequest] = []
for tool_call in tool_calls:
get_tool = getattr(spec.tools, "get", None)
tool = get_tool(tool_call.name) if callable(get_tool) else None
can_batch = bool(tool and tool.concurrency_safe)
if can_batch:
current.append(tool_call)
continue
if current:
batches.append(current)
current = []
batches.append([tool_call])
if current:
batches.append(current)
return batches

View File

@ -44,6 +44,7 @@ class SubagentManager:
provider: LLMProvider,
workspace: Path,
bus: MessageBus,
max_tool_result_chars: int,
model: str | None = None,
web_search_config: "WebSearchConfig | None" = None,
web_proxy: str | None = None,
@ -56,6 +57,7 @@ class SubagentManager:
self.workspace = workspace
self.bus = bus
self.model = model or provider.get_default_model()
self.max_tool_result_chars = max_tool_result_chars
self.web_search_config = web_search_config or WebSearchConfig()
self.web_proxy = web_proxy
self.exec_config = exec_config or ExecToolConfig()
@ -136,6 +138,7 @@ class SubagentManager:
tools=tools,
model=self.model,
max_iterations=15,
max_tool_result_chars=self.max_tool_result_chars,
hook=_SubagentHook(task_id),
max_iterations_message="Task completed but no final response was generated.",
error_message=None,

View File

@ -53,6 +53,21 @@ class Tool(ABC):
"""JSON Schema for tool parameters."""
pass
@property
def read_only(self) -> bool:
"""Whether this tool is side-effect free and safe to parallelize."""
return False
@property
def concurrency_safe(self) -> bool:
"""Whether this tool can run alongside other concurrency-safe tools."""
return self.read_only and not self.exclusive
@property
def exclusive(self) -> bool:
"""Whether this tool should run alone even if concurrency is enabled."""
return False
@abstractmethod
async def execute(self, **kwargs: Any) -> Any:
"""

View File

@ -73,6 +73,10 @@ class ReadFileTool(_FsTool):
"Use offset and limit to paginate through large files."
)
@property
def read_only(self) -> bool:
return True
@property
def parameters(self) -> dict[str, Any]:
return {
@ -344,6 +348,10 @@ class ListDirTool(_FsTool):
"Common noise directories (.git, node_modules, __pycache__, etc.) are auto-ignored."
)
@property
def read_only(self) -> bool:
return True
@property
def parameters(self) -> dict[str, Any]:
return {

View File

@ -35,22 +35,35 @@ class ToolRegistry:
"""Get all tool definitions in OpenAI format."""
return [tool.to_schema() for tool in self._tools.values()]
def prepare_call(
self,
name: str,
params: dict[str, Any],
) -> tuple[Tool | None, dict[str, Any], str | None]:
"""Resolve, cast, and validate one tool call."""
tool = self._tools.get(name)
if not tool:
return None, params, (
f"Error: Tool '{name}' not found. Available: {', '.join(self.tool_names)}"
)
cast_params = tool.cast_params(params)
errors = tool.validate_params(cast_params)
if errors:
return tool, cast_params, (
f"Error: Invalid parameters for tool '{name}': " + "; ".join(errors)
)
return tool, cast_params, None
async def execute(self, name: str, params: dict[str, Any]) -> Any:
"""Execute a tool by name with given parameters."""
_HINT = "\n\n[Analyze the error above and try a different approach.]"
tool = self._tools.get(name)
if not tool:
return f"Error: Tool '{name}' not found. Available: {', '.join(self.tool_names)}"
tool, params, error = self.prepare_call(name, params)
if error:
return error + _HINT
try:
# Attempt to cast parameters to match schema types
params = tool.cast_params(params)
# Validate parameters
errors = tool.validate_params(params)
if errors:
return f"Error: Invalid parameters for tool '{name}': " + "; ".join(errors) + _HINT
assert tool is not None # guarded by prepare_call()
result = await tool.execute(**params)
if isinstance(result, str) and result.startswith("Error"):
return result + _HINT

View File

@ -52,6 +52,10 @@ class ExecTool(Tool):
def description(self) -> str:
return "Execute a shell command and return its output. Use with caution."
@property
def exclusive(self) -> bool:
return True
@property
def parameters(self) -> dict[str, Any]:
return {

View File

@ -92,6 +92,10 @@ class WebSearchTool(Tool):
self.config = config if config is not None else WebSearchConfig()
self.proxy = proxy
@property
def read_only(self) -> bool:
return True
async def execute(self, query: str, count: int | None = None, **kwargs: Any) -> str:
provider = self.config.provider.strip().lower() or "brave"
n = min(max(count or self.config.max_results, 1), 10)
@ -234,6 +238,10 @@ class WebFetchTool(Tool):
self.max_chars = max_chars
self.proxy = proxy
@property
def read_only(self) -> bool:
return True
async def execute(self, url: str, extractMode: str = "markdown", maxChars: int | None = None, **kwargs: Any) -> Any:
max_chars = maxChars or self.max_chars
is_valid, error_msg = _validate_url_safe(url)

View File

@ -539,6 +539,9 @@ def serve(
model=runtime_config.agents.defaults.model,
max_iterations=runtime_config.agents.defaults.max_tool_iterations,
context_window_tokens=runtime_config.agents.defaults.context_window_tokens,
context_block_limit=runtime_config.agents.defaults.context_block_limit,
max_tool_result_chars=runtime_config.agents.defaults.max_tool_result_chars,
provider_retry_mode=runtime_config.agents.defaults.provider_retry_mode,
web_search_config=runtime_config.tools.web.search,
web_proxy=runtime_config.tools.web.proxy or None,
exec_config=runtime_config.tools.exec,
@ -626,6 +629,9 @@ def gateway(
model=config.agents.defaults.model,
max_iterations=config.agents.defaults.max_tool_iterations,
context_window_tokens=config.agents.defaults.context_window_tokens,
context_block_limit=config.agents.defaults.context_block_limit,
max_tool_result_chars=config.agents.defaults.max_tool_result_chars,
provider_retry_mode=config.agents.defaults.provider_retry_mode,
web_search_config=config.tools.web.search,
web_proxy=config.tools.web.proxy or None,
exec_config=config.tools.exec,
@ -832,6 +838,9 @@ def agent(
model=config.agents.defaults.model,
max_iterations=config.agents.defaults.max_tool_iterations,
context_window_tokens=config.agents.defaults.context_window_tokens,
context_block_limit=config.agents.defaults.context_block_limit,
max_tool_result_chars=config.agents.defaults.max_tool_result_chars,
provider_retry_mode=config.agents.defaults.provider_retry_mode,
web_search_config=config.tools.web.search,
web_proxy=config.tools.web.proxy or None,
exec_config=config.tools.exec,

View File

@ -38,8 +38,11 @@ class AgentDefaults(Base):
)
max_tokens: int = 8192
context_window_tokens: int = 65_536
context_block_limit: int | None = None
temperature: float = 0.1
max_tool_iterations: int = 40
max_tool_iterations: int = 200
max_tool_result_chars: int = 16_000
provider_retry_mode: Literal["standard", "persistent"] = "standard"
reasoning_effort: str | None = None # low / medium / high - enables LLM thinking mode
timezone: str = "UTC" # IANA timezone, e.g. "Asia/Shanghai", "America/New_York"

View File

@ -73,6 +73,9 @@ class Nanobot:
model=defaults.model,
max_iterations=defaults.max_tool_iterations,
context_window_tokens=defaults.context_window_tokens,
context_block_limit=defaults.context_block_limit,
max_tool_result_chars=defaults.max_tool_result_chars,
provider_retry_mode=defaults.provider_retry_mode,
web_search_config=config.tools.web.search,
web_proxy=config.tools.web.proxy or None,
exec_config=config.tools.exec,

View File

@ -2,6 +2,8 @@
from __future__ import annotations
import asyncio
import os
import re
import secrets
import string
@ -427,13 +429,33 @@ class AnthropicProvider(LLMProvider):
messages, tools, model, max_tokens, temperature,
reasoning_effort, tool_choice,
)
idle_timeout_s = int(os.environ.get("NANOBOT_STREAM_IDLE_TIMEOUT_S", "90"))
try:
async with self._client.messages.stream(**kwargs) as stream:
if on_content_delta:
async for text in stream.text_stream:
stream_iter = stream.text_stream.__aiter__()
while True:
try:
text = await asyncio.wait_for(
stream_iter.__anext__(),
timeout=idle_timeout_s,
)
except StopAsyncIteration:
break
await on_content_delta(text)
response = await stream.get_final_message()
response = await asyncio.wait_for(
stream.get_final_message(),
timeout=idle_timeout_s,
)
return self._parse_response(response)
except asyncio.TimeoutError:
return LLMResponse(
content=(
f"Error calling LLM: stream stalled for more than "
f"{idle_timeout_s} seconds"
),
finish_reason="error",
)
except Exception as e:
return LLMResponse(content=f"Error calling LLM: {e}", finish_reason="error")

View File

@ -2,6 +2,7 @@
import asyncio
import json
import re
from abc import ABC, abstractmethod
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field
@ -9,6 +10,8 @@ from typing import Any
from loguru import logger
from nanobot.utils.helpers import image_placeholder_text
@dataclass
class ToolCallRequest:
@ -57,13 +60,7 @@ class LLMResponse:
@dataclass(frozen=True)
class GenerationSettings:
"""Default generation parameters for LLM calls.
Stored on the provider so every call site inherits the same defaults
without having to pass temperature / max_tokens / reasoning_effort
through every layer. Individual call sites can still override by
passing explicit keyword arguments to chat() / chat_with_retry().
"""
"""Default generation settings."""
temperature: float = 0.7
max_tokens: int = 4096
@ -71,14 +68,11 @@ class GenerationSettings:
class LLMProvider(ABC):
"""
Abstract base class for LLM providers.
Implementations should handle the specifics of each provider's API
while maintaining a consistent interface.
"""
"""Base class for LLM providers."""
_CHAT_RETRY_DELAYS = (1, 2, 4)
_PERSISTENT_MAX_DELAY = 60
_RETRY_HEARTBEAT_CHUNK = 30
_TRANSIENT_ERROR_MARKERS = (
"429",
"rate limit",
@ -208,7 +202,7 @@ class LLMProvider(ABC):
for b in content:
if isinstance(b, dict) and b.get("type") == "image_url":
path = (b.get("_meta") or {}).get("path", "")
placeholder = f"[image: {path}]" if path else "[image omitted]"
placeholder = image_placeholder_text(path, empty="[image omitted]")
new_content.append({"type": "text", "text": placeholder})
found = True
else:
@ -273,6 +267,8 @@ class LLMProvider(ABC):
reasoning_effort: object = _SENTINEL,
tool_choice: str | dict[str, Any] | None = None,
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
retry_mode: str = "standard",
on_retry_wait: Callable[[str], Awaitable[None]] | None = None,
) -> LLMResponse:
"""Call chat_stream() with retry on transient provider failures."""
if max_tokens is self._SENTINEL:
@ -288,28 +284,13 @@ class LLMProvider(ABC):
reasoning_effort=reasoning_effort, tool_choice=tool_choice,
on_content_delta=on_content_delta,
)
for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1):
response = await self._safe_chat_stream(**kw)
if response.finish_reason != "error":
return response
if not self._is_transient_error(response.content):
stripped = self._strip_image_content(messages)
if stripped is not None:
logger.warning("Non-transient LLM error with image content, retrying without images")
return await self._safe_chat_stream(**{**kw, "messages": stripped})
return response
logger.warning(
"LLM transient error (attempt {}/{}), retrying in {}s: {}",
attempt, len(self._CHAT_RETRY_DELAYS), delay,
(response.content or "")[:120].lower(),
)
await asyncio.sleep(delay)
return await self._safe_chat_stream(**kw)
return await self._run_with_retry(
self._safe_chat_stream,
kw,
messages,
retry_mode=retry_mode,
on_retry_wait=on_retry_wait,
)
async def chat_with_retry(
self,
@ -320,6 +301,8 @@ class LLMProvider(ABC):
temperature: object = _SENTINEL,
reasoning_effort: object = _SENTINEL,
tool_choice: str | dict[str, Any] | None = None,
retry_mode: str = "standard",
on_retry_wait: Callable[[str], Awaitable[None]] | None = None,
) -> LLMResponse:
"""Call chat() with retry on transient provider failures.
@ -339,28 +322,102 @@ class LLMProvider(ABC):
max_tokens=max_tokens, temperature=temperature,
reasoning_effort=reasoning_effort, tool_choice=tool_choice,
)
return await self._run_with_retry(
self._safe_chat,
kw,
messages,
retry_mode=retry_mode,
on_retry_wait=on_retry_wait,
)
for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1):
response = await self._safe_chat(**kw)
@classmethod
def _extract_retry_after(cls, content: str | None) -> float | None:
text = (content or "").lower()
match = re.search(r"retry after\s+(\d+(?:\.\d+)?)\s*(ms|milliseconds|s|sec|secs|seconds|m|min|minutes)?", text)
if not match:
return None
value = float(match.group(1))
unit = (match.group(2) or "s").lower()
if unit in {"ms", "milliseconds"}:
return max(0.1, value / 1000.0)
if unit in {"m", "min", "minutes"}:
return value * 60.0
return value
async def _sleep_with_heartbeat(
self,
delay: float,
*,
attempt: int,
persistent: bool,
on_retry_wait: Callable[[str], Awaitable[None]] | None = None,
) -> None:
remaining = max(0.0, delay)
while remaining > 0:
if on_retry_wait:
kind = "persistent retry" if persistent else "retry"
await on_retry_wait(
f"Model request failed, {kind} in {max(1, int(round(remaining)))}s "
f"(attempt {attempt})."
)
chunk = min(remaining, self._RETRY_HEARTBEAT_CHUNK)
await asyncio.sleep(chunk)
remaining -= chunk
async def _run_with_retry(
self,
call: Callable[..., Awaitable[LLMResponse]],
kw: dict[str, Any],
original_messages: list[dict[str, Any]],
*,
retry_mode: str,
on_retry_wait: Callable[[str], Awaitable[None]] | None,
) -> LLMResponse:
attempt = 0
delays = list(self._CHAT_RETRY_DELAYS)
persistent = retry_mode == "persistent"
last_response: LLMResponse | None = None
while True:
attempt += 1
response = await call(**kw)
if response.finish_reason != "error":
return response
last_response = response
if not self._is_transient_error(response.content):
stripped = self._strip_image_content(messages)
if stripped is not None:
logger.warning("Non-transient LLM error with image content, retrying without images")
return await self._safe_chat(**{**kw, "messages": stripped})
stripped = self._strip_image_content(original_messages)
if stripped is not None and stripped != kw["messages"]:
logger.warning(
"Non-transient LLM error with image content, retrying without images"
)
retry_kw = dict(kw)
retry_kw["messages"] = stripped
return await call(**retry_kw)
return response
if not persistent and attempt > len(delays):
break
base_delay = delays[min(attempt - 1, len(delays) - 1)]
delay = self._extract_retry_after(response.content) or base_delay
if persistent:
delay = min(delay, self._PERSISTENT_MAX_DELAY)
logger.warning(
"LLM transient error (attempt {}/{}), retrying in {}s: {}",
attempt, len(self._CHAT_RETRY_DELAYS), delay,
"LLM transient error (attempt {}{}), retrying in {}s: {}",
attempt,
"+" if persistent and attempt > len(delays) else f"/{len(delays)}",
int(round(delay)),
(response.content or "")[:120].lower(),
)
await asyncio.sleep(delay)
await self._sleep_with_heartbeat(
delay,
attempt=attempt,
persistent=persistent,
on_retry_wait=on_retry_wait,
)
return await self._safe_chat(**kw)
return last_response if last_response is not None else await call(**kw)
@abstractmethod
def get_default_model(self) -> str:

View File

@ -2,6 +2,7 @@
from __future__ import annotations
import asyncio
import hashlib
import os
import secrets
@ -20,7 +21,6 @@ if TYPE_CHECKING:
_ALLOWED_MSG_KEYS = frozenset({
"role", "content", "tool_calls", "tool_call_id", "name",
"reasoning_content", "extra_content",
})
_ALNUM = string.ascii_letters + string.digits
@ -572,16 +572,33 @@ class OpenAICompatProvider(LLMProvider):
)
kwargs["stream"] = True
kwargs["stream_options"] = {"include_usage": True}
idle_timeout_s = int(os.environ.get("NANOBOT_STREAM_IDLE_TIMEOUT_S", "90"))
try:
stream = await self._client.chat.completions.create(**kwargs)
chunks: list[Any] = []
async for chunk in stream:
stream_iter = stream.__aiter__()
while True:
try:
chunk = await asyncio.wait_for(
stream_iter.__anext__(),
timeout=idle_timeout_s,
)
except StopAsyncIteration:
break
chunks.append(chunk)
if on_content_delta and chunk.choices:
text = getattr(chunk.choices[0].delta, "content", None)
if text:
await on_content_delta(text)
return self._parse_chunks(chunks)
except asyncio.TimeoutError:
return LLMResponse(
content=(
f"Error calling LLM: stream stalled for more than "
f"{idle_timeout_s} seconds"
),
finish_reason="error",
)
except Exception as e:
return self._handle_error(e)

View File

@ -10,20 +10,12 @@ from typing import Any
from loguru import logger
from nanobot.config.paths import get_legacy_sessions_dir
from nanobot.utils.helpers import ensure_dir, safe_filename
from nanobot.utils.helpers import ensure_dir, find_legal_message_start, safe_filename
@dataclass
class Session:
"""
A conversation session.
Stores messages in JSONL format for easy reading and persistence.
Important: Messages are append-only for LLM cache efficiency.
The consolidation process writes summaries to MEMORY.md/HISTORY.md
but does NOT modify the messages list or get_history() output.
"""
"""A conversation session."""
key: str # channel:chat_id
messages: list[dict[str, Any]] = field(default_factory=list)
@ -43,43 +35,19 @@ class Session:
self.messages.append(msg)
self.updated_at = datetime.now()
@staticmethod
def _find_legal_start(messages: list[dict[str, Any]]) -> int:
"""Find first index where every tool result has a matching assistant tool_call."""
declared: set[str] = set()
start = 0
for i, msg in enumerate(messages):
role = msg.get("role")
if role == "assistant":
for tc in msg.get("tool_calls") or []:
if isinstance(tc, dict) and tc.get("id"):
declared.add(str(tc["id"]))
elif role == "tool":
tid = msg.get("tool_call_id")
if tid and str(tid) not in declared:
start = i + 1
declared.clear()
for prev in messages[start:i + 1]:
if prev.get("role") == "assistant":
for tc in prev.get("tool_calls") or []:
if isinstance(tc, dict) and tc.get("id"):
declared.add(str(tc["id"]))
return start
def get_history(self, max_messages: int = 500) -> list[dict[str, Any]]:
"""Return unconsolidated messages for LLM input, aligned to a legal tool-call boundary."""
unconsolidated = self.messages[self.last_consolidated:]
sliced = unconsolidated[-max_messages:]
# Drop leading non-user messages to avoid starting mid-turn when possible.
# Avoid starting mid-turn when possible.
for i, message in enumerate(sliced):
if message.get("role") == "user":
sliced = sliced[i:]
break
# Some providers reject orphan tool results if the matching assistant
# tool_calls message fell outside the fixed-size history window.
start = self._find_legal_start(sliced)
# Drop orphan tool results at the front.
start = find_legal_message_start(sliced)
if start:
sliced = sliced[start:]
@ -115,7 +83,7 @@ class Session:
retained = self.messages[start_idx:]
# Mirror get_history(): avoid persisting orphan tool results at the front.
start = self._find_legal_start(retained)
start = find_legal_message_start(retained)
if start:
retained = retained[start:]

View File

@ -3,7 +3,9 @@
import base64
import json
import re
import shutil
import time
import uuid
from datetime import datetime
from pathlib import Path
from typing import Any
@ -56,11 +58,7 @@ def timestamp() -> str:
def current_time_str(timezone: str | None = None) -> str:
"""Human-readable current time with weekday and UTC offset.
When *timezone* is a valid IANA name (e.g. ``"Asia/Shanghai"``), the time
is converted to that zone. Otherwise falls back to the host local time.
"""
"""Return the current time string."""
from zoneinfo import ZoneInfo
try:
@ -76,12 +74,164 @@ def current_time_str(timezone: str | None = None) -> str:
_UNSAFE_CHARS = re.compile(r'[<>:"/\\|?*]')
_TOOL_RESULT_PREVIEW_CHARS = 1200
_TOOL_RESULTS_DIR = ".nanobot/tool-results"
_TOOL_RESULT_RETENTION_SECS = 7 * 24 * 60 * 60
_TOOL_RESULT_MAX_BUCKETS = 32
def safe_filename(name: str) -> str:
"""Replace unsafe path characters with underscores."""
return _UNSAFE_CHARS.sub("_", name).strip()
def image_placeholder_text(path: str | None, *, empty: str = "[image]") -> str:
"""Build an image placeholder string."""
return f"[image: {path}]" if path else empty
def truncate_text(text: str, max_chars: int) -> str:
"""Truncate text with a stable suffix."""
if max_chars <= 0 or len(text) <= max_chars:
return text
return text[:max_chars] + "\n... (truncated)"
def find_legal_message_start(messages: list[dict[str, Any]]) -> int:
"""Find the first index whose tool results have matching assistant calls."""
declared: set[str] = set()
start = 0
for i, msg in enumerate(messages):
role = msg.get("role")
if role == "assistant":
for tc in msg.get("tool_calls") or []:
if isinstance(tc, dict) and tc.get("id"):
declared.add(str(tc["id"]))
elif role == "tool":
tid = msg.get("tool_call_id")
if tid and str(tid) not in declared:
start = i + 1
declared.clear()
for prev in messages[start : i + 1]:
if prev.get("role") == "assistant":
for tc in prev.get("tool_calls") or []:
if isinstance(tc, dict) and tc.get("id"):
declared.add(str(tc["id"]))
return start
def _stringify_text_blocks(content: list[dict[str, Any]]) -> str | None:
parts: list[str] = []
for block in content:
if not isinstance(block, dict):
return None
if block.get("type") != "text":
return None
text = block.get("text")
if not isinstance(text, str):
return None
parts.append(text)
return "\n".join(parts)
def _render_tool_result_reference(
filepath: Path,
*,
original_size: int,
preview: str,
truncated_preview: bool,
) -> str:
result = (
f"[tool output persisted]\n"
f"Full output saved to: {filepath}\n"
f"Original size: {original_size} chars\n"
f"Preview:\n{preview}"
)
if truncated_preview:
result += "\n...\n(Read the saved file if you need the full output.)"
return result
def _bucket_mtime(path: Path) -> float:
try:
return path.stat().st_mtime
except OSError:
return 0.0
def _cleanup_tool_result_buckets(root: Path, current_bucket: Path) -> None:
siblings = [path for path in root.iterdir() if path.is_dir() and path != current_bucket]
cutoff = time.time() - _TOOL_RESULT_RETENTION_SECS
for path in siblings:
if _bucket_mtime(path) < cutoff:
shutil.rmtree(path, ignore_errors=True)
keep = max(_TOOL_RESULT_MAX_BUCKETS - 1, 0)
siblings = [path for path in siblings if path.exists()]
if len(siblings) <= keep:
return
siblings.sort(key=_bucket_mtime, reverse=True)
for path in siblings[keep:]:
shutil.rmtree(path, ignore_errors=True)
def _write_text_atomic(path: Path, content: str) -> None:
tmp = path.with_name(f".{path.name}.{uuid.uuid4().hex}.tmp")
try:
tmp.write_text(content, encoding="utf-8")
tmp.replace(path)
finally:
if tmp.exists():
tmp.unlink(missing_ok=True)
def maybe_persist_tool_result(
workspace: Path | None,
session_key: str | None,
tool_call_id: str,
content: Any,
*,
max_chars: int,
) -> Any:
"""Persist oversized tool output and replace it with a stable reference string."""
if workspace is None or max_chars <= 0:
return content
text_payload: str | None = None
suffix = "txt"
if isinstance(content, str):
text_payload = content
elif isinstance(content, list):
text_payload = _stringify_text_blocks(content)
if text_payload is None:
return content
suffix = "json"
else:
return content
if len(text_payload) <= max_chars:
return content
root = ensure_dir(workspace / _TOOL_RESULTS_DIR)
bucket = ensure_dir(root / safe_filename(session_key or "default"))
try:
_cleanup_tool_result_buckets(root, bucket)
except Exception:
pass
path = bucket / f"{safe_filename(tool_call_id)}.{suffix}"
if not path.exists():
if suffix == "json" and isinstance(content, list):
_write_text_atomic(path, json.dumps(content, ensure_ascii=False, indent=2))
else:
_write_text_atomic(path, text_payload)
preview = text_payload[:_TOOL_RESULT_PREVIEW_CHARS]
return _render_tool_result_reference(
path,
original_size=len(text_payload),
preview=preview,
truncated_preview=len(text_payload) > _TOOL_RESULT_PREVIEW_CHARS,
)
def split_message(content: str, max_len: int = 2000) -> list[str]:
"""
Split content into chunks within max_len, preferring line breaks.

View File

@ -71,3 +71,19 @@ def test_runtime_context_is_separate_untrusted_user_message(tmp_path) -> None:
assert "Channel: cli" in user_content
assert "Chat ID: direct" in user_content
assert "Return exactly: OK" in user_content
def test_subagent_result_does_not_create_consecutive_assistant_messages(tmp_path) -> None:
workspace = _make_workspace(tmp_path)
builder = ContextBuilder(workspace)
messages = builder.build_messages(
history=[{"role": "assistant", "content": "previous result"}],
current_message="subagent result",
channel="cli",
chat_id="direct",
current_role="assistant",
)
for left, right in zip(messages, messages[1:]):
assert not (left.get("role") == right.get("role") == "assistant")

View File

@ -5,7 +5,9 @@ from nanobot.session.manager import Session
def _mk_loop() -> AgentLoop:
loop = AgentLoop.__new__(AgentLoop)
loop._TOOL_RESULT_MAX_CHARS = AgentLoop._TOOL_RESULT_MAX_CHARS
from nanobot.config.schema import AgentDefaults
loop.max_tool_result_chars = AgentDefaults().max_tool_result_chars
return loop
@ -72,3 +74,129 @@ def test_save_turn_keeps_tool_results_under_16k() -> None:
)
assert session.messages[0]["content"] == content
def test_restore_runtime_checkpoint_rehydrates_completed_and_pending_tools() -> None:
loop = _mk_loop()
session = Session(
key="test:checkpoint",
metadata={
AgentLoop._RUNTIME_CHECKPOINT_KEY: {
"assistant_message": {
"role": "assistant",
"content": "working",
"tool_calls": [
{
"id": "call_done",
"type": "function",
"function": {"name": "read_file", "arguments": "{}"},
},
{
"id": "call_pending",
"type": "function",
"function": {"name": "exec", "arguments": "{}"},
},
],
},
"completed_tool_results": [
{
"role": "tool",
"tool_call_id": "call_done",
"name": "read_file",
"content": "ok",
}
],
"pending_tool_calls": [
{
"id": "call_pending",
"type": "function",
"function": {"name": "exec", "arguments": "{}"},
}
],
}
},
)
restored = loop._restore_runtime_checkpoint(session)
assert restored is True
assert session.metadata.get(AgentLoop._RUNTIME_CHECKPOINT_KEY) is None
assert session.messages[0]["role"] == "assistant"
assert session.messages[1]["tool_call_id"] == "call_done"
assert session.messages[2]["tool_call_id"] == "call_pending"
assert "interrupted before this tool finished" in session.messages[2]["content"].lower()
def test_restore_runtime_checkpoint_dedupes_overlapping_tail() -> None:
loop = _mk_loop()
session = Session(
key="test:checkpoint-overlap",
messages=[
{
"role": "assistant",
"content": "working",
"tool_calls": [
{
"id": "call_done",
"type": "function",
"function": {"name": "read_file", "arguments": "{}"},
},
{
"id": "call_pending",
"type": "function",
"function": {"name": "exec", "arguments": "{}"},
},
],
},
{
"role": "tool",
"tool_call_id": "call_done",
"name": "read_file",
"content": "ok",
},
],
metadata={
AgentLoop._RUNTIME_CHECKPOINT_KEY: {
"assistant_message": {
"role": "assistant",
"content": "working",
"tool_calls": [
{
"id": "call_done",
"type": "function",
"function": {"name": "read_file", "arguments": "{}"},
},
{
"id": "call_pending",
"type": "function",
"function": {"name": "exec", "arguments": "{}"},
},
],
},
"completed_tool_results": [
{
"role": "tool",
"tool_call_id": "call_done",
"name": "read_file",
"content": "ok",
}
],
"pending_tool_calls": [
{
"id": "call_pending",
"type": "function",
"function": {"name": "exec", "arguments": "{}"},
}
],
}
},
)
restored = loop._restore_runtime_checkpoint(session)
assert restored is True
assert session.metadata.get(AgentLoop._RUNTIME_CHECKPOINT_KEY) is None
assert len(session.messages) == 3
assert session.messages[0]["role"] == "assistant"
assert session.messages[1]["tool_call_id"] == "call_done"
assert session.messages[2]["tool_call_id"] == "call_pending"

View File

@ -2,12 +2,20 @@
from __future__ import annotations
import asyncio
import os
import time
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from nanobot.config.schema import AgentDefaults
from nanobot.agent.tools.base import Tool
from nanobot.agent.tools.registry import ToolRegistry
from nanobot.providers.base import LLMResponse, ToolCallRequest
_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars
def _make_loop(tmp_path):
from nanobot.agent.loop import AgentLoop
@ -60,6 +68,7 @@ async def test_runner_preserves_reasoning_fields_and_tool_results():
tools=tools,
model="test-model",
max_iterations=3,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
))
assert result.final_content == "done"
@ -135,6 +144,7 @@ async def test_runner_calls_hooks_in_order():
tools=tools,
model="test-model",
max_iterations=3,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
hook=RecordingHook(),
))
@ -191,6 +201,7 @@ async def test_runner_streaming_hook_receives_deltas_and_end_signal():
tools=tools,
model="test-model",
max_iterations=1,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
hook=StreamingHook(),
))
@ -219,6 +230,7 @@ async def test_runner_returns_max_iterations_fallback():
tools=tools,
model="test-model",
max_iterations=2,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
))
assert result.stop_reason == "max_iterations"
@ -226,7 +238,8 @@ async def test_runner_returns_max_iterations_fallback():
"I reached the maximum number of tool call iterations (2) "
"without completing the task. You can try breaking the task into smaller steps."
)
assert result.messages[-1]["role"] == "assistant"
assert result.messages[-1]["content"] == result.final_content
@pytest.mark.asyncio
async def test_runner_returns_structured_tool_error():
@ -248,6 +261,7 @@ async def test_runner_returns_structured_tool_error():
tools=tools,
model="test-model",
max_iterations=2,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
fail_on_tool_error=True,
))
@ -258,6 +272,232 @@ async def test_runner_returns_structured_tool_error():
]
@pytest.mark.asyncio
async def test_runner_persists_large_tool_results_for_follow_up_calls(tmp_path):
from nanobot.agent.runner import AgentRunSpec, AgentRunner
provider = MagicMock()
captured_second_call: list[dict] = []
call_count = {"n": 0}
async def chat_with_retry(*, messages, **kwargs):
call_count["n"] += 1
if call_count["n"] == 1:
return LLMResponse(
content="working",
tool_calls=[ToolCallRequest(id="call_big", name="list_dir", arguments={"path": "."})],
usage={"prompt_tokens": 5, "completion_tokens": 3},
)
captured_second_call[:] = messages
return LLMResponse(content="done", tool_calls=[], usage={})
provider.chat_with_retry = chat_with_retry
tools = MagicMock()
tools.get_definitions.return_value = []
tools.execute = AsyncMock(return_value="x" * 20_000)
runner = AgentRunner(provider)
result = await runner.run(AgentRunSpec(
initial_messages=[{"role": "user", "content": "do task"}],
tools=tools,
model="test-model",
max_iterations=2,
workspace=tmp_path,
session_key="test:runner",
max_tool_result_chars=2048,
))
assert result.final_content == "done"
tool_message = next(msg for msg in captured_second_call if msg.get("role") == "tool")
assert "[tool output persisted]" in tool_message["content"]
assert "tool-results" in tool_message["content"]
assert (tmp_path / ".nanobot" / "tool-results" / "test_runner" / "call_big.txt").exists()
def test_persist_tool_result_prunes_old_session_buckets(tmp_path):
from nanobot.utils.helpers import maybe_persist_tool_result
root = tmp_path / ".nanobot" / "tool-results"
old_bucket = root / "old_session"
recent_bucket = root / "recent_session"
old_bucket.mkdir(parents=True)
recent_bucket.mkdir(parents=True)
(old_bucket / "old.txt").write_text("old", encoding="utf-8")
(recent_bucket / "recent.txt").write_text("recent", encoding="utf-8")
stale = time.time() - (8 * 24 * 60 * 60)
os.utime(old_bucket, (stale, stale))
os.utime(old_bucket / "old.txt", (stale, stale))
persisted = maybe_persist_tool_result(
tmp_path,
"current:session",
"call_big",
"x" * 5000,
max_chars=64,
)
assert "[tool output persisted]" in persisted
assert not old_bucket.exists()
assert recent_bucket.exists()
assert (root / "current_session" / "call_big.txt").exists()
def test_persist_tool_result_leaves_no_temp_files(tmp_path):
from nanobot.utils.helpers import maybe_persist_tool_result
root = tmp_path / ".nanobot" / "tool-results"
maybe_persist_tool_result(
tmp_path,
"current:session",
"call_big",
"x" * 5000,
max_chars=64,
)
assert (root / "current_session" / "call_big.txt").exists()
assert list((root / "current_session").glob("*.tmp")) == []
@pytest.mark.asyncio
async def test_runner_uses_raw_messages_when_context_governance_fails():
from nanobot.agent.runner import AgentRunSpec, AgentRunner
provider = MagicMock()
captured_messages: list[dict] = []
async def chat_with_retry(*, messages, **kwargs):
captured_messages[:] = messages
return LLMResponse(content="done", tool_calls=[], usage={})
provider.chat_with_retry = chat_with_retry
tools = MagicMock()
tools.get_definitions.return_value = []
initial_messages = [
{"role": "system", "content": "system"},
{"role": "user", "content": "hello"},
]
runner = AgentRunner(provider)
runner._snip_history = MagicMock(side_effect=RuntimeError("boom")) # type: ignore[method-assign]
result = await runner.run(AgentRunSpec(
initial_messages=initial_messages,
tools=tools,
model="test-model",
max_iterations=1,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
))
assert result.final_content == "done"
assert captured_messages == initial_messages
@pytest.mark.asyncio
async def test_runner_keeps_going_when_tool_result_persistence_fails():
from nanobot.agent.runner import AgentRunSpec, AgentRunner
provider = MagicMock()
captured_second_call: list[dict] = []
call_count = {"n": 0}
async def chat_with_retry(*, messages, **kwargs):
call_count["n"] += 1
if call_count["n"] == 1:
return LLMResponse(
content="working",
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
usage={"prompt_tokens": 5, "completion_tokens": 3},
)
captured_second_call[:] = messages
return LLMResponse(content="done", tool_calls=[], usage={})
provider.chat_with_retry = chat_with_retry
tools = MagicMock()
tools.get_definitions.return_value = []
tools.execute = AsyncMock(return_value="tool result")
runner = AgentRunner(provider)
with patch("nanobot.agent.runner.maybe_persist_tool_result", side_effect=RuntimeError("disk full")):
result = await runner.run(AgentRunSpec(
initial_messages=[{"role": "user", "content": "do task"}],
tools=tools,
model="test-model",
max_iterations=2,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
))
assert result.final_content == "done"
tool_message = next(msg for msg in captured_second_call if msg.get("role") == "tool")
assert tool_message["content"] == "tool result"
class _DelayTool(Tool):
def __init__(self, name: str, *, delay: float, read_only: bool, shared_events: list[str]):
self._name = name
self._delay = delay
self._read_only = read_only
self._shared_events = shared_events
@property
def name(self) -> str:
return self._name
@property
def description(self) -> str:
return self._name
@property
def parameters(self) -> dict:
return {"type": "object", "properties": {}, "required": []}
@property
def read_only(self) -> bool:
return self._read_only
async def execute(self, **kwargs):
self._shared_events.append(f"start:{self._name}")
await asyncio.sleep(self._delay)
self._shared_events.append(f"end:{self._name}")
return self._name
@pytest.mark.asyncio
async def test_runner_batches_read_only_tools_before_exclusive_work():
from nanobot.agent.runner import AgentRunSpec, AgentRunner
tools = ToolRegistry()
shared_events: list[str] = []
read_a = _DelayTool("read_a", delay=0.05, read_only=True, shared_events=shared_events)
read_b = _DelayTool("read_b", delay=0.05, read_only=True, shared_events=shared_events)
write_a = _DelayTool("write_a", delay=0.01, read_only=False, shared_events=shared_events)
tools.register(read_a)
tools.register(read_b)
tools.register(write_a)
runner = AgentRunner(MagicMock())
await runner._execute_tools(
AgentRunSpec(
initial_messages=[],
tools=tools,
model="test-model",
max_iterations=1,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
concurrent_tools=True,
),
[
ToolCallRequest(id="ro1", name="read_a", arguments={}),
ToolCallRequest(id="ro2", name="read_b", arguments={}),
ToolCallRequest(id="rw1", name="write_a", arguments={}),
],
)
assert shared_events[0:2] == ["start:read_a", "start:read_b"]
assert "end:read_a" in shared_events and "end:read_b" in shared_events
assert shared_events.index("end:read_a") < shared_events.index("start:write_a")
assert shared_events.index("end:read_b") < shared_events.index("start:write_a")
assert shared_events[-2:] == ["start:write_a", "end:write_a"]
@pytest.mark.asyncio
async def test_loop_max_iterations_message_stays_stable(tmp_path):
loop = _make_loop(tmp_path)
@ -317,15 +557,20 @@ async def test_subagent_max_iterations_announces_existing_fallback(tmp_path, mon
provider.get_default_model.return_value = "test-model"
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
content="working",
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})],
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
))
mgr = SubagentManager(provider=provider, workspace=tmp_path, bus=bus)
mgr = SubagentManager(
provider=provider,
workspace=tmp_path,
bus=bus,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
)
mgr._announce_result = AsyncMock()
async def fake_execute(self, name, arguments):
async def fake_execute(self, **kwargs):
return "tool result"
monkeypatch.setattr("nanobot.agent.tools.registry.ToolRegistry.execute", fake_execute)
monkeypatch.setattr("nanobot.agent.tools.filesystem.ListDirTool.execute", fake_execute)
await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"})

View File

@ -8,6 +8,10 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from nanobot.config.schema import AgentDefaults
_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars
def _make_loop(*, exec_config=None):
"""Create a minimal AgentLoop with mocked dependencies."""
@ -186,7 +190,12 @@ class TestSubagentCancellation:
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
mgr = SubagentManager(provider=provider, workspace=MagicMock(), bus=bus)
mgr = SubagentManager(
provider=provider,
workspace=MagicMock(),
bus=bus,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
)
cancelled = asyncio.Event()
@ -214,7 +223,12 @@ class TestSubagentCancellation:
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
mgr = SubagentManager(provider=provider, workspace=MagicMock(), bus=bus)
mgr = SubagentManager(
provider=provider,
workspace=MagicMock(),
bus=bus,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
)
assert await mgr.cancel_by_session("nonexistent") == 0
@pytest.mark.asyncio
@ -236,19 +250,24 @@ class TestSubagentCancellation:
if call_count["n"] == 1:
return LLMResponse(
content="thinking",
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})],
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
reasoning_content="hidden reasoning",
thinking_blocks=[{"type": "thinking", "thinking": "step"}],
)
captured_second_call[:] = messages
return LLMResponse(content="done", tool_calls=[])
provider.chat_with_retry = scripted_chat_with_retry
mgr = SubagentManager(provider=provider, workspace=tmp_path, bus=bus)
mgr = SubagentManager(
provider=provider,
workspace=tmp_path,
bus=bus,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
)
async def fake_execute(self, name, arguments):
async def fake_execute(self, **kwargs):
return "tool result"
monkeypatch.setattr("nanobot.agent.tools.registry.ToolRegistry.execute", fake_execute)
monkeypatch.setattr("nanobot.agent.tools.filesystem.ListDirTool.execute", fake_execute)
await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"})
@ -273,6 +292,7 @@ class TestSubagentCancellation:
provider=provider,
workspace=tmp_path,
bus=bus,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
exec_config=ExecToolConfig(enable=False),
)
mgr._announce_result = AsyncMock()
@ -304,20 +324,25 @@ class TestSubagentCancellation:
provider.get_default_model.return_value = "test-model"
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
content="thinking",
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})],
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
))
mgr = SubagentManager(provider=provider, workspace=tmp_path, bus=bus)
mgr = SubagentManager(
provider=provider,
workspace=tmp_path,
bus=bus,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
)
mgr._announce_result = AsyncMock()
calls = {"n": 0}
async def fake_execute(self, name, arguments):
async def fake_execute(self, **kwargs):
calls["n"] += 1
if calls["n"] == 1:
return "first result"
raise RuntimeError("boom")
monkeypatch.setattr("nanobot.agent.tools.registry.ToolRegistry.execute", fake_execute)
monkeypatch.setattr("nanobot.agent.tools.filesystem.ListDirTool.execute", fake_execute)
await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"})
@ -340,15 +365,20 @@ class TestSubagentCancellation:
provider.get_default_model.return_value = "test-model"
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
content="thinking",
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})],
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
))
mgr = SubagentManager(provider=provider, workspace=tmp_path, bus=bus)
mgr = SubagentManager(
provider=provider,
workspace=tmp_path,
bus=bus,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
)
mgr._announce_result = AsyncMock()
started = asyncio.Event()
cancelled = asyncio.Event()
async def fake_execute(self, name, arguments):
async def fake_execute(self, **kwargs):
started.set()
try:
await asyncio.sleep(60)
@ -356,7 +386,7 @@ class TestSubagentCancellation:
cancelled.set()
raise
monkeypatch.setattr("nanobot.agent.tools.registry.ToolRegistry.execute", fake_execute)
monkeypatch.setattr("nanobot.agent.tools.filesystem.ListDirTool.execute", fake_execute)
task = asyncio.create_task(
mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"})
@ -364,7 +394,7 @@ class TestSubagentCancellation:
mgr._running_tasks["sub-1"] = task
mgr._session_tasks["test:c1"] = {"sub-1"}
await started.wait()
await asyncio.wait_for(started.wait(), timeout=1.0)
count = await mgr.cancel_by_session("test:c1")

View File

@ -594,7 +594,7 @@ async def test_send_stops_typing_after_send() -> None:
typing_channel.typing_enter_hook = slow_typing
await channel._start_typing(typing_channel)
await start.wait()
await asyncio.wait_for(start.wait(), timeout=1.0)
await channel.send(OutboundMessage(channel="discord", chat_id="123", content="hello"))
release.set()
@ -614,7 +614,7 @@ async def test_send_stops_typing_after_send() -> None:
typing_channel.typing_enter_hook = slow_typing_progress
await channel._start_typing(typing_channel)
await start.wait()
await asyncio.wait_for(start.wait(), timeout=1.0)
await channel.send(
OutboundMessage(
@ -665,7 +665,7 @@ async def test_start_typing_uses_typing_context_when_trigger_typing_missing() ->
typing_channel = _NoTriggerChannel(channel_id=123)
await channel._start_typing(typing_channel) # type: ignore[arg-type]
await entered.wait()
await asyncio.wait_for(entered.wait(), timeout=1.0)
assert "123" in channel._typing_tasks

View File

@ -8,6 +8,7 @@ Validates that:
from __future__ import annotations
import asyncio
from types import SimpleNamespace
from unittest.mock import AsyncMock, patch
@ -53,6 +54,15 @@ def _fake_tool_call_response() -> SimpleNamespace:
return SimpleNamespace(choices=[choice], usage=usage)
class _StalledStream:
def __aiter__(self):
return self
async def __anext__(self):
await asyncio.sleep(3600)
raise StopAsyncIteration
def test_openrouter_spec_is_gateway() -> None:
spec = find_by_name("openrouter")
assert spec is not None
@ -214,3 +224,54 @@ def test_openai_model_passthrough() -> None:
spec=spec,
)
assert provider.get_default_model() == "gpt-4o"
def test_openai_compat_strips_message_level_reasoning_fields() -> None:
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
provider = OpenAICompatProvider()
sanitized = provider._sanitize_messages([
{
"role": "assistant",
"content": "done",
"reasoning_content": "hidden",
"extra_content": {"debug": True},
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {"name": "fn", "arguments": "{}"},
"extra_content": {"google": {"thought_signature": "sig"}},
}
],
}
])
assert "reasoning_content" not in sanitized[0]
assert "extra_content" not in sanitized[0]
assert sanitized[0]["tool_calls"][0]["extra_content"] == {"google": {"thought_signature": "sig"}}
@pytest.mark.asyncio
async def test_openai_compat_stream_watchdog_returns_error_on_stall(monkeypatch) -> None:
monkeypatch.setenv("NANOBOT_STREAM_IDLE_TIMEOUT_S", "0")
mock_create = AsyncMock(return_value=_StalledStream())
spec = find_by_name("openai")
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
client_instance = MockClient.return_value
client_instance.chat.completions.create = mock_create
provider = OpenAICompatProvider(
api_key="sk-test-key",
default_model="gpt-4o",
spec=spec,
)
result = await provider.chat_stream(
messages=[{"role": "user", "content": "hello"}],
model="gpt-4o",
)
assert result.finish_reason == "error"
assert result.content is not None
assert "stream stalled" in result.content

View File

@ -211,3 +211,32 @@ async def test_image_fallback_without_meta_uses_default_placeholder() -> None:
content = msg.get("content")
if isinstance(content, list):
assert any("[image omitted]" in (b.get("text") or "") for b in content)
@pytest.mark.asyncio
async def test_chat_with_retry_uses_retry_after_and_emits_wait_progress(monkeypatch) -> None:
provider = ScriptedProvider([
LLMResponse(content="429 rate limit, retry after 7s", finish_reason="error"),
LLMResponse(content="ok"),
])
delays: list[float] = []
progress: list[str] = []
async def _fake_sleep(delay: float) -> None:
delays.append(delay)
async def _progress(msg: str) -> None:
progress.append(msg)
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
response = await provider.chat_with_retry(
messages=[{"role": "user", "content": "hello"}],
on_retry_wait=_progress,
)
assert response.content == "ok"
assert delays == [7.0]
assert progress and "7s" in progress[0]

View File

@ -196,7 +196,7 @@ async def test_execute_re_raises_external_cancellation() -> None:
wrapper = _make_wrapper(SimpleNamespace(call_tool=call_tool), timeout=10)
task = asyncio.create_task(wrapper.execute())
await started.wait()
await asyncio.wait_for(started.wait(), timeout=1.0)
task.cancel()