mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-23 03:15:58 +00:00
- Use import importlib.util (not bare importlib) for find_spec - Warn and fall back to standard openai instead of crashing with ImportError when LANGFUSE_SECRET_KEY is set but langfuse is missing Made-with: Cursor
702 lines
27 KiB
Python
702 lines
27 KiB
Python
"""OpenAI-compatible provider for all non-Anthropic LLM APIs."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import hashlib
|
|
import importlib.util
|
|
import os
|
|
import secrets
|
|
import string
|
|
import uuid
|
|
from collections.abc import Awaitable, Callable
|
|
from typing import TYPE_CHECKING, Any
|
|
|
|
import json_repair
|
|
|
|
if os.environ.get("LANGFUSE_SECRET_KEY") and importlib.util.find_spec("langfuse"):
|
|
from langfuse.openai import AsyncOpenAI
|
|
else:
|
|
if os.environ.get("LANGFUSE_SECRET_KEY"):
|
|
import logging
|
|
logging.getLogger(__name__).warning(
|
|
"LANGFUSE_SECRET_KEY is set but langfuse is not installed; "
|
|
"install with `pip install langfuse` to enable tracing"
|
|
)
|
|
from openai import AsyncOpenAI
|
|
|
|
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
|
|
|
if TYPE_CHECKING:
|
|
from nanobot.providers.registry import ProviderSpec
|
|
|
|
_ALLOWED_MSG_KEYS = frozenset({
|
|
"role", "content", "tool_calls", "tool_call_id", "name",
|
|
"reasoning_content", "extra_content",
|
|
})
|
|
_ALNUM = string.ascii_letters + string.digits
|
|
|
|
_STANDARD_TC_KEYS = frozenset({"id", "type", "index", "function"})
|
|
_STANDARD_FN_KEYS = frozenset({"name", "arguments"})
|
|
_DEFAULT_OPENROUTER_HEADERS = {
|
|
"HTTP-Referer": "https://github.com/HKUDS/nanobot",
|
|
"X-OpenRouter-Title": "nanobot",
|
|
"X-OpenRouter-Categories": "cli-agent,personal-agent",
|
|
}
|
|
|
|
|
|
def _short_tool_id() -> str:
|
|
"""9-char alphanumeric ID compatible with all providers (incl. Mistral)."""
|
|
return "".join(secrets.choice(_ALNUM) for _ in range(9))
|
|
|
|
|
|
def _get(obj: Any, key: str) -> Any:
|
|
"""Get a value from dict or object attribute, returning None if absent."""
|
|
if isinstance(obj, dict):
|
|
return obj.get(key)
|
|
return getattr(obj, key, None)
|
|
|
|
|
|
def _coerce_dict(value: Any) -> dict[str, Any] | None:
|
|
"""Try to coerce *value* to a dict; return None if not possible or empty."""
|
|
if value is None:
|
|
return None
|
|
if isinstance(value, dict):
|
|
return value if value else None
|
|
model_dump = getattr(value, "model_dump", None)
|
|
if callable(model_dump):
|
|
dumped = model_dump()
|
|
if isinstance(dumped, dict) and dumped:
|
|
return dumped
|
|
return None
|
|
|
|
|
|
def _extract_tc_extras(tc: Any) -> tuple[
|
|
dict[str, Any] | None,
|
|
dict[str, Any] | None,
|
|
dict[str, Any] | None,
|
|
]:
|
|
"""Extract (extra_content, provider_specific_fields, fn_provider_specific_fields).
|
|
|
|
Works for both SDK objects and dicts. Captures Gemini ``extra_content``
|
|
verbatim and any non-standard keys on the tool-call / function.
|
|
"""
|
|
extra_content = _coerce_dict(_get(tc, "extra_content"))
|
|
|
|
tc_dict = _coerce_dict(tc)
|
|
prov = None
|
|
fn_prov = None
|
|
if tc_dict is not None:
|
|
leftover = {k: v for k, v in tc_dict.items()
|
|
if k not in _STANDARD_TC_KEYS and k != "extra_content" and v is not None}
|
|
if leftover:
|
|
prov = leftover
|
|
fn = _coerce_dict(tc_dict.get("function"))
|
|
if fn is not None:
|
|
fn_leftover = {k: v for k, v in fn.items()
|
|
if k not in _STANDARD_FN_KEYS and v is not None}
|
|
if fn_leftover:
|
|
fn_prov = fn_leftover
|
|
else:
|
|
prov = _coerce_dict(_get(tc, "provider_specific_fields"))
|
|
fn_obj = _get(tc, "function")
|
|
if fn_obj is not None:
|
|
fn_prov = _coerce_dict(_get(fn_obj, "provider_specific_fields"))
|
|
|
|
return extra_content, prov, fn_prov
|
|
|
|
|
|
def _uses_openrouter_attribution(spec: "ProviderSpec | None", api_base: str | None) -> bool:
|
|
"""Apply Nanobot attribution headers to OpenRouter requests by default."""
|
|
if spec and spec.name == "openrouter":
|
|
return True
|
|
return bool(api_base and "openrouter" in api_base.lower())
|
|
|
|
|
|
class OpenAICompatProvider(LLMProvider):
|
|
"""Unified provider for all OpenAI-compatible APIs.
|
|
|
|
Receives a resolved ``ProviderSpec`` from the caller — no internal
|
|
registry lookups needed.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
api_key: str | None = None,
|
|
api_base: str | None = None,
|
|
default_model: str = "gpt-4o",
|
|
extra_headers: dict[str, str] | None = None,
|
|
spec: ProviderSpec | None = None,
|
|
):
|
|
super().__init__(api_key, api_base)
|
|
self.default_model = default_model
|
|
self.extra_headers = extra_headers or {}
|
|
self._spec = spec
|
|
|
|
if api_key and spec and spec.env_key:
|
|
self._setup_env(api_key, api_base)
|
|
|
|
effective_base = api_base or (spec.default_api_base if spec else None) or None
|
|
default_headers = {"x-session-affinity": uuid.uuid4().hex}
|
|
if _uses_openrouter_attribution(spec, effective_base):
|
|
default_headers.update(_DEFAULT_OPENROUTER_HEADERS)
|
|
if extra_headers:
|
|
default_headers.update(extra_headers)
|
|
|
|
self._client = AsyncOpenAI(
|
|
api_key=api_key or "no-key",
|
|
base_url=effective_base,
|
|
default_headers=default_headers,
|
|
max_retries=0,
|
|
)
|
|
|
|
def _setup_env(self, api_key: str, api_base: str | None) -> None:
|
|
"""Set environment variables based on provider spec."""
|
|
spec = self._spec
|
|
if not spec or not spec.env_key:
|
|
return
|
|
if spec.is_gateway:
|
|
os.environ[spec.env_key] = api_key
|
|
else:
|
|
os.environ.setdefault(spec.env_key, api_key)
|
|
effective_base = api_base or spec.default_api_base
|
|
for env_name, env_val in spec.env_extras:
|
|
resolved = env_val.replace("{api_key}", api_key).replace("{api_base}", effective_base)
|
|
os.environ.setdefault(env_name, resolved)
|
|
|
|
@classmethod
|
|
def _apply_cache_control(
|
|
cls,
|
|
messages: list[dict[str, Any]],
|
|
tools: list[dict[str, Any]] | None,
|
|
) -> tuple[list[dict[str, Any]], list[dict[str, Any]] | None]:
|
|
"""Inject cache_control markers for prompt caching."""
|
|
cache_marker = {"type": "ephemeral"}
|
|
new_messages = list(messages)
|
|
|
|
def _mark(msg: dict[str, Any]) -> dict[str, Any]:
|
|
content = msg.get("content")
|
|
if isinstance(content, str):
|
|
return {**msg, "content": [
|
|
{"type": "text", "text": content, "cache_control": cache_marker},
|
|
]}
|
|
if isinstance(content, list) and content:
|
|
nc = list(content)
|
|
nc[-1] = {**nc[-1], "cache_control": cache_marker}
|
|
return {**msg, "content": nc}
|
|
return msg
|
|
|
|
if new_messages and new_messages[0].get("role") == "system":
|
|
new_messages[0] = _mark(new_messages[0])
|
|
if len(new_messages) >= 3:
|
|
new_messages[-2] = _mark(new_messages[-2])
|
|
|
|
new_tools = tools
|
|
if tools:
|
|
new_tools = list(tools)
|
|
for idx in cls._tool_cache_marker_indices(new_tools):
|
|
new_tools[idx] = {**new_tools[idx], "cache_control": cache_marker}
|
|
return new_messages, new_tools
|
|
|
|
@staticmethod
|
|
def _normalize_tool_call_id(tool_call_id: Any) -> Any:
|
|
"""Normalize to a provider-safe 9-char alphanumeric form."""
|
|
if not isinstance(tool_call_id, str):
|
|
return tool_call_id
|
|
if len(tool_call_id) == 9 and tool_call_id.isalnum():
|
|
return tool_call_id
|
|
return hashlib.sha1(tool_call_id.encode()).hexdigest()[:9]
|
|
|
|
def _sanitize_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
"""Strip non-standard keys, normalize tool_call IDs."""
|
|
sanitized = LLMProvider._sanitize_request_messages(messages, _ALLOWED_MSG_KEYS)
|
|
id_map: dict[str, str] = {}
|
|
|
|
def map_id(value: Any) -> Any:
|
|
if not isinstance(value, str):
|
|
return value
|
|
return id_map.setdefault(value, self._normalize_tool_call_id(value))
|
|
|
|
for clean in sanitized:
|
|
if isinstance(clean.get("tool_calls"), list):
|
|
normalized = []
|
|
for tc in clean["tool_calls"]:
|
|
if not isinstance(tc, dict):
|
|
normalized.append(tc)
|
|
continue
|
|
tc_clean = dict(tc)
|
|
tc_clean["id"] = map_id(tc_clean.get("id"))
|
|
normalized.append(tc_clean)
|
|
clean["tool_calls"] = normalized
|
|
if "tool_call_id" in clean and clean["tool_call_id"]:
|
|
clean["tool_call_id"] = map_id(clean["tool_call_id"])
|
|
return sanitized
|
|
|
|
# ------------------------------------------------------------------
|
|
# Build kwargs
|
|
# ------------------------------------------------------------------
|
|
|
|
@staticmethod
|
|
def _supports_temperature(
|
|
model_name: str,
|
|
reasoning_effort: str | None = None,
|
|
) -> bool:
|
|
"""Return True when the model accepts a temperature parameter.
|
|
|
|
GPT-5 family and reasoning models (o1/o3/o4) reject temperature
|
|
when reasoning_effort is set to anything other than ``"none"``.
|
|
"""
|
|
if reasoning_effort and reasoning_effort.lower() != "none":
|
|
return False
|
|
name = model_name.lower()
|
|
return not any(token in name for token in ("gpt-5", "o1", "o3", "o4"))
|
|
|
|
def _build_kwargs(
|
|
self,
|
|
messages: list[dict[str, Any]],
|
|
tools: list[dict[str, Any]] | None,
|
|
model: str | None,
|
|
max_tokens: int,
|
|
temperature: float,
|
|
reasoning_effort: str | None,
|
|
tool_choice: str | dict[str, Any] | None,
|
|
) -> dict[str, Any]:
|
|
model_name = model or self.default_model
|
|
spec = self._spec
|
|
|
|
if spec and spec.supports_prompt_caching:
|
|
model_name = model or self.default_model
|
|
if any(model_name.lower().startswith(k) for k in ("anthropic/", "claude")):
|
|
messages, tools = self._apply_cache_control(messages, tools)
|
|
|
|
if spec and spec.strip_model_prefix:
|
|
model_name = model_name.split("/")[-1]
|
|
|
|
kwargs: dict[str, Any] = {
|
|
"model": model_name,
|
|
"messages": self._sanitize_messages(self._sanitize_empty_content(messages)),
|
|
}
|
|
|
|
# GPT-5 and reasoning models (o1/o3/o4) reject temperature when
|
|
# reasoning_effort is active. Only include it when safe.
|
|
if self._supports_temperature(model_name, reasoning_effort):
|
|
kwargs["temperature"] = temperature
|
|
|
|
if spec and getattr(spec, "supports_max_completion_tokens", False):
|
|
kwargs["max_completion_tokens"] = max(1, max_tokens)
|
|
else:
|
|
kwargs["max_tokens"] = max(1, max_tokens)
|
|
|
|
if spec:
|
|
model_lower = model_name.lower()
|
|
for pattern, overrides in spec.model_overrides:
|
|
if pattern in model_lower:
|
|
kwargs.update(overrides)
|
|
break
|
|
|
|
if reasoning_effort:
|
|
kwargs["reasoning_effort"] = reasoning_effort
|
|
|
|
if tools:
|
|
kwargs["tools"] = tools
|
|
kwargs["tool_choice"] = tool_choice or "auto"
|
|
|
|
return kwargs
|
|
|
|
# ------------------------------------------------------------------
|
|
# Response parsing
|
|
# ------------------------------------------------------------------
|
|
|
|
@staticmethod
|
|
def _maybe_mapping(value: Any) -> dict[str, Any] | None:
|
|
if isinstance(value, dict):
|
|
return value
|
|
model_dump = getattr(value, "model_dump", None)
|
|
if callable(model_dump):
|
|
dumped = model_dump()
|
|
if isinstance(dumped, dict):
|
|
return dumped
|
|
return None
|
|
|
|
@classmethod
|
|
def _extract_text_content(cls, value: Any) -> str | None:
|
|
if value is None:
|
|
return None
|
|
if isinstance(value, str):
|
|
return value
|
|
if isinstance(value, list):
|
|
parts: list[str] = []
|
|
for item in value:
|
|
item_map = cls._maybe_mapping(item)
|
|
if item_map:
|
|
text = item_map.get("text")
|
|
if isinstance(text, str):
|
|
parts.append(text)
|
|
continue
|
|
text = getattr(item, "text", None)
|
|
if isinstance(text, str):
|
|
parts.append(text)
|
|
continue
|
|
if isinstance(item, str):
|
|
parts.append(item)
|
|
return "".join(parts) or None
|
|
return str(value)
|
|
|
|
@classmethod
|
|
def _extract_usage(cls, response: Any) -> dict[str, int]:
|
|
"""Extract token usage from an OpenAI-compatible response.
|
|
|
|
Handles both dict-based (raw JSON) and object-based (SDK Pydantic)
|
|
responses. Provider-specific ``cached_tokens`` fields are normalised
|
|
under a single key; see the priority chain inside for details.
|
|
"""
|
|
# --- resolve usage object ---
|
|
usage_obj = None
|
|
response_map = cls._maybe_mapping(response)
|
|
if response_map is not None:
|
|
usage_obj = response_map.get("usage")
|
|
elif hasattr(response, "usage") and response.usage:
|
|
usage_obj = response.usage
|
|
|
|
usage_map = cls._maybe_mapping(usage_obj)
|
|
if usage_map is not None:
|
|
result = {
|
|
"prompt_tokens": int(usage_map.get("prompt_tokens") or 0),
|
|
"completion_tokens": int(usage_map.get("completion_tokens") or 0),
|
|
"total_tokens": int(usage_map.get("total_tokens") or 0),
|
|
}
|
|
elif usage_obj:
|
|
result = {
|
|
"prompt_tokens": getattr(usage_obj, "prompt_tokens", 0) or 0,
|
|
"completion_tokens": getattr(usage_obj, "completion_tokens", 0) or 0,
|
|
"total_tokens": getattr(usage_obj, "total_tokens", 0) or 0,
|
|
}
|
|
else:
|
|
return {}
|
|
|
|
# --- cached_tokens (normalised across providers) ---
|
|
# Try nested paths first (dict), fall back to attribute (SDK object).
|
|
# Priority order ensures the most specific field wins.
|
|
for path in (
|
|
("prompt_tokens_details", "cached_tokens"), # OpenAI/Zhipu/MiniMax/Qwen/Mistral/xAI
|
|
("cached_tokens",), # StepFun/Moonshot (top-level)
|
|
("prompt_cache_hit_tokens",), # DeepSeek/SiliconFlow
|
|
):
|
|
cached = cls._get_nested_int(usage_map, path)
|
|
if not cached and usage_obj:
|
|
cached = cls._get_nested_int(usage_obj, path)
|
|
if cached:
|
|
result["cached_tokens"] = cached
|
|
break
|
|
|
|
return result
|
|
|
|
@staticmethod
|
|
def _get_nested_int(obj: Any, path: tuple[str, ...]) -> int:
|
|
"""Drill into *obj* by *path* segments and return an ``int`` value.
|
|
|
|
Supports both dict-key access and attribute access so it works
|
|
uniformly with raw JSON dicts **and** SDK Pydantic models.
|
|
"""
|
|
current = obj
|
|
for segment in path:
|
|
if current is None:
|
|
return 0
|
|
if isinstance(current, dict):
|
|
current = current.get(segment)
|
|
else:
|
|
current = getattr(current, segment, None)
|
|
return int(current or 0) if current is not None else 0
|
|
|
|
def _parse(self, response: Any) -> LLMResponse:
|
|
if isinstance(response, str):
|
|
return LLMResponse(content=response, finish_reason="stop")
|
|
|
|
response_map = self._maybe_mapping(response)
|
|
if response_map is not None:
|
|
choices = response_map.get("choices") or []
|
|
if not choices:
|
|
content = self._extract_text_content(
|
|
response_map.get("content") or response_map.get("output_text")
|
|
)
|
|
reasoning_content = self._extract_text_content(
|
|
response_map.get("reasoning_content")
|
|
)
|
|
if content is not None:
|
|
return LLMResponse(
|
|
content=content,
|
|
reasoning_content=reasoning_content,
|
|
finish_reason=str(response_map.get("finish_reason") or "stop"),
|
|
usage=self._extract_usage(response_map),
|
|
)
|
|
return LLMResponse(content="Error: API returned empty choices.", finish_reason="error")
|
|
|
|
choice0 = self._maybe_mapping(choices[0]) or {}
|
|
msg0 = self._maybe_mapping(choice0.get("message")) or {}
|
|
content = self._extract_text_content(msg0.get("content"))
|
|
finish_reason = str(choice0.get("finish_reason") or "stop")
|
|
|
|
raw_tool_calls: list[Any] = []
|
|
reasoning_content = msg0.get("reasoning_content")
|
|
for ch in choices:
|
|
ch_map = self._maybe_mapping(ch) or {}
|
|
m = self._maybe_mapping(ch_map.get("message")) or {}
|
|
tool_calls = m.get("tool_calls")
|
|
if isinstance(tool_calls, list) and tool_calls:
|
|
raw_tool_calls.extend(tool_calls)
|
|
if ch_map.get("finish_reason") in ("tool_calls", "stop"):
|
|
finish_reason = str(ch_map["finish_reason"])
|
|
if not content:
|
|
content = self._extract_text_content(m.get("content"))
|
|
if not reasoning_content:
|
|
reasoning_content = m.get("reasoning_content")
|
|
|
|
parsed_tool_calls = []
|
|
for tc in raw_tool_calls:
|
|
tc_map = self._maybe_mapping(tc) or {}
|
|
fn = self._maybe_mapping(tc_map.get("function")) or {}
|
|
args = fn.get("arguments", {})
|
|
if isinstance(args, str):
|
|
args = json_repair.loads(args)
|
|
ec, prov, fn_prov = _extract_tc_extras(tc)
|
|
parsed_tool_calls.append(ToolCallRequest(
|
|
id=_short_tool_id(),
|
|
name=str(fn.get("name") or ""),
|
|
arguments=args if isinstance(args, dict) else {},
|
|
extra_content=ec,
|
|
provider_specific_fields=prov,
|
|
function_provider_specific_fields=fn_prov,
|
|
))
|
|
|
|
return LLMResponse(
|
|
content=content,
|
|
tool_calls=parsed_tool_calls,
|
|
finish_reason=finish_reason,
|
|
usage=self._extract_usage(response_map),
|
|
reasoning_content=reasoning_content if isinstance(reasoning_content, str) else None,
|
|
)
|
|
|
|
if not response.choices:
|
|
return LLMResponse(content="Error: API returned empty choices.", finish_reason="error")
|
|
|
|
choice = response.choices[0]
|
|
msg = choice.message
|
|
content = msg.content
|
|
finish_reason = choice.finish_reason
|
|
|
|
raw_tool_calls: list[Any] = []
|
|
for ch in response.choices:
|
|
m = ch.message
|
|
if hasattr(m, "tool_calls") and m.tool_calls:
|
|
raw_tool_calls.extend(m.tool_calls)
|
|
if ch.finish_reason in ("tool_calls", "stop"):
|
|
finish_reason = ch.finish_reason
|
|
if not content and m.content:
|
|
content = m.content
|
|
|
|
tool_calls = []
|
|
for tc in raw_tool_calls:
|
|
args = tc.function.arguments
|
|
if isinstance(args, str):
|
|
args = json_repair.loads(args)
|
|
ec, prov, fn_prov = _extract_tc_extras(tc)
|
|
tool_calls.append(ToolCallRequest(
|
|
id=_short_tool_id(),
|
|
name=tc.function.name,
|
|
arguments=args,
|
|
extra_content=ec,
|
|
provider_specific_fields=prov,
|
|
function_provider_specific_fields=fn_prov,
|
|
))
|
|
|
|
return LLMResponse(
|
|
content=content,
|
|
tool_calls=tool_calls,
|
|
finish_reason=finish_reason or "stop",
|
|
usage=self._extract_usage(response),
|
|
reasoning_content=getattr(msg, "reasoning_content", None) or None,
|
|
)
|
|
|
|
@classmethod
|
|
def _parse_chunks(cls, chunks: list[Any]) -> LLMResponse:
|
|
content_parts: list[str] = []
|
|
reasoning_parts: list[str] = []
|
|
tc_bufs: dict[int, dict[str, Any]] = {}
|
|
finish_reason = "stop"
|
|
usage: dict[str, int] = {}
|
|
|
|
def _accum_tc(tc: Any, idx_hint: int) -> None:
|
|
"""Accumulate one streaming tool-call delta into *tc_bufs*."""
|
|
tc_index: int = _get(tc, "index") if _get(tc, "index") is not None else idx_hint
|
|
buf = tc_bufs.setdefault(tc_index, {
|
|
"id": "", "name": "", "arguments": "",
|
|
"extra_content": None, "prov": None, "fn_prov": None,
|
|
})
|
|
tc_id = _get(tc, "id")
|
|
if tc_id:
|
|
buf["id"] = str(tc_id)
|
|
fn = _get(tc, "function")
|
|
if fn is not None:
|
|
fn_name = _get(fn, "name")
|
|
if fn_name:
|
|
buf["name"] = str(fn_name)
|
|
fn_args = _get(fn, "arguments")
|
|
if fn_args:
|
|
buf["arguments"] += str(fn_args)
|
|
ec, prov, fn_prov = _extract_tc_extras(tc)
|
|
if ec:
|
|
buf["extra_content"] = ec
|
|
if prov:
|
|
buf["prov"] = prov
|
|
if fn_prov:
|
|
buf["fn_prov"] = fn_prov
|
|
|
|
for chunk in chunks:
|
|
if isinstance(chunk, str):
|
|
content_parts.append(chunk)
|
|
continue
|
|
|
|
chunk_map = cls._maybe_mapping(chunk)
|
|
if chunk_map is not None:
|
|
choices = chunk_map.get("choices") or []
|
|
if not choices:
|
|
usage = cls._extract_usage(chunk_map) or usage
|
|
text = cls._extract_text_content(
|
|
chunk_map.get("content") or chunk_map.get("output_text")
|
|
)
|
|
if text:
|
|
content_parts.append(text)
|
|
continue
|
|
choice = cls._maybe_mapping(choices[0]) or {}
|
|
if choice.get("finish_reason"):
|
|
finish_reason = str(choice["finish_reason"])
|
|
delta = cls._maybe_mapping(choice.get("delta")) or {}
|
|
text = cls._extract_text_content(delta.get("content"))
|
|
if text:
|
|
content_parts.append(text)
|
|
text = cls._extract_text_content(delta.get("reasoning_content"))
|
|
if text:
|
|
reasoning_parts.append(text)
|
|
for idx, tc in enumerate(delta.get("tool_calls") or []):
|
|
_accum_tc(tc, idx)
|
|
usage = cls._extract_usage(chunk_map) or usage
|
|
continue
|
|
|
|
if not chunk.choices:
|
|
usage = cls._extract_usage(chunk) or usage
|
|
continue
|
|
choice = chunk.choices[0]
|
|
if choice.finish_reason:
|
|
finish_reason = choice.finish_reason
|
|
delta = choice.delta
|
|
if delta and delta.content:
|
|
content_parts.append(delta.content)
|
|
if delta:
|
|
reasoning = getattr(delta, "reasoning_content", None)
|
|
if reasoning:
|
|
reasoning_parts.append(reasoning)
|
|
for tc in (delta.tool_calls or []) if delta else []:
|
|
_accum_tc(tc, getattr(tc, "index", 0))
|
|
|
|
return LLMResponse(
|
|
content="".join(content_parts) or None,
|
|
tool_calls=[
|
|
ToolCallRequest(
|
|
id=b["id"] or _short_tool_id(),
|
|
name=b["name"],
|
|
arguments=json_repair.loads(b["arguments"]) if b["arguments"] else {},
|
|
extra_content=b.get("extra_content"),
|
|
provider_specific_fields=b.get("prov"),
|
|
function_provider_specific_fields=b.get("fn_prov"),
|
|
)
|
|
for b in tc_bufs.values()
|
|
],
|
|
finish_reason=finish_reason,
|
|
usage=usage,
|
|
reasoning_content="".join(reasoning_parts) or None,
|
|
)
|
|
|
|
@staticmethod
|
|
def _handle_error(e: Exception) -> LLMResponse:
|
|
response = getattr(e, "response", None)
|
|
body = getattr(e, "doc", None) or getattr(response, "text", None)
|
|
body_text = str(body).strip() if body is not None else ""
|
|
msg = f"Error: {body_text[:500]}" if body_text else f"Error calling LLM: {e}"
|
|
retry_after = LLMProvider._extract_retry_after_from_headers(getattr(response, "headers", None))
|
|
if retry_after is None:
|
|
retry_after = LLMProvider._extract_retry_after(msg)
|
|
return LLMResponse(content=msg, finish_reason="error", retry_after=retry_after)
|
|
|
|
# ------------------------------------------------------------------
|
|
# Public API
|
|
# ------------------------------------------------------------------
|
|
|
|
async def chat(
|
|
self,
|
|
messages: list[dict[str, Any]],
|
|
tools: list[dict[str, Any]] | None = None,
|
|
model: str | None = None,
|
|
max_tokens: int = 4096,
|
|
temperature: float = 0.7,
|
|
reasoning_effort: str | None = None,
|
|
tool_choice: str | dict[str, Any] | None = None,
|
|
) -> LLMResponse:
|
|
kwargs = self._build_kwargs(
|
|
messages, tools, model, max_tokens, temperature,
|
|
reasoning_effort, tool_choice,
|
|
)
|
|
try:
|
|
return self._parse(await self._client.chat.completions.create(**kwargs))
|
|
except Exception as e:
|
|
return self._handle_error(e)
|
|
|
|
async def chat_stream(
|
|
self,
|
|
messages: list[dict[str, Any]],
|
|
tools: list[dict[str, Any]] | None = None,
|
|
model: str | None = None,
|
|
max_tokens: int = 4096,
|
|
temperature: float = 0.7,
|
|
reasoning_effort: str | None = None,
|
|
tool_choice: str | dict[str, Any] | None = None,
|
|
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
|
) -> LLMResponse:
|
|
kwargs = self._build_kwargs(
|
|
messages, tools, model, max_tokens, temperature,
|
|
reasoning_effort, tool_choice,
|
|
)
|
|
kwargs["stream"] = True
|
|
kwargs["stream_options"] = {"include_usage": True}
|
|
idle_timeout_s = int(os.environ.get("NANOBOT_STREAM_IDLE_TIMEOUT_S", "90"))
|
|
try:
|
|
stream = await self._client.chat.completions.create(**kwargs)
|
|
chunks: list[Any] = []
|
|
stream_iter = stream.__aiter__()
|
|
while True:
|
|
try:
|
|
chunk = await asyncio.wait_for(
|
|
stream_iter.__anext__(),
|
|
timeout=idle_timeout_s,
|
|
)
|
|
except StopAsyncIteration:
|
|
break
|
|
chunks.append(chunk)
|
|
if on_content_delta and chunk.choices:
|
|
text = getattr(chunk.choices[0].delta, "content", None)
|
|
if text:
|
|
await on_content_delta(text)
|
|
return self._parse_chunks(chunks)
|
|
except asyncio.TimeoutError:
|
|
return LLMResponse(
|
|
content=(
|
|
f"Error calling LLM: stream stalled for more than "
|
|
f"{idle_timeout_s} seconds"
|
|
),
|
|
finish_reason="error",
|
|
)
|
|
except Exception as e:
|
|
return self._handle_error(e)
|
|
|
|
def get_default_model(self) -> str:
|
|
return self.default_model
|