From fbedf7ad77a9999a2462ece74e97255e2e9ecb70 Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Wed, 1 Apr 2026 19:12:49 +0000 Subject: [PATCH] feat: harden agent runtime for long-running tasks --- nanobot/agent/context.py | 25 +- nanobot/agent/loop.py | 154 ++++++++-- nanobot/agent/runner.py | 305 ++++++++++++++++++-- nanobot/agent/subagent.py | 3 + nanobot/agent/tools/base.py | 15 + nanobot/agent/tools/filesystem.py | 8 + nanobot/agent/tools/registry.py | 35 ++- nanobot/agent/tools/shell.py | 4 + nanobot/agent/tools/web.py | 8 + nanobot/cli/commands.py | 9 + nanobot/config/schema.py | 5 +- nanobot/nanobot.py | 3 + nanobot/providers/anthropic_provider.py | 26 +- nanobot/providers/base.py | 149 +++++++--- nanobot/providers/openai_compat_provider.py | 21 +- nanobot/session/manager.py | 44 +-- nanobot/utils/helpers.py | 160 +++++++++- tests/agent/test_context_prompt_cache.py | 16 + tests/agent/test_loop_save_turn.py | 130 ++++++++- tests/agent/test_runner.py | 255 +++++++++++++++- tests/agent/test_task_cancel.py | 60 +++- tests/channels/test_discord_channel.py | 6 +- tests/providers/test_litellm_kwargs.py | 61 ++++ tests/providers/test_provider_retry.py | 29 ++ tests/tools/test_mcp_tool.py | 2 +- 25 files changed, 1348 insertions(+), 185 deletions(-) diff --git a/nanobot/agent/context.py b/nanobot/agent/context.py index ce69d247b..8ce2873a9 100644 --- a/nanobot/agent/context.py +++ b/nanobot/agent/context.py @@ -110,6 +110,20 @@ IMPORTANT: To send files (images, documents, audio, video) to the user, you MUST lines += [f"Channel: {channel}", f"Chat ID: {chat_id}"] return ContextBuilder._RUNTIME_CONTEXT_TAG + "\n" + "\n".join(lines) + @staticmethod + def _merge_message_content(left: Any, right: Any) -> str | list[dict[str, Any]]: + if isinstance(left, str) and isinstance(right, str): + return f"{left}\n\n{right}" if left else right + + def _to_blocks(value: Any) -> list[dict[str, Any]]: + if isinstance(value, list): + return [item if isinstance(item, dict) else {"type": "text", "text": str(item)} for item in value] + if value is None: + return [] + return [{"type": "text", "text": str(value)}] + + return _to_blocks(left) + _to_blocks(right) + def _load_bootstrap_files(self) -> str: """Load all bootstrap files from workspace.""" parts = [] @@ -142,12 +156,17 @@ IMPORTANT: To send files (images, documents, audio, video) to the user, you MUST merged = f"{runtime_ctx}\n\n{user_content}" else: merged = [{"type": "text", "text": runtime_ctx}] + user_content - - return [ + messages = [ {"role": "system", "content": self.build_system_prompt(skill_names)}, *history, - {"role": current_role, "content": merged}, ] + if messages[-1].get("role") == current_role: + last = dict(messages[-1]) + last["content"] = self._merge_message_content(last.get("content"), merged) + messages[-1] = last + return messages + messages.append({"role": current_role, "content": merged}) + return messages def _build_user_content(self, text: str, media: list[str] | None) -> str | list[dict[str, Any]]: """Build user message content with optional base64-encoded images.""" diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index a9dc589e8..d231ba9a5 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -29,8 +29,10 @@ from nanobot.agent.tools.web import WebFetchTool, WebSearchTool from nanobot.bus.events import InboundMessage, OutboundMessage from nanobot.command import CommandContext, CommandRouter, register_builtin_commands from nanobot.bus.queue import MessageBus +from nanobot.config.schema import AgentDefaults from nanobot.providers.base import LLMProvider from nanobot.session.manager import Session, SessionManager +from nanobot.utils.helpers import image_placeholder_text, truncate_text if TYPE_CHECKING: from nanobot.config.schema import ChannelsConfig, ExecToolConfig, WebSearchConfig @@ -38,11 +40,7 @@ if TYPE_CHECKING: class _LoopHook(AgentHook): - """Core lifecycle hook for the main agent loop. - - Handles streaming delta relay, progress reporting, tool-call logging, - and think-tag stripping for the built-in agent path. - """ + """Core hook for the main loop.""" def __init__( self, @@ -102,11 +100,7 @@ class _LoopHook(AgentHook): class _LoopHookChain(AgentHook): - """Run the core loop hook first, then best-effort extra hooks. - - This preserves the historical failure behavior of ``_LoopHook`` while still - letting user-supplied hooks opt into ``CompositeHook`` isolation. - """ + """Run the core hook before extra hooks.""" __slots__ = ("_primary", "_extras") @@ -154,7 +148,7 @@ class AgentLoop: 5. Sends responses back """ - _TOOL_RESULT_MAX_CHARS = 16_000 + _RUNTIME_CHECKPOINT_KEY = "runtime_checkpoint" def __init__( self, @@ -162,8 +156,11 @@ class AgentLoop: provider: LLMProvider, workspace: Path, model: str | None = None, - max_iterations: int = 40, - context_window_tokens: int = 65_536, + max_iterations: int | None = None, + context_window_tokens: int | None = None, + context_block_limit: int | None = None, + max_tool_result_chars: int | None = None, + provider_retry_mode: str = "standard", web_search_config: WebSearchConfig | None = None, web_proxy: str | None = None, exec_config: ExecToolConfig | None = None, @@ -177,13 +174,27 @@ class AgentLoop: ): from nanobot.config.schema import ExecToolConfig, WebSearchConfig + defaults = AgentDefaults() self.bus = bus self.channels_config = channels_config self.provider = provider self.workspace = workspace self.model = model or provider.get_default_model() - self.max_iterations = max_iterations - self.context_window_tokens = context_window_tokens + self.max_iterations = ( + max_iterations if max_iterations is not None else defaults.max_tool_iterations + ) + self.context_window_tokens = ( + context_window_tokens + if context_window_tokens is not None + else defaults.context_window_tokens + ) + self.context_block_limit = context_block_limit + self.max_tool_result_chars = ( + max_tool_result_chars + if max_tool_result_chars is not None + else defaults.max_tool_result_chars + ) + self.provider_retry_mode = provider_retry_mode self.web_search_config = web_search_config or WebSearchConfig() self.web_proxy = web_proxy self.exec_config = exec_config or ExecToolConfig() @@ -202,6 +213,7 @@ class AgentLoop: workspace=workspace, bus=bus, model=self.model, + max_tool_result_chars=self.max_tool_result_chars, web_search_config=self.web_search_config, web_proxy=web_proxy, exec_config=self.exec_config, @@ -313,6 +325,7 @@ class AgentLoop: on_stream: Callable[[str], Awaitable[None]] | None = None, on_stream_end: Callable[..., Awaitable[None]] | None = None, *, + session: Session | None = None, channel: str = "cli", chat_id: str = "direct", message_id: str | None = None, @@ -339,14 +352,27 @@ class AgentLoop: else loop_hook ) + async def _checkpoint(payload: dict[str, Any]) -> None: + if session is None: + return + self._set_runtime_checkpoint(session, payload) + result = await self.runner.run(AgentRunSpec( initial_messages=initial_messages, tools=self.tools, model=self.model, max_iterations=self.max_iterations, + max_tool_result_chars=self.max_tool_result_chars, hook=hook, error_message="Sorry, I encountered an error calling the AI model.", concurrent_tools=True, + workspace=self.workspace, + session_key=session.key if session else None, + context_window_tokens=self.context_window_tokens, + context_block_limit=self.context_block_limit, + provider_retry_mode=self.provider_retry_mode, + progress_callback=on_progress, + checkpoint_callback=_checkpoint, )) self._last_usage = result.usage if result.stop_reason == "max_iterations": @@ -484,6 +510,8 @@ class AgentLoop: logger.info("Processing system message from {}", msg.sender_id) key = f"{channel}:{chat_id}" session = self.sessions.get_or_create(key) + if self._restore_runtime_checkpoint(session): + self.sessions.save(session) await self.memory_consolidator.maybe_consolidate_by_tokens(session) self._set_tool_context(channel, chat_id, msg.metadata.get("message_id")) history = session.get_history(max_messages=0) @@ -494,10 +522,11 @@ class AgentLoop: current_role=current_role, ) final_content, _, all_msgs = await self._run_agent_loop( - messages, channel=channel, chat_id=chat_id, + messages, session=session, channel=channel, chat_id=chat_id, message_id=msg.metadata.get("message_id"), ) self._save_turn(session, all_msgs, 1 + len(history)) + self._clear_runtime_checkpoint(session) self.sessions.save(session) self._schedule_background(self.memory_consolidator.maybe_consolidate_by_tokens(session)) return OutboundMessage(channel=channel, chat_id=chat_id, @@ -508,6 +537,8 @@ class AgentLoop: key = session_key or msg.session_key session = self.sessions.get_or_create(key) + if self._restore_runtime_checkpoint(session): + self.sessions.save(session) # Slash commands raw = msg.content.strip() @@ -543,6 +574,7 @@ class AgentLoop: on_progress=on_progress or _bus_progress, on_stream=on_stream, on_stream_end=on_stream_end, + session=session, channel=msg.channel, chat_id=msg.chat_id, message_id=msg.metadata.get("message_id"), ) @@ -551,6 +583,7 @@ class AgentLoop: final_content = "I've completed processing but have no response to give." self._save_turn(session, all_msgs, 1 + len(history)) + self._clear_runtime_checkpoint(session) self.sessions.save(session) self._schedule_background(self.memory_consolidator.maybe_consolidate_by_tokens(session)) @@ -568,12 +601,6 @@ class AgentLoop: metadata=meta, ) - @staticmethod - def _image_placeholder(block: dict[str, Any]) -> dict[str, str]: - """Convert an inline image block into a compact text placeholder.""" - path = (block.get("_meta") or {}).get("path", "") - return {"type": "text", "text": f"[image: {path}]" if path else "[image]"} - def _sanitize_persisted_blocks( self, content: list[dict[str, Any]], @@ -600,13 +627,14 @@ class AgentLoop: block.get("type") == "image_url" and block.get("image_url", {}).get("url", "").startswith("data:image/") ): - filtered.append(self._image_placeholder(block)) + path = (block.get("_meta") or {}).get("path", "") + filtered.append({"type": "text", "text": image_placeholder_text(path)}) continue if block.get("type") == "text" and isinstance(block.get("text"), str): text = block["text"] - if truncate_text and len(text) > self._TOOL_RESULT_MAX_CHARS: - text = text[:self._TOOL_RESULT_MAX_CHARS] + "\n... (truncated)" + if truncate_text and len(text) > self.max_tool_result_chars: + text = truncate_text(text, self.max_tool_result_chars) filtered.append({**block, "text": text}) continue @@ -623,8 +651,8 @@ class AgentLoop: if role == "assistant" and not content and not entry.get("tool_calls"): continue # skip empty assistant messages — they poison session context if role == "tool": - if isinstance(content, str) and len(content) > self._TOOL_RESULT_MAX_CHARS: - entry["content"] = content[:self._TOOL_RESULT_MAX_CHARS] + "\n... (truncated)" + if isinstance(content, str) and len(content) > self.max_tool_result_chars: + entry["content"] = truncate_text(content, self.max_tool_result_chars) elif isinstance(content, list): filtered = self._sanitize_persisted_blocks(content, truncate_text=True) if not filtered: @@ -647,6 +675,78 @@ class AgentLoop: session.messages.append(entry) session.updated_at = datetime.now() + def _set_runtime_checkpoint(self, session: Session, payload: dict[str, Any]) -> None: + """Persist the latest in-flight turn state into session metadata.""" + session.metadata[self._RUNTIME_CHECKPOINT_KEY] = payload + self.sessions.save(session) + + def _clear_runtime_checkpoint(self, session: Session) -> None: + if self._RUNTIME_CHECKPOINT_KEY in session.metadata: + session.metadata.pop(self._RUNTIME_CHECKPOINT_KEY, None) + + @staticmethod + def _checkpoint_message_key(message: dict[str, Any]) -> tuple[Any, ...]: + return ( + message.get("role"), + message.get("content"), + message.get("tool_call_id"), + message.get("name"), + message.get("tool_calls"), + message.get("reasoning_content"), + message.get("thinking_blocks"), + ) + + def _restore_runtime_checkpoint(self, session: Session) -> bool: + """Materialize an unfinished turn into session history before a new request.""" + from datetime import datetime + + checkpoint = session.metadata.get(self._RUNTIME_CHECKPOINT_KEY) + if not isinstance(checkpoint, dict): + return False + + assistant_message = checkpoint.get("assistant_message") + completed_tool_results = checkpoint.get("completed_tool_results") or [] + pending_tool_calls = checkpoint.get("pending_tool_calls") or [] + + restored_messages: list[dict[str, Any]] = [] + if isinstance(assistant_message, dict): + restored = dict(assistant_message) + restored.setdefault("timestamp", datetime.now().isoformat()) + restored_messages.append(restored) + for message in completed_tool_results: + if isinstance(message, dict): + restored = dict(message) + restored.setdefault("timestamp", datetime.now().isoformat()) + restored_messages.append(restored) + for tool_call in pending_tool_calls: + if not isinstance(tool_call, dict): + continue + tool_id = tool_call.get("id") + name = ((tool_call.get("function") or {}).get("name")) or "tool" + restored_messages.append({ + "role": "tool", + "tool_call_id": tool_id, + "name": name, + "content": "Error: Task interrupted before this tool finished.", + "timestamp": datetime.now().isoformat(), + }) + + overlap = 0 + max_overlap = min(len(session.messages), len(restored_messages)) + for size in range(max_overlap, 0, -1): + existing = session.messages[-size:] + restored = restored_messages[:size] + if all( + self._checkpoint_message_key(left) == self._checkpoint_message_key(right) + for left, right in zip(existing, restored) + ): + overlap = size + break + session.messages.extend(restored_messages[overlap:]) + + self._clear_runtime_checkpoint(session) + return True + async def process_direct( self, content: str, diff --git a/nanobot/agent/runner.py b/nanobot/agent/runner.py index d6242a6b4..648073680 100644 --- a/nanobot/agent/runner.py +++ b/nanobot/agent/runner.py @@ -4,20 +4,29 @@ from __future__ import annotations import asyncio from dataclasses import dataclass, field +from pathlib import Path from typing import Any +from loguru import logger + from nanobot.agent.hook import AgentHook, AgentHookContext from nanobot.agent.tools.registry import ToolRegistry from nanobot.providers.base import LLMProvider, ToolCallRequest -from nanobot.utils.helpers import build_assistant_message +from nanobot.utils.helpers import ( + build_assistant_message, + estimate_message_tokens, + estimate_prompt_tokens_chain, + find_legal_message_start, + maybe_persist_tool_result, + truncate_text, +) _DEFAULT_MAX_ITERATIONS_MESSAGE = ( "I reached the maximum number of tool call iterations ({max_iterations}) " "without completing the task. You can try breaking the task into smaller steps." ) _DEFAULT_ERROR_MESSAGE = "Sorry, I encountered an error calling the AI model." - - +_SNIP_SAFETY_BUFFER = 1024 @dataclass(slots=True) class AgentRunSpec: """Configuration for a single agent execution.""" @@ -26,6 +35,7 @@ class AgentRunSpec: tools: ToolRegistry model: str max_iterations: int + max_tool_result_chars: int temperature: float | None = None max_tokens: int | None = None reasoning_effort: str | None = None @@ -34,6 +44,13 @@ class AgentRunSpec: max_iterations_message: str | None = None concurrent_tools: bool = False fail_on_tool_error: bool = False + workspace: Path | None = None + session_key: str | None = None + context_window_tokens: int | None = None + context_block_limit: int | None = None + provider_retry_mode: str = "standard" + progress_callback: Any | None = None + checkpoint_callback: Any | None = None @dataclass(slots=True) @@ -66,12 +83,25 @@ class AgentRunner: tool_events: list[dict[str, str]] = [] for iteration in range(spec.max_iterations): + try: + messages = self._apply_tool_result_budget(spec, messages) + messages_for_model = self._snip_history(spec, messages) + except Exception as exc: + logger.warning( + "Context governance failed on turn {} for {}: {}; using raw messages", + iteration, + spec.session_key or "default", + exc, + ) + messages_for_model = messages context = AgentHookContext(iteration=iteration, messages=messages) await hook.before_iteration(context) kwargs: dict[str, Any] = { - "messages": messages, + "messages": messages_for_model, "tools": spec.tools.get_definitions(), "model": spec.model, + "retry_mode": spec.provider_retry_mode, + "on_retry_wait": spec.progress_callback, } if spec.temperature is not None: kwargs["temperature"] = spec.temperature @@ -104,13 +134,25 @@ class AgentRunner: if hook.wants_streaming(): await hook.on_stream_end(context, resuming=True) - messages.append(build_assistant_message( + assistant_message = build_assistant_message( response.content or "", tool_calls=[tc.to_openai_tool_call() for tc in response.tool_calls], reasoning_content=response.reasoning_content, thinking_blocks=response.thinking_blocks, - )) + ) + messages.append(assistant_message) tools_used.extend(tc.name for tc in response.tool_calls) + await self._emit_checkpoint( + spec, + { + "phase": "awaiting_tools", + "iteration": iteration, + "model": spec.model, + "assistant_message": assistant_message, + "completed_tool_results": [], + "pending_tool_calls": [tc.to_openai_tool_call() for tc in response.tool_calls], + }, + ) await hook.before_execute_tools(context) @@ -125,13 +167,31 @@ class AgentRunner: context.stop_reason = stop_reason await hook.after_iteration(context) break + completed_tool_results: list[dict[str, Any]] = [] for tool_call, result in zip(response.tool_calls, results): - messages.append({ + tool_message = { "role": "tool", "tool_call_id": tool_call.id, "name": tool_call.name, - "content": result, - }) + "content": self._normalize_tool_result( + spec, + tool_call.id, + result, + ), + } + messages.append(tool_message) + completed_tool_results.append(tool_message) + await self._emit_checkpoint( + spec, + { + "phase": "tools_completed", + "iteration": iteration, + "model": spec.model, + "assistant_message": assistant_message, + "completed_tool_results": completed_tool_results, + "pending_tool_calls": [], + }, + ) await hook.after_iteration(context) continue @@ -143,6 +203,7 @@ class AgentRunner: final_content = clean or spec.error_message or _DEFAULT_ERROR_MESSAGE stop_reason = "error" error = final_content + self._append_final_message(messages, final_content) context.final_content = final_content context.error = error context.stop_reason = stop_reason @@ -154,6 +215,17 @@ class AgentRunner: reasoning_content=response.reasoning_content, thinking_blocks=response.thinking_blocks, )) + await self._emit_checkpoint( + spec, + { + "phase": "final_response", + "iteration": iteration, + "model": spec.model, + "assistant_message": messages[-1], + "completed_tool_results": [], + "pending_tool_calls": [], + }, + ) final_content = clean context.final_content = final_content context.stop_reason = stop_reason @@ -163,6 +235,7 @@ class AgentRunner: stop_reason = "max_iterations" template = spec.max_iterations_message or _DEFAULT_MAX_ITERATIONS_MESSAGE final_content = template.format(max_iterations=spec.max_iterations) + self._append_final_message(messages, final_content) return AgentRunResult( final_content=final_content, @@ -179,16 +252,17 @@ class AgentRunner: spec: AgentRunSpec, tool_calls: list[ToolCallRequest], ) -> tuple[list[Any], list[dict[str, str]], BaseException | None]: - if spec.concurrent_tools: - tool_results = await asyncio.gather(*( - self._run_tool(spec, tool_call) - for tool_call in tool_calls - )) - else: - tool_results = [ - await self._run_tool(spec, tool_call) - for tool_call in tool_calls - ] + batches = self._partition_tool_batches(spec, tool_calls) + tool_results: list[tuple[Any, dict[str, str], BaseException | None]] = [] + for batch in batches: + if spec.concurrent_tools and len(batch) > 1: + tool_results.extend(await asyncio.gather(*( + self._run_tool(spec, tool_call) + for tool_call in batch + ))) + else: + for tool_call in batch: + tool_results.append(await self._run_tool(spec, tool_call)) results: list[Any] = [] events: list[dict[str, str]] = [] @@ -205,8 +279,28 @@ class AgentRunner: spec: AgentRunSpec, tool_call: ToolCallRequest, ) -> tuple[Any, dict[str, str], BaseException | None]: + _HINT = "\n\n[Analyze the error above and try a different approach.]" + prepare_call = getattr(spec.tools, "prepare_call", None) + tool, params, prep_error = None, tool_call.arguments, None + if callable(prepare_call): + try: + prepared = prepare_call(tool_call.name, tool_call.arguments) + if isinstance(prepared, tuple) and len(prepared) == 3: + tool, params, prep_error = prepared + except Exception: + pass + if prep_error: + event = { + "name": tool_call.name, + "status": "error", + "detail": prep_error.split(": ", 1)[-1][:120], + } + return prep_error + _HINT, event, RuntimeError(prep_error) if spec.fail_on_tool_error else None try: - result = await spec.tools.execute(tool_call.name, tool_call.arguments) + if tool is not None: + result = await tool.execute(**params) + else: + result = await spec.tools.execute(tool_call.name, params) except asyncio.CancelledError: raise except BaseException as exc: @@ -219,14 +313,175 @@ class AgentRunner: return f"Error: {type(exc).__name__}: {exc}", event, exc return f"Error: {type(exc).__name__}: {exc}", event, None + if isinstance(result, str) and result.startswith("Error"): + event = { + "name": tool_call.name, + "status": "error", + "detail": result.replace("\n", " ").strip()[:120], + } + if spec.fail_on_tool_error: + return result + _HINT, event, RuntimeError(result) + return result + _HINT, event, None + detail = "" if result is None else str(result) detail = detail.replace("\n", " ").strip() if not detail: detail = "(empty)" elif len(detail) > 120: detail = detail[:120] + "..." - return result, { - "name": tool_call.name, - "status": "error" if isinstance(result, str) and result.startswith("Error") else "ok", - "detail": detail, - }, None + return result, {"name": tool_call.name, "status": "ok", "detail": detail}, None + + async def _emit_checkpoint( + self, + spec: AgentRunSpec, + payload: dict[str, Any], + ) -> None: + callback = spec.checkpoint_callback + if callback is not None: + await callback(payload) + + @staticmethod + def _append_final_message(messages: list[dict[str, Any]], content: str | None) -> None: + if not content: + return + if ( + messages + and messages[-1].get("role") == "assistant" + and not messages[-1].get("tool_calls") + ): + if messages[-1].get("content") == content: + return + messages[-1] = build_assistant_message(content) + return + messages.append(build_assistant_message(content)) + + def _normalize_tool_result( + self, + spec: AgentRunSpec, + tool_call_id: str, + result: Any, + ) -> Any: + try: + content = maybe_persist_tool_result( + spec.workspace, + spec.session_key, + tool_call_id, + result, + max_chars=spec.max_tool_result_chars, + ) + except Exception as exc: + logger.warning( + "Tool result persist failed for {} in {}: {}; using raw result", + tool_call_id, + spec.session_key or "default", + exc, + ) + content = result + if isinstance(content, str) and len(content) > spec.max_tool_result_chars: + return truncate_text(content, spec.max_tool_result_chars) + return content + + def _apply_tool_result_budget( + self, + spec: AgentRunSpec, + messages: list[dict[str, Any]], + ) -> list[dict[str, Any]]: + updated = messages + for idx, message in enumerate(messages): + if message.get("role") != "tool": + continue + normalized = self._normalize_tool_result( + spec, + str(message.get("tool_call_id") or f"tool_{idx}"), + message.get("content"), + ) + if normalized != message.get("content"): + if updated is messages: + updated = [dict(m) for m in messages] + updated[idx]["content"] = normalized + return updated + + def _snip_history( + self, + spec: AgentRunSpec, + messages: list[dict[str, Any]], + ) -> list[dict[str, Any]]: + if not messages or not spec.context_window_tokens: + return messages + + provider_max_tokens = getattr(getattr(self.provider, "generation", None), "max_tokens", 4096) + max_output = spec.max_tokens if isinstance(spec.max_tokens, int) else ( + provider_max_tokens if isinstance(provider_max_tokens, int) else 4096 + ) + budget = spec.context_block_limit or ( + spec.context_window_tokens - max_output - _SNIP_SAFETY_BUFFER + ) + if budget <= 0: + return messages + + estimate, _ = estimate_prompt_tokens_chain( + self.provider, + spec.model, + messages, + spec.tools.get_definitions(), + ) + if estimate <= budget: + return messages + + system_messages = [dict(msg) for msg in messages if msg.get("role") == "system"] + non_system = [dict(msg) for msg in messages if msg.get("role") != "system"] + if not non_system: + return messages + + system_tokens = sum(estimate_message_tokens(msg) for msg in system_messages) + remaining_budget = max(128, budget - system_tokens) + kept: list[dict[str, Any]] = [] + kept_tokens = 0 + for message in reversed(non_system): + msg_tokens = estimate_message_tokens(message) + if kept and kept_tokens + msg_tokens > remaining_budget: + break + kept.append(message) + kept_tokens += msg_tokens + kept.reverse() + + if kept: + for i, message in enumerate(kept): + if message.get("role") == "user": + kept = kept[i:] + break + start = find_legal_message_start(kept) + if start: + kept = kept[start:] + if not kept: + kept = non_system[-min(len(non_system), 4) :] + start = find_legal_message_start(kept) + if start: + kept = kept[start:] + return system_messages + kept + + def _partition_tool_batches( + self, + spec: AgentRunSpec, + tool_calls: list[ToolCallRequest], + ) -> list[list[ToolCallRequest]]: + if not spec.concurrent_tools: + return [[tool_call] for tool_call in tool_calls] + + batches: list[list[ToolCallRequest]] = [] + current: list[ToolCallRequest] = [] + for tool_call in tool_calls: + get_tool = getattr(spec.tools, "get", None) + tool = get_tool(tool_call.name) if callable(get_tool) else None + can_batch = bool(tool and tool.concurrency_safe) + if can_batch: + current.append(tool_call) + continue + if current: + batches.append(current) + current = [] + batches.append([tool_call]) + if current: + batches.append(current) + return batches + diff --git a/nanobot/agent/subagent.py b/nanobot/agent/subagent.py index 9d936f034..c7643a486 100644 --- a/nanobot/agent/subagent.py +++ b/nanobot/agent/subagent.py @@ -44,6 +44,7 @@ class SubagentManager: provider: LLMProvider, workspace: Path, bus: MessageBus, + max_tool_result_chars: int, model: str | None = None, web_search_config: "WebSearchConfig | None" = None, web_proxy: str | None = None, @@ -56,6 +57,7 @@ class SubagentManager: self.workspace = workspace self.bus = bus self.model = model or provider.get_default_model() + self.max_tool_result_chars = max_tool_result_chars self.web_search_config = web_search_config or WebSearchConfig() self.web_proxy = web_proxy self.exec_config = exec_config or ExecToolConfig() @@ -136,6 +138,7 @@ class SubagentManager: tools=tools, model=self.model, max_iterations=15, + max_tool_result_chars=self.max_tool_result_chars, hook=_SubagentHook(task_id), max_iterations_message="Task completed but no final response was generated.", error_message=None, diff --git a/nanobot/agent/tools/base.py b/nanobot/agent/tools/base.py index 4017f7cf6..f119f6908 100644 --- a/nanobot/agent/tools/base.py +++ b/nanobot/agent/tools/base.py @@ -53,6 +53,21 @@ class Tool(ABC): """JSON Schema for tool parameters.""" pass + @property + def read_only(self) -> bool: + """Whether this tool is side-effect free and safe to parallelize.""" + return False + + @property + def concurrency_safe(self) -> bool: + """Whether this tool can run alongside other concurrency-safe tools.""" + return self.read_only and not self.exclusive + + @property + def exclusive(self) -> bool: + """Whether this tool should run alone even if concurrency is enabled.""" + return False + @abstractmethod async def execute(self, **kwargs: Any) -> Any: """ diff --git a/nanobot/agent/tools/filesystem.py b/nanobot/agent/tools/filesystem.py index da7778da3..d4094e7f3 100644 --- a/nanobot/agent/tools/filesystem.py +++ b/nanobot/agent/tools/filesystem.py @@ -73,6 +73,10 @@ class ReadFileTool(_FsTool): "Use offset and limit to paginate through large files." ) + @property + def read_only(self) -> bool: + return True + @property def parameters(self) -> dict[str, Any]: return { @@ -344,6 +348,10 @@ class ListDirTool(_FsTool): "Common noise directories (.git, node_modules, __pycache__, etc.) are auto-ignored." ) + @property + def read_only(self) -> bool: + return True + @property def parameters(self) -> dict[str, Any]: return { diff --git a/nanobot/agent/tools/registry.py b/nanobot/agent/tools/registry.py index c24659a70..725706dce 100644 --- a/nanobot/agent/tools/registry.py +++ b/nanobot/agent/tools/registry.py @@ -35,22 +35,35 @@ class ToolRegistry: """Get all tool definitions in OpenAI format.""" return [tool.to_schema() for tool in self._tools.values()] + def prepare_call( + self, + name: str, + params: dict[str, Any], + ) -> tuple[Tool | None, dict[str, Any], str | None]: + """Resolve, cast, and validate one tool call.""" + tool = self._tools.get(name) + if not tool: + return None, params, ( + f"Error: Tool '{name}' not found. Available: {', '.join(self.tool_names)}" + ) + + cast_params = tool.cast_params(params) + errors = tool.validate_params(cast_params) + if errors: + return tool, cast_params, ( + f"Error: Invalid parameters for tool '{name}': " + "; ".join(errors) + ) + return tool, cast_params, None + async def execute(self, name: str, params: dict[str, Any]) -> Any: """Execute a tool by name with given parameters.""" _HINT = "\n\n[Analyze the error above and try a different approach.]" - - tool = self._tools.get(name) - if not tool: - return f"Error: Tool '{name}' not found. Available: {', '.join(self.tool_names)}" + tool, params, error = self.prepare_call(name, params) + if error: + return error + _HINT try: - # Attempt to cast parameters to match schema types - params = tool.cast_params(params) - - # Validate parameters - errors = tool.validate_params(params) - if errors: - return f"Error: Invalid parameters for tool '{name}': " + "; ".join(errors) + _HINT + assert tool is not None # guarded by prepare_call() result = await tool.execute(**params) if isinstance(result, str) and result.startswith("Error"): return result + _HINT diff --git a/nanobot/agent/tools/shell.py b/nanobot/agent/tools/shell.py index ed552b33e..89e3d0e8a 100644 --- a/nanobot/agent/tools/shell.py +++ b/nanobot/agent/tools/shell.py @@ -52,6 +52,10 @@ class ExecTool(Tool): def description(self) -> str: return "Execute a shell command and return its output. Use with caution." + @property + def exclusive(self) -> bool: + return True + @property def parameters(self) -> dict[str, Any]: return { diff --git a/nanobot/agent/tools/web.py b/nanobot/agent/tools/web.py index 9480e194f..1c0fde822 100644 --- a/nanobot/agent/tools/web.py +++ b/nanobot/agent/tools/web.py @@ -92,6 +92,10 @@ class WebSearchTool(Tool): self.config = config if config is not None else WebSearchConfig() self.proxy = proxy + @property + def read_only(self) -> bool: + return True + async def execute(self, query: str, count: int | None = None, **kwargs: Any) -> str: provider = self.config.provider.strip().lower() or "brave" n = min(max(count or self.config.max_results, 1), 10) @@ -234,6 +238,10 @@ class WebFetchTool(Tool): self.max_chars = max_chars self.proxy = proxy + @property + def read_only(self) -> bool: + return True + async def execute(self, url: str, extractMode: str = "markdown", maxChars: int | None = None, **kwargs: Any) -> Any: max_chars = maxChars or self.max_chars is_valid, error_msg = _validate_url_safe(url) diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py index 7f7d24f39..ad41355ee 100644 --- a/nanobot/cli/commands.py +++ b/nanobot/cli/commands.py @@ -539,6 +539,9 @@ def serve( model=runtime_config.agents.defaults.model, max_iterations=runtime_config.agents.defaults.max_tool_iterations, context_window_tokens=runtime_config.agents.defaults.context_window_tokens, + context_block_limit=runtime_config.agents.defaults.context_block_limit, + max_tool_result_chars=runtime_config.agents.defaults.max_tool_result_chars, + provider_retry_mode=runtime_config.agents.defaults.provider_retry_mode, web_search_config=runtime_config.tools.web.search, web_proxy=runtime_config.tools.web.proxy or None, exec_config=runtime_config.tools.exec, @@ -626,6 +629,9 @@ def gateway( model=config.agents.defaults.model, max_iterations=config.agents.defaults.max_tool_iterations, context_window_tokens=config.agents.defaults.context_window_tokens, + context_block_limit=config.agents.defaults.context_block_limit, + max_tool_result_chars=config.agents.defaults.max_tool_result_chars, + provider_retry_mode=config.agents.defaults.provider_retry_mode, web_search_config=config.tools.web.search, web_proxy=config.tools.web.proxy or None, exec_config=config.tools.exec, @@ -832,6 +838,9 @@ def agent( model=config.agents.defaults.model, max_iterations=config.agents.defaults.max_tool_iterations, context_window_tokens=config.agents.defaults.context_window_tokens, + context_block_limit=config.agents.defaults.context_block_limit, + max_tool_result_chars=config.agents.defaults.max_tool_result_chars, + provider_retry_mode=config.agents.defaults.provider_retry_mode, web_search_config=config.tools.web.search, web_proxy=config.tools.web.proxy or None, exec_config=config.tools.exec, diff --git a/nanobot/config/schema.py b/nanobot/config/schema.py index c4c927afd..602b8a911 100644 --- a/nanobot/config/schema.py +++ b/nanobot/config/schema.py @@ -38,8 +38,11 @@ class AgentDefaults(Base): ) max_tokens: int = 8192 context_window_tokens: int = 65_536 + context_block_limit: int | None = None temperature: float = 0.1 - max_tool_iterations: int = 40 + max_tool_iterations: int = 200 + max_tool_result_chars: int = 16_000 + provider_retry_mode: Literal["standard", "persistent"] = "standard" reasoning_effort: str | None = None # low / medium / high - enables LLM thinking mode timezone: str = "UTC" # IANA timezone, e.g. "Asia/Shanghai", "America/New_York" diff --git a/nanobot/nanobot.py b/nanobot/nanobot.py index 137688455..7e8dad0e6 100644 --- a/nanobot/nanobot.py +++ b/nanobot/nanobot.py @@ -73,6 +73,9 @@ class Nanobot: model=defaults.model, max_iterations=defaults.max_tool_iterations, context_window_tokens=defaults.context_window_tokens, + context_block_limit=defaults.context_block_limit, + max_tool_result_chars=defaults.max_tool_result_chars, + provider_retry_mode=defaults.provider_retry_mode, web_search_config=config.tools.web.search, web_proxy=config.tools.web.proxy or None, exec_config=config.tools.exec, diff --git a/nanobot/providers/anthropic_provider.py b/nanobot/providers/anthropic_provider.py index 3c789e730..a6d2519dd 100644 --- a/nanobot/providers/anthropic_provider.py +++ b/nanobot/providers/anthropic_provider.py @@ -2,6 +2,8 @@ from __future__ import annotations +import asyncio +import os import re import secrets import string @@ -427,13 +429,33 @@ class AnthropicProvider(LLMProvider): messages, tools, model, max_tokens, temperature, reasoning_effort, tool_choice, ) + idle_timeout_s = int(os.environ.get("NANOBOT_STREAM_IDLE_TIMEOUT_S", "90")) try: async with self._client.messages.stream(**kwargs) as stream: if on_content_delta: - async for text in stream.text_stream: + stream_iter = stream.text_stream.__aiter__() + while True: + try: + text = await asyncio.wait_for( + stream_iter.__anext__(), + timeout=idle_timeout_s, + ) + except StopAsyncIteration: + break await on_content_delta(text) - response = await stream.get_final_message() + response = await asyncio.wait_for( + stream.get_final_message(), + timeout=idle_timeout_s, + ) return self._parse_response(response) + except asyncio.TimeoutError: + return LLMResponse( + content=( + f"Error calling LLM: stream stalled for more than " + f"{idle_timeout_s} seconds" + ), + finish_reason="error", + ) except Exception as e: return LLMResponse(content=f"Error calling LLM: {e}", finish_reason="error") diff --git a/nanobot/providers/base.py b/nanobot/providers/base.py index 9ce2b0c63..c51f5ddaf 100644 --- a/nanobot/providers/base.py +++ b/nanobot/providers/base.py @@ -2,6 +2,7 @@ import asyncio import json +import re from abc import ABC, abstractmethod from collections.abc import Awaitable, Callable from dataclasses import dataclass, field @@ -9,6 +10,8 @@ from typing import Any from loguru import logger +from nanobot.utils.helpers import image_placeholder_text + @dataclass class ToolCallRequest: @@ -57,13 +60,7 @@ class LLMResponse: @dataclass(frozen=True) class GenerationSettings: - """Default generation parameters for LLM calls. - - Stored on the provider so every call site inherits the same defaults - without having to pass temperature / max_tokens / reasoning_effort - through every layer. Individual call sites can still override by - passing explicit keyword arguments to chat() / chat_with_retry(). - """ + """Default generation settings.""" temperature: float = 0.7 max_tokens: int = 4096 @@ -71,14 +68,11 @@ class GenerationSettings: class LLMProvider(ABC): - """ - Abstract base class for LLM providers. - - Implementations should handle the specifics of each provider's API - while maintaining a consistent interface. - """ + """Base class for LLM providers.""" _CHAT_RETRY_DELAYS = (1, 2, 4) + _PERSISTENT_MAX_DELAY = 60 + _RETRY_HEARTBEAT_CHUNK = 30 _TRANSIENT_ERROR_MARKERS = ( "429", "rate limit", @@ -208,7 +202,7 @@ class LLMProvider(ABC): for b in content: if isinstance(b, dict) and b.get("type") == "image_url": path = (b.get("_meta") or {}).get("path", "") - placeholder = f"[image: {path}]" if path else "[image omitted]" + placeholder = image_placeholder_text(path, empty="[image omitted]") new_content.append({"type": "text", "text": placeholder}) found = True else: @@ -273,6 +267,8 @@ class LLMProvider(ABC): reasoning_effort: object = _SENTINEL, tool_choice: str | dict[str, Any] | None = None, on_content_delta: Callable[[str], Awaitable[None]] | None = None, + retry_mode: str = "standard", + on_retry_wait: Callable[[str], Awaitable[None]] | None = None, ) -> LLMResponse: """Call chat_stream() with retry on transient provider failures.""" if max_tokens is self._SENTINEL: @@ -288,28 +284,13 @@ class LLMProvider(ABC): reasoning_effort=reasoning_effort, tool_choice=tool_choice, on_content_delta=on_content_delta, ) - - for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1): - response = await self._safe_chat_stream(**kw) - - if response.finish_reason != "error": - return response - - if not self._is_transient_error(response.content): - stripped = self._strip_image_content(messages) - if stripped is not None: - logger.warning("Non-transient LLM error with image content, retrying without images") - return await self._safe_chat_stream(**{**kw, "messages": stripped}) - return response - - logger.warning( - "LLM transient error (attempt {}/{}), retrying in {}s: {}", - attempt, len(self._CHAT_RETRY_DELAYS), delay, - (response.content or "")[:120].lower(), - ) - await asyncio.sleep(delay) - - return await self._safe_chat_stream(**kw) + return await self._run_with_retry( + self._safe_chat_stream, + kw, + messages, + retry_mode=retry_mode, + on_retry_wait=on_retry_wait, + ) async def chat_with_retry( self, @@ -320,6 +301,8 @@ class LLMProvider(ABC): temperature: object = _SENTINEL, reasoning_effort: object = _SENTINEL, tool_choice: str | dict[str, Any] | None = None, + retry_mode: str = "standard", + on_retry_wait: Callable[[str], Awaitable[None]] | None = None, ) -> LLMResponse: """Call chat() with retry on transient provider failures. @@ -339,28 +322,102 @@ class LLMProvider(ABC): max_tokens=max_tokens, temperature=temperature, reasoning_effort=reasoning_effort, tool_choice=tool_choice, ) + return await self._run_with_retry( + self._safe_chat, + kw, + messages, + retry_mode=retry_mode, + on_retry_wait=on_retry_wait, + ) - for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1): - response = await self._safe_chat(**kw) + @classmethod + def _extract_retry_after(cls, content: str | None) -> float | None: + text = (content or "").lower() + match = re.search(r"retry after\s+(\d+(?:\.\d+)?)\s*(ms|milliseconds|s|sec|secs|seconds|m|min|minutes)?", text) + if not match: + return None + value = float(match.group(1)) + unit = (match.group(2) or "s").lower() + if unit in {"ms", "milliseconds"}: + return max(0.1, value / 1000.0) + if unit in {"m", "min", "minutes"}: + return value * 60.0 + return value + async def _sleep_with_heartbeat( + self, + delay: float, + *, + attempt: int, + persistent: bool, + on_retry_wait: Callable[[str], Awaitable[None]] | None = None, + ) -> None: + remaining = max(0.0, delay) + while remaining > 0: + if on_retry_wait: + kind = "persistent retry" if persistent else "retry" + await on_retry_wait( + f"Model request failed, {kind} in {max(1, int(round(remaining)))}s " + f"(attempt {attempt})." + ) + chunk = min(remaining, self._RETRY_HEARTBEAT_CHUNK) + await asyncio.sleep(chunk) + remaining -= chunk + + async def _run_with_retry( + self, + call: Callable[..., Awaitable[LLMResponse]], + kw: dict[str, Any], + original_messages: list[dict[str, Any]], + *, + retry_mode: str, + on_retry_wait: Callable[[str], Awaitable[None]] | None, + ) -> LLMResponse: + attempt = 0 + delays = list(self._CHAT_RETRY_DELAYS) + persistent = retry_mode == "persistent" + last_response: LLMResponse | None = None + while True: + attempt += 1 + response = await call(**kw) if response.finish_reason != "error": return response + last_response = response if not self._is_transient_error(response.content): - stripped = self._strip_image_content(messages) - if stripped is not None: - logger.warning("Non-transient LLM error with image content, retrying without images") - return await self._safe_chat(**{**kw, "messages": stripped}) + stripped = self._strip_image_content(original_messages) + if stripped is not None and stripped != kw["messages"]: + logger.warning( + "Non-transient LLM error with image content, retrying without images" + ) + retry_kw = dict(kw) + retry_kw["messages"] = stripped + return await call(**retry_kw) return response + if not persistent and attempt > len(delays): + break + + base_delay = delays[min(attempt - 1, len(delays) - 1)] + delay = self._extract_retry_after(response.content) or base_delay + if persistent: + delay = min(delay, self._PERSISTENT_MAX_DELAY) + logger.warning( - "LLM transient error (attempt {}/{}), retrying in {}s: {}", - attempt, len(self._CHAT_RETRY_DELAYS), delay, + "LLM transient error (attempt {}{}), retrying in {}s: {}", + attempt, + "+" if persistent and attempt > len(delays) else f"/{len(delays)}", + int(round(delay)), (response.content or "")[:120].lower(), ) - await asyncio.sleep(delay) + await self._sleep_with_heartbeat( + delay, + attempt=attempt, + persistent=persistent, + on_retry_wait=on_retry_wait, + ) - return await self._safe_chat(**kw) + return last_response if last_response is not None else await call(**kw) @abstractmethod def get_default_model(self) -> str: diff --git a/nanobot/providers/openai_compat_provider.py b/nanobot/providers/openai_compat_provider.py index 397b8e797..2b7728c25 100644 --- a/nanobot/providers/openai_compat_provider.py +++ b/nanobot/providers/openai_compat_provider.py @@ -2,6 +2,7 @@ from __future__ import annotations +import asyncio import hashlib import os import secrets @@ -20,7 +21,6 @@ if TYPE_CHECKING: _ALLOWED_MSG_KEYS = frozenset({ "role", "content", "tool_calls", "tool_call_id", "name", - "reasoning_content", "extra_content", }) _ALNUM = string.ascii_letters + string.digits @@ -572,16 +572,33 @@ class OpenAICompatProvider(LLMProvider): ) kwargs["stream"] = True kwargs["stream_options"] = {"include_usage": True} + idle_timeout_s = int(os.environ.get("NANOBOT_STREAM_IDLE_TIMEOUT_S", "90")) try: stream = await self._client.chat.completions.create(**kwargs) chunks: list[Any] = [] - async for chunk in stream: + stream_iter = stream.__aiter__() + while True: + try: + chunk = await asyncio.wait_for( + stream_iter.__anext__(), + timeout=idle_timeout_s, + ) + except StopAsyncIteration: + break chunks.append(chunk) if on_content_delta and chunk.choices: text = getattr(chunk.choices[0].delta, "content", None) if text: await on_content_delta(text) return self._parse_chunks(chunks) + except asyncio.TimeoutError: + return LLMResponse( + content=( + f"Error calling LLM: stream stalled for more than " + f"{idle_timeout_s} seconds" + ), + finish_reason="error", + ) except Exception as e: return self._handle_error(e) diff --git a/nanobot/session/manager.py b/nanobot/session/manager.py index 537ba42d0..95e3916b9 100644 --- a/nanobot/session/manager.py +++ b/nanobot/session/manager.py @@ -10,20 +10,12 @@ from typing import Any from loguru import logger from nanobot.config.paths import get_legacy_sessions_dir -from nanobot.utils.helpers import ensure_dir, safe_filename +from nanobot.utils.helpers import ensure_dir, find_legal_message_start, safe_filename @dataclass class Session: - """ - A conversation session. - - Stores messages in JSONL format for easy reading and persistence. - - Important: Messages are append-only for LLM cache efficiency. - The consolidation process writes summaries to MEMORY.md/HISTORY.md - but does NOT modify the messages list or get_history() output. - """ + """A conversation session.""" key: str # channel:chat_id messages: list[dict[str, Any]] = field(default_factory=list) @@ -43,43 +35,19 @@ class Session: self.messages.append(msg) self.updated_at = datetime.now() - @staticmethod - def _find_legal_start(messages: list[dict[str, Any]]) -> int: - """Find first index where every tool result has a matching assistant tool_call.""" - declared: set[str] = set() - start = 0 - for i, msg in enumerate(messages): - role = msg.get("role") - if role == "assistant": - for tc in msg.get("tool_calls") or []: - if isinstance(tc, dict) and tc.get("id"): - declared.add(str(tc["id"])) - elif role == "tool": - tid = msg.get("tool_call_id") - if tid and str(tid) not in declared: - start = i + 1 - declared.clear() - for prev in messages[start:i + 1]: - if prev.get("role") == "assistant": - for tc in prev.get("tool_calls") or []: - if isinstance(tc, dict) and tc.get("id"): - declared.add(str(tc["id"])) - return start - def get_history(self, max_messages: int = 500) -> list[dict[str, Any]]: """Return unconsolidated messages for LLM input, aligned to a legal tool-call boundary.""" unconsolidated = self.messages[self.last_consolidated:] sliced = unconsolidated[-max_messages:] - # Drop leading non-user messages to avoid starting mid-turn when possible. + # Avoid starting mid-turn when possible. for i, message in enumerate(sliced): if message.get("role") == "user": sliced = sliced[i:] break - # Some providers reject orphan tool results if the matching assistant - # tool_calls message fell outside the fixed-size history window. - start = self._find_legal_start(sliced) + # Drop orphan tool results at the front. + start = find_legal_message_start(sliced) if start: sliced = sliced[start:] @@ -115,7 +83,7 @@ class Session: retained = self.messages[start_idx:] # Mirror get_history(): avoid persisting orphan tool results at the front. - start = self._find_legal_start(retained) + start = find_legal_message_start(retained) if start: retained = retained[start:] diff --git a/nanobot/utils/helpers.py b/nanobot/utils/helpers.py index a7c2c2574..6813c659e 100644 --- a/nanobot/utils/helpers.py +++ b/nanobot/utils/helpers.py @@ -3,7 +3,9 @@ import base64 import json import re +import shutil import time +import uuid from datetime import datetime from pathlib import Path from typing import Any @@ -56,11 +58,7 @@ def timestamp() -> str: def current_time_str(timezone: str | None = None) -> str: - """Human-readable current time with weekday and UTC offset. - - When *timezone* is a valid IANA name (e.g. ``"Asia/Shanghai"``), the time - is converted to that zone. Otherwise falls back to the host local time. - """ + """Return the current time string.""" from zoneinfo import ZoneInfo try: @@ -76,12 +74,164 @@ def current_time_str(timezone: str | None = None) -> str: _UNSAFE_CHARS = re.compile(r'[<>:"/\\|?*]') +_TOOL_RESULT_PREVIEW_CHARS = 1200 +_TOOL_RESULTS_DIR = ".nanobot/tool-results" +_TOOL_RESULT_RETENTION_SECS = 7 * 24 * 60 * 60 +_TOOL_RESULT_MAX_BUCKETS = 32 def safe_filename(name: str) -> str: """Replace unsafe path characters with underscores.""" return _UNSAFE_CHARS.sub("_", name).strip() +def image_placeholder_text(path: str | None, *, empty: str = "[image]") -> str: + """Build an image placeholder string.""" + return f"[image: {path}]" if path else empty + + +def truncate_text(text: str, max_chars: int) -> str: + """Truncate text with a stable suffix.""" + if max_chars <= 0 or len(text) <= max_chars: + return text + return text[:max_chars] + "\n... (truncated)" + + +def find_legal_message_start(messages: list[dict[str, Any]]) -> int: + """Find the first index whose tool results have matching assistant calls.""" + declared: set[str] = set() + start = 0 + for i, msg in enumerate(messages): + role = msg.get("role") + if role == "assistant": + for tc in msg.get("tool_calls") or []: + if isinstance(tc, dict) and tc.get("id"): + declared.add(str(tc["id"])) + elif role == "tool": + tid = msg.get("tool_call_id") + if tid and str(tid) not in declared: + start = i + 1 + declared.clear() + for prev in messages[start : i + 1]: + if prev.get("role") == "assistant": + for tc in prev.get("tool_calls") or []: + if isinstance(tc, dict) and tc.get("id"): + declared.add(str(tc["id"])) + return start + + +def _stringify_text_blocks(content: list[dict[str, Any]]) -> str | None: + parts: list[str] = [] + for block in content: + if not isinstance(block, dict): + return None + if block.get("type") != "text": + return None + text = block.get("text") + if not isinstance(text, str): + return None + parts.append(text) + return "\n".join(parts) + + +def _render_tool_result_reference( + filepath: Path, + *, + original_size: int, + preview: str, + truncated_preview: bool, +) -> str: + result = ( + f"[tool output persisted]\n" + f"Full output saved to: {filepath}\n" + f"Original size: {original_size} chars\n" + f"Preview:\n{preview}" + ) + if truncated_preview: + result += "\n...\n(Read the saved file if you need the full output.)" + return result + + +def _bucket_mtime(path: Path) -> float: + try: + return path.stat().st_mtime + except OSError: + return 0.0 + + +def _cleanup_tool_result_buckets(root: Path, current_bucket: Path) -> None: + siblings = [path for path in root.iterdir() if path.is_dir() and path != current_bucket] + cutoff = time.time() - _TOOL_RESULT_RETENTION_SECS + for path in siblings: + if _bucket_mtime(path) < cutoff: + shutil.rmtree(path, ignore_errors=True) + keep = max(_TOOL_RESULT_MAX_BUCKETS - 1, 0) + siblings = [path for path in siblings if path.exists()] + if len(siblings) <= keep: + return + siblings.sort(key=_bucket_mtime, reverse=True) + for path in siblings[keep:]: + shutil.rmtree(path, ignore_errors=True) + + +def _write_text_atomic(path: Path, content: str) -> None: + tmp = path.with_name(f".{path.name}.{uuid.uuid4().hex}.tmp") + try: + tmp.write_text(content, encoding="utf-8") + tmp.replace(path) + finally: + if tmp.exists(): + tmp.unlink(missing_ok=True) + + +def maybe_persist_tool_result( + workspace: Path | None, + session_key: str | None, + tool_call_id: str, + content: Any, + *, + max_chars: int, +) -> Any: + """Persist oversized tool output and replace it with a stable reference string.""" + if workspace is None or max_chars <= 0: + return content + + text_payload: str | None = None + suffix = "txt" + if isinstance(content, str): + text_payload = content + elif isinstance(content, list): + text_payload = _stringify_text_blocks(content) + if text_payload is None: + return content + suffix = "json" + else: + return content + + if len(text_payload) <= max_chars: + return content + + root = ensure_dir(workspace / _TOOL_RESULTS_DIR) + bucket = ensure_dir(root / safe_filename(session_key or "default")) + try: + _cleanup_tool_result_buckets(root, bucket) + except Exception: + pass + path = bucket / f"{safe_filename(tool_call_id)}.{suffix}" + if not path.exists(): + if suffix == "json" and isinstance(content, list): + _write_text_atomic(path, json.dumps(content, ensure_ascii=False, indent=2)) + else: + _write_text_atomic(path, text_payload) + + preview = text_payload[:_TOOL_RESULT_PREVIEW_CHARS] + return _render_tool_result_reference( + path, + original_size=len(text_payload), + preview=preview, + truncated_preview=len(text_payload) > _TOOL_RESULT_PREVIEW_CHARS, + ) + + def split_message(content: str, max_len: int = 2000) -> list[str]: """ Split content into chunks within max_len, preferring line breaks. diff --git a/tests/agent/test_context_prompt_cache.py b/tests/agent/test_context_prompt_cache.py index 6eb4b4f19..4484e5ed0 100644 --- a/tests/agent/test_context_prompt_cache.py +++ b/tests/agent/test_context_prompt_cache.py @@ -71,3 +71,19 @@ def test_runtime_context_is_separate_untrusted_user_message(tmp_path) -> None: assert "Channel: cli" in user_content assert "Chat ID: direct" in user_content assert "Return exactly: OK" in user_content + + +def test_subagent_result_does_not_create_consecutive_assistant_messages(tmp_path) -> None: + workspace = _make_workspace(tmp_path) + builder = ContextBuilder(workspace) + + messages = builder.build_messages( + history=[{"role": "assistant", "content": "previous result"}], + current_message="subagent result", + channel="cli", + chat_id="direct", + current_role="assistant", + ) + + for left, right in zip(messages, messages[1:]): + assert not (left.get("role") == right.get("role") == "assistant") diff --git a/tests/agent/test_loop_save_turn.py b/tests/agent/test_loop_save_turn.py index aed7653c3..8a0b54b86 100644 --- a/tests/agent/test_loop_save_turn.py +++ b/tests/agent/test_loop_save_turn.py @@ -5,7 +5,9 @@ from nanobot.session.manager import Session def _mk_loop() -> AgentLoop: loop = AgentLoop.__new__(AgentLoop) - loop._TOOL_RESULT_MAX_CHARS = AgentLoop._TOOL_RESULT_MAX_CHARS + from nanobot.config.schema import AgentDefaults + + loop.max_tool_result_chars = AgentDefaults().max_tool_result_chars return loop @@ -72,3 +74,129 @@ def test_save_turn_keeps_tool_results_under_16k() -> None: ) assert session.messages[0]["content"] == content + + +def test_restore_runtime_checkpoint_rehydrates_completed_and_pending_tools() -> None: + loop = _mk_loop() + session = Session( + key="test:checkpoint", + metadata={ + AgentLoop._RUNTIME_CHECKPOINT_KEY: { + "assistant_message": { + "role": "assistant", + "content": "working", + "tool_calls": [ + { + "id": "call_done", + "type": "function", + "function": {"name": "read_file", "arguments": "{}"}, + }, + { + "id": "call_pending", + "type": "function", + "function": {"name": "exec", "arguments": "{}"}, + }, + ], + }, + "completed_tool_results": [ + { + "role": "tool", + "tool_call_id": "call_done", + "name": "read_file", + "content": "ok", + } + ], + "pending_tool_calls": [ + { + "id": "call_pending", + "type": "function", + "function": {"name": "exec", "arguments": "{}"}, + } + ], + } + }, + ) + + restored = loop._restore_runtime_checkpoint(session) + + assert restored is True + assert session.metadata.get(AgentLoop._RUNTIME_CHECKPOINT_KEY) is None + assert session.messages[0]["role"] == "assistant" + assert session.messages[1]["tool_call_id"] == "call_done" + assert session.messages[2]["tool_call_id"] == "call_pending" + assert "interrupted before this tool finished" in session.messages[2]["content"].lower() + + +def test_restore_runtime_checkpoint_dedupes_overlapping_tail() -> None: + loop = _mk_loop() + session = Session( + key="test:checkpoint-overlap", + messages=[ + { + "role": "assistant", + "content": "working", + "tool_calls": [ + { + "id": "call_done", + "type": "function", + "function": {"name": "read_file", "arguments": "{}"}, + }, + { + "id": "call_pending", + "type": "function", + "function": {"name": "exec", "arguments": "{}"}, + }, + ], + }, + { + "role": "tool", + "tool_call_id": "call_done", + "name": "read_file", + "content": "ok", + }, + ], + metadata={ + AgentLoop._RUNTIME_CHECKPOINT_KEY: { + "assistant_message": { + "role": "assistant", + "content": "working", + "tool_calls": [ + { + "id": "call_done", + "type": "function", + "function": {"name": "read_file", "arguments": "{}"}, + }, + { + "id": "call_pending", + "type": "function", + "function": {"name": "exec", "arguments": "{}"}, + }, + ], + }, + "completed_tool_results": [ + { + "role": "tool", + "tool_call_id": "call_done", + "name": "read_file", + "content": "ok", + } + ], + "pending_tool_calls": [ + { + "id": "call_pending", + "type": "function", + "function": {"name": "exec", "arguments": "{}"}, + } + ], + } + }, + ) + + restored = loop._restore_runtime_checkpoint(session) + + assert restored is True + assert session.metadata.get(AgentLoop._RUNTIME_CHECKPOINT_KEY) is None + assert len(session.messages) == 3 + assert session.messages[0]["role"] == "assistant" + assert session.messages[1]["tool_call_id"] == "call_done" + assert session.messages[2]["tool_call_id"] == "call_pending" diff --git a/tests/agent/test_runner.py b/tests/agent/test_runner.py index 86b0ba710..f2a26820e 100644 --- a/tests/agent/test_runner.py +++ b/tests/agent/test_runner.py @@ -2,12 +2,20 @@ from __future__ import annotations +import asyncio +import os +import time from unittest.mock import AsyncMock, MagicMock, patch import pytest +from nanobot.config.schema import AgentDefaults +from nanobot.agent.tools.base import Tool +from nanobot.agent.tools.registry import ToolRegistry from nanobot.providers.base import LLMResponse, ToolCallRequest +_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars + def _make_loop(tmp_path): from nanobot.agent.loop import AgentLoop @@ -60,6 +68,7 @@ async def test_runner_preserves_reasoning_fields_and_tool_results(): tools=tools, model="test-model", max_iterations=3, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, )) assert result.final_content == "done" @@ -135,6 +144,7 @@ async def test_runner_calls_hooks_in_order(): tools=tools, model="test-model", max_iterations=3, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, hook=RecordingHook(), )) @@ -191,6 +201,7 @@ async def test_runner_streaming_hook_receives_deltas_and_end_signal(): tools=tools, model="test-model", max_iterations=1, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, hook=StreamingHook(), )) @@ -219,6 +230,7 @@ async def test_runner_returns_max_iterations_fallback(): tools=tools, model="test-model", max_iterations=2, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, )) assert result.stop_reason == "max_iterations" @@ -226,7 +238,8 @@ async def test_runner_returns_max_iterations_fallback(): "I reached the maximum number of tool call iterations (2) " "without completing the task. You can try breaking the task into smaller steps." ) - + assert result.messages[-1]["role"] == "assistant" + assert result.messages[-1]["content"] == result.final_content @pytest.mark.asyncio async def test_runner_returns_structured_tool_error(): @@ -248,6 +261,7 @@ async def test_runner_returns_structured_tool_error(): tools=tools, model="test-model", max_iterations=2, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, fail_on_tool_error=True, )) @@ -258,6 +272,232 @@ async def test_runner_returns_structured_tool_error(): ] +@pytest.mark.asyncio +async def test_runner_persists_large_tool_results_for_follow_up_calls(tmp_path): + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock() + captured_second_call: list[dict] = [] + call_count = {"n": 0} + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + if call_count["n"] == 1: + return LLMResponse( + content="working", + tool_calls=[ToolCallRequest(id="call_big", name="list_dir", arguments={"path": "."})], + usage={"prompt_tokens": 5, "completion_tokens": 3}, + ) + captured_second_call[:] = messages + return LLMResponse(content="done", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + tools.execute = AsyncMock(return_value="x" * 20_000) + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "do task"}], + tools=tools, + model="test-model", + max_iterations=2, + workspace=tmp_path, + session_key="test:runner", + max_tool_result_chars=2048, + )) + + assert result.final_content == "done" + tool_message = next(msg for msg in captured_second_call if msg.get("role") == "tool") + assert "[tool output persisted]" in tool_message["content"] + assert "tool-results" in tool_message["content"] + assert (tmp_path / ".nanobot" / "tool-results" / "test_runner" / "call_big.txt").exists() + + +def test_persist_tool_result_prunes_old_session_buckets(tmp_path): + from nanobot.utils.helpers import maybe_persist_tool_result + + root = tmp_path / ".nanobot" / "tool-results" + old_bucket = root / "old_session" + recent_bucket = root / "recent_session" + old_bucket.mkdir(parents=True) + recent_bucket.mkdir(parents=True) + (old_bucket / "old.txt").write_text("old", encoding="utf-8") + (recent_bucket / "recent.txt").write_text("recent", encoding="utf-8") + + stale = time.time() - (8 * 24 * 60 * 60) + os.utime(old_bucket, (stale, stale)) + os.utime(old_bucket / "old.txt", (stale, stale)) + + persisted = maybe_persist_tool_result( + tmp_path, + "current:session", + "call_big", + "x" * 5000, + max_chars=64, + ) + + assert "[tool output persisted]" in persisted + assert not old_bucket.exists() + assert recent_bucket.exists() + assert (root / "current_session" / "call_big.txt").exists() + + +def test_persist_tool_result_leaves_no_temp_files(tmp_path): + from nanobot.utils.helpers import maybe_persist_tool_result + + root = tmp_path / ".nanobot" / "tool-results" + maybe_persist_tool_result( + tmp_path, + "current:session", + "call_big", + "x" * 5000, + max_chars=64, + ) + + assert (root / "current_session" / "call_big.txt").exists() + assert list((root / "current_session").glob("*.tmp")) == [] + + +@pytest.mark.asyncio +async def test_runner_uses_raw_messages_when_context_governance_fails(): + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock() + captured_messages: list[dict] = [] + + async def chat_with_retry(*, messages, **kwargs): + captured_messages[:] = messages + return LLMResponse(content="done", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + initial_messages = [ + {"role": "system", "content": "system"}, + {"role": "user", "content": "hello"}, + ] + + runner = AgentRunner(provider) + runner._snip_history = MagicMock(side_effect=RuntimeError("boom")) # type: ignore[method-assign] + result = await runner.run(AgentRunSpec( + initial_messages=initial_messages, + tools=tools, + model="test-model", + max_iterations=1, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + )) + + assert result.final_content == "done" + assert captured_messages == initial_messages + + +@pytest.mark.asyncio +async def test_runner_keeps_going_when_tool_result_persistence_fails(): + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock() + captured_second_call: list[dict] = [] + call_count = {"n": 0} + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + if call_count["n"] == 1: + return LLMResponse( + content="working", + tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})], + usage={"prompt_tokens": 5, "completion_tokens": 3}, + ) + captured_second_call[:] = messages + return LLMResponse(content="done", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + tools.execute = AsyncMock(return_value="tool result") + + runner = AgentRunner(provider) + with patch("nanobot.agent.runner.maybe_persist_tool_result", side_effect=RuntimeError("disk full")): + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "do task"}], + tools=tools, + model="test-model", + max_iterations=2, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + )) + + assert result.final_content == "done" + tool_message = next(msg for msg in captured_second_call if msg.get("role") == "tool") + assert tool_message["content"] == "tool result" + + +class _DelayTool(Tool): + def __init__(self, name: str, *, delay: float, read_only: bool, shared_events: list[str]): + self._name = name + self._delay = delay + self._read_only = read_only + self._shared_events = shared_events + + @property + def name(self) -> str: + return self._name + + @property + def description(self) -> str: + return self._name + + @property + def parameters(self) -> dict: + return {"type": "object", "properties": {}, "required": []} + + @property + def read_only(self) -> bool: + return self._read_only + + async def execute(self, **kwargs): + self._shared_events.append(f"start:{self._name}") + await asyncio.sleep(self._delay) + self._shared_events.append(f"end:{self._name}") + return self._name + + +@pytest.mark.asyncio +async def test_runner_batches_read_only_tools_before_exclusive_work(): + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + tools = ToolRegistry() + shared_events: list[str] = [] + read_a = _DelayTool("read_a", delay=0.05, read_only=True, shared_events=shared_events) + read_b = _DelayTool("read_b", delay=0.05, read_only=True, shared_events=shared_events) + write_a = _DelayTool("write_a", delay=0.01, read_only=False, shared_events=shared_events) + tools.register(read_a) + tools.register(read_b) + tools.register(write_a) + + runner = AgentRunner(MagicMock()) + await runner._execute_tools( + AgentRunSpec( + initial_messages=[], + tools=tools, + model="test-model", + max_iterations=1, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + concurrent_tools=True, + ), + [ + ToolCallRequest(id="ro1", name="read_a", arguments={}), + ToolCallRequest(id="ro2", name="read_b", arguments={}), + ToolCallRequest(id="rw1", name="write_a", arguments={}), + ], + ) + + assert shared_events[0:2] == ["start:read_a", "start:read_b"] + assert "end:read_a" in shared_events and "end:read_b" in shared_events + assert shared_events.index("end:read_a") < shared_events.index("start:write_a") + assert shared_events.index("end:read_b") < shared_events.index("start:write_a") + assert shared_events[-2:] == ["start:write_a", "end:write_a"] + + @pytest.mark.asyncio async def test_loop_max_iterations_message_stays_stable(tmp_path): loop = _make_loop(tmp_path) @@ -317,15 +557,20 @@ async def test_subagent_max_iterations_announces_existing_fallback(tmp_path, mon provider.get_default_model.return_value = "test-model" provider.chat_with_retry = AsyncMock(return_value=LLMResponse( content="working", - tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})], + tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})], )) - mgr = SubagentManager(provider=provider, workspace=tmp_path, bus=bus) + mgr = SubagentManager( + provider=provider, + workspace=tmp_path, + bus=bus, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + ) mgr._announce_result = AsyncMock() - async def fake_execute(self, name, arguments): + async def fake_execute(self, **kwargs): return "tool result" - monkeypatch.setattr("nanobot.agent.tools.registry.ToolRegistry.execute", fake_execute) + monkeypatch.setattr("nanobot.agent.tools.filesystem.ListDirTool.execute", fake_execute) await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"}) diff --git a/tests/agent/test_task_cancel.py b/tests/agent/test_task_cancel.py index 70f7621d1..7e84e57d8 100644 --- a/tests/agent/test_task_cancel.py +++ b/tests/agent/test_task_cancel.py @@ -8,6 +8,10 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest +from nanobot.config.schema import AgentDefaults + +_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars + def _make_loop(*, exec_config=None): """Create a minimal AgentLoop with mocked dependencies.""" @@ -186,7 +190,12 @@ class TestSubagentCancellation: bus = MessageBus() provider = MagicMock() provider.get_default_model.return_value = "test-model" - mgr = SubagentManager(provider=provider, workspace=MagicMock(), bus=bus) + mgr = SubagentManager( + provider=provider, + workspace=MagicMock(), + bus=bus, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + ) cancelled = asyncio.Event() @@ -214,7 +223,12 @@ class TestSubagentCancellation: bus = MessageBus() provider = MagicMock() provider.get_default_model.return_value = "test-model" - mgr = SubagentManager(provider=provider, workspace=MagicMock(), bus=bus) + mgr = SubagentManager( + provider=provider, + workspace=MagicMock(), + bus=bus, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + ) assert await mgr.cancel_by_session("nonexistent") == 0 @pytest.mark.asyncio @@ -236,19 +250,24 @@ class TestSubagentCancellation: if call_count["n"] == 1: return LLMResponse( content="thinking", - tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})], + tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})], reasoning_content="hidden reasoning", thinking_blocks=[{"type": "thinking", "thinking": "step"}], ) captured_second_call[:] = messages return LLMResponse(content="done", tool_calls=[]) provider.chat_with_retry = scripted_chat_with_retry - mgr = SubagentManager(provider=provider, workspace=tmp_path, bus=bus) + mgr = SubagentManager( + provider=provider, + workspace=tmp_path, + bus=bus, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + ) - async def fake_execute(self, name, arguments): + async def fake_execute(self, **kwargs): return "tool result" - monkeypatch.setattr("nanobot.agent.tools.registry.ToolRegistry.execute", fake_execute) + monkeypatch.setattr("nanobot.agent.tools.filesystem.ListDirTool.execute", fake_execute) await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"}) @@ -273,6 +292,7 @@ class TestSubagentCancellation: provider=provider, workspace=tmp_path, bus=bus, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, exec_config=ExecToolConfig(enable=False), ) mgr._announce_result = AsyncMock() @@ -304,20 +324,25 @@ class TestSubagentCancellation: provider.get_default_model.return_value = "test-model" provider.chat_with_retry = AsyncMock(return_value=LLMResponse( content="thinking", - tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})], + tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})], )) - mgr = SubagentManager(provider=provider, workspace=tmp_path, bus=bus) + mgr = SubagentManager( + provider=provider, + workspace=tmp_path, + bus=bus, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + ) mgr._announce_result = AsyncMock() calls = {"n": 0} - async def fake_execute(self, name, arguments): + async def fake_execute(self, **kwargs): calls["n"] += 1 if calls["n"] == 1: return "first result" raise RuntimeError("boom") - monkeypatch.setattr("nanobot.agent.tools.registry.ToolRegistry.execute", fake_execute) + monkeypatch.setattr("nanobot.agent.tools.filesystem.ListDirTool.execute", fake_execute) await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"}) @@ -340,15 +365,20 @@ class TestSubagentCancellation: provider.get_default_model.return_value = "test-model" provider.chat_with_retry = AsyncMock(return_value=LLMResponse( content="thinking", - tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})], + tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})], )) - mgr = SubagentManager(provider=provider, workspace=tmp_path, bus=bus) + mgr = SubagentManager( + provider=provider, + workspace=tmp_path, + bus=bus, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + ) mgr._announce_result = AsyncMock() started = asyncio.Event() cancelled = asyncio.Event() - async def fake_execute(self, name, arguments): + async def fake_execute(self, **kwargs): started.set() try: await asyncio.sleep(60) @@ -356,7 +386,7 @@ class TestSubagentCancellation: cancelled.set() raise - monkeypatch.setattr("nanobot.agent.tools.registry.ToolRegistry.execute", fake_execute) + monkeypatch.setattr("nanobot.agent.tools.filesystem.ListDirTool.execute", fake_execute) task = asyncio.create_task( mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"}) @@ -364,7 +394,7 @@ class TestSubagentCancellation: mgr._running_tasks["sub-1"] = task mgr._session_tasks["test:c1"] = {"sub-1"} - await started.wait() + await asyncio.wait_for(started.wait(), timeout=1.0) count = await mgr.cancel_by_session("test:c1") diff --git a/tests/channels/test_discord_channel.py b/tests/channels/test_discord_channel.py index d352c788c..845c03c57 100644 --- a/tests/channels/test_discord_channel.py +++ b/tests/channels/test_discord_channel.py @@ -594,7 +594,7 @@ async def test_send_stops_typing_after_send() -> None: typing_channel.typing_enter_hook = slow_typing await channel._start_typing(typing_channel) - await start.wait() + await asyncio.wait_for(start.wait(), timeout=1.0) await channel.send(OutboundMessage(channel="discord", chat_id="123", content="hello")) release.set() @@ -614,7 +614,7 @@ async def test_send_stops_typing_after_send() -> None: typing_channel.typing_enter_hook = slow_typing_progress await channel._start_typing(typing_channel) - await start.wait() + await asyncio.wait_for(start.wait(), timeout=1.0) await channel.send( OutboundMessage( @@ -665,7 +665,7 @@ async def test_start_typing_uses_typing_context_when_trigger_typing_missing() -> typing_channel = _NoTriggerChannel(channel_id=123) await channel._start_typing(typing_channel) # type: ignore[arg-type] - await entered.wait() + await asyncio.wait_for(entered.wait(), timeout=1.0) assert "123" in channel._typing_tasks diff --git a/tests/providers/test_litellm_kwargs.py b/tests/providers/test_litellm_kwargs.py index 62fb0a2cc..cc8347f0e 100644 --- a/tests/providers/test_litellm_kwargs.py +++ b/tests/providers/test_litellm_kwargs.py @@ -8,6 +8,7 @@ Validates that: from __future__ import annotations +import asyncio from types import SimpleNamespace from unittest.mock import AsyncMock, patch @@ -53,6 +54,15 @@ def _fake_tool_call_response() -> SimpleNamespace: return SimpleNamespace(choices=[choice], usage=usage) +class _StalledStream: + def __aiter__(self): + return self + + async def __anext__(self): + await asyncio.sleep(3600) + raise StopAsyncIteration + + def test_openrouter_spec_is_gateway() -> None: spec = find_by_name("openrouter") assert spec is not None @@ -214,3 +224,54 @@ def test_openai_model_passthrough() -> None: spec=spec, ) assert provider.get_default_model() == "gpt-4o" + + +def test_openai_compat_strips_message_level_reasoning_fields() -> None: + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): + provider = OpenAICompatProvider() + + sanitized = provider._sanitize_messages([ + { + "role": "assistant", + "content": "done", + "reasoning_content": "hidden", + "extra_content": {"debug": True}, + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": {"name": "fn", "arguments": "{}"}, + "extra_content": {"google": {"thought_signature": "sig"}}, + } + ], + } + ]) + + assert "reasoning_content" not in sanitized[0] + assert "extra_content" not in sanitized[0] + assert sanitized[0]["tool_calls"][0]["extra_content"] == {"google": {"thought_signature": "sig"}} + + +@pytest.mark.asyncio +async def test_openai_compat_stream_watchdog_returns_error_on_stall(monkeypatch) -> None: + monkeypatch.setenv("NANOBOT_STREAM_IDLE_TIMEOUT_S", "0") + mock_create = AsyncMock(return_value=_StalledStream()) + spec = find_by_name("openai") + + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient: + client_instance = MockClient.return_value + client_instance.chat.completions.create = mock_create + + provider = OpenAICompatProvider( + api_key="sk-test-key", + default_model="gpt-4o", + spec=spec, + ) + result = await provider.chat_stream( + messages=[{"role": "user", "content": "hello"}], + model="gpt-4o", + ) + + assert result.finish_reason == "error" + assert result.content is not None + assert "stream stalled" in result.content diff --git a/tests/providers/test_provider_retry.py b/tests/providers/test_provider_retry.py index d732054d5..6b5c8d8d6 100644 --- a/tests/providers/test_provider_retry.py +++ b/tests/providers/test_provider_retry.py @@ -211,3 +211,32 @@ async def test_image_fallback_without_meta_uses_default_placeholder() -> None: content = msg.get("content") if isinstance(content, list): assert any("[image omitted]" in (b.get("text") or "") for b in content) + + +@pytest.mark.asyncio +async def test_chat_with_retry_uses_retry_after_and_emits_wait_progress(monkeypatch) -> None: + provider = ScriptedProvider([ + LLMResponse(content="429 rate limit, retry after 7s", finish_reason="error"), + LLMResponse(content="ok"), + ]) + delays: list[float] = [] + progress: list[str] = [] + + async def _fake_sleep(delay: float) -> None: + delays.append(delay) + + async def _progress(msg: str) -> None: + progress.append(msg) + + monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep) + + response = await provider.chat_with_retry( + messages=[{"role": "user", "content": "hello"}], + on_retry_wait=_progress, + ) + + assert response.content == "ok" + assert delays == [7.0] + assert progress and "7s" in progress[0] + + diff --git a/tests/tools/test_mcp_tool.py b/tests/tools/test_mcp_tool.py index 28666f05f..9c1320251 100644 --- a/tests/tools/test_mcp_tool.py +++ b/tests/tools/test_mcp_tool.py @@ -196,7 +196,7 @@ async def test_execute_re_raises_external_cancellation() -> None: wrapper = _make_wrapper(SimpleNamespace(call_tool=call_tool), timeout=10) task = asyncio.create_task(wrapper.execute()) - await started.wait() + await asyncio.wait_for(started.wait(), timeout=1.0) task.cancel()