mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-21 18:40:06 +00:00
Replace the flatten/unflatten approach (merging extra_content.google.* into provider_specific_fields then reconstructing) with direct pass-through: parse extra_content as-is, store on ToolCallRequest.extra_content, serialize back untouched. This is lossless, requires no hardcoded field names, and covers all three parsing branches (str, dict, SDK object) plus streaming.
572 lines
21 KiB
Python
572 lines
21 KiB
Python
"""OpenAI-compatible provider for all non-Anthropic LLM APIs."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import hashlib
|
|
import os
|
|
import secrets
|
|
import string
|
|
import uuid
|
|
from collections.abc import Awaitable, Callable
|
|
from typing import TYPE_CHECKING, Any
|
|
|
|
import json_repair
|
|
from openai import AsyncOpenAI
|
|
|
|
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
|
|
|
if TYPE_CHECKING:
|
|
from nanobot.providers.registry import ProviderSpec
|
|
|
|
_ALLOWED_MSG_KEYS = frozenset({
|
|
"role", "content", "tool_calls", "tool_call_id", "name",
|
|
"reasoning_content", "extra_content",
|
|
})
|
|
_ALNUM = string.ascii_letters + string.digits
|
|
|
|
_STANDARD_TC_KEYS = frozenset({"id", "type", "index", "function"})
|
|
_STANDARD_FN_KEYS = frozenset({"name", "arguments"})
|
|
|
|
|
|
def _short_tool_id() -> str:
|
|
"""9-char alphanumeric ID compatible with all providers (incl. Mistral)."""
|
|
return "".join(secrets.choice(_ALNUM) for _ in range(9))
|
|
|
|
|
|
def _get(obj: Any, key: str) -> Any:
|
|
"""Get a value from dict or object attribute, returning None if absent."""
|
|
if isinstance(obj, dict):
|
|
return obj.get(key)
|
|
return getattr(obj, key, None)
|
|
|
|
|
|
def _coerce_dict(value: Any) -> dict[str, Any] | None:
|
|
"""Try to coerce *value* to a dict; return None if not possible or empty."""
|
|
if value is None:
|
|
return None
|
|
if isinstance(value, dict):
|
|
return value if value else None
|
|
model_dump = getattr(value, "model_dump", None)
|
|
if callable(model_dump):
|
|
dumped = model_dump()
|
|
if isinstance(dumped, dict) and dumped:
|
|
return dumped
|
|
return None
|
|
|
|
|
|
def _extract_tc_extras(tc: Any) -> tuple[
|
|
dict[str, Any] | None,
|
|
dict[str, Any] | None,
|
|
dict[str, Any] | None,
|
|
]:
|
|
"""Extract (extra_content, provider_specific_fields, fn_provider_specific_fields).
|
|
|
|
Works for both SDK objects and dicts. Captures Gemini ``extra_content``
|
|
verbatim and any non-standard keys on the tool-call / function.
|
|
"""
|
|
extra_content = _coerce_dict(_get(tc, "extra_content"))
|
|
|
|
tc_dict = _coerce_dict(tc)
|
|
prov = None
|
|
fn_prov = None
|
|
if tc_dict is not None:
|
|
leftover = {k: v for k, v in tc_dict.items()
|
|
if k not in _STANDARD_TC_KEYS and k != "extra_content" and v is not None}
|
|
if leftover:
|
|
prov = leftover
|
|
fn = _coerce_dict(tc_dict.get("function"))
|
|
if fn is not None:
|
|
fn_leftover = {k: v for k, v in fn.items()
|
|
if k not in _STANDARD_FN_KEYS and v is not None}
|
|
if fn_leftover:
|
|
fn_prov = fn_leftover
|
|
else:
|
|
prov = _coerce_dict(_get(tc, "provider_specific_fields"))
|
|
fn_obj = _get(tc, "function")
|
|
if fn_obj is not None:
|
|
fn_prov = _coerce_dict(_get(fn_obj, "provider_specific_fields"))
|
|
|
|
return extra_content, prov, fn_prov
|
|
|
|
|
|
class OpenAICompatProvider(LLMProvider):
|
|
"""Unified provider for all OpenAI-compatible APIs.
|
|
|
|
Receives a resolved ``ProviderSpec`` from the caller — no internal
|
|
registry lookups needed.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
api_key: str | None = None,
|
|
api_base: str | None = None,
|
|
default_model: str = "gpt-4o",
|
|
extra_headers: dict[str, str] | None = None,
|
|
spec: ProviderSpec | None = None,
|
|
):
|
|
super().__init__(api_key, api_base)
|
|
self.default_model = default_model
|
|
self.extra_headers = extra_headers or {}
|
|
self._spec = spec
|
|
|
|
if api_key and spec and spec.env_key:
|
|
self._setup_env(api_key, api_base)
|
|
|
|
effective_base = api_base or (spec.default_api_base if spec else None) or None
|
|
|
|
self._client = AsyncOpenAI(
|
|
api_key=api_key or "no-key",
|
|
base_url=effective_base,
|
|
default_headers={
|
|
"x-session-affinity": uuid.uuid4().hex,
|
|
**(extra_headers or {}),
|
|
},
|
|
)
|
|
|
|
def _setup_env(self, api_key: str, api_base: str | None) -> None:
|
|
"""Set environment variables based on provider spec."""
|
|
spec = self._spec
|
|
if not spec or not spec.env_key:
|
|
return
|
|
if spec.is_gateway:
|
|
os.environ[spec.env_key] = api_key
|
|
else:
|
|
os.environ.setdefault(spec.env_key, api_key)
|
|
effective_base = api_base or spec.default_api_base
|
|
for env_name, env_val in spec.env_extras:
|
|
resolved = env_val.replace("{api_key}", api_key).replace("{api_base}", effective_base)
|
|
os.environ.setdefault(env_name, resolved)
|
|
|
|
@staticmethod
|
|
def _apply_cache_control(
|
|
messages: list[dict[str, Any]],
|
|
tools: list[dict[str, Any]] | None,
|
|
) -> tuple[list[dict[str, Any]], list[dict[str, Any]] | None]:
|
|
"""Inject cache_control markers for prompt caching."""
|
|
cache_marker = {"type": "ephemeral"}
|
|
new_messages = list(messages)
|
|
|
|
def _mark(msg: dict[str, Any]) -> dict[str, Any]:
|
|
content = msg.get("content")
|
|
if isinstance(content, str):
|
|
return {**msg, "content": [
|
|
{"type": "text", "text": content, "cache_control": cache_marker},
|
|
]}
|
|
if isinstance(content, list) and content:
|
|
nc = list(content)
|
|
nc[-1] = {**nc[-1], "cache_control": cache_marker}
|
|
return {**msg, "content": nc}
|
|
return msg
|
|
|
|
if new_messages and new_messages[0].get("role") == "system":
|
|
new_messages[0] = _mark(new_messages[0])
|
|
if len(new_messages) >= 3:
|
|
new_messages[-2] = _mark(new_messages[-2])
|
|
|
|
new_tools = tools
|
|
if tools:
|
|
new_tools = list(tools)
|
|
new_tools[-1] = {**new_tools[-1], "cache_control": cache_marker}
|
|
return new_messages, new_tools
|
|
|
|
@staticmethod
|
|
def _normalize_tool_call_id(tool_call_id: Any) -> Any:
|
|
"""Normalize to a provider-safe 9-char alphanumeric form."""
|
|
if not isinstance(tool_call_id, str):
|
|
return tool_call_id
|
|
if len(tool_call_id) == 9 and tool_call_id.isalnum():
|
|
return tool_call_id
|
|
return hashlib.sha1(tool_call_id.encode()).hexdigest()[:9]
|
|
|
|
def _sanitize_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
"""Strip non-standard keys, normalize tool_call IDs."""
|
|
sanitized = LLMProvider._sanitize_request_messages(messages, _ALLOWED_MSG_KEYS)
|
|
id_map: dict[str, str] = {}
|
|
|
|
def map_id(value: Any) -> Any:
|
|
if not isinstance(value, str):
|
|
return value
|
|
return id_map.setdefault(value, self._normalize_tool_call_id(value))
|
|
|
|
for clean in sanitized:
|
|
if isinstance(clean.get("tool_calls"), list):
|
|
normalized = []
|
|
for tc in clean["tool_calls"]:
|
|
if not isinstance(tc, dict):
|
|
normalized.append(tc)
|
|
continue
|
|
tc_clean = dict(tc)
|
|
tc_clean["id"] = map_id(tc_clean.get("id"))
|
|
normalized.append(tc_clean)
|
|
clean["tool_calls"] = normalized
|
|
if "tool_call_id" in clean and clean["tool_call_id"]:
|
|
clean["tool_call_id"] = map_id(clean["tool_call_id"])
|
|
return sanitized
|
|
|
|
# ------------------------------------------------------------------
|
|
# Build kwargs
|
|
# ------------------------------------------------------------------
|
|
|
|
def _build_kwargs(
|
|
self,
|
|
messages: list[dict[str, Any]],
|
|
tools: list[dict[str, Any]] | None,
|
|
model: str | None,
|
|
max_tokens: int,
|
|
temperature: float,
|
|
reasoning_effort: str | None,
|
|
tool_choice: str | dict[str, Any] | None,
|
|
) -> dict[str, Any]:
|
|
model_name = model or self.default_model
|
|
spec = self._spec
|
|
|
|
if spec and spec.supports_prompt_caching:
|
|
messages, tools = self._apply_cache_control(messages, tools)
|
|
|
|
if spec and spec.strip_model_prefix:
|
|
model_name = model_name.split("/")[-1]
|
|
|
|
kwargs: dict[str, Any] = {
|
|
"model": model_name,
|
|
"messages": self._sanitize_messages(self._sanitize_empty_content(messages)),
|
|
"max_tokens": max(1, max_tokens),
|
|
"temperature": temperature,
|
|
}
|
|
|
|
if spec:
|
|
model_lower = model_name.lower()
|
|
for pattern, overrides in spec.model_overrides:
|
|
if pattern in model_lower:
|
|
kwargs.update(overrides)
|
|
break
|
|
|
|
if reasoning_effort:
|
|
kwargs["reasoning_effort"] = reasoning_effort
|
|
|
|
if tools:
|
|
kwargs["tools"] = tools
|
|
kwargs["tool_choice"] = tool_choice or "auto"
|
|
|
|
return kwargs
|
|
|
|
# ------------------------------------------------------------------
|
|
# Response parsing
|
|
# ------------------------------------------------------------------
|
|
|
|
@staticmethod
|
|
def _maybe_mapping(value: Any) -> dict[str, Any] | None:
|
|
if isinstance(value, dict):
|
|
return value
|
|
model_dump = getattr(value, "model_dump", None)
|
|
if callable(model_dump):
|
|
dumped = model_dump()
|
|
if isinstance(dumped, dict):
|
|
return dumped
|
|
return None
|
|
|
|
@classmethod
|
|
def _extract_text_content(cls, value: Any) -> str | None:
|
|
if value is None:
|
|
return None
|
|
if isinstance(value, str):
|
|
return value
|
|
if isinstance(value, list):
|
|
parts: list[str] = []
|
|
for item in value:
|
|
item_map = cls._maybe_mapping(item)
|
|
if item_map:
|
|
text = item_map.get("text")
|
|
if isinstance(text, str):
|
|
parts.append(text)
|
|
continue
|
|
text = getattr(item, "text", None)
|
|
if isinstance(text, str):
|
|
parts.append(text)
|
|
continue
|
|
if isinstance(item, str):
|
|
parts.append(item)
|
|
return "".join(parts) or None
|
|
return str(value)
|
|
|
|
@classmethod
|
|
def _extract_usage(cls, response: Any) -> dict[str, int]:
|
|
usage_obj = None
|
|
response_map = cls._maybe_mapping(response)
|
|
if response_map is not None:
|
|
usage_obj = response_map.get("usage")
|
|
elif hasattr(response, "usage") and response.usage:
|
|
usage_obj = response.usage
|
|
|
|
usage_map = cls._maybe_mapping(usage_obj)
|
|
if usage_map is not None:
|
|
return {
|
|
"prompt_tokens": int(usage_map.get("prompt_tokens") or 0),
|
|
"completion_tokens": int(usage_map.get("completion_tokens") or 0),
|
|
"total_tokens": int(usage_map.get("total_tokens") or 0),
|
|
}
|
|
|
|
if usage_obj:
|
|
return {
|
|
"prompt_tokens": getattr(usage_obj, "prompt_tokens", 0) or 0,
|
|
"completion_tokens": getattr(usage_obj, "completion_tokens", 0) or 0,
|
|
"total_tokens": getattr(usage_obj, "total_tokens", 0) or 0,
|
|
}
|
|
return {}
|
|
|
|
def _parse(self, response: Any) -> LLMResponse:
|
|
if isinstance(response, str):
|
|
return LLMResponse(content=response, finish_reason="stop")
|
|
|
|
response_map = self._maybe_mapping(response)
|
|
if response_map is not None:
|
|
choices = response_map.get("choices") or []
|
|
if not choices:
|
|
content = self._extract_text_content(
|
|
response_map.get("content") or response_map.get("output_text")
|
|
)
|
|
if content is not None:
|
|
return LLMResponse(
|
|
content=content,
|
|
finish_reason=str(response_map.get("finish_reason") or "stop"),
|
|
usage=self._extract_usage(response_map),
|
|
)
|
|
return LLMResponse(content="Error: API returned empty choices.", finish_reason="error")
|
|
|
|
choice0 = self._maybe_mapping(choices[0]) or {}
|
|
msg0 = self._maybe_mapping(choice0.get("message")) or {}
|
|
content = self._extract_text_content(msg0.get("content"))
|
|
finish_reason = str(choice0.get("finish_reason") or "stop")
|
|
|
|
raw_tool_calls: list[Any] = []
|
|
reasoning_content = msg0.get("reasoning_content")
|
|
for ch in choices:
|
|
ch_map = self._maybe_mapping(ch) or {}
|
|
m = self._maybe_mapping(ch_map.get("message")) or {}
|
|
tool_calls = m.get("tool_calls")
|
|
if isinstance(tool_calls, list) and tool_calls:
|
|
raw_tool_calls.extend(tool_calls)
|
|
if ch_map.get("finish_reason") in ("tool_calls", "stop"):
|
|
finish_reason = str(ch_map["finish_reason"])
|
|
if not content:
|
|
content = self._extract_text_content(m.get("content"))
|
|
if not reasoning_content:
|
|
reasoning_content = m.get("reasoning_content")
|
|
|
|
parsed_tool_calls = []
|
|
for tc in raw_tool_calls:
|
|
tc_map = self._maybe_mapping(tc) or {}
|
|
fn = self._maybe_mapping(tc_map.get("function")) or {}
|
|
args = fn.get("arguments", {})
|
|
if isinstance(args, str):
|
|
args = json_repair.loads(args)
|
|
ec, prov, fn_prov = _extract_tc_extras(tc)
|
|
parsed_tool_calls.append(ToolCallRequest(
|
|
id=_short_tool_id(),
|
|
name=str(fn.get("name") or ""),
|
|
arguments=args if isinstance(args, dict) else {},
|
|
extra_content=ec,
|
|
provider_specific_fields=prov,
|
|
function_provider_specific_fields=fn_prov,
|
|
))
|
|
|
|
return LLMResponse(
|
|
content=content,
|
|
tool_calls=parsed_tool_calls,
|
|
finish_reason=finish_reason,
|
|
usage=self._extract_usage(response_map),
|
|
reasoning_content=reasoning_content if isinstance(reasoning_content, str) else None,
|
|
)
|
|
|
|
if not response.choices:
|
|
return LLMResponse(content="Error: API returned empty choices.", finish_reason="error")
|
|
|
|
choice = response.choices[0]
|
|
msg = choice.message
|
|
content = msg.content
|
|
finish_reason = choice.finish_reason
|
|
|
|
raw_tool_calls: list[Any] = []
|
|
for ch in response.choices:
|
|
m = ch.message
|
|
if hasattr(m, "tool_calls") and m.tool_calls:
|
|
raw_tool_calls.extend(m.tool_calls)
|
|
if ch.finish_reason in ("tool_calls", "stop"):
|
|
finish_reason = ch.finish_reason
|
|
if not content and m.content:
|
|
content = m.content
|
|
|
|
tool_calls = []
|
|
for tc in raw_tool_calls:
|
|
args = tc.function.arguments
|
|
if isinstance(args, str):
|
|
args = json_repair.loads(args)
|
|
ec, prov, fn_prov = _extract_tc_extras(tc)
|
|
tool_calls.append(ToolCallRequest(
|
|
id=_short_tool_id(),
|
|
name=tc.function.name,
|
|
arguments=args,
|
|
extra_content=ec,
|
|
provider_specific_fields=prov,
|
|
function_provider_specific_fields=fn_prov,
|
|
))
|
|
|
|
return LLMResponse(
|
|
content=content,
|
|
tool_calls=tool_calls,
|
|
finish_reason=finish_reason or "stop",
|
|
usage=self._extract_usage(response),
|
|
reasoning_content=getattr(msg, "reasoning_content", None) or None,
|
|
)
|
|
|
|
@classmethod
|
|
def _parse_chunks(cls, chunks: list[Any]) -> LLMResponse:
|
|
content_parts: list[str] = []
|
|
tc_bufs: dict[int, dict[str, Any]] = {}
|
|
finish_reason = "stop"
|
|
usage: dict[str, int] = {}
|
|
|
|
def _accum_tc(tc: Any, idx_hint: int) -> None:
|
|
"""Accumulate one streaming tool-call delta into *tc_bufs*."""
|
|
tc_index: int = _get(tc, "index") if _get(tc, "index") is not None else idx_hint
|
|
buf = tc_bufs.setdefault(tc_index, {
|
|
"id": "", "name": "", "arguments": "",
|
|
"extra_content": None, "prov": None, "fn_prov": None,
|
|
})
|
|
tc_id = _get(tc, "id")
|
|
if tc_id:
|
|
buf["id"] = str(tc_id)
|
|
fn = _get(tc, "function")
|
|
if fn is not None:
|
|
fn_name = _get(fn, "name")
|
|
if fn_name:
|
|
buf["name"] = str(fn_name)
|
|
fn_args = _get(fn, "arguments")
|
|
if fn_args:
|
|
buf["arguments"] += str(fn_args)
|
|
ec, prov, fn_prov = _extract_tc_extras(tc)
|
|
if ec:
|
|
buf["extra_content"] = ec
|
|
if prov:
|
|
buf["prov"] = prov
|
|
if fn_prov:
|
|
buf["fn_prov"] = fn_prov
|
|
|
|
for chunk in chunks:
|
|
if isinstance(chunk, str):
|
|
content_parts.append(chunk)
|
|
continue
|
|
|
|
chunk_map = cls._maybe_mapping(chunk)
|
|
if chunk_map is not None:
|
|
choices = chunk_map.get("choices") or []
|
|
if not choices:
|
|
usage = cls._extract_usage(chunk_map) or usage
|
|
text = cls._extract_text_content(
|
|
chunk_map.get("content") or chunk_map.get("output_text")
|
|
)
|
|
if text:
|
|
content_parts.append(text)
|
|
continue
|
|
choice = cls._maybe_mapping(choices[0]) or {}
|
|
if choice.get("finish_reason"):
|
|
finish_reason = str(choice["finish_reason"])
|
|
delta = cls._maybe_mapping(choice.get("delta")) or {}
|
|
text = cls._extract_text_content(delta.get("content"))
|
|
if text:
|
|
content_parts.append(text)
|
|
for idx, tc in enumerate(delta.get("tool_calls") or []):
|
|
_accum_tc(tc, idx)
|
|
usage = cls._extract_usage(chunk_map) or usage
|
|
continue
|
|
|
|
if not chunk.choices:
|
|
usage = cls._extract_usage(chunk) or usage
|
|
continue
|
|
choice = chunk.choices[0]
|
|
if choice.finish_reason:
|
|
finish_reason = choice.finish_reason
|
|
delta = choice.delta
|
|
if delta and delta.content:
|
|
content_parts.append(delta.content)
|
|
for tc in (delta.tool_calls or []) if delta else []:
|
|
_accum_tc(tc, getattr(tc, "index", 0))
|
|
|
|
return LLMResponse(
|
|
content="".join(content_parts) or None,
|
|
tool_calls=[
|
|
ToolCallRequest(
|
|
id=b["id"] or _short_tool_id(),
|
|
name=b["name"],
|
|
arguments=json_repair.loads(b["arguments"]) if b["arguments"] else {},
|
|
extra_content=b.get("extra_content"),
|
|
provider_specific_fields=b.get("prov"),
|
|
function_provider_specific_fields=b.get("fn_prov"),
|
|
)
|
|
for b in tc_bufs.values()
|
|
],
|
|
finish_reason=finish_reason,
|
|
usage=usage,
|
|
)
|
|
|
|
@staticmethod
|
|
def _handle_error(e: Exception) -> LLMResponse:
|
|
body = getattr(e, "doc", None) or getattr(getattr(e, "response", None), "text", None)
|
|
msg = f"Error: {body.strip()[:500]}" if body and body.strip() else f"Error calling LLM: {e}"
|
|
return LLMResponse(content=msg, finish_reason="error")
|
|
|
|
# ------------------------------------------------------------------
|
|
# Public API
|
|
# ------------------------------------------------------------------
|
|
|
|
async def chat(
|
|
self,
|
|
messages: list[dict[str, Any]],
|
|
tools: list[dict[str, Any]] | None = None,
|
|
model: str | None = None,
|
|
max_tokens: int = 4096,
|
|
temperature: float = 0.7,
|
|
reasoning_effort: str | None = None,
|
|
tool_choice: str | dict[str, Any] | None = None,
|
|
) -> LLMResponse:
|
|
kwargs = self._build_kwargs(
|
|
messages, tools, model, max_tokens, temperature,
|
|
reasoning_effort, tool_choice,
|
|
)
|
|
try:
|
|
return self._parse(await self._client.chat.completions.create(**kwargs))
|
|
except Exception as e:
|
|
return self._handle_error(e)
|
|
|
|
async def chat_stream(
|
|
self,
|
|
messages: list[dict[str, Any]],
|
|
tools: list[dict[str, Any]] | None = None,
|
|
model: str | None = None,
|
|
max_tokens: int = 4096,
|
|
temperature: float = 0.7,
|
|
reasoning_effort: str | None = None,
|
|
tool_choice: str | dict[str, Any] | None = None,
|
|
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
|
) -> LLMResponse:
|
|
kwargs = self._build_kwargs(
|
|
messages, tools, model, max_tokens, temperature,
|
|
reasoning_effort, tool_choice,
|
|
)
|
|
kwargs["stream"] = True
|
|
kwargs["stream_options"] = {"include_usage": True}
|
|
try:
|
|
stream = await self._client.chat.completions.create(**kwargs)
|
|
chunks: list[Any] = []
|
|
async for chunk in stream:
|
|
chunks.append(chunk)
|
|
if on_content_delta and chunk.choices:
|
|
text = getattr(chunk.choices[0].delta, "content", None)
|
|
if text:
|
|
await on_content_delta(text)
|
|
return self._parse_chunks(chunks)
|
|
except Exception as e:
|
|
return self._handle_error(e)
|
|
|
|
def get_default_model(self) -> str:
|
|
return self.default_model
|