2026-04-03 19:07:30 +00:00

487 lines
18 KiB
Python

"""Base LLM provider interface."""
import asyncio
import json
import re
from abc import ABC, abstractmethod
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field
from datetime import datetime, timezone
from email.utils import parsedate_to_datetime
from typing import Any
from loguru import logger
from nanobot.utils.helpers import image_placeholder_text
@dataclass
class ToolCallRequest:
"""A tool call request from the LLM."""
id: str
name: str
arguments: dict[str, Any]
extra_content: dict[str, Any] | None = None
provider_specific_fields: dict[str, Any] | None = None
function_provider_specific_fields: dict[str, Any] | None = None
def to_openai_tool_call(self) -> dict[str, Any]:
"""Serialize to an OpenAI-style tool_call payload."""
tool_call = {
"id": self.id,
"type": "function",
"function": {
"name": self.name,
"arguments": json.dumps(self.arguments, ensure_ascii=False),
},
}
if self.extra_content:
tool_call["extra_content"] = self.extra_content
if self.provider_specific_fields:
tool_call["provider_specific_fields"] = self.provider_specific_fields
if self.function_provider_specific_fields:
tool_call["function"]["provider_specific_fields"] = self.function_provider_specific_fields
return tool_call
@dataclass
class LLMResponse:
"""Response from an LLM provider."""
content: str | None
tool_calls: list[ToolCallRequest] = field(default_factory=list)
finish_reason: str = "stop"
usage: dict[str, int] = field(default_factory=dict)
retry_after: float | None = None # Provider supplied retry wait in seconds.
reasoning_content: str | None = None # Kimi, DeepSeek-R1, MiMo etc.
thinking_blocks: list[dict] | None = None # Anthropic extended thinking
@property
def has_tool_calls(self) -> bool:
"""Check if response contains tool calls."""
return len(self.tool_calls) > 0
@dataclass(frozen=True)
class GenerationSettings:
"""Default generation settings."""
temperature: float = 0.7
max_tokens: int = 4096
reasoning_effort: str | None = None
class LLMProvider(ABC):
"""Base class for LLM providers."""
_CHAT_RETRY_DELAYS = (1, 2, 4)
_PERSISTENT_MAX_DELAY = 60
_PERSISTENT_IDENTICAL_ERROR_LIMIT = 10
_RETRY_HEARTBEAT_CHUNK = 30
_TRANSIENT_ERROR_MARKERS = (
"429",
"rate limit",
"500",
"502",
"503",
"504",
"overloaded",
"timeout",
"timed out",
"connection",
"server error",
"temporarily unavailable",
)
_SENTINEL = object()
def __init__(self, api_key: str | None = None, api_base: str | None = None):
self.api_key = api_key
self.api_base = api_base
self.generation: GenerationSettings = GenerationSettings()
@staticmethod
def _sanitize_empty_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Sanitize message content: fix empty blocks, strip internal _meta fields."""
result: list[dict[str, Any]] = []
for msg in messages:
content = msg.get("content")
if isinstance(content, str) and not content:
clean = dict(msg)
clean["content"] = None if (msg.get("role") == "assistant" and msg.get("tool_calls")) else "(empty)"
result.append(clean)
continue
if isinstance(content, list):
new_items: list[Any] = []
changed = False
for item in content:
if (
isinstance(item, dict)
and item.get("type") in ("text", "input_text", "output_text")
and not item.get("text")
):
changed = True
continue
if isinstance(item, dict) and "_meta" in item:
new_items.append({k: v for k, v in item.items() if k != "_meta"})
changed = True
else:
new_items.append(item)
if changed:
clean = dict(msg)
if new_items:
clean["content"] = new_items
elif msg.get("role") == "assistant" and msg.get("tool_calls"):
clean["content"] = None
else:
clean["content"] = "(empty)"
result.append(clean)
continue
if isinstance(content, dict):
clean = dict(msg)
clean["content"] = [content]
result.append(clean)
continue
result.append(msg)
return result
@staticmethod
def _sanitize_request_messages(
messages: list[dict[str, Any]],
allowed_keys: frozenset[str],
) -> list[dict[str, Any]]:
"""Keep only provider-safe message keys and normalize assistant content."""
sanitized = []
for msg in messages:
clean = {k: v for k, v in msg.items() if k in allowed_keys}
if clean.get("role") == "assistant" and "content" not in clean:
clean["content"] = None
sanitized.append(clean)
return sanitized
@abstractmethod
async def chat(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
model: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
reasoning_effort: str | None = None,
tool_choice: str | dict[str, Any] | None = None,
) -> LLMResponse:
"""
Send a chat completion request.
Args:
messages: List of message dicts with 'role' and 'content'.
tools: Optional list of tool definitions.
model: Model identifier (provider-specific).
max_tokens: Maximum tokens in response.
temperature: Sampling temperature.
tool_choice: Tool selection strategy ("auto", "required", or specific tool dict).
Returns:
LLMResponse with content and/or tool calls.
"""
pass
@classmethod
def _is_transient_error(cls, content: str | None) -> bool:
err = (content or "").lower()
return any(marker in err for marker in cls._TRANSIENT_ERROR_MARKERS)
@staticmethod
def _strip_image_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]] | None:
"""Replace image_url blocks with text placeholder. Returns None if no images found."""
found = False
result = []
for msg in messages:
content = msg.get("content")
if isinstance(content, list):
new_content = []
for b in content:
if isinstance(b, dict) and b.get("type") == "image_url":
path = (b.get("_meta") or {}).get("path", "")
placeholder = image_placeholder_text(path, empty="[image omitted]")
new_content.append({"type": "text", "text": placeholder})
found = True
else:
new_content.append(b)
result.append({**msg, "content": new_content})
else:
result.append(msg)
return result if found else None
async def _safe_chat(self, **kwargs: Any) -> LLMResponse:
"""Call chat() and convert unexpected exceptions to error responses."""
try:
return await self.chat(**kwargs)
except asyncio.CancelledError:
raise
except Exception as exc:
return LLMResponse(content=f"Error calling LLM: {exc}", finish_reason="error")
async def chat_stream(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
model: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
reasoning_effort: str | None = None,
tool_choice: str | dict[str, Any] | None = None,
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
) -> LLMResponse:
"""Stream a chat completion, calling *on_content_delta* for each text chunk.
Returns the same ``LLMResponse`` as :meth:`chat`. The default
implementation falls back to a non-streaming call and delivers the
full content as a single delta. Providers that support native
streaming should override this method.
"""
response = await self.chat(
messages=messages, tools=tools, model=model,
max_tokens=max_tokens, temperature=temperature,
reasoning_effort=reasoning_effort, tool_choice=tool_choice,
)
if on_content_delta and response.content:
await on_content_delta(response.content)
return response
async def _safe_chat_stream(self, **kwargs: Any) -> LLMResponse:
"""Call chat_stream() and convert unexpected exceptions to error responses."""
try:
return await self.chat_stream(**kwargs)
except asyncio.CancelledError:
raise
except Exception as exc:
return LLMResponse(content=f"Error calling LLM: {exc}", finish_reason="error")
async def chat_stream_with_retry(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
model: str | None = None,
max_tokens: object = _SENTINEL,
temperature: object = _SENTINEL,
reasoning_effort: object = _SENTINEL,
tool_choice: str | dict[str, Any] | None = None,
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
retry_mode: str = "standard",
on_retry_wait: Callable[[str], Awaitable[None]] | None = None,
) -> LLMResponse:
"""Call chat_stream() with retry on transient provider failures."""
if max_tokens is self._SENTINEL:
max_tokens = self.generation.max_tokens
if temperature is self._SENTINEL:
temperature = self.generation.temperature
if reasoning_effort is self._SENTINEL:
reasoning_effort = self.generation.reasoning_effort
kw: dict[str, Any] = dict(
messages=messages, tools=tools, model=model,
max_tokens=max_tokens, temperature=temperature,
reasoning_effort=reasoning_effort, tool_choice=tool_choice,
on_content_delta=on_content_delta,
)
return await self._run_with_retry(
self._safe_chat_stream,
kw,
messages,
retry_mode=retry_mode,
on_retry_wait=on_retry_wait,
)
async def chat_with_retry(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
model: str | None = None,
max_tokens: object = _SENTINEL,
temperature: object = _SENTINEL,
reasoning_effort: object = _SENTINEL,
tool_choice: str | dict[str, Any] | None = None,
retry_mode: str = "standard",
on_retry_wait: Callable[[str], Awaitable[None]] | None = None,
) -> LLMResponse:
"""Call chat() with retry on transient provider failures.
Parameters default to ``self.generation`` when not explicitly passed,
so callers no longer need to thread temperature / max_tokens /
reasoning_effort through every layer.
"""
if max_tokens is self._SENTINEL:
max_tokens = self.generation.max_tokens
if temperature is self._SENTINEL:
temperature = self.generation.temperature
if reasoning_effort is self._SENTINEL:
reasoning_effort = self.generation.reasoning_effort
kw: dict[str, Any] = dict(
messages=messages, tools=tools, model=model,
max_tokens=max_tokens, temperature=temperature,
reasoning_effort=reasoning_effort, tool_choice=tool_choice,
)
return await self._run_with_retry(
self._safe_chat,
kw,
messages,
retry_mode=retry_mode,
on_retry_wait=on_retry_wait,
)
@classmethod
def _extract_retry_after(cls, content: str | None) -> float | None:
text = (content or "").lower()
patterns = (
r"retry after\s+(\d+(?:\.\d+)?)\s*(ms|milliseconds|s|sec|secs|seconds|m|min|minutes)?",
r"try again in\s+(\d+(?:\.\d+)?)\s*(ms|milliseconds|s|sec|secs|seconds|m|min|minutes)",
r"wait\s+(\d+(?:\.\d+)?)\s*(ms|milliseconds|s|sec|secs|seconds|m|min|minutes)\s*before retry",
r"retry[_-]?after[\"'\s:=]+(\d+(?:\.\d+)?)",
)
for idx, pattern in enumerate(patterns):
match = re.search(pattern, text)
if not match:
continue
value = float(match.group(1))
unit = match.group(2) if idx < 3 else "s"
return cls._to_retry_seconds(value, unit)
return None
@classmethod
def _to_retry_seconds(cls, value: float, unit: str | None = None) -> float:
normalized_unit = (unit or "s").lower()
if normalized_unit in {"ms", "milliseconds"}:
return max(0.1, value / 1000.0)
if normalized_unit in {"m", "min", "minutes"}:
return max(0.1, value * 60.0)
return max(0.1, value)
@classmethod
def _extract_retry_after_from_headers(cls, headers: Any) -> float | None:
if not headers:
return None
retry_after: Any = None
if hasattr(headers, "get"):
retry_after = headers.get("retry-after") or headers.get("Retry-After")
if retry_after is None and isinstance(headers, dict):
for key, value in headers.items():
if isinstance(key, str) and key.lower() == "retry-after":
retry_after = value
break
if retry_after is None:
return None
retry_after_text = str(retry_after).strip()
if not retry_after_text:
return None
if re.fullmatch(r"\d+(?:\.\d+)?", retry_after_text):
return cls._to_retry_seconds(float(retry_after_text), "s")
try:
retry_at = parsedate_to_datetime(retry_after_text)
except Exception:
return None
if retry_at.tzinfo is None:
retry_at = retry_at.replace(tzinfo=timezone.utc)
remaining = (retry_at - datetime.now(retry_at.tzinfo)).total_seconds()
return max(0.1, remaining)
async def _sleep_with_heartbeat(
self,
delay: float,
*,
attempt: int,
persistent: bool,
on_retry_wait: Callable[[str], Awaitable[None]] | None = None,
) -> None:
remaining = max(0.0, delay)
while remaining > 0:
if on_retry_wait:
kind = "persistent retry" if persistent else "retry"
await on_retry_wait(
f"Model request failed, {kind} in {max(1, int(round(remaining)))}s "
f"(attempt {attempt})."
)
chunk = min(remaining, self._RETRY_HEARTBEAT_CHUNK)
await asyncio.sleep(chunk)
remaining -= chunk
async def _run_with_retry(
self,
call: Callable[..., Awaitable[LLMResponse]],
kw: dict[str, Any],
original_messages: list[dict[str, Any]],
*,
retry_mode: str,
on_retry_wait: Callable[[str], Awaitable[None]] | None,
) -> LLMResponse:
attempt = 0
delays = list(self._CHAT_RETRY_DELAYS)
persistent = retry_mode == "persistent"
last_response: LLMResponse | None = None
last_error_key: str | None = None
identical_error_count = 0
while True:
attempt += 1
response = await call(**kw)
if response.finish_reason != "error":
return response
last_response = response
error_key = ((response.content or "").strip().lower() or None)
if error_key and error_key == last_error_key:
identical_error_count += 1
else:
last_error_key = error_key
identical_error_count = 1 if error_key else 0
if not self._is_transient_error(response.content):
stripped = self._strip_image_content(original_messages)
if stripped is not None and stripped != kw["messages"]:
logger.warning(
"Non-transient LLM error with image content, retrying without images"
)
retry_kw = dict(kw)
retry_kw["messages"] = stripped
return await call(**retry_kw)
return response
if persistent and identical_error_count >= self._PERSISTENT_IDENTICAL_ERROR_LIMIT:
logger.warning(
"Stopping persistent retry after {} identical transient errors: {}",
identical_error_count,
(response.content or "")[:120].lower(),
)
return response
if not persistent and attempt > len(delays):
break
base_delay = delays[min(attempt - 1, len(delays) - 1)]
delay = response.retry_after or self._extract_retry_after(response.content) or base_delay
if persistent:
delay = min(delay, self._PERSISTENT_MAX_DELAY)
logger.warning(
"LLM transient error (attempt {}{}), retrying in {}s: {}",
attempt,
"+" if persistent and attempt > len(delays) else f"/{len(delays)}",
int(round(delay)),
(response.content or "")[:120].lower(),
)
await self._sleep_with_heartbeat(
delay,
attempt=attempt,
persistent=persistent,
on_retry_wait=on_retry_wait,
)
return last_response if last_response is not None else await call(**kw)
@abstractmethod
def get_default_model(self) -> str:
"""Get the default model for this provider."""
pass