nanobot/nanobot/providers/openai_compat_provider.py
axelray-dev 28f3a20d64 feat(providers): add extra_query config for OpenAI-compatible providers
Adds ProviderConfig.extra_query, threaded into AsyncOpenAI(default_query)
so that Azure-style gateways requiring query params like api-version can
be configured without URL hacks.

Also updates provider_signature to track extra_query changes so per-turn
refresh rebuilds the provider when the value changes.

Addresses the extra_query portion of #4204. The max_completion_tokens
model-awareness enhancement is intentionally left separate.
2026-06-09 03:18:14 +08:00

1490 lines
60 KiB
Python

"""OpenAI-compatible provider for all non-Anthropic LLM APIs."""
from __future__ import annotations
import asyncio
import hashlib
import importlib.util
import json
import os
import secrets
import string
import time
import uuid
from collections import deque
from collections.abc import Awaitable, Callable
from ipaddress import ip_address
from typing import TYPE_CHECKING, Any
from urllib.parse import urlparse
import json_repair
from loguru import logger
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
from nanobot.providers.openai_responses import (
consume_sdk_stream,
convert_messages,
convert_tools,
parse_response_output,
)
if TYPE_CHECKING:
from openai import AsyncOpenAI as AsyncOpenAIType
from nanobot.providers.registry import ProviderSpec
# Module-level placeholder — set lazily by _ensure_client on first real
# use, or replaced by tests via ``patch(...)``. Kept as a plain name so
# that ``unittest.mock.patch`` can find and replace it.
AsyncOpenAI: Any = None
_ALLOWED_MSG_KEYS = frozenset({
"role", "content", "tool_calls", "tool_call_id", "name",
"reasoning_content", "extra_content",
})
_ALNUM = string.ascii_letters + string.digits
_STANDARD_TC_KEYS = frozenset({"id", "type", "index", "function"})
_STANDARD_FN_KEYS = frozenset({"name", "arguments"})
_DEFAULT_OPENROUTER_HEADERS = {
"HTTP-Referer": "https://github.com/HKUDS/nanobot",
"X-OpenRouter-Title": "nanobot",
"X-OpenRouter-Categories": "cli-agent,personal-agent",
}
_KIMI_THINKING_MODELS: frozenset[str] = frozenset({
"kimi-k2.5",
"kimi-k2.6",
"k2.6-code-preview",
})
# Thinking-capable MiMo models per Xiaomi docs (see
# tests/providers/test_xiaomi_mimo_thinking.py). mimo-v2-flash is omitted
# because it does not support thinking.
_MIMO_THINKING_MODELS: frozenset[str] = frozenset({
"mimo-v2.5-pro",
"mimo-v2.5",
"mimo-v2-pro",
"mimo-v2-omni",
})
_OPENAI_COMPAT_REQUEST_TIMEOUT_S = 120.0
# Maps ProviderSpec.thinking_style → extra_body builder.
# Each builder takes a bool (thinking_enabled) and returns the dict to
# merge into extra_body, keeping the style→wire-format mapping in one place.
_THINKING_STYLE_MAP: dict[str, Any] = {
"thinking_type": lambda on: {"thinking": {"type": "enabled" if on else "disabled"}},
"enable_thinking": lambda on: {"enable_thinking": on},
"reasoning_split": lambda on: {"reasoning_split": on},
}
_GATEWAY_REASONING_STYLE_MAP: dict[str, Any] = {
"reasoning_effort": lambda effort: {"reasoning": {"effort": effort}},
}
_MODEL_THINKING_STYLES: dict[str, str] = {
**dict.fromkeys(_KIMI_THINKING_MODELS, "thinking_type"),
**dict.fromkeys(_MIMO_THINKING_MODELS, "thinking_type"),
}
def _model_slug(model_name: str) -> str:
return model_name.lower().rsplit("/", 1)[-1]
def _model_thinking_style(model_name: str) -> str:
return _MODEL_THINKING_STYLES.get(_model_slug(model_name), "")
def _thinking_styles_for(spec: ProviderSpec | None, model_name: str) -> list[str]:
styles: list[str] = []
if spec and spec.thinking_style:
styles.append(spec.thinking_style)
model_style = _model_thinking_style(model_name)
if model_style and model_style not in styles:
styles.append(model_style)
return styles
def _thinking_extra_body(style: str, thinking_enabled: bool) -> dict[str, Any] | None:
builder = _THINKING_STYLE_MAP.get(style)
return builder(thinking_enabled) if builder else None
def _gateway_reasoning_extra_body(style: str, effort: str | None) -> dict[str, Any] | None:
if not effort:
return None
builder = _GATEWAY_REASONING_STYLE_MAP.get(style)
return builder(effort) if builder else None
def _openai_compat_timeout_s() -> float:
"""Return the bounded request timeout used for OpenAI-compatible providers."""
return _float_env("NANOBOT_OPENAI_COMPAT_TIMEOUT_S", _OPENAI_COMPAT_REQUEST_TIMEOUT_S)
def _float_env(name: str, default: float) -> float:
raw = os.environ.get(name)
if raw is None or not raw.strip():
return default
try:
value = float(raw)
except (TypeError, ValueError):
logger.warning("Ignoring invalid {}={!r}; using {}", name, raw, default)
return default
if value <= 0:
logger.warning("Ignoring non-positive {}={!r}; using {}", name, raw, default)
return default
return value
def _short_tool_id() -> str:
"""9-char alphanumeric ID compatible with all providers (incl. Mistral)."""
return "".join(secrets.choice(_ALNUM) for _ in range(9))
def _get(obj: Any, key: str) -> Any:
"""Get a value from dict or object attribute, returning None if absent."""
if isinstance(obj, dict):
return obj.get(key)
return getattr(obj, key, None)
def _coerce_dict(value: Any) -> dict[str, Any] | None:
"""Try to coerce *value* to a dict; return None if not possible or empty."""
if value is None:
return None
if isinstance(value, dict):
return value if value else None
model_dump = getattr(value, "model_dump", None)
if callable(model_dump):
dumped = model_dump()
if isinstance(dumped, dict) and dumped:
return dumped
return None
def _extract_tc_extras(tc: Any) -> tuple[
dict[str, Any] | None,
dict[str, Any] | None,
dict[str, Any] | None,
]:
"""Extract (extra_content, provider_specific_fields, fn_provider_specific_fields).
Works for both SDK objects and dicts. Captures Gemini ``extra_content``
verbatim and any non-standard keys on the tool-call / function.
"""
extra_content = _coerce_dict(_get(tc, "extra_content"))
tc_dict = _coerce_dict(tc)
prov = None
fn_prov = None
if tc_dict is not None:
leftover = {k: v for k, v in tc_dict.items()
if k not in _STANDARD_TC_KEYS and k != "extra_content" and v is not None}
if leftover:
prov = leftover
fn = _coerce_dict(tc_dict.get("function"))
if fn is not None:
fn_leftover = {k: v for k, v in fn.items()
if k not in _STANDARD_FN_KEYS and v is not None}
if fn_leftover:
fn_prov = fn_leftover
else:
prov = _coerce_dict(_get(tc, "provider_specific_fields"))
fn_obj = _get(tc, "function")
if fn_obj is not None:
fn_prov = _coerce_dict(_get(fn_obj, "provider_specific_fields"))
return extra_content, prov, fn_prov
def _uses_openrouter_attribution(spec: "ProviderSpec | None", api_base: str | None) -> bool:
"""Apply Nanobot attribution headers to OpenRouter requests by default."""
if spec and spec.name == "openrouter":
return True
return bool(api_base and "openrouter" in api_base.lower())
_RESPONSES_FAILURE_THRESHOLD = 3
_RESPONSES_PROBE_INTERVAL_S = 300 # 5 minutes
def _is_local_endpoint(
spec: "ProviderSpec | None",
api_base: str | None,
) -> bool:
"""Return True when the endpoint is a local or LAN model server.
Matches either the provider spec's ``is_local`` flag or common private-
network patterns in the base URL (localhost, 127.x, 192.168.x, 10.x,
172.16-31.x, Docker ``host.docker.internal``).
"""
if spec and spec.is_local:
return True
if not api_base:
return False
raw = api_base.strip().lower()
parsed = urlparse(raw if "://" in raw else f"//{raw}")
try:
host = parsed.hostname
except ValueError:
return False
if host in {"localhost", "host.docker.internal"}:
return True
if not host:
return False
try:
addr = ip_address(host)
except ValueError:
return False
return addr.is_loopback or addr.is_private
def _is_direct_openai_base(api_base: str | None) -> bool:
"""Return True for direct OpenAI endpoints, not generic OpenAI-compatible gateways."""
if not api_base:
return True
normalized = api_base.strip().lower().rstrip("/")
return "api.openai.com" in normalized and "openrouter" not in normalized
def _responses_circuit_key(
model: str | None,
default_model: str,
reasoning_effort: str | None,
) -> str:
model_name = (model or default_model).lower()
effort = reasoning_effort.lower() if isinstance(reasoning_effort, str) else ""
return f"{model_name}:{effort}"
def _deep_merge(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]:
"""Recursively merge *override* into *base*, returning a new dict.
Nested dicts are merged key-by-key; all other types in *override*
replace the corresponding key in *base*.
"""
merged = dict(base)
for key, value in override.items():
if (
key in merged
and isinstance(merged[key], dict)
and isinstance(value, dict)
):
merged[key] = _deep_merge(merged[key], value)
else:
merged[key] = value
return merged
def _merge_unique_list(base: Any, override: Any) -> Any:
"""Append list values while preserving order and removing duplicates."""
if not isinstance(base, list) or not isinstance(override, list):
return override
result: list[Any] = []
seen: set[str] = set()
for value in [*base, *override]:
try:
key = json.dumps(value, sort_keys=True, ensure_ascii=False)
except Exception:
key = repr(value)
if key in seen:
continue
seen.add(key)
result.append(value)
return result
def _merge_responses_extra_body(
body: dict[str, Any],
extra_body: dict[str, Any],
) -> dict[str, Any]:
"""Merge configured Responses API body fields without clobbering tools."""
reserved = {"include", "tools"}
regular_extra = {key: value for key, value in extra_body.items() if key not in reserved}
merged = _deep_merge(body, regular_extra)
if "include" in extra_body:
merged["include"] = _merge_unique_list(body.get("include"), extra_body["include"])
if "tools" in extra_body:
current_tools = body.get("tools")
configured_tools = extra_body["tools"]
if isinstance(current_tools, list) and isinstance(configured_tools, list):
merged["tools"] = [*current_tools, *configured_tools]
else:
merged["tools"] = configured_tools
return merged
class OpenAICompatProvider(LLMProvider):
"""Unified provider for all OpenAI-compatible APIs.
Receives a resolved ``ProviderSpec`` from the caller — no internal
registry lookups needed.
"""
def __init__(
self,
api_key: str | None = None,
api_base: str | None = None,
default_model: str = "gpt-4o",
extra_headers: dict[str, str] | None = None,
spec: ProviderSpec | None = None,
extra_body: dict[str, Any] | None = None,
api_type: str = "auto",
extra_query: dict[str, str] | None = None,
):
super().__init__(api_key, api_base)
self.default_model = default_model
self.extra_headers = extra_headers or {}
self._spec = spec
self._extra_body = extra_body or {}
self._api_type = api_type if spec and spec.name == "openai" else "auto"
self._extra_query = extra_query or {}
if api_key and spec and spec.env_key:
self._setup_env(api_key, api_base)
effective_base = api_base or (spec.default_api_base if spec else None) or None
self._effective_base = effective_base
self._default_headers = {"x-session-affinity": uuid.uuid4().hex}
if _uses_openrouter_attribution(spec, effective_base):
self._default_headers.update(_DEFAULT_OPENROUTER_HEADERS)
if extra_headers:
self._default_headers.update(extra_headers)
self._api_key_for_client = api_key or "no-key"
self._is_local = _is_local_endpoint(spec, effective_base)
# Lazy-init: the OpenAI client and its httpx transport are expensive
# to create (~700 ms on Windows). Defer until first use.
self._client: AsyncOpenAIType | None = None
self._client_lock = asyncio.Lock()
# Responses API circuit breaker: skip after repeated failures,
# probe again after _RESPONSES_PROBE_INTERVAL_S seconds.
self._responses_failures: dict[str, int] = {}
self._responses_tripped_at: dict[str, float] = {}
def _build_client(self) -> None:
"""Create the OpenAI client using the current module-level AsyncOpenAI."""
import httpx
timeout_s = _openai_compat_timeout_s()
http_client: httpx.AsyncClient | None = None
if self._is_local:
# Local model servers (Ollama, llama.cpp, vLLM) often close idle
# HTTP connections before the client-side keepalive expires. When
# two LLM calls happen seconds apart (e.g. heartbeat _decide then
# process_direct), the second call may grab a now-dead pooled
# connection, causing a transient APIConnectionError on every first
# attempt. Disabling keepalive for local endpoints avoids this by
# opening a fresh connection for each request, which is cheap on a
# LAN. Cloud providers benefit from keepalive, so we leave the
# default pool settings for them.
http_client = httpx.AsyncClient(
limits=httpx.Limits(keepalive_expiry=0),
timeout=timeout_s,
)
self._client = AsyncOpenAI(
api_key=self._api_key_for_client,
base_url=self._effective_base,
default_headers=self._default_headers,
default_query=self._extra_query or None,
max_retries=0,
timeout=timeout_s,
http_client=http_client,
)
async def _ensure_client(self):
"""Return the shared OpenAI client, creating it on first call."""
if self._client is not None:
return self._client
async with self._client_lock:
if self._client is not None:
return self._client
global AsyncOpenAI
if AsyncOpenAI is None:
if os.environ.get("LANGFUSE_SECRET_KEY") and importlib.util.find_spec("langfuse"):
from langfuse.openai import AsyncOpenAI as _AsyncOpenAI
else:
if os.environ.get("LANGFUSE_SECRET_KEY"):
logger.warning(
"LANGFUSE_SECRET_KEY is set but langfuse is not installed; "
"install with `pip install langfuse` to enable tracing"
)
from openai import AsyncOpenAI as _AsyncOpenAI
AsyncOpenAI = _AsyncOpenAI
self._build_client()
return self._client
def _setup_env(self, api_key: str, api_base: str | None) -> None:
"""Set environment variables based on provider spec."""
spec = self._spec
if not spec or not spec.env_key:
return
if spec.is_gateway:
os.environ[spec.env_key] = api_key
else:
os.environ.setdefault(spec.env_key, api_key)
effective_base = api_base or spec.default_api_base
for env_name, env_val in spec.env_extras:
resolved = env_val.replace("{api_key}", api_key).replace("{api_base}", effective_base)
os.environ.setdefault(env_name, resolved)
@classmethod
def _apply_cache_control(
cls,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None,
) -> tuple[list[dict[str, Any]], list[dict[str, Any]] | None]:
"""Inject cache_control markers for prompt caching."""
cache_marker = {"type": "ephemeral"}
new_messages = list(messages)
def _mark(msg: dict[str, Any]) -> dict[str, Any]:
content = msg.get("content")
if isinstance(content, str):
return {**msg, "content": [
{"type": "text", "text": content, "cache_control": cache_marker},
]}
if isinstance(content, list) and content:
nc = list(content)
nc[-1] = {**nc[-1], "cache_control": cache_marker}
return {**msg, "content": nc}
return msg
if new_messages and new_messages[0].get("role") == "system":
new_messages[0] = _mark(new_messages[0])
if len(new_messages) >= 3:
new_messages[-2] = _mark(new_messages[-2])
new_tools = tools
if tools:
new_tools = list(tools)
for idx in cls._tool_cache_marker_indices(new_tools):
new_tools[idx] = {**new_tools[idx], "cache_control": cache_marker}
return new_messages, new_tools
@staticmethod
def _normalize_tool_call_id(tool_call_id: Any) -> Any:
"""Normalize to a provider-safe 9-char alphanumeric form."""
if not isinstance(tool_call_id, str):
return tool_call_id
if len(tool_call_id) == 9 and tool_call_id.isalnum():
return tool_call_id
return hashlib.sha1(tool_call_id.encode()).hexdigest()[:9]
def _should_normalize_tool_call_ids(self) -> bool:
"""Return True for providers that reject normal OpenAI tool call IDs."""
return bool(self._spec and self._spec.name == "mistral")
@staticmethod
def _normalize_tool_call_arguments(arguments: Any) -> str:
"""Force function.arguments into a valid JSON object string."""
if isinstance(arguments, str):
stripped = arguments.strip()
if not stripped:
return "{}"
try:
parsed = json_repair.loads(stripped)
except Exception:
return "{}"
if isinstance(parsed, dict):
return json.dumps(parsed, ensure_ascii=False)
return "{}"
if isinstance(arguments, dict):
return json.dumps(arguments, ensure_ascii=False)
return "{}"
@staticmethod
def _coerce_content_to_string(content: Any) -> str | None:
"""Coerce block/list content into plain text for strict string-only APIs."""
if content is None or isinstance(content, str):
return content
text = OpenAICompatProvider._extract_text_content(content)
if isinstance(text, str) and text:
return text
try:
dumped = json.dumps(content, ensure_ascii=False)
except Exception:
dumped = str(content)
return dumped or "(empty)"
def _sanitize_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Strip non-standard keys, normalize tool_call IDs."""
sanitized = LLMProvider._sanitize_request_messages(messages, _ALLOWED_MSG_KEYS)
id_map: dict[str, str] = {}
pending_tool_ids: dict[str, deque[str]] = {}
force_string_content = bool(self._spec and self._spec.name == "deepseek")
normalize_tool_ids = self._should_normalize_tool_call_ids()
def map_id(value: Any) -> Any:
if not isinstance(value, str):
return value
if not normalize_tool_ids:
return value
return id_map.setdefault(value, self._normalize_tool_call_id(value))
def unique_tool_id(value: Any, used_ids: set[str], idx: int) -> str:
if isinstance(value, str) and value:
base = map_id(value)
else:
base = _short_tool_id()
if not isinstance(base, str) or not base:
base = _short_tool_id()
if base not in used_ids:
return base
seed = value if isinstance(value, str) and value else base
salt = 1
while True:
candidate = self._normalize_tool_call_id(f"{seed}:{idx}:{salt}")
if isinstance(candidate, str) and candidate not in used_ids:
return candidate
salt += 1
def map_tool_result_id(value: Any) -> Any:
if not isinstance(value, str):
return value
queue = pending_tool_ids.get(value)
if queue:
mapped = queue.popleft()
if not queue:
pending_tool_ids.pop(value, None)
return mapped
return map_id(value)
for clean in sanitized:
if isinstance(clean.get("tool_calls"), list):
normalized = []
used_ids: set[str] = set()
for idx, tc in enumerate(clean["tool_calls"]):
if not isinstance(tc, dict):
normalized.append(tc)
continue
tc_clean = dict(tc)
raw_id = tc_clean.get("id")
mapped_id = unique_tool_id(raw_id, used_ids, idx)
tc_clean["id"] = mapped_id
used_ids.add(mapped_id)
if isinstance(raw_id, str) and raw_id:
pending_tool_ids.setdefault(raw_id, deque()).append(mapped_id)
function = tc_clean.get("function")
if isinstance(function, dict):
function_clean = dict(function)
if "arguments" in function_clean:
function_clean["arguments"] = self._normalize_tool_call_arguments(
function_clean.get("arguments")
)
else:
function_clean["arguments"] = "{}"
tc_clean["function"] = function_clean
normalized.append(tc_clean)
clean["tool_calls"] = normalized
if clean.get("role") == "assistant":
# Some OpenAI-compatible gateways reject assistant messages
# that mix non-empty content with tool_calls.
clean["content"] = None
if "tool_call_id" in clean and clean["tool_call_id"]:
clean["tool_call_id"] = map_tool_result_id(clean["tool_call_id"])
if (
force_string_content
and not (clean.get("role") == "assistant" and clean.get("tool_calls"))
):
clean["content"] = self._coerce_content_to_string(clean.get("content"))
return self._enforce_role_alternation(sanitized)
# ------------------------------------------------------------------
# Build kwargs
# ------------------------------------------------------------------
@staticmethod
def _supports_temperature(
model_name: str,
reasoning_effort: str | None = None,
) -> bool:
"""Return True when the model accepts a temperature parameter.
GPT-5 family and reasoning models (o1/o3/o4) reject temperature
when reasoning_effort is set to anything other than ``"none"``.
"""
if reasoning_effort and reasoning_effort.lower() != "none":
return False
name = model_name.lower()
return not any(token in name for token in ("gpt-5", "o1", "o3", "o4"))
def _build_kwargs(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None,
model: str | None,
max_tokens: int,
temperature: float,
reasoning_effort: str | None,
tool_choice: str | dict[str, Any] | None,
) -> dict[str, Any]:
model_name = model or self.default_model
spec = self._spec
if spec and spec.supports_prompt_caching:
model_name = model or self.default_model
if any(model_name.lower().startswith(k) for k in ("anthropic/", "claude")):
messages, tools = self._apply_cache_control(messages, tools)
if spec and spec.strip_model_prefix:
model_name = model_name.split("/")[-1]
kwargs: dict[str, Any] = {
"model": model_name,
"messages": self._sanitize_messages(self._sanitize_empty_content(messages)),
}
# GPT-5 and reasoning models (o1/o3/o4) reject temperature when
# reasoning_effort is active. Only include it when safe.
if self._supports_temperature(model_name, reasoning_effort):
kwargs["temperature"] = temperature
if spec and getattr(spec, "supports_max_completion_tokens", False):
kwargs["max_completion_tokens"] = max(1, max_tokens)
else:
kwargs["max_tokens"] = max(1, max_tokens)
if spec:
model_lower = model_name.lower()
for pattern, overrides in spec.model_overrides:
if pattern in model_lower:
kwargs.update(overrides)
break
# Normalize reasoning_effort into a semantic form (OpenAI vocab)
# used for internal decisions, and a wire form actually sent out.
# "minimum" is accepted as a DashScope-native alias for "minimal".
semantic_effort: str | None = None
if isinstance(reasoning_effort, str):
semantic_effort = reasoning_effort.lower()
if semantic_effort == "minimum":
semantic_effort = "minimal"
wire_effort = reasoning_effort
if spec and spec.name == "dashscope" and semantic_effort == "minimal":
# DashScope accepts none/minimum/low/medium/high/xhigh; "minimal" 400s.
wire_effort = "minimum"
if wire_effort and semantic_effort != "none":
kwargs["reasoning_effort"] = wire_effort
# Only send thinking controls when reasoning_effort is explicit so
# omitting the config preserves each provider's default.
if reasoning_effort is not None:
thinking_enabled = semantic_effort not in ("none", "minimal")
for thinking_style in _thinking_styles_for(spec, model_name):
extra = _thinking_extra_body(thinking_style, thinking_enabled)
if extra:
kwargs.setdefault("extra_body", {}).update(extra)
gateway_style = getattr(spec, "gateway_reasoning_style", "") if spec else ""
if gateway_style and _model_thinking_style(model_name):
extra = _gateway_reasoning_extra_body(gateway_style, semantic_effort)
if extra:
kwargs.setdefault("extra_body", {}).update(extra)
# Moonshot rejects requests that carry both 'reasoning_effort'
# and the native 'thinking' param. We already expressed the
# user's intent via the provider-native shape, so drop the
# redundant wire-level kwarg. Only kimi models need this —
# Xiaomi's API accepts both params.
if _model_slug(model_name) in _KIMI_THINKING_MODELS:
kwargs.pop("reasoning_effort", None)
if tools:
kwargs["tools"] = tools
kwargs["tool_choice"] = tool_choice or "auto"
# Backfill reasoning_content="" on assistants missing it: DeepSeek
# thinking mode rejects history otherwise (#3554, #3584); "" reads
# as "no thinking that turn". DeepSeek-V4/reasoner reason natively,
# so backfill even without explicit reasoning_effort.
explicit_thinking = (
reasoning_effort is not None
and semantic_effort not in ("none", "minimal")
and (
(spec and spec.thinking_style)
or _model_thinking_style(model_name)
)
)
implicit_deepseek_thinking = (
spec is not None
and spec.name == "deepseek"
and semantic_effort not in ("none", "minimal", "minimum")
and any(t in model_name.lower() for t in ("deepseek-v4", "deepseek-reasoner"))
)
if explicit_thinking or implicit_deepseek_thinking:
for msg in kwargs["messages"]:
if msg.get("role") == "assistant" and "reasoning_content" not in msg:
msg["reasoning_content"] = ""
# Merge user-configured extra_body last so it can override or
# extend provider-specific defaults (e.g. chat_template_kwargs,
# guided_json, repetition_penalty). Uses recursive merge so
# nested dicts like {"chat_template_kwargs": {"enable_thinking": false}}
# do not clobber sibling keys already set by thinking-style logic.
if self._extra_body:
existing = kwargs.get("extra_body", {})
kwargs["extra_body"] = _deep_merge(existing, self._extra_body)
return kwargs
def _should_use_responses_api(
self,
model: str | None,
reasoning_effort: str | None,
) -> bool:
"""Use Responses API only for direct OpenAI requests that benefit from it."""
if self._api_type == "chat_completions":
return False
if self._spec and self._spec.name not in ("openai", "github_copilot"):
return False
if self._api_type == "responses":
# Explicit configuration means Responses is mandatory; do not
# consult the circuit breaker or fall back to Chat Completions.
return True
if self._spec is None or self._spec.name != "github_copilot":
if not _is_direct_openai_base(self._effective_base):
return False
model_name = (model or self.default_model).lower()
wants = False
if reasoning_effort and reasoning_effort.lower() != "none":
wants = True
elif any(token in model_name for token in ("gpt-5", "o1", "o3", "o4")):
wants = True
if not wants:
return False
return self._responses_circuit_allows_probe(model, reasoning_effort)
def _responses_circuit_allows_probe(
self,
model: str | None,
reasoning_effort: str | None,
) -> bool:
"""Return False when the Responses API circuit breaker is open."""
key = _responses_circuit_key(model, self.default_model, reasoning_effort)
failures = self._responses_failures.get(key, 0)
if failures >= _RESPONSES_FAILURE_THRESHOLD:
tripped = self._responses_tripped_at.get(key, 0.0)
if (time.monotonic() - tripped) < _RESPONSES_PROBE_INTERVAL_S:
return False
# Half-open: allow one probe attempt
return True
def _record_responses_failure(self, model: str | None, reasoning_effort: str | None) -> None:
key = _responses_circuit_key(model, self.default_model, reasoning_effort)
count = self._responses_failures.get(key, 0) + 1
self._responses_failures[key] = count
if count >= _RESPONSES_FAILURE_THRESHOLD:
self._responses_tripped_at[key] = time.monotonic()
logger.warning(
"Responses API circuit open for {} — falling back to Chat Completions",
key,
)
def _record_responses_success(self, model: str | None, reasoning_effort: str | None) -> None:
key = _responses_circuit_key(model, self.default_model, reasoning_effort)
self._responses_failures.pop(key, None)
self._responses_tripped_at.pop(key, None)
@staticmethod
def _should_fallback_from_responses_error(e: Exception) -> bool:
"""Fallback only for likely Responses API compatibility errors."""
response = getattr(e, "response", None)
status_code = getattr(e, "status_code", None)
if status_code is None and response is not None:
status_code = getattr(response, "status_code", None)
if status_code not in {400, 404, 422}:
return False
body = (
getattr(e, "body", None)
or getattr(e, "doc", None)
or getattr(response, "text", None)
)
body_text = str(body).lower() if body is not None else ""
compatibility_markers = (
"responses",
"response api",
"max_output_tokens",
"instructions",
"previous_response",
"unsupported",
"not supported",
"unknown parameter",
"unrecognized request argument",
)
return any(marker in body_text for marker in compatibility_markers)
def _build_responses_body(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None,
model: str | None,
max_tokens: int,
temperature: float,
reasoning_effort: str | None,
tool_choice: str | dict[str, Any] | None,
) -> dict[str, Any]:
"""Build a Responses API body for direct OpenAI requests."""
model_name = model or self.default_model
if self._spec and self._spec.strip_model_prefix:
model_name = model_name.split("/")[-1]
sanitized_messages = self._sanitize_messages(self._sanitize_empty_content(messages))
instructions, input_items = convert_messages(sanitized_messages)
body: dict[str, Any] = {
"model": model_name,
"instructions": instructions or None,
"input": input_items,
"max_output_tokens": max(1, max_tokens),
"store": False,
"stream": False,
}
if self._supports_temperature(model_name, reasoning_effort):
body["temperature"] = temperature
if reasoning_effort and reasoning_effort.lower() != "none":
body["reasoning"] = {"effort": reasoning_effort}
body["include"] = ["reasoning.encrypted_content"]
if tools:
body["tools"] = convert_tools(tools)
body["tool_choice"] = tool_choice or "auto"
extra_body = getattr(self, "_extra_body", {})
if extra_body:
body = _merge_responses_extra_body(body, extra_body)
return body
# ------------------------------------------------------------------
# Response parsing
# ------------------------------------------------------------------
@staticmethod
def _maybe_mapping(value: Any) -> dict[str, Any] | None:
if isinstance(value, dict):
return value
model_dump = getattr(value, "model_dump", None)
if callable(model_dump):
dumped = model_dump()
if isinstance(dumped, dict):
return dumped
return None
@classmethod
def _extract_text_content(cls, value: Any) -> str | None:
if value is None:
return None
if isinstance(value, str):
return value
if isinstance(value, list):
parts: list[str] = []
for item in value:
item_map = cls._maybe_mapping(item)
if item_map:
text = item_map.get("text")
if isinstance(text, str):
parts.append(text)
continue
text = getattr(item, "text", None)
if isinstance(text, str):
parts.append(text)
continue
if isinstance(item, str):
parts.append(item)
return "".join(parts) or None
return str(value)
@classmethod
def _extract_usage(cls, response: Any) -> dict[str, int]:
"""Extract token usage from an OpenAI-compatible response.
Handles both dict-based (raw JSON) and object-based (SDK Pydantic)
responses. Provider-specific ``cached_tokens`` fields are normalised
under a single key; see the priority chain inside for details.
"""
# --- resolve usage object ---
usage_obj = None
response_map = cls._maybe_mapping(response)
if response_map is not None:
usage_obj = response_map.get("usage")
elif hasattr(response, "usage") and response.usage:
usage_obj = response.usage
usage_map = cls._maybe_mapping(usage_obj)
if usage_map is not None:
result = {
"prompt_tokens": int(usage_map.get("prompt_tokens") or 0),
"completion_tokens": int(usage_map.get("completion_tokens") or 0),
"total_tokens": int(usage_map.get("total_tokens") or 0),
}
elif usage_obj:
result = {
"prompt_tokens": getattr(usage_obj, "prompt_tokens", 0) or 0,
"completion_tokens": getattr(usage_obj, "completion_tokens", 0) or 0,
"total_tokens": getattr(usage_obj, "total_tokens", 0) or 0,
}
else:
return {}
# --- cached_tokens (normalised across providers) ---
# Try nested paths first (dict), fall back to attribute (SDK object).
# Priority order ensures the most specific field wins.
for path in (
("prompt_tokens_details", "cached_tokens"), # OpenAI/Zhipu/MiniMax/Qwen/Mistral/xAI
("cached_tokens",), # StepFun/Moonshot (top-level)
("prompt_cache_hit_tokens",), # DeepSeek/SiliconFlow
):
cached = cls._get_nested_int(usage_map, path)
if not cached and usage_obj:
cached = cls._get_nested_int(usage_obj, path)
if cached:
result["cached_tokens"] = cached
break
return result
@staticmethod
def _get_nested_int(obj: Any, path: tuple[str, ...]) -> int:
"""Drill into *obj* by *path* segments and return an ``int`` value.
Supports both dict-key access and attribute access so it works
uniformly with raw JSON dicts **and** SDK Pydantic models.
"""
current = obj
for segment in path:
if current is None:
return 0
if isinstance(current, dict):
current = current.get(segment)
else:
current = getattr(current, segment, None)
return int(current or 0) if current is not None else 0
def _parse(self, response: Any) -> LLMResponse:
if isinstance(response, str):
return LLMResponse(content=response, finish_reason="stop")
response_map = self._maybe_mapping(response)
if response_map is not None:
choices = response_map.get("choices") or []
if not choices:
content = self._extract_text_content(
response_map.get("content") or response_map.get("output_text")
)
reasoning_content = self._extract_text_content(
response_map.get("reasoning_content")
)
if content is not None:
return LLMResponse(
content=content,
reasoning_content=reasoning_content,
finish_reason=str(response_map.get("finish_reason") or "stop"),
usage=self._extract_usage(response_map),
)
return LLMResponse(content="Error: API returned empty choices.", finish_reason="error")
choice0 = self._maybe_mapping(choices[0]) or {}
msg0 = self._maybe_mapping(choice0.get("message")) or {}
content = self._extract_text_content(msg0.get("content"))
finish_reason = str(choice0.get("finish_reason") or "stop")
raw_tool_calls: list[Any] = []
# StepFun: fallback to reasoning field when content is empty
if not content and msg0.get("reasoning") and self._spec and self._spec.reasoning_as_content:
content = self._extract_text_content(msg0.get("reasoning"))
reasoning_content = msg0.get("reasoning_content")
if reasoning_content is None and msg0.get("reasoning"):
reasoning_content = self._extract_text_content(msg0.get("reasoning"))
for ch in choices:
ch_map = self._maybe_mapping(ch) or {}
m = self._maybe_mapping(ch_map.get("message")) or {}
tool_calls = m.get("tool_calls")
if isinstance(tool_calls, list) and tool_calls:
raw_tool_calls.extend(tool_calls)
if ch_map.get("finish_reason") in ("tool_calls", "stop"):
finish_reason = str(ch_map["finish_reason"])
if not content:
content = self._extract_text_content(m.get("content"))
if reasoning_content is None:
reasoning_content = m.get("reasoning_content")
parsed_tool_calls = []
for tc in raw_tool_calls:
tc_map = self._maybe_mapping(tc) or {}
fn = self._maybe_mapping(tc_map.get("function")) or {}
args = fn.get("arguments", {})
if isinstance(args, str):
args = json_repair.loads(args)
ec, prov, fn_prov = _extract_tc_extras(tc)
parsed_tool_calls.append(ToolCallRequest(
id=str(tc_map.get("id") or _short_tool_id()),
name=str(fn.get("name") or ""),
arguments=args if isinstance(args, dict) else {},
extra_content=ec,
provider_specific_fields=prov,
function_provider_specific_fields=fn_prov,
))
return LLMResponse(
content=content,
tool_calls=parsed_tool_calls,
finish_reason=finish_reason,
usage=self._extract_usage(response_map),
reasoning_content=reasoning_content if isinstance(reasoning_content, str) else None,
)
if not response.choices:
return LLMResponse(content="Error: API returned empty choices.", finish_reason="error")
choice = response.choices[0]
msg = choice.message
content = msg.content
finish_reason = choice.finish_reason
raw_tool_calls: list[Any] = []
for ch in response.choices:
m = ch.message
if hasattr(m, "tool_calls") and m.tool_calls:
raw_tool_calls.extend(m.tool_calls)
if ch.finish_reason in ("tool_calls", "stop"):
finish_reason = ch.finish_reason
if not content and m.content:
content = m.content
if not content and getattr(m, "reasoning", None) and self._spec and self._spec.reasoning_as_content:
content = m.reasoning
tool_calls = []
for tc in raw_tool_calls:
args = tc.function.arguments
if isinstance(args, str):
args = json_repair.loads(args)
ec, prov, fn_prov = _extract_tc_extras(tc)
tool_calls.append(ToolCallRequest(
id=str(getattr(tc, "id", None) or _short_tool_id()),
name=tc.function.name,
arguments=args,
extra_content=ec,
provider_specific_fields=prov,
function_provider_specific_fields=fn_prov,
))
reasoning_content = getattr(msg, "reasoning_content", None)
if reasoning_content is None and getattr(msg, "reasoning", None):
reasoning_content = msg.reasoning
return LLMResponse(
content=content,
tool_calls=tool_calls,
finish_reason=finish_reason or "stop",
usage=self._extract_usage(response),
reasoning_content=reasoning_content,
)
@classmethod
def _parse_chunks(cls, chunks: list[Any]) -> LLMResponse:
content_parts: list[str] = []
reasoning_parts: list[str] = []
tc_bufs: dict[int, dict[str, Any]] = {}
finish_reason = "stop"
usage: dict[str, int] = {}
def _accum_tc(tc: Any, idx_hint: int) -> None:
"""Accumulate one streaming tool-call delta into *tc_bufs*."""
tc_index: int = _get(tc, "index") if _get(tc, "index") is not None else idx_hint
buf = tc_bufs.setdefault(tc_index, {
"id": "", "name": "", "arguments": "",
"extra_content": None, "prov": None, "fn_prov": None,
})
tc_id = _get(tc, "id")
if tc_id:
buf["id"] = str(tc_id)
fn = _get(tc, "function")
if fn is not None:
fn_name = _get(fn, "name")
if fn_name:
buf["name"] = str(fn_name)
fn_args = _get(fn, "arguments")
if fn_args:
buf["arguments"] += str(fn_args)
ec, prov, fn_prov = _extract_tc_extras(tc)
if ec:
buf["extra_content"] = ec
if prov:
buf["prov"] = prov
if fn_prov:
buf["fn_prov"] = fn_prov
def _accum_legacy_function_call(function_call: Any) -> None:
"""Accumulate legacy ``delta.function_call`` streaming chunks."""
if not function_call:
return
buf = tc_bufs.setdefault(0, {
"id": "", "name": "", "arguments": "",
"extra_content": None, "prov": None, "fn_prov": None,
})
fn_name = _get(function_call, "name")
if fn_name:
buf["name"] = str(fn_name)
fn_args = _get(function_call, "arguments")
if fn_args:
buf["arguments"] += str(fn_args)
for chunk in chunks:
if isinstance(chunk, str):
content_parts.append(chunk)
continue
chunk_map = cls._maybe_mapping(chunk)
if chunk_map is not None:
choices = chunk_map.get("choices") or []
if not choices:
usage = cls._extract_usage(chunk_map) or usage
text = cls._extract_text_content(
chunk_map.get("content") or chunk_map.get("output_text")
)
if text:
content_parts.append(text)
continue
choice = cls._maybe_mapping(choices[0]) or {}
if choice.get("finish_reason"):
finish_reason = str(choice["finish_reason"])
delta = cls._maybe_mapping(choice.get("delta")) or {}
text = cls._extract_text_content(delta.get("content"))
if text:
content_parts.append(text)
text = cls._extract_text_content(delta.get("reasoning_content"))
if not text:
text = cls._extract_text_content(delta.get("reasoning"))
if text:
reasoning_parts.append(text)
for idx, tc in enumerate(delta.get("tool_calls") or []):
_accum_tc(tc, idx)
_accum_legacy_function_call(delta.get("function_call"))
usage = cls._extract_usage(chunk_map) or usage
continue
if not chunk.choices:
usage = cls._extract_usage(chunk) or usage
continue
choice = chunk.choices[0]
if choice.finish_reason:
finish_reason = choice.finish_reason
delta = choice.delta
if delta and delta.content:
content_parts.append(delta.content)
if delta:
reasoning = getattr(delta, "reasoning_content", None)
if not reasoning:
reasoning = getattr(delta, "reasoning", None)
if reasoning:
reasoning_parts.append(reasoning)
for tc in (getattr(delta, "tool_calls", None) or []) if delta else []:
_accum_tc(tc, getattr(tc, "index", 0))
if delta:
_accum_legacy_function_call(getattr(delta, "function_call", None))
# Some providers (e.g. Zhipu/GLM) reuse the same tool_call id for
# parallel tool calls in streaming mode. Deduplicate before building
# the response so downstream tool messages don't collide.
_seen_tc_ids: set[str] = set()
for b in tc_bufs.values():
if not b["id"] or b["id"] in _seen_tc_ids:
b["id"] = _short_tool_id()
_seen_tc_ids.add(b["id"])
return LLMResponse(
content="".join(content_parts) or None,
tool_calls=[
ToolCallRequest(
id=b["id"] or _short_tool_id(),
name=b["name"],
arguments=json_repair.loads(b["arguments"]) if b["arguments"] else {},
extra_content=b.get("extra_content"),
provider_specific_fields=b.get("prov"),
function_provider_specific_fields=b.get("fn_prov"),
)
for b in tc_bufs.values()
],
finish_reason=finish_reason,
usage=usage,
reasoning_content="".join(reasoning_parts) or None,
)
@classmethod
def _extract_error_metadata(cls, e: Exception) -> dict[str, Any]:
response = getattr(e, "response", None)
headers = getattr(response, "headers", None)
payload = (
getattr(e, "body", None)
or getattr(e, "doc", None)
or getattr(response, "text", None)
)
if payload is None and response is not None:
response_json = getattr(response, "json", None)
if callable(response_json):
try:
payload = response_json()
except Exception:
payload = None
error_type, error_code = LLMProvider._extract_error_type_code(payload)
status_code = getattr(e, "status_code", None)
if status_code is None and response is not None:
status_code = getattr(response, "status_code", None)
should_retry: bool | None = None
if headers is not None:
raw = headers.get("x-should-retry")
if isinstance(raw, str):
lowered = raw.strip().lower()
if lowered == "true":
should_retry = True
elif lowered == "false":
should_retry = False
error_kind: str | None = None
error_name = e.__class__.__name__.lower()
if "timeout" in error_name:
error_kind = "timeout"
elif "connection" in error_name:
error_kind = "connection"
return {
"error_status_code": int(status_code) if status_code is not None else None,
"error_kind": error_kind,
"error_type": error_type,
"error_code": error_code,
"error_retry_after_s": cls._extract_retry_after_from_headers(headers),
"error_should_retry": should_retry,
}
@staticmethod
def _handle_error(
e: Exception,
*,
spec: ProviderSpec | None = None,
api_base: str | None = None,
) -> LLMResponse:
body = (
getattr(e, "doc", None)
or getattr(e, "body", None)
or getattr(getattr(e, "response", None), "text", None)
)
body_text = body if isinstance(body, str) else str(body) if body is not None else ""
msg = f"Error: {body_text.strip()[:500]}" if body_text.strip() else f"Error calling LLM: {e}"
text = f"{body_text} {e}".lower()
if spec and spec.is_local and ("502" in text or "connection" in text or "refused" in text):
msg += (
"\nHint: this is a local model endpoint. Check that the local server is reachable at "
f"{api_base or spec.default_api_base}, and if you are using a proxy/tunnel, make sure it "
"can reach your local Ollama/vLLM service instead of routing localhost through the remote host."
)
response = getattr(e, "response", None)
retry_after = LLMProvider._extract_retry_after_from_headers(getattr(response, "headers", None))
if retry_after is None:
retry_after = LLMProvider._extract_retry_after(msg)
return LLMResponse(
content=msg,
finish_reason="error",
retry_after=retry_after,
**OpenAICompatProvider._extract_error_metadata(e),
)
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
async def chat(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
model: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
reasoning_effort: str | None = None,
tool_choice: str | dict[str, Any] | None = None,
) -> LLMResponse:
await self._ensure_client()
try:
if self._should_use_responses_api(model, reasoning_effort):
try:
body = self._build_responses_body(
messages, tools, model, max_tokens, temperature,
reasoning_effort, tool_choice,
)
result = parse_response_output(await self._client.responses.create(**body))
self._record_responses_success(model, reasoning_effort)
return result
except Exception as responses_error:
if self._spec and self._spec.name == "github_copilot":
# Copilot gateway exposes GPT-5/o-series only via /responses;
# falling back to /chat/completions cannot succeed and would
# hide the real error.
raise
if self._api_type == "responses":
raise
if not self._should_fallback_from_responses_error(responses_error):
raise
self._record_responses_failure(model, reasoning_effort)
kwargs = self._build_kwargs(
messages, tools, model, max_tokens, temperature,
reasoning_effort, tool_choice,
)
return self._parse(await self._client.chat.completions.create(**kwargs))
except Exception as e:
return self._handle_error(e, spec=self._spec, api_base=self.api_base)
async def chat_stream(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
model: str | None = None,
max_tokens: int = 4096,
temperature: float = 0.7,
reasoning_effort: str | None = None,
tool_choice: str | dict[str, Any] | None = None,
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
on_thinking_delta: Callable[[str], Awaitable[None]] | None = None,
on_tool_call_delta: Callable[[dict[str, Any]], Awaitable[None]] | None = None,
) -> LLMResponse:
await self._ensure_client()
idle_timeout_s = int(os.environ.get("NANOBOT_STREAM_IDLE_TIMEOUT_S", "90"))
try:
if self._should_use_responses_api(model, reasoning_effort):
try:
body = self._build_responses_body(
messages, tools, model, max_tokens, temperature,
reasoning_effort, tool_choice,
)
body["stream"] = True
stream = await self._client.responses.create(**body)
async def _timed_stream():
stream_iter = stream.__aiter__()
while True:
try:
yield await asyncio.wait_for(
stream_iter.__anext__(),
timeout=idle_timeout_s,
)
except StopAsyncIteration:
break
(
content,
tool_calls,
finish_reason,
usage,
reasoning_content,
) = await consume_sdk_stream(
_timed_stream(),
on_content_delta,
on_tool_call_delta=on_tool_call_delta,
)
self._record_responses_success(model, reasoning_effort)
return LLMResponse(
content=content or None,
tool_calls=tool_calls,
finish_reason=finish_reason,
usage=usage,
reasoning_content=reasoning_content,
)
except Exception as responses_error:
if self._spec and self._spec.name == "github_copilot":
# Copilot gateway exposes GPT-5/o-series only via /responses;
# falling back to /chat/completions cannot succeed and would
# hide the real error.
raise
if self._api_type == "responses":
raise
if not self._should_fallback_from_responses_error(responses_error):
raise
self._record_responses_failure(model, reasoning_effort)
kwargs = self._build_kwargs(
messages, tools, model, max_tokens, temperature,
reasoning_effort, tool_choice,
)
if self._spec and self._spec.name == "zhipu" and tools and on_tool_call_delta:
# Z.AI/GLM keeps streaming tool-call arguments behind an
# explicit provider flag. Pass it through the OpenAI SDK's
# extra_body escape hatch so the usual delta.tool_calls path
# can surface live file-edit progress.
kwargs.setdefault("extra_body", {})["tool_stream"] = True
kwargs["stream"] = True
kwargs["stream_options"] = {"include_usage": True}
stream = await self._client.chat.completions.create(**kwargs)
chunks: list[Any] = []
stream_iter = stream.__aiter__()
while True:
try:
chunk = await asyncio.wait_for(
stream_iter.__anext__(),
timeout=idle_timeout_s,
)
except StopAsyncIteration:
break
chunks.append(chunk)
if chunk.choices:
delta_obj = chunk.choices[0].delta
if on_content_delta:
text = getattr(delta_obj, "content", None)
if text:
await on_content_delta(text)
if on_thinking_delta:
reasoning = getattr(delta_obj, "reasoning_content", None) or getattr(
delta_obj, "reasoning", None,
)
r_text = self._extract_text_content(reasoning)
if r_text:
await on_thinking_delta(r_text)
if on_tool_call_delta:
for idx, tool_delta in enumerate(
getattr(delta_obj, "tool_calls", None) or []
):
fn = _get(tool_delta, "function")
tool_index = _get(tool_delta, "index")
await on_tool_call_delta({
"index": tool_index if tool_index is not None else idx,
"call_id": str(_get(tool_delta, "id") or ""),
"name": str(_get(fn, "name") or "") if fn is not None else "",
"arguments_delta": (
str(_get(fn, "arguments") or "") if fn is not None else ""
),
})
function_call = getattr(delta_obj, "function_call", None)
if function_call:
await on_tool_call_delta({
"index": 0,
"call_id": "",
"name": str(_get(function_call, "name") or ""),
"arguments_delta": str(_get(function_call, "arguments") or ""),
})
return self._parse_chunks(chunks)
except asyncio.TimeoutError:
return LLMResponse(
content=(
f"Error calling LLM: stream stalled for more than "
f"{idle_timeout_s} seconds"
),
finish_reason="error",
error_kind="timeout",
)
except Exception as e:
return self._handle_error(e, spec=self._spec, api_base=self.api_base)
def get_default_model(self) -> str:
return self.default_model