"""Base LLM provider interface.""" import asyncio import json import re from abc import ABC, abstractmethod from collections.abc import Awaitable, Callable from dataclasses import dataclass, field from datetime import datetime, timezone from email.utils import parsedate_to_datetime from typing import Any from loguru import logger from nanobot.utils.helpers import image_placeholder_text @dataclass class ToolCallRequest: """A tool call request from the LLM.""" id: str name: str arguments: dict[str, Any] extra_content: dict[str, Any] | None = None provider_specific_fields: dict[str, Any] | None = None function_provider_specific_fields: dict[str, Any] | None = None def to_openai_tool_call(self) -> dict[str, Any]: """Serialize to an OpenAI-style tool_call payload.""" tool_call = { "id": self.id, "type": "function", "function": { "name": self.name, "arguments": json.dumps(self.arguments, ensure_ascii=False), }, } if self.extra_content: tool_call["extra_content"] = self.extra_content if self.provider_specific_fields: tool_call["provider_specific_fields"] = self.provider_specific_fields if self.function_provider_specific_fields: tool_call["function"]["provider_specific_fields"] = self.function_provider_specific_fields return tool_call @dataclass class LLMResponse: """Response from an LLM provider.""" content: str | None tool_calls: list[ToolCallRequest] = field(default_factory=list) finish_reason: str = "stop" usage: dict[str, int] = field(default_factory=dict) retry_after: float | None = None # Provider supplied retry wait in seconds. reasoning_content: str | None = None # Kimi, DeepSeek-R1, MiMo etc. thinking_blocks: list[dict] | None = None # Anthropic extended thinking @property def has_tool_calls(self) -> bool: """Check if response contains tool calls.""" return len(self.tool_calls) > 0 @dataclass(frozen=True) class GenerationSettings: """Default generation settings.""" temperature: float = 0.7 max_tokens: int = 4096 reasoning_effort: str | None = None class LLMProvider(ABC): """Base class for LLM providers.""" _CHAT_RETRY_DELAYS = (1, 2, 4) _PERSISTENT_MAX_DELAY = 60 _PERSISTENT_IDENTICAL_ERROR_LIMIT = 10 _RETRY_HEARTBEAT_CHUNK = 30 _TRANSIENT_ERROR_MARKERS = ( "429", "rate limit", "500", "502", "503", "504", "overloaded", "timeout", "timed out", "connection", "server error", "temporarily unavailable", ) _SENTINEL = object() def __init__(self, api_key: str | None = None, api_base: str | None = None): self.api_key = api_key self.api_base = api_base self.generation: GenerationSettings = GenerationSettings() @staticmethod def _sanitize_empty_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]]: """Sanitize message content: fix empty blocks, strip internal _meta fields.""" result: list[dict[str, Any]] = [] for msg in messages: content = msg.get("content") if isinstance(content, str) and not content: clean = dict(msg) clean["content"] = None if (msg.get("role") == "assistant" and msg.get("tool_calls")) else "(empty)" result.append(clean) continue if isinstance(content, list): new_items: list[Any] = [] changed = False for item in content: if ( isinstance(item, dict) and item.get("type") in ("text", "input_text", "output_text") and not item.get("text") ): changed = True continue if isinstance(item, dict) and "_meta" in item: new_items.append({k: v for k, v in item.items() if k != "_meta"}) changed = True else: new_items.append(item) if changed: clean = dict(msg) if new_items: clean["content"] = new_items elif msg.get("role") == "assistant" and msg.get("tool_calls"): clean["content"] = None else: clean["content"] = "(empty)" result.append(clean) continue if isinstance(content, dict): clean = dict(msg) clean["content"] = [content] result.append(clean) continue result.append(msg) return result @staticmethod def _sanitize_request_messages( messages: list[dict[str, Any]], allowed_keys: frozenset[str], ) -> list[dict[str, Any]]: """Keep only provider-safe message keys and normalize assistant content.""" sanitized = [] for msg in messages: clean = {k: v for k, v in msg.items() if k in allowed_keys} if clean.get("role") == "assistant" and "content" not in clean: clean["content"] = None sanitized.append(clean) return sanitized @abstractmethod async def chat( self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None, model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7, reasoning_effort: str | None = None, tool_choice: str | dict[str, Any] | None = None, ) -> LLMResponse: """ Send a chat completion request. Args: messages: List of message dicts with 'role' and 'content'. tools: Optional list of tool definitions. model: Model identifier (provider-specific). max_tokens: Maximum tokens in response. temperature: Sampling temperature. tool_choice: Tool selection strategy ("auto", "required", or specific tool dict). Returns: LLMResponse with content and/or tool calls. """ pass @classmethod def _is_transient_error(cls, content: str | None) -> bool: err = (content or "").lower() return any(marker in err for marker in cls._TRANSIENT_ERROR_MARKERS) @staticmethod def _strip_image_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]] | None: """Replace image_url blocks with text placeholder. Returns None if no images found.""" found = False result = [] for msg in messages: content = msg.get("content") if isinstance(content, list): new_content = [] for b in content: if isinstance(b, dict) and b.get("type") == "image_url": path = (b.get("_meta") or {}).get("path", "") placeholder = image_placeholder_text(path, empty="[image omitted]") new_content.append({"type": "text", "text": placeholder}) found = True else: new_content.append(b) result.append({**msg, "content": new_content}) else: result.append(msg) return result if found else None async def _safe_chat(self, **kwargs: Any) -> LLMResponse: """Call chat() and convert unexpected exceptions to error responses.""" try: return await self.chat(**kwargs) except asyncio.CancelledError: raise except Exception as exc: return LLMResponse(content=f"Error calling LLM: {exc}", finish_reason="error") async def chat_stream( self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None, model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7, reasoning_effort: str | None = None, tool_choice: str | dict[str, Any] | None = None, on_content_delta: Callable[[str], Awaitable[None]] | None = None, ) -> LLMResponse: """Stream a chat completion, calling *on_content_delta* for each text chunk. Returns the same ``LLMResponse`` as :meth:`chat`. The default implementation falls back to a non-streaming call and delivers the full content as a single delta. Providers that support native streaming should override this method. """ response = await self.chat( messages=messages, tools=tools, model=model, max_tokens=max_tokens, temperature=temperature, reasoning_effort=reasoning_effort, tool_choice=tool_choice, ) if on_content_delta and response.content: await on_content_delta(response.content) return response async def _safe_chat_stream(self, **kwargs: Any) -> LLMResponse: """Call chat_stream() and convert unexpected exceptions to error responses.""" try: return await self.chat_stream(**kwargs) except asyncio.CancelledError: raise except Exception as exc: return LLMResponse(content=f"Error calling LLM: {exc}", finish_reason="error") async def chat_stream_with_retry( self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None, model: str | None = None, max_tokens: object = _SENTINEL, temperature: object = _SENTINEL, reasoning_effort: object = _SENTINEL, tool_choice: str | dict[str, Any] | None = None, on_content_delta: Callable[[str], Awaitable[None]] | None = None, retry_mode: str = "standard", on_retry_wait: Callable[[str], Awaitable[None]] | None = None, ) -> LLMResponse: """Call chat_stream() with retry on transient provider failures.""" if max_tokens is self._SENTINEL: max_tokens = self.generation.max_tokens if temperature is self._SENTINEL: temperature = self.generation.temperature if reasoning_effort is self._SENTINEL: reasoning_effort = self.generation.reasoning_effort kw: dict[str, Any] = dict( messages=messages, tools=tools, model=model, max_tokens=max_tokens, temperature=temperature, reasoning_effort=reasoning_effort, tool_choice=tool_choice, on_content_delta=on_content_delta, ) return await self._run_with_retry( self._safe_chat_stream, kw, messages, retry_mode=retry_mode, on_retry_wait=on_retry_wait, ) async def chat_with_retry( self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None, model: str | None = None, max_tokens: object = _SENTINEL, temperature: object = _SENTINEL, reasoning_effort: object = _SENTINEL, tool_choice: str | dict[str, Any] | None = None, retry_mode: str = "standard", on_retry_wait: Callable[[str], Awaitable[None]] | None = None, ) -> LLMResponse: """Call chat() with retry on transient provider failures. Parameters default to ``self.generation`` when not explicitly passed, so callers no longer need to thread temperature / max_tokens / reasoning_effort through every layer. """ if max_tokens is self._SENTINEL: max_tokens = self.generation.max_tokens if temperature is self._SENTINEL: temperature = self.generation.temperature if reasoning_effort is self._SENTINEL: reasoning_effort = self.generation.reasoning_effort kw: dict[str, Any] = dict( messages=messages, tools=tools, model=model, max_tokens=max_tokens, temperature=temperature, reasoning_effort=reasoning_effort, tool_choice=tool_choice, ) return await self._run_with_retry( self._safe_chat, kw, messages, retry_mode=retry_mode, on_retry_wait=on_retry_wait, ) @classmethod def _extract_retry_after(cls, content: str | None) -> float | None: text = (content or "").lower() patterns = ( r"retry after\s+(\d+(?:\.\d+)?)\s*(ms|milliseconds|s|sec|secs|seconds|m|min|minutes)?", r"try again in\s+(\d+(?:\.\d+)?)\s*(ms|milliseconds|s|sec|secs|seconds|m|min|minutes)", r"wait\s+(\d+(?:\.\d+)?)\s*(ms|milliseconds|s|sec|secs|seconds|m|min|minutes)\s*before retry", r"retry[_-]?after[\"'\s:=]+(\d+(?:\.\d+)?)", ) for idx, pattern in enumerate(patterns): match = re.search(pattern, text) if not match: continue value = float(match.group(1)) unit = match.group(2) if idx < 3 else "s" return cls._to_retry_seconds(value, unit) return None @classmethod def _to_retry_seconds(cls, value: float, unit: str | None = None) -> float: normalized_unit = (unit or "s").lower() if normalized_unit in {"ms", "milliseconds"}: return max(0.1, value / 1000.0) if normalized_unit in {"m", "min", "minutes"}: return max(0.1, value * 60.0) return max(0.1, value) @classmethod def _extract_retry_after_from_headers(cls, headers: Any) -> float | None: if not headers: return None retry_after: Any = None if hasattr(headers, "get"): retry_after = headers.get("retry-after") or headers.get("Retry-After") if retry_after is None and isinstance(headers, dict): for key, value in headers.items(): if isinstance(key, str) and key.lower() == "retry-after": retry_after = value break if retry_after is None: return None retry_after_text = str(retry_after).strip() if not retry_after_text: return None if re.fullmatch(r"\d+(?:\.\d+)?", retry_after_text): return cls._to_retry_seconds(float(retry_after_text), "s") try: retry_at = parsedate_to_datetime(retry_after_text) except Exception: return None if retry_at.tzinfo is None: retry_at = retry_at.replace(tzinfo=timezone.utc) remaining = (retry_at - datetime.now(retry_at.tzinfo)).total_seconds() return max(0.1, remaining) async def _sleep_with_heartbeat( self, delay: float, *, attempt: int, persistent: bool, on_retry_wait: Callable[[str], Awaitable[None]] | None = None, ) -> None: remaining = max(0.0, delay) while remaining > 0: if on_retry_wait: kind = "persistent retry" if persistent else "retry" await on_retry_wait( f"Model request failed, {kind} in {max(1, int(round(remaining)))}s " f"(attempt {attempt})." ) chunk = min(remaining, self._RETRY_HEARTBEAT_CHUNK) await asyncio.sleep(chunk) remaining -= chunk async def _run_with_retry( self, call: Callable[..., Awaitable[LLMResponse]], kw: dict[str, Any], original_messages: list[dict[str, Any]], *, retry_mode: str, on_retry_wait: Callable[[str], Awaitable[None]] | None, ) -> LLMResponse: attempt = 0 delays = list(self._CHAT_RETRY_DELAYS) persistent = retry_mode == "persistent" last_response: LLMResponse | None = None last_error_key: str | None = None identical_error_count = 0 while True: attempt += 1 response = await call(**kw) if response.finish_reason != "error": return response last_response = response error_key = ((response.content or "").strip().lower() or None) if error_key and error_key == last_error_key: identical_error_count += 1 else: last_error_key = error_key identical_error_count = 1 if error_key else 0 if not self._is_transient_error(response.content): stripped = self._strip_image_content(original_messages) if stripped is not None and stripped != kw["messages"]: logger.warning( "Non-transient LLM error with image content, retrying without images" ) retry_kw = dict(kw) retry_kw["messages"] = stripped return await call(**retry_kw) return response if persistent and identical_error_count >= self._PERSISTENT_IDENTICAL_ERROR_LIMIT: logger.warning( "Stopping persistent retry after {} identical transient errors: {}", identical_error_count, (response.content or "")[:120].lower(), ) return response if not persistent and attempt > len(delays): break base_delay = delays[min(attempt - 1, len(delays) - 1)] delay = response.retry_after or self._extract_retry_after(response.content) or base_delay if persistent: delay = min(delay, self._PERSISTENT_MAX_DELAY) logger.warning( "LLM transient error (attempt {}{}), retrying in {}s: {}", attempt, "+" if persistent and attempt > len(delays) else f"/{len(delays)}", int(round(delay)), (response.content or "")[:120].lower(), ) await self._sleep_with_heartbeat( delay, attempt=attempt, persistent=persistent, on_retry_wait=on_retry_wait, ) return last_response if last_response is not None else await call(**kw) @abstractmethod def get_default_model(self) -> str: """Get the default model for this provider.""" pass