mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-06 19:23:39 +00:00
Replace the flatten/unflatten approach (merging extra_content.google.* into provider_specific_fields then reconstructing) with direct pass-through: parse extra_content as-is, store on ToolCallRequest.extra_content, serialize back untouched. This is lossless, requires no hardcoded field names, and covers all three parsing branches (str, dict, SDK object) plus streaming.
369 lines
14 KiB
Python
369 lines
14 KiB
Python
"""Base LLM provider interface."""
|
|
|
|
import asyncio
|
|
import json
|
|
from abc import ABC, abstractmethod
|
|
from collections.abc import Awaitable, Callable
|
|
from dataclasses import dataclass, field
|
|
from typing import Any
|
|
|
|
from loguru import logger
|
|
|
|
|
|
@dataclass
|
|
class ToolCallRequest:
|
|
"""A tool call request from the LLM."""
|
|
id: str
|
|
name: str
|
|
arguments: dict[str, Any]
|
|
extra_content: dict[str, Any] | None = None
|
|
provider_specific_fields: dict[str, Any] | None = None
|
|
function_provider_specific_fields: dict[str, Any] | None = None
|
|
|
|
def to_openai_tool_call(self) -> dict[str, Any]:
|
|
"""Serialize to an OpenAI-style tool_call payload."""
|
|
tool_call = {
|
|
"id": self.id,
|
|
"type": "function",
|
|
"function": {
|
|
"name": self.name,
|
|
"arguments": json.dumps(self.arguments, ensure_ascii=False),
|
|
},
|
|
}
|
|
if self.extra_content:
|
|
tool_call["extra_content"] = self.extra_content
|
|
if self.provider_specific_fields:
|
|
tool_call["provider_specific_fields"] = self.provider_specific_fields
|
|
if self.function_provider_specific_fields:
|
|
tool_call["function"]["provider_specific_fields"] = self.function_provider_specific_fields
|
|
return tool_call
|
|
|
|
|
|
@dataclass
|
|
class LLMResponse:
|
|
"""Response from an LLM provider."""
|
|
content: str | None
|
|
tool_calls: list[ToolCallRequest] = field(default_factory=list)
|
|
finish_reason: str = "stop"
|
|
usage: dict[str, int] = field(default_factory=dict)
|
|
reasoning_content: str | None = None # Kimi, DeepSeek-R1 etc.
|
|
thinking_blocks: list[dict] | None = None # Anthropic extended thinking
|
|
|
|
@property
|
|
def has_tool_calls(self) -> bool:
|
|
"""Check if response contains tool calls."""
|
|
return len(self.tool_calls) > 0
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class GenerationSettings:
|
|
"""Default generation parameters for LLM calls.
|
|
|
|
Stored on the provider so every call site inherits the same defaults
|
|
without having to pass temperature / max_tokens / reasoning_effort
|
|
through every layer. Individual call sites can still override by
|
|
passing explicit keyword arguments to chat() / chat_with_retry().
|
|
"""
|
|
|
|
temperature: float = 0.7
|
|
max_tokens: int = 4096
|
|
reasoning_effort: str | None = None
|
|
|
|
|
|
class LLMProvider(ABC):
|
|
"""
|
|
Abstract base class for LLM providers.
|
|
|
|
Implementations should handle the specifics of each provider's API
|
|
while maintaining a consistent interface.
|
|
"""
|
|
|
|
_CHAT_RETRY_DELAYS = (1, 2, 4)
|
|
_TRANSIENT_ERROR_MARKERS = (
|
|
"429",
|
|
"rate limit",
|
|
"500",
|
|
"502",
|
|
"503",
|
|
"504",
|
|
"overloaded",
|
|
"timeout",
|
|
"timed out",
|
|
"connection",
|
|
"server error",
|
|
"temporarily unavailable",
|
|
)
|
|
|
|
_SENTINEL = object()
|
|
|
|
def __init__(self, api_key: str | None = None, api_base: str | None = None):
|
|
self.api_key = api_key
|
|
self.api_base = api_base
|
|
self.generation: GenerationSettings = GenerationSettings()
|
|
|
|
@staticmethod
|
|
def _sanitize_empty_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
"""Sanitize message content: fix empty blocks, strip internal _meta fields."""
|
|
result: list[dict[str, Any]] = []
|
|
for msg in messages:
|
|
content = msg.get("content")
|
|
|
|
if isinstance(content, str) and not content:
|
|
clean = dict(msg)
|
|
clean["content"] = None if (msg.get("role") == "assistant" and msg.get("tool_calls")) else "(empty)"
|
|
result.append(clean)
|
|
continue
|
|
|
|
if isinstance(content, list):
|
|
new_items: list[Any] = []
|
|
changed = False
|
|
for item in content:
|
|
if (
|
|
isinstance(item, dict)
|
|
and item.get("type") in ("text", "input_text", "output_text")
|
|
and not item.get("text")
|
|
):
|
|
changed = True
|
|
continue
|
|
if isinstance(item, dict) and "_meta" in item:
|
|
new_items.append({k: v for k, v in item.items() if k != "_meta"})
|
|
changed = True
|
|
else:
|
|
new_items.append(item)
|
|
if changed:
|
|
clean = dict(msg)
|
|
if new_items:
|
|
clean["content"] = new_items
|
|
elif msg.get("role") == "assistant" and msg.get("tool_calls"):
|
|
clean["content"] = None
|
|
else:
|
|
clean["content"] = "(empty)"
|
|
result.append(clean)
|
|
continue
|
|
|
|
if isinstance(content, dict):
|
|
clean = dict(msg)
|
|
clean["content"] = [content]
|
|
result.append(clean)
|
|
continue
|
|
|
|
result.append(msg)
|
|
return result
|
|
|
|
@staticmethod
|
|
def _sanitize_request_messages(
|
|
messages: list[dict[str, Any]],
|
|
allowed_keys: frozenset[str],
|
|
) -> list[dict[str, Any]]:
|
|
"""Keep only provider-safe message keys and normalize assistant content."""
|
|
sanitized = []
|
|
for msg in messages:
|
|
clean = {k: v for k, v in msg.items() if k in allowed_keys}
|
|
if clean.get("role") == "assistant" and "content" not in clean:
|
|
clean["content"] = None
|
|
sanitized.append(clean)
|
|
return sanitized
|
|
|
|
@abstractmethod
|
|
async def chat(
|
|
self,
|
|
messages: list[dict[str, Any]],
|
|
tools: list[dict[str, Any]] | None = None,
|
|
model: str | None = None,
|
|
max_tokens: int = 4096,
|
|
temperature: float = 0.7,
|
|
reasoning_effort: str | None = None,
|
|
tool_choice: str | dict[str, Any] | None = None,
|
|
) -> LLMResponse:
|
|
"""
|
|
Send a chat completion request.
|
|
|
|
Args:
|
|
messages: List of message dicts with 'role' and 'content'.
|
|
tools: Optional list of tool definitions.
|
|
model: Model identifier (provider-specific).
|
|
max_tokens: Maximum tokens in response.
|
|
temperature: Sampling temperature.
|
|
tool_choice: Tool selection strategy ("auto", "required", or specific tool dict).
|
|
|
|
Returns:
|
|
LLMResponse with content and/or tool calls.
|
|
"""
|
|
pass
|
|
|
|
@classmethod
|
|
def _is_transient_error(cls, content: str | None) -> bool:
|
|
err = (content or "").lower()
|
|
return any(marker in err for marker in cls._TRANSIENT_ERROR_MARKERS)
|
|
|
|
@staticmethod
|
|
def _strip_image_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]] | None:
|
|
"""Replace image_url blocks with text placeholder. Returns None if no images found."""
|
|
found = False
|
|
result = []
|
|
for msg in messages:
|
|
content = msg.get("content")
|
|
if isinstance(content, list):
|
|
new_content = []
|
|
for b in content:
|
|
if isinstance(b, dict) and b.get("type") == "image_url":
|
|
path = (b.get("_meta") or {}).get("path", "")
|
|
placeholder = f"[image: {path}]" if path else "[image omitted]"
|
|
new_content.append({"type": "text", "text": placeholder})
|
|
found = True
|
|
else:
|
|
new_content.append(b)
|
|
result.append({**msg, "content": new_content})
|
|
else:
|
|
result.append(msg)
|
|
return result if found else None
|
|
|
|
async def _safe_chat(self, **kwargs: Any) -> LLMResponse:
|
|
"""Call chat() and convert unexpected exceptions to error responses."""
|
|
try:
|
|
return await self.chat(**kwargs)
|
|
except asyncio.CancelledError:
|
|
raise
|
|
except Exception as exc:
|
|
return LLMResponse(content=f"Error calling LLM: {exc}", finish_reason="error")
|
|
|
|
async def chat_stream(
|
|
self,
|
|
messages: list[dict[str, Any]],
|
|
tools: list[dict[str, Any]] | None = None,
|
|
model: str | None = None,
|
|
max_tokens: int = 4096,
|
|
temperature: float = 0.7,
|
|
reasoning_effort: str | None = None,
|
|
tool_choice: str | dict[str, Any] | None = None,
|
|
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
|
) -> LLMResponse:
|
|
"""Stream a chat completion, calling *on_content_delta* for each text chunk.
|
|
|
|
Returns the same ``LLMResponse`` as :meth:`chat`. The default
|
|
implementation falls back to a non-streaming call and delivers the
|
|
full content as a single delta. Providers that support native
|
|
streaming should override this method.
|
|
"""
|
|
response = await self.chat(
|
|
messages=messages, tools=tools, model=model,
|
|
max_tokens=max_tokens, temperature=temperature,
|
|
reasoning_effort=reasoning_effort, tool_choice=tool_choice,
|
|
)
|
|
if on_content_delta and response.content:
|
|
await on_content_delta(response.content)
|
|
return response
|
|
|
|
async def _safe_chat_stream(self, **kwargs: Any) -> LLMResponse:
|
|
"""Call chat_stream() and convert unexpected exceptions to error responses."""
|
|
try:
|
|
return await self.chat_stream(**kwargs)
|
|
except asyncio.CancelledError:
|
|
raise
|
|
except Exception as exc:
|
|
return LLMResponse(content=f"Error calling LLM: {exc}", finish_reason="error")
|
|
|
|
async def chat_stream_with_retry(
|
|
self,
|
|
messages: list[dict[str, Any]],
|
|
tools: list[dict[str, Any]] | None = None,
|
|
model: str | None = None,
|
|
max_tokens: object = _SENTINEL,
|
|
temperature: object = _SENTINEL,
|
|
reasoning_effort: object = _SENTINEL,
|
|
tool_choice: str | dict[str, Any] | None = None,
|
|
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
|
) -> LLMResponse:
|
|
"""Call chat_stream() with retry on transient provider failures."""
|
|
if max_tokens is self._SENTINEL:
|
|
max_tokens = self.generation.max_tokens
|
|
if temperature is self._SENTINEL:
|
|
temperature = self.generation.temperature
|
|
if reasoning_effort is self._SENTINEL:
|
|
reasoning_effort = self.generation.reasoning_effort
|
|
|
|
kw: dict[str, Any] = dict(
|
|
messages=messages, tools=tools, model=model,
|
|
max_tokens=max_tokens, temperature=temperature,
|
|
reasoning_effort=reasoning_effort, tool_choice=tool_choice,
|
|
on_content_delta=on_content_delta,
|
|
)
|
|
|
|
for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1):
|
|
response = await self._safe_chat_stream(**kw)
|
|
|
|
if response.finish_reason != "error":
|
|
return response
|
|
|
|
if not self._is_transient_error(response.content):
|
|
stripped = self._strip_image_content(messages)
|
|
if stripped is not None:
|
|
logger.warning("Non-transient LLM error with image content, retrying without images")
|
|
return await self._safe_chat_stream(**{**kw, "messages": stripped})
|
|
return response
|
|
|
|
logger.warning(
|
|
"LLM transient error (attempt {}/{}), retrying in {}s: {}",
|
|
attempt, len(self._CHAT_RETRY_DELAYS), delay,
|
|
(response.content or "")[:120].lower(),
|
|
)
|
|
await asyncio.sleep(delay)
|
|
|
|
return await self._safe_chat_stream(**kw)
|
|
|
|
async def chat_with_retry(
|
|
self,
|
|
messages: list[dict[str, Any]],
|
|
tools: list[dict[str, Any]] | None = None,
|
|
model: str | None = None,
|
|
max_tokens: object = _SENTINEL,
|
|
temperature: object = _SENTINEL,
|
|
reasoning_effort: object = _SENTINEL,
|
|
tool_choice: str | dict[str, Any] | None = None,
|
|
) -> LLMResponse:
|
|
"""Call chat() with retry on transient provider failures.
|
|
|
|
Parameters default to ``self.generation`` when not explicitly passed,
|
|
so callers no longer need to thread temperature / max_tokens /
|
|
reasoning_effort through every layer.
|
|
"""
|
|
if max_tokens is self._SENTINEL:
|
|
max_tokens = self.generation.max_tokens
|
|
if temperature is self._SENTINEL:
|
|
temperature = self.generation.temperature
|
|
if reasoning_effort is self._SENTINEL:
|
|
reasoning_effort = self.generation.reasoning_effort
|
|
|
|
kw: dict[str, Any] = dict(
|
|
messages=messages, tools=tools, model=model,
|
|
max_tokens=max_tokens, temperature=temperature,
|
|
reasoning_effort=reasoning_effort, tool_choice=tool_choice,
|
|
)
|
|
|
|
for attempt, delay in enumerate(self._CHAT_RETRY_DELAYS, start=1):
|
|
response = await self._safe_chat(**kw)
|
|
|
|
if response.finish_reason != "error":
|
|
return response
|
|
|
|
if not self._is_transient_error(response.content):
|
|
stripped = self._strip_image_content(messages)
|
|
if stripped is not None:
|
|
logger.warning("Non-transient LLM error with image content, retrying without images")
|
|
return await self._safe_chat(**{**kw, "messages": stripped})
|
|
return response
|
|
|
|
logger.warning(
|
|
"LLM transient error (attempt {}/{}), retrying in {}s: {}",
|
|
attempt, len(self._CHAT_RETRY_DELAYS), delay,
|
|
(response.content or "")[:120].lower(),
|
|
)
|
|
await asyncio.sleep(delay)
|
|
|
|
return await self._safe_chat(**kw)
|
|
|
|
@abstractmethod
|
|
def get_default_model(self) -> str:
|
|
"""Get the default model for this provider."""
|
|
pass
|