2026-04-17 20:39:46 +08:00

807 lines
30 KiB
Python

"""Base LLM provider interface."""
import asyncio
import json
import re
from abc import ABC, abstractmethod
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field
from datetime import datetime, timezone
from email.utils import parsedate_to_datetime
from typing import Any
from loguru import logger
from nanobot.utils.helpers import image_placeholder_text
@dataclass
class ToolCallRequest:
"""A tool call request from the LLM."""
id: str
name: str
arguments: dict[str, Any]
extra_content: dict[str, Any] | None = None
provider_specific_fields: dict[str, Any] | None = None
function_provider_specific_fields: dict[str, Any] | None = None
def to_openai_tool_call(self) -> dict[str, Any]:
"""Serialize to an OpenAI-style tool_call payload."""
tool_call = {
"id": self.id,
"type": "function",
"function": {
"name": self.name,
"arguments": json.dumps(self.arguments, ensure_ascii=False),
},
}
if self.extra_content:
tool_call["extra_content"] = self.extra_content
if self.provider_specific_fields:
tool_call["provider_specific_fields"] = self.provider_specific_fields
if self.function_provider_specific_fields:
tool_call["function"]["provider_specific_fields"] = (
self.function_provider_specific_fields
)
return tool_call
@dataclass
class LLMResponse:
"""Response from an LLM provider."""
content: str | None
tool_calls: list[ToolCallRequest] = field(default_factory=list)
finish_reason: str = "stop"
usage: dict[str, int] = field(default_factory=dict)
retry_after: float | None = None # Provider supplied retry wait in seconds.
reasoning_content: str | None = None # Kimi, DeepSeek-R1, MiMo etc.
thinking_blocks: list[dict] | None = None # Anthropic extended thinking
# Structured error metadata used by retry policy when finish_reason == "error".
error_status_code: int | None = None
error_kind: str | None = None # e.g. "timeout", "connection"
error_type: str | None = None # Provider/type semantic, e.g. insufficient_quota.
error_code: str | None = None # Provider/code semantic, e.g. rate_limit_exceeded.
error_retry_after_s: float | None = None
error_should_retry: bool | None = None
@property
def has_tool_calls(self) -> bool:
"""Check if response contains tool calls."""
return len(self.tool_calls) > 0
@property
def should_execute_tools(self) -> bool:
"""True only if tool_calls present AND finish_reason is a known-good signal (``tool_calls`` or ``stop``); blocks gateway-injected calls under ``refusal`` / ``content_filter`` / ``error``."""
if not self.has_tool_calls:
return False
return self.finish_reason in ("tool_calls", "stop")
@dataclass(frozen=True)
class GenerationSettings:
"""Default generation settings."""
temperature: float = 0.7
max_tokens: int = 4096
reasoning_effort: str | None = None
_SYNTHETIC_USER_CONTENT = "(conversation continued)"
class LLMProvider(ABC):
"""Base class for LLM providers."""
_CHAT_RETRY_DELAYS = (1, 2, 4)
_PERSISTENT_MAX_DELAY = 60
_PERSISTENT_IDENTICAL_ERROR_LIMIT = 10
_RETRY_HEARTBEAT_CHUNK = 30
_TRANSIENT_ERROR_MARKERS = (
"429",
"rate limit",
"500",
"502",
"503",
"504",
"overloaded",
"timeout",
"timed out",
"connection",
"server error",
"temporarily unavailable",
)
_RETRYABLE_STATUS_CODES = frozenset({408, 409, 429})
_TRANSIENT_ERROR_KINDS = frozenset({"timeout", "connection"})
_NON_RETRYABLE_429_ERROR_TOKENS = frozenset(
{
"insufficient_quota",
"quota_exceeded",
"quota_exhausted",
"billing_hard_limit_reached",
"insufficient_balance",
"credit_balance_too_low",
"billing_not_active",
"payment_required",
}
)
_RETRYABLE_429_ERROR_TOKENS = frozenset(
{
"rate_limit_exceeded",
"rate_limit_error",
"too_many_requests",
"request_limit_exceeded",
"requests_limit_exceeded",
"overloaded_error",
}
)
_NON_RETRYABLE_429_TEXT_MARKERS = (
"insufficient_quota",
"insufficient quota",
"quota exceeded",
"quota exhausted",
"billing hard limit",
"billing_hard_limit_reached",
"billing not active",
"insufficient balance",
"insufficient_balance",
"credit balance too low",
"payment required",
"out of credits",
"out of quota",
"exceeded your current quota",
)
_RETRYABLE_429_TEXT_MARKERS = (
"rate limit",
"rate_limit",
"too many requests",
"retry after",
"try again in",
"temporarily unavailable",
"overloaded",
"concurrency limit",
)
_SENTINEL = object()
def __init__(self, api_key: str | None = None, api_base: str | None = None):
self.api_key = api_key
self.api_base = api_base
self.generation: GenerationSettings = GenerationSettings()
@staticmethod
def _sanitize_empty_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Sanitize message content: fix empty blocks, strip internal _meta fields."""
result: list[dict[str, Any]] = []
for msg in messages:
content = msg.get("content")
if isinstance(content, str) and not content:
clean = dict(msg)
clean["content"] = (
None
if (msg.get("role") == "assistant" and msg.get("tool_calls"))
else "(empty)"
)
result.append(clean)
continue
if isinstance(content, list):
new_items: list[Any] = []
changed = False
for item in content:
if (
isinstance(item, dict)
and item.get("type") in ("text", "input_text", "output_text")
and not item.get("text")
):
changed = True
continue
if isinstance(item, dict) and "_meta" in item:
new_items.append({k: v for k, v in item.items() if k != "_meta"})
changed = True
else:
new_items.append(item)
if changed:
clean = dict(msg)
if new_items:
clean["content"] = new_items
elif msg.get("role") == "assistant" and msg.get("tool_calls"):
clean["content"] = None
else:
clean["content"] = "(empty)"
result.append(clean)
continue
if isinstance(content, dict):
clean = dict(msg)
clean["content"] = [content]
result.append(clean)
continue
result.append(msg)
return result
@staticmethod
def _tool_name(tool: dict[str, Any]) -> str:
"""Extract tool name from either OpenAI or Anthropic-style tool schemas."""
name = tool.get("name")
if isinstance(name, str):
return name
fn = tool.get("function")
if isinstance(fn, dict):
fname = fn.get("name")
if isinstance(fname, str):
return fname
return ""
@classmethod
def _tool_cache_marker_indices(cls, tools: list[dict[str, Any]]) -> list[int]:
"""Return cache marker indices: builtin/MCP boundary and tail index."""
if not tools:
return []
tail_idx = len(tools) - 1
last_builtin_idx: int | None = None
for i in range(tail_idx, -1, -1):
if not cls._tool_name(tools[i]).startswith("mcp_"):
last_builtin_idx = i
break
ordered_unique: list[int] = []
for idx in (last_builtin_idx, tail_idx):
if idx is not None and idx not in ordered_unique:
ordered_unique.append(idx)
return ordered_unique
@staticmethod
def _sanitize_request_messages(
messages: list[dict[str, Any]],
allowed_keys: frozenset[str],
) -> list[dict[str, Any]]:
"""Keep only provider-safe message keys and normalize assistant content."""
sanitized = []
for msg in messages:
clean = {k: v for k, v in msg.items() if k in allowed_keys}
if clean.get("role") == "assistant" and "content" not in clean:
clean["content"] = None
sanitized.append(clean)
return sanitized
@abstractmethod
async def chat(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
model: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
reasoning_effort: str | None = None,
tool_choice: str | dict[str, Any] | None = None,
) -> LLMResponse:
"""
Send a chat completion request.
Args:
messages: List of message dicts with 'role' and 'content'.
tools: Optional list of tool definitions.
model: Model identifier (provider-specific).
max_tokens: Maximum tokens in response.
temperature: Sampling temperature.
tool_choice: Tool selection strategy ("auto", "required", or specific tool dict).
Returns:
LLMResponse with content and/or tool calls.
"""
pass
@classmethod
def _is_transient_error(cls, content: str | None) -> bool:
err = (content or "").lower()
return any(marker in err for marker in cls._TRANSIENT_ERROR_MARKERS)
@classmethod
def _is_transient_response(cls, response: LLMResponse) -> bool:
"""Prefer structured error metadata, fallback to text markers for legacy providers."""
if response.error_should_retry is not None:
return bool(response.error_should_retry)
if response.error_status_code is not None:
status = int(response.error_status_code)
if status == 429:
return cls._is_retryable_429_response(response)
if status in cls._RETRYABLE_STATUS_CODES or status >= 500:
return True
kind = (response.error_kind or "").strip().lower()
if kind in cls._TRANSIENT_ERROR_KINDS:
return True
return cls._is_transient_error(response.content)
@staticmethod
def _normalize_error_token(value: Any) -> str | None:
if value is None:
return None
token = str(value).strip().lower()
return token or None
@classmethod
def _extract_error_type_code(cls, payload: Any) -> tuple[str | None, str | None]:
data: dict[str, Any] | None = None
if isinstance(payload, dict):
data = payload
elif isinstance(payload, str):
text = payload.strip()
if text:
try:
parsed = json.loads(text)
except Exception:
parsed = None
if isinstance(parsed, dict):
data = parsed
if not isinstance(data, dict):
return None, None
error_obj = data.get("error")
type_value = data.get("type")
code_value = data.get("code")
if isinstance(error_obj, dict):
type_value = error_obj.get("type") or type_value
code_value = error_obj.get("code") or code_value
return cls._normalize_error_token(type_value), cls._normalize_error_token(code_value)
@classmethod
def _is_retryable_429_response(cls, response: LLMResponse) -> bool:
type_token = cls._normalize_error_token(response.error_type)
code_token = cls._normalize_error_token(response.error_code)
semantic_tokens = {token for token in (type_token, code_token) if token is not None}
if any(token in cls._NON_RETRYABLE_429_ERROR_TOKENS for token in semantic_tokens):
return False
content = (response.content or "").lower()
if any(marker in content for marker in cls._NON_RETRYABLE_429_TEXT_MARKERS):
return False
if any(token in cls._RETRYABLE_429_ERROR_TOKENS for token in semantic_tokens):
return True
if any(marker in content for marker in cls._RETRYABLE_429_TEXT_MARKERS):
return True
# Unknown 429 defaults to WAIT+retry.
return True
@staticmethod
def _enforce_role_alternation(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Merge consecutive same-role messages and drop trailing assistant messages.
Some providers (OpenAI-compat, Azure, vLLM, Ollama, etc.) reject requests
where the last message is 'assistant' (prefill not supported) or two
consecutive non-system messages share the same role.
"""
if not messages:
return messages
merged: list[dict[str, Any]] = []
for msg in messages:
role = msg.get("role")
if (
merged
and role != "system"
and role not in ("tool",)
and merged[-1].get("role") == role
and role in ("user", "assistant")
):
prev = merged[-1]
if role == "assistant":
prev_has_tools = bool(prev.get("tool_calls"))
curr_has_tools = bool(msg.get("tool_calls"))
if curr_has_tools:
merged[-1] = dict(msg)
continue
if prev_has_tools:
continue
prev_content = prev.get("content") or ""
curr_content = msg.get("content") or ""
if isinstance(prev_content, str) and isinstance(curr_content, str):
prev["content"] = (prev_content + "\n\n" + curr_content).strip()
else:
merged[-1] = dict(msg)
else:
merged.append(dict(msg))
last_popped = None
while merged and merged[-1].get("role") == "assistant":
last_popped = merged.pop()
# If removing trailing assistant messages left only system messages,
# the request would be invalid for most providers (e.g. Zhipu/GLM
# error 1214). Recover by converting the last popped assistant
# message to a user message so the LLM can still see the content.
if (
merged
and last_popped is not None
and not any(m.get("role") in ("user", "tool") for m in merged)
):
recovered = dict(last_popped)
recovered["role"] = "user"
merged.append(recovered)
# Safety net: ensure the first non-system message is not a bare
# ``assistant`` message. Providers like GLM reject system→assistant
# with error 1214. This can happen when upstream truncation (e.g.
# _snip_history) drops the only user message. Insert a synthetic
# user message to keep the sequence valid.
for i, msg in enumerate(merged):
if msg.get("role") != "system":
if msg.get("role") == "assistant" and not msg.get("tool_calls"):
merged.insert(i, {"role": "user", "content": _SYNTHETIC_USER_CONTENT})
break
return merged
@staticmethod
def _strip_image_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]] | None:
"""Replace image_url blocks with text placeholder. Returns None if no images found."""
found = False
result = []
for msg in messages:
content = msg.get("content")
if isinstance(content, list):
new_content = []
for b in content:
if isinstance(b, dict) and b.get("type") == "image_url":
path = (b.get("_meta") or {}).get("path", "")
placeholder = image_placeholder_text(path, empty="[image omitted]")
new_content.append({"type": "text", "text": placeholder})
found = True
else:
new_content.append(b)
result.append({**msg, "content": new_content})
else:
result.append(msg)
return result if found else None
@staticmethod
def _strip_image_content_inplace(messages: list[dict[str, Any]]) -> bool:
"""Replace image_url blocks with text placeholder *in-place*.
Mutates the content lists of the original message dicts so that
callers holding references to those dicts also see the stripped
version.
"""
found = False
for msg in messages:
content = msg.get("content")
if isinstance(content, list):
for i, b in enumerate(content):
if isinstance(b, dict) and b.get("type") == "image_url":
path = (b.get("_meta") or {}).get("path", "")
placeholder = image_placeholder_text(path, empty="[image omitted]")
content[i] = {"type": "text", "text": placeholder}
found = True
return found
async def _safe_chat(self, **kwargs: Any) -> LLMResponse:
"""Call chat() and convert unexpected exceptions to error responses."""
try:
return await self.chat(**kwargs)
except asyncio.CancelledError:
raise
except Exception as exc:
return LLMResponse(content=f"Error calling LLM: {exc}", finish_reason="error")
async def chat_stream(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
model: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
reasoning_effort: str | None = None,
tool_choice: str | dict[str, Any] | None = None,
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
) -> LLMResponse:
"""Stream a chat completion, calling *on_content_delta* for each text chunk.
Returns the same ``LLMResponse`` as :meth:`chat`. The default
implementation falls back to a non-streaming call and delivers the
full content as a single delta. Providers that support native
streaming should override this method.
"""
response = await self.chat(
messages=messages,
tools=tools,
model=model,
max_tokens=max_tokens,
temperature=temperature,
reasoning_effort=reasoning_effort,
tool_choice=tool_choice,
)
if on_content_delta and response.content:
await on_content_delta(response.content)
return response
async def _safe_chat_stream(self, **kwargs: Any) -> LLMResponse:
"""Call chat_stream() and convert unexpected exceptions to error responses."""
try:
return await self.chat_stream(**kwargs)
except asyncio.CancelledError:
raise
except Exception as exc:
return LLMResponse(content=f"Error calling LLM: {exc}", finish_reason="error")
async def chat_stream_with_retry(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
model: str | None = None,
max_tokens: object = _SENTINEL,
temperature: object = _SENTINEL,
reasoning_effort: object = _SENTINEL,
tool_choice: str | dict[str, Any] | None = None,
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
retry_mode: str = "standard",
on_retry_wait: Callable[[str], Awaitable[None]] | None = None,
) -> LLMResponse:
"""Call chat_stream() with retry on transient provider failures."""
if max_tokens is self._SENTINEL or max_tokens is None:
max_tokens = self.generation.max_tokens
if temperature is self._SENTINEL or temperature is None:
temperature = self.generation.temperature
if reasoning_effort is self._SENTINEL:
reasoning_effort = self.generation.reasoning_effort
kw: dict[str, Any] = dict(
messages=messages,
tools=tools,
model=model,
max_tokens=max_tokens,
temperature=temperature,
reasoning_effort=reasoning_effort,
tool_choice=tool_choice,
on_content_delta=on_content_delta,
)
return await self._run_with_retry(
self._safe_chat_stream,
kw,
messages,
retry_mode=retry_mode,
on_retry_wait=on_retry_wait,
)
async def chat_with_retry(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
model: str | None = None,
max_tokens: object = _SENTINEL,
temperature: object = _SENTINEL,
reasoning_effort: object = _SENTINEL,
tool_choice: str | dict[str, Any] | None = None,
retry_mode: str = "standard",
on_retry_wait: Callable[[str], Awaitable[None]] | None = None,
) -> LLMResponse:
"""Call chat() with retry on transient provider failures.
Parameters default to ``self.generation`` when not explicitly passed,
so callers no longer need to thread temperature / max_tokens /
reasoning_effort through every layer. Explicit ``None`` is also
normalized to the provider's generation defaults so that downstream
``_build_kwargs`` never sees ``None`` for ``max_tokens`` / ``temperature``
(which would crash ``max(1, max_tokens)``).
"""
if max_tokens is self._SENTINEL or max_tokens is None:
max_tokens = self.generation.max_tokens
if temperature is self._SENTINEL or temperature is None:
temperature = self.generation.temperature
if reasoning_effort is self._SENTINEL:
reasoning_effort = self.generation.reasoning_effort
kw: dict[str, Any] = dict(
messages=messages,
tools=tools,
model=model,
max_tokens=max_tokens,
temperature=temperature,
reasoning_effort=reasoning_effort,
tool_choice=tool_choice,
)
return await self._run_with_retry(
self._safe_chat,
kw,
messages,
retry_mode=retry_mode,
on_retry_wait=on_retry_wait,
)
@classmethod
def _extract_retry_after(cls, content: str | None) -> float | None:
text = (content or "").lower()
patterns = (
r"retry after\s+(\d+(?:\.\d+)?)\s*(ms|milliseconds|s|sec|secs|seconds|m|min|minutes)?",
r"try again in\s+(\d+(?:\.\d+)?)\s*(ms|milliseconds|s|sec|secs|seconds|m|min|minutes)",
r"wait\s+(\d+(?:\.\d+)?)\s*(ms|milliseconds|s|sec|secs|seconds|m|min|minutes)\s*before retry",
r"retry[_-]?after[\"'\s:=]+(\d+(?:\.\d+)?)",
)
for idx, pattern in enumerate(patterns):
match = re.search(pattern, text)
if not match:
continue
value = float(match.group(1))
unit = match.group(2) if idx < 3 else "s"
return cls._to_retry_seconds(value, unit)
return None
@classmethod
def _to_retry_seconds(cls, value: float, unit: str | None = None) -> float:
normalized_unit = (unit or "s").lower()
if normalized_unit in {"ms", "milliseconds"}:
return max(0.1, value / 1000.0)
if normalized_unit in {"m", "min", "minutes"}:
return max(0.1, value * 60.0)
return max(0.1, value)
@classmethod
def _extract_retry_after_from_headers(cls, headers: Any) -> float | None:
if not headers:
return None
def _header_value(name: str) -> Any:
if hasattr(headers, "get"):
value = headers.get(name) or headers.get(name.title())
if value is not None:
return value
if isinstance(headers, dict):
for key, value in headers.items():
if isinstance(key, str) and key.lower() == name.lower():
return value
return None
try:
retry_ms = _header_value("retry-after-ms")
if retry_ms is not None:
value = float(retry_ms) / 1000.0
if value > 0:
return value
except (TypeError, ValueError):
pass
retry_after = _header_value("retry-after")
if retry_after is None:
return None
retry_after_text = str(retry_after).strip()
if not retry_after_text:
return None
if re.fullmatch(r"\d+(?:\.\d+)?", retry_after_text):
return cls._to_retry_seconds(float(retry_after_text), "s")
try:
retry_at = parsedate_to_datetime(retry_after_text)
except Exception:
return None
if retry_at.tzinfo is None:
retry_at = retry_at.replace(tzinfo=timezone.utc)
remaining = (retry_at - datetime.now(retry_at.tzinfo)).total_seconds()
return max(0.1, remaining)
@classmethod
def _extract_retry_after_from_response(cls, response: LLMResponse) -> float | None:
if response.error_retry_after_s is not None and response.error_retry_after_s > 0:
return response.error_retry_after_s
if response.retry_after is not None and response.retry_after > 0:
return response.retry_after
return cls._extract_retry_after(response.content)
async def _sleep_with_heartbeat(
self,
delay: float,
*,
attempt: int,
persistent: bool,
on_retry_wait: Callable[[str], Awaitable[None]] | None = None,
) -> None:
remaining = max(0.0, delay)
while remaining > 0:
if on_retry_wait:
kind = "persistent retry" if persistent else "retry"
await on_retry_wait(
f"Model request failed, {kind} in {max(1, int(round(remaining)))}s "
f"(attempt {attempt})."
)
chunk = min(remaining, self._RETRY_HEARTBEAT_CHUNK)
await asyncio.sleep(chunk)
remaining -= chunk
async def _run_with_retry(
self,
call: Callable[..., Awaitable[LLMResponse]],
kw: dict[str, Any],
original_messages: list[dict[str, Any]],
*,
retry_mode: str,
on_retry_wait: Callable[[str], Awaitable[None]] | None,
) -> LLMResponse:
attempt = 0
delays = list(self._CHAT_RETRY_DELAYS)
persistent = retry_mode == "persistent"
last_response: LLMResponse | None = None
last_error_key: str | None = None
identical_error_count = 0
while True:
attempt += 1
response = await call(**kw)
if response.finish_reason != "error":
return response
last_response = response
error_key = (response.content or "").strip().lower() or None
if error_key and error_key == last_error_key:
identical_error_count += 1
else:
last_error_key = error_key
identical_error_count = 1 if error_key else 0
if not self._is_transient_response(response):
stripped = self._strip_image_content(original_messages)
if stripped is not None and stripped != kw["messages"]:
logger.warning(
"Non-transient LLM error with image content, retrying without images"
)
retry_kw = dict(kw)
retry_kw["messages"] = stripped
result = await call(**retry_kw)
# Permanently strip images from the original messages so
# subsequent iterations do not repeat the error-retry cycle.
if result.finish_reason != "error":
self._strip_image_content_inplace(original_messages)
return result
return response
if persistent and identical_error_count >= self._PERSISTENT_IDENTICAL_ERROR_LIMIT:
logger.warning(
"Stopping persistent retry after {} identical transient errors: {}",
identical_error_count,
(response.content or "")[:120].lower(),
)
if on_retry_wait:
await on_retry_wait(
f"Persistent retry stopped after {identical_error_count} identical errors."
)
return response
if not persistent and attempt > len(delays):
logger.warning(
"LLM request failed after {} retries, giving up: {}",
attempt,
(response.content or "")[:120].lower(),
)
if on_retry_wait:
await on_retry_wait(f"Model request failed after {attempt} retries, giving up.")
break
base_delay = delays[min(attempt - 1, len(delays) - 1)]
delay = self._extract_retry_after_from_response(response) or base_delay
if persistent:
delay = min(delay, self._PERSISTENT_MAX_DELAY)
logger.warning(
"LLM transient error (attempt {}{}), retrying in {}s: {}",
attempt,
"+" if persistent and attempt > len(delays) else f"/{len(delays)}",
int(round(delay)),
(response.content or "")[:120].lower(),
)
await self._sleep_with_heartbeat(
delay,
attempt=attempt,
persistent=persistent,
on_retry_wait=on_retry_wait,
)
return last_response if last_response is not None else await call(**kw)
@abstractmethod
def get_default_model(self) -> str:
"""Get the default model for this provider."""
pass