mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-02 09:22:36 +00:00
feat: harden agent runtime for long-running tasks
This commit is contained in:
parent
63d646f731
commit
fbedf7ad77
@ -110,6 +110,20 @@ IMPORTANT: To send files (images, documents, audio, video) to the user, you MUST
|
||||
lines += [f"Channel: {channel}", f"Chat ID: {chat_id}"]
|
||||
return ContextBuilder._RUNTIME_CONTEXT_TAG + "\n" + "\n".join(lines)
|
||||
|
||||
@staticmethod
|
||||
def _merge_message_content(left: Any, right: Any) -> str | list[dict[str, Any]]:
|
||||
if isinstance(left, str) and isinstance(right, str):
|
||||
return f"{left}\n\n{right}" if left else right
|
||||
|
||||
def _to_blocks(value: Any) -> list[dict[str, Any]]:
|
||||
if isinstance(value, list):
|
||||
return [item if isinstance(item, dict) else {"type": "text", "text": str(item)} for item in value]
|
||||
if value is None:
|
||||
return []
|
||||
return [{"type": "text", "text": str(value)}]
|
||||
|
||||
return _to_blocks(left) + _to_blocks(right)
|
||||
|
||||
def _load_bootstrap_files(self) -> str:
|
||||
"""Load all bootstrap files from workspace."""
|
||||
parts = []
|
||||
@ -142,12 +156,17 @@ IMPORTANT: To send files (images, documents, audio, video) to the user, you MUST
|
||||
merged = f"{runtime_ctx}\n\n{user_content}"
|
||||
else:
|
||||
merged = [{"type": "text", "text": runtime_ctx}] + user_content
|
||||
|
||||
return [
|
||||
messages = [
|
||||
{"role": "system", "content": self.build_system_prompt(skill_names)},
|
||||
*history,
|
||||
{"role": current_role, "content": merged},
|
||||
]
|
||||
if messages[-1].get("role") == current_role:
|
||||
last = dict(messages[-1])
|
||||
last["content"] = self._merge_message_content(last.get("content"), merged)
|
||||
messages[-1] = last
|
||||
return messages
|
||||
messages.append({"role": current_role, "content": merged})
|
||||
return messages
|
||||
|
||||
def _build_user_content(self, text: str, media: list[str] | None) -> str | list[dict[str, Any]]:
|
||||
"""Build user message content with optional base64-encoded images."""
|
||||
|
||||
@ -29,8 +29,10 @@ from nanobot.agent.tools.web import WebFetchTool, WebSearchTool
|
||||
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||
from nanobot.command import CommandContext, CommandRouter, register_builtin_commands
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.config.schema import AgentDefaults
|
||||
from nanobot.providers.base import LLMProvider
|
||||
from nanobot.session.manager import Session, SessionManager
|
||||
from nanobot.utils.helpers import image_placeholder_text, truncate_text
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nanobot.config.schema import ChannelsConfig, ExecToolConfig, WebSearchConfig
|
||||
@ -38,11 +40,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class _LoopHook(AgentHook):
|
||||
"""Core lifecycle hook for the main agent loop.
|
||||
|
||||
Handles streaming delta relay, progress reporting, tool-call logging,
|
||||
and think-tag stripping for the built-in agent path.
|
||||
"""
|
||||
"""Core hook for the main loop."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -102,11 +100,7 @@ class _LoopHook(AgentHook):
|
||||
|
||||
|
||||
class _LoopHookChain(AgentHook):
|
||||
"""Run the core loop hook first, then best-effort extra hooks.
|
||||
|
||||
This preserves the historical failure behavior of ``_LoopHook`` while still
|
||||
letting user-supplied hooks opt into ``CompositeHook`` isolation.
|
||||
"""
|
||||
"""Run the core hook before extra hooks."""
|
||||
|
||||
__slots__ = ("_primary", "_extras")
|
||||
|
||||
@ -154,7 +148,7 @@ class AgentLoop:
|
||||
5. Sends responses back
|
||||
"""
|
||||
|
||||
_TOOL_RESULT_MAX_CHARS = 16_000
|
||||
_RUNTIME_CHECKPOINT_KEY = "runtime_checkpoint"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -162,8 +156,11 @@ class AgentLoop:
|
||||
provider: LLMProvider,
|
||||
workspace: Path,
|
||||
model: str | None = None,
|
||||
max_iterations: int = 40,
|
||||
context_window_tokens: int = 65_536,
|
||||
max_iterations: int | None = None,
|
||||
context_window_tokens: int | None = None,
|
||||
context_block_limit: int | None = None,
|
||||
max_tool_result_chars: int | None = None,
|
||||
provider_retry_mode: str = "standard",
|
||||
web_search_config: WebSearchConfig | None = None,
|
||||
web_proxy: str | None = None,
|
||||
exec_config: ExecToolConfig | None = None,
|
||||
@ -177,13 +174,27 @@ class AgentLoop:
|
||||
):
|
||||
from nanobot.config.schema import ExecToolConfig, WebSearchConfig
|
||||
|
||||
defaults = AgentDefaults()
|
||||
self.bus = bus
|
||||
self.channels_config = channels_config
|
||||
self.provider = provider
|
||||
self.workspace = workspace
|
||||
self.model = model or provider.get_default_model()
|
||||
self.max_iterations = max_iterations
|
||||
self.context_window_tokens = context_window_tokens
|
||||
self.max_iterations = (
|
||||
max_iterations if max_iterations is not None else defaults.max_tool_iterations
|
||||
)
|
||||
self.context_window_tokens = (
|
||||
context_window_tokens
|
||||
if context_window_tokens is not None
|
||||
else defaults.context_window_tokens
|
||||
)
|
||||
self.context_block_limit = context_block_limit
|
||||
self.max_tool_result_chars = (
|
||||
max_tool_result_chars
|
||||
if max_tool_result_chars is not None
|
||||
else defaults.max_tool_result_chars
|
||||
)
|
||||
self.provider_retry_mode = provider_retry_mode
|
||||
self.web_search_config = web_search_config or WebSearchConfig()
|
||||
self.web_proxy = web_proxy
|
||||
self.exec_config = exec_config or ExecToolConfig()
|
||||
@ -202,6 +213,7 @@ class AgentLoop:
|
||||
workspace=workspace,
|
||||
bus=bus,
|
||||
model=self.model,
|
||||
max_tool_result_chars=self.max_tool_result_chars,
|
||||
web_search_config=self.web_search_config,
|
||||
web_proxy=web_proxy,
|
||||
exec_config=self.exec_config,
|
||||
@ -313,6 +325,7 @@ class AgentLoop:
|
||||
on_stream: Callable[[str], Awaitable[None]] | None = None,
|
||||
on_stream_end: Callable[..., Awaitable[None]] | None = None,
|
||||
*,
|
||||
session: Session | None = None,
|
||||
channel: str = "cli",
|
||||
chat_id: str = "direct",
|
||||
message_id: str | None = None,
|
||||
@ -339,14 +352,27 @@ class AgentLoop:
|
||||
else loop_hook
|
||||
)
|
||||
|
||||
async def _checkpoint(payload: dict[str, Any]) -> None:
|
||||
if session is None:
|
||||
return
|
||||
self._set_runtime_checkpoint(session, payload)
|
||||
|
||||
result = await self.runner.run(AgentRunSpec(
|
||||
initial_messages=initial_messages,
|
||||
tools=self.tools,
|
||||
model=self.model,
|
||||
max_iterations=self.max_iterations,
|
||||
max_tool_result_chars=self.max_tool_result_chars,
|
||||
hook=hook,
|
||||
error_message="Sorry, I encountered an error calling the AI model.",
|
||||
concurrent_tools=True,
|
||||
workspace=self.workspace,
|
||||
session_key=session.key if session else None,
|
||||
context_window_tokens=self.context_window_tokens,
|
||||
context_block_limit=self.context_block_limit,
|
||||
provider_retry_mode=self.provider_retry_mode,
|
||||
progress_callback=on_progress,
|
||||
checkpoint_callback=_checkpoint,
|
||||
))
|
||||
self._last_usage = result.usage
|
||||
if result.stop_reason == "max_iterations":
|
||||
@ -484,6 +510,8 @@ class AgentLoop:
|
||||
logger.info("Processing system message from {}", msg.sender_id)
|
||||
key = f"{channel}:{chat_id}"
|
||||
session = self.sessions.get_or_create(key)
|
||||
if self._restore_runtime_checkpoint(session):
|
||||
self.sessions.save(session)
|
||||
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
|
||||
self._set_tool_context(channel, chat_id, msg.metadata.get("message_id"))
|
||||
history = session.get_history(max_messages=0)
|
||||
@ -494,10 +522,11 @@ class AgentLoop:
|
||||
current_role=current_role,
|
||||
)
|
||||
final_content, _, all_msgs = await self._run_agent_loop(
|
||||
messages, channel=channel, chat_id=chat_id,
|
||||
messages, session=session, channel=channel, chat_id=chat_id,
|
||||
message_id=msg.metadata.get("message_id"),
|
||||
)
|
||||
self._save_turn(session, all_msgs, 1 + len(history))
|
||||
self._clear_runtime_checkpoint(session)
|
||||
self.sessions.save(session)
|
||||
self._schedule_background(self.memory_consolidator.maybe_consolidate_by_tokens(session))
|
||||
return OutboundMessage(channel=channel, chat_id=chat_id,
|
||||
@ -508,6 +537,8 @@ class AgentLoop:
|
||||
|
||||
key = session_key or msg.session_key
|
||||
session = self.sessions.get_or_create(key)
|
||||
if self._restore_runtime_checkpoint(session):
|
||||
self.sessions.save(session)
|
||||
|
||||
# Slash commands
|
||||
raw = msg.content.strip()
|
||||
@ -543,6 +574,7 @@ class AgentLoop:
|
||||
on_progress=on_progress or _bus_progress,
|
||||
on_stream=on_stream,
|
||||
on_stream_end=on_stream_end,
|
||||
session=session,
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
message_id=msg.metadata.get("message_id"),
|
||||
)
|
||||
@ -551,6 +583,7 @@ class AgentLoop:
|
||||
final_content = "I've completed processing but have no response to give."
|
||||
|
||||
self._save_turn(session, all_msgs, 1 + len(history))
|
||||
self._clear_runtime_checkpoint(session)
|
||||
self.sessions.save(session)
|
||||
self._schedule_background(self.memory_consolidator.maybe_consolidate_by_tokens(session))
|
||||
|
||||
@ -568,12 +601,6 @@ class AgentLoop:
|
||||
metadata=meta,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _image_placeholder(block: dict[str, Any]) -> dict[str, str]:
|
||||
"""Convert an inline image block into a compact text placeholder."""
|
||||
path = (block.get("_meta") or {}).get("path", "")
|
||||
return {"type": "text", "text": f"[image: {path}]" if path else "[image]"}
|
||||
|
||||
def _sanitize_persisted_blocks(
|
||||
self,
|
||||
content: list[dict[str, Any]],
|
||||
@ -600,13 +627,14 @@ class AgentLoop:
|
||||
block.get("type") == "image_url"
|
||||
and block.get("image_url", {}).get("url", "").startswith("data:image/")
|
||||
):
|
||||
filtered.append(self._image_placeholder(block))
|
||||
path = (block.get("_meta") or {}).get("path", "")
|
||||
filtered.append({"type": "text", "text": image_placeholder_text(path)})
|
||||
continue
|
||||
|
||||
if block.get("type") == "text" and isinstance(block.get("text"), str):
|
||||
text = block["text"]
|
||||
if truncate_text and len(text) > self._TOOL_RESULT_MAX_CHARS:
|
||||
text = text[:self._TOOL_RESULT_MAX_CHARS] + "\n... (truncated)"
|
||||
if truncate_text and len(text) > self.max_tool_result_chars:
|
||||
text = truncate_text(text, self.max_tool_result_chars)
|
||||
filtered.append({**block, "text": text})
|
||||
continue
|
||||
|
||||
@ -623,8 +651,8 @@ class AgentLoop:
|
||||
if role == "assistant" and not content and not entry.get("tool_calls"):
|
||||
continue # skip empty assistant messages — they poison session context
|
||||
if role == "tool":
|
||||
if isinstance(content, str) and len(content) > self._TOOL_RESULT_MAX_CHARS:
|
||||
entry["content"] = content[:self._TOOL_RESULT_MAX_CHARS] + "\n... (truncated)"
|
||||
if isinstance(content, str) and len(content) > self.max_tool_result_chars:
|
||||
entry["content"] = truncate_text(content, self.max_tool_result_chars)
|
||||
elif isinstance(content, list):
|
||||
filtered = self._sanitize_persisted_blocks(content, truncate_text=True)
|
||||
if not filtered:
|
||||
@ -647,6 +675,78 @@ class AgentLoop:
|
||||
session.messages.append(entry)
|
||||
session.updated_at = datetime.now()
|
||||
|
||||
def _set_runtime_checkpoint(self, session: Session, payload: dict[str, Any]) -> None:
|
||||
"""Persist the latest in-flight turn state into session metadata."""
|
||||
session.metadata[self._RUNTIME_CHECKPOINT_KEY] = payload
|
||||
self.sessions.save(session)
|
||||
|
||||
def _clear_runtime_checkpoint(self, session: Session) -> None:
|
||||
if self._RUNTIME_CHECKPOINT_KEY in session.metadata:
|
||||
session.metadata.pop(self._RUNTIME_CHECKPOINT_KEY, None)
|
||||
|
||||
@staticmethod
|
||||
def _checkpoint_message_key(message: dict[str, Any]) -> tuple[Any, ...]:
|
||||
return (
|
||||
message.get("role"),
|
||||
message.get("content"),
|
||||
message.get("tool_call_id"),
|
||||
message.get("name"),
|
||||
message.get("tool_calls"),
|
||||
message.get("reasoning_content"),
|
||||
message.get("thinking_blocks"),
|
||||
)
|
||||
|
||||
def _restore_runtime_checkpoint(self, session: Session) -> bool:
|
||||
"""Materialize an unfinished turn into session history before a new request."""
|
||||
from datetime import datetime
|
||||
|
||||
checkpoint = session.metadata.get(self._RUNTIME_CHECKPOINT_KEY)
|
||||
if not isinstance(checkpoint, dict):
|
||||
return False
|
||||
|
||||
assistant_message = checkpoint.get("assistant_message")
|
||||
completed_tool_results = checkpoint.get("completed_tool_results") or []
|
||||
pending_tool_calls = checkpoint.get("pending_tool_calls") or []
|
||||
|
||||
restored_messages: list[dict[str, Any]] = []
|
||||
if isinstance(assistant_message, dict):
|
||||
restored = dict(assistant_message)
|
||||
restored.setdefault("timestamp", datetime.now().isoformat())
|
||||
restored_messages.append(restored)
|
||||
for message in completed_tool_results:
|
||||
if isinstance(message, dict):
|
||||
restored = dict(message)
|
||||
restored.setdefault("timestamp", datetime.now().isoformat())
|
||||
restored_messages.append(restored)
|
||||
for tool_call in pending_tool_calls:
|
||||
if not isinstance(tool_call, dict):
|
||||
continue
|
||||
tool_id = tool_call.get("id")
|
||||
name = ((tool_call.get("function") or {}).get("name")) or "tool"
|
||||
restored_messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_id,
|
||||
"name": name,
|
||||
"content": "Error: Task interrupted before this tool finished.",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
})
|
||||
|
||||
overlap = 0
|
||||
max_overlap = min(len(session.messages), len(restored_messages))
|
||||
for size in range(max_overlap, 0, -1):
|
||||
existing = session.messages[-size:]
|
||||
restored = restored_messages[:size]
|
||||
if all(
|
||||
self._checkpoint_message_key(left) == self._checkpoint_message_key(right)
|
||||
for left, right in zip(existing, restored)
|
||||
):
|
||||
overlap = size
|
||||
break
|
||||
session.messages.extend(restored_messages[overlap:])
|
||||
|
||||
self._clear_runtime_checkpoint(session)
|
||||
return True
|
||||
|
||||
async def process_direct(
|
||||
self,
|
||||
content: str,
|
||||
|
||||
@ -4,20 +4,29 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.agent.hook import AgentHook, AgentHookContext
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
from nanobot.providers.base import LLMProvider, ToolCallRequest
|
||||
from nanobot.utils.helpers import build_assistant_message
|
||||
from nanobot.utils.helpers import (
|
||||
build_assistant_message,
|
||||
estimate_message_tokens,
|
||||
estimate_prompt_tokens_chain,
|
||||
find_legal_message_start,
|
||||
maybe_persist_tool_result,
|
||||
truncate_text,
|
||||
)
|
||||
|
||||
_DEFAULT_MAX_ITERATIONS_MESSAGE = (
|
||||
"I reached the maximum number of tool call iterations ({max_iterations}) "
|
||||
"without completing the task. You can try breaking the task into smaller steps."
|
||||
)
|
||||
_DEFAULT_ERROR_MESSAGE = "Sorry, I encountered an error calling the AI model."
|
||||
|
||||
|
||||
_SNIP_SAFETY_BUFFER = 1024
|
||||
@dataclass(slots=True)
|
||||
class AgentRunSpec:
|
||||
"""Configuration for a single agent execution."""
|
||||
@ -26,6 +35,7 @@ class AgentRunSpec:
|
||||
tools: ToolRegistry
|
||||
model: str
|
||||
max_iterations: int
|
||||
max_tool_result_chars: int
|
||||
temperature: float | None = None
|
||||
max_tokens: int | None = None
|
||||
reasoning_effort: str | None = None
|
||||
@ -34,6 +44,13 @@ class AgentRunSpec:
|
||||
max_iterations_message: str | None = None
|
||||
concurrent_tools: bool = False
|
||||
fail_on_tool_error: bool = False
|
||||
workspace: Path | None = None
|
||||
session_key: str | None = None
|
||||
context_window_tokens: int | None = None
|
||||
context_block_limit: int | None = None
|
||||
provider_retry_mode: str = "standard"
|
||||
progress_callback: Any | None = None
|
||||
checkpoint_callback: Any | None = None
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
@ -66,12 +83,25 @@ class AgentRunner:
|
||||
tool_events: list[dict[str, str]] = []
|
||||
|
||||
for iteration in range(spec.max_iterations):
|
||||
try:
|
||||
messages = self._apply_tool_result_budget(spec, messages)
|
||||
messages_for_model = self._snip_history(spec, messages)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Context governance failed on turn {} for {}: {}; using raw messages",
|
||||
iteration,
|
||||
spec.session_key or "default",
|
||||
exc,
|
||||
)
|
||||
messages_for_model = messages
|
||||
context = AgentHookContext(iteration=iteration, messages=messages)
|
||||
await hook.before_iteration(context)
|
||||
kwargs: dict[str, Any] = {
|
||||
"messages": messages,
|
||||
"messages": messages_for_model,
|
||||
"tools": spec.tools.get_definitions(),
|
||||
"model": spec.model,
|
||||
"retry_mode": spec.provider_retry_mode,
|
||||
"on_retry_wait": spec.progress_callback,
|
||||
}
|
||||
if spec.temperature is not None:
|
||||
kwargs["temperature"] = spec.temperature
|
||||
@ -104,13 +134,25 @@ class AgentRunner:
|
||||
if hook.wants_streaming():
|
||||
await hook.on_stream_end(context, resuming=True)
|
||||
|
||||
messages.append(build_assistant_message(
|
||||
assistant_message = build_assistant_message(
|
||||
response.content or "",
|
||||
tool_calls=[tc.to_openai_tool_call() for tc in response.tool_calls],
|
||||
reasoning_content=response.reasoning_content,
|
||||
thinking_blocks=response.thinking_blocks,
|
||||
))
|
||||
)
|
||||
messages.append(assistant_message)
|
||||
tools_used.extend(tc.name for tc in response.tool_calls)
|
||||
await self._emit_checkpoint(
|
||||
spec,
|
||||
{
|
||||
"phase": "awaiting_tools",
|
||||
"iteration": iteration,
|
||||
"model": spec.model,
|
||||
"assistant_message": assistant_message,
|
||||
"completed_tool_results": [],
|
||||
"pending_tool_calls": [tc.to_openai_tool_call() for tc in response.tool_calls],
|
||||
},
|
||||
)
|
||||
|
||||
await hook.before_execute_tools(context)
|
||||
|
||||
@ -125,13 +167,31 @@ class AgentRunner:
|
||||
context.stop_reason = stop_reason
|
||||
await hook.after_iteration(context)
|
||||
break
|
||||
completed_tool_results: list[dict[str, Any]] = []
|
||||
for tool_call, result in zip(response.tool_calls, results):
|
||||
messages.append({
|
||||
tool_message = {
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"name": tool_call.name,
|
||||
"content": result,
|
||||
})
|
||||
"content": self._normalize_tool_result(
|
||||
spec,
|
||||
tool_call.id,
|
||||
result,
|
||||
),
|
||||
}
|
||||
messages.append(tool_message)
|
||||
completed_tool_results.append(tool_message)
|
||||
await self._emit_checkpoint(
|
||||
spec,
|
||||
{
|
||||
"phase": "tools_completed",
|
||||
"iteration": iteration,
|
||||
"model": spec.model,
|
||||
"assistant_message": assistant_message,
|
||||
"completed_tool_results": completed_tool_results,
|
||||
"pending_tool_calls": [],
|
||||
},
|
||||
)
|
||||
await hook.after_iteration(context)
|
||||
continue
|
||||
|
||||
@ -143,6 +203,7 @@ class AgentRunner:
|
||||
final_content = clean or spec.error_message or _DEFAULT_ERROR_MESSAGE
|
||||
stop_reason = "error"
|
||||
error = final_content
|
||||
self._append_final_message(messages, final_content)
|
||||
context.final_content = final_content
|
||||
context.error = error
|
||||
context.stop_reason = stop_reason
|
||||
@ -154,6 +215,17 @@ class AgentRunner:
|
||||
reasoning_content=response.reasoning_content,
|
||||
thinking_blocks=response.thinking_blocks,
|
||||
))
|
||||
await self._emit_checkpoint(
|
||||
spec,
|
||||
{
|
||||
"phase": "final_response",
|
||||
"iteration": iteration,
|
||||
"model": spec.model,
|
||||
"assistant_message": messages[-1],
|
||||
"completed_tool_results": [],
|
||||
"pending_tool_calls": [],
|
||||
},
|
||||
)
|
||||
final_content = clean
|
||||
context.final_content = final_content
|
||||
context.stop_reason = stop_reason
|
||||
@ -163,6 +235,7 @@ class AgentRunner:
|
||||
stop_reason = "max_iterations"
|
||||
template = spec.max_iterations_message or _DEFAULT_MAX_ITERATIONS_MESSAGE
|
||||
final_content = template.format(max_iterations=spec.max_iterations)
|
||||
self._append_final_message(messages, final_content)
|
||||
|
||||
return AgentRunResult(
|
||||
final_content=final_content,
|
||||
@ -179,16 +252,17 @@ class AgentRunner:
|
||||
spec: AgentRunSpec,
|
||||
tool_calls: list[ToolCallRequest],
|
||||
) -> tuple[list[Any], list[dict[str, str]], BaseException | None]:
|
||||
if spec.concurrent_tools:
|
||||
tool_results = await asyncio.gather(*(
|
||||
self._run_tool(spec, tool_call)
|
||||
for tool_call in tool_calls
|
||||
))
|
||||
else:
|
||||
tool_results = [
|
||||
await self._run_tool(spec, tool_call)
|
||||
for tool_call in tool_calls
|
||||
]
|
||||
batches = self._partition_tool_batches(spec, tool_calls)
|
||||
tool_results: list[tuple[Any, dict[str, str], BaseException | None]] = []
|
||||
for batch in batches:
|
||||
if spec.concurrent_tools and len(batch) > 1:
|
||||
tool_results.extend(await asyncio.gather(*(
|
||||
self._run_tool(spec, tool_call)
|
||||
for tool_call in batch
|
||||
)))
|
||||
else:
|
||||
for tool_call in batch:
|
||||
tool_results.append(await self._run_tool(spec, tool_call))
|
||||
|
||||
results: list[Any] = []
|
||||
events: list[dict[str, str]] = []
|
||||
@ -205,8 +279,28 @@ class AgentRunner:
|
||||
spec: AgentRunSpec,
|
||||
tool_call: ToolCallRequest,
|
||||
) -> tuple[Any, dict[str, str], BaseException | None]:
|
||||
_HINT = "\n\n[Analyze the error above and try a different approach.]"
|
||||
prepare_call = getattr(spec.tools, "prepare_call", None)
|
||||
tool, params, prep_error = None, tool_call.arguments, None
|
||||
if callable(prepare_call):
|
||||
try:
|
||||
prepared = prepare_call(tool_call.name, tool_call.arguments)
|
||||
if isinstance(prepared, tuple) and len(prepared) == 3:
|
||||
tool, params, prep_error = prepared
|
||||
except Exception:
|
||||
pass
|
||||
if prep_error:
|
||||
event = {
|
||||
"name": tool_call.name,
|
||||
"status": "error",
|
||||
"detail": prep_error.split(": ", 1)[-1][:120],
|
||||
}
|
||||
return prep_error + _HINT, event, RuntimeError(prep_error) if spec.fail_on_tool_error else None
|
||||
try:
|
||||
result = await spec.tools.execute(tool_call.name, tool_call.arguments)
|
||||
if tool is not None:
|
||||
result = await tool.execute(**params)
|
||||
else:
|
||||
result = await spec.tools.execute(tool_call.name, params)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except BaseException as exc:
|
||||
@ -219,14 +313,175 @@ class AgentRunner:
|
||||
return f"Error: {type(exc).__name__}: {exc}", event, exc
|
||||
return f"Error: {type(exc).__name__}: {exc}", event, None
|
||||
|
||||
if isinstance(result, str) and result.startswith("Error"):
|
||||
event = {
|
||||
"name": tool_call.name,
|
||||
"status": "error",
|
||||
"detail": result.replace("\n", " ").strip()[:120],
|
||||
}
|
||||
if spec.fail_on_tool_error:
|
||||
return result + _HINT, event, RuntimeError(result)
|
||||
return result + _HINT, event, None
|
||||
|
||||
detail = "" if result is None else str(result)
|
||||
detail = detail.replace("\n", " ").strip()
|
||||
if not detail:
|
||||
detail = "(empty)"
|
||||
elif len(detail) > 120:
|
||||
detail = detail[:120] + "..."
|
||||
return result, {
|
||||
"name": tool_call.name,
|
||||
"status": "error" if isinstance(result, str) and result.startswith("Error") else "ok",
|
||||
"detail": detail,
|
||||
}, None
|
||||
return result, {"name": tool_call.name, "status": "ok", "detail": detail}, None
|
||||
|
||||
async def _emit_checkpoint(
|
||||
self,
|
||||
spec: AgentRunSpec,
|
||||
payload: dict[str, Any],
|
||||
) -> None:
|
||||
callback = spec.checkpoint_callback
|
||||
if callback is not None:
|
||||
await callback(payload)
|
||||
|
||||
@staticmethod
|
||||
def _append_final_message(messages: list[dict[str, Any]], content: str | None) -> None:
|
||||
if not content:
|
||||
return
|
||||
if (
|
||||
messages
|
||||
and messages[-1].get("role") == "assistant"
|
||||
and not messages[-1].get("tool_calls")
|
||||
):
|
||||
if messages[-1].get("content") == content:
|
||||
return
|
||||
messages[-1] = build_assistant_message(content)
|
||||
return
|
||||
messages.append(build_assistant_message(content))
|
||||
|
||||
def _normalize_tool_result(
|
||||
self,
|
||||
spec: AgentRunSpec,
|
||||
tool_call_id: str,
|
||||
result: Any,
|
||||
) -> Any:
|
||||
try:
|
||||
content = maybe_persist_tool_result(
|
||||
spec.workspace,
|
||||
spec.session_key,
|
||||
tool_call_id,
|
||||
result,
|
||||
max_chars=spec.max_tool_result_chars,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Tool result persist failed for {} in {}: {}; using raw result",
|
||||
tool_call_id,
|
||||
spec.session_key or "default",
|
||||
exc,
|
||||
)
|
||||
content = result
|
||||
if isinstance(content, str) and len(content) > spec.max_tool_result_chars:
|
||||
return truncate_text(content, spec.max_tool_result_chars)
|
||||
return content
|
||||
|
||||
def _apply_tool_result_budget(
|
||||
self,
|
||||
spec: AgentRunSpec,
|
||||
messages: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
updated = messages
|
||||
for idx, message in enumerate(messages):
|
||||
if message.get("role") != "tool":
|
||||
continue
|
||||
normalized = self._normalize_tool_result(
|
||||
spec,
|
||||
str(message.get("tool_call_id") or f"tool_{idx}"),
|
||||
message.get("content"),
|
||||
)
|
||||
if normalized != message.get("content"):
|
||||
if updated is messages:
|
||||
updated = [dict(m) for m in messages]
|
||||
updated[idx]["content"] = normalized
|
||||
return updated
|
||||
|
||||
def _snip_history(
|
||||
self,
|
||||
spec: AgentRunSpec,
|
||||
messages: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
if not messages or not spec.context_window_tokens:
|
||||
return messages
|
||||
|
||||
provider_max_tokens = getattr(getattr(self.provider, "generation", None), "max_tokens", 4096)
|
||||
max_output = spec.max_tokens if isinstance(spec.max_tokens, int) else (
|
||||
provider_max_tokens if isinstance(provider_max_tokens, int) else 4096
|
||||
)
|
||||
budget = spec.context_block_limit or (
|
||||
spec.context_window_tokens - max_output - _SNIP_SAFETY_BUFFER
|
||||
)
|
||||
if budget <= 0:
|
||||
return messages
|
||||
|
||||
estimate, _ = estimate_prompt_tokens_chain(
|
||||
self.provider,
|
||||
spec.model,
|
||||
messages,
|
||||
spec.tools.get_definitions(),
|
||||
)
|
||||
if estimate <= budget:
|
||||
return messages
|
||||
|
||||
system_messages = [dict(msg) for msg in messages if msg.get("role") == "system"]
|
||||
non_system = [dict(msg) for msg in messages if msg.get("role") != "system"]
|
||||
if not non_system:
|
||||
return messages
|
||||
|
||||
system_tokens = sum(estimate_message_tokens(msg) for msg in system_messages)
|
||||
remaining_budget = max(128, budget - system_tokens)
|
||||
kept: list[dict[str, Any]] = []
|
||||
kept_tokens = 0
|
||||
for message in reversed(non_system):
|
||||
msg_tokens = estimate_message_tokens(message)
|
||||
if kept and kept_tokens + msg_tokens > remaining_budget:
|
||||
break
|
||||
kept.append(message)
|
||||
kept_tokens += msg_tokens
|
||||
kept.reverse()
|
||||
|
||||
if kept:
|
||||
for i, message in enumerate(kept):
|
||||
if message.get("role") == "user":
|
||||
kept = kept[i:]
|
||||
break
|
||||
start = find_legal_message_start(kept)
|
||||
if start:
|
||||
kept = kept[start:]
|
||||
if not kept:
|
||||
kept = non_system[-min(len(non_system), 4) :]
|
||||
start = find_legal_message_start(kept)
|
||||
if start:
|
||||
kept = kept[start:]
|
||||
return system_messages + kept
|
||||
|
||||
def _partition_tool_batches(
|
||||
self,
|
||||
spec: AgentRunSpec,
|
||||
tool_calls: list[ToolCallRequest],
|
||||
) -> list[list[ToolCallRequest]]:
|
||||
if not spec.concurrent_tools:
|
||||
return [[tool_call] for tool_call in tool_calls]
|
||||
|
||||
batches: list[list[ToolCallRequest]] = []
|
||||
current: list[ToolCallRequest] = []
|
||||
for tool_call in tool_calls:
|
||||
get_tool = getattr(spec.tools, "get", None)
|
||||
tool = get_tool(tool_call.name) if callable(get_tool) else None
|
||||
can_batch = bool(tool and tool.concurrency_safe)
|
||||
if can_batch:
|
||||
current.append(tool_call)
|
||||
continue
|
||||
if current:
|
||||
batches.append(current)
|
||||
current = []
|
||||
batches.append([tool_call])
|
||||
if current:
|
||||
batches.append(current)
|
||||
return batches
|
||||
|
||||
|
||||
@ -44,6 +44,7 @@ class SubagentManager:
|
||||
provider: LLMProvider,
|
||||
workspace: Path,
|
||||
bus: MessageBus,
|
||||
max_tool_result_chars: int,
|
||||
model: str | None = None,
|
||||
web_search_config: "WebSearchConfig | None" = None,
|
||||
web_proxy: str | None = None,
|
||||
@ -56,6 +57,7 @@ class SubagentManager:
|
||||
self.workspace = workspace
|
||||
self.bus = bus
|
||||
self.model = model or provider.get_default_model()
|
||||
self.max_tool_result_chars = max_tool_result_chars
|
||||
self.web_search_config = web_search_config or WebSearchConfig()
|
||||
self.web_proxy = web_proxy
|
||||
self.exec_config = exec_config or ExecToolConfig()
|
||||
@ -136,6 +138,7 @@ class SubagentManager:
|
||||
tools=tools,
|
||||
model=self.model,
|
||||
max_iterations=15,
|
||||
max_tool_result_chars=self.max_tool_result_chars,
|
||||
hook=_SubagentHook(task_id),
|
||||
max_iterations_message="Task completed but no final response was generated.",
|
||||
error_message=None,
|
||||
|
||||
@ -53,6 +53,21 @@ class Tool(ABC):
|
||||
"""JSON Schema for tool parameters."""
|
||||
pass
|
||||
|
||||
@property
|
||||
def read_only(self) -> bool:
|
||||
"""Whether this tool is side-effect free and safe to parallelize."""
|
||||
return False
|
||||
|
||||
@property
|
||||
def concurrency_safe(self) -> bool:
|
||||
"""Whether this tool can run alongside other concurrency-safe tools."""
|
||||
return self.read_only and not self.exclusive
|
||||
|
||||
@property
|
||||
def exclusive(self) -> bool:
|
||||
"""Whether this tool should run alone even if concurrency is enabled."""
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, **kwargs: Any) -> Any:
|
||||
"""
|
||||
|
||||
@ -73,6 +73,10 @@ class ReadFileTool(_FsTool):
|
||||
"Use offset and limit to paginate through large files."
|
||||
)
|
||||
|
||||
@property
|
||||
def read_only(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
@ -344,6 +348,10 @@ class ListDirTool(_FsTool):
|
||||
"Common noise directories (.git, node_modules, __pycache__, etc.) are auto-ignored."
|
||||
)
|
||||
|
||||
@property
|
||||
def read_only(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
|
||||
@ -35,22 +35,35 @@ class ToolRegistry:
|
||||
"""Get all tool definitions in OpenAI format."""
|
||||
return [tool.to_schema() for tool in self._tools.values()]
|
||||
|
||||
def prepare_call(
|
||||
self,
|
||||
name: str,
|
||||
params: dict[str, Any],
|
||||
) -> tuple[Tool | None, dict[str, Any], str | None]:
|
||||
"""Resolve, cast, and validate one tool call."""
|
||||
tool = self._tools.get(name)
|
||||
if not tool:
|
||||
return None, params, (
|
||||
f"Error: Tool '{name}' not found. Available: {', '.join(self.tool_names)}"
|
||||
)
|
||||
|
||||
cast_params = tool.cast_params(params)
|
||||
errors = tool.validate_params(cast_params)
|
||||
if errors:
|
||||
return tool, cast_params, (
|
||||
f"Error: Invalid parameters for tool '{name}': " + "; ".join(errors)
|
||||
)
|
||||
return tool, cast_params, None
|
||||
|
||||
async def execute(self, name: str, params: dict[str, Any]) -> Any:
|
||||
"""Execute a tool by name with given parameters."""
|
||||
_HINT = "\n\n[Analyze the error above and try a different approach.]"
|
||||
|
||||
tool = self._tools.get(name)
|
||||
if not tool:
|
||||
return f"Error: Tool '{name}' not found. Available: {', '.join(self.tool_names)}"
|
||||
tool, params, error = self.prepare_call(name, params)
|
||||
if error:
|
||||
return error + _HINT
|
||||
|
||||
try:
|
||||
# Attempt to cast parameters to match schema types
|
||||
params = tool.cast_params(params)
|
||||
|
||||
# Validate parameters
|
||||
errors = tool.validate_params(params)
|
||||
if errors:
|
||||
return f"Error: Invalid parameters for tool '{name}': " + "; ".join(errors) + _HINT
|
||||
assert tool is not None # guarded by prepare_call()
|
||||
result = await tool.execute(**params)
|
||||
if isinstance(result, str) and result.startswith("Error"):
|
||||
return result + _HINT
|
||||
|
||||
@ -52,6 +52,10 @@ class ExecTool(Tool):
|
||||
def description(self) -> str:
|
||||
return "Execute a shell command and return its output. Use with caution."
|
||||
|
||||
@property
|
||||
def exclusive(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
|
||||
@ -92,6 +92,10 @@ class WebSearchTool(Tool):
|
||||
self.config = config if config is not None else WebSearchConfig()
|
||||
self.proxy = proxy
|
||||
|
||||
@property
|
||||
def read_only(self) -> bool:
|
||||
return True
|
||||
|
||||
async def execute(self, query: str, count: int | None = None, **kwargs: Any) -> str:
|
||||
provider = self.config.provider.strip().lower() or "brave"
|
||||
n = min(max(count or self.config.max_results, 1), 10)
|
||||
@ -234,6 +238,10 @@ class WebFetchTool(Tool):
|
||||
self.max_chars = max_chars
|
||||
self.proxy = proxy
|
||||
|
||||
@property
|
||||
def read_only(self) -> bool:
|
||||
return True
|
||||
|
||||
async def execute(self, url: str, extractMode: str = "markdown", maxChars: int | None = None, **kwargs: Any) -> Any:
|
||||
max_chars = maxChars or self.max_chars
|
||||
is_valid, error_msg = _validate_url_safe(url)
|
||||
|
||||
@ -539,6 +539,9 @@ def serve(
|
||||
model=runtime_config.agents.defaults.model,
|
||||
max_iterations=runtime_config.agents.defaults.max_tool_iterations,
|
||||
context_window_tokens=runtime_config.agents.defaults.context_window_tokens,
|
||||
context_block_limit=runtime_config.agents.defaults.context_block_limit,
|
||||
max_tool_result_chars=runtime_config.agents.defaults.max_tool_result_chars,
|
||||
provider_retry_mode=runtime_config.agents.defaults.provider_retry_mode,
|
||||
web_search_config=runtime_config.tools.web.search,
|
||||
web_proxy=runtime_config.tools.web.proxy or None,
|
||||
exec_config=runtime_config.tools.exec,
|
||||
@ -626,6 +629,9 @@ def gateway(
|
||||
model=config.agents.defaults.model,
|
||||
max_iterations=config.agents.defaults.max_tool_iterations,
|
||||
context_window_tokens=config.agents.defaults.context_window_tokens,
|
||||
context_block_limit=config.agents.defaults.context_block_limit,
|
||||
max_tool_result_chars=config.agents.defaults.max_tool_result_chars,
|
||||
provider_retry_mode=config.agents.defaults.provider_retry_mode,
|
||||
web_search_config=config.tools.web.search,
|
||||
web_proxy=config.tools.web.proxy or None,
|
||||
exec_config=config.tools.exec,
|
||||
@ -832,6 +838,9 @@ def agent(
|
||||
model=config.agents.defaults.model,
|
||||
max_iterations=config.agents.defaults.max_tool_iterations,
|
||||
context_window_tokens=config.agents.defaults.context_window_tokens,
|
||||
context_block_limit=config.agents.defaults.context_block_limit,
|
||||
max_tool_result_chars=config.agents.defaults.max_tool_result_chars,
|
||||
provider_retry_mode=config.agents.defaults.provider_retry_mode,
|
||||
web_search_config=config.tools.web.search,
|
||||
web_proxy=config.tools.web.proxy or None,
|
||||
exec_config=config.tools.exec,
|
||||
|
||||
@ -38,8 +38,11 @@ class AgentDefaults(Base):
|
||||
)
|
||||
max_tokens: int = 8192
|
||||
context_window_tokens: int = 65_536
|
||||
context_block_limit: int | None = None
|
||||
temperature: float = 0.1
|
||||
max_tool_iterations: int = 40
|
||||
max_tool_iterations: int = 200
|
||||
max_tool_result_chars: int = 16_000
|
||||
provider_retry_mode: Literal["standard", "persistent"] = "standard"
|
||||
reasoning_effort: str | None = None # low / medium / high - enables LLM thinking mode
|
||||
timezone: str = "UTC" # IANA timezone, e.g. "Asia/Shanghai", "America/New_York"
|
||||
|
||||
|
||||
@ -73,6 +73,9 @@ class Nanobot:
|
||||
model=defaults.model,
|
||||
max_iterations=defaults.max_tool_iterations,
|
||||
context_window_tokens=defaults.context_window_tokens,
|
||||
context_block_limit=defaults.context_block_limit,
|
||||
max_tool_result_chars=defaults.max_tool_result_chars,
|
||||
provider_retry_mode=defaults.provider_retry_mode,
|
||||
web_search_config=config.tools.web.search,
|
||||
web_proxy=config.tools.web.proxy or None,
|
||||
exec_config=config.tools.exec,
|
||||
|
||||
@ -2,6 +2,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
import secrets
|
||||
import string
|
||||
@ -427,13 +429,33 @@ class AnthropicProvider(LLMProvider):
|
||||
messages, tools, model, max_tokens, temperature,
|
||||
reasoning_effort, tool_choice,
|
||||
)
|
||||
idle_timeout_s = int(os.environ.get("NANOBOT_STREAM_IDLE_TIMEOUT_S", "90"))
|
||||
try:
|
||||
async with self._client.messages.stream(**kwargs) as stream:
|
||||
if on_content_delta:
|
||||
async for text in stream.text_stream:
|
||||
stream_iter = stream.text_stream.__aiter__()
|
||||
while True:
|
||||
try:
|
||||
text = await asyncio.wait_for(
|
||||
stream_iter.__anext__(),
|
||||
timeout=idle_timeout_s,
|
||||
)
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
await on_content_delta(text)
|
||||
response = await stream.get_final_message()
|
||||
response = await asyncio.wait_for(
|
||||
stream.get_final_message(),
|
||||
timeout=idle_timeout_s,
|
||||
)
|
||||
return self._parse_response(response)
|
||||
except asyncio.TimeoutError:
|
||||
return LLMResponse(
|
||||
content=(
|
||||
f"Error calling LLM: stream stalled for more than "
|
||||
f"{idle_timeout_s} seconds"
|
||||
),
|
||||
finish_reason="error",
|
||||
)
|
||||
except Exception as e:
|
||||
return LLMResponse(content=f"Error calling LLM: {e}", finish_reason="error")
|
||||
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass, field
|
||||
@ -9,6 +10,8 @@ from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.utils.helpers import image_placeholder_text
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCallRequest:
|
||||
@ -57,13 +60,7 @@ class LLMResponse:
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GenerationSettings:
|
||||
"""Default generation parameters for LLM calls.
|
||||
|
||||
Stored on the provider so every call site inherits the same defaults
|
||||
without having to pass temperature / max_tokens / reasoning_effort
|
||||
through every layer. Individual call sites can still override by
|
||||
passing explicit keyword arguments to chat() / chat_with_retry().
|
||||
"""
|
||||
"""Default generation settings."""
|
||||
|
||||
temperature: float = 0.7
|
||||
max_tokens: int = 4096
|
||||
@ -71,14 +68,11 @@ class GenerationSettings:
|
||||
|
||||
|
||||
class LLMProvider(ABC):
|
||||
"""
|
||||
Abstract base class for LLM providers.
|
||||
|
||||
Implementations should handle the specifics of each provider's API
|
||||
while maintaining a consistent interface.
|
||||
"""
|
||||
"""Base class for LLM providers."""
|
||||
|
||||
_CHAT_RETRY_DELAYS = (1, 2, 4)
|
||||
_PERSISTENT_MAX_DELAY = 60
|
||||
_RETRY_HEARTBEAT_CHUNK = 30
|
||||
_TRANSIENT_ERROR_MARKERS = (
|
||||
"429",
|
||||
"rate limit",
|
||||
@ -208,7 +202,7 @@ class LLMProvider(ABC):
|
||||
for b in content:
|
||||
if isinstance(b, dict) and b.get("type") == "image_url":
|
||||
path = (b.get("_meta") or {}).get("path", "")
|
||||
placeholder = f"[image: {path}]" if path else "[image omitted]"
|
||||
placeholder = image_placeholder_text(path, empty="[image omitted]")
|
||||
new_content.append({"type": "text", "text": placeholder})
|
||||
found = True
|
||||
else:
|
||||
@ -273,6 +267,8 @@ class LLMProvider(ABC):
|
||||
reasoning_effort: object = _SENTINEL,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||
retry_mode: str = "standard",
|
||||
on_retry_wait: Callable[[str], Awaitable[None]] | None = None,
|
||||
) -> LLMResponse:
|
||||
"""Call chat_stream() with retry on transient provider failures."""
|
||||
if max_tokens is self._SENTINEL:
|
||||
@ -288,28 +284,13 @@ class LLMProvider(ABC):
|
||||
reasoning_effort=reasoning_effort, tool_choice=tool_choice,
|
||||
on_content_delta=on_content_delta,
|
||||
)
|
||||
|
||||
for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1):
|
||||
response = await self._safe_chat_stream(**kw)
|
||||
|
||||
if response.finish_reason != "error":
|
||||
return response
|
||||
|
||||
if not self._is_transient_error(response.content):
|
||||
stripped = self._strip_image_content(messages)
|
||||
if stripped is not None:
|
||||
logger.warning("Non-transient LLM error with image content, retrying without images")
|
||||
return await self._safe_chat_stream(**{**kw, "messages": stripped})
|
||||
return response
|
||||
|
||||
logger.warning(
|
||||
"LLM transient error (attempt {}/{}), retrying in {}s: {}",
|
||||
attempt, len(self._CHAT_RETRY_DELAYS), delay,
|
||||
(response.content or "")[:120].lower(),
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
return await self._safe_chat_stream(**kw)
|
||||
return await self._run_with_retry(
|
||||
self._safe_chat_stream,
|
||||
kw,
|
||||
messages,
|
||||
retry_mode=retry_mode,
|
||||
on_retry_wait=on_retry_wait,
|
||||
)
|
||||
|
||||
async def chat_with_retry(
|
||||
self,
|
||||
@ -320,6 +301,8 @@ class LLMProvider(ABC):
|
||||
temperature: object = _SENTINEL,
|
||||
reasoning_effort: object = _SENTINEL,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
retry_mode: str = "standard",
|
||||
on_retry_wait: Callable[[str], Awaitable[None]] | None = None,
|
||||
) -> LLMResponse:
|
||||
"""Call chat() with retry on transient provider failures.
|
||||
|
||||
@ -339,28 +322,102 @@ class LLMProvider(ABC):
|
||||
max_tokens=max_tokens, temperature=temperature,
|
||||
reasoning_effort=reasoning_effort, tool_choice=tool_choice,
|
||||
)
|
||||
return await self._run_with_retry(
|
||||
self._safe_chat,
|
||||
kw,
|
||||
messages,
|
||||
retry_mode=retry_mode,
|
||||
on_retry_wait=on_retry_wait,
|
||||
)
|
||||
|
||||
for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1):
|
||||
response = await self._safe_chat(**kw)
|
||||
@classmethod
|
||||
def _extract_retry_after(cls, content: str | None) -> float | None:
|
||||
text = (content or "").lower()
|
||||
match = re.search(r"retry after\s+(\d+(?:\.\d+)?)\s*(ms|milliseconds|s|sec|secs|seconds|m|min|minutes)?", text)
|
||||
if not match:
|
||||
return None
|
||||
value = float(match.group(1))
|
||||
unit = (match.group(2) or "s").lower()
|
||||
if unit in {"ms", "milliseconds"}:
|
||||
return max(0.1, value / 1000.0)
|
||||
if unit in {"m", "min", "minutes"}:
|
||||
return value * 60.0
|
||||
return value
|
||||
|
||||
async def _sleep_with_heartbeat(
|
||||
self,
|
||||
delay: float,
|
||||
*,
|
||||
attempt: int,
|
||||
persistent: bool,
|
||||
on_retry_wait: Callable[[str], Awaitable[None]] | None = None,
|
||||
) -> None:
|
||||
remaining = max(0.0, delay)
|
||||
while remaining > 0:
|
||||
if on_retry_wait:
|
||||
kind = "persistent retry" if persistent else "retry"
|
||||
await on_retry_wait(
|
||||
f"Model request failed, {kind} in {max(1, int(round(remaining)))}s "
|
||||
f"(attempt {attempt})."
|
||||
)
|
||||
chunk = min(remaining, self._RETRY_HEARTBEAT_CHUNK)
|
||||
await asyncio.sleep(chunk)
|
||||
remaining -= chunk
|
||||
|
||||
async def _run_with_retry(
|
||||
self,
|
||||
call: Callable[..., Awaitable[LLMResponse]],
|
||||
kw: dict[str, Any],
|
||||
original_messages: list[dict[str, Any]],
|
||||
*,
|
||||
retry_mode: str,
|
||||
on_retry_wait: Callable[[str], Awaitable[None]] | None,
|
||||
) -> LLMResponse:
|
||||
attempt = 0
|
||||
delays = list(self._CHAT_RETRY_DELAYS)
|
||||
persistent = retry_mode == "persistent"
|
||||
last_response: LLMResponse | None = None
|
||||
while True:
|
||||
attempt += 1
|
||||
response = await call(**kw)
|
||||
if response.finish_reason != "error":
|
||||
return response
|
||||
last_response = response
|
||||
|
||||
if not self._is_transient_error(response.content):
|
||||
stripped = self._strip_image_content(messages)
|
||||
if stripped is not None:
|
||||
logger.warning("Non-transient LLM error with image content, retrying without images")
|
||||
return await self._safe_chat(**{**kw, "messages": stripped})
|
||||
stripped = self._strip_image_content(original_messages)
|
||||
if stripped is not None and stripped != kw["messages"]:
|
||||
logger.warning(
|
||||
"Non-transient LLM error with image content, retrying without images"
|
||||
)
|
||||
retry_kw = dict(kw)
|
||||
retry_kw["messages"] = stripped
|
||||
return await call(**retry_kw)
|
||||
return response
|
||||
|
||||
if not persistent and attempt > len(delays):
|
||||
break
|
||||
|
||||
base_delay = delays[min(attempt - 1, len(delays) - 1)]
|
||||
delay = self._extract_retry_after(response.content) or base_delay
|
||||
if persistent:
|
||||
delay = min(delay, self._PERSISTENT_MAX_DELAY)
|
||||
|
||||
logger.warning(
|
||||
"LLM transient error (attempt {}/{}), retrying in {}s: {}",
|
||||
attempt, len(self._CHAT_RETRY_DELAYS), delay,
|
||||
"LLM transient error (attempt {}{}), retrying in {}s: {}",
|
||||
attempt,
|
||||
"+" if persistent and attempt > len(delays) else f"/{len(delays)}",
|
||||
int(round(delay)),
|
||||
(response.content or "")[:120].lower(),
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
await self._sleep_with_heartbeat(
|
||||
delay,
|
||||
attempt=attempt,
|
||||
persistent=persistent,
|
||||
on_retry_wait=on_retry_wait,
|
||||
)
|
||||
|
||||
return await self._safe_chat(**kw)
|
||||
return last_response if last_response is not None else await call(**kw)
|
||||
|
||||
@abstractmethod
|
||||
def get_default_model(self) -> str:
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import os
|
||||
import secrets
|
||||
@ -20,7 +21,6 @@ if TYPE_CHECKING:
|
||||
|
||||
_ALLOWED_MSG_KEYS = frozenset({
|
||||
"role", "content", "tool_calls", "tool_call_id", "name",
|
||||
"reasoning_content", "extra_content",
|
||||
})
|
||||
_ALNUM = string.ascii_letters + string.digits
|
||||
|
||||
@ -572,16 +572,33 @@ class OpenAICompatProvider(LLMProvider):
|
||||
)
|
||||
kwargs["stream"] = True
|
||||
kwargs["stream_options"] = {"include_usage": True}
|
||||
idle_timeout_s = int(os.environ.get("NANOBOT_STREAM_IDLE_TIMEOUT_S", "90"))
|
||||
try:
|
||||
stream = await self._client.chat.completions.create(**kwargs)
|
||||
chunks: list[Any] = []
|
||||
async for chunk in stream:
|
||||
stream_iter = stream.__aiter__()
|
||||
while True:
|
||||
try:
|
||||
chunk = await asyncio.wait_for(
|
||||
stream_iter.__anext__(),
|
||||
timeout=idle_timeout_s,
|
||||
)
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
chunks.append(chunk)
|
||||
if on_content_delta and chunk.choices:
|
||||
text = getattr(chunk.choices[0].delta, "content", None)
|
||||
if text:
|
||||
await on_content_delta(text)
|
||||
return self._parse_chunks(chunks)
|
||||
except asyncio.TimeoutError:
|
||||
return LLMResponse(
|
||||
content=(
|
||||
f"Error calling LLM: stream stalled for more than "
|
||||
f"{idle_timeout_s} seconds"
|
||||
),
|
||||
finish_reason="error",
|
||||
)
|
||||
except Exception as e:
|
||||
return self._handle_error(e)
|
||||
|
||||
|
||||
@ -10,20 +10,12 @@ from typing import Any
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.config.paths import get_legacy_sessions_dir
|
||||
from nanobot.utils.helpers import ensure_dir, safe_filename
|
||||
from nanobot.utils.helpers import ensure_dir, find_legal_message_start, safe_filename
|
||||
|
||||
|
||||
@dataclass
|
||||
class Session:
|
||||
"""
|
||||
A conversation session.
|
||||
|
||||
Stores messages in JSONL format for easy reading and persistence.
|
||||
|
||||
Important: Messages are append-only for LLM cache efficiency.
|
||||
The consolidation process writes summaries to MEMORY.md/HISTORY.md
|
||||
but does NOT modify the messages list or get_history() output.
|
||||
"""
|
||||
"""A conversation session."""
|
||||
|
||||
key: str # channel:chat_id
|
||||
messages: list[dict[str, Any]] = field(default_factory=list)
|
||||
@ -43,43 +35,19 @@ class Session:
|
||||
self.messages.append(msg)
|
||||
self.updated_at = datetime.now()
|
||||
|
||||
@staticmethod
|
||||
def _find_legal_start(messages: list[dict[str, Any]]) -> int:
|
||||
"""Find first index where every tool result has a matching assistant tool_call."""
|
||||
declared: set[str] = set()
|
||||
start = 0
|
||||
for i, msg in enumerate(messages):
|
||||
role = msg.get("role")
|
||||
if role == "assistant":
|
||||
for tc in msg.get("tool_calls") or []:
|
||||
if isinstance(tc, dict) and tc.get("id"):
|
||||
declared.add(str(tc["id"]))
|
||||
elif role == "tool":
|
||||
tid = msg.get("tool_call_id")
|
||||
if tid and str(tid) not in declared:
|
||||
start = i + 1
|
||||
declared.clear()
|
||||
for prev in messages[start:i + 1]:
|
||||
if prev.get("role") == "assistant":
|
||||
for tc in prev.get("tool_calls") or []:
|
||||
if isinstance(tc, dict) and tc.get("id"):
|
||||
declared.add(str(tc["id"]))
|
||||
return start
|
||||
|
||||
def get_history(self, max_messages: int = 500) -> list[dict[str, Any]]:
|
||||
"""Return unconsolidated messages for LLM input, aligned to a legal tool-call boundary."""
|
||||
unconsolidated = self.messages[self.last_consolidated:]
|
||||
sliced = unconsolidated[-max_messages:]
|
||||
|
||||
# Drop leading non-user messages to avoid starting mid-turn when possible.
|
||||
# Avoid starting mid-turn when possible.
|
||||
for i, message in enumerate(sliced):
|
||||
if message.get("role") == "user":
|
||||
sliced = sliced[i:]
|
||||
break
|
||||
|
||||
# Some providers reject orphan tool results if the matching assistant
|
||||
# tool_calls message fell outside the fixed-size history window.
|
||||
start = self._find_legal_start(sliced)
|
||||
# Drop orphan tool results at the front.
|
||||
start = find_legal_message_start(sliced)
|
||||
if start:
|
||||
sliced = sliced[start:]
|
||||
|
||||
@ -115,7 +83,7 @@ class Session:
|
||||
retained = self.messages[start_idx:]
|
||||
|
||||
# Mirror get_history(): avoid persisting orphan tool results at the front.
|
||||
start = self._find_legal_start(retained)
|
||||
start = find_legal_message_start(retained)
|
||||
if start:
|
||||
retained = retained[start:]
|
||||
|
||||
|
||||
@ -3,7 +3,9 @@
|
||||
import base64
|
||||
import json
|
||||
import re
|
||||
import shutil
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
@ -56,11 +58,7 @@ def timestamp() -> str:
|
||||
|
||||
|
||||
def current_time_str(timezone: str | None = None) -> str:
|
||||
"""Human-readable current time with weekday and UTC offset.
|
||||
|
||||
When *timezone* is a valid IANA name (e.g. ``"Asia/Shanghai"``), the time
|
||||
is converted to that zone. Otherwise falls back to the host local time.
|
||||
"""
|
||||
"""Return the current time string."""
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
try:
|
||||
@ -76,12 +74,164 @@ def current_time_str(timezone: str | None = None) -> str:
|
||||
|
||||
|
||||
_UNSAFE_CHARS = re.compile(r'[<>:"/\\|?*]')
|
||||
_TOOL_RESULT_PREVIEW_CHARS = 1200
|
||||
_TOOL_RESULTS_DIR = ".nanobot/tool-results"
|
||||
_TOOL_RESULT_RETENTION_SECS = 7 * 24 * 60 * 60
|
||||
_TOOL_RESULT_MAX_BUCKETS = 32
|
||||
|
||||
def safe_filename(name: str) -> str:
|
||||
"""Replace unsafe path characters with underscores."""
|
||||
return _UNSAFE_CHARS.sub("_", name).strip()
|
||||
|
||||
|
||||
def image_placeholder_text(path: str | None, *, empty: str = "[image]") -> str:
|
||||
"""Build an image placeholder string."""
|
||||
return f"[image: {path}]" if path else empty
|
||||
|
||||
|
||||
def truncate_text(text: str, max_chars: int) -> str:
|
||||
"""Truncate text with a stable suffix."""
|
||||
if max_chars <= 0 or len(text) <= max_chars:
|
||||
return text
|
||||
return text[:max_chars] + "\n... (truncated)"
|
||||
|
||||
|
||||
def find_legal_message_start(messages: list[dict[str, Any]]) -> int:
|
||||
"""Find the first index whose tool results have matching assistant calls."""
|
||||
declared: set[str] = set()
|
||||
start = 0
|
||||
for i, msg in enumerate(messages):
|
||||
role = msg.get("role")
|
||||
if role == "assistant":
|
||||
for tc in msg.get("tool_calls") or []:
|
||||
if isinstance(tc, dict) and tc.get("id"):
|
||||
declared.add(str(tc["id"]))
|
||||
elif role == "tool":
|
||||
tid = msg.get("tool_call_id")
|
||||
if tid and str(tid) not in declared:
|
||||
start = i + 1
|
||||
declared.clear()
|
||||
for prev in messages[start : i + 1]:
|
||||
if prev.get("role") == "assistant":
|
||||
for tc in prev.get("tool_calls") or []:
|
||||
if isinstance(tc, dict) and tc.get("id"):
|
||||
declared.add(str(tc["id"]))
|
||||
return start
|
||||
|
||||
|
||||
def _stringify_text_blocks(content: list[dict[str, Any]]) -> str | None:
|
||||
parts: list[str] = []
|
||||
for block in content:
|
||||
if not isinstance(block, dict):
|
||||
return None
|
||||
if block.get("type") != "text":
|
||||
return None
|
||||
text = block.get("text")
|
||||
if not isinstance(text, str):
|
||||
return None
|
||||
parts.append(text)
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
def _render_tool_result_reference(
|
||||
filepath: Path,
|
||||
*,
|
||||
original_size: int,
|
||||
preview: str,
|
||||
truncated_preview: bool,
|
||||
) -> str:
|
||||
result = (
|
||||
f"[tool output persisted]\n"
|
||||
f"Full output saved to: {filepath}\n"
|
||||
f"Original size: {original_size} chars\n"
|
||||
f"Preview:\n{preview}"
|
||||
)
|
||||
if truncated_preview:
|
||||
result += "\n...\n(Read the saved file if you need the full output.)"
|
||||
return result
|
||||
|
||||
|
||||
def _bucket_mtime(path: Path) -> float:
|
||||
try:
|
||||
return path.stat().st_mtime
|
||||
except OSError:
|
||||
return 0.0
|
||||
|
||||
|
||||
def _cleanup_tool_result_buckets(root: Path, current_bucket: Path) -> None:
|
||||
siblings = [path for path in root.iterdir() if path.is_dir() and path != current_bucket]
|
||||
cutoff = time.time() - _TOOL_RESULT_RETENTION_SECS
|
||||
for path in siblings:
|
||||
if _bucket_mtime(path) < cutoff:
|
||||
shutil.rmtree(path, ignore_errors=True)
|
||||
keep = max(_TOOL_RESULT_MAX_BUCKETS - 1, 0)
|
||||
siblings = [path for path in siblings if path.exists()]
|
||||
if len(siblings) <= keep:
|
||||
return
|
||||
siblings.sort(key=_bucket_mtime, reverse=True)
|
||||
for path in siblings[keep:]:
|
||||
shutil.rmtree(path, ignore_errors=True)
|
||||
|
||||
|
||||
def _write_text_atomic(path: Path, content: str) -> None:
|
||||
tmp = path.with_name(f".{path.name}.{uuid.uuid4().hex}.tmp")
|
||||
try:
|
||||
tmp.write_text(content, encoding="utf-8")
|
||||
tmp.replace(path)
|
||||
finally:
|
||||
if tmp.exists():
|
||||
tmp.unlink(missing_ok=True)
|
||||
|
||||
|
||||
def maybe_persist_tool_result(
|
||||
workspace: Path | None,
|
||||
session_key: str | None,
|
||||
tool_call_id: str,
|
||||
content: Any,
|
||||
*,
|
||||
max_chars: int,
|
||||
) -> Any:
|
||||
"""Persist oversized tool output and replace it with a stable reference string."""
|
||||
if workspace is None or max_chars <= 0:
|
||||
return content
|
||||
|
||||
text_payload: str | None = None
|
||||
suffix = "txt"
|
||||
if isinstance(content, str):
|
||||
text_payload = content
|
||||
elif isinstance(content, list):
|
||||
text_payload = _stringify_text_blocks(content)
|
||||
if text_payload is None:
|
||||
return content
|
||||
suffix = "json"
|
||||
else:
|
||||
return content
|
||||
|
||||
if len(text_payload) <= max_chars:
|
||||
return content
|
||||
|
||||
root = ensure_dir(workspace / _TOOL_RESULTS_DIR)
|
||||
bucket = ensure_dir(root / safe_filename(session_key or "default"))
|
||||
try:
|
||||
_cleanup_tool_result_buckets(root, bucket)
|
||||
except Exception:
|
||||
pass
|
||||
path = bucket / f"{safe_filename(tool_call_id)}.{suffix}"
|
||||
if not path.exists():
|
||||
if suffix == "json" and isinstance(content, list):
|
||||
_write_text_atomic(path, json.dumps(content, ensure_ascii=False, indent=2))
|
||||
else:
|
||||
_write_text_atomic(path, text_payload)
|
||||
|
||||
preview = text_payload[:_TOOL_RESULT_PREVIEW_CHARS]
|
||||
return _render_tool_result_reference(
|
||||
path,
|
||||
original_size=len(text_payload),
|
||||
preview=preview,
|
||||
truncated_preview=len(text_payload) > _TOOL_RESULT_PREVIEW_CHARS,
|
||||
)
|
||||
|
||||
|
||||
def split_message(content: str, max_len: int = 2000) -> list[str]:
|
||||
"""
|
||||
Split content into chunks within max_len, preferring line breaks.
|
||||
|
||||
@ -71,3 +71,19 @@ def test_runtime_context_is_separate_untrusted_user_message(tmp_path) -> None:
|
||||
assert "Channel: cli" in user_content
|
||||
assert "Chat ID: direct" in user_content
|
||||
assert "Return exactly: OK" in user_content
|
||||
|
||||
|
||||
def test_subagent_result_does_not_create_consecutive_assistant_messages(tmp_path) -> None:
|
||||
workspace = _make_workspace(tmp_path)
|
||||
builder = ContextBuilder(workspace)
|
||||
|
||||
messages = builder.build_messages(
|
||||
history=[{"role": "assistant", "content": "previous result"}],
|
||||
current_message="subagent result",
|
||||
channel="cli",
|
||||
chat_id="direct",
|
||||
current_role="assistant",
|
||||
)
|
||||
|
||||
for left, right in zip(messages, messages[1:]):
|
||||
assert not (left.get("role") == right.get("role") == "assistant")
|
||||
|
||||
@ -5,7 +5,9 @@ from nanobot.session.manager import Session
|
||||
|
||||
def _mk_loop() -> AgentLoop:
|
||||
loop = AgentLoop.__new__(AgentLoop)
|
||||
loop._TOOL_RESULT_MAX_CHARS = AgentLoop._TOOL_RESULT_MAX_CHARS
|
||||
from nanobot.config.schema import AgentDefaults
|
||||
|
||||
loop.max_tool_result_chars = AgentDefaults().max_tool_result_chars
|
||||
return loop
|
||||
|
||||
|
||||
@ -72,3 +74,129 @@ def test_save_turn_keeps_tool_results_under_16k() -> None:
|
||||
)
|
||||
|
||||
assert session.messages[0]["content"] == content
|
||||
|
||||
|
||||
def test_restore_runtime_checkpoint_rehydrates_completed_and_pending_tools() -> None:
|
||||
loop = _mk_loop()
|
||||
session = Session(
|
||||
key="test:checkpoint",
|
||||
metadata={
|
||||
AgentLoop._RUNTIME_CHECKPOINT_KEY: {
|
||||
"assistant_message": {
|
||||
"role": "assistant",
|
||||
"content": "working",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_done",
|
||||
"type": "function",
|
||||
"function": {"name": "read_file", "arguments": "{}"},
|
||||
},
|
||||
{
|
||||
"id": "call_pending",
|
||||
"type": "function",
|
||||
"function": {"name": "exec", "arguments": "{}"},
|
||||
},
|
||||
],
|
||||
},
|
||||
"completed_tool_results": [
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_done",
|
||||
"name": "read_file",
|
||||
"content": "ok",
|
||||
}
|
||||
],
|
||||
"pending_tool_calls": [
|
||||
{
|
||||
"id": "call_pending",
|
||||
"type": "function",
|
||||
"function": {"name": "exec", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
restored = loop._restore_runtime_checkpoint(session)
|
||||
|
||||
assert restored is True
|
||||
assert session.metadata.get(AgentLoop._RUNTIME_CHECKPOINT_KEY) is None
|
||||
assert session.messages[0]["role"] == "assistant"
|
||||
assert session.messages[1]["tool_call_id"] == "call_done"
|
||||
assert session.messages[2]["tool_call_id"] == "call_pending"
|
||||
assert "interrupted before this tool finished" in session.messages[2]["content"].lower()
|
||||
|
||||
|
||||
def test_restore_runtime_checkpoint_dedupes_overlapping_tail() -> None:
|
||||
loop = _mk_loop()
|
||||
session = Session(
|
||||
key="test:checkpoint-overlap",
|
||||
messages=[
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "working",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_done",
|
||||
"type": "function",
|
||||
"function": {"name": "read_file", "arguments": "{}"},
|
||||
},
|
||||
{
|
||||
"id": "call_pending",
|
||||
"type": "function",
|
||||
"function": {"name": "exec", "arguments": "{}"},
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_done",
|
||||
"name": "read_file",
|
||||
"content": "ok",
|
||||
},
|
||||
],
|
||||
metadata={
|
||||
AgentLoop._RUNTIME_CHECKPOINT_KEY: {
|
||||
"assistant_message": {
|
||||
"role": "assistant",
|
||||
"content": "working",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_done",
|
||||
"type": "function",
|
||||
"function": {"name": "read_file", "arguments": "{}"},
|
||||
},
|
||||
{
|
||||
"id": "call_pending",
|
||||
"type": "function",
|
||||
"function": {"name": "exec", "arguments": "{}"},
|
||||
},
|
||||
],
|
||||
},
|
||||
"completed_tool_results": [
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_done",
|
||||
"name": "read_file",
|
||||
"content": "ok",
|
||||
}
|
||||
],
|
||||
"pending_tool_calls": [
|
||||
{
|
||||
"id": "call_pending",
|
||||
"type": "function",
|
||||
"function": {"name": "exec", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
restored = loop._restore_runtime_checkpoint(session)
|
||||
|
||||
assert restored is True
|
||||
assert session.metadata.get(AgentLoop._RUNTIME_CHECKPOINT_KEY) is None
|
||||
assert len(session.messages) == 3
|
||||
assert session.messages[0]["role"] == "assistant"
|
||||
assert session.messages[1]["tool_call_id"] == "call_done"
|
||||
assert session.messages[2]["tool_call_id"] == "call_pending"
|
||||
|
||||
@ -2,12 +2,20 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.config.schema import AgentDefaults
|
||||
from nanobot.agent.tools.base import Tool
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||
|
||||
_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars
|
||||
|
||||
|
||||
def _make_loop(tmp_path):
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
@ -60,6 +68,7 @@ async def test_runner_preserves_reasoning_fields_and_tool_results():
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=3,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
))
|
||||
|
||||
assert result.final_content == "done"
|
||||
@ -135,6 +144,7 @@ async def test_runner_calls_hooks_in_order():
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=3,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
hook=RecordingHook(),
|
||||
))
|
||||
|
||||
@ -191,6 +201,7 @@ async def test_runner_streaming_hook_receives_deltas_and_end_signal():
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
hook=StreamingHook(),
|
||||
))
|
||||
|
||||
@ -219,6 +230,7 @@ async def test_runner_returns_max_iterations_fallback():
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=2,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
))
|
||||
|
||||
assert result.stop_reason == "max_iterations"
|
||||
@ -226,7 +238,8 @@ async def test_runner_returns_max_iterations_fallback():
|
||||
"I reached the maximum number of tool call iterations (2) "
|
||||
"without completing the task. You can try breaking the task into smaller steps."
|
||||
)
|
||||
|
||||
assert result.messages[-1]["role"] == "assistant"
|
||||
assert result.messages[-1]["content"] == result.final_content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_returns_structured_tool_error():
|
||||
@ -248,6 +261,7 @@ async def test_runner_returns_structured_tool_error():
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=2,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
fail_on_tool_error=True,
|
||||
))
|
||||
|
||||
@ -258,6 +272,232 @@ async def test_runner_returns_structured_tool_error():
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_persists_large_tool_results_for_follow_up_calls(tmp_path):
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
captured_second_call: list[dict] = []
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
return LLMResponse(
|
||||
content="working",
|
||||
tool_calls=[ToolCallRequest(id="call_big", name="list_dir", arguments={"path": "."})],
|
||||
usage={"prompt_tokens": 5, "completion_tokens": 3},
|
||||
)
|
||||
captured_second_call[:] = messages
|
||||
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
tools.execute = AsyncMock(return_value="x" * 20_000)
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "do task"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=2,
|
||||
workspace=tmp_path,
|
||||
session_key="test:runner",
|
||||
max_tool_result_chars=2048,
|
||||
))
|
||||
|
||||
assert result.final_content == "done"
|
||||
tool_message = next(msg for msg in captured_second_call if msg.get("role") == "tool")
|
||||
assert "[tool output persisted]" in tool_message["content"]
|
||||
assert "tool-results" in tool_message["content"]
|
||||
assert (tmp_path / ".nanobot" / "tool-results" / "test_runner" / "call_big.txt").exists()
|
||||
|
||||
|
||||
def test_persist_tool_result_prunes_old_session_buckets(tmp_path):
|
||||
from nanobot.utils.helpers import maybe_persist_tool_result
|
||||
|
||||
root = tmp_path / ".nanobot" / "tool-results"
|
||||
old_bucket = root / "old_session"
|
||||
recent_bucket = root / "recent_session"
|
||||
old_bucket.mkdir(parents=True)
|
||||
recent_bucket.mkdir(parents=True)
|
||||
(old_bucket / "old.txt").write_text("old", encoding="utf-8")
|
||||
(recent_bucket / "recent.txt").write_text("recent", encoding="utf-8")
|
||||
|
||||
stale = time.time() - (8 * 24 * 60 * 60)
|
||||
os.utime(old_bucket, (stale, stale))
|
||||
os.utime(old_bucket / "old.txt", (stale, stale))
|
||||
|
||||
persisted = maybe_persist_tool_result(
|
||||
tmp_path,
|
||||
"current:session",
|
||||
"call_big",
|
||||
"x" * 5000,
|
||||
max_chars=64,
|
||||
)
|
||||
|
||||
assert "[tool output persisted]" in persisted
|
||||
assert not old_bucket.exists()
|
||||
assert recent_bucket.exists()
|
||||
assert (root / "current_session" / "call_big.txt").exists()
|
||||
|
||||
|
||||
def test_persist_tool_result_leaves_no_temp_files(tmp_path):
|
||||
from nanobot.utils.helpers import maybe_persist_tool_result
|
||||
|
||||
root = tmp_path / ".nanobot" / "tool-results"
|
||||
maybe_persist_tool_result(
|
||||
tmp_path,
|
||||
"current:session",
|
||||
"call_big",
|
||||
"x" * 5000,
|
||||
max_chars=64,
|
||||
)
|
||||
|
||||
assert (root / "current_session" / "call_big.txt").exists()
|
||||
assert list((root / "current_session").glob("*.tmp")) == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_uses_raw_messages_when_context_governance_fails():
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
captured_messages: list[dict] = []
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
captured_messages[:] = messages
|
||||
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
initial_messages = [
|
||||
{"role": "system", "content": "system"},
|
||||
{"role": "user", "content": "hello"},
|
||||
]
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
runner._snip_history = MagicMock(side_effect=RuntimeError("boom")) # type: ignore[method-assign]
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=initial_messages,
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
))
|
||||
|
||||
assert result.final_content == "done"
|
||||
assert captured_messages == initial_messages
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_keeps_going_when_tool_result_persistence_fails():
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
captured_second_call: list[dict] = []
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
return LLMResponse(
|
||||
content="working",
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
|
||||
usage={"prompt_tokens": 5, "completion_tokens": 3},
|
||||
)
|
||||
captured_second_call[:] = messages
|
||||
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
tools.execute = AsyncMock(return_value="tool result")
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
with patch("nanobot.agent.runner.maybe_persist_tool_result", side_effect=RuntimeError("disk full")):
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "do task"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=2,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
))
|
||||
|
||||
assert result.final_content == "done"
|
||||
tool_message = next(msg for msg in captured_second_call if msg.get("role") == "tool")
|
||||
assert tool_message["content"] == "tool result"
|
||||
|
||||
|
||||
class _DelayTool(Tool):
|
||||
def __init__(self, name: str, *, delay: float, read_only: bool, shared_events: list[str]):
|
||||
self._name = name
|
||||
self._delay = delay
|
||||
self._read_only = read_only
|
||||
self._shared_events = shared_events
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict:
|
||||
return {"type": "object", "properties": {}, "required": []}
|
||||
|
||||
@property
|
||||
def read_only(self) -> bool:
|
||||
return self._read_only
|
||||
|
||||
async def execute(self, **kwargs):
|
||||
self._shared_events.append(f"start:{self._name}")
|
||||
await asyncio.sleep(self._delay)
|
||||
self._shared_events.append(f"end:{self._name}")
|
||||
return self._name
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_batches_read_only_tools_before_exclusive_work():
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
tools = ToolRegistry()
|
||||
shared_events: list[str] = []
|
||||
read_a = _DelayTool("read_a", delay=0.05, read_only=True, shared_events=shared_events)
|
||||
read_b = _DelayTool("read_b", delay=0.05, read_only=True, shared_events=shared_events)
|
||||
write_a = _DelayTool("write_a", delay=0.01, read_only=False, shared_events=shared_events)
|
||||
tools.register(read_a)
|
||||
tools.register(read_b)
|
||||
tools.register(write_a)
|
||||
|
||||
runner = AgentRunner(MagicMock())
|
||||
await runner._execute_tools(
|
||||
AgentRunSpec(
|
||||
initial_messages=[],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
concurrent_tools=True,
|
||||
),
|
||||
[
|
||||
ToolCallRequest(id="ro1", name="read_a", arguments={}),
|
||||
ToolCallRequest(id="ro2", name="read_b", arguments={}),
|
||||
ToolCallRequest(id="rw1", name="write_a", arguments={}),
|
||||
],
|
||||
)
|
||||
|
||||
assert shared_events[0:2] == ["start:read_a", "start:read_b"]
|
||||
assert "end:read_a" in shared_events and "end:read_b" in shared_events
|
||||
assert shared_events.index("end:read_a") < shared_events.index("start:write_a")
|
||||
assert shared_events.index("end:read_b") < shared_events.index("start:write_a")
|
||||
assert shared_events[-2:] == ["start:write_a", "end:write_a"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_loop_max_iterations_message_stays_stable(tmp_path):
|
||||
loop = _make_loop(tmp_path)
|
||||
@ -317,15 +557,20 @@ async def test_subagent_max_iterations_announces_existing_fallback(tmp_path, mon
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
|
||||
content="working",
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})],
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
|
||||
))
|
||||
mgr = SubagentManager(provider=provider, workspace=tmp_path, bus=bus)
|
||||
mgr = SubagentManager(
|
||||
provider=provider,
|
||||
workspace=tmp_path,
|
||||
bus=bus,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
)
|
||||
mgr._announce_result = AsyncMock()
|
||||
|
||||
async def fake_execute(self, name, arguments):
|
||||
async def fake_execute(self, **kwargs):
|
||||
return "tool result"
|
||||
|
||||
monkeypatch.setattr("nanobot.agent.tools.registry.ToolRegistry.execute", fake_execute)
|
||||
monkeypatch.setattr("nanobot.agent.tools.filesystem.ListDirTool.execute", fake_execute)
|
||||
|
||||
await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"})
|
||||
|
||||
|
||||
@ -8,6 +8,10 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.config.schema import AgentDefaults
|
||||
|
||||
_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars
|
||||
|
||||
|
||||
def _make_loop(*, exec_config=None):
|
||||
"""Create a minimal AgentLoop with mocked dependencies."""
|
||||
@ -186,7 +190,12 @@ class TestSubagentCancellation:
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
mgr = SubagentManager(provider=provider, workspace=MagicMock(), bus=bus)
|
||||
mgr = SubagentManager(
|
||||
provider=provider,
|
||||
workspace=MagicMock(),
|
||||
bus=bus,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
)
|
||||
|
||||
cancelled = asyncio.Event()
|
||||
|
||||
@ -214,7 +223,12 @@ class TestSubagentCancellation:
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
mgr = SubagentManager(provider=provider, workspace=MagicMock(), bus=bus)
|
||||
mgr = SubagentManager(
|
||||
provider=provider,
|
||||
workspace=MagicMock(),
|
||||
bus=bus,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
)
|
||||
assert await mgr.cancel_by_session("nonexistent") == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -236,19 +250,24 @@ class TestSubagentCancellation:
|
||||
if call_count["n"] == 1:
|
||||
return LLMResponse(
|
||||
content="thinking",
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})],
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
|
||||
reasoning_content="hidden reasoning",
|
||||
thinking_blocks=[{"type": "thinking", "thinking": "step"}],
|
||||
)
|
||||
captured_second_call[:] = messages
|
||||
return LLMResponse(content="done", tool_calls=[])
|
||||
provider.chat_with_retry = scripted_chat_with_retry
|
||||
mgr = SubagentManager(provider=provider, workspace=tmp_path, bus=bus)
|
||||
mgr = SubagentManager(
|
||||
provider=provider,
|
||||
workspace=tmp_path,
|
||||
bus=bus,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
)
|
||||
|
||||
async def fake_execute(self, name, arguments):
|
||||
async def fake_execute(self, **kwargs):
|
||||
return "tool result"
|
||||
|
||||
monkeypatch.setattr("nanobot.agent.tools.registry.ToolRegistry.execute", fake_execute)
|
||||
monkeypatch.setattr("nanobot.agent.tools.filesystem.ListDirTool.execute", fake_execute)
|
||||
|
||||
await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"})
|
||||
|
||||
@ -273,6 +292,7 @@ class TestSubagentCancellation:
|
||||
provider=provider,
|
||||
workspace=tmp_path,
|
||||
bus=bus,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
exec_config=ExecToolConfig(enable=False),
|
||||
)
|
||||
mgr._announce_result = AsyncMock()
|
||||
@ -304,20 +324,25 @@ class TestSubagentCancellation:
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
|
||||
content="thinking",
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})],
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
|
||||
))
|
||||
mgr = SubagentManager(provider=provider, workspace=tmp_path, bus=bus)
|
||||
mgr = SubagentManager(
|
||||
provider=provider,
|
||||
workspace=tmp_path,
|
||||
bus=bus,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
)
|
||||
mgr._announce_result = AsyncMock()
|
||||
|
||||
calls = {"n": 0}
|
||||
|
||||
async def fake_execute(self, name, arguments):
|
||||
async def fake_execute(self, **kwargs):
|
||||
calls["n"] += 1
|
||||
if calls["n"] == 1:
|
||||
return "first result"
|
||||
raise RuntimeError("boom")
|
||||
|
||||
monkeypatch.setattr("nanobot.agent.tools.registry.ToolRegistry.execute", fake_execute)
|
||||
monkeypatch.setattr("nanobot.agent.tools.filesystem.ListDirTool.execute", fake_execute)
|
||||
|
||||
await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"})
|
||||
|
||||
@ -340,15 +365,20 @@ class TestSubagentCancellation:
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
|
||||
content="thinking",
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})],
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})],
|
||||
))
|
||||
mgr = SubagentManager(provider=provider, workspace=tmp_path, bus=bus)
|
||||
mgr = SubagentManager(
|
||||
provider=provider,
|
||||
workspace=tmp_path,
|
||||
bus=bus,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
)
|
||||
mgr._announce_result = AsyncMock()
|
||||
|
||||
started = asyncio.Event()
|
||||
cancelled = asyncio.Event()
|
||||
|
||||
async def fake_execute(self, name, arguments):
|
||||
async def fake_execute(self, **kwargs):
|
||||
started.set()
|
||||
try:
|
||||
await asyncio.sleep(60)
|
||||
@ -356,7 +386,7 @@ class TestSubagentCancellation:
|
||||
cancelled.set()
|
||||
raise
|
||||
|
||||
monkeypatch.setattr("nanobot.agent.tools.registry.ToolRegistry.execute", fake_execute)
|
||||
monkeypatch.setattr("nanobot.agent.tools.filesystem.ListDirTool.execute", fake_execute)
|
||||
|
||||
task = asyncio.create_task(
|
||||
mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"})
|
||||
@ -364,7 +394,7 @@ class TestSubagentCancellation:
|
||||
mgr._running_tasks["sub-1"] = task
|
||||
mgr._session_tasks["test:c1"] = {"sub-1"}
|
||||
|
||||
await started.wait()
|
||||
await asyncio.wait_for(started.wait(), timeout=1.0)
|
||||
|
||||
count = await mgr.cancel_by_session("test:c1")
|
||||
|
||||
|
||||
@ -594,7 +594,7 @@ async def test_send_stops_typing_after_send() -> None:
|
||||
typing_channel.typing_enter_hook = slow_typing
|
||||
|
||||
await channel._start_typing(typing_channel)
|
||||
await start.wait()
|
||||
await asyncio.wait_for(start.wait(), timeout=1.0)
|
||||
|
||||
await channel.send(OutboundMessage(channel="discord", chat_id="123", content="hello"))
|
||||
release.set()
|
||||
@ -614,7 +614,7 @@ async def test_send_stops_typing_after_send() -> None:
|
||||
typing_channel.typing_enter_hook = slow_typing_progress
|
||||
|
||||
await channel._start_typing(typing_channel)
|
||||
await start.wait()
|
||||
await asyncio.wait_for(start.wait(), timeout=1.0)
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
@ -665,7 +665,7 @@ async def test_start_typing_uses_typing_context_when_trigger_typing_missing() ->
|
||||
|
||||
typing_channel = _NoTriggerChannel(channel_id=123)
|
||||
await channel._start_typing(typing_channel) # type: ignore[arg-type]
|
||||
await entered.wait()
|
||||
await asyncio.wait_for(entered.wait(), timeout=1.0)
|
||||
|
||||
assert "123" in channel._typing_tasks
|
||||
|
||||
|
||||
@ -8,6 +8,7 @@ Validates that:
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
@ -53,6 +54,15 @@ def _fake_tool_call_response() -> SimpleNamespace:
|
||||
return SimpleNamespace(choices=[choice], usage=usage)
|
||||
|
||||
|
||||
class _StalledStream:
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
await asyncio.sleep(3600)
|
||||
raise StopAsyncIteration
|
||||
|
||||
|
||||
def test_openrouter_spec_is_gateway() -> None:
|
||||
spec = find_by_name("openrouter")
|
||||
assert spec is not None
|
||||
@ -214,3 +224,54 @@ def test_openai_model_passthrough() -> None:
|
||||
spec=spec,
|
||||
)
|
||||
assert provider.get_default_model() == "gpt-4o"
|
||||
|
||||
|
||||
def test_openai_compat_strips_message_level_reasoning_fields() -> None:
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
provider = OpenAICompatProvider()
|
||||
|
||||
sanitized = provider._sanitize_messages([
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "done",
|
||||
"reasoning_content": "hidden",
|
||||
"extra_content": {"debug": True},
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {"name": "fn", "arguments": "{}"},
|
||||
"extra_content": {"google": {"thought_signature": "sig"}},
|
||||
}
|
||||
],
|
||||
}
|
||||
])
|
||||
|
||||
assert "reasoning_content" not in sanitized[0]
|
||||
assert "extra_content" not in sanitized[0]
|
||||
assert sanitized[0]["tool_calls"][0]["extra_content"] == {"google": {"thought_signature": "sig"}}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_compat_stream_watchdog_returns_error_on_stall(monkeypatch) -> None:
|
||||
monkeypatch.setenv("NANOBOT_STREAM_IDLE_TIMEOUT_S", "0")
|
||||
mock_create = AsyncMock(return_value=_StalledStream())
|
||||
spec = find_by_name("openai")
|
||||
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
|
||||
client_instance = MockClient.return_value
|
||||
client_instance.chat.completions.create = mock_create
|
||||
|
||||
provider = OpenAICompatProvider(
|
||||
api_key="sk-test-key",
|
||||
default_model="gpt-4o",
|
||||
spec=spec,
|
||||
)
|
||||
result = await provider.chat_stream(
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
model="gpt-4o",
|
||||
)
|
||||
|
||||
assert result.finish_reason == "error"
|
||||
assert result.content is not None
|
||||
assert "stream stalled" in result.content
|
||||
|
||||
@ -211,3 +211,32 @@ async def test_image_fallback_without_meta_uses_default_placeholder() -> None:
|
||||
content = msg.get("content")
|
||||
if isinstance(content, list):
|
||||
assert any("[image omitted]" in (b.get("text") or "") for b in content)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_with_retry_uses_retry_after_and_emits_wait_progress(monkeypatch) -> None:
|
||||
provider = ScriptedProvider([
|
||||
LLMResponse(content="429 rate limit, retry after 7s", finish_reason="error"),
|
||||
LLMResponse(content="ok"),
|
||||
])
|
||||
delays: list[float] = []
|
||||
progress: list[str] = []
|
||||
|
||||
async def _fake_sleep(delay: float) -> None:
|
||||
delays.append(delay)
|
||||
|
||||
async def _progress(msg: str) -> None:
|
||||
progress.append(msg)
|
||||
|
||||
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
|
||||
|
||||
response = await provider.chat_with_retry(
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
on_retry_wait=_progress,
|
||||
)
|
||||
|
||||
assert response.content == "ok"
|
||||
assert delays == [7.0]
|
||||
assert progress and "7s" in progress[0]
|
||||
|
||||
|
||||
|
||||
@ -196,7 +196,7 @@ async def test_execute_re_raises_external_cancellation() -> None:
|
||||
|
||||
wrapper = _make_wrapper(SimpleNamespace(call_tool=call_tool), timeout=10)
|
||||
task = asyncio.create_task(wrapper.execute())
|
||||
await started.wait()
|
||||
await asyncio.wait_for(started.wait(), timeout=1.0)
|
||||
|
||||
task.cancel()
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user