mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-03 00:05:55 +00:00
807 lines
30 KiB
Python
807 lines
30 KiB
Python
"""Base LLM provider interface."""
|
|
|
|
import asyncio
|
|
import json
|
|
import re
|
|
from abc import ABC, abstractmethod
|
|
from collections.abc import Awaitable, Callable
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime, timezone
|
|
from email.utils import parsedate_to_datetime
|
|
from typing import Any
|
|
|
|
from loguru import logger
|
|
|
|
from nanobot.utils.helpers import image_placeholder_text
|
|
|
|
|
|
@dataclass
|
|
class ToolCallRequest:
|
|
"""A tool call request from the LLM."""
|
|
|
|
id: str
|
|
name: str
|
|
arguments: dict[str, Any]
|
|
extra_content: dict[str, Any] | None = None
|
|
provider_specific_fields: dict[str, Any] | None = None
|
|
function_provider_specific_fields: dict[str, Any] | None = None
|
|
|
|
def to_openai_tool_call(self) -> dict[str, Any]:
|
|
"""Serialize to an OpenAI-style tool_call payload."""
|
|
tool_call = {
|
|
"id": self.id,
|
|
"type": "function",
|
|
"function": {
|
|
"name": self.name,
|
|
"arguments": json.dumps(self.arguments, ensure_ascii=False),
|
|
},
|
|
}
|
|
if self.extra_content:
|
|
tool_call["extra_content"] = self.extra_content
|
|
if self.provider_specific_fields:
|
|
tool_call["provider_specific_fields"] = self.provider_specific_fields
|
|
if self.function_provider_specific_fields:
|
|
tool_call["function"]["provider_specific_fields"] = (
|
|
self.function_provider_specific_fields
|
|
)
|
|
return tool_call
|
|
|
|
|
|
@dataclass
|
|
class LLMResponse:
|
|
"""Response from an LLM provider."""
|
|
|
|
content: str | None
|
|
tool_calls: list[ToolCallRequest] = field(default_factory=list)
|
|
finish_reason: str = "stop"
|
|
usage: dict[str, int] = field(default_factory=dict)
|
|
retry_after: float | None = None # Provider supplied retry wait in seconds.
|
|
reasoning_content: str | None = None # Kimi, DeepSeek-R1, MiMo etc.
|
|
thinking_blocks: list[dict] | None = None # Anthropic extended thinking
|
|
# Structured error metadata used by retry policy when finish_reason == "error".
|
|
error_status_code: int | None = None
|
|
error_kind: str | None = None # e.g. "timeout", "connection"
|
|
error_type: str | None = None # Provider/type semantic, e.g. insufficient_quota.
|
|
error_code: str | None = None # Provider/code semantic, e.g. rate_limit_exceeded.
|
|
error_retry_after_s: float | None = None
|
|
error_should_retry: bool | None = None
|
|
|
|
@property
|
|
def has_tool_calls(self) -> bool:
|
|
"""Check if response contains tool calls."""
|
|
return len(self.tool_calls) > 0
|
|
|
|
@property
|
|
def should_execute_tools(self) -> bool:
|
|
"""True only if tool_calls present AND finish_reason is a known-good signal (``tool_calls`` or ``stop``); blocks gateway-injected calls under ``refusal`` / ``content_filter`` / ``error``."""
|
|
if not self.has_tool_calls:
|
|
return False
|
|
return self.finish_reason in ("tool_calls", "stop")
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class GenerationSettings:
|
|
"""Default generation settings."""
|
|
|
|
temperature: float = 0.7
|
|
max_tokens: int = 4096
|
|
reasoning_effort: str | None = None
|
|
|
|
|
|
_SYNTHETIC_USER_CONTENT = "(conversation continued)"
|
|
|
|
|
|
class LLMProvider(ABC):
|
|
"""Base class for LLM providers."""
|
|
|
|
_CHAT_RETRY_DELAYS = (1, 2, 4)
|
|
_PERSISTENT_MAX_DELAY = 60
|
|
_PERSISTENT_IDENTICAL_ERROR_LIMIT = 10
|
|
_RETRY_HEARTBEAT_CHUNK = 30
|
|
_TRANSIENT_ERROR_MARKERS = (
|
|
"429",
|
|
"rate limit",
|
|
"500",
|
|
"502",
|
|
"503",
|
|
"504",
|
|
"overloaded",
|
|
"timeout",
|
|
"timed out",
|
|
"connection",
|
|
"server error",
|
|
"temporarily unavailable",
|
|
)
|
|
_RETRYABLE_STATUS_CODES = frozenset({408, 409, 429})
|
|
_TRANSIENT_ERROR_KINDS = frozenset({"timeout", "connection"})
|
|
_NON_RETRYABLE_429_ERROR_TOKENS = frozenset(
|
|
{
|
|
"insufficient_quota",
|
|
"quota_exceeded",
|
|
"quota_exhausted",
|
|
"billing_hard_limit_reached",
|
|
"insufficient_balance",
|
|
"credit_balance_too_low",
|
|
"billing_not_active",
|
|
"payment_required",
|
|
}
|
|
)
|
|
_RETRYABLE_429_ERROR_TOKENS = frozenset(
|
|
{
|
|
"rate_limit_exceeded",
|
|
"rate_limit_error",
|
|
"too_many_requests",
|
|
"request_limit_exceeded",
|
|
"requests_limit_exceeded",
|
|
"overloaded_error",
|
|
}
|
|
)
|
|
_NON_RETRYABLE_429_TEXT_MARKERS = (
|
|
"insufficient_quota",
|
|
"insufficient quota",
|
|
"quota exceeded",
|
|
"quota exhausted",
|
|
"billing hard limit",
|
|
"billing_hard_limit_reached",
|
|
"billing not active",
|
|
"insufficient balance",
|
|
"insufficient_balance",
|
|
"credit balance too low",
|
|
"payment required",
|
|
"out of credits",
|
|
"out of quota",
|
|
"exceeded your current quota",
|
|
)
|
|
_RETRYABLE_429_TEXT_MARKERS = (
|
|
"rate limit",
|
|
"rate_limit",
|
|
"too many requests",
|
|
"retry after",
|
|
"try again in",
|
|
"temporarily unavailable",
|
|
"overloaded",
|
|
"concurrency limit",
|
|
)
|
|
|
|
_SENTINEL = object()
|
|
|
|
def __init__(self, api_key: str | None = None, api_base: str | None = None):
|
|
self.api_key = api_key
|
|
self.api_base = api_base
|
|
self.generation: GenerationSettings = GenerationSettings()
|
|
|
|
@staticmethod
|
|
def _sanitize_empty_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
"""Sanitize message content: fix empty blocks, strip internal _meta fields."""
|
|
result: list[dict[str, Any]] = []
|
|
for msg in messages:
|
|
content = msg.get("content")
|
|
|
|
if isinstance(content, str) and not content:
|
|
clean = dict(msg)
|
|
clean["content"] = (
|
|
None
|
|
if (msg.get("role") == "assistant" and msg.get("tool_calls"))
|
|
else "(empty)"
|
|
)
|
|
result.append(clean)
|
|
continue
|
|
|
|
if isinstance(content, list):
|
|
new_items: list[Any] = []
|
|
changed = False
|
|
for item in content:
|
|
if (
|
|
isinstance(item, dict)
|
|
and item.get("type") in ("text", "input_text", "output_text")
|
|
and not item.get("text")
|
|
):
|
|
changed = True
|
|
continue
|
|
if isinstance(item, dict) and "_meta" in item:
|
|
new_items.append({k: v for k, v in item.items() if k != "_meta"})
|
|
changed = True
|
|
else:
|
|
new_items.append(item)
|
|
if changed:
|
|
clean = dict(msg)
|
|
if new_items:
|
|
clean["content"] = new_items
|
|
elif msg.get("role") == "assistant" and msg.get("tool_calls"):
|
|
clean["content"] = None
|
|
else:
|
|
clean["content"] = "(empty)"
|
|
result.append(clean)
|
|
continue
|
|
|
|
if isinstance(content, dict):
|
|
clean = dict(msg)
|
|
clean["content"] = [content]
|
|
result.append(clean)
|
|
continue
|
|
|
|
result.append(msg)
|
|
return result
|
|
|
|
@staticmethod
|
|
def _tool_name(tool: dict[str, Any]) -> str:
|
|
"""Extract tool name from either OpenAI or Anthropic-style tool schemas."""
|
|
name = tool.get("name")
|
|
if isinstance(name, str):
|
|
return name
|
|
fn = tool.get("function")
|
|
if isinstance(fn, dict):
|
|
fname = fn.get("name")
|
|
if isinstance(fname, str):
|
|
return fname
|
|
return ""
|
|
|
|
@classmethod
|
|
def _tool_cache_marker_indices(cls, tools: list[dict[str, Any]]) -> list[int]:
|
|
"""Return cache marker indices: builtin/MCP boundary and tail index."""
|
|
if not tools:
|
|
return []
|
|
|
|
tail_idx = len(tools) - 1
|
|
last_builtin_idx: int | None = None
|
|
for i in range(tail_idx, -1, -1):
|
|
if not cls._tool_name(tools[i]).startswith("mcp_"):
|
|
last_builtin_idx = i
|
|
break
|
|
|
|
ordered_unique: list[int] = []
|
|
for idx in (last_builtin_idx, tail_idx):
|
|
if idx is not None and idx not in ordered_unique:
|
|
ordered_unique.append(idx)
|
|
return ordered_unique
|
|
|
|
@staticmethod
|
|
def _sanitize_request_messages(
|
|
messages: list[dict[str, Any]],
|
|
allowed_keys: frozenset[str],
|
|
) -> list[dict[str, Any]]:
|
|
"""Keep only provider-safe message keys and normalize assistant content."""
|
|
sanitized = []
|
|
for msg in messages:
|
|
clean = {k: v for k, v in msg.items() if k in allowed_keys}
|
|
if clean.get("role") == "assistant" and "content" not in clean:
|
|
clean["content"] = None
|
|
sanitized.append(clean)
|
|
return sanitized
|
|
|
|
@abstractmethod
|
|
async def chat(
|
|
self,
|
|
messages: list[dict[str, Any]],
|
|
tools: list[dict[str, Any]] | None = None,
|
|
model: str | None = None,
|
|
max_tokens: int = 4096,
|
|
temperature: float = 0.7,
|
|
reasoning_effort: str | None = None,
|
|
tool_choice: str | dict[str, Any] | None = None,
|
|
) -> LLMResponse:
|
|
"""
|
|
Send a chat completion request.
|
|
|
|
Args:
|
|
messages: List of message dicts with 'role' and 'content'.
|
|
tools: Optional list of tool definitions.
|
|
model: Model identifier (provider-specific).
|
|
max_tokens: Maximum tokens in response.
|
|
temperature: Sampling temperature.
|
|
tool_choice: Tool selection strategy ("auto", "required", or specific tool dict).
|
|
|
|
Returns:
|
|
LLMResponse with content and/or tool calls.
|
|
"""
|
|
pass
|
|
|
|
@classmethod
|
|
def _is_transient_error(cls, content: str | None) -> bool:
|
|
err = (content or "").lower()
|
|
return any(marker in err for marker in cls._TRANSIENT_ERROR_MARKERS)
|
|
|
|
@classmethod
|
|
def _is_transient_response(cls, response: LLMResponse) -> bool:
|
|
"""Prefer structured error metadata, fallback to text markers for legacy providers."""
|
|
if response.error_should_retry is not None:
|
|
return bool(response.error_should_retry)
|
|
|
|
if response.error_status_code is not None:
|
|
status = int(response.error_status_code)
|
|
if status == 429:
|
|
return cls._is_retryable_429_response(response)
|
|
if status in cls._RETRYABLE_STATUS_CODES or status >= 500:
|
|
return True
|
|
|
|
kind = (response.error_kind or "").strip().lower()
|
|
if kind in cls._TRANSIENT_ERROR_KINDS:
|
|
return True
|
|
|
|
return cls._is_transient_error(response.content)
|
|
|
|
@staticmethod
|
|
def _normalize_error_token(value: Any) -> str | None:
|
|
if value is None:
|
|
return None
|
|
token = str(value).strip().lower()
|
|
return token or None
|
|
|
|
@classmethod
|
|
def _extract_error_type_code(cls, payload: Any) -> tuple[str | None, str | None]:
|
|
data: dict[str, Any] | None = None
|
|
if isinstance(payload, dict):
|
|
data = payload
|
|
elif isinstance(payload, str):
|
|
text = payload.strip()
|
|
if text:
|
|
try:
|
|
parsed = json.loads(text)
|
|
except Exception:
|
|
parsed = None
|
|
if isinstance(parsed, dict):
|
|
data = parsed
|
|
if not isinstance(data, dict):
|
|
return None, None
|
|
|
|
error_obj = data.get("error")
|
|
type_value = data.get("type")
|
|
code_value = data.get("code")
|
|
if isinstance(error_obj, dict):
|
|
type_value = error_obj.get("type") or type_value
|
|
code_value = error_obj.get("code") or code_value
|
|
|
|
return cls._normalize_error_token(type_value), cls._normalize_error_token(code_value)
|
|
|
|
@classmethod
|
|
def _is_retryable_429_response(cls, response: LLMResponse) -> bool:
|
|
type_token = cls._normalize_error_token(response.error_type)
|
|
code_token = cls._normalize_error_token(response.error_code)
|
|
semantic_tokens = {token for token in (type_token, code_token) if token is not None}
|
|
if any(token in cls._NON_RETRYABLE_429_ERROR_TOKENS for token in semantic_tokens):
|
|
return False
|
|
|
|
content = (response.content or "").lower()
|
|
if any(marker in content for marker in cls._NON_RETRYABLE_429_TEXT_MARKERS):
|
|
return False
|
|
|
|
if any(token in cls._RETRYABLE_429_ERROR_TOKENS for token in semantic_tokens):
|
|
return True
|
|
if any(marker in content for marker in cls._RETRYABLE_429_TEXT_MARKERS):
|
|
return True
|
|
# Unknown 429 defaults to WAIT+retry.
|
|
return True
|
|
|
|
@staticmethod
|
|
def _enforce_role_alternation(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
"""Merge consecutive same-role messages and drop trailing assistant messages.
|
|
|
|
Some providers (OpenAI-compat, Azure, vLLM, Ollama, etc.) reject requests
|
|
where the last message is 'assistant' (prefill not supported) or two
|
|
consecutive non-system messages share the same role.
|
|
"""
|
|
if not messages:
|
|
return messages
|
|
|
|
merged: list[dict[str, Any]] = []
|
|
for msg in messages:
|
|
role = msg.get("role")
|
|
if (
|
|
merged
|
|
and role != "system"
|
|
and role not in ("tool",)
|
|
and merged[-1].get("role") == role
|
|
and role in ("user", "assistant")
|
|
):
|
|
prev = merged[-1]
|
|
if role == "assistant":
|
|
prev_has_tools = bool(prev.get("tool_calls"))
|
|
curr_has_tools = bool(msg.get("tool_calls"))
|
|
if curr_has_tools:
|
|
merged[-1] = dict(msg)
|
|
continue
|
|
if prev_has_tools:
|
|
continue
|
|
prev_content = prev.get("content") or ""
|
|
curr_content = msg.get("content") or ""
|
|
if isinstance(prev_content, str) and isinstance(curr_content, str):
|
|
prev["content"] = (prev_content + "\n\n" + curr_content).strip()
|
|
else:
|
|
merged[-1] = dict(msg)
|
|
else:
|
|
merged.append(dict(msg))
|
|
|
|
last_popped = None
|
|
while merged and merged[-1].get("role") == "assistant":
|
|
last_popped = merged.pop()
|
|
|
|
# If removing trailing assistant messages left only system messages,
|
|
# the request would be invalid for most providers (e.g. Zhipu/GLM
|
|
# error 1214). Recover by converting the last popped assistant
|
|
# message to a user message so the LLM can still see the content.
|
|
if (
|
|
merged
|
|
and last_popped is not None
|
|
and not any(m.get("role") in ("user", "tool") for m in merged)
|
|
):
|
|
recovered = dict(last_popped)
|
|
recovered["role"] = "user"
|
|
merged.append(recovered)
|
|
|
|
# Safety net: ensure the first non-system message is not a bare
|
|
# ``assistant`` message. Providers like GLM reject system→assistant
|
|
# with error 1214. This can happen when upstream truncation (e.g.
|
|
# _snip_history) drops the only user message. Insert a synthetic
|
|
# user message to keep the sequence valid.
|
|
for i, msg in enumerate(merged):
|
|
if msg.get("role") != "system":
|
|
if msg.get("role") == "assistant" and not msg.get("tool_calls"):
|
|
merged.insert(i, {"role": "user", "content": _SYNTHETIC_USER_CONTENT})
|
|
break
|
|
|
|
return merged
|
|
|
|
@staticmethod
|
|
def _strip_image_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]] | None:
|
|
"""Replace image_url blocks with text placeholder. Returns None if no images found."""
|
|
found = False
|
|
result = []
|
|
for msg in messages:
|
|
content = msg.get("content")
|
|
if isinstance(content, list):
|
|
new_content = []
|
|
for b in content:
|
|
if isinstance(b, dict) and b.get("type") == "image_url":
|
|
path = (b.get("_meta") or {}).get("path", "")
|
|
placeholder = image_placeholder_text(path, empty="[image omitted]")
|
|
new_content.append({"type": "text", "text": placeholder})
|
|
found = True
|
|
else:
|
|
new_content.append(b)
|
|
result.append({**msg, "content": new_content})
|
|
else:
|
|
result.append(msg)
|
|
return result if found else None
|
|
|
|
@staticmethod
|
|
def _strip_image_content_inplace(messages: list[dict[str, Any]]) -> bool:
|
|
"""Replace image_url blocks with text placeholder *in-place*.
|
|
|
|
Mutates the content lists of the original message dicts so that
|
|
callers holding references to those dicts also see the stripped
|
|
version.
|
|
"""
|
|
found = False
|
|
for msg in messages:
|
|
content = msg.get("content")
|
|
if isinstance(content, list):
|
|
for i, b in enumerate(content):
|
|
if isinstance(b, dict) and b.get("type") == "image_url":
|
|
path = (b.get("_meta") or {}).get("path", "")
|
|
placeholder = image_placeholder_text(path, empty="[image omitted]")
|
|
content[i] = {"type": "text", "text": placeholder}
|
|
found = True
|
|
return found
|
|
|
|
async def _safe_chat(self, **kwargs: Any) -> LLMResponse:
|
|
"""Call chat() and convert unexpected exceptions to error responses."""
|
|
try:
|
|
return await self.chat(**kwargs)
|
|
except asyncio.CancelledError:
|
|
raise
|
|
except Exception as exc:
|
|
return LLMResponse(content=f"Error calling LLM: {exc}", finish_reason="error")
|
|
|
|
async def chat_stream(
|
|
self,
|
|
messages: list[dict[str, Any]],
|
|
tools: list[dict[str, Any]] | None = None,
|
|
model: str | None = None,
|
|
max_tokens: int = 4096,
|
|
temperature: float = 0.7,
|
|
reasoning_effort: str | None = None,
|
|
tool_choice: str | dict[str, Any] | None = None,
|
|
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
|
) -> LLMResponse:
|
|
"""Stream a chat completion, calling *on_content_delta* for each text chunk.
|
|
|
|
Returns the same ``LLMResponse`` as :meth:`chat`. The default
|
|
implementation falls back to a non-streaming call and delivers the
|
|
full content as a single delta. Providers that support native
|
|
streaming should override this method.
|
|
"""
|
|
response = await self.chat(
|
|
messages=messages,
|
|
tools=tools,
|
|
model=model,
|
|
max_tokens=max_tokens,
|
|
temperature=temperature,
|
|
reasoning_effort=reasoning_effort,
|
|
tool_choice=tool_choice,
|
|
)
|
|
if on_content_delta and response.content:
|
|
await on_content_delta(response.content)
|
|
return response
|
|
|
|
async def _safe_chat_stream(self, **kwargs: Any) -> LLMResponse:
|
|
"""Call chat_stream() and convert unexpected exceptions to error responses."""
|
|
try:
|
|
return await self.chat_stream(**kwargs)
|
|
except asyncio.CancelledError:
|
|
raise
|
|
except Exception as exc:
|
|
return LLMResponse(content=f"Error calling LLM: {exc}", finish_reason="error")
|
|
|
|
async def chat_stream_with_retry(
|
|
self,
|
|
messages: list[dict[str, Any]],
|
|
tools: list[dict[str, Any]] | None = None,
|
|
model: str | None = None,
|
|
max_tokens: object = _SENTINEL,
|
|
temperature: object = _SENTINEL,
|
|
reasoning_effort: object = _SENTINEL,
|
|
tool_choice: str | dict[str, Any] | None = None,
|
|
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
|
retry_mode: str = "standard",
|
|
on_retry_wait: Callable[[str], Awaitable[None]] | None = None,
|
|
) -> LLMResponse:
|
|
"""Call chat_stream() with retry on transient provider failures."""
|
|
if max_tokens is self._SENTINEL or max_tokens is None:
|
|
max_tokens = self.generation.max_tokens
|
|
if temperature is self._SENTINEL or temperature is None:
|
|
temperature = self.generation.temperature
|
|
if reasoning_effort is self._SENTINEL:
|
|
reasoning_effort = self.generation.reasoning_effort
|
|
|
|
kw: dict[str, Any] = dict(
|
|
messages=messages,
|
|
tools=tools,
|
|
model=model,
|
|
max_tokens=max_tokens,
|
|
temperature=temperature,
|
|
reasoning_effort=reasoning_effort,
|
|
tool_choice=tool_choice,
|
|
on_content_delta=on_content_delta,
|
|
)
|
|
return await self._run_with_retry(
|
|
self._safe_chat_stream,
|
|
kw,
|
|
messages,
|
|
retry_mode=retry_mode,
|
|
on_retry_wait=on_retry_wait,
|
|
)
|
|
|
|
async def chat_with_retry(
|
|
self,
|
|
messages: list[dict[str, Any]],
|
|
tools: list[dict[str, Any]] | None = None,
|
|
model: str | None = None,
|
|
max_tokens: object = _SENTINEL,
|
|
temperature: object = _SENTINEL,
|
|
reasoning_effort: object = _SENTINEL,
|
|
tool_choice: str | dict[str, Any] | None = None,
|
|
retry_mode: str = "standard",
|
|
on_retry_wait: Callable[[str], Awaitable[None]] | None = None,
|
|
) -> LLMResponse:
|
|
"""Call chat() with retry on transient provider failures.
|
|
|
|
Parameters default to ``self.generation`` when not explicitly passed,
|
|
so callers no longer need to thread temperature / max_tokens /
|
|
reasoning_effort through every layer. Explicit ``None`` is also
|
|
normalized to the provider's generation defaults so that downstream
|
|
``_build_kwargs`` never sees ``None`` for ``max_tokens`` / ``temperature``
|
|
(which would crash ``max(1, max_tokens)``).
|
|
"""
|
|
if max_tokens is self._SENTINEL or max_tokens is None:
|
|
max_tokens = self.generation.max_tokens
|
|
if temperature is self._SENTINEL or temperature is None:
|
|
temperature = self.generation.temperature
|
|
if reasoning_effort is self._SENTINEL:
|
|
reasoning_effort = self.generation.reasoning_effort
|
|
|
|
kw: dict[str, Any] = dict(
|
|
messages=messages,
|
|
tools=tools,
|
|
model=model,
|
|
max_tokens=max_tokens,
|
|
temperature=temperature,
|
|
reasoning_effort=reasoning_effort,
|
|
tool_choice=tool_choice,
|
|
)
|
|
return await self._run_with_retry(
|
|
self._safe_chat,
|
|
kw,
|
|
messages,
|
|
retry_mode=retry_mode,
|
|
on_retry_wait=on_retry_wait,
|
|
)
|
|
|
|
@classmethod
|
|
def _extract_retry_after(cls, content: str | None) -> float | None:
|
|
text = (content or "").lower()
|
|
patterns = (
|
|
r"retry after\s+(\d+(?:\.\d+)?)\s*(ms|milliseconds|s|sec|secs|seconds|m|min|minutes)?",
|
|
r"try again in\s+(\d+(?:\.\d+)?)\s*(ms|milliseconds|s|sec|secs|seconds|m|min|minutes)",
|
|
r"wait\s+(\d+(?:\.\d+)?)\s*(ms|milliseconds|s|sec|secs|seconds|m|min|minutes)\s*before retry",
|
|
r"retry[_-]?after[\"'\s:=]+(\d+(?:\.\d+)?)",
|
|
)
|
|
for idx, pattern in enumerate(patterns):
|
|
match = re.search(pattern, text)
|
|
if not match:
|
|
continue
|
|
value = float(match.group(1))
|
|
unit = match.group(2) if idx < 3 else "s"
|
|
return cls._to_retry_seconds(value, unit)
|
|
return None
|
|
|
|
@classmethod
|
|
def _to_retry_seconds(cls, value: float, unit: str | None = None) -> float:
|
|
normalized_unit = (unit or "s").lower()
|
|
if normalized_unit in {"ms", "milliseconds"}:
|
|
return max(0.1, value / 1000.0)
|
|
if normalized_unit in {"m", "min", "minutes"}:
|
|
return max(0.1, value * 60.0)
|
|
return max(0.1, value)
|
|
|
|
@classmethod
|
|
def _extract_retry_after_from_headers(cls, headers: Any) -> float | None:
|
|
if not headers:
|
|
return None
|
|
|
|
def _header_value(name: str) -> Any:
|
|
if hasattr(headers, "get"):
|
|
value = headers.get(name) or headers.get(name.title())
|
|
if value is not None:
|
|
return value
|
|
if isinstance(headers, dict):
|
|
for key, value in headers.items():
|
|
if isinstance(key, str) and key.lower() == name.lower():
|
|
return value
|
|
return None
|
|
|
|
try:
|
|
retry_ms = _header_value("retry-after-ms")
|
|
if retry_ms is not None:
|
|
value = float(retry_ms) / 1000.0
|
|
if value > 0:
|
|
return value
|
|
except (TypeError, ValueError):
|
|
pass
|
|
|
|
retry_after = _header_value("retry-after")
|
|
if retry_after is None:
|
|
return None
|
|
retry_after_text = str(retry_after).strip()
|
|
if not retry_after_text:
|
|
return None
|
|
if re.fullmatch(r"\d+(?:\.\d+)?", retry_after_text):
|
|
return cls._to_retry_seconds(float(retry_after_text), "s")
|
|
try:
|
|
retry_at = parsedate_to_datetime(retry_after_text)
|
|
except Exception:
|
|
return None
|
|
if retry_at.tzinfo is None:
|
|
retry_at = retry_at.replace(tzinfo=timezone.utc)
|
|
remaining = (retry_at - datetime.now(retry_at.tzinfo)).total_seconds()
|
|
return max(0.1, remaining)
|
|
|
|
@classmethod
|
|
def _extract_retry_after_from_response(cls, response: LLMResponse) -> float | None:
|
|
if response.error_retry_after_s is not None and response.error_retry_after_s > 0:
|
|
return response.error_retry_after_s
|
|
if response.retry_after is not None and response.retry_after > 0:
|
|
return response.retry_after
|
|
return cls._extract_retry_after(response.content)
|
|
|
|
async def _sleep_with_heartbeat(
|
|
self,
|
|
delay: float,
|
|
*,
|
|
attempt: int,
|
|
persistent: bool,
|
|
on_retry_wait: Callable[[str], Awaitable[None]] | None = None,
|
|
) -> None:
|
|
remaining = max(0.0, delay)
|
|
while remaining > 0:
|
|
if on_retry_wait:
|
|
kind = "persistent retry" if persistent else "retry"
|
|
await on_retry_wait(
|
|
f"Model request failed, {kind} in {max(1, int(round(remaining)))}s "
|
|
f"(attempt {attempt})."
|
|
)
|
|
chunk = min(remaining, self._RETRY_HEARTBEAT_CHUNK)
|
|
await asyncio.sleep(chunk)
|
|
remaining -= chunk
|
|
|
|
async def _run_with_retry(
|
|
self,
|
|
call: Callable[..., Awaitable[LLMResponse]],
|
|
kw: dict[str, Any],
|
|
original_messages: list[dict[str, Any]],
|
|
*,
|
|
retry_mode: str,
|
|
on_retry_wait: Callable[[str], Awaitable[None]] | None,
|
|
) -> LLMResponse:
|
|
attempt = 0
|
|
delays = list(self._CHAT_RETRY_DELAYS)
|
|
persistent = retry_mode == "persistent"
|
|
last_response: LLMResponse | None = None
|
|
last_error_key: str | None = None
|
|
identical_error_count = 0
|
|
while True:
|
|
attempt += 1
|
|
response = await call(**kw)
|
|
if response.finish_reason != "error":
|
|
return response
|
|
last_response = response
|
|
error_key = (response.content or "").strip().lower() or None
|
|
if error_key and error_key == last_error_key:
|
|
identical_error_count += 1
|
|
else:
|
|
last_error_key = error_key
|
|
identical_error_count = 1 if error_key else 0
|
|
|
|
if not self._is_transient_response(response):
|
|
stripped = self._strip_image_content(original_messages)
|
|
if stripped is not None and stripped != kw["messages"]:
|
|
logger.warning(
|
|
"Non-transient LLM error with image content, retrying without images"
|
|
)
|
|
retry_kw = dict(kw)
|
|
retry_kw["messages"] = stripped
|
|
result = await call(**retry_kw)
|
|
# Permanently strip images from the original messages so
|
|
# subsequent iterations do not repeat the error-retry cycle.
|
|
if result.finish_reason != "error":
|
|
self._strip_image_content_inplace(original_messages)
|
|
return result
|
|
return response
|
|
|
|
if persistent and identical_error_count >= self._PERSISTENT_IDENTICAL_ERROR_LIMIT:
|
|
logger.warning(
|
|
"Stopping persistent retry after {} identical transient errors: {}",
|
|
identical_error_count,
|
|
(response.content or "")[:120].lower(),
|
|
)
|
|
if on_retry_wait:
|
|
await on_retry_wait(
|
|
f"Persistent retry stopped after {identical_error_count} identical errors."
|
|
)
|
|
return response
|
|
|
|
if not persistent and attempt > len(delays):
|
|
logger.warning(
|
|
"LLM request failed after {} retries, giving up: {}",
|
|
attempt,
|
|
(response.content or "")[:120].lower(),
|
|
)
|
|
if on_retry_wait:
|
|
await on_retry_wait(f"Model request failed after {attempt} retries, giving up.")
|
|
break
|
|
|
|
base_delay = delays[min(attempt - 1, len(delays) - 1)]
|
|
delay = self._extract_retry_after_from_response(response) or base_delay
|
|
if persistent:
|
|
delay = min(delay, self._PERSISTENT_MAX_DELAY)
|
|
|
|
logger.warning(
|
|
"LLM transient error (attempt {}{}), retrying in {}s: {}",
|
|
attempt,
|
|
"+" if persistent and attempt > len(delays) else f"/{len(delays)}",
|
|
int(round(delay)),
|
|
(response.content or "")[:120].lower(),
|
|
)
|
|
await self._sleep_with_heartbeat(
|
|
delay,
|
|
attempt=attempt,
|
|
persistent=persistent,
|
|
on_retry_wait=on_retry_wait,
|
|
)
|
|
|
|
return last_response if last_response is not None else await call(**kw)
|
|
|
|
@abstractmethod
|
|
def get_default_model(self) -> str:
|
|
"""Get the default model for this provider."""
|
|
pass
|