mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-15 15:24:06 +00:00
Adds ProviderConfig.extra_query, threaded into AsyncOpenAI(default_query) so that Azure-style gateways requiring query params like api-version can be configured without URL hacks. Also updates provider_signature to track extra_query changes so per-turn refresh rebuilds the provider when the value changes. Addresses the extra_query portion of #4204. The max_completion_tokens model-awareness enhancement is intentionally left separate.
1490 lines
60 KiB
Python
1490 lines
60 KiB
Python
"""OpenAI-compatible provider for all non-Anthropic LLM APIs."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import hashlib
|
|
import importlib.util
|
|
import json
|
|
import os
|
|
import secrets
|
|
import string
|
|
import time
|
|
import uuid
|
|
from collections import deque
|
|
from collections.abc import Awaitable, Callable
|
|
from ipaddress import ip_address
|
|
from typing import TYPE_CHECKING, Any
|
|
from urllib.parse import urlparse
|
|
|
|
import json_repair
|
|
from loguru import logger
|
|
|
|
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
|
from nanobot.providers.openai_responses import (
|
|
consume_sdk_stream,
|
|
convert_messages,
|
|
convert_tools,
|
|
parse_response_output,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
from openai import AsyncOpenAI as AsyncOpenAIType
|
|
|
|
from nanobot.providers.registry import ProviderSpec
|
|
|
|
# Module-level placeholder — set lazily by _ensure_client on first real
|
|
# use, or replaced by tests via ``patch(...)``. Kept as a plain name so
|
|
# that ``unittest.mock.patch`` can find and replace it.
|
|
AsyncOpenAI: Any = None
|
|
|
|
_ALLOWED_MSG_KEYS = frozenset({
|
|
"role", "content", "tool_calls", "tool_call_id", "name",
|
|
"reasoning_content", "extra_content",
|
|
})
|
|
_ALNUM = string.ascii_letters + string.digits
|
|
|
|
_STANDARD_TC_KEYS = frozenset({"id", "type", "index", "function"})
|
|
_STANDARD_FN_KEYS = frozenset({"name", "arguments"})
|
|
_DEFAULT_OPENROUTER_HEADERS = {
|
|
"HTTP-Referer": "https://github.com/HKUDS/nanobot",
|
|
"X-OpenRouter-Title": "nanobot",
|
|
"X-OpenRouter-Categories": "cli-agent,personal-agent",
|
|
}
|
|
_KIMI_THINKING_MODELS: frozenset[str] = frozenset({
|
|
"kimi-k2.5",
|
|
"kimi-k2.6",
|
|
"k2.6-code-preview",
|
|
})
|
|
# Thinking-capable MiMo models per Xiaomi docs (see
|
|
# tests/providers/test_xiaomi_mimo_thinking.py). mimo-v2-flash is omitted
|
|
# because it does not support thinking.
|
|
_MIMO_THINKING_MODELS: frozenset[str] = frozenset({
|
|
"mimo-v2.5-pro",
|
|
"mimo-v2.5",
|
|
"mimo-v2-pro",
|
|
"mimo-v2-omni",
|
|
})
|
|
_OPENAI_COMPAT_REQUEST_TIMEOUT_S = 120.0
|
|
|
|
# Maps ProviderSpec.thinking_style → extra_body builder.
|
|
# Each builder takes a bool (thinking_enabled) and returns the dict to
|
|
# merge into extra_body, keeping the style→wire-format mapping in one place.
|
|
_THINKING_STYLE_MAP: dict[str, Any] = {
|
|
"thinking_type": lambda on: {"thinking": {"type": "enabled" if on else "disabled"}},
|
|
"enable_thinking": lambda on: {"enable_thinking": on},
|
|
"reasoning_split": lambda on: {"reasoning_split": on},
|
|
}
|
|
_GATEWAY_REASONING_STYLE_MAP: dict[str, Any] = {
|
|
"reasoning_effort": lambda effort: {"reasoning": {"effort": effort}},
|
|
}
|
|
_MODEL_THINKING_STYLES: dict[str, str] = {
|
|
**dict.fromkeys(_KIMI_THINKING_MODELS, "thinking_type"),
|
|
**dict.fromkeys(_MIMO_THINKING_MODELS, "thinking_type"),
|
|
}
|
|
|
|
|
|
def _model_slug(model_name: str) -> str:
|
|
return model_name.lower().rsplit("/", 1)[-1]
|
|
|
|
|
|
def _model_thinking_style(model_name: str) -> str:
|
|
return _MODEL_THINKING_STYLES.get(_model_slug(model_name), "")
|
|
|
|
|
|
def _thinking_styles_for(spec: ProviderSpec | None, model_name: str) -> list[str]:
|
|
styles: list[str] = []
|
|
if spec and spec.thinking_style:
|
|
styles.append(spec.thinking_style)
|
|
model_style = _model_thinking_style(model_name)
|
|
if model_style and model_style not in styles:
|
|
styles.append(model_style)
|
|
return styles
|
|
|
|
|
|
def _thinking_extra_body(style: str, thinking_enabled: bool) -> dict[str, Any] | None:
|
|
builder = _THINKING_STYLE_MAP.get(style)
|
|
return builder(thinking_enabled) if builder else None
|
|
|
|
|
|
def _gateway_reasoning_extra_body(style: str, effort: str | None) -> dict[str, Any] | None:
|
|
if not effort:
|
|
return None
|
|
builder = _GATEWAY_REASONING_STYLE_MAP.get(style)
|
|
return builder(effort) if builder else None
|
|
|
|
|
|
def _openai_compat_timeout_s() -> float:
|
|
"""Return the bounded request timeout used for OpenAI-compatible providers."""
|
|
return _float_env("NANOBOT_OPENAI_COMPAT_TIMEOUT_S", _OPENAI_COMPAT_REQUEST_TIMEOUT_S)
|
|
|
|
|
|
def _float_env(name: str, default: float) -> float:
|
|
raw = os.environ.get(name)
|
|
if raw is None or not raw.strip():
|
|
return default
|
|
try:
|
|
value = float(raw)
|
|
except (TypeError, ValueError):
|
|
logger.warning("Ignoring invalid {}={!r}; using {}", name, raw, default)
|
|
return default
|
|
if value <= 0:
|
|
logger.warning("Ignoring non-positive {}={!r}; using {}", name, raw, default)
|
|
return default
|
|
return value
|
|
|
|
|
|
def _short_tool_id() -> str:
|
|
"""9-char alphanumeric ID compatible with all providers (incl. Mistral)."""
|
|
return "".join(secrets.choice(_ALNUM) for _ in range(9))
|
|
|
|
|
|
def _get(obj: Any, key: str) -> Any:
|
|
"""Get a value from dict or object attribute, returning None if absent."""
|
|
if isinstance(obj, dict):
|
|
return obj.get(key)
|
|
return getattr(obj, key, None)
|
|
|
|
|
|
def _coerce_dict(value: Any) -> dict[str, Any] | None:
|
|
"""Try to coerce *value* to a dict; return None if not possible or empty."""
|
|
if value is None:
|
|
return None
|
|
if isinstance(value, dict):
|
|
return value if value else None
|
|
model_dump = getattr(value, "model_dump", None)
|
|
if callable(model_dump):
|
|
dumped = model_dump()
|
|
if isinstance(dumped, dict) and dumped:
|
|
return dumped
|
|
return None
|
|
|
|
|
|
def _extract_tc_extras(tc: Any) -> tuple[
|
|
dict[str, Any] | None,
|
|
dict[str, Any] | None,
|
|
dict[str, Any] | None,
|
|
]:
|
|
"""Extract (extra_content, provider_specific_fields, fn_provider_specific_fields).
|
|
|
|
Works for both SDK objects and dicts. Captures Gemini ``extra_content``
|
|
verbatim and any non-standard keys on the tool-call / function.
|
|
"""
|
|
extra_content = _coerce_dict(_get(tc, "extra_content"))
|
|
|
|
tc_dict = _coerce_dict(tc)
|
|
prov = None
|
|
fn_prov = None
|
|
if tc_dict is not None:
|
|
leftover = {k: v for k, v in tc_dict.items()
|
|
if k not in _STANDARD_TC_KEYS and k != "extra_content" and v is not None}
|
|
if leftover:
|
|
prov = leftover
|
|
fn = _coerce_dict(tc_dict.get("function"))
|
|
if fn is not None:
|
|
fn_leftover = {k: v for k, v in fn.items()
|
|
if k not in _STANDARD_FN_KEYS and v is not None}
|
|
if fn_leftover:
|
|
fn_prov = fn_leftover
|
|
else:
|
|
prov = _coerce_dict(_get(tc, "provider_specific_fields"))
|
|
fn_obj = _get(tc, "function")
|
|
if fn_obj is not None:
|
|
fn_prov = _coerce_dict(_get(fn_obj, "provider_specific_fields"))
|
|
|
|
return extra_content, prov, fn_prov
|
|
|
|
|
|
def _uses_openrouter_attribution(spec: "ProviderSpec | None", api_base: str | None) -> bool:
|
|
"""Apply Nanobot attribution headers to OpenRouter requests by default."""
|
|
if spec and spec.name == "openrouter":
|
|
return True
|
|
return bool(api_base and "openrouter" in api_base.lower())
|
|
|
|
|
|
_RESPONSES_FAILURE_THRESHOLD = 3
|
|
_RESPONSES_PROBE_INTERVAL_S = 300 # 5 minutes
|
|
|
|
|
|
def _is_local_endpoint(
|
|
spec: "ProviderSpec | None",
|
|
api_base: str | None,
|
|
) -> bool:
|
|
"""Return True when the endpoint is a local or LAN model server.
|
|
|
|
Matches either the provider spec's ``is_local`` flag or common private-
|
|
network patterns in the base URL (localhost, 127.x, 192.168.x, 10.x,
|
|
172.16-31.x, Docker ``host.docker.internal``).
|
|
"""
|
|
if spec and spec.is_local:
|
|
return True
|
|
if not api_base:
|
|
return False
|
|
raw = api_base.strip().lower()
|
|
parsed = urlparse(raw if "://" in raw else f"//{raw}")
|
|
try:
|
|
host = parsed.hostname
|
|
except ValueError:
|
|
return False
|
|
if host in {"localhost", "host.docker.internal"}:
|
|
return True
|
|
if not host:
|
|
return False
|
|
try:
|
|
addr = ip_address(host)
|
|
except ValueError:
|
|
return False
|
|
return addr.is_loopback or addr.is_private
|
|
|
|
|
|
def _is_direct_openai_base(api_base: str | None) -> bool:
|
|
"""Return True for direct OpenAI endpoints, not generic OpenAI-compatible gateways."""
|
|
if not api_base:
|
|
return True
|
|
normalized = api_base.strip().lower().rstrip("/")
|
|
return "api.openai.com" in normalized and "openrouter" not in normalized
|
|
|
|
|
|
def _responses_circuit_key(
|
|
model: str | None,
|
|
default_model: str,
|
|
reasoning_effort: str | None,
|
|
) -> str:
|
|
model_name = (model or default_model).lower()
|
|
effort = reasoning_effort.lower() if isinstance(reasoning_effort, str) else ""
|
|
return f"{model_name}:{effort}"
|
|
|
|
|
|
def _deep_merge(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]:
|
|
"""Recursively merge *override* into *base*, returning a new dict.
|
|
|
|
Nested dicts are merged key-by-key; all other types in *override*
|
|
replace the corresponding key in *base*.
|
|
"""
|
|
merged = dict(base)
|
|
for key, value in override.items():
|
|
if (
|
|
key in merged
|
|
and isinstance(merged[key], dict)
|
|
and isinstance(value, dict)
|
|
):
|
|
merged[key] = _deep_merge(merged[key], value)
|
|
else:
|
|
merged[key] = value
|
|
return merged
|
|
|
|
|
|
def _merge_unique_list(base: Any, override: Any) -> Any:
|
|
"""Append list values while preserving order and removing duplicates."""
|
|
if not isinstance(base, list) or not isinstance(override, list):
|
|
return override
|
|
result: list[Any] = []
|
|
seen: set[str] = set()
|
|
for value in [*base, *override]:
|
|
try:
|
|
key = json.dumps(value, sort_keys=True, ensure_ascii=False)
|
|
except Exception:
|
|
key = repr(value)
|
|
if key in seen:
|
|
continue
|
|
seen.add(key)
|
|
result.append(value)
|
|
return result
|
|
|
|
|
|
def _merge_responses_extra_body(
|
|
body: dict[str, Any],
|
|
extra_body: dict[str, Any],
|
|
) -> dict[str, Any]:
|
|
"""Merge configured Responses API body fields without clobbering tools."""
|
|
reserved = {"include", "tools"}
|
|
regular_extra = {key: value for key, value in extra_body.items() if key not in reserved}
|
|
merged = _deep_merge(body, regular_extra)
|
|
|
|
if "include" in extra_body:
|
|
merged["include"] = _merge_unique_list(body.get("include"), extra_body["include"])
|
|
|
|
if "tools" in extra_body:
|
|
current_tools = body.get("tools")
|
|
configured_tools = extra_body["tools"]
|
|
if isinstance(current_tools, list) and isinstance(configured_tools, list):
|
|
merged["tools"] = [*current_tools, *configured_tools]
|
|
else:
|
|
merged["tools"] = configured_tools
|
|
|
|
return merged
|
|
|
|
|
|
class OpenAICompatProvider(LLMProvider):
|
|
"""Unified provider for all OpenAI-compatible APIs.
|
|
|
|
Receives a resolved ``ProviderSpec`` from the caller — no internal
|
|
registry lookups needed.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
api_key: str | None = None,
|
|
api_base: str | None = None,
|
|
default_model: str = "gpt-4o",
|
|
extra_headers: dict[str, str] | None = None,
|
|
spec: ProviderSpec | None = None,
|
|
extra_body: dict[str, Any] | None = None,
|
|
api_type: str = "auto",
|
|
extra_query: dict[str, str] | None = None,
|
|
):
|
|
super().__init__(api_key, api_base)
|
|
self.default_model = default_model
|
|
self.extra_headers = extra_headers or {}
|
|
self._spec = spec
|
|
self._extra_body = extra_body or {}
|
|
self._api_type = api_type if spec and spec.name == "openai" else "auto"
|
|
self._extra_query = extra_query or {}
|
|
|
|
if api_key and spec and spec.env_key:
|
|
self._setup_env(api_key, api_base)
|
|
|
|
effective_base = api_base or (spec.default_api_base if spec else None) or None
|
|
self._effective_base = effective_base
|
|
self._default_headers = {"x-session-affinity": uuid.uuid4().hex}
|
|
if _uses_openrouter_attribution(spec, effective_base):
|
|
self._default_headers.update(_DEFAULT_OPENROUTER_HEADERS)
|
|
if extra_headers:
|
|
self._default_headers.update(extra_headers)
|
|
self._api_key_for_client = api_key or "no-key"
|
|
self._is_local = _is_local_endpoint(spec, effective_base)
|
|
|
|
# Lazy-init: the OpenAI client and its httpx transport are expensive
|
|
# to create (~700 ms on Windows). Defer until first use.
|
|
self._client: AsyncOpenAIType | None = None
|
|
self._client_lock = asyncio.Lock()
|
|
|
|
# Responses API circuit breaker: skip after repeated failures,
|
|
# probe again after _RESPONSES_PROBE_INTERVAL_S seconds.
|
|
self._responses_failures: dict[str, int] = {}
|
|
self._responses_tripped_at: dict[str, float] = {}
|
|
|
|
def _build_client(self) -> None:
|
|
"""Create the OpenAI client using the current module-level AsyncOpenAI."""
|
|
import httpx
|
|
|
|
timeout_s = _openai_compat_timeout_s()
|
|
http_client: httpx.AsyncClient | None = None
|
|
if self._is_local:
|
|
# Local model servers (Ollama, llama.cpp, vLLM) often close idle
|
|
# HTTP connections before the client-side keepalive expires. When
|
|
# two LLM calls happen seconds apart (e.g. heartbeat _decide then
|
|
# process_direct), the second call may grab a now-dead pooled
|
|
# connection, causing a transient APIConnectionError on every first
|
|
# attempt. Disabling keepalive for local endpoints avoids this by
|
|
# opening a fresh connection for each request, which is cheap on a
|
|
# LAN. Cloud providers benefit from keepalive, so we leave the
|
|
# default pool settings for them.
|
|
http_client = httpx.AsyncClient(
|
|
limits=httpx.Limits(keepalive_expiry=0),
|
|
timeout=timeout_s,
|
|
)
|
|
self._client = AsyncOpenAI(
|
|
api_key=self._api_key_for_client,
|
|
base_url=self._effective_base,
|
|
default_headers=self._default_headers,
|
|
default_query=self._extra_query or None,
|
|
max_retries=0,
|
|
timeout=timeout_s,
|
|
http_client=http_client,
|
|
)
|
|
|
|
async def _ensure_client(self):
|
|
"""Return the shared OpenAI client, creating it on first call."""
|
|
if self._client is not None:
|
|
return self._client
|
|
async with self._client_lock:
|
|
if self._client is not None:
|
|
return self._client
|
|
global AsyncOpenAI
|
|
if AsyncOpenAI is None:
|
|
if os.environ.get("LANGFUSE_SECRET_KEY") and importlib.util.find_spec("langfuse"):
|
|
from langfuse.openai import AsyncOpenAI as _AsyncOpenAI
|
|
else:
|
|
if os.environ.get("LANGFUSE_SECRET_KEY"):
|
|
logger.warning(
|
|
"LANGFUSE_SECRET_KEY is set but langfuse is not installed; "
|
|
"install with `pip install langfuse` to enable tracing"
|
|
)
|
|
from openai import AsyncOpenAI as _AsyncOpenAI
|
|
AsyncOpenAI = _AsyncOpenAI
|
|
|
|
self._build_client()
|
|
return self._client
|
|
|
|
def _setup_env(self, api_key: str, api_base: str | None) -> None:
|
|
"""Set environment variables based on provider spec."""
|
|
spec = self._spec
|
|
if not spec or not spec.env_key:
|
|
return
|
|
if spec.is_gateway:
|
|
os.environ[spec.env_key] = api_key
|
|
else:
|
|
os.environ.setdefault(spec.env_key, api_key)
|
|
effective_base = api_base or spec.default_api_base
|
|
for env_name, env_val in spec.env_extras:
|
|
resolved = env_val.replace("{api_key}", api_key).replace("{api_base}", effective_base)
|
|
os.environ.setdefault(env_name, resolved)
|
|
|
|
@classmethod
|
|
def _apply_cache_control(
|
|
cls,
|
|
messages: list[dict[str, Any]],
|
|
tools: list[dict[str, Any]] | None,
|
|
) -> tuple[list[dict[str, Any]], list[dict[str, Any]] | None]:
|
|
"""Inject cache_control markers for prompt caching."""
|
|
cache_marker = {"type": "ephemeral"}
|
|
new_messages = list(messages)
|
|
|
|
def _mark(msg: dict[str, Any]) -> dict[str, Any]:
|
|
content = msg.get("content")
|
|
if isinstance(content, str):
|
|
return {**msg, "content": [
|
|
{"type": "text", "text": content, "cache_control": cache_marker},
|
|
]}
|
|
if isinstance(content, list) and content:
|
|
nc = list(content)
|
|
nc[-1] = {**nc[-1], "cache_control": cache_marker}
|
|
return {**msg, "content": nc}
|
|
return msg
|
|
|
|
if new_messages and new_messages[0].get("role") == "system":
|
|
new_messages[0] = _mark(new_messages[0])
|
|
if len(new_messages) >= 3:
|
|
new_messages[-2] = _mark(new_messages[-2])
|
|
|
|
new_tools = tools
|
|
if tools:
|
|
new_tools = list(tools)
|
|
for idx in cls._tool_cache_marker_indices(new_tools):
|
|
new_tools[idx] = {**new_tools[idx], "cache_control": cache_marker}
|
|
return new_messages, new_tools
|
|
|
|
@staticmethod
|
|
def _normalize_tool_call_id(tool_call_id: Any) -> Any:
|
|
"""Normalize to a provider-safe 9-char alphanumeric form."""
|
|
if not isinstance(tool_call_id, str):
|
|
return tool_call_id
|
|
if len(tool_call_id) == 9 and tool_call_id.isalnum():
|
|
return tool_call_id
|
|
return hashlib.sha1(tool_call_id.encode()).hexdigest()[:9]
|
|
|
|
def _should_normalize_tool_call_ids(self) -> bool:
|
|
"""Return True for providers that reject normal OpenAI tool call IDs."""
|
|
return bool(self._spec and self._spec.name == "mistral")
|
|
|
|
@staticmethod
|
|
def _normalize_tool_call_arguments(arguments: Any) -> str:
|
|
"""Force function.arguments into a valid JSON object string."""
|
|
if isinstance(arguments, str):
|
|
stripped = arguments.strip()
|
|
if not stripped:
|
|
return "{}"
|
|
try:
|
|
parsed = json_repair.loads(stripped)
|
|
except Exception:
|
|
return "{}"
|
|
if isinstance(parsed, dict):
|
|
return json.dumps(parsed, ensure_ascii=False)
|
|
return "{}"
|
|
if isinstance(arguments, dict):
|
|
return json.dumps(arguments, ensure_ascii=False)
|
|
return "{}"
|
|
|
|
@staticmethod
|
|
def _coerce_content_to_string(content: Any) -> str | None:
|
|
"""Coerce block/list content into plain text for strict string-only APIs."""
|
|
if content is None or isinstance(content, str):
|
|
return content
|
|
text = OpenAICompatProvider._extract_text_content(content)
|
|
if isinstance(text, str) and text:
|
|
return text
|
|
try:
|
|
dumped = json.dumps(content, ensure_ascii=False)
|
|
except Exception:
|
|
dumped = str(content)
|
|
return dumped or "(empty)"
|
|
|
|
def _sanitize_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
"""Strip non-standard keys, normalize tool_call IDs."""
|
|
sanitized = LLMProvider._sanitize_request_messages(messages, _ALLOWED_MSG_KEYS)
|
|
id_map: dict[str, str] = {}
|
|
pending_tool_ids: dict[str, deque[str]] = {}
|
|
force_string_content = bool(self._spec and self._spec.name == "deepseek")
|
|
normalize_tool_ids = self._should_normalize_tool_call_ids()
|
|
|
|
def map_id(value: Any) -> Any:
|
|
if not isinstance(value, str):
|
|
return value
|
|
if not normalize_tool_ids:
|
|
return value
|
|
return id_map.setdefault(value, self._normalize_tool_call_id(value))
|
|
|
|
def unique_tool_id(value: Any, used_ids: set[str], idx: int) -> str:
|
|
if isinstance(value, str) and value:
|
|
base = map_id(value)
|
|
else:
|
|
base = _short_tool_id()
|
|
if not isinstance(base, str) or not base:
|
|
base = _short_tool_id()
|
|
if base not in used_ids:
|
|
return base
|
|
seed = value if isinstance(value, str) and value else base
|
|
salt = 1
|
|
while True:
|
|
candidate = self._normalize_tool_call_id(f"{seed}:{idx}:{salt}")
|
|
if isinstance(candidate, str) and candidate not in used_ids:
|
|
return candidate
|
|
salt += 1
|
|
|
|
def map_tool_result_id(value: Any) -> Any:
|
|
if not isinstance(value, str):
|
|
return value
|
|
queue = pending_tool_ids.get(value)
|
|
if queue:
|
|
mapped = queue.popleft()
|
|
if not queue:
|
|
pending_tool_ids.pop(value, None)
|
|
return mapped
|
|
return map_id(value)
|
|
|
|
for clean in sanitized:
|
|
if isinstance(clean.get("tool_calls"), list):
|
|
normalized = []
|
|
used_ids: set[str] = set()
|
|
for idx, tc in enumerate(clean["tool_calls"]):
|
|
if not isinstance(tc, dict):
|
|
normalized.append(tc)
|
|
continue
|
|
tc_clean = dict(tc)
|
|
raw_id = tc_clean.get("id")
|
|
mapped_id = unique_tool_id(raw_id, used_ids, idx)
|
|
tc_clean["id"] = mapped_id
|
|
used_ids.add(mapped_id)
|
|
if isinstance(raw_id, str) and raw_id:
|
|
pending_tool_ids.setdefault(raw_id, deque()).append(mapped_id)
|
|
function = tc_clean.get("function")
|
|
if isinstance(function, dict):
|
|
function_clean = dict(function)
|
|
if "arguments" in function_clean:
|
|
function_clean["arguments"] = self._normalize_tool_call_arguments(
|
|
function_clean.get("arguments")
|
|
)
|
|
else:
|
|
function_clean["arguments"] = "{}"
|
|
tc_clean["function"] = function_clean
|
|
normalized.append(tc_clean)
|
|
clean["tool_calls"] = normalized
|
|
if clean.get("role") == "assistant":
|
|
# Some OpenAI-compatible gateways reject assistant messages
|
|
# that mix non-empty content with tool_calls.
|
|
clean["content"] = None
|
|
if "tool_call_id" in clean and clean["tool_call_id"]:
|
|
clean["tool_call_id"] = map_tool_result_id(clean["tool_call_id"])
|
|
if (
|
|
force_string_content
|
|
and not (clean.get("role") == "assistant" and clean.get("tool_calls"))
|
|
):
|
|
clean["content"] = self._coerce_content_to_string(clean.get("content"))
|
|
return self._enforce_role_alternation(sanitized)
|
|
|
|
# ------------------------------------------------------------------
|
|
# Build kwargs
|
|
# ------------------------------------------------------------------
|
|
|
|
@staticmethod
|
|
def _supports_temperature(
|
|
model_name: str,
|
|
reasoning_effort: str | None = None,
|
|
) -> bool:
|
|
"""Return True when the model accepts a temperature parameter.
|
|
|
|
GPT-5 family and reasoning models (o1/o3/o4) reject temperature
|
|
when reasoning_effort is set to anything other than ``"none"``.
|
|
"""
|
|
if reasoning_effort and reasoning_effort.lower() != "none":
|
|
return False
|
|
name = model_name.lower()
|
|
return not any(token in name for token in ("gpt-5", "o1", "o3", "o4"))
|
|
|
|
def _build_kwargs(
|
|
self,
|
|
messages: list[dict[str, Any]],
|
|
tools: list[dict[str, Any]] | None,
|
|
model: str | None,
|
|
max_tokens: int,
|
|
temperature: float,
|
|
reasoning_effort: str | None,
|
|
tool_choice: str | dict[str, Any] | None,
|
|
) -> dict[str, Any]:
|
|
model_name = model or self.default_model
|
|
spec = self._spec
|
|
|
|
if spec and spec.supports_prompt_caching:
|
|
model_name = model or self.default_model
|
|
if any(model_name.lower().startswith(k) for k in ("anthropic/", "claude")):
|
|
messages, tools = self._apply_cache_control(messages, tools)
|
|
|
|
if spec and spec.strip_model_prefix:
|
|
model_name = model_name.split("/")[-1]
|
|
|
|
kwargs: dict[str, Any] = {
|
|
"model": model_name,
|
|
"messages": self._sanitize_messages(self._sanitize_empty_content(messages)),
|
|
}
|
|
|
|
# GPT-5 and reasoning models (o1/o3/o4) reject temperature when
|
|
# reasoning_effort is active. Only include it when safe.
|
|
if self._supports_temperature(model_name, reasoning_effort):
|
|
kwargs["temperature"] = temperature
|
|
|
|
if spec and getattr(spec, "supports_max_completion_tokens", False):
|
|
kwargs["max_completion_tokens"] = max(1, max_tokens)
|
|
else:
|
|
kwargs["max_tokens"] = max(1, max_tokens)
|
|
|
|
if spec:
|
|
model_lower = model_name.lower()
|
|
for pattern, overrides in spec.model_overrides:
|
|
if pattern in model_lower:
|
|
kwargs.update(overrides)
|
|
break
|
|
|
|
# Normalize reasoning_effort into a semantic form (OpenAI vocab)
|
|
# used for internal decisions, and a wire form actually sent out.
|
|
# "minimum" is accepted as a DashScope-native alias for "minimal".
|
|
semantic_effort: str | None = None
|
|
if isinstance(reasoning_effort, str):
|
|
semantic_effort = reasoning_effort.lower()
|
|
if semantic_effort == "minimum":
|
|
semantic_effort = "minimal"
|
|
|
|
wire_effort = reasoning_effort
|
|
if spec and spec.name == "dashscope" and semantic_effort == "minimal":
|
|
# DashScope accepts none/minimum/low/medium/high/xhigh; "minimal" 400s.
|
|
wire_effort = "minimum"
|
|
|
|
if wire_effort and semantic_effort != "none":
|
|
kwargs["reasoning_effort"] = wire_effort
|
|
|
|
# Only send thinking controls when reasoning_effort is explicit so
|
|
# omitting the config preserves each provider's default.
|
|
if reasoning_effort is not None:
|
|
thinking_enabled = semantic_effort not in ("none", "minimal")
|
|
for thinking_style in _thinking_styles_for(spec, model_name):
|
|
extra = _thinking_extra_body(thinking_style, thinking_enabled)
|
|
if extra:
|
|
kwargs.setdefault("extra_body", {}).update(extra)
|
|
gateway_style = getattr(spec, "gateway_reasoning_style", "") if spec else ""
|
|
if gateway_style and _model_thinking_style(model_name):
|
|
extra = _gateway_reasoning_extra_body(gateway_style, semantic_effort)
|
|
if extra:
|
|
kwargs.setdefault("extra_body", {}).update(extra)
|
|
|
|
# Moonshot rejects requests that carry both 'reasoning_effort'
|
|
# and the native 'thinking' param. We already expressed the
|
|
# user's intent via the provider-native shape, so drop the
|
|
# redundant wire-level kwarg. Only kimi models need this —
|
|
# Xiaomi's API accepts both params.
|
|
if _model_slug(model_name) in _KIMI_THINKING_MODELS:
|
|
kwargs.pop("reasoning_effort", None)
|
|
|
|
if tools:
|
|
kwargs["tools"] = tools
|
|
kwargs["tool_choice"] = tool_choice or "auto"
|
|
|
|
# Backfill reasoning_content="" on assistants missing it: DeepSeek
|
|
# thinking mode rejects history otherwise (#3554, #3584); "" reads
|
|
# as "no thinking that turn". DeepSeek-V4/reasoner reason natively,
|
|
# so backfill even without explicit reasoning_effort.
|
|
explicit_thinking = (
|
|
reasoning_effort is not None
|
|
and semantic_effort not in ("none", "minimal")
|
|
and (
|
|
(spec and spec.thinking_style)
|
|
or _model_thinking_style(model_name)
|
|
)
|
|
)
|
|
implicit_deepseek_thinking = (
|
|
spec is not None
|
|
and spec.name == "deepseek"
|
|
and semantic_effort not in ("none", "minimal", "minimum")
|
|
and any(t in model_name.lower() for t in ("deepseek-v4", "deepseek-reasoner"))
|
|
)
|
|
if explicit_thinking or implicit_deepseek_thinking:
|
|
for msg in kwargs["messages"]:
|
|
if msg.get("role") == "assistant" and "reasoning_content" not in msg:
|
|
msg["reasoning_content"] = ""
|
|
|
|
# Merge user-configured extra_body last so it can override or
|
|
# extend provider-specific defaults (e.g. chat_template_kwargs,
|
|
# guided_json, repetition_penalty). Uses recursive merge so
|
|
# nested dicts like {"chat_template_kwargs": {"enable_thinking": false}}
|
|
# do not clobber sibling keys already set by thinking-style logic.
|
|
if self._extra_body:
|
|
existing = kwargs.get("extra_body", {})
|
|
kwargs["extra_body"] = _deep_merge(existing, self._extra_body)
|
|
|
|
return kwargs
|
|
|
|
def _should_use_responses_api(
|
|
self,
|
|
model: str | None,
|
|
reasoning_effort: str | None,
|
|
) -> bool:
|
|
"""Use Responses API only for direct OpenAI requests that benefit from it."""
|
|
if self._api_type == "chat_completions":
|
|
return False
|
|
if self._spec and self._spec.name not in ("openai", "github_copilot"):
|
|
return False
|
|
if self._api_type == "responses":
|
|
# Explicit configuration means Responses is mandatory; do not
|
|
# consult the circuit breaker or fall back to Chat Completions.
|
|
return True
|
|
if self._spec is None or self._spec.name != "github_copilot":
|
|
if not _is_direct_openai_base(self._effective_base):
|
|
return False
|
|
|
|
model_name = (model or self.default_model).lower()
|
|
wants = False
|
|
if reasoning_effort and reasoning_effort.lower() != "none":
|
|
wants = True
|
|
elif any(token in model_name for token in ("gpt-5", "o1", "o3", "o4")):
|
|
wants = True
|
|
if not wants:
|
|
return False
|
|
|
|
return self._responses_circuit_allows_probe(model, reasoning_effort)
|
|
|
|
def _responses_circuit_allows_probe(
|
|
self,
|
|
model: str | None,
|
|
reasoning_effort: str | None,
|
|
) -> bool:
|
|
"""Return False when the Responses API circuit breaker is open."""
|
|
key = _responses_circuit_key(model, self.default_model, reasoning_effort)
|
|
failures = self._responses_failures.get(key, 0)
|
|
if failures >= _RESPONSES_FAILURE_THRESHOLD:
|
|
tripped = self._responses_tripped_at.get(key, 0.0)
|
|
if (time.monotonic() - tripped) < _RESPONSES_PROBE_INTERVAL_S:
|
|
return False
|
|
# Half-open: allow one probe attempt
|
|
return True
|
|
|
|
def _record_responses_failure(self, model: str | None, reasoning_effort: str | None) -> None:
|
|
key = _responses_circuit_key(model, self.default_model, reasoning_effort)
|
|
count = self._responses_failures.get(key, 0) + 1
|
|
self._responses_failures[key] = count
|
|
if count >= _RESPONSES_FAILURE_THRESHOLD:
|
|
self._responses_tripped_at[key] = time.monotonic()
|
|
logger.warning(
|
|
"Responses API circuit open for {} — falling back to Chat Completions",
|
|
key,
|
|
)
|
|
|
|
def _record_responses_success(self, model: str | None, reasoning_effort: str | None) -> None:
|
|
key = _responses_circuit_key(model, self.default_model, reasoning_effort)
|
|
self._responses_failures.pop(key, None)
|
|
self._responses_tripped_at.pop(key, None)
|
|
|
|
@staticmethod
|
|
def _should_fallback_from_responses_error(e: Exception) -> bool:
|
|
"""Fallback only for likely Responses API compatibility errors."""
|
|
response = getattr(e, "response", None)
|
|
status_code = getattr(e, "status_code", None)
|
|
if status_code is None and response is not None:
|
|
status_code = getattr(response, "status_code", None)
|
|
if status_code not in {400, 404, 422}:
|
|
return False
|
|
|
|
body = (
|
|
getattr(e, "body", None)
|
|
or getattr(e, "doc", None)
|
|
or getattr(response, "text", None)
|
|
)
|
|
body_text = str(body).lower() if body is not None else ""
|
|
compatibility_markers = (
|
|
"responses",
|
|
"response api",
|
|
"max_output_tokens",
|
|
"instructions",
|
|
"previous_response",
|
|
"unsupported",
|
|
"not supported",
|
|
"unknown parameter",
|
|
"unrecognized request argument",
|
|
)
|
|
return any(marker in body_text for marker in compatibility_markers)
|
|
|
|
def _build_responses_body(
|
|
self,
|
|
messages: list[dict[str, Any]],
|
|
tools: list[dict[str, Any]] | None,
|
|
model: str | None,
|
|
max_tokens: int,
|
|
temperature: float,
|
|
reasoning_effort: str | None,
|
|
tool_choice: str | dict[str, Any] | None,
|
|
) -> dict[str, Any]:
|
|
"""Build a Responses API body for direct OpenAI requests."""
|
|
model_name = model or self.default_model
|
|
if self._spec and self._spec.strip_model_prefix:
|
|
model_name = model_name.split("/")[-1]
|
|
sanitized_messages = self._sanitize_messages(self._sanitize_empty_content(messages))
|
|
instructions, input_items = convert_messages(sanitized_messages)
|
|
|
|
body: dict[str, Any] = {
|
|
"model": model_name,
|
|
"instructions": instructions or None,
|
|
"input": input_items,
|
|
"max_output_tokens": max(1, max_tokens),
|
|
"store": False,
|
|
"stream": False,
|
|
}
|
|
|
|
if self._supports_temperature(model_name, reasoning_effort):
|
|
body["temperature"] = temperature
|
|
|
|
if reasoning_effort and reasoning_effort.lower() != "none":
|
|
body["reasoning"] = {"effort": reasoning_effort}
|
|
body["include"] = ["reasoning.encrypted_content"]
|
|
|
|
if tools:
|
|
body["tools"] = convert_tools(tools)
|
|
body["tool_choice"] = tool_choice or "auto"
|
|
|
|
extra_body = getattr(self, "_extra_body", {})
|
|
if extra_body:
|
|
body = _merge_responses_extra_body(body, extra_body)
|
|
|
|
return body
|
|
|
|
# ------------------------------------------------------------------
|
|
# Response parsing
|
|
# ------------------------------------------------------------------
|
|
|
|
@staticmethod
|
|
def _maybe_mapping(value: Any) -> dict[str, Any] | None:
|
|
if isinstance(value, dict):
|
|
return value
|
|
model_dump = getattr(value, "model_dump", None)
|
|
if callable(model_dump):
|
|
dumped = model_dump()
|
|
if isinstance(dumped, dict):
|
|
return dumped
|
|
return None
|
|
|
|
@classmethod
|
|
def _extract_text_content(cls, value: Any) -> str | None:
|
|
if value is None:
|
|
return None
|
|
if isinstance(value, str):
|
|
return value
|
|
if isinstance(value, list):
|
|
parts: list[str] = []
|
|
for item in value:
|
|
item_map = cls._maybe_mapping(item)
|
|
if item_map:
|
|
text = item_map.get("text")
|
|
if isinstance(text, str):
|
|
parts.append(text)
|
|
continue
|
|
text = getattr(item, "text", None)
|
|
if isinstance(text, str):
|
|
parts.append(text)
|
|
continue
|
|
if isinstance(item, str):
|
|
parts.append(item)
|
|
return "".join(parts) or None
|
|
return str(value)
|
|
|
|
@classmethod
|
|
def _extract_usage(cls, response: Any) -> dict[str, int]:
|
|
"""Extract token usage from an OpenAI-compatible response.
|
|
|
|
Handles both dict-based (raw JSON) and object-based (SDK Pydantic)
|
|
responses. Provider-specific ``cached_tokens`` fields are normalised
|
|
under a single key; see the priority chain inside for details.
|
|
"""
|
|
# --- resolve usage object ---
|
|
usage_obj = None
|
|
response_map = cls._maybe_mapping(response)
|
|
if response_map is not None:
|
|
usage_obj = response_map.get("usage")
|
|
elif hasattr(response, "usage") and response.usage:
|
|
usage_obj = response.usage
|
|
|
|
usage_map = cls._maybe_mapping(usage_obj)
|
|
if usage_map is not None:
|
|
result = {
|
|
"prompt_tokens": int(usage_map.get("prompt_tokens") or 0),
|
|
"completion_tokens": int(usage_map.get("completion_tokens") or 0),
|
|
"total_tokens": int(usage_map.get("total_tokens") or 0),
|
|
}
|
|
elif usage_obj:
|
|
result = {
|
|
"prompt_tokens": getattr(usage_obj, "prompt_tokens", 0) or 0,
|
|
"completion_tokens": getattr(usage_obj, "completion_tokens", 0) or 0,
|
|
"total_tokens": getattr(usage_obj, "total_tokens", 0) or 0,
|
|
}
|
|
else:
|
|
return {}
|
|
|
|
# --- cached_tokens (normalised across providers) ---
|
|
# Try nested paths first (dict), fall back to attribute (SDK object).
|
|
# Priority order ensures the most specific field wins.
|
|
for path in (
|
|
("prompt_tokens_details", "cached_tokens"), # OpenAI/Zhipu/MiniMax/Qwen/Mistral/xAI
|
|
("cached_tokens",), # StepFun/Moonshot (top-level)
|
|
("prompt_cache_hit_tokens",), # DeepSeek/SiliconFlow
|
|
):
|
|
cached = cls._get_nested_int(usage_map, path)
|
|
if not cached and usage_obj:
|
|
cached = cls._get_nested_int(usage_obj, path)
|
|
if cached:
|
|
result["cached_tokens"] = cached
|
|
break
|
|
|
|
return result
|
|
|
|
@staticmethod
|
|
def _get_nested_int(obj: Any, path: tuple[str, ...]) -> int:
|
|
"""Drill into *obj* by *path* segments and return an ``int`` value.
|
|
|
|
Supports both dict-key access and attribute access so it works
|
|
uniformly with raw JSON dicts **and** SDK Pydantic models.
|
|
"""
|
|
current = obj
|
|
for segment in path:
|
|
if current is None:
|
|
return 0
|
|
if isinstance(current, dict):
|
|
current = current.get(segment)
|
|
else:
|
|
current = getattr(current, segment, None)
|
|
return int(current or 0) if current is not None else 0
|
|
|
|
def _parse(self, response: Any) -> LLMResponse:
|
|
if isinstance(response, str):
|
|
return LLMResponse(content=response, finish_reason="stop")
|
|
|
|
response_map = self._maybe_mapping(response)
|
|
if response_map is not None:
|
|
choices = response_map.get("choices") or []
|
|
if not choices:
|
|
content = self._extract_text_content(
|
|
response_map.get("content") or response_map.get("output_text")
|
|
)
|
|
reasoning_content = self._extract_text_content(
|
|
response_map.get("reasoning_content")
|
|
)
|
|
if content is not None:
|
|
return LLMResponse(
|
|
content=content,
|
|
reasoning_content=reasoning_content,
|
|
finish_reason=str(response_map.get("finish_reason") or "stop"),
|
|
usage=self._extract_usage(response_map),
|
|
)
|
|
return LLMResponse(content="Error: API returned empty choices.", finish_reason="error")
|
|
|
|
choice0 = self._maybe_mapping(choices[0]) or {}
|
|
msg0 = self._maybe_mapping(choice0.get("message")) or {}
|
|
content = self._extract_text_content(msg0.get("content"))
|
|
finish_reason = str(choice0.get("finish_reason") or "stop")
|
|
|
|
raw_tool_calls: list[Any] = []
|
|
# StepFun: fallback to reasoning field when content is empty
|
|
if not content and msg0.get("reasoning") and self._spec and self._spec.reasoning_as_content:
|
|
content = self._extract_text_content(msg0.get("reasoning"))
|
|
reasoning_content = msg0.get("reasoning_content")
|
|
if reasoning_content is None and msg0.get("reasoning"):
|
|
reasoning_content = self._extract_text_content(msg0.get("reasoning"))
|
|
for ch in choices:
|
|
ch_map = self._maybe_mapping(ch) or {}
|
|
m = self._maybe_mapping(ch_map.get("message")) or {}
|
|
tool_calls = m.get("tool_calls")
|
|
if isinstance(tool_calls, list) and tool_calls:
|
|
raw_tool_calls.extend(tool_calls)
|
|
if ch_map.get("finish_reason") in ("tool_calls", "stop"):
|
|
finish_reason = str(ch_map["finish_reason"])
|
|
if not content:
|
|
content = self._extract_text_content(m.get("content"))
|
|
if reasoning_content is None:
|
|
reasoning_content = m.get("reasoning_content")
|
|
|
|
parsed_tool_calls = []
|
|
for tc in raw_tool_calls:
|
|
tc_map = self._maybe_mapping(tc) or {}
|
|
fn = self._maybe_mapping(tc_map.get("function")) or {}
|
|
args = fn.get("arguments", {})
|
|
if isinstance(args, str):
|
|
args = json_repair.loads(args)
|
|
ec, prov, fn_prov = _extract_tc_extras(tc)
|
|
parsed_tool_calls.append(ToolCallRequest(
|
|
id=str(tc_map.get("id") or _short_tool_id()),
|
|
name=str(fn.get("name") or ""),
|
|
arguments=args if isinstance(args, dict) else {},
|
|
extra_content=ec,
|
|
provider_specific_fields=prov,
|
|
function_provider_specific_fields=fn_prov,
|
|
))
|
|
|
|
return LLMResponse(
|
|
content=content,
|
|
tool_calls=parsed_tool_calls,
|
|
finish_reason=finish_reason,
|
|
usage=self._extract_usage(response_map),
|
|
reasoning_content=reasoning_content if isinstance(reasoning_content, str) else None,
|
|
)
|
|
|
|
if not response.choices:
|
|
return LLMResponse(content="Error: API returned empty choices.", finish_reason="error")
|
|
|
|
choice = response.choices[0]
|
|
msg = choice.message
|
|
content = msg.content
|
|
finish_reason = choice.finish_reason
|
|
|
|
raw_tool_calls: list[Any] = []
|
|
for ch in response.choices:
|
|
m = ch.message
|
|
if hasattr(m, "tool_calls") and m.tool_calls:
|
|
raw_tool_calls.extend(m.tool_calls)
|
|
if ch.finish_reason in ("tool_calls", "stop"):
|
|
finish_reason = ch.finish_reason
|
|
if not content and m.content:
|
|
content = m.content
|
|
if not content and getattr(m, "reasoning", None) and self._spec and self._spec.reasoning_as_content:
|
|
content = m.reasoning
|
|
|
|
tool_calls = []
|
|
for tc in raw_tool_calls:
|
|
args = tc.function.arguments
|
|
if isinstance(args, str):
|
|
args = json_repair.loads(args)
|
|
ec, prov, fn_prov = _extract_tc_extras(tc)
|
|
tool_calls.append(ToolCallRequest(
|
|
id=str(getattr(tc, "id", None) or _short_tool_id()),
|
|
name=tc.function.name,
|
|
arguments=args,
|
|
extra_content=ec,
|
|
provider_specific_fields=prov,
|
|
function_provider_specific_fields=fn_prov,
|
|
))
|
|
|
|
reasoning_content = getattr(msg, "reasoning_content", None)
|
|
if reasoning_content is None and getattr(msg, "reasoning", None):
|
|
reasoning_content = msg.reasoning
|
|
|
|
return LLMResponse(
|
|
content=content,
|
|
tool_calls=tool_calls,
|
|
finish_reason=finish_reason or "stop",
|
|
usage=self._extract_usage(response),
|
|
reasoning_content=reasoning_content,
|
|
)
|
|
|
|
@classmethod
|
|
def _parse_chunks(cls, chunks: list[Any]) -> LLMResponse:
|
|
content_parts: list[str] = []
|
|
reasoning_parts: list[str] = []
|
|
tc_bufs: dict[int, dict[str, Any]] = {}
|
|
finish_reason = "stop"
|
|
usage: dict[str, int] = {}
|
|
|
|
def _accum_tc(tc: Any, idx_hint: int) -> None:
|
|
"""Accumulate one streaming tool-call delta into *tc_bufs*."""
|
|
tc_index: int = _get(tc, "index") if _get(tc, "index") is not None else idx_hint
|
|
buf = tc_bufs.setdefault(tc_index, {
|
|
"id": "", "name": "", "arguments": "",
|
|
"extra_content": None, "prov": None, "fn_prov": None,
|
|
})
|
|
tc_id = _get(tc, "id")
|
|
if tc_id:
|
|
buf["id"] = str(tc_id)
|
|
fn = _get(tc, "function")
|
|
if fn is not None:
|
|
fn_name = _get(fn, "name")
|
|
if fn_name:
|
|
buf["name"] = str(fn_name)
|
|
fn_args = _get(fn, "arguments")
|
|
if fn_args:
|
|
buf["arguments"] += str(fn_args)
|
|
ec, prov, fn_prov = _extract_tc_extras(tc)
|
|
if ec:
|
|
buf["extra_content"] = ec
|
|
if prov:
|
|
buf["prov"] = prov
|
|
if fn_prov:
|
|
buf["fn_prov"] = fn_prov
|
|
|
|
def _accum_legacy_function_call(function_call: Any) -> None:
|
|
"""Accumulate legacy ``delta.function_call`` streaming chunks."""
|
|
if not function_call:
|
|
return
|
|
buf = tc_bufs.setdefault(0, {
|
|
"id": "", "name": "", "arguments": "",
|
|
"extra_content": None, "prov": None, "fn_prov": None,
|
|
})
|
|
fn_name = _get(function_call, "name")
|
|
if fn_name:
|
|
buf["name"] = str(fn_name)
|
|
fn_args = _get(function_call, "arguments")
|
|
if fn_args:
|
|
buf["arguments"] += str(fn_args)
|
|
|
|
for chunk in chunks:
|
|
if isinstance(chunk, str):
|
|
content_parts.append(chunk)
|
|
continue
|
|
|
|
chunk_map = cls._maybe_mapping(chunk)
|
|
if chunk_map is not None:
|
|
choices = chunk_map.get("choices") or []
|
|
if not choices:
|
|
usage = cls._extract_usage(chunk_map) or usage
|
|
text = cls._extract_text_content(
|
|
chunk_map.get("content") or chunk_map.get("output_text")
|
|
)
|
|
if text:
|
|
content_parts.append(text)
|
|
continue
|
|
choice = cls._maybe_mapping(choices[0]) or {}
|
|
if choice.get("finish_reason"):
|
|
finish_reason = str(choice["finish_reason"])
|
|
delta = cls._maybe_mapping(choice.get("delta")) or {}
|
|
text = cls._extract_text_content(delta.get("content"))
|
|
if text:
|
|
content_parts.append(text)
|
|
text = cls._extract_text_content(delta.get("reasoning_content"))
|
|
if not text:
|
|
text = cls._extract_text_content(delta.get("reasoning"))
|
|
if text:
|
|
reasoning_parts.append(text)
|
|
for idx, tc in enumerate(delta.get("tool_calls") or []):
|
|
_accum_tc(tc, idx)
|
|
_accum_legacy_function_call(delta.get("function_call"))
|
|
usage = cls._extract_usage(chunk_map) or usage
|
|
continue
|
|
|
|
if not chunk.choices:
|
|
usage = cls._extract_usage(chunk) or usage
|
|
continue
|
|
choice = chunk.choices[0]
|
|
if choice.finish_reason:
|
|
finish_reason = choice.finish_reason
|
|
delta = choice.delta
|
|
if delta and delta.content:
|
|
content_parts.append(delta.content)
|
|
if delta:
|
|
reasoning = getattr(delta, "reasoning_content", None)
|
|
if not reasoning:
|
|
reasoning = getattr(delta, "reasoning", None)
|
|
if reasoning:
|
|
reasoning_parts.append(reasoning)
|
|
for tc in (getattr(delta, "tool_calls", None) or []) if delta else []:
|
|
_accum_tc(tc, getattr(tc, "index", 0))
|
|
if delta:
|
|
_accum_legacy_function_call(getattr(delta, "function_call", None))
|
|
|
|
# Some providers (e.g. Zhipu/GLM) reuse the same tool_call id for
|
|
# parallel tool calls in streaming mode. Deduplicate before building
|
|
# the response so downstream tool messages don't collide.
|
|
_seen_tc_ids: set[str] = set()
|
|
for b in tc_bufs.values():
|
|
if not b["id"] or b["id"] in _seen_tc_ids:
|
|
b["id"] = _short_tool_id()
|
|
_seen_tc_ids.add(b["id"])
|
|
|
|
return LLMResponse(
|
|
content="".join(content_parts) or None,
|
|
tool_calls=[
|
|
ToolCallRequest(
|
|
id=b["id"] or _short_tool_id(),
|
|
name=b["name"],
|
|
arguments=json_repair.loads(b["arguments"]) if b["arguments"] else {},
|
|
extra_content=b.get("extra_content"),
|
|
provider_specific_fields=b.get("prov"),
|
|
function_provider_specific_fields=b.get("fn_prov"),
|
|
)
|
|
for b in tc_bufs.values()
|
|
],
|
|
finish_reason=finish_reason,
|
|
usage=usage,
|
|
reasoning_content="".join(reasoning_parts) or None,
|
|
)
|
|
|
|
@classmethod
|
|
def _extract_error_metadata(cls, e: Exception) -> dict[str, Any]:
|
|
response = getattr(e, "response", None)
|
|
headers = getattr(response, "headers", None)
|
|
payload = (
|
|
getattr(e, "body", None)
|
|
or getattr(e, "doc", None)
|
|
or getattr(response, "text", None)
|
|
)
|
|
if payload is None and response is not None:
|
|
response_json = getattr(response, "json", None)
|
|
if callable(response_json):
|
|
try:
|
|
payload = response_json()
|
|
except Exception:
|
|
payload = None
|
|
error_type, error_code = LLMProvider._extract_error_type_code(payload)
|
|
|
|
status_code = getattr(e, "status_code", None)
|
|
if status_code is None and response is not None:
|
|
status_code = getattr(response, "status_code", None)
|
|
|
|
should_retry: bool | None = None
|
|
if headers is not None:
|
|
raw = headers.get("x-should-retry")
|
|
if isinstance(raw, str):
|
|
lowered = raw.strip().lower()
|
|
if lowered == "true":
|
|
should_retry = True
|
|
elif lowered == "false":
|
|
should_retry = False
|
|
|
|
error_kind: str | None = None
|
|
error_name = e.__class__.__name__.lower()
|
|
if "timeout" in error_name:
|
|
error_kind = "timeout"
|
|
elif "connection" in error_name:
|
|
error_kind = "connection"
|
|
|
|
return {
|
|
"error_status_code": int(status_code) if status_code is not None else None,
|
|
"error_kind": error_kind,
|
|
"error_type": error_type,
|
|
"error_code": error_code,
|
|
"error_retry_after_s": cls._extract_retry_after_from_headers(headers),
|
|
"error_should_retry": should_retry,
|
|
}
|
|
|
|
@staticmethod
|
|
def _handle_error(
|
|
e: Exception,
|
|
*,
|
|
spec: ProviderSpec | None = None,
|
|
api_base: str | None = None,
|
|
) -> LLMResponse:
|
|
body = (
|
|
getattr(e, "doc", None)
|
|
or getattr(e, "body", None)
|
|
or getattr(getattr(e, "response", None), "text", None)
|
|
)
|
|
body_text = body if isinstance(body, str) else str(body) if body is not None else ""
|
|
msg = f"Error: {body_text.strip()[:500]}" if body_text.strip() else f"Error calling LLM: {e}"
|
|
|
|
text = f"{body_text} {e}".lower()
|
|
if spec and spec.is_local and ("502" in text or "connection" in text or "refused" in text):
|
|
msg += (
|
|
"\nHint: this is a local model endpoint. Check that the local server is reachable at "
|
|
f"{api_base or spec.default_api_base}, and if you are using a proxy/tunnel, make sure it "
|
|
"can reach your local Ollama/vLLM service instead of routing localhost through the remote host."
|
|
)
|
|
|
|
response = getattr(e, "response", None)
|
|
retry_after = LLMProvider._extract_retry_after_from_headers(getattr(response, "headers", None))
|
|
if retry_after is None:
|
|
retry_after = LLMProvider._extract_retry_after(msg)
|
|
return LLMResponse(
|
|
content=msg,
|
|
finish_reason="error",
|
|
retry_after=retry_after,
|
|
**OpenAICompatProvider._extract_error_metadata(e),
|
|
)
|
|
|
|
# ------------------------------------------------------------------
|
|
# Public API
|
|
# ------------------------------------------------------------------
|
|
|
|
async def chat(
|
|
self,
|
|
messages: list[dict[str, Any]],
|
|
tools: list[dict[str, Any]] | None = None,
|
|
model: str | None = None,
|
|
max_tokens: int = 4096,
|
|
temperature: float = 0.7,
|
|
reasoning_effort: str | None = None,
|
|
tool_choice: str | dict[str, Any] | None = None,
|
|
) -> LLMResponse:
|
|
await self._ensure_client()
|
|
try:
|
|
if self._should_use_responses_api(model, reasoning_effort):
|
|
try:
|
|
body = self._build_responses_body(
|
|
messages, tools, model, max_tokens, temperature,
|
|
reasoning_effort, tool_choice,
|
|
)
|
|
result = parse_response_output(await self._client.responses.create(**body))
|
|
self._record_responses_success(model, reasoning_effort)
|
|
return result
|
|
except Exception as responses_error:
|
|
if self._spec and self._spec.name == "github_copilot":
|
|
# Copilot gateway exposes GPT-5/o-series only via /responses;
|
|
# falling back to /chat/completions cannot succeed and would
|
|
# hide the real error.
|
|
raise
|
|
if self._api_type == "responses":
|
|
raise
|
|
if not self._should_fallback_from_responses_error(responses_error):
|
|
raise
|
|
self._record_responses_failure(model, reasoning_effort)
|
|
|
|
kwargs = self._build_kwargs(
|
|
messages, tools, model, max_tokens, temperature,
|
|
reasoning_effort, tool_choice,
|
|
)
|
|
return self._parse(await self._client.chat.completions.create(**kwargs))
|
|
except Exception as e:
|
|
return self._handle_error(e, spec=self._spec, api_base=self.api_base)
|
|
|
|
async def chat_stream(
|
|
self,
|
|
messages: list[dict[str, Any]],
|
|
tools: list[dict[str, Any]] | None = None,
|
|
model: str | None = None,
|
|
max_tokens: int = 4096,
|
|
temperature: float = 0.7,
|
|
reasoning_effort: str | None = None,
|
|
tool_choice: str | dict[str, Any] | None = None,
|
|
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
|
on_thinking_delta: Callable[[str], Awaitable[None]] | None = None,
|
|
on_tool_call_delta: Callable[[dict[str, Any]], Awaitable[None]] | None = None,
|
|
) -> LLMResponse:
|
|
await self._ensure_client()
|
|
idle_timeout_s = int(os.environ.get("NANOBOT_STREAM_IDLE_TIMEOUT_S", "90"))
|
|
try:
|
|
if self._should_use_responses_api(model, reasoning_effort):
|
|
try:
|
|
body = self._build_responses_body(
|
|
messages, tools, model, max_tokens, temperature,
|
|
reasoning_effort, tool_choice,
|
|
)
|
|
body["stream"] = True
|
|
stream = await self._client.responses.create(**body)
|
|
|
|
async def _timed_stream():
|
|
stream_iter = stream.__aiter__()
|
|
while True:
|
|
try:
|
|
yield await asyncio.wait_for(
|
|
stream_iter.__anext__(),
|
|
timeout=idle_timeout_s,
|
|
)
|
|
except StopAsyncIteration:
|
|
break
|
|
|
|
(
|
|
content,
|
|
tool_calls,
|
|
finish_reason,
|
|
usage,
|
|
reasoning_content,
|
|
) = await consume_sdk_stream(
|
|
_timed_stream(),
|
|
on_content_delta,
|
|
on_tool_call_delta=on_tool_call_delta,
|
|
)
|
|
self._record_responses_success(model, reasoning_effort)
|
|
return LLMResponse(
|
|
content=content or None,
|
|
tool_calls=tool_calls,
|
|
finish_reason=finish_reason,
|
|
usage=usage,
|
|
reasoning_content=reasoning_content,
|
|
)
|
|
except Exception as responses_error:
|
|
if self._spec and self._spec.name == "github_copilot":
|
|
# Copilot gateway exposes GPT-5/o-series only via /responses;
|
|
# falling back to /chat/completions cannot succeed and would
|
|
# hide the real error.
|
|
raise
|
|
if self._api_type == "responses":
|
|
raise
|
|
if not self._should_fallback_from_responses_error(responses_error):
|
|
raise
|
|
self._record_responses_failure(model, reasoning_effort)
|
|
|
|
kwargs = self._build_kwargs(
|
|
messages, tools, model, max_tokens, temperature,
|
|
reasoning_effort, tool_choice,
|
|
)
|
|
if self._spec and self._spec.name == "zhipu" and tools and on_tool_call_delta:
|
|
# Z.AI/GLM keeps streaming tool-call arguments behind an
|
|
# explicit provider flag. Pass it through the OpenAI SDK's
|
|
# extra_body escape hatch so the usual delta.tool_calls path
|
|
# can surface live file-edit progress.
|
|
kwargs.setdefault("extra_body", {})["tool_stream"] = True
|
|
kwargs["stream"] = True
|
|
kwargs["stream_options"] = {"include_usage": True}
|
|
stream = await self._client.chat.completions.create(**kwargs)
|
|
chunks: list[Any] = []
|
|
stream_iter = stream.__aiter__()
|
|
while True:
|
|
try:
|
|
chunk = await asyncio.wait_for(
|
|
stream_iter.__anext__(),
|
|
timeout=idle_timeout_s,
|
|
)
|
|
except StopAsyncIteration:
|
|
break
|
|
chunks.append(chunk)
|
|
if chunk.choices:
|
|
delta_obj = chunk.choices[0].delta
|
|
if on_content_delta:
|
|
text = getattr(delta_obj, "content", None)
|
|
if text:
|
|
await on_content_delta(text)
|
|
if on_thinking_delta:
|
|
reasoning = getattr(delta_obj, "reasoning_content", None) or getattr(
|
|
delta_obj, "reasoning", None,
|
|
)
|
|
r_text = self._extract_text_content(reasoning)
|
|
if r_text:
|
|
await on_thinking_delta(r_text)
|
|
if on_tool_call_delta:
|
|
for idx, tool_delta in enumerate(
|
|
getattr(delta_obj, "tool_calls", None) or []
|
|
):
|
|
fn = _get(tool_delta, "function")
|
|
tool_index = _get(tool_delta, "index")
|
|
await on_tool_call_delta({
|
|
"index": tool_index if tool_index is not None else idx,
|
|
"call_id": str(_get(tool_delta, "id") or ""),
|
|
"name": str(_get(fn, "name") or "") if fn is not None else "",
|
|
"arguments_delta": (
|
|
str(_get(fn, "arguments") or "") if fn is not None else ""
|
|
),
|
|
})
|
|
function_call = getattr(delta_obj, "function_call", None)
|
|
if function_call:
|
|
await on_tool_call_delta({
|
|
"index": 0,
|
|
"call_id": "",
|
|
"name": str(_get(function_call, "name") or ""),
|
|
"arguments_delta": str(_get(function_call, "arguments") or ""),
|
|
})
|
|
return self._parse_chunks(chunks)
|
|
except asyncio.TimeoutError:
|
|
return LLMResponse(
|
|
content=(
|
|
f"Error calling LLM: stream stalled for more than "
|
|
f"{idle_timeout_s} seconds"
|
|
),
|
|
finish_reason="error",
|
|
error_kind="timeout",
|
|
)
|
|
except Exception as e:
|
|
return self._handle_error(e, spec=self._spec, api_base=self.api_base)
|
|
|
|
def get_default_model(self) -> str:
|
|
return self.default_model
|