diff --git a/.gitignore b/.gitignore index fce6e07f8..08217c5b1 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ .assets .docs .env +.web *.pyc dist/ build/ diff --git a/README.md b/README.md index 8a8c864d0..60714b34b 100644 --- a/README.md +++ b/README.md @@ -20,13 +20,20 @@ ## πŸ“’ News -> [!IMPORTANT] -> **Security note:** Due to `litellm` supply chain poisoning, **please check your Python environment ASAP** and refer to this [advisory](https://github.com/HKUDS/nanobot/discussions/2445) for details. We have fully removed the `litellm` since **v0.1.4.post6**. - +- **2026-04-02** 🧱 **Long-running tasks** run more reliably β€” core runtime hardening. +- **2026-04-01** πŸ”‘ GitHub Copilot auth restored; stricter workspace paths; OpenRouter Claude caching fix. +- **2026-03-31** πŸ›°οΈ WeChat multimodal alignment, Discord/Matrix polish, Python SDK facade, MCP and tool fixes. +- **2026-03-30** 🧩 OpenAI-compatible API tightened; composable agent lifecycle hooks. +- **2026-03-29** πŸ’¬ WeChat voice, typing, QR/media resilience; fixed-session OpenAI-compatible API. +- **2026-03-28** πŸ“š Provider docs refresh; skill template wording fix. - **2026-03-27** πŸš€ Released **v0.1.4.post6** β€” architecture decoupling, litellm removal, end-to-end streaming, WeChat channel, and a security fix. Please see [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post6) for details. - **2026-03-26** πŸ—οΈ Agent runner extracted and lifecycle hooks unified; stream delta coalescing at boundaries. - **2026-03-25** 🌏 StepFun provider, configurable timezone, Gemini thought signatures. - **2026-03-24** πŸ”§ WeChat compatibility, Feishu CardKit streaming, test suite restructured. + +
+Earlier news + - **2026-03-23** πŸ”§ Command routing refactored for plugins, WhatsApp/WeChat media, unified channel login CLI. - **2026-03-22** ⚑ End-to-end streaming, WeChat channel, Anthropic cache optimization, `/status` command. - **2026-03-21** πŸ”’ Replace `litellm` with native `openai` + `anthropic` SDKs. Please see [commit](https://github.com/HKUDS/nanobot/commit/3dfdab7). @@ -34,10 +41,6 @@ - **2026-03-19** πŸ’¬ Telegram gets more resilient under load; Feishu now renders code blocks properly. - **2026-03-18** πŸ“· Telegram can now send media via URL. Cron schedules show human-readable details. - **2026-03-17** ✨ Feishu formatting glow-up, Slack reacts when done, custom endpoints support extra headers, and image handling is more reliable. - -
-Earlier news - - **2026-03-16** πŸš€ Released **v0.1.4.post5** β€” a refinement-focused release with stronger reliability and channel support, and a more dependable day-to-day experience. Please see [release notes](https://github.com/HKUDS/nanobot/releases/tag/v0.1.4.post5) for details. - **2026-03-15** 🧩 DingTalk rich media, smarter built-in skills, and cleaner model compatibility. - **2026-03-14** πŸ’¬ Channel plugins, Feishu replies, and steadier MCP, QQ, and media handling. diff --git a/nanobot/agent/context.py b/nanobot/agent/context.py index ce69d247b..8ce2873a9 100644 --- a/nanobot/agent/context.py +++ b/nanobot/agent/context.py @@ -110,6 +110,20 @@ IMPORTANT: To send files (images, documents, audio, video) to the user, you MUST lines += [f"Channel: {channel}", f"Chat ID: {chat_id}"] return ContextBuilder._RUNTIME_CONTEXT_TAG + "\n" + "\n".join(lines) + @staticmethod + def _merge_message_content(left: Any, right: Any) -> str | list[dict[str, Any]]: + if isinstance(left, str) and isinstance(right, str): + return f"{left}\n\n{right}" if left else right + + def _to_blocks(value: Any) -> list[dict[str, Any]]: + if isinstance(value, list): + return [item if isinstance(item, dict) else {"type": "text", "text": str(item)} for item in value] + if value is None: + return [] + return [{"type": "text", "text": str(value)}] + + return _to_blocks(left) + _to_blocks(right) + def _load_bootstrap_files(self) -> str: """Load all bootstrap files from workspace.""" parts = [] @@ -142,12 +156,17 @@ IMPORTANT: To send files (images, documents, audio, video) to the user, you MUST merged = f"{runtime_ctx}\n\n{user_content}" else: merged = [{"type": "text", "text": runtime_ctx}] + user_content - - return [ + messages = [ {"role": "system", "content": self.build_system_prompt(skill_names)}, *history, - {"role": current_role, "content": merged}, ] + if messages[-1].get("role") == current_role: + last = dict(messages[-1]) + last["content"] = self._merge_message_content(last.get("content"), merged) + messages[-1] = last + return messages + messages.append({"role": current_role, "content": merged}) + return messages def _build_user_content(self, text: str, media: list[str] | None) -> str | list[dict[str, Any]]: """Build user message content with optional base64-encoded images.""" diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index 50fef58fd..4a68a19fc 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -29,8 +29,11 @@ from nanobot.agent.tools.web import WebFetchTool, WebSearchTool from nanobot.bus.events import InboundMessage, OutboundMessage from nanobot.command import CommandContext, CommandRouter, register_builtin_commands from nanobot.bus.queue import MessageBus +from nanobot.config.schema import AgentDefaults from nanobot.providers.base import LLMProvider from nanobot.session.manager import Session, SessionManager +from nanobot.utils.helpers import image_placeholder_text, truncate_text +from nanobot.utils.runtime import EMPTY_FINAL_RESPONSE_MESSAGE if TYPE_CHECKING: from nanobot.config.schema import ChannelsConfig, ExecToolConfig, WebSearchConfig @@ -38,11 +41,7 @@ if TYPE_CHECKING: class _LoopHook(AgentHook): - """Core lifecycle hook for the main agent loop. - - Handles streaming delta relay, progress reporting, tool-call logging, - and think-tag stripping for the built-in agent path. - """ + """Core hook for the main loop.""" def __init__( self, @@ -111,11 +110,7 @@ class _LoopHook(AgentHook): class _LoopHookChain(AgentHook): - """Run the core loop hook first, then best-effort extra hooks. - - This preserves the historical failure behavior of ``_LoopHook`` while still - letting user-supplied hooks opt into ``CompositeHook`` isolation. - """ + """Run the core hook before extra hooks.""" __slots__ = ("_primary", "_extras") @@ -163,7 +158,7 @@ class AgentLoop: 5. Sends responses back """ - _TOOL_RESULT_MAX_CHARS = 16_000 + _RUNTIME_CHECKPOINT_KEY = "runtime_checkpoint" def __init__( self, @@ -171,8 +166,11 @@ class AgentLoop: provider: LLMProvider, workspace: Path, model: str | None = None, - max_iterations: int = 40, - context_window_tokens: int = 65_536, + max_iterations: int | None = None, + context_window_tokens: int | None = None, + context_block_limit: int | None = None, + max_tool_result_chars: int | None = None, + provider_retry_mode: str = "standard", web_search_config: WebSearchConfig | None = None, web_proxy: str | None = None, exec_config: ExecToolConfig | None = None, @@ -186,13 +184,27 @@ class AgentLoop: ): from nanobot.config.schema import ExecToolConfig, WebSearchConfig + defaults = AgentDefaults() self.bus = bus self.channels_config = channels_config self.provider = provider self.workspace = workspace self.model = model or provider.get_default_model() - self.max_iterations = max_iterations - self.context_window_tokens = context_window_tokens + self.max_iterations = ( + max_iterations if max_iterations is not None else defaults.max_tool_iterations + ) + self.context_window_tokens = ( + context_window_tokens + if context_window_tokens is not None + else defaults.context_window_tokens + ) + self.context_block_limit = context_block_limit + self.max_tool_result_chars = ( + max_tool_result_chars + if max_tool_result_chars is not None + else defaults.max_tool_result_chars + ) + self.provider_retry_mode = provider_retry_mode self.web_search_config = web_search_config or WebSearchConfig() self.web_proxy = web_proxy self.exec_config = exec_config or ExecToolConfig() @@ -211,6 +223,7 @@ class AgentLoop: workspace=workspace, bus=bus, model=self.model, + max_tool_result_chars=self.max_tool_result_chars, web_search_config=self.web_search_config, web_proxy=web_proxy, exec_config=self.exec_config, @@ -322,6 +335,7 @@ class AgentLoop: on_stream: Callable[[str], Awaitable[None]] | None = None, on_stream_end: Callable[..., Awaitable[None]] | None = None, *, + session: Session | None = None, channel: str = "cli", chat_id: str = "direct", message_id: str | None = None, @@ -348,14 +362,27 @@ class AgentLoop: else loop_hook ) + async def _checkpoint(payload: dict[str, Any]) -> None: + if session is None: + return + self._set_runtime_checkpoint(session, payload) + result = await self.runner.run(AgentRunSpec( initial_messages=initial_messages, tools=self.tools, model=self.model, max_iterations=self.max_iterations, + max_tool_result_chars=self.max_tool_result_chars, hook=hook, error_message="Sorry, I encountered an error calling the AI model.", concurrent_tools=True, + workspace=self.workspace, + session_key=session.key if session else None, + context_window_tokens=self.context_window_tokens, + context_block_limit=self.context_block_limit, + provider_retry_mode=self.provider_retry_mode, + progress_callback=on_progress, + checkpoint_callback=_checkpoint, )) self._last_usage = result.usage if result.stop_reason == "max_iterations": @@ -493,6 +520,8 @@ class AgentLoop: logger.info("Processing system message from {}", msg.sender_id) key = f"{channel}:{chat_id}" session = self.sessions.get_or_create(key) + if self._restore_runtime_checkpoint(session): + self.sessions.save(session) await self.memory_consolidator.maybe_consolidate_by_tokens(session) self._set_tool_context(channel, chat_id, msg.metadata.get("message_id")) history = session.get_history(max_messages=0) @@ -503,10 +532,11 @@ class AgentLoop: current_role=current_role, ) final_content, _, all_msgs = await self._run_agent_loop( - messages, channel=channel, chat_id=chat_id, + messages, session=session, channel=channel, chat_id=chat_id, message_id=msg.metadata.get("message_id"), ) self._save_turn(session, all_msgs, 1 + len(history)) + self._clear_runtime_checkpoint(session) self.sessions.save(session) self._schedule_background(self.memory_consolidator.maybe_consolidate_by_tokens(session)) return OutboundMessage(channel=channel, chat_id=chat_id, @@ -517,6 +547,8 @@ class AgentLoop: key = session_key or msg.session_key session = self.sessions.get_or_create(key) + if self._restore_runtime_checkpoint(session): + self.sessions.save(session) # Slash commands raw = msg.content.strip() @@ -552,14 +584,16 @@ class AgentLoop: on_progress=on_progress or _bus_progress, on_stream=on_stream, on_stream_end=on_stream_end, + session=session, channel=msg.channel, chat_id=msg.chat_id, message_id=msg.metadata.get("message_id"), ) - if final_content is None: - final_content = "I've completed processing but have no response to give." + if final_content is None or not final_content.strip(): + final_content = EMPTY_FINAL_RESPONSE_MESSAGE self._save_turn(session, all_msgs, 1 + len(history)) + self._clear_runtime_checkpoint(session) self.sessions.save(session) self._schedule_background(self.memory_consolidator.maybe_consolidate_by_tokens(session)) @@ -577,12 +611,6 @@ class AgentLoop: metadata=meta, ) - @staticmethod - def _image_placeholder(block: dict[str, Any]) -> dict[str, str]: - """Convert an inline image block into a compact text placeholder.""" - path = (block.get("_meta") or {}).get("path", "") - return {"type": "text", "text": f"[image: {path}]" if path else "[image]"} - def _sanitize_persisted_blocks( self, content: list[dict[str, Any]], @@ -609,13 +637,14 @@ class AgentLoop: block.get("type") == "image_url" and block.get("image_url", {}).get("url", "").startswith("data:image/") ): - filtered.append(self._image_placeholder(block)) + path = (block.get("_meta") or {}).get("path", "") + filtered.append({"type": "text", "text": image_placeholder_text(path)}) continue if block.get("type") == "text" and isinstance(block.get("text"), str): text = block["text"] - if truncate_text and len(text) > self._TOOL_RESULT_MAX_CHARS: - text = text[:self._TOOL_RESULT_MAX_CHARS] + "\n... (truncated)" + if truncate_text and len(text) > self.max_tool_result_chars: + text = truncate_text(text, self.max_tool_result_chars) filtered.append({**block, "text": text}) continue @@ -632,8 +661,8 @@ class AgentLoop: if role == "assistant" and not content and not entry.get("tool_calls"): continue # skip empty assistant messages β€” they poison session context if role == "tool": - if isinstance(content, str) and len(content) > self._TOOL_RESULT_MAX_CHARS: - entry["content"] = content[:self._TOOL_RESULT_MAX_CHARS] + "\n... (truncated)" + if isinstance(content, str) and len(content) > self.max_tool_result_chars: + entry["content"] = truncate_text(content, self.max_tool_result_chars) elif isinstance(content, list): filtered = self._sanitize_persisted_blocks(content, truncate_text=True) if not filtered: @@ -656,6 +685,78 @@ class AgentLoop: session.messages.append(entry) session.updated_at = datetime.now() + def _set_runtime_checkpoint(self, session: Session, payload: dict[str, Any]) -> None: + """Persist the latest in-flight turn state into session metadata.""" + session.metadata[self._RUNTIME_CHECKPOINT_KEY] = payload + self.sessions.save(session) + + def _clear_runtime_checkpoint(self, session: Session) -> None: + if self._RUNTIME_CHECKPOINT_KEY in session.metadata: + session.metadata.pop(self._RUNTIME_CHECKPOINT_KEY, None) + + @staticmethod + def _checkpoint_message_key(message: dict[str, Any]) -> tuple[Any, ...]: + return ( + message.get("role"), + message.get("content"), + message.get("tool_call_id"), + message.get("name"), + message.get("tool_calls"), + message.get("reasoning_content"), + message.get("thinking_blocks"), + ) + + def _restore_runtime_checkpoint(self, session: Session) -> bool: + """Materialize an unfinished turn into session history before a new request.""" + from datetime import datetime + + checkpoint = session.metadata.get(self._RUNTIME_CHECKPOINT_KEY) + if not isinstance(checkpoint, dict): + return False + + assistant_message = checkpoint.get("assistant_message") + completed_tool_results = checkpoint.get("completed_tool_results") or [] + pending_tool_calls = checkpoint.get("pending_tool_calls") or [] + + restored_messages: list[dict[str, Any]] = [] + if isinstance(assistant_message, dict): + restored = dict(assistant_message) + restored.setdefault("timestamp", datetime.now().isoformat()) + restored_messages.append(restored) + for message in completed_tool_results: + if isinstance(message, dict): + restored = dict(message) + restored.setdefault("timestamp", datetime.now().isoformat()) + restored_messages.append(restored) + for tool_call in pending_tool_calls: + if not isinstance(tool_call, dict): + continue + tool_id = tool_call.get("id") + name = ((tool_call.get("function") or {}).get("name")) or "tool" + restored_messages.append({ + "role": "tool", + "tool_call_id": tool_id, + "name": name, + "content": "Error: Task interrupted before this tool finished.", + "timestamp": datetime.now().isoformat(), + }) + + overlap = 0 + max_overlap = min(len(session.messages), len(restored_messages)) + for size in range(max_overlap, 0, -1): + existing = session.messages[-size:] + restored = restored_messages[:size] + if all( + self._checkpoint_message_key(left) == self._checkpoint_message_key(right) + for left, right in zip(existing, restored) + ): + overlap = size + break + session.messages.extend(restored_messages[overlap:]) + + self._clear_runtime_checkpoint(session) + return True + async def process_direct( self, content: str, diff --git a/nanobot/agent/runner.py b/nanobot/agent/runner.py index 4fec539dd..a8676a8e0 100644 --- a/nanobot/agent/runner.py +++ b/nanobot/agent/runner.py @@ -4,20 +4,36 @@ from __future__ import annotations import asyncio from dataclasses import dataclass, field +from pathlib import Path from typing import Any +from loguru import logger + from nanobot.agent.hook import AgentHook, AgentHookContext from nanobot.agent.tools.registry import ToolRegistry from nanobot.providers.base import LLMProvider, ToolCallRequest -from nanobot.utils.helpers import build_assistant_message +from nanobot.utils.helpers import ( + build_assistant_message, + estimate_message_tokens, + estimate_prompt_tokens_chain, + find_legal_message_start, + maybe_persist_tool_result, + truncate_text, +) +from nanobot.utils.runtime import ( + EMPTY_FINAL_RESPONSE_MESSAGE, + build_finalization_retry_message, + ensure_nonempty_tool_result, + is_blank_text, + repeated_external_lookup_error, +) _DEFAULT_MAX_ITERATIONS_MESSAGE = ( "I reached the maximum number of tool call iterations ({max_iterations}) " "without completing the task. You can try breaking the task into smaller steps." ) _DEFAULT_ERROR_MESSAGE = "Sorry, I encountered an error calling the AI model." - - +_SNIP_SAFETY_BUFFER = 1024 @dataclass(slots=True) class AgentRunSpec: """Configuration for a single agent execution.""" @@ -26,6 +42,7 @@ class AgentRunSpec: tools: ToolRegistry model: str max_iterations: int + max_tool_result_chars: int temperature: float | None = None max_tokens: int | None = None reasoning_effort: str | None = None @@ -34,6 +51,13 @@ class AgentRunSpec: max_iterations_message: str | None = None concurrent_tools: bool = False fail_on_tool_error: bool = False + workspace: Path | None = None + session_key: str | None = None + context_window_tokens: int | None = None + context_block_limit: int | None = None + provider_retry_mode: str = "standard" + progress_callback: Any | None = None + checkpoint_callback: Any | None = None @dataclass(slots=True) @@ -60,91 +84,142 @@ class AgentRunner: messages = list(spec.initial_messages) final_content: str | None = None tools_used: list[str] = [] - usage: dict[str, int] = {} + usage: dict[str, int] = {"prompt_tokens": 0, "completion_tokens": 0} error: str | None = None stop_reason = "completed" tool_events: list[dict[str, str]] = [] + external_lookup_counts: dict[str, int] = {} for iteration in range(spec.max_iterations): + try: + messages = self._apply_tool_result_budget(spec, messages) + messages_for_model = self._snip_history(spec, messages) + except Exception as exc: + logger.warning( + "Context governance failed on turn {} for {}: {}; using raw messages", + iteration, + spec.session_key or "default", + exc, + ) + messages_for_model = messages context = AgentHookContext(iteration=iteration, messages=messages) await hook.before_iteration(context) - kwargs: dict[str, Any] = { - "messages": messages, - "tools": spec.tools.get_definitions(), - "model": spec.model, - } - if spec.temperature is not None: - kwargs["temperature"] = spec.temperature - if spec.max_tokens is not None: - kwargs["max_tokens"] = spec.max_tokens - if spec.reasoning_effort is not None: - kwargs["reasoning_effort"] = spec.reasoning_effort - - if hook.wants_streaming(): - async def _stream(delta: str) -> None: - await hook.on_stream(context, delta) - - response = await self.provider.chat_stream_with_retry( - **kwargs, - on_content_delta=_stream, - ) - else: - response = await self.provider.chat_with_retry(**kwargs) - - raw_usage = response.usage or {} + response = await self._request_model(spec, messages_for_model, hook, context) + raw_usage = self._usage_dict(response.usage) context.response = response - context.usage = raw_usage + context.usage = dict(raw_usage) context.tool_calls = list(response.tool_calls) - # Accumulate standard fields into result usage. - usage["prompt_tokens"] = usage.get("prompt_tokens", 0) + int(raw_usage.get("prompt_tokens", 0) or 0) - usage["completion_tokens"] = usage.get("completion_tokens", 0) + int(raw_usage.get("completion_tokens", 0) or 0) - cached = raw_usage.get("cached_tokens") - if cached: - usage["cached_tokens"] = usage.get("cached_tokens", 0) + int(cached) + self._accumulate_usage(usage, raw_usage) if response.has_tool_calls: if hook.wants_streaming(): await hook.on_stream_end(context, resuming=True) - messages.append(build_assistant_message( + assistant_message = build_assistant_message( response.content or "", tool_calls=[tc.to_openai_tool_call() for tc in response.tool_calls], reasoning_content=response.reasoning_content, thinking_blocks=response.thinking_blocks, - )) + ) + messages.append(assistant_message) tools_used.extend(tc.name for tc in response.tool_calls) + await self._emit_checkpoint( + spec, + { + "phase": "awaiting_tools", + "iteration": iteration, + "model": spec.model, + "assistant_message": assistant_message, + "completed_tool_results": [], + "pending_tool_calls": [tc.to_openai_tool_call() for tc in response.tool_calls], + }, + ) await hook.before_execute_tools(context) - results, new_events, fatal_error = await self._execute_tools(spec, response.tool_calls) + results, new_events, fatal_error = await self._execute_tools( + spec, + response.tool_calls, + external_lookup_counts, + ) tool_events.extend(new_events) context.tool_results = list(results) context.tool_events = list(new_events) if fatal_error is not None: error = f"Error: {type(fatal_error).__name__}: {fatal_error}" + final_content = error stop_reason = "tool_error" + self._append_final_message(messages, final_content) + context.final_content = final_content context.error = error context.stop_reason = stop_reason await hook.after_iteration(context) break + completed_tool_results: list[dict[str, Any]] = [] for tool_call, result in zip(response.tool_calls, results): - messages.append({ + tool_message = { "role": "tool", "tool_call_id": tool_call.id, "name": tool_call.name, - "content": result, - }) + "content": self._normalize_tool_result( + spec, + tool_call.id, + tool_call.name, + result, + ), + } + messages.append(tool_message) + completed_tool_results.append(tool_message) + await self._emit_checkpoint( + spec, + { + "phase": "tools_completed", + "iteration": iteration, + "model": spec.model, + "assistant_message": assistant_message, + "completed_tool_results": completed_tool_results, + "pending_tool_calls": [], + }, + ) await hook.after_iteration(context) continue + clean = hook.finalize_content(context, response.content) + if response.finish_reason != "error" and is_blank_text(clean): + logger.warning( + "Empty final response on turn {} for {}; retrying with explicit finalization prompt", + iteration, + spec.session_key or "default", + ) + if hook.wants_streaming(): + await hook.on_stream_end(context, resuming=False) + response = await self._request_finalization_retry(spec, messages_for_model) + retry_usage = self._usage_dict(response.usage) + self._accumulate_usage(usage, retry_usage) + raw_usage = self._merge_usage(raw_usage, retry_usage) + context.response = response + context.usage = dict(raw_usage) + context.tool_calls = list(response.tool_calls) + clean = hook.finalize_content(context, response.content) + if hook.wants_streaming(): await hook.on_stream_end(context, resuming=False) - clean = hook.finalize_content(context, response.content) if response.finish_reason == "error": final_content = clean or spec.error_message or _DEFAULT_ERROR_MESSAGE stop_reason = "error" error = final_content + self._append_final_message(messages, final_content) + context.final_content = final_content + context.error = error + context.stop_reason = stop_reason + await hook.after_iteration(context) + break + if is_blank_text(clean): + final_content = EMPTY_FINAL_RESPONSE_MESSAGE + stop_reason = "empty_final_response" + error = final_content + self._append_final_message(messages, final_content) context.final_content = final_content context.error = error context.stop_reason = stop_reason @@ -156,6 +231,17 @@ class AgentRunner: reasoning_content=response.reasoning_content, thinking_blocks=response.thinking_blocks, )) + await self._emit_checkpoint( + spec, + { + "phase": "final_response", + "iteration": iteration, + "model": spec.model, + "assistant_message": messages[-1], + "completed_tool_results": [], + "pending_tool_calls": [], + }, + ) final_content = clean context.final_content = final_content context.stop_reason = stop_reason @@ -165,6 +251,7 @@ class AgentRunner: stop_reason = "max_iterations" template = spec.max_iterations_message or _DEFAULT_MAX_ITERATIONS_MESSAGE final_content = template.format(max_iterations=spec.max_iterations) + self._append_final_message(messages, final_content) return AgentRunResult( final_content=final_content, @@ -176,21 +263,101 @@ class AgentRunner: tool_events=tool_events, ) + def _build_request_kwargs( + self, + spec: AgentRunSpec, + messages: list[dict[str, Any]], + *, + tools: list[dict[str, Any]] | None, + ) -> dict[str, Any]: + kwargs: dict[str, Any] = { + "messages": messages, + "tools": tools, + "model": spec.model, + "retry_mode": spec.provider_retry_mode, + "on_retry_wait": spec.progress_callback, + } + if spec.temperature is not None: + kwargs["temperature"] = spec.temperature + if spec.max_tokens is not None: + kwargs["max_tokens"] = spec.max_tokens + if spec.reasoning_effort is not None: + kwargs["reasoning_effort"] = spec.reasoning_effort + return kwargs + + async def _request_model( + self, + spec: AgentRunSpec, + messages: list[dict[str, Any]], + hook: AgentHook, + context: AgentHookContext, + ): + kwargs = self._build_request_kwargs( + spec, + messages, + tools=spec.tools.get_definitions(), + ) + if hook.wants_streaming(): + async def _stream(delta: str) -> None: + await hook.on_stream(context, delta) + + return await self.provider.chat_stream_with_retry( + **kwargs, + on_content_delta=_stream, + ) + return await self.provider.chat_with_retry(**kwargs) + + async def _request_finalization_retry( + self, + spec: AgentRunSpec, + messages: list[dict[str, Any]], + ): + retry_messages = list(messages) + retry_messages.append(build_finalization_retry_message()) + kwargs = self._build_request_kwargs(spec, retry_messages, tools=None) + return await self.provider.chat_with_retry(**kwargs) + + @staticmethod + def _usage_dict(usage: dict[str, Any] | None) -> dict[str, int]: + if not usage: + return {} + result: dict[str, int] = {} + for key, value in usage.items(): + try: + result[key] = int(value or 0) + except (TypeError, ValueError): + continue + return result + + @staticmethod + def _accumulate_usage(target: dict[str, int], addition: dict[str, int]) -> None: + for key, value in addition.items(): + target[key] = target.get(key, 0) + value + + @staticmethod + def _merge_usage(left: dict[str, int], right: dict[str, int]) -> dict[str, int]: + merged = dict(left) + for key, value in right.items(): + merged[key] = merged.get(key, 0) + value + return merged + async def _execute_tools( self, spec: AgentRunSpec, tool_calls: list[ToolCallRequest], + external_lookup_counts: dict[str, int], ) -> tuple[list[Any], list[dict[str, str]], BaseException | None]: - if spec.concurrent_tools: - tool_results = await asyncio.gather(*( - self._run_tool(spec, tool_call) - for tool_call in tool_calls - )) - else: - tool_results = [ - await self._run_tool(spec, tool_call) - for tool_call in tool_calls - ] + batches = self._partition_tool_batches(spec, tool_calls) + tool_results: list[tuple[Any, dict[str, str], BaseException | None]] = [] + for batch in batches: + if spec.concurrent_tools and len(batch) > 1: + tool_results.extend(await asyncio.gather(*( + self._run_tool(spec, tool_call, external_lookup_counts) + for tool_call in batch + ))) + else: + for tool_call in batch: + tool_results.append(await self._run_tool(spec, tool_call, external_lookup_counts)) results: list[Any] = [] events: list[dict[str, str]] = [] @@ -206,9 +373,44 @@ class AgentRunner: self, spec: AgentRunSpec, tool_call: ToolCallRequest, + external_lookup_counts: dict[str, int], ) -> tuple[Any, dict[str, str], BaseException | None]: + _HINT = "\n\n[Analyze the error above and try a different approach.]" + lookup_error = repeated_external_lookup_error( + tool_call.name, + tool_call.arguments, + external_lookup_counts, + ) + if lookup_error: + event = { + "name": tool_call.name, + "status": "error", + "detail": "repeated external lookup blocked", + } + if spec.fail_on_tool_error: + return lookup_error + _HINT, event, RuntimeError(lookup_error) + return lookup_error + _HINT, event, None + prepare_call = getattr(spec.tools, "prepare_call", None) + tool, params, prep_error = None, tool_call.arguments, None + if callable(prepare_call): + try: + prepared = prepare_call(tool_call.name, tool_call.arguments) + if isinstance(prepared, tuple) and len(prepared) == 3: + tool, params, prep_error = prepared + except Exception: + pass + if prep_error: + event = { + "name": tool_call.name, + "status": "error", + "detail": prep_error.split(": ", 1)[-1][:120], + } + return prep_error + _HINT, event, RuntimeError(prep_error) if spec.fail_on_tool_error else None try: - result = await spec.tools.execute(tool_call.name, tool_call.arguments) + if tool is not None: + result = await tool.execute(**params) + else: + result = await spec.tools.execute(tool_call.name, params) except asyncio.CancelledError: raise except BaseException as exc: @@ -221,14 +423,178 @@ class AgentRunner: return f"Error: {type(exc).__name__}: {exc}", event, exc return f"Error: {type(exc).__name__}: {exc}", event, None + if isinstance(result, str) and result.startswith("Error"): + event = { + "name": tool_call.name, + "status": "error", + "detail": result.replace("\n", " ").strip()[:120], + } + if spec.fail_on_tool_error: + return result + _HINT, event, RuntimeError(result) + return result + _HINT, event, None + detail = "" if result is None else str(result) detail = detail.replace("\n", " ").strip() if not detail: detail = "(empty)" elif len(detail) > 120: detail = detail[:120] + "..." - return result, { - "name": tool_call.name, - "status": "error" if isinstance(result, str) and result.startswith("Error") else "ok", - "detail": detail, - }, None + return result, {"name": tool_call.name, "status": "ok", "detail": detail}, None + + async def _emit_checkpoint( + self, + spec: AgentRunSpec, + payload: dict[str, Any], + ) -> None: + callback = spec.checkpoint_callback + if callback is not None: + await callback(payload) + + @staticmethod + def _append_final_message(messages: list[dict[str, Any]], content: str | None) -> None: + if not content: + return + if ( + messages + and messages[-1].get("role") == "assistant" + and not messages[-1].get("tool_calls") + ): + if messages[-1].get("content") == content: + return + messages[-1] = build_assistant_message(content) + return + messages.append(build_assistant_message(content)) + + def _normalize_tool_result( + self, + spec: AgentRunSpec, + tool_call_id: str, + tool_name: str, + result: Any, + ) -> Any: + result = ensure_nonempty_tool_result(tool_name, result) + try: + content = maybe_persist_tool_result( + spec.workspace, + spec.session_key, + tool_call_id, + result, + max_chars=spec.max_tool_result_chars, + ) + except Exception as exc: + logger.warning( + "Tool result persist failed for {} in {}: {}; using raw result", + tool_call_id, + spec.session_key or "default", + exc, + ) + content = result + if isinstance(content, str) and len(content) > spec.max_tool_result_chars: + return truncate_text(content, spec.max_tool_result_chars) + return content + + def _apply_tool_result_budget( + self, + spec: AgentRunSpec, + messages: list[dict[str, Any]], + ) -> list[dict[str, Any]]: + updated = messages + for idx, message in enumerate(messages): + if message.get("role") != "tool": + continue + normalized = self._normalize_tool_result( + spec, + str(message.get("tool_call_id") or f"tool_{idx}"), + str(message.get("name") or "tool"), + message.get("content"), + ) + if normalized != message.get("content"): + if updated is messages: + updated = [dict(m) for m in messages] + updated[idx]["content"] = normalized + return updated + + def _snip_history( + self, + spec: AgentRunSpec, + messages: list[dict[str, Any]], + ) -> list[dict[str, Any]]: + if not messages or not spec.context_window_tokens: + return messages + + provider_max_tokens = getattr(getattr(self.provider, "generation", None), "max_tokens", 4096) + max_output = spec.max_tokens if isinstance(spec.max_tokens, int) else ( + provider_max_tokens if isinstance(provider_max_tokens, int) else 4096 + ) + budget = spec.context_block_limit or ( + spec.context_window_tokens - max_output - _SNIP_SAFETY_BUFFER + ) + if budget <= 0: + return messages + + estimate, _ = estimate_prompt_tokens_chain( + self.provider, + spec.model, + messages, + spec.tools.get_definitions(), + ) + if estimate <= budget: + return messages + + system_messages = [dict(msg) for msg in messages if msg.get("role") == "system"] + non_system = [dict(msg) for msg in messages if msg.get("role") != "system"] + if not non_system: + return messages + + system_tokens = sum(estimate_message_tokens(msg) for msg in system_messages) + remaining_budget = max(128, budget - system_tokens) + kept: list[dict[str, Any]] = [] + kept_tokens = 0 + for message in reversed(non_system): + msg_tokens = estimate_message_tokens(message) + if kept and kept_tokens + msg_tokens > remaining_budget: + break + kept.append(message) + kept_tokens += msg_tokens + kept.reverse() + + if kept: + for i, message in enumerate(kept): + if message.get("role") == "user": + kept = kept[i:] + break + start = find_legal_message_start(kept) + if start: + kept = kept[start:] + if not kept: + kept = non_system[-min(len(non_system), 4) :] + start = find_legal_message_start(kept) + if start: + kept = kept[start:] + return system_messages + kept + + def _partition_tool_batches( + self, + spec: AgentRunSpec, + tool_calls: list[ToolCallRequest], + ) -> list[list[ToolCallRequest]]: + if not spec.concurrent_tools: + return [[tool_call] for tool_call in tool_calls] + + batches: list[list[ToolCallRequest]] = [] + current: list[ToolCallRequest] = [] + for tool_call in tool_calls: + get_tool = getattr(spec.tools, "get", None) + tool = get_tool(tool_call.name) if callable(get_tool) else None + can_batch = bool(tool and tool.concurrency_safe) + if can_batch: + current.append(tool_call) + continue + if current: + batches.append(current) + current = [] + batches.append([tool_call]) + if current: + batches.append(current) + return batches + diff --git a/nanobot/agent/subagent.py b/nanobot/agent/subagent.py index 9d936f034..c7643a486 100644 --- a/nanobot/agent/subagent.py +++ b/nanobot/agent/subagent.py @@ -44,6 +44,7 @@ class SubagentManager: provider: LLMProvider, workspace: Path, bus: MessageBus, + max_tool_result_chars: int, model: str | None = None, web_search_config: "WebSearchConfig | None" = None, web_proxy: str | None = None, @@ -56,6 +57,7 @@ class SubagentManager: self.workspace = workspace self.bus = bus self.model = model or provider.get_default_model() + self.max_tool_result_chars = max_tool_result_chars self.web_search_config = web_search_config or WebSearchConfig() self.web_proxy = web_proxy self.exec_config = exec_config or ExecToolConfig() @@ -136,6 +138,7 @@ class SubagentManager: tools=tools, model=self.model, max_iterations=15, + max_tool_result_chars=self.max_tool_result_chars, hook=_SubagentHook(task_id), max_iterations_message="Task completed but no final response was generated.", error_message=None, diff --git a/nanobot/agent/tools/base.py b/nanobot/agent/tools/base.py index 4017f7cf6..f119f6908 100644 --- a/nanobot/agent/tools/base.py +++ b/nanobot/agent/tools/base.py @@ -53,6 +53,21 @@ class Tool(ABC): """JSON Schema for tool parameters.""" pass + @property + def read_only(self) -> bool: + """Whether this tool is side-effect free and safe to parallelize.""" + return False + + @property + def concurrency_safe(self) -> bool: + """Whether this tool can run alongside other concurrency-safe tools.""" + return self.read_only and not self.exclusive + + @property + def exclusive(self) -> bool: + """Whether this tool should run alone even if concurrency is enabled.""" + return False + @abstractmethod async def execute(self, **kwargs: Any) -> Any: """ diff --git a/nanobot/agent/tools/filesystem.py b/nanobot/agent/tools/filesystem.py index da7778da3..d4094e7f3 100644 --- a/nanobot/agent/tools/filesystem.py +++ b/nanobot/agent/tools/filesystem.py @@ -73,6 +73,10 @@ class ReadFileTool(_FsTool): "Use offset and limit to paginate through large files." ) + @property + def read_only(self) -> bool: + return True + @property def parameters(self) -> dict[str, Any]: return { @@ -344,6 +348,10 @@ class ListDirTool(_FsTool): "Common noise directories (.git, node_modules, __pycache__, etc.) are auto-ignored." ) + @property + def read_only(self) -> bool: + return True + @property def parameters(self) -> dict[str, Any]: return { diff --git a/nanobot/agent/tools/message.py b/nanobot/agent/tools/message.py index 3ac813248..520020735 100644 --- a/nanobot/agent/tools/message.py +++ b/nanobot/agent/tools/message.py @@ -84,6 +84,9 @@ class MessageTool(Tool): media: list[str] | None = None, **kwargs: Any ) -> str: + from nanobot.utils.helpers import strip_think + content = strip_think(content) + channel = channel or self._default_channel chat_id = chat_id or self._default_chat_id # Only inherit default message_id when targeting the same channel+chat. diff --git a/nanobot/agent/tools/registry.py b/nanobot/agent/tools/registry.py index c24659a70..725706dce 100644 --- a/nanobot/agent/tools/registry.py +++ b/nanobot/agent/tools/registry.py @@ -35,22 +35,35 @@ class ToolRegistry: """Get all tool definitions in OpenAI format.""" return [tool.to_schema() for tool in self._tools.values()] + def prepare_call( + self, + name: str, + params: dict[str, Any], + ) -> tuple[Tool | None, dict[str, Any], str | None]: + """Resolve, cast, and validate one tool call.""" + tool = self._tools.get(name) + if not tool: + return None, params, ( + f"Error: Tool '{name}' not found. Available: {', '.join(self.tool_names)}" + ) + + cast_params = tool.cast_params(params) + errors = tool.validate_params(cast_params) + if errors: + return tool, cast_params, ( + f"Error: Invalid parameters for tool '{name}': " + "; ".join(errors) + ) + return tool, cast_params, None + async def execute(self, name: str, params: dict[str, Any]) -> Any: """Execute a tool by name with given parameters.""" _HINT = "\n\n[Analyze the error above and try a different approach.]" - - tool = self._tools.get(name) - if not tool: - return f"Error: Tool '{name}' not found. Available: {', '.join(self.tool_names)}" + tool, params, error = self.prepare_call(name, params) + if error: + return error + _HINT try: - # Attempt to cast parameters to match schema types - params = tool.cast_params(params) - - # Validate parameters - errors = tool.validate_params(params) - if errors: - return f"Error: Invalid parameters for tool '{name}': " + "; ".join(errors) + _HINT + assert tool is not None # guarded by prepare_call() result = await tool.execute(**params) if isinstance(result, str) and result.startswith("Error"): return result + _HINT diff --git a/nanobot/agent/tools/shell.py b/nanobot/agent/tools/shell.py index b051edffc..dd3a44335 100644 --- a/nanobot/agent/tools/shell.py +++ b/nanobot/agent/tools/shell.py @@ -52,6 +52,10 @@ class ExecTool(Tool): def description(self) -> str: return "Execute a shell command and return its output. Use with caution." + @property + def exclusive(self) -> bool: + return True + @property def parameters(self) -> dict[str, Any]: return { diff --git a/nanobot/agent/tools/web.py b/nanobot/agent/tools/web.py index 9480e194f..1c0fde822 100644 --- a/nanobot/agent/tools/web.py +++ b/nanobot/agent/tools/web.py @@ -92,6 +92,10 @@ class WebSearchTool(Tool): self.config = config if config is not None else WebSearchConfig() self.proxy = proxy + @property + def read_only(self) -> bool: + return True + async def execute(self, query: str, count: int | None = None, **kwargs: Any) -> str: provider = self.config.provider.strip().lower() or "brave" n = min(max(count or self.config.max_results, 1), 10) @@ -234,6 +238,10 @@ class WebFetchTool(Tool): self.max_chars = max_chars self.proxy = proxy + @property + def read_only(self) -> bool: + return True + async def execute(self, url: str, extractMode: str = "markdown", maxChars: int | None = None, **kwargs: Any) -> Any: max_chars = maxChars or self.max_chars is_valid, error_msg = _validate_url_safe(url) diff --git a/nanobot/api/server.py b/nanobot/api/server.py index 9494b6e31..2bfeddd05 100644 --- a/nanobot/api/server.py +++ b/nanobot/api/server.py @@ -14,6 +14,8 @@ from typing import Any from aiohttp import web from loguru import logger +from nanobot.utils.runtime import EMPTY_FINAL_RESPONSE_MESSAGE + API_SESSION_KEY = "api:default" API_CHAT_ID = "default" @@ -98,7 +100,7 @@ async def handle_chat_completions(request: web.Request) -> web.Response: logger.info("API request session_key={} content={}", session_key, user_content[:80]) - _FALLBACK = "I've completed processing but have no response to give." + _FALLBACK = EMPTY_FINAL_RESPONSE_MESSAGE try: async with session_lock: diff --git a/nanobot/channels/qq.py b/nanobot/channels/qq.py index b9d2d64d8..bef2cf27a 100644 --- a/nanobot/channels/qq.py +++ b/nanobot/channels/qq.py @@ -134,6 +134,7 @@ class QQConfig(Base): secret: str = "" allow_from: list[str] = Field(default_factory=list) msg_format: Literal["plain", "markdown"] = "plain" + ack_message: str = "⏳ Processing..." # Optional: directory to save inbound attachments. If empty, use nanobot get_media_dir("qq"). media_dir: str = "" @@ -484,6 +485,17 @@ class QQChannel(BaseChannel): if not content and not media_paths: return + if self.config.ack_message: + try: + await self._send_text_only( + chat_id=chat_id, + is_group=is_group, + msg_id=data.id, + content=self.config.ack_message, + ) + except Exception: + logger.debug("QQ ack message failed for chat_id={}", chat_id) + await self._handle_message( sender_id=user_id, chat_id=chat_id, diff --git a/nanobot/channels/telegram.py b/nanobot/channels/telegram.py index 916b9ba64..a6bd810f2 100644 --- a/nanobot/channels/telegram.py +++ b/nanobot/channels/telegram.py @@ -275,13 +275,10 @@ class TelegramChannel(BaseChannel): self._app = builder.build() self._app.add_error_handler(self._on_error) - # Add command handlers - self._app.add_handler(CommandHandler("start", self._on_start)) - self._app.add_handler(CommandHandler("new", self._forward_command)) - self._app.add_handler(CommandHandler("stop", self._forward_command)) - self._app.add_handler(CommandHandler("restart", self._forward_command)) - self._app.add_handler(CommandHandler("status", self._forward_command)) - self._app.add_handler(CommandHandler("help", self._on_help)) + # Add command handlers (using Regex to support @username suffixes before bot initialization) + self._app.add_handler(MessageHandler(filters.Regex(r"^/start(?:@\w+)?$"), self._on_start)) + self._app.add_handler(MessageHandler(filters.Regex(r"^/(new|stop|restart|status)(?:@\w+)?$"), self._forward_command)) + self._app.add_handler(MessageHandler(filters.Regex(r"^/help(?:@\w+)?$"), self._on_help)) # Add message handler for text, photos, voice, documents self._app.add_handler( @@ -313,7 +310,7 @@ class TelegramChannel(BaseChannel): # Start polling (this runs until stopped) await self._app.updater.start_polling( allowed_updates=["message"], - drop_pending_updates=True # Ignore old messages on startup + drop_pending_updates=False # Process pending messages on startup ) # Keep running until stopped @@ -362,9 +359,14 @@ class TelegramChannel(BaseChannel): logger.warning("Telegram bot not running") return - # Only stop typing indicator for final responses + # Only stop typing indicator and remove reaction for final responses if not msg.metadata.get("_progress", False): self._stop_typing(msg.chat_id) + if reply_to_message_id := msg.metadata.get("message_id"): + try: + await self._remove_reaction(msg.chat_id, int(reply_to_message_id)) + except ValueError: + pass try: chat_id = int(msg.chat_id) @@ -435,7 +437,9 @@ class TelegramChannel(BaseChannel): await self._send_text(chat_id, chunk, reply_params, thread_kwargs) async def _call_with_retry(self, fn, *args, **kwargs): - """Call an async Telegram API function with retry on pool/network timeout.""" + """Call an async Telegram API function with retry on pool/network timeout and RetryAfter.""" + from telegram.error import RetryAfter + for attempt in range(1, _SEND_MAX_RETRIES + 1): try: return await fn(*args, **kwargs) @@ -448,6 +452,15 @@ class TelegramChannel(BaseChannel): attempt, _SEND_MAX_RETRIES, delay, ) await asyncio.sleep(delay) + except RetryAfter as e: + if attempt == _SEND_MAX_RETRIES: + raise + delay = float(e.retry_after) + logger.warning( + "Telegram Flood Control (attempt {}/{}), retrying in {:.1f}s", + attempt, _SEND_MAX_RETRIES, delay, + ) + await asyncio.sleep(delay) async def _send_text( self, @@ -498,6 +511,11 @@ class TelegramChannel(BaseChannel): if stream_id is not None and buf.stream_id is not None and buf.stream_id != stream_id: return self._stop_typing(chat_id) + if reply_to_message_id := meta.get("message_id"): + try: + await self._remove_reaction(chat_id, int(reply_to_message_id)) + except ValueError: + pass try: html = _markdown_to_telegram_html(buf.text) await self._call_with_retry( @@ -619,8 +637,7 @@ class TelegramChannel(BaseChannel): "reply_to_message_id": getattr(reply_to, "message_id", None) if reply_to else None, } - @staticmethod - def _extract_reply_context(message) -> str | None: + async def _extract_reply_context(self, message) -> str | None: """Extract text from the message being replied to, if any.""" reply = getattr(message, "reply_to_message", None) if not reply: @@ -628,7 +645,21 @@ class TelegramChannel(BaseChannel): text = getattr(reply, "text", None) or getattr(reply, "caption", None) or "" if len(text) > TELEGRAM_REPLY_CONTEXT_MAX_LEN: text = text[:TELEGRAM_REPLY_CONTEXT_MAX_LEN] + "..." - return f"[Reply to: {text}]" if text else None + + if not text: + return None + + bot_id, _ = await self._ensure_bot_identity() + reply_user = getattr(reply, "from_user", None) + + if bot_id and reply_user and getattr(reply_user, "id", None) == bot_id: + return f"[Reply to bot: {text}]" + elif reply_user and getattr(reply_user, "username", None): + return f"[Reply to @{reply_user.username}: {text}]" + elif reply_user and getattr(reply_user, "first_name", None): + return f"[Reply to {reply_user.first_name}: {text}]" + else: + return f"[Reply to: {text}]" async def _download_message_media( self, msg, *, add_failure_content: bool = False @@ -765,10 +796,18 @@ class TelegramChannel(BaseChannel): message = update.message user = update.effective_user self._remember_thread_context(message) + + # Strip @bot_username suffix if present + content = message.text or "" + if content.startswith("/") and "@" in content: + cmd_part, *rest = content.split(" ", 1) + cmd_part = cmd_part.split("@")[0] + content = f"{cmd_part} {rest[0]}" if rest else cmd_part + await self._handle_message( sender_id=self._sender_id(user), chat_id=str(message.chat_id), - content=message.text or "", + content=content, metadata=self._build_message_metadata(message, user), session_key=self._derive_topic_session_key(message), ) @@ -812,7 +851,7 @@ class TelegramChannel(BaseChannel): # Reply context: text and/or media from the replied-to message reply = getattr(message, "reply_to_message", None) if reply is not None: - reply_ctx = self._extract_reply_context(message) + reply_ctx = await self._extract_reply_context(message) reply_media, reply_media_parts = await self._download_message_media(reply) if reply_media: media_paths = reply_media + media_paths @@ -903,6 +942,19 @@ class TelegramChannel(BaseChannel): except Exception as e: logger.debug("Telegram reaction failed: {}", e) + async def _remove_reaction(self, chat_id: str, message_id: int) -> None: + """Remove emoji reaction from a message (best-effort, non-blocking).""" + if not self._app: + return + try: + await self._app.bot.set_message_reaction( + chat_id=int(chat_id), + message_id=message_id, + reaction=[], + ) + except Exception as e: + logger.debug("Telegram reaction removal failed: {}", e) + async def _typing_loop(self, chat_id: str) -> None: """Repeatedly send 'typing' action until cancelled.""" try: diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py index 49521aa16..d611c2772 100644 --- a/nanobot/cli/commands.py +++ b/nanobot/cli/commands.py @@ -542,6 +542,9 @@ def serve( model=runtime_config.agents.defaults.model, max_iterations=runtime_config.agents.defaults.max_tool_iterations, context_window_tokens=runtime_config.agents.defaults.context_window_tokens, + context_block_limit=runtime_config.agents.defaults.context_block_limit, + max_tool_result_chars=runtime_config.agents.defaults.max_tool_result_chars, + provider_retry_mode=runtime_config.agents.defaults.provider_retry_mode, web_search_config=runtime_config.tools.web.search, web_proxy=runtime_config.tools.web.proxy or None, exec_config=runtime_config.tools.exec, @@ -629,6 +632,9 @@ def gateway( model=config.agents.defaults.model, max_iterations=config.agents.defaults.max_tool_iterations, context_window_tokens=config.agents.defaults.context_window_tokens, + context_block_limit=config.agents.defaults.context_block_limit, + max_tool_result_chars=config.agents.defaults.max_tool_result_chars, + provider_retry_mode=config.agents.defaults.provider_retry_mode, web_search_config=config.tools.web.search, web_proxy=config.tools.web.proxy or None, exec_config=config.tools.exec, @@ -835,6 +841,9 @@ def agent( model=config.agents.defaults.model, max_iterations=config.agents.defaults.max_tool_iterations, context_window_tokens=config.agents.defaults.context_window_tokens, + context_block_limit=config.agents.defaults.context_block_limit, + max_tool_result_chars=config.agents.defaults.max_tool_result_chars, + provider_retry_mode=config.agents.defaults.provider_retry_mode, web_search_config=config.tools.web.search, web_proxy=config.tools.web.proxy or None, exec_config=config.tools.exec, @@ -1023,12 +1032,18 @@ app.add_typer(channels_app, name="channels") @channels_app.command("status") -def channels_status(): +def channels_status( + config_path: str | None = typer.Option(None, "--config", "-c", help="Path to config file"), +): """Show channel status.""" from nanobot.channels.registry import discover_all - from nanobot.config.loader import load_config + from nanobot.config.loader import load_config, set_config_path - config = load_config() + resolved_config_path = Path(config_path).expanduser().resolve() if config_path else None + if resolved_config_path is not None: + set_config_path(resolved_config_path) + + config = load_config(resolved_config_path) table = Table(title="Channel Status") table.add_column("Channel", style="cyan") @@ -1115,12 +1130,17 @@ def _get_bridge_dir() -> Path: def channels_login( channel_name: str = typer.Argument(..., help="Channel name (e.g. weixin, whatsapp)"), force: bool = typer.Option(False, "--force", "-f", help="Force re-authentication even if already logged in"), + config_path: str | None = typer.Option(None, "--config", "-c", help="Path to config file"), ): """Authenticate with a channel via QR code or other interactive login.""" from nanobot.channels.registry import discover_all - from nanobot.config.loader import load_config + from nanobot.config.loader import load_config, set_config_path - config = load_config() + resolved_config_path = Path(config_path).expanduser().resolve() if config_path else None + if resolved_config_path is not None: + set_config_path(resolved_config_path) + + config = load_config(resolved_config_path) channel_cfg = getattr(config.channels, channel_name, None) or {} # Validate channel exists diff --git a/nanobot/command/builtin.py b/nanobot/command/builtin.py index 643397057..05d4fc163 100644 --- a/nanobot/command/builtin.py +++ b/nanobot/command/builtin.py @@ -26,7 +26,10 @@ async def cmd_stop(ctx: CommandContext) -> OutboundMessage: sub_cancelled = await loop.subagents.cancel_by_session(msg.session_key) total = cancelled + sub_cancelled content = f"Stopped {total} task(s)." if total else "No active task to stop." - return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id, content=content) + return OutboundMessage( + channel=msg.channel, chat_id=msg.chat_id, content=content, + metadata=dict(msg.metadata or {}) + ) async def cmd_restart(ctx: CommandContext) -> OutboundMessage: @@ -38,7 +41,10 @@ async def cmd_restart(ctx: CommandContext) -> OutboundMessage: os.execv(sys.executable, [sys.executable, "-m", "nanobot"] + sys.argv[1:]) asyncio.create_task(_do_restart()) - return OutboundMessage(channel=msg.channel, chat_id=msg.chat_id, content="Restarting...") + return OutboundMessage( + channel=msg.channel, chat_id=msg.chat_id, content="Restarting...", + metadata=dict(msg.metadata or {}) + ) async def cmd_status(ctx: CommandContext) -> OutboundMessage: @@ -62,7 +68,7 @@ async def cmd_status(ctx: CommandContext) -> OutboundMessage: session_msg_count=len(session.get_history(max_messages=0)), context_tokens_estimate=ctx_est, ), - metadata={"render_as": "text"}, + metadata={**dict(ctx.msg.metadata or {}), "render_as": "text"}, ) @@ -79,6 +85,7 @@ async def cmd_new(ctx: CommandContext) -> OutboundMessage: return OutboundMessage( channel=ctx.msg.channel, chat_id=ctx.msg.chat_id, content="New session started.", + metadata=dict(ctx.msg.metadata or {}) ) @@ -88,7 +95,7 @@ async def cmd_help(ctx: CommandContext) -> OutboundMessage: channel=ctx.msg.channel, chat_id=ctx.msg.chat_id, content=build_help_text(), - metadata={"render_as": "text"}, + metadata={**dict(ctx.msg.metadata or {}), "render_as": "text"}, ) diff --git a/nanobot/config/schema.py b/nanobot/config/schema.py index c4c927afd..602b8a911 100644 --- a/nanobot/config/schema.py +++ b/nanobot/config/schema.py @@ -38,8 +38,11 @@ class AgentDefaults(Base): ) max_tokens: int = 8192 context_window_tokens: int = 65_536 + context_block_limit: int | None = None temperature: float = 0.1 - max_tool_iterations: int = 40 + max_tool_iterations: int = 200 + max_tool_result_chars: int = 16_000 + provider_retry_mode: Literal["standard", "persistent"] = "standard" reasoning_effort: str | None = None # low / medium / high - enables LLM thinking mode timezone: str = "UTC" # IANA timezone, e.g. "Asia/Shanghai", "America/New_York" diff --git a/nanobot/nanobot.py b/nanobot/nanobot.py index 84fb70934..30282cb4f 100644 --- a/nanobot/nanobot.py +++ b/nanobot/nanobot.py @@ -73,6 +73,9 @@ class Nanobot: model=defaults.model, max_iterations=defaults.max_tool_iterations, context_window_tokens=defaults.context_window_tokens, + context_block_limit=defaults.context_block_limit, + max_tool_result_chars=defaults.max_tool_result_chars, + provider_retry_mode=defaults.provider_retry_mode, web_search_config=config.tools.web.search, web_proxy=config.tools.web.proxy or None, exec_config=config.tools.exec, diff --git a/nanobot/providers/anthropic_provider.py b/nanobot/providers/anthropic_provider.py index 8e102d305..eaec77789 100644 --- a/nanobot/providers/anthropic_provider.py +++ b/nanobot/providers/anthropic_provider.py @@ -2,6 +2,8 @@ from __future__ import annotations +import asyncio +import os import re import secrets import string @@ -434,13 +436,33 @@ class AnthropicProvider(LLMProvider): messages, tools, model, max_tokens, temperature, reasoning_effort, tool_choice, ) + idle_timeout_s = int(os.environ.get("NANOBOT_STREAM_IDLE_TIMEOUT_S", "90")) try: async with self._client.messages.stream(**kwargs) as stream: if on_content_delta: - async for text in stream.text_stream: + stream_iter = stream.text_stream.__aiter__() + while True: + try: + text = await asyncio.wait_for( + stream_iter.__anext__(), + timeout=idle_timeout_s, + ) + except StopAsyncIteration: + break await on_content_delta(text) - response = await stream.get_final_message() + response = await asyncio.wait_for( + stream.get_final_message(), + timeout=idle_timeout_s, + ) return self._parse_response(response) + except asyncio.TimeoutError: + return LLMResponse( + content=( + f"Error calling LLM: stream stalled for more than " + f"{idle_timeout_s} seconds" + ), + finish_reason="error", + ) except Exception as e: return LLMResponse(content=f"Error calling LLM: {e}", finish_reason="error") diff --git a/nanobot/providers/azure_openai_provider.py b/nanobot/providers/azure_openai_provider.py index d71dae917..12c74be02 100644 --- a/nanobot/providers/azure_openai_provider.py +++ b/nanobot/providers/azure_openai_provider.py @@ -1,31 +1,36 @@ -"""Azure OpenAI provider implementation with API version 2024-10-21.""" +"""Azure OpenAI provider using the OpenAI SDK Responses API. + +Uses ``AsyncOpenAI`` pointed at ``https://{endpoint}/openai/v1/`` which +routes to the Responses API (``/responses``). Reuses shared conversion +helpers from :mod:`nanobot.providers.openai_responses`. +""" from __future__ import annotations -import json import uuid from collections.abc import Awaitable, Callable from typing import Any -from urllib.parse import urljoin -import httpx -import json_repair +from openai import AsyncOpenAI -from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest - -_AZURE_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name"}) +from nanobot.providers.base import LLMProvider, LLMResponse +from nanobot.providers.openai_responses import ( + consume_sdk_stream, + convert_messages, + convert_tools, + parse_response_output, +) class AzureOpenAIProvider(LLMProvider): - """ - Azure OpenAI provider with API version 2024-10-21 compliance. - + """Azure OpenAI provider backed by the Responses API. + Features: - - Hardcoded API version 2024-10-21 - - Uses model field as Azure deployment name in URL path - - Uses api-key header instead of Authorization Bearer - - Uses max_completion_tokens instead of max_tokens - - Direct HTTP calls, bypasses LiteLLM + - Uses the OpenAI Python SDK (``AsyncOpenAI``) with + ``base_url = {endpoint}/openai/v1/`` + - Calls ``client.responses.create()`` (Responses API) + - Reuses shared message/tool/SSE conversion from + ``openai_responses`` """ def __init__( @@ -36,40 +41,28 @@ class AzureOpenAIProvider(LLMProvider): ): super().__init__(api_key, api_base) self.default_model = default_model - self.api_version = "2024-10-21" - - # Validate required parameters + if not api_key: raise ValueError("Azure OpenAI api_key is required") if not api_base: raise ValueError("Azure OpenAI api_base is required") - - # Ensure api_base ends with / - if not api_base.endswith('/'): - api_base += '/' + + # Normalise: ensure trailing slash + if not api_base.endswith("/"): + api_base += "/" self.api_base = api_base - def _build_chat_url(self, deployment_name: str) -> str: - """Build the Azure OpenAI chat completions URL.""" - # Azure OpenAI URL format: - # https://{resource}.openai.azure.com/openai/deployments/{deployment}/chat/completions?api-version={version} - base_url = self.api_base - if not base_url.endswith('/'): - base_url += '/' - - url = urljoin( - base_url, - f"openai/deployments/{deployment_name}/chat/completions" + # SDK client targeting the Azure Responses API endpoint + base_url = f"{api_base.rstrip('/')}/openai/v1/" + self._client = AsyncOpenAI( + api_key=api_key, + base_url=base_url, + default_headers={"x-session-affinity": uuid.uuid4().hex}, ) - return f"{url}?api-version={self.api_version}" - def _build_headers(self) -> dict[str, str]: - """Build headers for Azure OpenAI API with api-key header.""" - return { - "Content-Type": "application/json", - "api-key": self.api_key, # Azure OpenAI uses api-key header, not Authorization - "x-session-affinity": uuid.uuid4().hex, # For cache locality - } + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ @staticmethod def _supports_temperature( @@ -82,36 +75,51 @@ class AzureOpenAIProvider(LLMProvider): name = deployment_name.lower() return not any(token in name for token in ("gpt-5", "o1", "o3", "o4")) - def _prepare_request_payload( + def _build_body( self, - deployment_name: str, messages: list[dict[str, Any]], - tools: list[dict[str, Any]] | None = None, - max_tokens: int = 4096, - temperature: float = 0.7, - reasoning_effort: str | None = None, - tool_choice: str | dict[str, Any] | None = None, + tools: list[dict[str, Any]] | None, + model: str | None, + max_tokens: int, + temperature: float, + reasoning_effort: str | None, + tool_choice: str | dict[str, Any] | None, ) -> dict[str, Any]: - """Prepare the request payload with Azure OpenAI 2024-10-21 compliance.""" - payload: dict[str, Any] = { - "messages": self._sanitize_request_messages( - self._sanitize_empty_content(messages), - _AZURE_MSG_KEYS, - ), - "max_completion_tokens": max(1, max_tokens), # Azure API 2024-10-21 uses max_completion_tokens + """Build the Responses API request body from Chat-Completions-style args.""" + deployment = model or self.default_model + instructions, input_items = convert_messages(self._sanitize_empty_content(messages)) + + body: dict[str, Any] = { + "model": deployment, + "instructions": instructions or None, + "input": input_items, + "max_output_tokens": max(1, max_tokens), + "store": False, + "stream": False, } - if self._supports_temperature(deployment_name, reasoning_effort): - payload["temperature"] = temperature + if self._supports_temperature(deployment, reasoning_effort): + body["temperature"] = temperature if reasoning_effort: - payload["reasoning_effort"] = reasoning_effort + body["reasoning"] = {"effort": reasoning_effort} + body["include"] = ["reasoning.encrypted_content"] if tools: - payload["tools"] = tools - payload["tool_choice"] = tool_choice or "auto" + body["tools"] = convert_tools(tools) + body["tool_choice"] = tool_choice or "auto" - return payload + return body + + @staticmethod + def _handle_error(e: Exception) -> LLMResponse: + body = getattr(e, "body", None) or getattr(getattr(e, "response", None), "text", None) + msg = f"Error: {str(body).strip()[:500]}" if body else f"Error calling Azure OpenAI: {e}" + return LLMResponse(content=msg, finish_reason="error") + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ async def chat( self, @@ -123,92 +131,15 @@ class AzureOpenAIProvider(LLMProvider): reasoning_effort: str | None = None, tool_choice: str | dict[str, Any] | None = None, ) -> LLMResponse: - """ - Send a chat completion request to Azure OpenAI. - - Args: - messages: List of message dicts with 'role' and 'content'. - tools: Optional list of tool definitions in OpenAI format. - model: Model identifier (used as deployment name). - max_tokens: Maximum tokens in response (mapped to max_completion_tokens). - temperature: Sampling temperature. - reasoning_effort: Optional reasoning effort parameter. - - Returns: - LLMResponse with content and/or tool calls. - """ - deployment_name = model or self.default_model - url = self._build_chat_url(deployment_name) - headers = self._build_headers() - payload = self._prepare_request_payload( - deployment_name, messages, tools, max_tokens, temperature, reasoning_effort, - tool_choice=tool_choice, + body = self._build_body( + messages, tools, model, max_tokens, temperature, + reasoning_effort, tool_choice, ) - try: - async with httpx.AsyncClient(timeout=60.0, verify=True) as client: - response = await client.post(url, headers=headers, json=payload) - if response.status_code != 200: - return LLMResponse( - content=f"Azure OpenAI API Error {response.status_code}: {response.text}", - finish_reason="error", - ) - - response_data = response.json() - return self._parse_response(response_data) - + response = await self._client.responses.create(**body) + return parse_response_output(response) except Exception as e: - return LLMResponse( - content=f"Error calling Azure OpenAI: {repr(e)}", - finish_reason="error", - ) - - def _parse_response(self, response: dict[str, Any]) -> LLMResponse: - """Parse Azure OpenAI response into our standard format.""" - try: - choice = response["choices"][0] - message = choice["message"] - - tool_calls = [] - if message.get("tool_calls"): - for tc in message["tool_calls"]: - # Parse arguments from JSON string if needed - args = tc["function"]["arguments"] - if isinstance(args, str): - args = json_repair.loads(args) - - tool_calls.append( - ToolCallRequest( - id=tc["id"], - name=tc["function"]["name"], - arguments=args, - ) - ) - - usage = {} - if response.get("usage"): - usage_data = response["usage"] - usage = { - "prompt_tokens": usage_data.get("prompt_tokens", 0), - "completion_tokens": usage_data.get("completion_tokens", 0), - "total_tokens": usage_data.get("total_tokens", 0), - } - - reasoning_content = message.get("reasoning_content") or None - - return LLMResponse( - content=message.get("content"), - tool_calls=tool_calls, - finish_reason=choice.get("finish_reason", "stop"), - usage=usage, - reasoning_content=reasoning_content, - ) - - except (KeyError, IndexError) as e: - return LLMResponse( - content=f"Error parsing Azure OpenAI response: {str(e)}", - finish_reason="error", - ) + return self._handle_error(e) async def chat_stream( self, @@ -221,89 +152,26 @@ class AzureOpenAIProvider(LLMProvider): tool_choice: str | dict[str, Any] | None = None, on_content_delta: Callable[[str], Awaitable[None]] | None = None, ) -> LLMResponse: - """Stream a chat completion via Azure OpenAI SSE.""" - deployment_name = model or self.default_model - url = self._build_chat_url(deployment_name) - headers = self._build_headers() - payload = self._prepare_request_payload( - deployment_name, messages, tools, max_tokens, temperature, - reasoning_effort, tool_choice=tool_choice, + body = self._build_body( + messages, tools, model, max_tokens, temperature, + reasoning_effort, tool_choice, ) - payload["stream"] = True + body["stream"] = True try: - async with httpx.AsyncClient(timeout=60.0, verify=True) as client: - async with client.stream("POST", url, headers=headers, json=payload) as response: - if response.status_code != 200: - text = await response.aread() - return LLMResponse( - content=f"Azure OpenAI API Error {response.status_code}: {text.decode('utf-8', 'ignore')}", - finish_reason="error", - ) - return await self._consume_stream(response, on_content_delta) - except Exception as e: - return LLMResponse(content=f"Error calling Azure OpenAI: {repr(e)}", finish_reason="error") - - async def _consume_stream( - self, - response: httpx.Response, - on_content_delta: Callable[[str], Awaitable[None]] | None, - ) -> LLMResponse: - """Parse Azure OpenAI SSE stream into an LLMResponse.""" - content_parts: list[str] = [] - tool_call_buffers: dict[int, dict[str, str]] = {} - finish_reason = "stop" - - async for line in response.aiter_lines(): - if not line.startswith("data: "): - continue - data = line[6:].strip() - if data == "[DONE]": - break - try: - chunk = json.loads(data) - except Exception: - continue - - choices = chunk.get("choices") or [] - if not choices: - continue - choice = choices[0] - if choice.get("finish_reason"): - finish_reason = choice["finish_reason"] - delta = choice.get("delta") or {} - - text = delta.get("content") - if text: - content_parts.append(text) - if on_content_delta: - await on_content_delta(text) - - for tc in delta.get("tool_calls") or []: - idx = tc.get("index", 0) - buf = tool_call_buffers.setdefault(idx, {"id": "", "name": "", "arguments": ""}) - if tc.get("id"): - buf["id"] = tc["id"] - fn = tc.get("function") or {} - if fn.get("name"): - buf["name"] = fn["name"] - if fn.get("arguments"): - buf["arguments"] += fn["arguments"] - - tool_calls = [ - ToolCallRequest( - id=buf["id"], name=buf["name"], - arguments=json_repair.loads(buf["arguments"]) if buf["arguments"] else {}, + stream = await self._client.responses.create(**body) + content, tool_calls, finish_reason, usage, reasoning_content = ( + await consume_sdk_stream(stream, on_content_delta) ) - for buf in tool_call_buffers.values() - ] - - return LLMResponse( - content="".join(content_parts) or None, - tool_calls=tool_calls, - finish_reason=finish_reason, - ) + return LLMResponse( + content=content or None, + tool_calls=tool_calls, + finish_reason=finish_reason, + usage=usage, + reasoning_content=reasoning_content, + ) + except Exception as e: + return self._handle_error(e) def get_default_model(self) -> str: - """Get the default model (also used as default deployment name).""" return self.default_model \ No newline at end of file diff --git a/nanobot/providers/base.py b/nanobot/providers/base.py index 9ce2b0c63..852e9c973 100644 --- a/nanobot/providers/base.py +++ b/nanobot/providers/base.py @@ -2,6 +2,7 @@ import asyncio import json +import re from abc import ABC, abstractmethod from collections.abc import Awaitable, Callable from dataclasses import dataclass, field @@ -9,6 +10,8 @@ from typing import Any from loguru import logger +from nanobot.utils.helpers import image_placeholder_text + @dataclass class ToolCallRequest: @@ -57,13 +60,7 @@ class LLMResponse: @dataclass(frozen=True) class GenerationSettings: - """Default generation parameters for LLM calls. - - Stored on the provider so every call site inherits the same defaults - without having to pass temperature / max_tokens / reasoning_effort - through every layer. Individual call sites can still override by - passing explicit keyword arguments to chat() / chat_with_retry(). - """ + """Default generation settings.""" temperature: float = 0.7 max_tokens: int = 4096 @@ -71,14 +68,12 @@ class GenerationSettings: class LLMProvider(ABC): - """ - Abstract base class for LLM providers. - - Implementations should handle the specifics of each provider's API - while maintaining a consistent interface. - """ + """Base class for LLM providers.""" _CHAT_RETRY_DELAYS = (1, 2, 4) + _PERSISTENT_MAX_DELAY = 60 + _PERSISTENT_IDENTICAL_ERROR_LIMIT = 10 + _RETRY_HEARTBEAT_CHUNK = 30 _TRANSIENT_ERROR_MARKERS = ( "429", "rate limit", @@ -208,7 +203,7 @@ class LLMProvider(ABC): for b in content: if isinstance(b, dict) and b.get("type") == "image_url": path = (b.get("_meta") or {}).get("path", "") - placeholder = f"[image: {path}]" if path else "[image omitted]" + placeholder = image_placeholder_text(path, empty="[image omitted]") new_content.append({"type": "text", "text": placeholder}) found = True else: @@ -273,6 +268,8 @@ class LLMProvider(ABC): reasoning_effort: object = _SENTINEL, tool_choice: str | dict[str, Any] | None = None, on_content_delta: Callable[[str], Awaitable[None]] | None = None, + retry_mode: str = "standard", + on_retry_wait: Callable[[str], Awaitable[None]] | None = None, ) -> LLMResponse: """Call chat_stream() with retry on transient provider failures.""" if max_tokens is self._SENTINEL: @@ -288,28 +285,13 @@ class LLMProvider(ABC): reasoning_effort=reasoning_effort, tool_choice=tool_choice, on_content_delta=on_content_delta, ) - - for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1): - response = await self._safe_chat_stream(**kw) - - if response.finish_reason != "error": - return response - - if not self._is_transient_error(response.content): - stripped = self._strip_image_content(messages) - if stripped is not None: - logger.warning("Non-transient LLM error with image content, retrying without images") - return await self._safe_chat_stream(**{**kw, "messages": stripped}) - return response - - logger.warning( - "LLM transient error (attempt {}/{}), retrying in {}s: {}", - attempt, len(self._CHAT_RETRY_DELAYS), delay, - (response.content or "")[:120].lower(), - ) - await asyncio.sleep(delay) - - return await self._safe_chat_stream(**kw) + return await self._run_with_retry( + self._safe_chat_stream, + kw, + messages, + retry_mode=retry_mode, + on_retry_wait=on_retry_wait, + ) async def chat_with_retry( self, @@ -320,6 +302,8 @@ class LLMProvider(ABC): temperature: object = _SENTINEL, reasoning_effort: object = _SENTINEL, tool_choice: str | dict[str, Any] | None = None, + retry_mode: str = "standard", + on_retry_wait: Callable[[str], Awaitable[None]] | None = None, ) -> LLMResponse: """Call chat() with retry on transient provider failures. @@ -339,28 +323,118 @@ class LLMProvider(ABC): max_tokens=max_tokens, temperature=temperature, reasoning_effort=reasoning_effort, tool_choice=tool_choice, ) + return await self._run_with_retry( + self._safe_chat, + kw, + messages, + retry_mode=retry_mode, + on_retry_wait=on_retry_wait, + ) - for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1): - response = await self._safe_chat(**kw) + @classmethod + def _extract_retry_after(cls, content: str | None) -> float | None: + text = (content or "").lower() + match = re.search(r"retry after\s+(\d+(?:\.\d+)?)\s*(ms|milliseconds|s|sec|secs|seconds|m|min|minutes)?", text) + if not match: + return None + value = float(match.group(1)) + unit = (match.group(2) or "s").lower() + if unit in {"ms", "milliseconds"}: + return max(0.1, value / 1000.0) + if unit in {"m", "min", "minutes"}: + return value * 60.0 + return value + async def _sleep_with_heartbeat( + self, + delay: float, + *, + attempt: int, + persistent: bool, + on_retry_wait: Callable[[str], Awaitable[None]] | None = None, + ) -> None: + remaining = max(0.0, delay) + while remaining > 0: + if on_retry_wait: + kind = "persistent retry" if persistent else "retry" + await on_retry_wait( + f"Model request failed, {kind} in {max(1, int(round(remaining)))}s " + f"(attempt {attempt})." + ) + chunk = min(remaining, self._RETRY_HEARTBEAT_CHUNK) + await asyncio.sleep(chunk) + remaining -= chunk + + async def _run_with_retry( + self, + call: Callable[..., Awaitable[LLMResponse]], + kw: dict[str, Any], + original_messages: list[dict[str, Any]], + *, + retry_mode: str, + on_retry_wait: Callable[[str], Awaitable[None]] | None, + ) -> LLMResponse: + attempt = 0 + delays = list(self._CHAT_RETRY_DELAYS) + persistent = retry_mode == "persistent" + last_response: LLMResponse | None = None + last_error_key: str | None = None + identical_error_count = 0 + while True: + attempt += 1 + response = await call(**kw) if response.finish_reason != "error": return response + last_response = response + error_key = ((response.content or "").strip().lower() or None) + if error_key and error_key == last_error_key: + identical_error_count += 1 + else: + last_error_key = error_key + identical_error_count = 1 if error_key else 0 if not self._is_transient_error(response.content): - stripped = self._strip_image_content(messages) - if stripped is not None: - logger.warning("Non-transient LLM error with image content, retrying without images") - return await self._safe_chat(**{**kw, "messages": stripped}) + stripped = self._strip_image_content(original_messages) + if stripped is not None and stripped != kw["messages"]: + logger.warning( + "Non-transient LLM error with image content, retrying without images" + ) + retry_kw = dict(kw) + retry_kw["messages"] = stripped + return await call(**retry_kw) return response + if persistent and identical_error_count >= self._PERSISTENT_IDENTICAL_ERROR_LIMIT: + logger.warning( + "Stopping persistent retry after {} identical transient errors: {}", + identical_error_count, + (response.content or "")[:120].lower(), + ) + return response + + if not persistent and attempt > len(delays): + break + + base_delay = delays[min(attempt - 1, len(delays) - 1)] + delay = self._extract_retry_after(response.content) or base_delay + if persistent: + delay = min(delay, self._PERSISTENT_MAX_DELAY) + logger.warning( - "LLM transient error (attempt {}/{}), retrying in {}s: {}", - attempt, len(self._CHAT_RETRY_DELAYS), delay, + "LLM transient error (attempt {}{}), retrying in {}s: {}", + attempt, + "+" if persistent and attempt > len(delays) else f"/{len(delays)}", + int(round(delay)), (response.content or "")[:120].lower(), ) - await asyncio.sleep(delay) + await self._sleep_with_heartbeat( + delay, + attempt=attempt, + persistent=persistent, + on_retry_wait=on_retry_wait, + ) - return await self._safe_chat(**kw) + return last_response if last_response is not None else await call(**kw) @abstractmethod def get_default_model(self) -> str: diff --git a/nanobot/providers/openai_codex_provider.py b/nanobot/providers/openai_codex_provider.py index 1c6bc7075..265b4b106 100644 --- a/nanobot/providers/openai_codex_provider.py +++ b/nanobot/providers/openai_codex_provider.py @@ -6,13 +6,18 @@ import asyncio import hashlib import json from collections.abc import Awaitable, Callable -from typing import Any, AsyncGenerator +from typing import Any import httpx from loguru import logger from oauth_cli_kit import get_token as get_codex_token from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest +from nanobot.providers.openai_responses import ( + consume_sse, + convert_messages, + convert_tools, +) DEFAULT_CODEX_URL = "https://chatgpt.com/backend-api/codex/responses" DEFAULT_ORIGINATOR = "nanobot" @@ -36,7 +41,7 @@ class OpenAICodexProvider(LLMProvider): ) -> LLMResponse: """Shared request logic for both chat() and chat_stream().""" model = model or self.default_model - system_prompt, input_items = _convert_messages(messages) + system_prompt, input_items = convert_messages(messages) token = await asyncio.to_thread(get_codex_token) headers = _build_headers(token.account_id, token.access) @@ -56,7 +61,7 @@ class OpenAICodexProvider(LLMProvider): if reasoning_effort: body["reasoning"] = {"effort": reasoning_effort} if tools: - body["tools"] = _convert_tools(tools) + body["tools"] = convert_tools(tools) try: try: @@ -127,96 +132,7 @@ async def _request_codex( if response.status_code != 200: text = await response.aread() raise RuntimeError(_friendly_error(response.status_code, text.decode("utf-8", "ignore"))) - return await _consume_sse(response, on_content_delta) - - -def _convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]: - """Convert OpenAI function-calling schema to Codex flat format.""" - converted: list[dict[str, Any]] = [] - for tool in tools: - fn = (tool.get("function") or {}) if tool.get("type") == "function" else tool - name = fn.get("name") - if not name: - continue - params = fn.get("parameters") or {} - converted.append({ - "type": "function", - "name": name, - "description": fn.get("description") or "", - "parameters": params if isinstance(params, dict) else {}, - }) - return converted - - -def _convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[str, Any]]]: - system_prompt = "" - input_items: list[dict[str, Any]] = [] - - for idx, msg in enumerate(messages): - role = msg.get("role") - content = msg.get("content") - - if role == "system": - system_prompt = content if isinstance(content, str) else "" - continue - - if role == "user": - input_items.append(_convert_user_message(content)) - continue - - if role == "assistant": - if isinstance(content, str) and content: - input_items.append({ - "type": "message", "role": "assistant", - "content": [{"type": "output_text", "text": content}], - "status": "completed", "id": f"msg_{idx}", - }) - for tool_call in msg.get("tool_calls", []) or []: - fn = tool_call.get("function") or {} - call_id, item_id = _split_tool_call_id(tool_call.get("id")) - input_items.append({ - "type": "function_call", - "id": item_id or f"fc_{idx}", - "call_id": call_id or f"call_{idx}", - "name": fn.get("name"), - "arguments": fn.get("arguments") or "{}", - }) - continue - - if role == "tool": - call_id, _ = _split_tool_call_id(msg.get("tool_call_id")) - output_text = content if isinstance(content, str) else json.dumps(content, ensure_ascii=False) - input_items.append({"type": "function_call_output", "call_id": call_id, "output": output_text}) - - return system_prompt, input_items - - -def _convert_user_message(content: Any) -> dict[str, Any]: - if isinstance(content, str): - return {"role": "user", "content": [{"type": "input_text", "text": content}]} - if isinstance(content, list): - converted: list[dict[str, Any]] = [] - for item in content: - if not isinstance(item, dict): - continue - if item.get("type") == "text": - converted.append({"type": "input_text", "text": item.get("text", "")}) - elif item.get("type") == "image_url": - url = (item.get("image_url") or {}).get("url") - if url: - converted.append({"type": "input_image", "image_url": url, "detail": "auto"}) - if converted: - return {"role": "user", "content": converted} - return {"role": "user", "content": [{"type": "input_text", "text": ""}]} - - -def _split_tool_call_id(tool_call_id: Any) -> tuple[str, str | None]: - if isinstance(tool_call_id, str) and tool_call_id: - if "|" in tool_call_id: - call_id, item_id = tool_call_id.split("|", 1) - return call_id, item_id or None - return tool_call_id, None - return "call_0", None + return await consume_sse(response, on_content_delta) def _prompt_cache_key(messages: list[dict[str, Any]]) -> str: @@ -224,96 +140,6 @@ def _prompt_cache_key(messages: list[dict[str, Any]]) -> str: return hashlib.sha256(raw.encode("utf-8")).hexdigest() -async def _iter_sse(response: httpx.Response) -> AsyncGenerator[dict[str, Any], None]: - buffer: list[str] = [] - async for line in response.aiter_lines(): - if line == "": - if buffer: - data_lines = [l[5:].strip() for l in buffer if l.startswith("data:")] - buffer = [] - if not data_lines: - continue - data = "\n".join(data_lines).strip() - if not data or data == "[DONE]": - continue - try: - yield json.loads(data) - except Exception: - continue - continue - buffer.append(line) - - -async def _consume_sse( - response: httpx.Response, - on_content_delta: Callable[[str], Awaitable[None]] | None = None, -) -> tuple[str, list[ToolCallRequest], str]: - content = "" - tool_calls: list[ToolCallRequest] = [] - tool_call_buffers: dict[str, dict[str, Any]] = {} - finish_reason = "stop" - - async for event in _iter_sse(response): - event_type = event.get("type") - if event_type == "response.output_item.added": - item = event.get("item") or {} - if item.get("type") == "function_call": - call_id = item.get("call_id") - if not call_id: - continue - tool_call_buffers[call_id] = { - "id": item.get("id") or "fc_0", - "name": item.get("name"), - "arguments": item.get("arguments") or "", - } - elif event_type == "response.output_text.delta": - delta_text = event.get("delta") or "" - content += delta_text - if on_content_delta and delta_text: - await on_content_delta(delta_text) - elif event_type == "response.function_call_arguments.delta": - call_id = event.get("call_id") - if call_id and call_id in tool_call_buffers: - tool_call_buffers[call_id]["arguments"] += event.get("delta") or "" - elif event_type == "response.function_call_arguments.done": - call_id = event.get("call_id") - if call_id and call_id in tool_call_buffers: - tool_call_buffers[call_id]["arguments"] = event.get("arguments") or "" - elif event_type == "response.output_item.done": - item = event.get("item") or {} - if item.get("type") == "function_call": - call_id = item.get("call_id") - if not call_id: - continue - buf = tool_call_buffers.get(call_id) or {} - args_raw = buf.get("arguments") or item.get("arguments") or "{}" - try: - args = json.loads(args_raw) - except Exception: - args = {"raw": args_raw} - tool_calls.append( - ToolCallRequest( - id=f"{call_id}|{buf.get('id') or item.get('id') or 'fc_0'}", - name=buf.get("name") or item.get("name"), - arguments=args, - ) - ) - elif event_type == "response.completed": - status = (event.get("response") or {}).get("status") - finish_reason = _map_finish_reason(status) - elif event_type in {"error", "response.failed"}: - raise RuntimeError("Codex response failed") - - return content, tool_calls, finish_reason - - -_FINISH_REASON_MAP = {"completed": "stop", "incomplete": "length", "failed": "error", "cancelled": "error"} - - -def _map_finish_reason(status: str | None) -> str: - return _FINISH_REASON_MAP.get(status or "completed", "stop") - - def _friendly_error(status_code: int, raw: str) -> str: if status_code == 429: return "ChatGPT usage quota exceeded or rate limit triggered. Please try again later." diff --git a/nanobot/providers/openai_compat_provider.py b/nanobot/providers/openai_compat_provider.py index f89879c90..3e0a34fbf 100644 --- a/nanobot/providers/openai_compat_provider.py +++ b/nanobot/providers/openai_compat_provider.py @@ -2,6 +2,7 @@ from __future__ import annotations +import asyncio import hashlib import os import secrets @@ -20,7 +21,6 @@ if TYPE_CHECKING: _ALLOWED_MSG_KEYS = frozenset({ "role", "content", "tool_calls", "tool_call_id", "name", - "reasoning_content", "extra_content", }) _ALNUM = string.ascii_letters + string.digits @@ -615,16 +615,33 @@ class OpenAICompatProvider(LLMProvider): ) kwargs["stream"] = True kwargs["stream_options"] = {"include_usage": True} + idle_timeout_s = int(os.environ.get("NANOBOT_STREAM_IDLE_TIMEOUT_S", "90")) try: stream = await self._client.chat.completions.create(**kwargs) chunks: list[Any] = [] - async for chunk in stream: + stream_iter = stream.__aiter__() + while True: + try: + chunk = await asyncio.wait_for( + stream_iter.__anext__(), + timeout=idle_timeout_s, + ) + except StopAsyncIteration: + break chunks.append(chunk) if on_content_delta and chunk.choices: text = getattr(chunk.choices[0].delta, "content", None) if text: await on_content_delta(text) return self._parse_chunks(chunks) + except asyncio.TimeoutError: + return LLMResponse( + content=( + f"Error calling LLM: stream stalled for more than " + f"{idle_timeout_s} seconds" + ), + finish_reason="error", + ) except Exception as e: return self._handle_error(e) diff --git a/nanobot/providers/openai_responses/__init__.py b/nanobot/providers/openai_responses/__init__.py new file mode 100644 index 000000000..b40e896ed --- /dev/null +++ b/nanobot/providers/openai_responses/__init__.py @@ -0,0 +1,29 @@ +"""Shared helpers for OpenAI Responses API providers (Codex, Azure OpenAI).""" + +from nanobot.providers.openai_responses.converters import ( + convert_messages, + convert_tools, + convert_user_message, + split_tool_call_id, +) +from nanobot.providers.openai_responses.parsing import ( + FINISH_REASON_MAP, + consume_sdk_stream, + consume_sse, + iter_sse, + map_finish_reason, + parse_response_output, +) + +__all__ = [ + "convert_messages", + "convert_tools", + "convert_user_message", + "split_tool_call_id", + "iter_sse", + "consume_sse", + "consume_sdk_stream", + "map_finish_reason", + "parse_response_output", + "FINISH_REASON_MAP", +] diff --git a/nanobot/providers/openai_responses/converters.py b/nanobot/providers/openai_responses/converters.py new file mode 100644 index 000000000..e0bfe832d --- /dev/null +++ b/nanobot/providers/openai_responses/converters.py @@ -0,0 +1,110 @@ +"""Convert Chat Completions messages/tools to Responses API format.""" + +from __future__ import annotations + +import json +from typing import Any + + +def convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[str, Any]]]: + """Convert Chat Completions messages to Responses API input items. + + Returns ``(system_prompt, input_items)`` where *system_prompt* is extracted + from any ``system`` role message and *input_items* is the Responses API + ``input`` array. + """ + system_prompt = "" + input_items: list[dict[str, Any]] = [] + + for idx, msg in enumerate(messages): + role = msg.get("role") + content = msg.get("content") + + if role == "system": + system_prompt = content if isinstance(content, str) else "" + continue + + if role == "user": + input_items.append(convert_user_message(content)) + continue + + if role == "assistant": + if isinstance(content, str) and content: + input_items.append({ + "type": "message", "role": "assistant", + "content": [{"type": "output_text", "text": content}], + "status": "completed", "id": f"msg_{idx}", + }) + for tool_call in msg.get("tool_calls", []) or []: + fn = tool_call.get("function") or {} + call_id, item_id = split_tool_call_id(tool_call.get("id")) + input_items.append({ + "type": "function_call", + "id": item_id or f"fc_{idx}", + "call_id": call_id or f"call_{idx}", + "name": fn.get("name"), + "arguments": fn.get("arguments") or "{}", + }) + continue + + if role == "tool": + call_id, _ = split_tool_call_id(msg.get("tool_call_id")) + output_text = content if isinstance(content, str) else json.dumps(content, ensure_ascii=False) + input_items.append({"type": "function_call_output", "call_id": call_id, "output": output_text}) + + return system_prompt, input_items + + +def convert_user_message(content: Any) -> dict[str, Any]: + """Convert a user message's content to Responses API format. + + Handles plain strings, ``text`` blocks -> ``input_text``, and + ``image_url`` blocks -> ``input_image``. + """ + if isinstance(content, str): + return {"role": "user", "content": [{"type": "input_text", "text": content}]} + if isinstance(content, list): + converted: list[dict[str, Any]] = [] + for item in content: + if not isinstance(item, dict): + continue + if item.get("type") == "text": + converted.append({"type": "input_text", "text": item.get("text", "")}) + elif item.get("type") == "image_url": + url = (item.get("image_url") or {}).get("url") + if url: + converted.append({"type": "input_image", "image_url": url, "detail": "auto"}) + if converted: + return {"role": "user", "content": converted} + return {"role": "user", "content": [{"type": "input_text", "text": ""}]} + + +def convert_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Convert OpenAI function-calling tool schema to Responses API flat format.""" + converted: list[dict[str, Any]] = [] + for tool in tools: + fn = (tool.get("function") or {}) if tool.get("type") == "function" else tool + name = fn.get("name") + if not name: + continue + params = fn.get("parameters") or {} + converted.append({ + "type": "function", + "name": name, + "description": fn.get("description") or "", + "parameters": params if isinstance(params, dict) else {}, + }) + return converted + + +def split_tool_call_id(tool_call_id: Any) -> tuple[str, str | None]: + """Split a compound ``call_id|item_id`` string. + + Returns ``(call_id, item_id)`` where *item_id* may be ``None``. + """ + if isinstance(tool_call_id, str) and tool_call_id: + if "|" in tool_call_id: + call_id, item_id = tool_call_id.split("|", 1) + return call_id, item_id or None + return tool_call_id, None + return "call_0", None diff --git a/nanobot/providers/openai_responses/parsing.py b/nanobot/providers/openai_responses/parsing.py new file mode 100644 index 000000000..9e3f0ef02 --- /dev/null +++ b/nanobot/providers/openai_responses/parsing.py @@ -0,0 +1,297 @@ +"""Parse Responses API SSE streams and SDK response objects.""" + +from __future__ import annotations + +import json +from collections.abc import Awaitable, Callable +from typing import Any, AsyncGenerator + +import httpx +import json_repair +from loguru import logger + +from nanobot.providers.base import LLMResponse, ToolCallRequest + +FINISH_REASON_MAP = { + "completed": "stop", + "incomplete": "length", + "failed": "error", + "cancelled": "error", +} + + +def map_finish_reason(status: str | None) -> str: + """Map a Responses API status string to a Chat-Completions-style finish_reason.""" + return FINISH_REASON_MAP.get(status or "completed", "stop") + + +async def iter_sse(response: httpx.Response) -> AsyncGenerator[dict[str, Any], None]: + """Yield parsed JSON events from a Responses API SSE stream.""" + buffer: list[str] = [] + + def _flush() -> dict[str, Any] | None: + data_lines = [l[5:].strip() for l in buffer if l.startswith("data:")] + buffer.clear() + if not data_lines: + return None + data = "\n".join(data_lines).strip() + if not data or data == "[DONE]": + return None + try: + return json.loads(data) + except Exception: + logger.warning("Failed to parse SSE event JSON: {}", data[:200]) + return None + + async for line in response.aiter_lines(): + if line == "": + if buffer: + event = _flush() + if event is not None: + yield event + continue + buffer.append(line) + + # Flush any remaining buffer at EOF (#10) + if buffer: + event = _flush() + if event is not None: + yield event + + +async def consume_sse( + response: httpx.Response, + on_content_delta: Callable[[str], Awaitable[None]] | None = None, +) -> tuple[str, list[ToolCallRequest], str]: + """Consume a Responses API SSE stream into ``(content, tool_calls, finish_reason)``.""" + content = "" + tool_calls: list[ToolCallRequest] = [] + tool_call_buffers: dict[str, dict[str, Any]] = {} + finish_reason = "stop" + + async for event in iter_sse(response): + event_type = event.get("type") + if event_type == "response.output_item.added": + item = event.get("item") or {} + if item.get("type") == "function_call": + call_id = item.get("call_id") + if not call_id: + continue + tool_call_buffers[call_id] = { + "id": item.get("id") or "fc_0", + "name": item.get("name"), + "arguments": item.get("arguments") or "", + } + elif event_type == "response.output_text.delta": + delta_text = event.get("delta") or "" + content += delta_text + if on_content_delta and delta_text: + await on_content_delta(delta_text) + elif event_type == "response.function_call_arguments.delta": + call_id = event.get("call_id") + if call_id and call_id in tool_call_buffers: + tool_call_buffers[call_id]["arguments"] += event.get("delta") or "" + elif event_type == "response.function_call_arguments.done": + call_id = event.get("call_id") + if call_id and call_id in tool_call_buffers: + tool_call_buffers[call_id]["arguments"] = event.get("arguments") or "" + elif event_type == "response.output_item.done": + item = event.get("item") or {} + if item.get("type") == "function_call": + call_id = item.get("call_id") + if not call_id: + continue + buf = tool_call_buffers.get(call_id) or {} + args_raw = buf.get("arguments") or item.get("arguments") or "{}" + try: + args = json.loads(args_raw) + except Exception: + logger.warning( + "Failed to parse tool call arguments for '{}': {}", + buf.get("name") or item.get("name"), + args_raw[:200], + ) + args = json_repair.loads(args_raw) + if not isinstance(args, dict): + args = {"raw": args_raw} + tool_calls.append( + ToolCallRequest( + id=f"{call_id}|{buf.get('id') or item.get('id') or 'fc_0'}", + name=buf.get("name") or item.get("name") or "", + arguments=args, + ) + ) + elif event_type == "response.completed": + status = (event.get("response") or {}).get("status") + finish_reason = map_finish_reason(status) + elif event_type in {"error", "response.failed"}: + detail = event.get("error") or event.get("message") or event + raise RuntimeError(f"Response failed: {str(detail)[:500]}") + + return content, tool_calls, finish_reason + + +def parse_response_output(response: Any) -> LLMResponse: + """Parse an SDK ``Response`` object into an ``LLMResponse``.""" + if not isinstance(response, dict): + dump = getattr(response, "model_dump", None) + response = dump() if callable(dump) else vars(response) + + output = response.get("output") or [] + content_parts: list[str] = [] + tool_calls: list[ToolCallRequest] = [] + reasoning_content: str | None = None + + for item in output: + if not isinstance(item, dict): + dump = getattr(item, "model_dump", None) + item = dump() if callable(dump) else vars(item) + + item_type = item.get("type") + if item_type == "message": + for block in item.get("content") or []: + if not isinstance(block, dict): + dump = getattr(block, "model_dump", None) + block = dump() if callable(dump) else vars(block) + if block.get("type") == "output_text": + content_parts.append(block.get("text") or "") + elif item_type == "reasoning": + for s in item.get("summary") or []: + if not isinstance(s, dict): + dump = getattr(s, "model_dump", None) + s = dump() if callable(dump) else vars(s) + if s.get("type") == "summary_text" and s.get("text"): + reasoning_content = (reasoning_content or "") + s["text"] + elif item_type == "function_call": + call_id = item.get("call_id") or "" + item_id = item.get("id") or "fc_0" + args_raw = item.get("arguments") or "{}" + try: + args = json.loads(args_raw) if isinstance(args_raw, str) else args_raw + except Exception: + logger.warning( + "Failed to parse tool call arguments for '{}': {}", + item.get("name"), + str(args_raw)[:200], + ) + args = json_repair.loads(args_raw) if isinstance(args_raw, str) else args_raw + if not isinstance(args, dict): + args = {"raw": args_raw} + tool_calls.append(ToolCallRequest( + id=f"{call_id}|{item_id}", + name=item.get("name") or "", + arguments=args if isinstance(args, dict) else {}, + )) + + usage_raw = response.get("usage") or {} + if not isinstance(usage_raw, dict): + dump = getattr(usage_raw, "model_dump", None) + usage_raw = dump() if callable(dump) else vars(usage_raw) + usage = {} + if usage_raw: + usage = { + "prompt_tokens": int(usage_raw.get("input_tokens") or 0), + "completion_tokens": int(usage_raw.get("output_tokens") or 0), + "total_tokens": int(usage_raw.get("total_tokens") or 0), + } + + status = response.get("status") + finish_reason = map_finish_reason(status) + + return LLMResponse( + content="".join(content_parts) or None, + tool_calls=tool_calls, + finish_reason=finish_reason, + usage=usage, + reasoning_content=reasoning_content if isinstance(reasoning_content, str) else None, + ) + + +async def consume_sdk_stream( + stream: Any, + on_content_delta: Callable[[str], Awaitable[None]] | None = None, +) -> tuple[str, list[ToolCallRequest], str, dict[str, int], str | None]: + """Consume an SDK async stream from ``client.responses.create(stream=True)``.""" + content = "" + tool_calls: list[ToolCallRequest] = [] + tool_call_buffers: dict[str, dict[str, Any]] = {} + finish_reason = "stop" + usage: dict[str, int] = {} + reasoning_content: str | None = None + + async for event in stream: + event_type = getattr(event, "type", None) + if event_type == "response.output_item.added": + item = getattr(event, "item", None) + if item and getattr(item, "type", None) == "function_call": + call_id = getattr(item, "call_id", None) + if not call_id: + continue + tool_call_buffers[call_id] = { + "id": getattr(item, "id", None) or "fc_0", + "name": getattr(item, "name", None), + "arguments": getattr(item, "arguments", None) or "", + } + elif event_type == "response.output_text.delta": + delta_text = getattr(event, "delta", "") or "" + content += delta_text + if on_content_delta and delta_text: + await on_content_delta(delta_text) + elif event_type == "response.function_call_arguments.delta": + call_id = getattr(event, "call_id", None) + if call_id and call_id in tool_call_buffers: + tool_call_buffers[call_id]["arguments"] += getattr(event, "delta", "") or "" + elif event_type == "response.function_call_arguments.done": + call_id = getattr(event, "call_id", None) + if call_id and call_id in tool_call_buffers: + tool_call_buffers[call_id]["arguments"] = getattr(event, "arguments", "") or "" + elif event_type == "response.output_item.done": + item = getattr(event, "item", None) + if item and getattr(item, "type", None) == "function_call": + call_id = getattr(item, "call_id", None) + if not call_id: + continue + buf = tool_call_buffers.get(call_id) or {} + args_raw = buf.get("arguments") or getattr(item, "arguments", None) or "{}" + try: + args = json.loads(args_raw) + except Exception: + logger.warning( + "Failed to parse tool call arguments for '{}': {}", + buf.get("name") or getattr(item, "name", None), + str(args_raw)[:200], + ) + args = json_repair.loads(args_raw) + if not isinstance(args, dict): + args = {"raw": args_raw} + tool_calls.append( + ToolCallRequest( + id=f"{call_id}|{buf.get('id') or getattr(item, 'id', None) or 'fc_0'}", + name=buf.get("name") or getattr(item, "name", None) or "", + arguments=args, + ) + ) + elif event_type == "response.completed": + resp = getattr(event, "response", None) + status = getattr(resp, "status", None) if resp else None + finish_reason = map_finish_reason(status) + if resp: + usage_obj = getattr(resp, "usage", None) + if usage_obj: + usage = { + "prompt_tokens": int(getattr(usage_obj, "input_tokens", 0) or 0), + "completion_tokens": int(getattr(usage_obj, "output_tokens", 0) or 0), + "total_tokens": int(getattr(usage_obj, "total_tokens", 0) or 0), + } + for out_item in getattr(resp, "output", None) or []: + if getattr(out_item, "type", None) == "reasoning": + for s in getattr(out_item, "summary", None) or []: + if getattr(s, "type", None) == "summary_text": + text = getattr(s, "text", None) + if text: + reasoning_content = (reasoning_content or "") + text + elif event_type in {"error", "response.failed"}: + detail = getattr(event, "error", None) or getattr(event, "message", None) or event + raise RuntimeError(f"Response failed: {str(detail)[:500]}") + + return content, tool_calls, finish_reason, usage, reasoning_content diff --git a/nanobot/session/manager.py b/nanobot/session/manager.py index 537ba42d0..95e3916b9 100644 --- a/nanobot/session/manager.py +++ b/nanobot/session/manager.py @@ -10,20 +10,12 @@ from typing import Any from loguru import logger from nanobot.config.paths import get_legacy_sessions_dir -from nanobot.utils.helpers import ensure_dir, safe_filename +from nanobot.utils.helpers import ensure_dir, find_legal_message_start, safe_filename @dataclass class Session: - """ - A conversation session. - - Stores messages in JSONL format for easy reading and persistence. - - Important: Messages are append-only for LLM cache efficiency. - The consolidation process writes summaries to MEMORY.md/HISTORY.md - but does NOT modify the messages list or get_history() output. - """ + """A conversation session.""" key: str # channel:chat_id messages: list[dict[str, Any]] = field(default_factory=list) @@ -43,43 +35,19 @@ class Session: self.messages.append(msg) self.updated_at = datetime.now() - @staticmethod - def _find_legal_start(messages: list[dict[str, Any]]) -> int: - """Find first index where every tool result has a matching assistant tool_call.""" - declared: set[str] = set() - start = 0 - for i, msg in enumerate(messages): - role = msg.get("role") - if role == "assistant": - for tc in msg.get("tool_calls") or []: - if isinstance(tc, dict) and tc.get("id"): - declared.add(str(tc["id"])) - elif role == "tool": - tid = msg.get("tool_call_id") - if tid and str(tid) not in declared: - start = i + 1 - declared.clear() - for prev in messages[start:i + 1]: - if prev.get("role") == "assistant": - for tc in prev.get("tool_calls") or []: - if isinstance(tc, dict) and tc.get("id"): - declared.add(str(tc["id"])) - return start - def get_history(self, max_messages: int = 500) -> list[dict[str, Any]]: """Return unconsolidated messages for LLM input, aligned to a legal tool-call boundary.""" unconsolidated = self.messages[self.last_consolidated:] sliced = unconsolidated[-max_messages:] - # Drop leading non-user messages to avoid starting mid-turn when possible. + # Avoid starting mid-turn when possible. for i, message in enumerate(sliced): if message.get("role") == "user": sliced = sliced[i:] break - # Some providers reject orphan tool results if the matching assistant - # tool_calls message fell outside the fixed-size history window. - start = self._find_legal_start(sliced) + # Drop orphan tool results at the front. + start = find_legal_message_start(sliced) if start: sliced = sliced[start:] @@ -115,7 +83,7 @@ class Session: retained = self.messages[start_idx:] # Mirror get_history(): avoid persisting orphan tool results at the front. - start = self._find_legal_start(retained) + start = find_legal_message_start(retained) if start: retained = retained[start:] diff --git a/nanobot/utils/helpers.py b/nanobot/utils/helpers.py index 406a4dd45..9e0a69d5e 100644 --- a/nanobot/utils/helpers.py +++ b/nanobot/utils/helpers.py @@ -3,12 +3,15 @@ import base64 import json import re +import shutil import time +import uuid from datetime import datetime from pathlib import Path from typing import Any import tiktoken +from loguru import logger def strip_think(text: str) -> str: @@ -56,11 +59,7 @@ def timestamp() -> str: def current_time_str(timezone: str | None = None) -> str: - """Human-readable current time with weekday and UTC offset. - - When *timezone* is a valid IANA name (e.g. ``"Asia/Shanghai"``), the time - is converted to that zone. Otherwise falls back to the host local time. - """ + """Return the current time string.""" from zoneinfo import ZoneInfo try: @@ -76,12 +75,164 @@ def current_time_str(timezone: str | None = None) -> str: _UNSAFE_CHARS = re.compile(r'[<>:"/\\|?*]') +_TOOL_RESULT_PREVIEW_CHARS = 1200 +_TOOL_RESULTS_DIR = ".nanobot/tool-results" +_TOOL_RESULT_RETENTION_SECS = 7 * 24 * 60 * 60 +_TOOL_RESULT_MAX_BUCKETS = 32 def safe_filename(name: str) -> str: """Replace unsafe path characters with underscores.""" return _UNSAFE_CHARS.sub("_", name).strip() +def image_placeholder_text(path: str | None, *, empty: str = "[image]") -> str: + """Build an image placeholder string.""" + return f"[image: {path}]" if path else empty + + +def truncate_text(text: str, max_chars: int) -> str: + """Truncate text with a stable suffix.""" + if max_chars <= 0 or len(text) <= max_chars: + return text + return text[:max_chars] + "\n... (truncated)" + + +def find_legal_message_start(messages: list[dict[str, Any]]) -> int: + """Find the first index whose tool results have matching assistant calls.""" + declared: set[str] = set() + start = 0 + for i, msg in enumerate(messages): + role = msg.get("role") + if role == "assistant": + for tc in msg.get("tool_calls") or []: + if isinstance(tc, dict) and tc.get("id"): + declared.add(str(tc["id"])) + elif role == "tool": + tid = msg.get("tool_call_id") + if tid and str(tid) not in declared: + start = i + 1 + declared.clear() + for prev in messages[start : i + 1]: + if prev.get("role") == "assistant": + for tc in prev.get("tool_calls") or []: + if isinstance(tc, dict) and tc.get("id"): + declared.add(str(tc["id"])) + return start + + +def stringify_text_blocks(content: list[dict[str, Any]]) -> str | None: + parts: list[str] = [] + for block in content: + if not isinstance(block, dict): + return None + if block.get("type") != "text": + return None + text = block.get("text") + if not isinstance(text, str): + return None + parts.append(text) + return "\n".join(parts) + + +def _render_tool_result_reference( + filepath: Path, + *, + original_size: int, + preview: str, + truncated_preview: bool, +) -> str: + result = ( + f"[tool output persisted]\n" + f"Full output saved to: {filepath}\n" + f"Original size: {original_size} chars\n" + f"Preview:\n{preview}" + ) + if truncated_preview: + result += "\n...\n(Read the saved file if you need the full output.)" + return result + + +def _bucket_mtime(path: Path) -> float: + try: + return path.stat().st_mtime + except OSError: + return 0.0 + + +def _cleanup_tool_result_buckets(root: Path, current_bucket: Path) -> None: + siblings = [path for path in root.iterdir() if path.is_dir() and path != current_bucket] + cutoff = time.time() - _TOOL_RESULT_RETENTION_SECS + for path in siblings: + if _bucket_mtime(path) < cutoff: + shutil.rmtree(path, ignore_errors=True) + keep = max(_TOOL_RESULT_MAX_BUCKETS - 1, 0) + siblings = [path for path in siblings if path.exists()] + if len(siblings) <= keep: + return + siblings.sort(key=_bucket_mtime, reverse=True) + for path in siblings[keep:]: + shutil.rmtree(path, ignore_errors=True) + + +def _write_text_atomic(path: Path, content: str) -> None: + tmp = path.with_name(f".{path.name}.{uuid.uuid4().hex}.tmp") + try: + tmp.write_text(content, encoding="utf-8") + tmp.replace(path) + finally: + if tmp.exists(): + tmp.unlink(missing_ok=True) + + +def maybe_persist_tool_result( + workspace: Path | None, + session_key: str | None, + tool_call_id: str, + content: Any, + *, + max_chars: int, +) -> Any: + """Persist oversized tool output and replace it with a stable reference string.""" + if workspace is None or max_chars <= 0: + return content + + text_payload: str | None = None + suffix = "txt" + if isinstance(content, str): + text_payload = content + elif isinstance(content, list): + text_payload = stringify_text_blocks(content) + if text_payload is None: + return content + suffix = "json" + else: + return content + + if len(text_payload) <= max_chars: + return content + + root = ensure_dir(workspace / _TOOL_RESULTS_DIR) + bucket = ensure_dir(root / safe_filename(session_key or "default")) + try: + _cleanup_tool_result_buckets(root, bucket) + except Exception as exc: + logger.warning("Failed to clean stale tool result buckets in {}: {}", root, exc) + path = bucket / f"{safe_filename(tool_call_id)}.{suffix}" + if not path.exists(): + if suffix == "json" and isinstance(content, list): + _write_text_atomic(path, json.dumps(content, ensure_ascii=False, indent=2)) + else: + _write_text_atomic(path, text_payload) + + preview = text_payload[:_TOOL_RESULT_PREVIEW_CHARS] + return _render_tool_result_reference( + path, + original_size=len(text_payload), + preview=preview, + truncated_preview=len(text_payload) > _TOOL_RESULT_PREVIEW_CHARS, + ) + + def split_message(content: str, max_len: int = 2000) -> list[str]: """ Split content into chunks within max_len, preferring line breaks. diff --git a/nanobot/utils/runtime.py b/nanobot/utils/runtime.py new file mode 100644 index 000000000..7164629c5 --- /dev/null +++ b/nanobot/utils/runtime.py @@ -0,0 +1,88 @@ +"""Runtime-specific helper functions and constants.""" + +from __future__ import annotations + +from typing import Any + +from loguru import logger + +from nanobot.utils.helpers import stringify_text_blocks + +_MAX_REPEAT_EXTERNAL_LOOKUPS = 2 + +EMPTY_FINAL_RESPONSE_MESSAGE = ( + "I completed the tool steps but couldn't produce a final answer. " + "Please try again or narrow the task." +) + +FINALIZATION_RETRY_PROMPT = ( + "You have already finished the tool work. Do not call any more tools. " + "Using only the conversation and tool results above, provide the final answer for the user now." +) + + +def empty_tool_result_message(tool_name: str) -> str: + """Short prompt-safe marker for tools that completed without visible output.""" + return f"({tool_name} completed with no output)" + + +def ensure_nonempty_tool_result(tool_name: str, content: Any) -> Any: + """Replace semantically empty tool results with a short marker string.""" + if content is None: + return empty_tool_result_message(tool_name) + if isinstance(content, str) and not content.strip(): + return empty_tool_result_message(tool_name) + if isinstance(content, list): + if not content: + return empty_tool_result_message(tool_name) + text_payload = stringify_text_blocks(content) + if text_payload is not None and not text_payload.strip(): + return empty_tool_result_message(tool_name) + return content + + +def is_blank_text(content: str | None) -> bool: + """True when *content* is missing or only whitespace.""" + return content is None or not content.strip() + + +def build_finalization_retry_message() -> dict[str, str]: + """A short no-tools-allowed prompt for final answer recovery.""" + return {"role": "user", "content": FINALIZATION_RETRY_PROMPT} + + +def external_lookup_signature(tool_name: str, arguments: dict[str, Any]) -> str | None: + """Stable signature for repeated external lookups we want to throttle.""" + if tool_name == "web_fetch": + url = str(arguments.get("url") or "").strip() + if url: + return f"web_fetch:{url.lower()}" + if tool_name == "web_search": + query = str(arguments.get("query") or arguments.get("search_term") or "").strip() + if query: + return f"web_search:{query.lower()}" + return None + + +def repeated_external_lookup_error( + tool_name: str, + arguments: dict[str, Any], + seen_counts: dict[str, int], +) -> str | None: + """Block repeated external lookups after a small retry budget.""" + signature = external_lookup_signature(tool_name, arguments) + if signature is None: + return None + count = seen_counts.get(signature, 0) + 1 + seen_counts[signature] = count + if count <= _MAX_REPEAT_EXTERNAL_LOOKUPS: + return None + logger.warning( + "Blocking repeated external lookup {} on attempt {}", + signature[:160], + count, + ) + return ( + "Error: repeated external lookup blocked. " + "Use the results you already have to answer, or try a meaningfully different source." + ) diff --git a/tests/agent/test_context_prompt_cache.py b/tests/agent/test_context_prompt_cache.py index 6eb4b4f19..4484e5ed0 100644 --- a/tests/agent/test_context_prompt_cache.py +++ b/tests/agent/test_context_prompt_cache.py @@ -71,3 +71,19 @@ def test_runtime_context_is_separate_untrusted_user_message(tmp_path) -> None: assert "Channel: cli" in user_content assert "Chat ID: direct" in user_content assert "Return exactly: OK" in user_content + + +def test_subagent_result_does_not_create_consecutive_assistant_messages(tmp_path) -> None: + workspace = _make_workspace(tmp_path) + builder = ContextBuilder(workspace) + + messages = builder.build_messages( + history=[{"role": "assistant", "content": "previous result"}], + current_message="subagent result", + channel="cli", + chat_id="direct", + current_role="assistant", + ) + + for left, right in zip(messages, messages[1:]): + assert not (left.get("role") == right.get("role") == "assistant") diff --git a/tests/agent/test_loop_save_turn.py b/tests/agent/test_loop_save_turn.py index aed7653c3..8a0b54b86 100644 --- a/tests/agent/test_loop_save_turn.py +++ b/tests/agent/test_loop_save_turn.py @@ -5,7 +5,9 @@ from nanobot.session.manager import Session def _mk_loop() -> AgentLoop: loop = AgentLoop.__new__(AgentLoop) - loop._TOOL_RESULT_MAX_CHARS = AgentLoop._TOOL_RESULT_MAX_CHARS + from nanobot.config.schema import AgentDefaults + + loop.max_tool_result_chars = AgentDefaults().max_tool_result_chars return loop @@ -72,3 +74,129 @@ def test_save_turn_keeps_tool_results_under_16k() -> None: ) assert session.messages[0]["content"] == content + + +def test_restore_runtime_checkpoint_rehydrates_completed_and_pending_tools() -> None: + loop = _mk_loop() + session = Session( + key="test:checkpoint", + metadata={ + AgentLoop._RUNTIME_CHECKPOINT_KEY: { + "assistant_message": { + "role": "assistant", + "content": "working", + "tool_calls": [ + { + "id": "call_done", + "type": "function", + "function": {"name": "read_file", "arguments": "{}"}, + }, + { + "id": "call_pending", + "type": "function", + "function": {"name": "exec", "arguments": "{}"}, + }, + ], + }, + "completed_tool_results": [ + { + "role": "tool", + "tool_call_id": "call_done", + "name": "read_file", + "content": "ok", + } + ], + "pending_tool_calls": [ + { + "id": "call_pending", + "type": "function", + "function": {"name": "exec", "arguments": "{}"}, + } + ], + } + }, + ) + + restored = loop._restore_runtime_checkpoint(session) + + assert restored is True + assert session.metadata.get(AgentLoop._RUNTIME_CHECKPOINT_KEY) is None + assert session.messages[0]["role"] == "assistant" + assert session.messages[1]["tool_call_id"] == "call_done" + assert session.messages[2]["tool_call_id"] == "call_pending" + assert "interrupted before this tool finished" in session.messages[2]["content"].lower() + + +def test_restore_runtime_checkpoint_dedupes_overlapping_tail() -> None: + loop = _mk_loop() + session = Session( + key="test:checkpoint-overlap", + messages=[ + { + "role": "assistant", + "content": "working", + "tool_calls": [ + { + "id": "call_done", + "type": "function", + "function": {"name": "read_file", "arguments": "{}"}, + }, + { + "id": "call_pending", + "type": "function", + "function": {"name": "exec", "arguments": "{}"}, + }, + ], + }, + { + "role": "tool", + "tool_call_id": "call_done", + "name": "read_file", + "content": "ok", + }, + ], + metadata={ + AgentLoop._RUNTIME_CHECKPOINT_KEY: { + "assistant_message": { + "role": "assistant", + "content": "working", + "tool_calls": [ + { + "id": "call_done", + "type": "function", + "function": {"name": "read_file", "arguments": "{}"}, + }, + { + "id": "call_pending", + "type": "function", + "function": {"name": "exec", "arguments": "{}"}, + }, + ], + }, + "completed_tool_results": [ + { + "role": "tool", + "tool_call_id": "call_done", + "name": "read_file", + "content": "ok", + } + ], + "pending_tool_calls": [ + { + "id": "call_pending", + "type": "function", + "function": {"name": "exec", "arguments": "{}"}, + } + ], + } + }, + ) + + restored = loop._restore_runtime_checkpoint(session) + + assert restored is True + assert session.metadata.get(AgentLoop._RUNTIME_CHECKPOINT_KEY) is None + assert len(session.messages) == 3 + assert session.messages[0]["role"] == "assistant" + assert session.messages[1]["tool_call_id"] == "call_done" + assert session.messages[2]["tool_call_id"] == "call_pending" diff --git a/tests/agent/test_runner.py b/tests/agent/test_runner.py index 98f1d73ae..dcdd15031 100644 --- a/tests/agent/test_runner.py +++ b/tests/agent/test_runner.py @@ -2,12 +2,20 @@ from __future__ import annotations +import asyncio +import os +import time from unittest.mock import AsyncMock, MagicMock, patch import pytest +from nanobot.config.schema import AgentDefaults +from nanobot.agent.tools.base import Tool +from nanobot.agent.tools.registry import ToolRegistry from nanobot.providers.base import LLMResponse, ToolCallRequest +_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars + def _make_loop(tmp_path): from nanobot.agent.loop import AgentLoop @@ -60,6 +68,7 @@ async def test_runner_preserves_reasoning_fields_and_tool_results(): tools=tools, model="test-model", max_iterations=3, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, )) assert result.final_content == "done" @@ -135,6 +144,7 @@ async def test_runner_calls_hooks_in_order(): tools=tools, model="test-model", max_iterations=3, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, hook=RecordingHook(), )) @@ -191,6 +201,7 @@ async def test_runner_streaming_hook_receives_deltas_and_end_signal(): tools=tools, model="test-model", max_iterations=1, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, hook=StreamingHook(), )) @@ -219,6 +230,7 @@ async def test_runner_returns_max_iterations_fallback(): tools=tools, model="test-model", max_iterations=2, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, )) assert result.stop_reason == "max_iterations" @@ -226,7 +238,8 @@ async def test_runner_returns_max_iterations_fallback(): "I reached the maximum number of tool call iterations (2) " "without completing the task. You can try breaking the task into smaller steps." ) - + assert result.messages[-1]["role"] == "assistant" + assert result.messages[-1]["content"] == result.final_content @pytest.mark.asyncio async def test_runner_returns_structured_tool_error(): @@ -248,6 +261,7 @@ async def test_runner_returns_structured_tool_error(): tools=tools, model="test-model", max_iterations=2, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, fail_on_tool_error=True, )) @@ -258,6 +272,457 @@ async def test_runner_returns_structured_tool_error(): ] +@pytest.mark.asyncio +async def test_runner_persists_large_tool_results_for_follow_up_calls(tmp_path): + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock() + captured_second_call: list[dict] = [] + call_count = {"n": 0} + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + if call_count["n"] == 1: + return LLMResponse( + content="working", + tool_calls=[ToolCallRequest(id="call_big", name="list_dir", arguments={"path": "."})], + usage={"prompt_tokens": 5, "completion_tokens": 3}, + ) + captured_second_call[:] = messages + return LLMResponse(content="done", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + tools.execute = AsyncMock(return_value="x" * 20_000) + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "do task"}], + tools=tools, + model="test-model", + max_iterations=2, + workspace=tmp_path, + session_key="test:runner", + max_tool_result_chars=2048, + )) + + assert result.final_content == "done" + tool_message = next(msg for msg in captured_second_call if msg.get("role") == "tool") + assert "[tool output persisted]" in tool_message["content"] + assert "tool-results" in tool_message["content"] + assert (tmp_path / ".nanobot" / "tool-results" / "test_runner" / "call_big.txt").exists() + + +def test_persist_tool_result_prunes_old_session_buckets(tmp_path): + from nanobot.utils.helpers import maybe_persist_tool_result + + root = tmp_path / ".nanobot" / "tool-results" + old_bucket = root / "old_session" + recent_bucket = root / "recent_session" + old_bucket.mkdir(parents=True) + recent_bucket.mkdir(parents=True) + (old_bucket / "old.txt").write_text("old", encoding="utf-8") + (recent_bucket / "recent.txt").write_text("recent", encoding="utf-8") + + stale = time.time() - (8 * 24 * 60 * 60) + os.utime(old_bucket, (stale, stale)) + os.utime(old_bucket / "old.txt", (stale, stale)) + + persisted = maybe_persist_tool_result( + tmp_path, + "current:session", + "call_big", + "x" * 5000, + max_chars=64, + ) + + assert "[tool output persisted]" in persisted + assert not old_bucket.exists() + assert recent_bucket.exists() + assert (root / "current_session" / "call_big.txt").exists() + + +def test_persist_tool_result_leaves_no_temp_files(tmp_path): + from nanobot.utils.helpers import maybe_persist_tool_result + + root = tmp_path / ".nanobot" / "tool-results" + maybe_persist_tool_result( + tmp_path, + "current:session", + "call_big", + "x" * 5000, + max_chars=64, + ) + + assert (root / "current_session" / "call_big.txt").exists() + assert list((root / "current_session").glob("*.tmp")) == [] + + +def test_persist_tool_result_logs_cleanup_failures(monkeypatch, tmp_path): + from nanobot.utils.helpers import maybe_persist_tool_result + + warnings: list[str] = [] + + monkeypatch.setattr( + "nanobot.utils.helpers._cleanup_tool_result_buckets", + lambda *_args, **_kwargs: (_ for _ in ()).throw(OSError("busy")), + ) + monkeypatch.setattr( + "nanobot.utils.helpers.logger.warning", + lambda message, *args: warnings.append(message.format(*args)), + ) + + persisted = maybe_persist_tool_result( + tmp_path, + "current:session", + "call_big", + "x" * 5000, + max_chars=64, + ) + + assert "[tool output persisted]" in persisted + assert warnings and "Failed to clean stale tool result buckets" in warnings[0] + + +@pytest.mark.asyncio +async def test_runner_replaces_empty_tool_result_with_marker(): + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock() + captured_second_call: list[dict] = [] + call_count = {"n": 0} + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + if call_count["n"] == 1: + return LLMResponse( + content="working", + tool_calls=[ToolCallRequest(id="call_1", name="noop", arguments={})], + usage={}, + ) + captured_second_call[:] = messages + return LLMResponse(content="done", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + tools.execute = AsyncMock(return_value="") + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "do task"}], + tools=tools, + model="test-model", + max_iterations=2, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + )) + + assert result.final_content == "done" + tool_message = next(msg for msg in captured_second_call if msg.get("role") == "tool") + assert tool_message["content"] == "(noop completed with no output)" + + +@pytest.mark.asyncio +async def test_runner_uses_raw_messages_when_context_governance_fails(): + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock() + captured_messages: list[dict] = [] + + async def chat_with_retry(*, messages, **kwargs): + captured_messages[:] = messages + return LLMResponse(content="done", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + initial_messages = [ + {"role": "system", "content": "system"}, + {"role": "user", "content": "hello"}, + ] + + runner = AgentRunner(provider) + runner._snip_history = MagicMock(side_effect=RuntimeError("boom")) # type: ignore[method-assign] + result = await runner.run(AgentRunSpec( + initial_messages=initial_messages, + tools=tools, + model="test-model", + max_iterations=1, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + )) + + assert result.final_content == "done" + assert captured_messages == initial_messages + + +@pytest.mark.asyncio +async def test_runner_retries_empty_final_response_with_summary_prompt(): + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock() + calls: list[dict] = [] + + async def chat_with_retry(*, messages, tools=None, **kwargs): + calls.append({"messages": messages, "tools": tools}) + if len(calls) == 1: + return LLMResponse( + content=None, + tool_calls=[], + usage={"prompt_tokens": 10, "completion_tokens": 1}, + ) + return LLMResponse( + content="final answer", + tool_calls=[], + usage={"prompt_tokens": 3, "completion_tokens": 7}, + ) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "do task"}], + tools=tools, + model="test-model", + max_iterations=1, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + )) + + assert result.final_content == "final answer" + assert len(calls) == 2 + assert calls[1]["tools"] is None + assert "Do not call any more tools" in calls[1]["messages"][-1]["content"] + assert result.usage["prompt_tokens"] == 13 + assert result.usage["completion_tokens"] == 8 + + +@pytest.mark.asyncio +async def test_runner_uses_specific_message_after_empty_finalization_retry(): + from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.utils.runtime import EMPTY_FINAL_RESPONSE_MESSAGE + + provider = MagicMock() + + async def chat_with_retry(*, messages, **kwargs): + return LLMResponse(content=None, tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "do task"}], + tools=tools, + model="test-model", + max_iterations=1, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + )) + + assert result.final_content == EMPTY_FINAL_RESPONSE_MESSAGE + assert result.stop_reason == "empty_final_response" + + +def test_snip_history_drops_orphaned_tool_results_from_trimmed_slice(monkeypatch): + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock() + tools = MagicMock() + tools.get_definitions.return_value = [] + runner = AgentRunner(provider) + messages = [ + {"role": "system", "content": "system"}, + {"role": "user", "content": "old user"}, + { + "role": "assistant", + "content": "tool call", + "tool_calls": [{"id": "call_1", "type": "function", "function": {"name": "ls", "arguments": "{}"}}], + }, + {"role": "tool", "tool_call_id": "call_1", "content": "tool output"}, + {"role": "assistant", "content": "after tool"}, + ] + spec = AgentRunSpec( + initial_messages=messages, + tools=tools, + model="test-model", + max_iterations=1, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + context_window_tokens=2000, + context_block_limit=100, + ) + + monkeypatch.setattr("nanobot.agent.runner.estimate_prompt_tokens_chain", lambda *_args, **_kwargs: (500, None)) + token_sizes = { + "old user": 120, + "tool call": 120, + "tool output": 40, + "after tool": 40, + "system": 0, + } + monkeypatch.setattr( + "nanobot.agent.runner.estimate_message_tokens", + lambda msg: token_sizes.get(str(msg.get("content")), 40), + ) + + trimmed = runner._snip_history(spec, messages) + + assert trimmed == [ + {"role": "system", "content": "system"}, + {"role": "assistant", "content": "after tool"}, + ] + + +@pytest.mark.asyncio +async def test_runner_keeps_going_when_tool_result_persistence_fails(): + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock() + captured_second_call: list[dict] = [] + call_count = {"n": 0} + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + if call_count["n"] == 1: + return LLMResponse( + content="working", + tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})], + usage={"prompt_tokens": 5, "completion_tokens": 3}, + ) + captured_second_call[:] = messages + return LLMResponse(content="done", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + tools.execute = AsyncMock(return_value="tool result") + + runner = AgentRunner(provider) + with patch("nanobot.agent.runner.maybe_persist_tool_result", side_effect=RuntimeError("disk full")): + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "do task"}], + tools=tools, + model="test-model", + max_iterations=2, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + )) + + assert result.final_content == "done" + tool_message = next(msg for msg in captured_second_call if msg.get("role") == "tool") + assert tool_message["content"] == "tool result" + + +class _DelayTool(Tool): + def __init__(self, name: str, *, delay: float, read_only: bool, shared_events: list[str]): + self._name = name + self._delay = delay + self._read_only = read_only + self._shared_events = shared_events + + @property + def name(self) -> str: + return self._name + + @property + def description(self) -> str: + return self._name + + @property + def parameters(self) -> dict: + return {"type": "object", "properties": {}, "required": []} + + @property + def read_only(self) -> bool: + return self._read_only + + async def execute(self, **kwargs): + self._shared_events.append(f"start:{self._name}") + await asyncio.sleep(self._delay) + self._shared_events.append(f"end:{self._name}") + return self._name + + +@pytest.mark.asyncio +async def test_runner_batches_read_only_tools_before_exclusive_work(): + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + tools = ToolRegistry() + shared_events: list[str] = [] + read_a = _DelayTool("read_a", delay=0.05, read_only=True, shared_events=shared_events) + read_b = _DelayTool("read_b", delay=0.05, read_only=True, shared_events=shared_events) + write_a = _DelayTool("write_a", delay=0.01, read_only=False, shared_events=shared_events) + tools.register(read_a) + tools.register(read_b) + tools.register(write_a) + + runner = AgentRunner(MagicMock()) + await runner._execute_tools( + AgentRunSpec( + initial_messages=[], + tools=tools, + model="test-model", + max_iterations=1, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + concurrent_tools=True, + ), + [ + ToolCallRequest(id="ro1", name="read_a", arguments={}), + ToolCallRequest(id="ro2", name="read_b", arguments={}), + ToolCallRequest(id="rw1", name="write_a", arguments={}), + ], + {}, + ) + + assert shared_events[0:2] == ["start:read_a", "start:read_b"] + assert "end:read_a" in shared_events and "end:read_b" in shared_events + assert shared_events.index("end:read_a") < shared_events.index("start:write_a") + assert shared_events.index("end:read_b") < shared_events.index("start:write_a") + assert shared_events[-2:] == ["start:write_a", "end:write_a"] + + +@pytest.mark.asyncio +async def test_runner_blocks_repeated_external_fetches(): + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock() + captured_final_call: list[dict] = [] + call_count = {"n": 0} + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + if call_count["n"] <= 3: + return LLMResponse( + content="working", + tool_calls=[ToolCallRequest(id=f"call_{call_count['n']}", name="web_fetch", arguments={"url": "https://example.com"})], + usage={}, + ) + captured_final_call[:] = messages + return LLMResponse(content="done", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + tools.execute = AsyncMock(return_value="page content") + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "research task"}], + tools=tools, + model="test-model", + max_iterations=4, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + )) + + assert result.final_content == "done" + assert tools.execute.await_count == 2 + blocked_tool_message = [ + msg for msg in captured_final_call + if msg.get("role") == "tool" and msg.get("tool_call_id") == "call_3" + ][0] + assert "repeated external lookup blocked" in blocked_tool_message["content"] + + @pytest.mark.asyncio async def test_loop_max_iterations_message_stays_stable(tmp_path): loop = _make_loop(tmp_path) @@ -307,6 +772,57 @@ async def test_loop_stream_filter_handles_think_only_prefix_without_crashing(tmp assert endings == [False] +@pytest.mark.asyncio +async def test_loop_retries_think_only_final_response(tmp_path): + loop = _make_loop(tmp_path) + call_count = {"n": 0} + + async def chat_with_retry(**kwargs): + call_count["n"] += 1 + if call_count["n"] == 1: + return LLMResponse(content="hidden", tool_calls=[], usage={}) + return LLMResponse(content="Recovered answer", tool_calls=[], usage={}) + + loop.provider.chat_with_retry = chat_with_retry + + final_content, _, _ = await loop._run_agent_loop([]) + + assert final_content == "Recovered answer" + assert call_count["n"] == 2 + + +@pytest.mark.asyncio +async def test_runner_tool_error_sets_final_content(): + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock() + + async def chat_with_retry(*, messages, **kwargs): + return LLMResponse( + content="working", + tool_calls=[ToolCallRequest(id="call_1", name="read_file", arguments={"path": "x"})], + usage={}, + ) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + tools.execute = AsyncMock(side_effect=RuntimeError("boom")) + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "do task"}], + tools=tools, + model="test-model", + max_iterations=1, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + fail_on_tool_error=True, + )) + + assert result.final_content == "Error: RuntimeError: boom" + assert result.stop_reason == "tool_error" + + @pytest.mark.asyncio async def test_subagent_max_iterations_announces_existing_fallback(tmp_path, monkeypatch): from nanobot.agent.subagent import SubagentManager @@ -317,15 +833,20 @@ async def test_subagent_max_iterations_announces_existing_fallback(tmp_path, mon provider.get_default_model.return_value = "test-model" provider.chat_with_retry = AsyncMock(return_value=LLMResponse( content="working", - tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})], + tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})], )) - mgr = SubagentManager(provider=provider, workspace=tmp_path, bus=bus) + mgr = SubagentManager( + provider=provider, + workspace=tmp_path, + bus=bus, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + ) mgr._announce_result = AsyncMock() - async def fake_execute(self, name, arguments): + async def fake_execute(self, **kwargs): return "tool result" - monkeypatch.setattr("nanobot.agent.tools.registry.ToolRegistry.execute", fake_execute) + monkeypatch.setattr("nanobot.agent.tools.filesystem.ListDirTool.execute", fake_execute) await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"}) @@ -369,6 +890,7 @@ async def test_runner_accumulates_usage_and_preserves_cached_tokens(): tools=tools, model="test-model", max_iterations=3, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, )) # Usage should be accumulated across iterations @@ -407,6 +929,7 @@ async def test_runner_passes_cached_tokens_to_hook_context(): tools=tools, model="test-model", max_iterations=1, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, hook=UsageHook(), )) diff --git a/tests/agent/test_task_cancel.py b/tests/agent/test_task_cancel.py index 70f7621d1..7e84e57d8 100644 --- a/tests/agent/test_task_cancel.py +++ b/tests/agent/test_task_cancel.py @@ -8,6 +8,10 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest +from nanobot.config.schema import AgentDefaults + +_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars + def _make_loop(*, exec_config=None): """Create a minimal AgentLoop with mocked dependencies.""" @@ -186,7 +190,12 @@ class TestSubagentCancellation: bus = MessageBus() provider = MagicMock() provider.get_default_model.return_value = "test-model" - mgr = SubagentManager(provider=provider, workspace=MagicMock(), bus=bus) + mgr = SubagentManager( + provider=provider, + workspace=MagicMock(), + bus=bus, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + ) cancelled = asyncio.Event() @@ -214,7 +223,12 @@ class TestSubagentCancellation: bus = MessageBus() provider = MagicMock() provider.get_default_model.return_value = "test-model" - mgr = SubagentManager(provider=provider, workspace=MagicMock(), bus=bus) + mgr = SubagentManager( + provider=provider, + workspace=MagicMock(), + bus=bus, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + ) assert await mgr.cancel_by_session("nonexistent") == 0 @pytest.mark.asyncio @@ -236,19 +250,24 @@ class TestSubagentCancellation: if call_count["n"] == 1: return LLMResponse( content="thinking", - tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})], + tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})], reasoning_content="hidden reasoning", thinking_blocks=[{"type": "thinking", "thinking": "step"}], ) captured_second_call[:] = messages return LLMResponse(content="done", tool_calls=[]) provider.chat_with_retry = scripted_chat_with_retry - mgr = SubagentManager(provider=provider, workspace=tmp_path, bus=bus) + mgr = SubagentManager( + provider=provider, + workspace=tmp_path, + bus=bus, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + ) - async def fake_execute(self, name, arguments): + async def fake_execute(self, **kwargs): return "tool result" - monkeypatch.setattr("nanobot.agent.tools.registry.ToolRegistry.execute", fake_execute) + monkeypatch.setattr("nanobot.agent.tools.filesystem.ListDirTool.execute", fake_execute) await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"}) @@ -273,6 +292,7 @@ class TestSubagentCancellation: provider=provider, workspace=tmp_path, bus=bus, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, exec_config=ExecToolConfig(enable=False), ) mgr._announce_result = AsyncMock() @@ -304,20 +324,25 @@ class TestSubagentCancellation: provider.get_default_model.return_value = "test-model" provider.chat_with_retry = AsyncMock(return_value=LLMResponse( content="thinking", - tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})], + tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})], )) - mgr = SubagentManager(provider=provider, workspace=tmp_path, bus=bus) + mgr = SubagentManager( + provider=provider, + workspace=tmp_path, + bus=bus, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + ) mgr._announce_result = AsyncMock() calls = {"n": 0} - async def fake_execute(self, name, arguments): + async def fake_execute(self, **kwargs): calls["n"] += 1 if calls["n"] == 1: return "first result" raise RuntimeError("boom") - monkeypatch.setattr("nanobot.agent.tools.registry.ToolRegistry.execute", fake_execute) + monkeypatch.setattr("nanobot.agent.tools.filesystem.ListDirTool.execute", fake_execute) await mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"}) @@ -340,15 +365,20 @@ class TestSubagentCancellation: provider.get_default_model.return_value = "test-model" provider.chat_with_retry = AsyncMock(return_value=LLMResponse( content="thinking", - tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})], + tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={"path": "."})], )) - mgr = SubagentManager(provider=provider, workspace=tmp_path, bus=bus) + mgr = SubagentManager( + provider=provider, + workspace=tmp_path, + bus=bus, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + ) mgr._announce_result = AsyncMock() started = asyncio.Event() cancelled = asyncio.Event() - async def fake_execute(self, name, arguments): + async def fake_execute(self, **kwargs): started.set() try: await asyncio.sleep(60) @@ -356,7 +386,7 @@ class TestSubagentCancellation: cancelled.set() raise - monkeypatch.setattr("nanobot.agent.tools.registry.ToolRegistry.execute", fake_execute) + monkeypatch.setattr("nanobot.agent.tools.filesystem.ListDirTool.execute", fake_execute) task = asyncio.create_task( mgr._run_subagent("sub-1", "do task", "label", {"channel": "test", "chat_id": "c1"}) @@ -364,7 +394,7 @@ class TestSubagentCancellation: mgr._running_tasks["sub-1"] = task mgr._session_tasks["test:c1"] = {"sub-1"} - await started.wait() + await asyncio.wait_for(started.wait(), timeout=1.0) count = await mgr.cancel_by_session("test:c1") diff --git a/tests/channels/test_channel_plugins.py b/tests/channels/test_channel_plugins.py index a0b458a08..4cf4fab21 100644 --- a/tests/channels/test_channel_plugins.py +++ b/tests/channels/test_channel_plugins.py @@ -208,7 +208,7 @@ def test_channels_login_uses_discovered_plugin_class(monkeypatch): seen["config"] = self.config return True - monkeypatch.setattr("nanobot.config.loader.load_config", lambda: Config()) + monkeypatch.setattr("nanobot.config.loader.load_config", lambda config_path=None: Config()) monkeypatch.setattr( "nanobot.channels.registry.discover_all", lambda: {"fakeplugin": _LoginPlugin}, @@ -220,6 +220,57 @@ def test_channels_login_uses_discovered_plugin_class(monkeypatch): assert seen["force"] is True +def test_channels_login_sets_custom_config_path(monkeypatch, tmp_path): + from nanobot.cli.commands import app + from nanobot.config.schema import Config + from typer.testing import CliRunner + + runner = CliRunner() + seen: dict[str, object] = {} + config_path = tmp_path / "custom-config.json" + + class _LoginPlugin(_FakePlugin): + async def login(self, force: bool = False) -> bool: + return True + + monkeypatch.setattr("nanobot.config.loader.load_config", lambda config_path=None: Config()) + monkeypatch.setattr( + "nanobot.config.loader.set_config_path", + lambda path: seen.__setitem__("config_path", path), + ) + monkeypatch.setattr( + "nanobot.channels.registry.discover_all", + lambda: {"fakeplugin": _LoginPlugin}, + ) + + result = runner.invoke(app, ["channels", "login", "fakeplugin", "--config", str(config_path)]) + + assert result.exit_code == 0 + assert seen["config_path"] == config_path.resolve() + + +def test_channels_status_sets_custom_config_path(monkeypatch, tmp_path): + from nanobot.cli.commands import app + from nanobot.config.schema import Config + from typer.testing import CliRunner + + runner = CliRunner() + seen: dict[str, object] = {} + config_path = tmp_path / "custom-config.json" + + monkeypatch.setattr("nanobot.config.loader.load_config", lambda config_path=None: Config()) + monkeypatch.setattr( + "nanobot.config.loader.set_config_path", + lambda path: seen.__setitem__("config_path", path), + ) + monkeypatch.setattr("nanobot.channels.registry.discover_all", lambda: {}) + + result = runner.invoke(app, ["channels", "status", "--config", str(config_path)]) + + assert result.exit_code == 0 + assert seen["config_path"] == config_path.resolve() + + @pytest.mark.asyncio async def test_manager_skips_disabled_plugin(): fake_config = SimpleNamespace( diff --git a/tests/channels/test_discord_channel.py b/tests/channels/test_discord_channel.py index d352c788c..845c03c57 100644 --- a/tests/channels/test_discord_channel.py +++ b/tests/channels/test_discord_channel.py @@ -594,7 +594,7 @@ async def test_send_stops_typing_after_send() -> None: typing_channel.typing_enter_hook = slow_typing await channel._start_typing(typing_channel) - await start.wait() + await asyncio.wait_for(start.wait(), timeout=1.0) await channel.send(OutboundMessage(channel="discord", chat_id="123", content="hello")) release.set() @@ -614,7 +614,7 @@ async def test_send_stops_typing_after_send() -> None: typing_channel.typing_enter_hook = slow_typing_progress await channel._start_typing(typing_channel) - await start.wait() + await asyncio.wait_for(start.wait(), timeout=1.0) await channel.send( OutboundMessage( @@ -665,7 +665,7 @@ async def test_start_typing_uses_typing_context_when_trigger_typing_missing() -> typing_channel = _NoTriggerChannel(channel_id=123) await channel._start_typing(typing_channel) # type: ignore[arg-type] - await entered.wait() + await asyncio.wait_for(entered.wait(), timeout=1.0) assert "123" in channel._typing_tasks diff --git a/tests/channels/test_matrix_channel.py b/tests/channels/test_matrix_channel.py index 18a8e1097..27b7e1255 100644 --- a/tests/channels/test_matrix_channel.py +++ b/tests/channels/test_matrix_channel.py @@ -3,16 +3,14 @@ from pathlib import Path from types import SimpleNamespace import pytest + +pytest.importorskip("nio") +pytest.importorskip("nh3") +pytest.importorskip("mistune") from nio import RoomSendResponse from nanobot.channels.matrix import _build_matrix_text_content -# Check optional matrix dependencies before importing -try: - import nh3 # noqa: F401 -except ImportError: - pytest.skip("Matrix dependencies not installed (nh3)", allow_module_level=True) - import nanobot.channels.matrix as matrix_module from nanobot.bus.events import OutboundMessage from nanobot.bus.queue import MessageBus diff --git a/tests/channels/test_qq_ack_message.py b/tests/channels/test_qq_ack_message.py new file mode 100644 index 000000000..0f3a2dbec --- /dev/null +++ b/tests/channels/test_qq_ack_message.py @@ -0,0 +1,172 @@ +"""Tests for QQ channel ack_message feature. + +Covers the four verification points from the PR: +1. C2C message: ack appears instantly +2. Group message: ack appears instantly +3. ack_message set to "": no ack sent +4. Custom ack_message text: correct text delivered +Each test also verifies that normal message processing is not blocked. +""" + +from types import SimpleNamespace + +import pytest + +try: + from nanobot.channels import qq + + QQ_AVAILABLE = getattr(qq, "QQ_AVAILABLE", False) +except ImportError: + QQ_AVAILABLE = False + +if not QQ_AVAILABLE: + pytest.skip("QQ dependencies not installed (qq-botpy)", allow_module_level=True) + +from nanobot.bus.queue import MessageBus +from nanobot.channels.qq import QQChannel, QQConfig + + +class _FakeApi: + def __init__(self) -> None: + self.c2c_calls: list[dict] = [] + self.group_calls: list[dict] = [] + + async def post_c2c_message(self, **kwargs) -> None: + self.c2c_calls.append(kwargs) + + async def post_group_message(self, **kwargs) -> None: + self.group_calls.append(kwargs) + + +class _FakeClient: + def __init__(self) -> None: + self.api = _FakeApi() + + +@pytest.mark.asyncio +async def test_ack_sent_on_c2c_message() -> None: + """Ack is sent immediately for C2C messages, then normal processing continues.""" + channel = QQChannel( + QQConfig( + app_id="app", + secret="secret", + allow_from=["*"], + ack_message="⏳ Processing...", + ), + MessageBus(), + ) + channel._client = _FakeClient() + + data = SimpleNamespace( + id="msg1", + content="hello", + author=SimpleNamespace(user_openid="user1"), + attachments=[], + ) + await channel._on_message(data, is_group=False) + + assert len(channel._client.api.c2c_calls) >= 1 + ack_call = channel._client.api.c2c_calls[0] + assert ack_call["content"] == "⏳ Processing..." + assert ack_call["openid"] == "user1" + assert ack_call["msg_id"] == "msg1" + assert ack_call["msg_type"] == 0 + + msg = await channel.bus.consume_inbound() + assert msg.content == "hello" + assert msg.sender_id == "user1" + + +@pytest.mark.asyncio +async def test_ack_sent_on_group_message() -> None: + """Ack is sent immediately for group messages, then normal processing continues.""" + channel = QQChannel( + QQConfig( + app_id="app", + secret="secret", + allow_from=["*"], + ack_message="⏳ Processing...", + ), + MessageBus(), + ) + channel._client = _FakeClient() + + data = SimpleNamespace( + id="msg2", + content="hello group", + group_openid="group123", + author=SimpleNamespace(member_openid="user1"), + attachments=[], + ) + await channel._on_message(data, is_group=True) + + assert len(channel._client.api.group_calls) >= 1 + ack_call = channel._client.api.group_calls[0] + assert ack_call["content"] == "⏳ Processing..." + assert ack_call["group_openid"] == "group123" + assert ack_call["msg_id"] == "msg2" + assert ack_call["msg_type"] == 0 + + msg = await channel.bus.consume_inbound() + assert msg.content == "hello group" + assert msg.chat_id == "group123" + + +@pytest.mark.asyncio +async def test_no_ack_when_ack_message_empty() -> None: + """Setting ack_message to empty string disables the ack entirely.""" + channel = QQChannel( + QQConfig( + app_id="app", + secret="secret", + allow_from=["*"], + ack_message="", + ), + MessageBus(), + ) + channel._client = _FakeClient() + + data = SimpleNamespace( + id="msg3", + content="hello", + author=SimpleNamespace(user_openid="user1"), + attachments=[], + ) + await channel._on_message(data, is_group=False) + + assert len(channel._client.api.c2c_calls) == 0 + assert len(channel._client.api.group_calls) == 0 + + msg = await channel.bus.consume_inbound() + assert msg.content == "hello" + + +@pytest.mark.asyncio +async def test_custom_ack_message_text() -> None: + """Custom Chinese ack_message text is delivered correctly.""" + custom = "ζ­£εœ¨ε€„η†δΈ­οΌŒθ―·η¨ε€™..." + channel = QQChannel( + QQConfig( + app_id="app", + secret="secret", + allow_from=["*"], + ack_message=custom, + ), + MessageBus(), + ) + channel._client = _FakeClient() + + data = SimpleNamespace( + id="msg4", + content="test input", + author=SimpleNamespace(user_openid="user1"), + attachments=[], + ) + await channel._on_message(data, is_group=False) + + assert len(channel._client.api.c2c_calls) >= 1 + ack_call = channel._client.api.c2c_calls[0] + assert ack_call["content"] == custom + + msg = await channel.bus.consume_inbound() + assert msg.content == "test input" diff --git a/tests/channels/test_telegram_channel.py b/tests/channels/test_telegram_channel.py index 972f8ab6e..c793b1224 100644 --- a/tests/channels/test_telegram_channel.py +++ b/tests/channels/test_telegram_channel.py @@ -647,43 +647,56 @@ async def test_group_policy_open_accepts_plain_group_message() -> None: assert channel._app.bot.get_me_calls == 0 -def test_extract_reply_context_no_reply() -> None: +@pytest.mark.asyncio +async def test_extract_reply_context_no_reply() -> None: """When there is no reply_to_message, _extract_reply_context returns None.""" + channel = TelegramChannel(TelegramConfig(enabled=True, token="123:abc"), MessageBus()) message = SimpleNamespace(reply_to_message=None) - assert TelegramChannel._extract_reply_context(message) is None + assert await channel._extract_reply_context(message) is None -def test_extract_reply_context_with_text() -> None: +@pytest.mark.asyncio +async def test_extract_reply_context_with_text() -> None: """When reply has text, return prefixed string.""" - reply = SimpleNamespace(text="Hello world", caption=None) + channel = TelegramChannel(TelegramConfig(enabled=True, token="123:abc"), MessageBus()) + channel._app = _FakeApp(lambda: None) + reply = SimpleNamespace(text="Hello world", caption=None, from_user=SimpleNamespace(id=2, username="testuser", first_name="Test")) message = SimpleNamespace(reply_to_message=reply) - assert TelegramChannel._extract_reply_context(message) == "[Reply to: Hello world]" + assert await channel._extract_reply_context(message) == "[Reply to @testuser: Hello world]" -def test_extract_reply_context_with_caption_only() -> None: +@pytest.mark.asyncio +async def test_extract_reply_context_with_caption_only() -> None: """When reply has only caption (no text), caption is used.""" - reply = SimpleNamespace(text=None, caption="Photo caption") + channel = TelegramChannel(TelegramConfig(enabled=True, token="123:abc"), MessageBus()) + channel._app = _FakeApp(lambda: None) + reply = SimpleNamespace(text=None, caption="Photo caption", from_user=SimpleNamespace(id=2, username=None, first_name="Test")) message = SimpleNamespace(reply_to_message=reply) - assert TelegramChannel._extract_reply_context(message) == "[Reply to: Photo caption]" + assert await channel._extract_reply_context(message) == "[Reply to Test: Photo caption]" -def test_extract_reply_context_truncation() -> None: +@pytest.mark.asyncio +async def test_extract_reply_context_truncation() -> None: """Reply text is truncated at TELEGRAM_REPLY_CONTEXT_MAX_LEN.""" + channel = TelegramChannel(TelegramConfig(enabled=True, token="123:abc"), MessageBus()) + channel._app = _FakeApp(lambda: None) long_text = "x" * (TELEGRAM_REPLY_CONTEXT_MAX_LEN + 100) - reply = SimpleNamespace(text=long_text, caption=None) + reply = SimpleNamespace(text=long_text, caption=None, from_user=SimpleNamespace(id=2, username=None, first_name=None)) message = SimpleNamespace(reply_to_message=reply) - result = TelegramChannel._extract_reply_context(message) + result = await channel._extract_reply_context(message) assert result is not None assert result.startswith("[Reply to: ") assert result.endswith("...]") assert len(result) == len("[Reply to: ]") + TELEGRAM_REPLY_CONTEXT_MAX_LEN + len("...") -def test_extract_reply_context_no_text_returns_none() -> None: +@pytest.mark.asyncio +async def test_extract_reply_context_no_text_returns_none() -> None: """When reply has no text/caption, _extract_reply_context returns None (media handled separately).""" + channel = TelegramChannel(TelegramConfig(enabled=True, token="123:abc"), MessageBus()) reply = SimpleNamespace(text=None, caption=None) message = SimpleNamespace(reply_to_message=reply) - assert TelegramChannel._extract_reply_context(message) is None + assert await channel._extract_reply_context(message) is None @pytest.mark.asyncio diff --git a/tests/providers/test_azure_openai_provider.py b/tests/providers/test_azure_openai_provider.py index 77f36d468..89cea64f0 100644 --- a/tests/providers/test_azure_openai_provider.py +++ b/tests/providers/test_azure_openai_provider.py @@ -1,6 +1,6 @@ -"""Test Azure OpenAI provider implementation (updated for model-based deployment names).""" +"""Test Azure OpenAI provider (Responses API via OpenAI SDK).""" -from unittest.mock import AsyncMock, Mock, patch +from unittest.mock import AsyncMock, MagicMock import pytest @@ -8,392 +8,401 @@ from nanobot.providers.azure_openai_provider import AzureOpenAIProvider from nanobot.providers.base import LLMResponse -def test_azure_openai_provider_init(): - """Test AzureOpenAIProvider initialization without deployment_name.""" +# --------------------------------------------------------------------------- +# Init & validation +# --------------------------------------------------------------------------- + + +def test_init_creates_sdk_client(): + """Provider creates an AsyncOpenAI client with correct base_url.""" provider = AzureOpenAIProvider( api_key="test-key", api_base="https://test-resource.openai.azure.com", default_model="gpt-4o-deployment", ) - assert provider.api_key == "test-key" assert provider.api_base == "https://test-resource.openai.azure.com/" assert provider.default_model == "gpt-4o-deployment" - assert provider.api_version == "2024-10-21" + # SDK client base_url ends with /openai/v1/ + assert str(provider._client.base_url).rstrip("/").endswith("/openai/v1") -def test_azure_openai_provider_init_validation(): - """Test AzureOpenAIProvider initialization validation.""" - # Missing api_key +def test_init_base_url_no_trailing_slash(): + """Trailing slashes are normalised before building base_url.""" + provider = AzureOpenAIProvider( + api_key="k", api_base="https://res.openai.azure.com", + ) + assert str(provider._client.base_url).rstrip("/").endswith("/openai/v1") + + +def test_init_base_url_with_trailing_slash(): + provider = AzureOpenAIProvider( + api_key="k", api_base="https://res.openai.azure.com/", + ) + assert str(provider._client.base_url).rstrip("/").endswith("/openai/v1") + + +def test_init_validation_missing_key(): with pytest.raises(ValueError, match="Azure OpenAI api_key is required"): AzureOpenAIProvider(api_key="", api_base="https://test.com") - - # Missing api_base + + +def test_init_validation_missing_base(): with pytest.raises(ValueError, match="Azure OpenAI api_base is required"): AzureOpenAIProvider(api_key="test", api_base="") -def test_build_chat_url(): - """Test Azure OpenAI URL building with different deployment names.""" +def test_no_api_version_in_base_url(): + """The /openai/v1/ path should NOT contain an api-version query param.""" + provider = AzureOpenAIProvider(api_key="k", api_base="https://res.openai.azure.com") + base = str(provider._client.base_url) + assert "api-version" not in base + + +# --------------------------------------------------------------------------- +# _supports_temperature +# --------------------------------------------------------------------------- + + +def test_supports_temperature_standard_model(): + assert AzureOpenAIProvider._supports_temperature("gpt-4o") is True + + +def test_supports_temperature_reasoning_model(): + assert AzureOpenAIProvider._supports_temperature("o3-mini") is False + assert AzureOpenAIProvider._supports_temperature("gpt-5-chat") is False + assert AzureOpenAIProvider._supports_temperature("o4-mini") is False + + +def test_supports_temperature_with_reasoning_effort(): + assert AzureOpenAIProvider._supports_temperature("gpt-4o", reasoning_effort="medium") is False + + +# --------------------------------------------------------------------------- +# _build_body β€” Responses API body construction +# --------------------------------------------------------------------------- + + +def test_build_body_basic(): provider = AzureOpenAIProvider( - api_key="test-key", - api_base="https://test-resource.openai.azure.com", - default_model="gpt-4o", + api_key="k", api_base="https://res.openai.azure.com", default_model="gpt-4o", ) - - # Test various deployment names - test_cases = [ - ("gpt-4o-deployment", "https://test-resource.openai.azure.com/openai/deployments/gpt-4o-deployment/chat/completions?api-version=2024-10-21"), - ("gpt-35-turbo", "https://test-resource.openai.azure.com/openai/deployments/gpt-35-turbo/chat/completions?api-version=2024-10-21"), - ("custom-model", "https://test-resource.openai.azure.com/openai/deployments/custom-model/chat/completions?api-version=2024-10-21"), - ] - - for deployment_name, expected_url in test_cases: - url = provider._build_chat_url(deployment_name) - assert url == expected_url + messages = [{"role": "system", "content": "You are helpful."}, {"role": "user", "content": "Hi"}] + body = provider._build_body(messages, None, None, 4096, 0.7, None, None) - -def test_build_chat_url_api_base_without_slash(): - """Test URL building when api_base doesn't end with slash.""" - provider = AzureOpenAIProvider( - api_key="test-key", - api_base="https://test-resource.openai.azure.com", # No trailing slash - default_model="gpt-4o", + assert body["model"] == "gpt-4o" + assert body["instructions"] == "You are helpful." + assert body["temperature"] == 0.7 + assert body["max_output_tokens"] == 4096 + assert body["store"] is False + assert "reasoning" not in body + # input should contain the converted user message only (system extracted) + assert any( + item.get("role") == "user" + for item in body["input"] ) - - url = provider._build_chat_url("test-deployment") - expected = "https://test-resource.openai.azure.com/openai/deployments/test-deployment/chat/completions?api-version=2024-10-21" - assert url == expected -def test_build_headers(): - """Test Azure OpenAI header building with api-key authentication.""" - provider = AzureOpenAIProvider( - api_key="test-api-key-123", - api_base="https://test-resource.openai.azure.com", - default_model="gpt-4o", - ) - - headers = provider._build_headers() - assert headers["Content-Type"] == "application/json" - assert headers["api-key"] == "test-api-key-123" # Azure OpenAI specific header - assert "x-session-affinity" in headers +def test_build_body_max_tokens_minimum(): + """max_output_tokens should never be less than 1.""" + provider = AzureOpenAIProvider(api_key="k", api_base="https://r.com", default_model="gpt-4o") + body = provider._build_body([{"role": "user", "content": "x"}], None, None, 0, 0.7, None, None) + assert body["max_output_tokens"] == 1 -def test_prepare_request_payload(): - """Test request payload preparation with Azure OpenAI 2024-10-21 compliance.""" - provider = AzureOpenAIProvider( - api_key="test-key", - api_base="https://test-resource.openai.azure.com", - default_model="gpt-4o", - ) - - messages = [{"role": "user", "content": "Hello"}] - payload = provider._prepare_request_payload("gpt-4o", messages, max_tokens=1500, temperature=0.8) - - assert payload["messages"] == messages - assert payload["max_completion_tokens"] == 1500 # Azure API 2024-10-21 uses max_completion_tokens - assert payload["temperature"] == 0.8 - assert "tools" not in payload - - # Test with tools +def test_build_body_with_tools(): + provider = AzureOpenAIProvider(api_key="k", api_base="https://r.com", default_model="gpt-4o") tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}] - payload_with_tools = provider._prepare_request_payload("gpt-4o", messages, tools=tools) - assert payload_with_tools["tools"] == tools - assert payload_with_tools["tool_choice"] == "auto" - - # Test with reasoning_effort - payload_with_reasoning = provider._prepare_request_payload( - "gpt-5-chat", messages, reasoning_effort="medium" + body = provider._build_body( + [{"role": "user", "content": "weather?"}], tools, None, 4096, 0.7, None, None, ) - assert payload_with_reasoning["reasoning_effort"] == "medium" - assert "temperature" not in payload_with_reasoning + assert body["tools"] == [{"type": "function", "name": "get_weather", "description": "", "parameters": {}}] + assert body["tool_choice"] == "auto" -def test_prepare_request_payload_sanitizes_messages(): - """Test Azure payload strips non-standard message keys before sending.""" - provider = AzureOpenAIProvider( - api_key="test-key", - api_base="https://test-resource.openai.azure.com", - default_model="gpt-4o", +def test_build_body_with_reasoning(): + provider = AzureOpenAIProvider(api_key="k", api_base="https://r.com", default_model="gpt-5-chat") + body = provider._build_body( + [{"role": "user", "content": "think"}], None, "gpt-5-chat", 4096, 0.7, "medium", None, ) + assert body["reasoning"] == {"effort": "medium"} + assert "reasoning.encrypted_content" in body.get("include", []) + # temperature omitted for reasoning models + assert "temperature" not in body - messages = [ - { - "role": "assistant", - "tool_calls": [{"id": "call_123", "type": "function", "function": {"name": "x"}}], - "reasoning_content": "hidden chain-of-thought", - }, - { - "role": "tool", - "tool_call_id": "call_123", - "name": "x", - "content": "ok", - "extra_field": "should be removed", - }, - ] - payload = provider._prepare_request_payload("gpt-4o", messages) +def test_build_body_image_conversion(): + """image_url content blocks should be converted to input_image.""" + provider = AzureOpenAIProvider(api_key="k", api_base="https://r.com", default_model="gpt-4o") + messages = [{ + "role": "user", + "content": [ + {"type": "text", "text": "What's in this image?"}, + {"type": "image_url", "image_url": {"url": "https://example.com/img.png"}}, + ], + }] + body = provider._build_body(messages, None, None, 4096, 0.7, None, None) + user_item = body["input"][0] + content_types = [b["type"] for b in user_item["content"]] + assert "input_text" in content_types + assert "input_image" in content_types + image_block = next(b for b in user_item["content"] if b["type"] == "input_image") + assert image_block["image_url"] == "https://example.com/img.png" - assert payload["messages"] == [ - { - "role": "assistant", - "content": None, - "tool_calls": [{"id": "call_123", "type": "function", "function": {"name": "x"}}], + +def test_build_body_sanitizes_single_dict_content_block(): + """Single content dicts should be preserved via shared message sanitization.""" + provider = AzureOpenAIProvider(api_key="k", api_base="https://r.com", default_model="gpt-4o") + messages = [{ + "role": "user", + "content": {"type": "text", "text": "Hi from dict content"}, + }] + + body = provider._build_body(messages, None, None, 4096, 0.7, None, None) + + assert body["input"][0]["content"] == [{"type": "input_text", "text": "Hi from dict content"}] + + +# --------------------------------------------------------------------------- +# chat() β€” non-streaming +# --------------------------------------------------------------------------- + + +def _make_sdk_response( + content="Hello!", tool_calls=None, status="completed", + usage=None, +): + """Build a mock that quacks like an openai Response object.""" + resp = MagicMock() + resp.model_dump = MagicMock(return_value={ + "output": [ + {"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": content}]}, + *([{ + "type": "function_call", + "call_id": tc["call_id"], "id": tc["id"], + "name": tc["name"], "arguments": tc["arguments"], + } for tc in (tool_calls or [])]), + ], + "status": status, + "usage": { + "input_tokens": (usage or {}).get("input_tokens", 10), + "output_tokens": (usage or {}).get("output_tokens", 5), + "total_tokens": (usage or {}).get("total_tokens", 15), }, - { - "role": "tool", - "tool_call_id": "call_123", - "name": "x", - "content": "ok", - }, - ] + }) + return resp @pytest.mark.asyncio async def test_chat_success(): - """Test successful chat request using model as deployment name.""" provider = AzureOpenAIProvider( - api_key="test-key", - api_base="https://test-resource.openai.azure.com", - default_model="gpt-4o-deployment", + api_key="test-key", api_base="https://test.openai.azure.com", default_model="gpt-4o", ) - - # Mock response data - mock_response_data = { - "choices": [{ - "message": { - "content": "Hello! How can I help you today?", - "role": "assistant" - }, - "finish_reason": "stop" - }], - "usage": { - "prompt_tokens": 12, - "completion_tokens": 18, - "total_tokens": 30 - } - } - - with patch("httpx.AsyncClient") as mock_client: - mock_response = AsyncMock() - mock_response.status_code = 200 - mock_response.json = Mock(return_value=mock_response_data) - - mock_context = AsyncMock() - mock_context.post = AsyncMock(return_value=mock_response) - mock_client.return_value.__aenter__.return_value = mock_context - - # Test with specific model (deployment name) - messages = [{"role": "user", "content": "Hello"}] - result = await provider.chat(messages, model="custom-deployment") - - assert isinstance(result, LLMResponse) - assert result.content == "Hello! How can I help you today?" - assert result.finish_reason == "stop" - assert result.usage["prompt_tokens"] == 12 - assert result.usage["completion_tokens"] == 18 - assert result.usage["total_tokens"] == 30 - - # Verify URL was built with the provided model as deployment name - call_args = mock_context.post.call_args - expected_url = "https://test-resource.openai.azure.com/openai/deployments/custom-deployment/chat/completions?api-version=2024-10-21" - assert call_args[0][0] == expected_url + mock_resp = _make_sdk_response(content="Hello!") + provider._client.responses = MagicMock() + provider._client.responses.create = AsyncMock(return_value=mock_resp) + + result = await provider.chat([{"role": "user", "content": "Hi"}]) + + assert isinstance(result, LLMResponse) + assert result.content == "Hello!" + assert result.finish_reason == "stop" + assert result.usage["prompt_tokens"] == 10 @pytest.mark.asyncio -async def test_chat_uses_default_model_when_no_model_provided(): - """Test that chat uses default_model when no model is specified.""" +async def test_chat_uses_default_model(): provider = AzureOpenAIProvider( - api_key="test-key", - api_base="https://test-resource.openai.azure.com", - default_model="default-deployment", + api_key="k", api_base="https://test.openai.azure.com", default_model="my-deployment", ) - - mock_response_data = { - "choices": [{ - "message": {"content": "Response", "role": "assistant"}, - "finish_reason": "stop" - }], - "usage": {"prompt_tokens": 5, "completion_tokens": 5, "total_tokens": 10} - } - - with patch("httpx.AsyncClient") as mock_client: - mock_response = AsyncMock() - mock_response.status_code = 200 - mock_response.json = Mock(return_value=mock_response_data) - - mock_context = AsyncMock() - mock_context.post = AsyncMock(return_value=mock_response) - mock_client.return_value.__aenter__.return_value = mock_context - - messages = [{"role": "user", "content": "Test"}] - await provider.chat(messages) # No model specified - - # Verify URL was built with default model as deployment name - call_args = mock_context.post.call_args - expected_url = "https://test-resource.openai.azure.com/openai/deployments/default-deployment/chat/completions?api-version=2024-10-21" - assert call_args[0][0] == expected_url + mock_resp = _make_sdk_response(content="ok") + provider._client.responses = MagicMock() + provider._client.responses.create = AsyncMock(return_value=mock_resp) + + await provider.chat([{"role": "user", "content": "test"}]) + + call_kwargs = provider._client.responses.create.call_args[1] + assert call_kwargs["model"] == "my-deployment" + + +@pytest.mark.asyncio +async def test_chat_custom_model(): + provider = AzureOpenAIProvider( + api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o", + ) + mock_resp = _make_sdk_response(content="ok") + provider._client.responses = MagicMock() + provider._client.responses.create = AsyncMock(return_value=mock_resp) + + await provider.chat([{"role": "user", "content": "test"}], model="custom-deploy") + + call_kwargs = provider._client.responses.create.call_args[1] + assert call_kwargs["model"] == "custom-deploy" @pytest.mark.asyncio async def test_chat_with_tool_calls(): - """Test chat request with tool calls in response.""" provider = AzureOpenAIProvider( - api_key="test-key", - api_base="https://test-resource.openai.azure.com", - default_model="gpt-4o", + api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o", ) - - # Mock response with tool calls - mock_response_data = { - "choices": [{ - "message": { - "content": None, - "role": "assistant", - "tool_calls": [{ - "id": "call_12345", - "function": { - "name": "get_weather", - "arguments": '{"location": "San Francisco"}' - } - }] - }, - "finish_reason": "tool_calls" + mock_resp = _make_sdk_response( + content=None, + tool_calls=[{ + "call_id": "call_123", "id": "fc_1", + "name": "get_weather", "arguments": '{"location": "SF"}', }], - "usage": { - "prompt_tokens": 20, - "completion_tokens": 15, - "total_tokens": 35 - } - } - - with patch("httpx.AsyncClient") as mock_client: - mock_response = AsyncMock() - mock_response.status_code = 200 - mock_response.json = Mock(return_value=mock_response_data) - - mock_context = AsyncMock() - mock_context.post = AsyncMock(return_value=mock_response) - mock_client.return_value.__aenter__.return_value = mock_context - - messages = [{"role": "user", "content": "What's the weather?"}] - tools = [{"type": "function", "function": {"name": "get_weather", "parameters": {}}}] - result = await provider.chat(messages, tools=tools, model="weather-model") - - assert isinstance(result, LLMResponse) - assert result.content is None - assert result.finish_reason == "tool_calls" - assert len(result.tool_calls) == 1 - assert result.tool_calls[0].name == "get_weather" - assert result.tool_calls[0].arguments == {"location": "San Francisco"} + ) + provider._client.responses = MagicMock() + provider._client.responses.create = AsyncMock(return_value=mock_resp) + + result = await provider.chat( + [{"role": "user", "content": "Weather?"}], + tools=[{"type": "function", "function": {"name": "get_weather", "parameters": {}}}], + ) + + assert len(result.tool_calls) == 1 + assert result.tool_calls[0].name == "get_weather" + assert result.tool_calls[0].arguments == {"location": "SF"} @pytest.mark.asyncio -async def test_chat_api_error(): - """Test chat request API error handling.""" +async def test_chat_error_handling(): provider = AzureOpenAIProvider( - api_key="test-key", - api_base="https://test-resource.openai.azure.com", - default_model="gpt-4o", + api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o", ) - - with patch("httpx.AsyncClient") as mock_client: - mock_response = AsyncMock() - mock_response.status_code = 401 - mock_response.text = "Invalid authentication credentials" - - mock_context = AsyncMock() - mock_context.post = AsyncMock(return_value=mock_response) - mock_client.return_value.__aenter__.return_value = mock_context - - messages = [{"role": "user", "content": "Hello"}] - result = await provider.chat(messages) - - assert isinstance(result, LLMResponse) - assert "Azure OpenAI API Error 401" in result.content - assert "Invalid authentication credentials" in result.content - assert result.finish_reason == "error" + provider._client.responses = MagicMock() + provider._client.responses.create = AsyncMock(side_effect=Exception("Connection failed")) + result = await provider.chat([{"role": "user", "content": "Hi"}]) -@pytest.mark.asyncio -async def test_chat_connection_error(): - """Test chat request connection error handling.""" - provider = AzureOpenAIProvider( - api_key="test-key", - api_base="https://test-resource.openai.azure.com", - default_model="gpt-4o", - ) - - with patch("httpx.AsyncClient") as mock_client: - mock_context = AsyncMock() - mock_context.post = AsyncMock(side_effect=Exception("Connection failed")) - mock_client.return_value.__aenter__.return_value = mock_context - - messages = [{"role": "user", "content": "Hello"}] - result = await provider.chat(messages) - - assert isinstance(result, LLMResponse) - assert "Error calling Azure OpenAI: Exception('Connection failed')" in result.content - assert result.finish_reason == "error" - - -def test_parse_response_malformed(): - """Test response parsing with malformed data.""" - provider = AzureOpenAIProvider( - api_key="test-key", - api_base="https://test-resource.openai.azure.com", - default_model="gpt-4o", - ) - - # Test with missing choices - malformed_response = {"usage": {"prompt_tokens": 10}} - result = provider._parse_response(malformed_response) - assert isinstance(result, LLMResponse) - assert "Error parsing Azure OpenAI response" in result.content + assert "Connection failed" in result.content assert result.finish_reason == "error" +@pytest.mark.asyncio +async def test_chat_reasoning_param_format(): + """reasoning_effort should be sent as reasoning={effort: ...} not a flat string.""" + provider = AzureOpenAIProvider( + api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-5-chat", + ) + mock_resp = _make_sdk_response(content="thought") + provider._client.responses = MagicMock() + provider._client.responses.create = AsyncMock(return_value=mock_resp) + + await provider.chat( + [{"role": "user", "content": "think"}], reasoning_effort="medium", + ) + + call_kwargs = provider._client.responses.create.call_args[1] + assert call_kwargs["reasoning"] == {"effort": "medium"} + assert "reasoning_effort" not in call_kwargs + + +# --------------------------------------------------------------------------- +# chat_stream() +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_chat_stream_success(): + """Streaming should call on_content_delta and return combined response.""" + provider = AzureOpenAIProvider( + api_key="test-key", api_base="https://test.openai.azure.com", default_model="gpt-4o", + ) + + # Build mock SDK stream events + events = [] + ev1 = MagicMock(type="response.output_text.delta", delta="Hello") + ev2 = MagicMock(type="response.output_text.delta", delta=" world") + resp_obj = MagicMock(status="completed") + ev3 = MagicMock(type="response.completed", response=resp_obj) + events = [ev1, ev2, ev3] + + async def mock_stream(): + for e in events: + yield e + + provider._client.responses = MagicMock() + provider._client.responses.create = AsyncMock(return_value=mock_stream()) + + deltas: list[str] = [] + + async def on_delta(text: str) -> None: + deltas.append(text) + + result = await provider.chat_stream( + [{"role": "user", "content": "Hi"}], on_content_delta=on_delta, + ) + + assert result.content == "Hello world" + assert result.finish_reason == "stop" + assert deltas == ["Hello", " world"] + + +@pytest.mark.asyncio +async def test_chat_stream_with_tool_calls(): + """Streaming tool calls should be accumulated correctly.""" + provider = AzureOpenAIProvider( + api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o", + ) + + item_added = MagicMock(type="function_call", call_id="call_1", id="fc_1", arguments="") + item_added.name = "get_weather" + ev_added = MagicMock(type="response.output_item.added", item=item_added) + ev_args_delta = MagicMock(type="response.function_call_arguments.delta", call_id="call_1", delta='{"loc') + ev_args_done = MagicMock( + type="response.function_call_arguments.done", + call_id="call_1", arguments='{"location":"SF"}', + ) + item_done = MagicMock( + type="function_call", call_id="call_1", id="fc_1", + arguments='{"location":"SF"}', + ) + item_done.name = "get_weather" + ev_item_done = MagicMock(type="response.output_item.done", item=item_done) + resp_obj = MagicMock(status="completed") + ev_completed = MagicMock(type="response.completed", response=resp_obj) + + async def mock_stream(): + for e in [ev_added, ev_args_delta, ev_args_done, ev_item_done, ev_completed]: + yield e + + provider._client.responses = MagicMock() + provider._client.responses.create = AsyncMock(return_value=mock_stream()) + + result = await provider.chat_stream( + [{"role": "user", "content": "weather?"}], + tools=[{"type": "function", "function": {"name": "get_weather", "parameters": {}}}], + ) + + assert len(result.tool_calls) == 1 + assert result.tool_calls[0].name == "get_weather" + assert result.tool_calls[0].arguments == {"location": "SF"} + + +@pytest.mark.asyncio +async def test_chat_stream_error(): + """Streaming should return error when SDK raises.""" + provider = AzureOpenAIProvider( + api_key="k", api_base="https://test.openai.azure.com", default_model="gpt-4o", + ) + provider._client.responses = MagicMock() + provider._client.responses.create = AsyncMock(side_effect=Exception("Connection failed")) + + result = await provider.chat_stream([{"role": "user", "content": "Hi"}]) + + assert "Connection failed" in result.content + assert result.finish_reason == "error" + + +# --------------------------------------------------------------------------- +# get_default_model +# --------------------------------------------------------------------------- + + def test_get_default_model(): - """Test get_default_model method.""" provider = AzureOpenAIProvider( - api_key="test-key", - api_base="https://test-resource.openai.azure.com", - default_model="my-custom-deployment", + api_key="k", api_base="https://r.com", default_model="my-deploy", ) - - assert provider.get_default_model() == "my-custom-deployment" - - -if __name__ == "__main__": - # Run basic tests - print("Running basic Azure OpenAI provider tests...") - - # Test initialization - provider = AzureOpenAIProvider( - api_key="test-key", - api_base="https://test-resource.openai.azure.com", - default_model="gpt-4o-deployment", - ) - print("βœ… Provider initialization successful") - - # Test URL building - url = provider._build_chat_url("my-deployment") - expected = "https://test-resource.openai.azure.com/openai/deployments/my-deployment/chat/completions?api-version=2024-10-21" - assert url == expected - print("βœ… URL building works correctly") - - # Test headers - headers = provider._build_headers() - assert headers["api-key"] == "test-key" - assert headers["Content-Type"] == "application/json" - print("βœ… Header building works correctly") - - # Test payload preparation - messages = [{"role": "user", "content": "Test"}] - payload = provider._prepare_request_payload("gpt-4o-deployment", messages, max_tokens=1000) - assert payload["max_completion_tokens"] == 1000 # Azure 2024-10-21 format - print("βœ… Payload preparation works correctly") - - print("βœ… All basic tests passed! Updated test file is working correctly.") \ No newline at end of file + assert provider.get_default_model() == "my-deploy" diff --git a/tests/providers/test_litellm_kwargs.py b/tests/providers/test_litellm_kwargs.py index 62fb0a2cc..cc8347f0e 100644 --- a/tests/providers/test_litellm_kwargs.py +++ b/tests/providers/test_litellm_kwargs.py @@ -8,6 +8,7 @@ Validates that: from __future__ import annotations +import asyncio from types import SimpleNamespace from unittest.mock import AsyncMock, patch @@ -53,6 +54,15 @@ def _fake_tool_call_response() -> SimpleNamespace: return SimpleNamespace(choices=[choice], usage=usage) +class _StalledStream: + def __aiter__(self): + return self + + async def __anext__(self): + await asyncio.sleep(3600) + raise StopAsyncIteration + + def test_openrouter_spec_is_gateway() -> None: spec = find_by_name("openrouter") assert spec is not None @@ -214,3 +224,54 @@ def test_openai_model_passthrough() -> None: spec=spec, ) assert provider.get_default_model() == "gpt-4o" + + +def test_openai_compat_strips_message_level_reasoning_fields() -> None: + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): + provider = OpenAICompatProvider() + + sanitized = provider._sanitize_messages([ + { + "role": "assistant", + "content": "done", + "reasoning_content": "hidden", + "extra_content": {"debug": True}, + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": {"name": "fn", "arguments": "{}"}, + "extra_content": {"google": {"thought_signature": "sig"}}, + } + ], + } + ]) + + assert "reasoning_content" not in sanitized[0] + assert "extra_content" not in sanitized[0] + assert sanitized[0]["tool_calls"][0]["extra_content"] == {"google": {"thought_signature": "sig"}} + + +@pytest.mark.asyncio +async def test_openai_compat_stream_watchdog_returns_error_on_stall(monkeypatch) -> None: + monkeypatch.setenv("NANOBOT_STREAM_IDLE_TIMEOUT_S", "0") + mock_create = AsyncMock(return_value=_StalledStream()) + spec = find_by_name("openai") + + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient: + client_instance = MockClient.return_value + client_instance.chat.completions.create = mock_create + + provider = OpenAICompatProvider( + api_key="sk-test-key", + default_model="gpt-4o", + spec=spec, + ) + result = await provider.chat_stream( + messages=[{"role": "user", "content": "hello"}], + model="gpt-4o", + ) + + assert result.finish_reason == "error" + assert result.content is not None + assert "stream stalled" in result.content diff --git a/tests/providers/test_openai_responses.py b/tests/providers/test_openai_responses.py new file mode 100644 index 000000000..ce4220655 --- /dev/null +++ b/tests/providers/test_openai_responses.py @@ -0,0 +1,522 @@ +"""Tests for the shared openai_responses converters and parsers.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from nanobot.providers.base import LLMResponse, ToolCallRequest +from nanobot.providers.openai_responses.converters import ( + convert_messages, + convert_tools, + convert_user_message, + split_tool_call_id, +) +from nanobot.providers.openai_responses.parsing import ( + consume_sdk_stream, + map_finish_reason, + parse_response_output, +) + + +# ====================================================================== +# converters - split_tool_call_id +# ====================================================================== + + +class TestSplitToolCallId: + def test_plain_id(self): + assert split_tool_call_id("call_abc") == ("call_abc", None) + + def test_compound_id(self): + assert split_tool_call_id("call_abc|fc_1") == ("call_abc", "fc_1") + + def test_compound_empty_item_id(self): + assert split_tool_call_id("call_abc|") == ("call_abc", None) + + def test_none(self): + assert split_tool_call_id(None) == ("call_0", None) + + def test_empty_string(self): + assert split_tool_call_id("") == ("call_0", None) + + def test_non_string(self): + assert split_tool_call_id(42) == ("call_0", None) + + +# ====================================================================== +# converters - convert_user_message +# ====================================================================== + + +class TestConvertUserMessage: + def test_string_content(self): + result = convert_user_message("hello") + assert result == {"role": "user", "content": [{"type": "input_text", "text": "hello"}]} + + def test_text_block(self): + result = convert_user_message([{"type": "text", "text": "hi"}]) + assert result["content"] == [{"type": "input_text", "text": "hi"}] + + def test_image_url_block(self): + result = convert_user_message([ + {"type": "image_url", "image_url": {"url": "https://img.example/a.png"}}, + ]) + assert result["content"] == [ + {"type": "input_image", "image_url": "https://img.example/a.png", "detail": "auto"}, + ] + + def test_mixed_text_and_image(self): + result = convert_user_message([ + {"type": "text", "text": "what's this?"}, + {"type": "image_url", "image_url": {"url": "https://img.example/b.png"}}, + ]) + assert len(result["content"]) == 2 + assert result["content"][0]["type"] == "input_text" + assert result["content"][1]["type"] == "input_image" + + def test_empty_list_falls_back(self): + result = convert_user_message([]) + assert result["content"] == [{"type": "input_text", "text": ""}] + + def test_none_falls_back(self): + result = convert_user_message(None) + assert result["content"] == [{"type": "input_text", "text": ""}] + + def test_image_without_url_skipped(self): + result = convert_user_message([{"type": "image_url", "image_url": {}}]) + assert result["content"] == [{"type": "input_text", "text": ""}] + + def test_meta_fields_not_leaked(self): + """_meta on content blocks must never appear in converted output.""" + result = convert_user_message([ + {"type": "text", "text": "hi", "_meta": {"path": "/tmp/x"}}, + ]) + assert "_meta" not in result["content"][0] + + def test_non_dict_items_skipped(self): + result = convert_user_message(["just a string", 42]) + assert result["content"] == [{"type": "input_text", "text": ""}] + + +# ====================================================================== +# converters - convert_messages +# ====================================================================== + + +class TestConvertMessages: + def test_system_extracted_as_instructions(self): + msgs = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hi"}, + ] + instructions, items = convert_messages(msgs) + assert instructions == "You are helpful." + assert len(items) == 1 + assert items[0]["role"] == "user" + + def test_multiple_system_messages_last_wins(self): + msgs = [ + {"role": "system", "content": "first"}, + {"role": "system", "content": "second"}, + {"role": "user", "content": "x"}, + ] + instructions, _ = convert_messages(msgs) + assert instructions == "second" + + def test_user_message_converted(self): + _, items = convert_messages([{"role": "user", "content": "hello"}]) + assert items[0]["role"] == "user" + assert items[0]["content"][0]["type"] == "input_text" + + def test_assistant_text_message(self): + _, items = convert_messages([ + {"role": "assistant", "content": "I'll help"}, + ]) + assert items[0]["type"] == "message" + assert items[0]["role"] == "assistant" + assert items[0]["content"][0]["type"] == "output_text" + assert items[0]["content"][0]["text"] == "I'll help" + + def test_assistant_empty_content_skipped(self): + _, items = convert_messages([{"role": "assistant", "content": ""}]) + assert len(items) == 0 + + def test_assistant_with_tool_calls(self): + _, items = convert_messages([{ + "role": "assistant", + "content": None, + "tool_calls": [{ + "id": "call_abc|fc_1", + "function": {"name": "get_weather", "arguments": '{"city":"SF"}'}, + }], + }]) + assert items[0]["type"] == "function_call" + assert items[0]["call_id"] == "call_abc" + assert items[0]["id"] == "fc_1" + assert items[0]["name"] == "get_weather" + + def test_assistant_with_tool_calls_no_id(self): + """Fallback IDs when tool_call.id is missing.""" + _, items = convert_messages([{ + "role": "assistant", + "content": None, + "tool_calls": [{"function": {"name": "f1", "arguments": "{}"}}], + }]) + assert items[0]["call_id"] == "call_0" + assert items[0]["id"].startswith("fc_") + + def test_tool_message(self): + _, items = convert_messages([{ + "role": "tool", + "tool_call_id": "call_abc", + "content": "result text", + }]) + assert items[0]["type"] == "function_call_output" + assert items[0]["call_id"] == "call_abc" + assert items[0]["output"] == "result text" + + def test_tool_message_dict_content(self): + _, items = convert_messages([{ + "role": "tool", + "tool_call_id": "call_1", + "content": {"key": "value"}, + }]) + assert items[0]["output"] == '{"key": "value"}' + + def test_non_standard_keys_not_leaked(self): + """Extra keys on messages must not appear in converted items.""" + _, items = convert_messages([{ + "role": "user", + "content": "hi", + "extra_field": "should vanish", + "_meta": {"path": "/tmp"}, + }]) + item = items[0] + assert "extra_field" not in str(item) + assert "_meta" not in str(item) + + def test_full_conversation_roundtrip(self): + """System + user + assistant(tool_call) + tool -> correct structure.""" + msgs = [ + {"role": "system", "content": "Be concise."}, + {"role": "user", "content": "Weather in SF?"}, + { + "role": "assistant", "content": None, + "tool_calls": [{ + "id": "c1|fc1", + "function": {"name": "get_weather", "arguments": '{"city":"SF"}'}, + }], + }, + {"role": "tool", "tool_call_id": "c1", "content": '{"temp":72}'}, + ] + instructions, items = convert_messages(msgs) + assert instructions == "Be concise." + assert len(items) == 3 # user, function_call, function_call_output + assert items[0]["role"] == "user" + assert items[1]["type"] == "function_call" + assert items[2]["type"] == "function_call_output" + + +# ====================================================================== +# converters - convert_tools +# ====================================================================== + + +class TestConvertTools: + def test_standard_function_tool(self): + tools = [{"type": "function", "function": { + "name": "get_weather", + "description": "Get weather", + "parameters": {"type": "object", "properties": {"city": {"type": "string"}}}, + }}] + result = convert_tools(tools) + assert len(result) == 1 + assert result[0]["type"] == "function" + assert result[0]["name"] == "get_weather" + assert result[0]["description"] == "Get weather" + assert "properties" in result[0]["parameters"] + + def test_tool_without_name_skipped(self): + tools = [{"type": "function", "function": {"parameters": {}}}] + assert convert_tools(tools) == [] + + def test_tool_without_function_wrapper(self): + """Direct dict without type=function wrapper.""" + tools = [{"name": "f1", "description": "d", "parameters": {}}] + result = convert_tools(tools) + assert result[0]["name"] == "f1" + + def test_missing_optional_fields_default(self): + tools = [{"type": "function", "function": {"name": "f"}}] + result = convert_tools(tools) + assert result[0]["description"] == "" + assert result[0]["parameters"] == {} + + def test_multiple_tools(self): + tools = [ + {"type": "function", "function": {"name": "a", "parameters": {}}}, + {"type": "function", "function": {"name": "b", "parameters": {}}}, + ] + assert len(convert_tools(tools)) == 2 + + +# ====================================================================== +# parsing - map_finish_reason +# ====================================================================== + + +class TestMapFinishReason: + def test_completed(self): + assert map_finish_reason("completed") == "stop" + + def test_incomplete(self): + assert map_finish_reason("incomplete") == "length" + + def test_failed(self): + assert map_finish_reason("failed") == "error" + + def test_cancelled(self): + assert map_finish_reason("cancelled") == "error" + + def test_none_defaults_to_stop(self): + assert map_finish_reason(None) == "stop" + + def test_unknown_defaults_to_stop(self): + assert map_finish_reason("some_new_status") == "stop" + + +# ====================================================================== +# parsing - parse_response_output +# ====================================================================== + + +class TestParseResponseOutput: + def test_text_response(self): + resp = { + "output": [{"type": "message", "role": "assistant", + "content": [{"type": "output_text", "text": "Hello!"}]}], + "status": "completed", + "usage": {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}, + } + result = parse_response_output(resp) + assert result.content == "Hello!" + assert result.finish_reason == "stop" + assert result.usage == {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15} + assert result.tool_calls == [] + + def test_tool_call_response(self): + resp = { + "output": [{ + "type": "function_call", + "call_id": "call_1", "id": "fc_1", + "name": "get_weather", + "arguments": '{"city": "SF"}', + }], + "status": "completed", + "usage": {}, + } + result = parse_response_output(resp) + assert result.content is None + assert len(result.tool_calls) == 1 + assert result.tool_calls[0].name == "get_weather" + assert result.tool_calls[0].arguments == {"city": "SF"} + assert result.tool_calls[0].id == "call_1|fc_1" + + def test_malformed_tool_arguments_logged(self): + """Malformed JSON arguments should log a warning and fallback.""" + resp = { + "output": [{ + "type": "function_call", + "call_id": "c1", "id": "fc1", + "name": "f", "arguments": "{bad json", + }], + "status": "completed", "usage": {}, + } + with patch("nanobot.providers.openai_responses.parsing.logger") as mock_logger: + result = parse_response_output(resp) + assert result.tool_calls[0].arguments == {"raw": "{bad json"} + mock_logger.warning.assert_called_once() + assert "Failed to parse tool call arguments" in str(mock_logger.warning.call_args) + + def test_reasoning_content_extracted(self): + resp = { + "output": [ + {"type": "reasoning", "summary": [ + {"type": "summary_text", "text": "I think "}, + {"type": "summary_text", "text": "therefore I am."}, + ]}, + {"type": "message", "role": "assistant", + "content": [{"type": "output_text", "text": "42"}]}, + ], + "status": "completed", "usage": {}, + } + result = parse_response_output(resp) + assert result.content == "42" + assert result.reasoning_content == "I think therefore I am." + + def test_empty_output(self): + resp = {"output": [], "status": "completed", "usage": {}} + result = parse_response_output(resp) + assert result.content is None + assert result.tool_calls == [] + + def test_incomplete_status(self): + resp = {"output": [], "status": "incomplete", "usage": {}} + result = parse_response_output(resp) + assert result.finish_reason == "length" + + def test_sdk_model_object(self): + """parse_response_output should handle SDK objects with model_dump().""" + mock = MagicMock() + mock.model_dump.return_value = { + "output": [{"type": "message", "role": "assistant", + "content": [{"type": "output_text", "text": "sdk"}]}], + "status": "completed", + "usage": {"input_tokens": 1, "output_tokens": 2, "total_tokens": 3}, + } + result = parse_response_output(mock) + assert result.content == "sdk" + assert result.usage["prompt_tokens"] == 1 + + def test_usage_maps_responses_api_keys(self): + """Responses API uses input_tokens/output_tokens, not prompt_tokens/completion_tokens.""" + resp = { + "output": [], + "status": "completed", + "usage": {"input_tokens": 100, "output_tokens": 50, "total_tokens": 150}, + } + result = parse_response_output(resp) + assert result.usage["prompt_tokens"] == 100 + assert result.usage["completion_tokens"] == 50 + assert result.usage["total_tokens"] == 150 + + +# ====================================================================== +# parsing - consume_sdk_stream +# ====================================================================== + + +class TestConsumeSdkStream: + @pytest.mark.asyncio + async def test_text_stream(self): + ev1 = MagicMock(type="response.output_text.delta", delta="Hello") + ev2 = MagicMock(type="response.output_text.delta", delta=" world") + resp_obj = MagicMock(status="completed", usage=None, output=[]) + ev3 = MagicMock(type="response.completed", response=resp_obj) + + async def stream(): + for e in [ev1, ev2, ev3]: + yield e + + content, tool_calls, finish_reason, usage, reasoning = await consume_sdk_stream(stream()) + assert content == "Hello world" + assert tool_calls == [] + assert finish_reason == "stop" + + @pytest.mark.asyncio + async def test_on_content_delta_called(self): + ev1 = MagicMock(type="response.output_text.delta", delta="hi") + resp_obj = MagicMock(status="completed", usage=None, output=[]) + ev2 = MagicMock(type="response.completed", response=resp_obj) + deltas = [] + + async def cb(text): + deltas.append(text) + + async def stream(): + for e in [ev1, ev2]: + yield e + + await consume_sdk_stream(stream(), on_content_delta=cb) + assert deltas == ["hi"] + + @pytest.mark.asyncio + async def test_tool_call_stream(self): + item_added = MagicMock(type="function_call", call_id="c1", id="fc1", arguments="") + item_added.name = "get_weather" + ev1 = MagicMock(type="response.output_item.added", item=item_added) + ev2 = MagicMock(type="response.function_call_arguments.delta", call_id="c1", delta='{"ci') + ev3 = MagicMock(type="response.function_call_arguments.done", call_id="c1", arguments='{"city":"SF"}') + item_done = MagicMock(type="function_call", call_id="c1", id="fc1", arguments='{"city":"SF"}') + item_done.name = "get_weather" + ev4 = MagicMock(type="response.output_item.done", item=item_done) + resp_obj = MagicMock(status="completed", usage=None, output=[]) + ev5 = MagicMock(type="response.completed", response=resp_obj) + + async def stream(): + for e in [ev1, ev2, ev3, ev4, ev5]: + yield e + + content, tool_calls, finish_reason, usage, reasoning = await consume_sdk_stream(stream()) + assert content == "" + assert len(tool_calls) == 1 + assert tool_calls[0].name == "get_weather" + assert tool_calls[0].arguments == {"city": "SF"} + + @pytest.mark.asyncio + async def test_usage_extracted(self): + usage_obj = MagicMock(input_tokens=10, output_tokens=5, total_tokens=15) + resp_obj = MagicMock(status="completed", usage=usage_obj, output=[]) + ev = MagicMock(type="response.completed", response=resp_obj) + + async def stream(): + yield ev + + _, _, _, usage, _ = await consume_sdk_stream(stream()) + assert usage == {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15} + + @pytest.mark.asyncio + async def test_reasoning_extracted(self): + summary_item = MagicMock(type="summary_text", text="thinking...") + reasoning_item = MagicMock(type="reasoning", summary=[summary_item]) + resp_obj = MagicMock(status="completed", usage=None, output=[reasoning_item]) + ev = MagicMock(type="response.completed", response=resp_obj) + + async def stream(): + yield ev + + _, _, _, _, reasoning = await consume_sdk_stream(stream()) + assert reasoning == "thinking..." + + @pytest.mark.asyncio + async def test_error_event_raises(self): + ev = MagicMock(type="error", error="rate_limit_exceeded") + + async def stream(): + yield ev + + with pytest.raises(RuntimeError, match="Response failed.*rate_limit_exceeded"): + await consume_sdk_stream(stream()) + + @pytest.mark.asyncio + async def test_failed_event_raises(self): + ev = MagicMock(type="response.failed", error="server_error") + + async def stream(): + yield ev + + with pytest.raises(RuntimeError, match="Response failed.*server_error"): + await consume_sdk_stream(stream()) + + @pytest.mark.asyncio + async def test_malformed_tool_args_logged(self): + """Malformed JSON in streaming tool args should log a warning.""" + item_added = MagicMock(type="function_call", call_id="c1", id="fc1", arguments="") + item_added.name = "f" + ev1 = MagicMock(type="response.output_item.added", item=item_added) + ev2 = MagicMock(type="response.function_call_arguments.done", call_id="c1", arguments="{bad") + item_done = MagicMock(type="function_call", call_id="c1", id="fc1", arguments="{bad") + item_done.name = "f" + ev3 = MagicMock(type="response.output_item.done", item=item_done) + resp_obj = MagicMock(status="completed", usage=None, output=[]) + ev4 = MagicMock(type="response.completed", response=resp_obj) + + async def stream(): + for e in [ev1, ev2, ev3, ev4]: + yield e + + with patch("nanobot.providers.openai_responses.parsing.logger") as mock_logger: + _, tool_calls, _, _, _ = await consume_sdk_stream(stream()) + assert tool_calls[0].arguments == {"raw": "{bad"} + mock_logger.warning.assert_called_once() + assert "Failed to parse tool call arguments" in str(mock_logger.warning.call_args) diff --git a/tests/providers/test_provider_retry.py b/tests/providers/test_provider_retry.py index d732054d5..1d8facf52 100644 --- a/tests/providers/test_provider_retry.py +++ b/tests/providers/test_provider_retry.py @@ -211,3 +211,56 @@ async def test_image_fallback_without_meta_uses_default_placeholder() -> None: content = msg.get("content") if isinstance(content, list): assert any("[image omitted]" in (b.get("text") or "") for b in content) + + +@pytest.mark.asyncio +async def test_chat_with_retry_uses_retry_after_and_emits_wait_progress(monkeypatch) -> None: + provider = ScriptedProvider([ + LLMResponse(content="429 rate limit, retry after 7s", finish_reason="error"), + LLMResponse(content="ok"), + ]) + delays: list[float] = [] + progress: list[str] = [] + + async def _fake_sleep(delay: float) -> None: + delays.append(delay) + + async def _progress(msg: str) -> None: + progress.append(msg) + + monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep) + + response = await provider.chat_with_retry( + messages=[{"role": "user", "content": "hello"}], + on_retry_wait=_progress, + ) + + assert response.content == "ok" + assert delays == [7.0] + assert progress and "7s" in progress[0] + + +@pytest.mark.asyncio +async def test_persistent_retry_aborts_after_ten_identical_transient_errors(monkeypatch) -> None: + provider = ScriptedProvider([ + *[LLMResponse(content="429 rate limit", finish_reason="error") for _ in range(10)], + LLMResponse(content="ok"), + ]) + delays: list[float] = [] + + async def _fake_sleep(delay: float) -> None: + delays.append(delay) + + monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep) + + response = await provider.chat_with_retry( + messages=[{"role": "user", "content": "hello"}], + retry_mode="persistent", + ) + + assert response.finish_reason == "error" + assert response.content == "429 rate limit" + assert provider.calls == 10 + assert delays == [1, 2, 4, 4, 4, 4, 4, 4, 4] + + diff --git a/tests/test_openai_api.py b/tests/test_openai_api.py index 42fec33ed..2d4ae8580 100644 --- a/tests/test_openai_api.py +++ b/tests/test_openai_api.py @@ -347,6 +347,8 @@ async def test_empty_response_retry_then_success(aiohttp_client) -> None: @pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed") @pytest.mark.asyncio async def test_empty_response_falls_back(aiohttp_client) -> None: + from nanobot.utils.runtime import EMPTY_FINAL_RESPONSE_MESSAGE + call_count = 0 async def always_empty(content, session_key="", channel="", chat_id=""): @@ -367,5 +369,5 @@ async def test_empty_response_falls_back(aiohttp_client) -> None: ) assert resp.status == 200 body = await resp.json() - assert body["choices"][0]["message"]["content"] == "I've completed processing but have no response to give." + assert body["choices"][0]["message"]["content"] == EMPTY_FINAL_RESPONSE_MESSAGE assert call_count == 2 diff --git a/tests/tools/test_mcp_tool.py b/tests/tools/test_mcp_tool.py index 28666f05f..9c1320251 100644 --- a/tests/tools/test_mcp_tool.py +++ b/tests/tools/test_mcp_tool.py @@ -196,7 +196,7 @@ async def test_execute_re_raises_external_cancellation() -> None: wrapper = _make_wrapper(SimpleNamespace(call_tool=call_tool), timeout=10) task = asyncio.create_task(wrapper.execute()) - await started.wait() + await asyncio.wait_for(started.wait(), timeout=1.0) task.cancel()