mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-02 09:22:36 +00:00
refactor: replace litellm with native openai + anthropic SDKs
- Remove litellm dependency entirely (supply chain risk mitigation) - Add AnthropicProvider (native SDK) and OpenAICompatProvider (unified) - Merge CustomProvider into OpenAICompatProvider, delete custom_provider.py - Add ProviderSpec.backend field for declarative provider routing - Remove _resolve_model, find_gateway, find_by_model (dead heuristics) - Pass resolved spec directly into provider — zero internal lookups - Stub out litellm-dependent model database (cli/models.py) - Add anthropic>=0.45.0 to dependencies, remove litellm - 593 tests passed, net -1034 lines
This commit is contained in:
parent
38ce054b31
commit
3dfdab704e
16
README.md
16
README.md
@ -842,7 +842,7 @@ Config file: `~/.nanobot/config.json`
|
||||
|
||||
| Provider | Purpose | Get API Key |
|
||||
|----------|---------|-------------|
|
||||
| `custom` | Any OpenAI-compatible endpoint (direct, no LiteLLM) | — |
|
||||
| `custom` | Any OpenAI-compatible endpoint | — |
|
||||
| `openrouter` | LLM (recommended, access to all models) | [openrouter.ai](https://openrouter.ai) |
|
||||
| `volcengine` | LLM (VolcEngine, pay-per-use) | [Coding Plan](https://www.volcengine.com/activity/codingplan?utm_campaign=nanobot&utm_content=nanobot&utm_medium=devrel&utm_source=OWO&utm_term=nanobot) · [volcengine.com](https://www.volcengine.com) |
|
||||
| `byteplus` | LLM (VolcEngine international, pay-per-use) | [Coding Plan](https://www.byteplus.com/en/activity/codingplan?utm_campaign=nanobot&utm_content=nanobot&utm_medium=devrel&utm_source=OWO&utm_term=nanobot) · [byteplus.com](https://www.byteplus.com) |
|
||||
@ -943,7 +943,7 @@ nanobot agent -c ~/.nanobot-telegram/config.json -w /tmp/nanobot-telegram-test -
|
||||
<details>
|
||||
<summary><b>Custom Provider (Any OpenAI-compatible API)</b></summary>
|
||||
|
||||
Connects directly to any OpenAI-compatible endpoint — LM Studio, llama.cpp, Together AI, Fireworks, Azure OpenAI, or any self-hosted server. Bypasses LiteLLM; model name is passed as-is.
|
||||
Connects directly to any OpenAI-compatible endpoint — LM Studio, llama.cpp, Together AI, Fireworks, Azure OpenAI, or any self-hosted server. Model name is passed as-is.
|
||||
|
||||
```json
|
||||
{
|
||||
@ -1120,10 +1120,9 @@ Adding a new provider only takes **2 steps** — no if-elif chains to touch.
|
||||
ProviderSpec(
|
||||
name="myprovider", # config field name
|
||||
keywords=("myprovider", "mymodel"), # model-name keywords for auto-matching
|
||||
env_key="MYPROVIDER_API_KEY", # env var for LiteLLM
|
||||
env_key="MYPROVIDER_API_KEY", # env var name
|
||||
display_name="My Provider", # shown in `nanobot status`
|
||||
litellm_prefix="myprovider", # auto-prefix: model → myprovider/model
|
||||
skip_prefixes=("myprovider/",), # don't double-prefix
|
||||
default_api_base="https://api.myprovider.com/v1", # OpenAI-compatible endpoint
|
||||
)
|
||||
```
|
||||
|
||||
@ -1135,20 +1134,19 @@ class ProvidersConfig(BaseModel):
|
||||
myprovider: ProviderConfig = ProviderConfig()
|
||||
```
|
||||
|
||||
That's it! Environment variables, model prefixing, config matching, and `nanobot status` display will all work automatically.
|
||||
That's it! Environment variables, model routing, config matching, and `nanobot status` display will all work automatically.
|
||||
|
||||
**Common `ProviderSpec` options:**
|
||||
|
||||
| Field | Description | Example |
|
||||
|-------|-------------|---------|
|
||||
| `litellm_prefix` | Auto-prefix model names for LiteLLM | `"dashscope"` → `dashscope/qwen-max` |
|
||||
| `skip_prefixes` | Don't prefix if model already starts with these | `("dashscope/", "openrouter/")` |
|
||||
| `default_api_base` | OpenAI-compatible base URL | `"https://api.deepseek.com"` |
|
||||
| `env_extras` | Additional env vars to set | `(("ZHIPUAI_API_KEY", "{api_key}"),)` |
|
||||
| `model_overrides` | Per-model parameter overrides | `(("kimi-k2.5", {"temperature": 1.0}),)` |
|
||||
| `is_gateway` | Can route any model (like OpenRouter) | `True` |
|
||||
| `detect_by_key_prefix` | Detect gateway by API key prefix | `"sk-or-"` |
|
||||
| `detect_by_base_keyword` | Detect gateway by API base URL | `"openrouter"` |
|
||||
| `strip_model_prefix` | Strip existing prefix before re-prefixing | `True` (for AiHubMix) |
|
||||
| `strip_model_prefix` | Strip provider prefix before sending to gateway | `True` (for AiHubMix) |
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
@ -376,61 +376,61 @@ def _onboard_plugins(config_path: Path) -> None:
|
||||
|
||||
|
||||
def _make_provider(config: Config):
|
||||
"""Create the appropriate LLM provider from config."""
|
||||
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
||||
"""Create the appropriate LLM provider from config.
|
||||
|
||||
Routing is driven by ``ProviderSpec.backend`` in the registry.
|
||||
"""
|
||||
from nanobot.providers.base import GenerationSettings
|
||||
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
|
||||
from nanobot.providers.registry import find_by_name
|
||||
|
||||
model = config.agents.defaults.model
|
||||
provider_name = config.get_provider_name(model)
|
||||
p = config.get_provider(model)
|
||||
spec = find_by_name(provider_name) if provider_name else None
|
||||
backend = spec.backend if spec else "openai_compat"
|
||||
|
||||
# OpenAI Codex (OAuth)
|
||||
if provider_name == "openai_codex" or model.startswith("openai-codex/"):
|
||||
provider = OpenAICodexProvider(default_model=model)
|
||||
# Custom: direct OpenAI-compatible endpoint, bypasses LiteLLM
|
||||
elif provider_name == "custom":
|
||||
from nanobot.providers.custom_provider import CustomProvider
|
||||
provider = CustomProvider(
|
||||
api_key=p.api_key if p else "no-key",
|
||||
api_base=config.get_api_base(model) or "http://localhost:8000/v1",
|
||||
default_model=model,
|
||||
extra_headers=p.extra_headers if p else None,
|
||||
)
|
||||
# Azure OpenAI: direct Azure OpenAI endpoint with deployment name
|
||||
elif provider_name == "azure_openai":
|
||||
# --- validation ---
|
||||
if backend == "azure_openai":
|
||||
if not p or not p.api_key or not p.api_base:
|
||||
console.print("[red]Error: Azure OpenAI requires api_key and api_base.[/red]")
|
||||
console.print("Set them in ~/.nanobot/config.json under providers.azure_openai section")
|
||||
console.print("Use the model field to specify the deployment name.")
|
||||
raise typer.Exit(1)
|
||||
elif backend == "openai_compat" and not model.startswith("bedrock/"):
|
||||
needs_key = not (p and p.api_key)
|
||||
exempt = spec and (spec.is_oauth or spec.is_local or spec.is_direct)
|
||||
if needs_key and not exempt:
|
||||
console.print("[red]Error: No API key configured.[/red]")
|
||||
console.print("Set one in ~/.nanobot/config.json under providers section")
|
||||
raise typer.Exit(1)
|
||||
|
||||
# --- instantiation by backend ---
|
||||
if backend == "openai_codex":
|
||||
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
|
||||
provider = OpenAICodexProvider(default_model=model)
|
||||
elif backend == "azure_openai":
|
||||
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key=p.api_key,
|
||||
api_base=p.api_base,
|
||||
default_model=model,
|
||||
)
|
||||
# OpenVINO Model Server: direct OpenAI-compatible endpoint at /v3
|
||||
elif provider_name == "ovms":
|
||||
from nanobot.providers.custom_provider import CustomProvider
|
||||
provider = CustomProvider(
|
||||
api_key=p.api_key if p else "no-key",
|
||||
api_base=config.get_api_base(model) or "http://localhost:8000/v3",
|
||||
default_model=model,
|
||||
)
|
||||
else:
|
||||
from nanobot.providers.litellm_provider import LiteLLMProvider
|
||||
from nanobot.providers.registry import find_by_name
|
||||
spec = find_by_name(provider_name)
|
||||
if not model.startswith("bedrock/") and not (p and p.api_key) and not (spec and (spec.is_oauth or spec.is_local)):
|
||||
console.print("[red]Error: No API key configured.[/red]")
|
||||
console.print("Set one in ~/.nanobot/config.json under providers section")
|
||||
raise typer.Exit(1)
|
||||
provider = LiteLLMProvider(
|
||||
elif backend == "anthropic":
|
||||
from nanobot.providers.anthropic_provider import AnthropicProvider
|
||||
provider = AnthropicProvider(
|
||||
api_key=p.api_key if p else None,
|
||||
api_base=config.get_api_base(model),
|
||||
default_model=model,
|
||||
extra_headers=p.extra_headers if p else None,
|
||||
provider_name=provider_name,
|
||||
)
|
||||
else:
|
||||
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
||||
provider = OpenAICompatProvider(
|
||||
api_key=p.api_key if p else None,
|
||||
api_base=config.get_api_base(model),
|
||||
default_model=model,
|
||||
extra_headers=p.extra_headers if p else None,
|
||||
spec=spec,
|
||||
)
|
||||
|
||||
defaults = config.agents.defaults
|
||||
@ -1203,11 +1203,20 @@ def _login_openai_codex() -> None:
|
||||
def _login_github_copilot() -> None:
|
||||
import asyncio
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
console.print("[cyan]Starting GitHub Copilot device flow...[/cyan]\n")
|
||||
|
||||
async def _trigger():
|
||||
from litellm import acompletion
|
||||
await acompletion(model="github_copilot/gpt-4o", messages=[{"role": "user", "content": "hi"}], max_tokens=1)
|
||||
client = AsyncOpenAI(
|
||||
api_key="dummy",
|
||||
base_url="https://api.githubcopilot.com",
|
||||
)
|
||||
await client.chat.completions.create(
|
||||
model="gpt-4o",
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
max_tokens=1,
|
||||
)
|
||||
|
||||
try:
|
||||
asyncio.run(_trigger())
|
||||
|
||||
@ -1,229 +1,29 @@
|
||||
"""Model information helpers for the onboard wizard.
|
||||
|
||||
Provides model context window lookup and autocomplete suggestions using litellm.
|
||||
Model database / autocomplete is temporarily disabled while litellm is
|
||||
being replaced. All public function signatures are preserved so callers
|
||||
continue to work without changes.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import lru_cache
|
||||
from typing import Any
|
||||
|
||||
|
||||
def _litellm():
|
||||
"""Lazy accessor for litellm (heavy import deferred until actually needed)."""
|
||||
import litellm as _ll
|
||||
|
||||
return _ll
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _get_model_cost_map() -> dict[str, Any]:
|
||||
"""Get litellm's model cost map (cached)."""
|
||||
return getattr(_litellm(), "model_cost", {})
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_all_models() -> list[str]:
|
||||
"""Get all known model names from litellm.
|
||||
"""
|
||||
models = set()
|
||||
|
||||
# From model_cost (has pricing info)
|
||||
cost_map = _get_model_cost_map()
|
||||
for k in cost_map.keys():
|
||||
if k != "sample_spec":
|
||||
models.add(k)
|
||||
|
||||
# From models_by_provider (more complete provider coverage)
|
||||
for provider_models in getattr(_litellm(), "models_by_provider", {}).values():
|
||||
if isinstance(provider_models, (set, list)):
|
||||
models.update(provider_models)
|
||||
|
||||
return sorted(models)
|
||||
|
||||
|
||||
def _normalize_model_name(model: str) -> str:
|
||||
"""Normalize model name for comparison."""
|
||||
return model.lower().replace("-", "_").replace(".", "")
|
||||
return []
|
||||
|
||||
|
||||
def find_model_info(model_name: str) -> dict[str, Any] | None:
|
||||
"""Find model info with fuzzy matching.
|
||||
|
||||
Args:
|
||||
model_name: Model name in any common format
|
||||
|
||||
Returns:
|
||||
Model info dict or None if not found
|
||||
"""
|
||||
cost_map = _get_model_cost_map()
|
||||
if not cost_map:
|
||||
return None
|
||||
|
||||
# Direct match
|
||||
if model_name in cost_map:
|
||||
return cost_map[model_name]
|
||||
|
||||
# Extract base name (without provider prefix)
|
||||
base_name = model_name.split("/")[-1] if "/" in model_name else model_name
|
||||
base_normalized = _normalize_model_name(base_name)
|
||||
|
||||
candidates = []
|
||||
|
||||
for key, info in cost_map.items():
|
||||
if key == "sample_spec":
|
||||
continue
|
||||
|
||||
key_base = key.split("/")[-1] if "/" in key else key
|
||||
key_base_normalized = _normalize_model_name(key_base)
|
||||
|
||||
# Score the match
|
||||
score = 0
|
||||
|
||||
# Exact base name match (highest priority)
|
||||
if base_normalized == key_base_normalized:
|
||||
score = 100
|
||||
# Base name contains model
|
||||
elif base_normalized in key_base_normalized:
|
||||
score = 80
|
||||
# Model contains base name
|
||||
elif key_base_normalized in base_normalized:
|
||||
score = 70
|
||||
# Partial match
|
||||
elif base_normalized[:10] in key_base_normalized:
|
||||
score = 50
|
||||
|
||||
if score > 0:
|
||||
# Prefer models with max_input_tokens
|
||||
if info.get("max_input_tokens"):
|
||||
score += 10
|
||||
candidates.append((score, key, info))
|
||||
|
||||
if not candidates:
|
||||
return None
|
||||
|
||||
# Return the best match
|
||||
candidates.sort(key=lambda x: (-x[0], x[1]))
|
||||
return candidates[0][2]
|
||||
|
||||
|
||||
def get_model_context_limit(model: str, provider: str = "auto") -> int | None:
|
||||
"""Get the maximum input context tokens for a model.
|
||||
|
||||
Args:
|
||||
model: Model name (e.g., "claude-3.5-sonnet", "gpt-4o")
|
||||
provider: Provider name for informational purposes (not yet used for filtering)
|
||||
|
||||
Returns:
|
||||
Maximum input tokens, or None if unknown
|
||||
|
||||
Note:
|
||||
The provider parameter is currently informational only. Future versions may
|
||||
use it to prefer provider-specific model variants in the lookup.
|
||||
"""
|
||||
# First try fuzzy search in model_cost (has more accurate max_input_tokens)
|
||||
info = find_model_info(model)
|
||||
if info:
|
||||
# Prefer max_input_tokens (this is what we want for context window)
|
||||
max_input = info.get("max_input_tokens")
|
||||
if max_input and isinstance(max_input, int):
|
||||
return max_input
|
||||
|
||||
# Fall back to litellm's get_max_tokens (returns max_output_tokens typically)
|
||||
try:
|
||||
result = _litellm().get_max_tokens(model)
|
||||
if result and result > 0:
|
||||
return result
|
||||
except (KeyError, ValueError, AttributeError):
|
||||
# Model not found in litellm's database or invalid response
|
||||
pass
|
||||
|
||||
# Last resort: use max_tokens from model_cost
|
||||
if info:
|
||||
max_tokens = info.get("max_tokens")
|
||||
if max_tokens and isinstance(max_tokens, int):
|
||||
return max_tokens
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _get_provider_keywords() -> dict[str, list[str]]:
|
||||
"""Build provider keywords mapping from nanobot's provider registry.
|
||||
|
||||
Returns:
|
||||
Dict mapping provider name to list of keywords for model filtering.
|
||||
"""
|
||||
try:
|
||||
from nanobot.providers.registry import PROVIDERS
|
||||
|
||||
mapping = {}
|
||||
for spec in PROVIDERS:
|
||||
if spec.keywords:
|
||||
mapping[spec.name] = list(spec.keywords)
|
||||
return mapping
|
||||
except ImportError:
|
||||
return {}
|
||||
def get_model_context_limit(model: str, provider: str = "auto") -> int | None:
|
||||
return None
|
||||
|
||||
|
||||
def get_model_suggestions(partial: str, provider: str = "auto", limit: int = 20) -> list[str]:
|
||||
"""Get autocomplete suggestions for model names.
|
||||
|
||||
Args:
|
||||
partial: Partial model name typed by user
|
||||
provider: Provider name for filtering (e.g., "openrouter", "minimax")
|
||||
limit: Maximum number of suggestions to return
|
||||
|
||||
Returns:
|
||||
List of matching model names
|
||||
"""
|
||||
all_models = get_all_models()
|
||||
if not all_models:
|
||||
return []
|
||||
|
||||
partial_lower = partial.lower()
|
||||
partial_normalized = _normalize_model_name(partial)
|
||||
|
||||
# Get provider keywords from registry
|
||||
provider_keywords = _get_provider_keywords()
|
||||
|
||||
# Filter by provider if specified
|
||||
allowed_keywords = None
|
||||
if provider and provider != "auto":
|
||||
allowed_keywords = provider_keywords.get(provider.lower())
|
||||
|
||||
matches = []
|
||||
|
||||
for model in all_models:
|
||||
model_lower = model.lower()
|
||||
|
||||
# Apply provider filter
|
||||
if allowed_keywords:
|
||||
if not any(kw in model_lower for kw in allowed_keywords):
|
||||
continue
|
||||
|
||||
# Match against partial input
|
||||
if not partial:
|
||||
matches.append(model)
|
||||
continue
|
||||
|
||||
if partial_lower in model_lower:
|
||||
# Score by position of match (earlier = better)
|
||||
pos = model_lower.find(partial_lower)
|
||||
score = 100 - pos
|
||||
matches.append((score, model))
|
||||
elif partial_normalized in _normalize_model_name(model):
|
||||
score = 50
|
||||
matches.append((score, model))
|
||||
|
||||
# Sort by score if we have scored matches
|
||||
if matches and isinstance(matches[0], tuple):
|
||||
matches.sort(key=lambda x: (-x[0], x[1]))
|
||||
matches = [m[1] for m in matches]
|
||||
else:
|
||||
matches.sort()
|
||||
|
||||
return matches[:limit]
|
||||
return []
|
||||
|
||||
|
||||
def format_token_count(tokens: int) -> str:
|
||||
|
||||
@ -249,8 +249,7 @@ class Config(BaseSettings):
|
||||
if p and p.api_base:
|
||||
return p.api_base
|
||||
# Only gateways get a default api_base here. Standard providers
|
||||
# (like Moonshot) set their base URL via env vars in _setup_env
|
||||
# to avoid polluting the global litellm.api_base.
|
||||
# resolve their base URL from the registry in the provider constructor.
|
||||
if name:
|
||||
spec = find_by_name(name)
|
||||
if spec and (spec.is_gateway or spec.is_local) and spec.default_api_base:
|
||||
|
||||
@ -7,17 +7,26 @@ from typing import TYPE_CHECKING
|
||||
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse
|
||||
|
||||
__all__ = ["LLMProvider", "LLMResponse", "LiteLLMProvider", "OpenAICodexProvider", "AzureOpenAIProvider"]
|
||||
__all__ = [
|
||||
"LLMProvider",
|
||||
"LLMResponse",
|
||||
"AnthropicProvider",
|
||||
"OpenAICompatProvider",
|
||||
"OpenAICodexProvider",
|
||||
"AzureOpenAIProvider",
|
||||
]
|
||||
|
||||
_LAZY_IMPORTS = {
|
||||
"LiteLLMProvider": ".litellm_provider",
|
||||
"AnthropicProvider": ".anthropic_provider",
|
||||
"OpenAICompatProvider": ".openai_compat_provider",
|
||||
"OpenAICodexProvider": ".openai_codex_provider",
|
||||
"AzureOpenAIProvider": ".azure_openai_provider",
|
||||
}
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nanobot.providers.anthropic_provider import AnthropicProvider
|
||||
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
||||
from nanobot.providers.litellm_provider import LiteLLMProvider
|
||||
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
||||
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
|
||||
|
||||
|
||||
|
||||
441
nanobot/providers/anthropic_provider.py
Normal file
441
nanobot/providers/anthropic_provider.py
Normal file
@ -0,0 +1,441 @@
|
||||
"""Anthropic provider — direct SDK integration for Claude models."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import secrets
|
||||
import string
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
import json_repair
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
|
||||
_ALNUM = string.ascii_letters + string.digits
|
||||
|
||||
|
||||
def _gen_tool_id() -> str:
|
||||
return "toolu_" + "".join(secrets.choice(_ALNUM) for _ in range(22))
|
||||
|
||||
|
||||
class AnthropicProvider(LLMProvider):
|
||||
"""LLM provider using the native Anthropic SDK for Claude models.
|
||||
|
||||
Handles message format conversion (OpenAI → Anthropic Messages API),
|
||||
prompt caching, extended thinking, tool calls, and streaming.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
default_model: str = "claude-sonnet-4-20250514",
|
||||
extra_headers: dict[str, str] | None = None,
|
||||
):
|
||||
super().__init__(api_key, api_base)
|
||||
self.default_model = default_model
|
||||
self.extra_headers = extra_headers or {}
|
||||
|
||||
from anthropic import AsyncAnthropic
|
||||
|
||||
client_kw: dict[str, Any] = {}
|
||||
if api_key:
|
||||
client_kw["api_key"] = api_key
|
||||
if api_base:
|
||||
client_kw["base_url"] = api_base
|
||||
if extra_headers:
|
||||
client_kw["default_headers"] = extra_headers
|
||||
self._client = AsyncAnthropic(**client_kw)
|
||||
|
||||
@staticmethod
|
||||
def _strip_prefix(model: str) -> str:
|
||||
if model.startswith("anthropic/"):
|
||||
return model[len("anthropic/"):]
|
||||
return model
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Message conversion: OpenAI chat format → Anthropic Messages API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _convert_messages(
|
||||
self, messages: list[dict[str, Any]],
|
||||
) -> tuple[str | list[dict[str, Any]], list[dict[str, Any]]]:
|
||||
"""Return ``(system, anthropic_messages)``."""
|
||||
system: str | list[dict[str, Any]] = ""
|
||||
raw: list[dict[str, Any]] = []
|
||||
|
||||
for msg in messages:
|
||||
role = msg.get("role", "")
|
||||
content = msg.get("content")
|
||||
|
||||
if role == "system":
|
||||
system = content if isinstance(content, (str, list)) else str(content or "")
|
||||
continue
|
||||
|
||||
if role == "tool":
|
||||
block = self._tool_result_block(msg)
|
||||
if raw and raw[-1]["role"] == "user":
|
||||
prev_c = raw[-1]["content"]
|
||||
if isinstance(prev_c, list):
|
||||
prev_c.append(block)
|
||||
else:
|
||||
raw[-1]["content"] = [
|
||||
{"type": "text", "text": prev_c or ""}, block,
|
||||
]
|
||||
else:
|
||||
raw.append({"role": "user", "content": [block]})
|
||||
continue
|
||||
|
||||
if role == "assistant":
|
||||
raw.append({"role": "assistant", "content": self._assistant_blocks(msg)})
|
||||
continue
|
||||
|
||||
if role == "user":
|
||||
raw.append({
|
||||
"role": "user",
|
||||
"content": self._convert_user_content(content),
|
||||
})
|
||||
continue
|
||||
|
||||
return system, self._merge_consecutive(raw)
|
||||
|
||||
@staticmethod
|
||||
def _tool_result_block(msg: dict[str, Any]) -> dict[str, Any]:
|
||||
content = msg.get("content")
|
||||
block: dict[str, Any] = {
|
||||
"type": "tool_result",
|
||||
"tool_use_id": msg.get("tool_call_id", ""),
|
||||
}
|
||||
if isinstance(content, (str, list)):
|
||||
block["content"] = content
|
||||
else:
|
||||
block["content"] = str(content) if content else ""
|
||||
return block
|
||||
|
||||
@staticmethod
|
||||
def _assistant_blocks(msg: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
blocks: list[dict[str, Any]] = []
|
||||
content = msg.get("content")
|
||||
|
||||
for tb in msg.get("thinking_blocks") or []:
|
||||
if isinstance(tb, dict) and tb.get("type") == "thinking":
|
||||
blocks.append({
|
||||
"type": "thinking",
|
||||
"thinking": tb.get("thinking", ""),
|
||||
"signature": tb.get("signature", ""),
|
||||
})
|
||||
|
||||
if isinstance(content, str) and content:
|
||||
blocks.append({"type": "text", "text": content})
|
||||
elif isinstance(content, list):
|
||||
for item in content:
|
||||
blocks.append(item if isinstance(item, dict) else {"type": "text", "text": str(item)})
|
||||
|
||||
for tc in msg.get("tool_calls") or []:
|
||||
if not isinstance(tc, dict):
|
||||
continue
|
||||
func = tc.get("function", {})
|
||||
args = func.get("arguments", "{}")
|
||||
if isinstance(args, str):
|
||||
args = json_repair.loads(args)
|
||||
blocks.append({
|
||||
"type": "tool_use",
|
||||
"id": tc.get("id") or _gen_tool_id(),
|
||||
"name": func.get("name", ""),
|
||||
"input": args,
|
||||
})
|
||||
|
||||
return blocks or [{"type": "text", "text": ""}]
|
||||
|
||||
def _convert_user_content(self, content: Any) -> Any:
|
||||
"""Convert user message content, translating image_url blocks."""
|
||||
if isinstance(content, str) or content is None:
|
||||
return content or "(empty)"
|
||||
if not isinstance(content, list):
|
||||
return str(content)
|
||||
|
||||
result: list[dict[str, Any]] = []
|
||||
for item in content:
|
||||
if not isinstance(item, dict):
|
||||
result.append({"type": "text", "text": str(item)})
|
||||
continue
|
||||
if item.get("type") == "image_url":
|
||||
converted = self._convert_image_block(item)
|
||||
if converted:
|
||||
result.append(converted)
|
||||
continue
|
||||
result.append(item)
|
||||
return result or "(empty)"
|
||||
|
||||
@staticmethod
|
||||
def _convert_image_block(block: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Convert OpenAI image_url block to Anthropic image block."""
|
||||
url = (block.get("image_url") or {}).get("url", "")
|
||||
if not url:
|
||||
return None
|
||||
m = re.match(r"data:(image/\w+);base64,(.+)", url, re.DOTALL)
|
||||
if m:
|
||||
return {
|
||||
"type": "image",
|
||||
"source": {"type": "base64", "media_type": m.group(1), "data": m.group(2)},
|
||||
}
|
||||
return {
|
||||
"type": "image",
|
||||
"source": {"type": "url", "url": url},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _merge_consecutive(msgs: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Anthropic requires alternating user/assistant roles."""
|
||||
merged: list[dict[str, Any]] = []
|
||||
for msg in msgs:
|
||||
if merged and merged[-1]["role"] == msg["role"]:
|
||||
prev_c = merged[-1]["content"]
|
||||
cur_c = msg["content"]
|
||||
if isinstance(prev_c, str):
|
||||
prev_c = [{"type": "text", "text": prev_c}]
|
||||
if isinstance(cur_c, str):
|
||||
cur_c = [{"type": "text", "text": cur_c}]
|
||||
if isinstance(cur_c, list):
|
||||
prev_c.extend(cur_c)
|
||||
merged[-1]["content"] = prev_c
|
||||
else:
|
||||
merged.append(msg)
|
||||
return merged
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tool definition conversion
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _convert_tools(tools: list[dict[str, Any]] | None) -> list[dict[str, Any]] | None:
|
||||
if not tools:
|
||||
return None
|
||||
result = []
|
||||
for tool in tools:
|
||||
func = tool.get("function", tool)
|
||||
entry: dict[str, Any] = {
|
||||
"name": func.get("name", ""),
|
||||
"input_schema": func.get("parameters", {"type": "object", "properties": {}}),
|
||||
}
|
||||
desc = func.get("description")
|
||||
if desc:
|
||||
entry["description"] = desc
|
||||
if "cache_control" in tool:
|
||||
entry["cache_control"] = tool["cache_control"]
|
||||
result.append(entry)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _convert_tool_choice(
|
||||
tool_choice: str | dict[str, Any] | None,
|
||||
thinking_enabled: bool = False,
|
||||
) -> dict[str, Any] | None:
|
||||
if thinking_enabled:
|
||||
return {"type": "auto"}
|
||||
if tool_choice is None or tool_choice == "auto":
|
||||
return {"type": "auto"}
|
||||
if tool_choice == "required":
|
||||
return {"type": "any"}
|
||||
if tool_choice == "none":
|
||||
return None
|
||||
if isinstance(tool_choice, dict):
|
||||
name = tool_choice.get("function", {}).get("name")
|
||||
if name:
|
||||
return {"type": "tool", "name": name}
|
||||
return {"type": "auto"}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Prompt caching
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _apply_cache_control(
|
||||
system: str | list[dict[str, Any]],
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None,
|
||||
) -> tuple[str | list[dict[str, Any]], list[dict[str, Any]], list[dict[str, Any]] | None]:
|
||||
marker = {"type": "ephemeral"}
|
||||
|
||||
if isinstance(system, str) and system:
|
||||
system = [{"type": "text", "text": system, "cache_control": marker}]
|
||||
elif isinstance(system, list) and system:
|
||||
system = list(system)
|
||||
system[-1] = {**system[-1], "cache_control": marker}
|
||||
|
||||
new_msgs = list(messages)
|
||||
if len(new_msgs) >= 3:
|
||||
m = new_msgs[-2]
|
||||
c = m.get("content")
|
||||
if isinstance(c, str):
|
||||
new_msgs[-2] = {**m, "content": [{"type": "text", "text": c, "cache_control": marker}]}
|
||||
elif isinstance(c, list) and c:
|
||||
nc = list(c)
|
||||
nc[-1] = {**nc[-1], "cache_control": marker}
|
||||
new_msgs[-2] = {**m, "content": nc}
|
||||
|
||||
new_tools = tools
|
||||
if tools:
|
||||
new_tools = list(tools)
|
||||
new_tools[-1] = {**new_tools[-1], "cache_control": marker}
|
||||
|
||||
return system, new_msgs, new_tools
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Build API kwargs
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _build_kwargs(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None,
|
||||
model: str | None,
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
reasoning_effort: str | None,
|
||||
tool_choice: str | dict[str, Any] | None,
|
||||
supports_caching: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
model_name = self._strip_prefix(model or self.default_model)
|
||||
system, anthropic_msgs = self._convert_messages(self._sanitize_empty_content(messages))
|
||||
anthropic_tools = self._convert_tools(tools)
|
||||
|
||||
if supports_caching:
|
||||
system, anthropic_msgs, anthropic_tools = self._apply_cache_control(
|
||||
system, anthropic_msgs, anthropic_tools,
|
||||
)
|
||||
|
||||
max_tokens = max(1, max_tokens)
|
||||
thinking_enabled = bool(reasoning_effort)
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": model_name,
|
||||
"messages": anthropic_msgs,
|
||||
"max_tokens": max_tokens,
|
||||
}
|
||||
|
||||
if system:
|
||||
kwargs["system"] = system
|
||||
|
||||
if thinking_enabled:
|
||||
budget_map = {"low": 1024, "medium": 4096, "high": max(8192, max_tokens)}
|
||||
budget = budget_map.get(reasoning_effort.lower(), 4096) # type: ignore[union-attr]
|
||||
kwargs["thinking"] = {"type": "enabled", "budget_tokens": budget}
|
||||
kwargs["max_tokens"] = max(max_tokens, budget + 4096)
|
||||
kwargs["temperature"] = 1.0
|
||||
else:
|
||||
kwargs["temperature"] = temperature
|
||||
|
||||
if anthropic_tools:
|
||||
kwargs["tools"] = anthropic_tools
|
||||
tc = self._convert_tool_choice(tool_choice, thinking_enabled)
|
||||
if tc:
|
||||
kwargs["tool_choice"] = tc
|
||||
|
||||
if self.extra_headers:
|
||||
kwargs["extra_headers"] = self.extra_headers
|
||||
|
||||
return kwargs
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Response parsing
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _parse_response(response: Any) -> LLMResponse:
|
||||
content_parts: list[str] = []
|
||||
tool_calls: list[ToolCallRequest] = []
|
||||
thinking_blocks: list[dict[str, Any]] = []
|
||||
|
||||
for block in response.content:
|
||||
if block.type == "text":
|
||||
content_parts.append(block.text)
|
||||
elif block.type == "tool_use":
|
||||
tool_calls.append(ToolCallRequest(
|
||||
id=block.id,
|
||||
name=block.name,
|
||||
arguments=block.input if isinstance(block.input, dict) else {},
|
||||
))
|
||||
elif block.type == "thinking":
|
||||
thinking_blocks.append({
|
||||
"type": "thinking",
|
||||
"thinking": block.thinking,
|
||||
"signature": getattr(block, "signature", ""),
|
||||
})
|
||||
|
||||
stop_map = {"tool_use": "tool_calls", "end_turn": "stop", "max_tokens": "length"}
|
||||
finish_reason = stop_map.get(response.stop_reason or "", response.stop_reason or "stop")
|
||||
|
||||
usage: dict[str, int] = {}
|
||||
if response.usage:
|
||||
usage = {
|
||||
"prompt_tokens": response.usage.input_tokens,
|
||||
"completion_tokens": response.usage.output_tokens,
|
||||
"total_tokens": response.usage.input_tokens + response.usage.output_tokens,
|
||||
}
|
||||
for attr in ("cache_creation_input_tokens", "cache_read_input_tokens"):
|
||||
val = getattr(response.usage, attr, 0)
|
||||
if val:
|
||||
usage[attr] = val
|
||||
|
||||
return LLMResponse(
|
||||
content="".join(content_parts) or None,
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=finish_reason,
|
||||
usage=usage,
|
||||
thinking_blocks=thinking_blocks or None,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
reasoning_effort: str | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
) -> LLMResponse:
|
||||
kwargs = self._build_kwargs(
|
||||
messages, tools, model, max_tokens, temperature,
|
||||
reasoning_effort, tool_choice,
|
||||
)
|
||||
try:
|
||||
response = await self._client.messages.create(**kwargs)
|
||||
return self._parse_response(response)
|
||||
except Exception as e:
|
||||
return LLMResponse(content=f"Error calling LLM: {e}", finish_reason="error")
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
reasoning_effort: str | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||
) -> LLMResponse:
|
||||
kwargs = self._build_kwargs(
|
||||
messages, tools, model, max_tokens, temperature,
|
||||
reasoning_effort, tool_choice,
|
||||
)
|
||||
try:
|
||||
async with self._client.messages.stream(**kwargs) as stream:
|
||||
if on_content_delta:
|
||||
async for text in stream.text_stream:
|
||||
await on_content_delta(text)
|
||||
response = await stream.get_final_message()
|
||||
return self._parse_response(response)
|
||||
except Exception as e:
|
||||
return LLMResponse(content=f"Error calling LLM: {e}", finish_reason="error")
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return self.default_model
|
||||
@ -1,152 +0,0 @@
|
||||
"""Direct OpenAI-compatible provider — bypasses LiteLLM."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
import json_repair
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
|
||||
|
||||
class CustomProvider(LLMProvider):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str = "no-key",
|
||||
api_base: str = "http://localhost:8000/v1",
|
||||
default_model: str = "default",
|
||||
extra_headers: dict[str, str] | None = None,
|
||||
):
|
||||
super().__init__(api_key, api_base)
|
||||
self.default_model = default_model
|
||||
self._client = AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
default_headers={
|
||||
"x-session-affinity": uuid.uuid4().hex,
|
||||
**(extra_headers or {}),
|
||||
},
|
||||
)
|
||||
|
||||
def _build_kwargs(
|
||||
self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None,
|
||||
model: str | None, max_tokens: int, temperature: float,
|
||||
reasoning_effort: str | None, tool_choice: str | dict[str, Any] | None,
|
||||
) -> dict[str, Any]:
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": model or self.default_model,
|
||||
"messages": self._sanitize_empty_content(messages),
|
||||
"max_tokens": max(1, max_tokens),
|
||||
"temperature": temperature,
|
||||
}
|
||||
if reasoning_effort:
|
||||
kwargs["reasoning_effort"] = reasoning_effort
|
||||
if tools:
|
||||
kwargs.update(tools=tools, tool_choice=tool_choice or "auto")
|
||||
return kwargs
|
||||
|
||||
def _handle_error(self, e: Exception) -> LLMResponse:
|
||||
body = getattr(e, "doc", None) or getattr(getattr(e, "response", None), "text", None)
|
||||
msg = f"Error: {body.strip()[:500]}" if body and body.strip() else f"Error: {e}"
|
||||
return LLMResponse(content=msg, finish_reason="error")
|
||||
|
||||
async def chat(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7,
|
||||
reasoning_effort: str | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None) -> LLMResponse:
|
||||
kwargs = self._build_kwargs(messages, tools, model, max_tokens, temperature, reasoning_effort, tool_choice)
|
||||
try:
|
||||
return self._parse(await self._client.chat.completions.create(**kwargs))
|
||||
except Exception as e:
|
||||
return self._handle_error(e)
|
||||
|
||||
async def chat_stream(
|
||||
self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7,
|
||||
reasoning_effort: str | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||
) -> LLMResponse:
|
||||
kwargs = self._build_kwargs(messages, tools, model, max_tokens, temperature, reasoning_effort, tool_choice)
|
||||
kwargs["stream"] = True
|
||||
try:
|
||||
stream = await self._client.chat.completions.create(**kwargs)
|
||||
chunks: list[Any] = []
|
||||
async for chunk in stream:
|
||||
chunks.append(chunk)
|
||||
if on_content_delta and chunk.choices:
|
||||
text = getattr(chunk.choices[0].delta, "content", None)
|
||||
if text:
|
||||
await on_content_delta(text)
|
||||
return self._parse_chunks(chunks)
|
||||
except Exception as e:
|
||||
return self._handle_error(e)
|
||||
|
||||
def _parse(self, response: Any) -> LLMResponse:
|
||||
if not response.choices:
|
||||
return LLMResponse(
|
||||
content="Error: API returned empty choices.",
|
||||
finish_reason="error",
|
||||
)
|
||||
choice = response.choices[0]
|
||||
msg = choice.message
|
||||
tool_calls = [
|
||||
ToolCallRequest(
|
||||
id=tc.id, name=tc.function.name,
|
||||
arguments=json_repair.loads(tc.function.arguments) if isinstance(tc.function.arguments, str) else tc.function.arguments,
|
||||
)
|
||||
for tc in (msg.tool_calls or [])
|
||||
]
|
||||
u = response.usage
|
||||
return LLMResponse(
|
||||
content=msg.content, tool_calls=tool_calls,
|
||||
finish_reason=choice.finish_reason or "stop",
|
||||
usage={"prompt_tokens": u.prompt_tokens, "completion_tokens": u.completion_tokens, "total_tokens": u.total_tokens} if u else {},
|
||||
reasoning_content=getattr(msg, "reasoning_content", None) or None,
|
||||
)
|
||||
|
||||
def _parse_chunks(self, chunks: list[Any]) -> LLMResponse:
|
||||
"""Reassemble streamed chunks into a single LLMResponse."""
|
||||
content_parts: list[str] = []
|
||||
tc_bufs: dict[int, dict[str, str]] = {}
|
||||
finish_reason = "stop"
|
||||
usage: dict[str, int] = {}
|
||||
|
||||
for chunk in chunks:
|
||||
if not chunk.choices:
|
||||
if hasattr(chunk, "usage") and chunk.usage:
|
||||
u = chunk.usage
|
||||
usage = {"prompt_tokens": u.prompt_tokens or 0, "completion_tokens": u.completion_tokens or 0,
|
||||
"total_tokens": u.total_tokens or 0}
|
||||
continue
|
||||
choice = chunk.choices[0]
|
||||
if choice.finish_reason:
|
||||
finish_reason = choice.finish_reason
|
||||
delta = choice.delta
|
||||
if delta and delta.content:
|
||||
content_parts.append(delta.content)
|
||||
for tc in (delta.tool_calls or []) if delta else []:
|
||||
buf = tc_bufs.setdefault(tc.index, {"id": "", "name": "", "arguments": ""})
|
||||
if tc.id:
|
||||
buf["id"] = tc.id
|
||||
if tc.function and tc.function.name:
|
||||
buf["name"] = tc.function.name
|
||||
if tc.function and tc.function.arguments:
|
||||
buf["arguments"] += tc.function.arguments
|
||||
|
||||
return LLMResponse(
|
||||
content="".join(content_parts) or None,
|
||||
tool_calls=[
|
||||
ToolCallRequest(id=b["id"], name=b["name"], arguments=json_repair.loads(b["arguments"]) if b["arguments"] else {})
|
||||
for b in tc_bufs.values()
|
||||
],
|
||||
finish_reason=finish_reason,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return self.default_model
|
||||
@ -1,413 +0,0 @@
|
||||
"""LiteLLM provider implementation for multi-provider support."""
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
import secrets
|
||||
import string
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
import json_repair
|
||||
import litellm
|
||||
from litellm import acompletion
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
from nanobot.providers.registry import find_by_model, find_gateway
|
||||
|
||||
# Standard chat-completion message keys.
|
||||
_ALLOWED_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name", "reasoning_content"})
|
||||
_ANTHROPIC_EXTRA_KEYS = frozenset({"thinking_blocks"})
|
||||
_ALNUM = string.ascii_letters + string.digits
|
||||
|
||||
def _short_tool_id() -> str:
|
||||
"""Generate a 9-char alphanumeric ID compatible with all providers (incl. Mistral)."""
|
||||
return "".join(secrets.choice(_ALNUM) for _ in range(9))
|
||||
|
||||
|
||||
class LiteLLMProvider(LLMProvider):
|
||||
"""
|
||||
LLM provider using LiteLLM for multi-provider support.
|
||||
|
||||
Supports OpenRouter, Anthropic, OpenAI, Gemini, MiniMax, and many other providers through
|
||||
a unified interface. Provider-specific logic is driven by the registry
|
||||
(see providers/registry.py) — no if-elif chains needed here.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
default_model: str = "anthropic/claude-opus-4-5",
|
||||
extra_headers: dict[str, str] | None = None,
|
||||
provider_name: str | None = None,
|
||||
):
|
||||
super().__init__(api_key, api_base)
|
||||
self.default_model = default_model
|
||||
self.extra_headers = extra_headers or {}
|
||||
|
||||
# Detect gateway / local deployment.
|
||||
# provider_name (from config key) is the primary signal;
|
||||
# api_key / api_base are fallback for auto-detection.
|
||||
self._gateway = find_gateway(provider_name, api_key, api_base)
|
||||
|
||||
# Configure environment variables
|
||||
if api_key:
|
||||
self._setup_env(api_key, api_base, default_model)
|
||||
|
||||
if api_base:
|
||||
litellm.api_base = api_base
|
||||
|
||||
# Disable LiteLLM logging noise
|
||||
litellm.suppress_debug_info = True
|
||||
# Drop unsupported parameters for providers (e.g., gpt-5 rejects some params)
|
||||
litellm.drop_params = True
|
||||
|
||||
self._langsmith_enabled = bool(os.getenv("LANGSMITH_API_KEY"))
|
||||
|
||||
def _setup_env(self, api_key: str, api_base: str | None, model: str) -> None:
|
||||
"""Set environment variables based on detected provider."""
|
||||
spec = self._gateway or find_by_model(model)
|
||||
if not spec:
|
||||
return
|
||||
if not spec.env_key:
|
||||
# OAuth/provider-only specs (for example: openai_codex)
|
||||
return
|
||||
|
||||
# Gateway/local overrides existing env; standard provider doesn't
|
||||
if self._gateway:
|
||||
os.environ[spec.env_key] = api_key
|
||||
else:
|
||||
os.environ.setdefault(spec.env_key, api_key)
|
||||
|
||||
# Resolve env_extras placeholders:
|
||||
# {api_key} → user's API key
|
||||
# {api_base} → user's api_base, falling back to spec.default_api_base
|
||||
effective_base = api_base or spec.default_api_base
|
||||
for env_name, env_val in spec.env_extras:
|
||||
resolved = env_val.replace("{api_key}", api_key)
|
||||
resolved = resolved.replace("{api_base}", effective_base)
|
||||
os.environ.setdefault(env_name, resolved)
|
||||
|
||||
def _resolve_model(self, model: str) -> str:
|
||||
"""Resolve model name by applying provider/gateway prefixes."""
|
||||
if self._gateway:
|
||||
prefix = self._gateway.litellm_prefix
|
||||
if self._gateway.strip_model_prefix:
|
||||
model = model.split("/")[-1]
|
||||
if prefix:
|
||||
model = f"{prefix}/{model}"
|
||||
return model
|
||||
|
||||
# Standard mode: auto-prefix for known providers
|
||||
spec = find_by_model(model)
|
||||
if spec and spec.litellm_prefix:
|
||||
model = self._canonicalize_explicit_prefix(model, spec.name, spec.litellm_prefix)
|
||||
if not any(model.startswith(s) for s in spec.skip_prefixes):
|
||||
model = f"{spec.litellm_prefix}/{model}"
|
||||
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _canonicalize_explicit_prefix(model: str, spec_name: str, canonical_prefix: str) -> str:
|
||||
"""Normalize explicit provider prefixes like `github-copilot/...`."""
|
||||
if "/" not in model:
|
||||
return model
|
||||
prefix, remainder = model.split("/", 1)
|
||||
if prefix.lower().replace("-", "_") != spec_name:
|
||||
return model
|
||||
return f"{canonical_prefix}/{remainder}"
|
||||
|
||||
def _supports_cache_control(self, model: str) -> bool:
|
||||
"""Return True when the provider supports cache_control on content blocks."""
|
||||
if self._gateway is not None:
|
||||
return self._gateway.supports_prompt_caching
|
||||
spec = find_by_model(model)
|
||||
return spec is not None and spec.supports_prompt_caching
|
||||
|
||||
def _apply_cache_control(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None,
|
||||
) -> tuple[list[dict[str, Any]], list[dict[str, Any]] | None]:
|
||||
"""Return copies of messages and tools with cache_control injected.
|
||||
|
||||
Two breakpoints are placed:
|
||||
1. System message — caches the static system prompt
|
||||
2. Second-to-last message — caches the conversation history prefix
|
||||
This maximises cache hits across multi-turn conversations.
|
||||
"""
|
||||
cache_marker = {"type": "ephemeral"}
|
||||
new_messages = list(messages)
|
||||
|
||||
def _mark(msg: dict[str, Any]) -> dict[str, Any]:
|
||||
content = msg.get("content")
|
||||
if isinstance(content, str):
|
||||
return {**msg, "content": [
|
||||
{"type": "text", "text": content, "cache_control": cache_marker}
|
||||
]}
|
||||
elif isinstance(content, list) and content:
|
||||
new_content = list(content)
|
||||
new_content[-1] = {**new_content[-1], "cache_control": cache_marker}
|
||||
return {**msg, "content": new_content}
|
||||
return msg
|
||||
|
||||
# Breakpoint 1: system message
|
||||
if new_messages and new_messages[0].get("role") == "system":
|
||||
new_messages[0] = _mark(new_messages[0])
|
||||
|
||||
# Breakpoint 2: second-to-last message (caches conversation history prefix)
|
||||
if len(new_messages) >= 3:
|
||||
new_messages[-2] = _mark(new_messages[-2])
|
||||
|
||||
new_tools = tools
|
||||
if tools:
|
||||
new_tools = list(tools)
|
||||
new_tools[-1] = {**new_tools[-1], "cache_control": cache_marker}
|
||||
|
||||
return new_messages, new_tools
|
||||
|
||||
def _apply_model_overrides(self, model: str, kwargs: dict[str, Any]) -> None:
|
||||
"""Apply model-specific parameter overrides from the registry."""
|
||||
model_lower = model.lower()
|
||||
spec = find_by_model(model)
|
||||
if spec:
|
||||
for pattern, overrides in spec.model_overrides:
|
||||
if pattern in model_lower:
|
||||
kwargs.update(overrides)
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def _extra_msg_keys(original_model: str, resolved_model: str) -> frozenset[str]:
|
||||
"""Return provider-specific extra keys to preserve in request messages."""
|
||||
spec = find_by_model(original_model) or find_by_model(resolved_model)
|
||||
if (spec and spec.name == "anthropic") or "claude" in original_model.lower() or resolved_model.startswith("anthropic/"):
|
||||
return _ANTHROPIC_EXTRA_KEYS
|
||||
return frozenset()
|
||||
|
||||
@staticmethod
|
||||
def _normalize_tool_call_id(tool_call_id: Any) -> Any:
|
||||
"""Normalize tool_call_id to a provider-safe 9-char alphanumeric form."""
|
||||
if not isinstance(tool_call_id, str):
|
||||
return tool_call_id
|
||||
if len(tool_call_id) == 9 and tool_call_id.isalnum():
|
||||
return tool_call_id
|
||||
return hashlib.sha1(tool_call_id.encode()).hexdigest()[:9]
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_messages(messages: list[dict[str, Any]], extra_keys: frozenset[str] = frozenset()) -> list[dict[str, Any]]:
|
||||
"""Strip non-standard keys and ensure assistant messages have a content key."""
|
||||
allowed = _ALLOWED_MSG_KEYS | extra_keys
|
||||
sanitized = LLMProvider._sanitize_request_messages(messages, allowed)
|
||||
id_map: dict[str, str] = {}
|
||||
|
||||
def map_id(value: Any) -> Any:
|
||||
if not isinstance(value, str):
|
||||
return value
|
||||
return id_map.setdefault(value, LiteLLMProvider._normalize_tool_call_id(value))
|
||||
|
||||
for clean in sanitized:
|
||||
# Keep assistant tool_calls[].id and tool tool_call_id in sync after
|
||||
# shortening, otherwise strict providers reject the broken linkage.
|
||||
if isinstance(clean.get("tool_calls"), list):
|
||||
normalized_tool_calls = []
|
||||
for tc in clean["tool_calls"]:
|
||||
if not isinstance(tc, dict):
|
||||
normalized_tool_calls.append(tc)
|
||||
continue
|
||||
tc_clean = dict(tc)
|
||||
tc_clean["id"] = map_id(tc_clean.get("id"))
|
||||
normalized_tool_calls.append(tc_clean)
|
||||
clean["tool_calls"] = normalized_tool_calls
|
||||
|
||||
if "tool_call_id" in clean and clean["tool_call_id"]:
|
||||
clean["tool_call_id"] = map_id(clean["tool_call_id"])
|
||||
return sanitized
|
||||
|
||||
def _build_chat_kwargs(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None,
|
||||
model: str | None,
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
reasoning_effort: str | None,
|
||||
tool_choice: str | dict[str, Any] | None,
|
||||
) -> tuple[dict[str, Any], str]:
|
||||
"""Build the kwargs dict for ``acompletion``.
|
||||
|
||||
Returns ``(kwargs, original_model)`` so callers can reuse the
|
||||
original model string for downstream logic.
|
||||
"""
|
||||
original_model = model or self.default_model
|
||||
resolved = self._resolve_model(original_model)
|
||||
extra_msg_keys = self._extra_msg_keys(original_model, resolved)
|
||||
|
||||
if self._supports_cache_control(original_model):
|
||||
messages, tools = self._apply_cache_control(messages, tools)
|
||||
|
||||
max_tokens = max(1, max_tokens)
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": resolved,
|
||||
"messages": self._sanitize_messages(
|
||||
self._sanitize_empty_content(messages), extra_keys=extra_msg_keys,
|
||||
),
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
}
|
||||
|
||||
if self._gateway:
|
||||
kwargs.update(self._gateway.litellm_kwargs)
|
||||
|
||||
self._apply_model_overrides(resolved, kwargs)
|
||||
|
||||
if self._langsmith_enabled:
|
||||
kwargs.setdefault("callbacks", []).append("langsmith")
|
||||
|
||||
if self.api_key:
|
||||
kwargs["api_key"] = self.api_key
|
||||
if self.api_base:
|
||||
kwargs["api_base"] = self.api_base
|
||||
if self.extra_headers:
|
||||
kwargs["extra_headers"] = self.extra_headers
|
||||
|
||||
if reasoning_effort:
|
||||
kwargs["reasoning_effort"] = reasoning_effort
|
||||
kwargs["drop_params"] = True
|
||||
|
||||
if tools:
|
||||
kwargs["tools"] = tools
|
||||
kwargs["tool_choice"] = tool_choice or "auto"
|
||||
|
||||
return kwargs, original_model
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
reasoning_effort: str | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
) -> LLMResponse:
|
||||
"""Send a chat completion request via LiteLLM."""
|
||||
kwargs, _ = self._build_chat_kwargs(
|
||||
messages, tools, model, max_tokens, temperature,
|
||||
reasoning_effort, tool_choice,
|
||||
)
|
||||
try:
|
||||
response = await acompletion(**kwargs)
|
||||
return self._parse_response(response)
|
||||
except Exception as e:
|
||||
return LLMResponse(
|
||||
content=f"Error calling LLM: {str(e)}",
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
reasoning_effort: str | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||
) -> LLMResponse:
|
||||
"""Stream a chat completion via LiteLLM, forwarding text deltas."""
|
||||
kwargs, _ = self._build_chat_kwargs(
|
||||
messages, tools, model, max_tokens, temperature,
|
||||
reasoning_effort, tool_choice,
|
||||
)
|
||||
kwargs["stream"] = True
|
||||
|
||||
try:
|
||||
stream = await acompletion(**kwargs)
|
||||
chunks: list[Any] = []
|
||||
async for chunk in stream:
|
||||
chunks.append(chunk)
|
||||
if on_content_delta:
|
||||
delta = chunk.choices[0].delta if chunk.choices else None
|
||||
text = getattr(delta, "content", None) if delta else None
|
||||
if text:
|
||||
await on_content_delta(text)
|
||||
|
||||
full_response = litellm.stream_chunk_builder(
|
||||
chunks, messages=kwargs["messages"],
|
||||
)
|
||||
return self._parse_response(full_response)
|
||||
except Exception as e:
|
||||
return LLMResponse(
|
||||
content=f"Error calling LLM: {str(e)}",
|
||||
finish_reason="error",
|
||||
)
|
||||
|
||||
def _parse_response(self, response: Any) -> LLMResponse:
|
||||
"""Parse LiteLLM response into our standard format."""
|
||||
choice = response.choices[0]
|
||||
message = choice.message
|
||||
content = message.content
|
||||
finish_reason = choice.finish_reason
|
||||
|
||||
# Some providers (e.g. GitHub Copilot) split content and tool_calls
|
||||
# across multiple choices. Merge them so tool_calls are not lost.
|
||||
raw_tool_calls = []
|
||||
for ch in response.choices:
|
||||
msg = ch.message
|
||||
if hasattr(msg, "tool_calls") and msg.tool_calls:
|
||||
raw_tool_calls.extend(msg.tool_calls)
|
||||
if ch.finish_reason in ("tool_calls", "stop"):
|
||||
finish_reason = ch.finish_reason
|
||||
if not content and msg.content:
|
||||
content = msg.content
|
||||
|
||||
if len(response.choices) > 1:
|
||||
logger.debug("LiteLLM response has {} choices, merged {} tool_calls",
|
||||
len(response.choices), len(raw_tool_calls))
|
||||
|
||||
tool_calls = []
|
||||
for tc in raw_tool_calls:
|
||||
# Parse arguments from JSON string if needed
|
||||
args = tc.function.arguments
|
||||
if isinstance(args, str):
|
||||
args = json_repair.loads(args)
|
||||
|
||||
provider_specific_fields = getattr(tc, "provider_specific_fields", None) or None
|
||||
function_provider_specific_fields = (
|
||||
getattr(tc.function, "provider_specific_fields", None) or None
|
||||
)
|
||||
|
||||
tool_calls.append(ToolCallRequest(
|
||||
id=_short_tool_id(),
|
||||
name=tc.function.name,
|
||||
arguments=args,
|
||||
provider_specific_fields=provider_specific_fields,
|
||||
function_provider_specific_fields=function_provider_specific_fields,
|
||||
))
|
||||
|
||||
usage = {}
|
||||
if hasattr(response, "usage") and response.usage:
|
||||
usage = {
|
||||
"prompt_tokens": response.usage.prompt_tokens,
|
||||
"completion_tokens": response.usage.completion_tokens,
|
||||
"total_tokens": response.usage.total_tokens,
|
||||
}
|
||||
|
||||
reasoning_content = getattr(message, "reasoning_content", None) or None
|
||||
thinking_blocks = getattr(message, "thinking_blocks", None) or None
|
||||
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=finish_reason or "stop",
|
||||
usage=usage,
|
||||
reasoning_content=reasoning_content,
|
||||
thinking_blocks=thinking_blocks,
|
||||
)
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
"""Get the default model."""
|
||||
return self.default_model
|
||||
349
nanobot/providers/openai_compat_provider.py
Normal file
349
nanobot/providers/openai_compat_provider.py
Normal file
@ -0,0 +1,349 @@
|
||||
"""OpenAI-compatible provider for all non-Anthropic LLM APIs."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
import secrets
|
||||
import string
|
||||
import uuid
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import json_repair
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nanobot.providers.registry import ProviderSpec
|
||||
|
||||
_ALLOWED_MSG_KEYS = frozenset({
|
||||
"role", "content", "tool_calls", "tool_call_id", "name", "reasoning_content",
|
||||
})
|
||||
_ALNUM = string.ascii_letters + string.digits
|
||||
|
||||
|
||||
def _short_tool_id() -> str:
|
||||
"""9-char alphanumeric ID compatible with all providers (incl. Mistral)."""
|
||||
return "".join(secrets.choice(_ALNUM) for _ in range(9))
|
||||
|
||||
|
||||
class OpenAICompatProvider(LLMProvider):
|
||||
"""Unified provider for all OpenAI-compatible APIs.
|
||||
|
||||
Receives a resolved ``ProviderSpec`` from the caller — no internal
|
||||
registry lookups needed.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
default_model: str = "gpt-4o",
|
||||
extra_headers: dict[str, str] | None = None,
|
||||
spec: ProviderSpec | None = None,
|
||||
):
|
||||
super().__init__(api_key, api_base)
|
||||
self.default_model = default_model
|
||||
self.extra_headers = extra_headers or {}
|
||||
self._spec = spec
|
||||
|
||||
if api_key and spec and spec.env_key:
|
||||
self._setup_env(api_key, api_base)
|
||||
|
||||
effective_base = api_base or (spec.default_api_base if spec else None) or None
|
||||
|
||||
self._client = AsyncOpenAI(
|
||||
api_key=api_key or "no-key",
|
||||
base_url=effective_base,
|
||||
default_headers={
|
||||
"x-session-affinity": uuid.uuid4().hex,
|
||||
**(extra_headers or {}),
|
||||
},
|
||||
)
|
||||
|
||||
def _setup_env(self, api_key: str, api_base: str | None) -> None:
|
||||
"""Set environment variables based on provider spec."""
|
||||
spec = self._spec
|
||||
if not spec or not spec.env_key:
|
||||
return
|
||||
if spec.is_gateway:
|
||||
os.environ[spec.env_key] = api_key
|
||||
else:
|
||||
os.environ.setdefault(spec.env_key, api_key)
|
||||
effective_base = api_base or spec.default_api_base
|
||||
for env_name, env_val in spec.env_extras:
|
||||
resolved = env_val.replace("{api_key}", api_key).replace("{api_base}", effective_base)
|
||||
os.environ.setdefault(env_name, resolved)
|
||||
|
||||
@staticmethod
|
||||
def _apply_cache_control(
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None,
|
||||
) -> tuple[list[dict[str, Any]], list[dict[str, Any]] | None]:
|
||||
"""Inject cache_control markers for prompt caching."""
|
||||
cache_marker = {"type": "ephemeral"}
|
||||
new_messages = list(messages)
|
||||
|
||||
def _mark(msg: dict[str, Any]) -> dict[str, Any]:
|
||||
content = msg.get("content")
|
||||
if isinstance(content, str):
|
||||
return {**msg, "content": [
|
||||
{"type": "text", "text": content, "cache_control": cache_marker},
|
||||
]}
|
||||
if isinstance(content, list) and content:
|
||||
nc = list(content)
|
||||
nc[-1] = {**nc[-1], "cache_control": cache_marker}
|
||||
return {**msg, "content": nc}
|
||||
return msg
|
||||
|
||||
if new_messages and new_messages[0].get("role") == "system":
|
||||
new_messages[0] = _mark(new_messages[0])
|
||||
if len(new_messages) >= 3:
|
||||
new_messages[-2] = _mark(new_messages[-2])
|
||||
|
||||
new_tools = tools
|
||||
if tools:
|
||||
new_tools = list(tools)
|
||||
new_tools[-1] = {**new_tools[-1], "cache_control": cache_marker}
|
||||
return new_messages, new_tools
|
||||
|
||||
@staticmethod
|
||||
def _normalize_tool_call_id(tool_call_id: Any) -> Any:
|
||||
"""Normalize to a provider-safe 9-char alphanumeric form."""
|
||||
if not isinstance(tool_call_id, str):
|
||||
return tool_call_id
|
||||
if len(tool_call_id) == 9 and tool_call_id.isalnum():
|
||||
return tool_call_id
|
||||
return hashlib.sha1(tool_call_id.encode()).hexdigest()[:9]
|
||||
|
||||
def _sanitize_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Strip non-standard keys, normalize tool_call IDs."""
|
||||
sanitized = LLMProvider._sanitize_request_messages(messages, _ALLOWED_MSG_KEYS)
|
||||
id_map: dict[str, str] = {}
|
||||
|
||||
def map_id(value: Any) -> Any:
|
||||
if not isinstance(value, str):
|
||||
return value
|
||||
return id_map.setdefault(value, self._normalize_tool_call_id(value))
|
||||
|
||||
for clean in sanitized:
|
||||
if isinstance(clean.get("tool_calls"), list):
|
||||
normalized = []
|
||||
for tc in clean["tool_calls"]:
|
||||
if not isinstance(tc, dict):
|
||||
normalized.append(tc)
|
||||
continue
|
||||
tc_clean = dict(tc)
|
||||
tc_clean["id"] = map_id(tc_clean.get("id"))
|
||||
normalized.append(tc_clean)
|
||||
clean["tool_calls"] = normalized
|
||||
if "tool_call_id" in clean and clean["tool_call_id"]:
|
||||
clean["tool_call_id"] = map_id(clean["tool_call_id"])
|
||||
return sanitized
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Build kwargs
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _build_kwargs(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None,
|
||||
model: str | None,
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
reasoning_effort: str | None,
|
||||
tool_choice: str | dict[str, Any] | None,
|
||||
) -> dict[str, Any]:
|
||||
model_name = model or self.default_model
|
||||
spec = self._spec
|
||||
|
||||
if spec and spec.supports_prompt_caching:
|
||||
messages, tools = self._apply_cache_control(messages, tools)
|
||||
|
||||
if spec and spec.strip_model_prefix:
|
||||
model_name = model_name.split("/")[-1]
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": model_name,
|
||||
"messages": self._sanitize_messages(self._sanitize_empty_content(messages)),
|
||||
"max_tokens": max(1, max_tokens),
|
||||
"temperature": temperature,
|
||||
}
|
||||
|
||||
if spec:
|
||||
model_lower = model_name.lower()
|
||||
for pattern, overrides in spec.model_overrides:
|
||||
if pattern in model_lower:
|
||||
kwargs.update(overrides)
|
||||
break
|
||||
|
||||
if reasoning_effort:
|
||||
kwargs["reasoning_effort"] = reasoning_effort
|
||||
|
||||
if tools:
|
||||
kwargs["tools"] = tools
|
||||
kwargs["tool_choice"] = tool_choice or "auto"
|
||||
|
||||
return kwargs
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Response parsing
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _parse(self, response: Any) -> LLMResponse:
|
||||
if not response.choices:
|
||||
return LLMResponse(content="Error: API returned empty choices.", finish_reason="error")
|
||||
|
||||
choice = response.choices[0]
|
||||
msg = choice.message
|
||||
content = msg.content
|
||||
finish_reason = choice.finish_reason
|
||||
|
||||
raw_tool_calls: list[Any] = []
|
||||
for ch in response.choices:
|
||||
m = ch.message
|
||||
if hasattr(m, "tool_calls") and m.tool_calls:
|
||||
raw_tool_calls.extend(m.tool_calls)
|
||||
if ch.finish_reason in ("tool_calls", "stop"):
|
||||
finish_reason = ch.finish_reason
|
||||
if not content and m.content:
|
||||
content = m.content
|
||||
|
||||
tool_calls = []
|
||||
for tc in raw_tool_calls:
|
||||
args = tc.function.arguments
|
||||
if isinstance(args, str):
|
||||
args = json_repair.loads(args)
|
||||
tool_calls.append(ToolCallRequest(
|
||||
id=_short_tool_id(),
|
||||
name=tc.function.name,
|
||||
arguments=args,
|
||||
))
|
||||
|
||||
usage: dict[str, int] = {}
|
||||
if hasattr(response, "usage") and response.usage:
|
||||
u = response.usage
|
||||
usage = {
|
||||
"prompt_tokens": u.prompt_tokens or 0,
|
||||
"completion_tokens": u.completion_tokens or 0,
|
||||
"total_tokens": u.total_tokens or 0,
|
||||
}
|
||||
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=finish_reason or "stop",
|
||||
usage=usage,
|
||||
reasoning_content=getattr(msg, "reasoning_content", None) or None,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _parse_chunks(chunks: list[Any]) -> LLMResponse:
|
||||
content_parts: list[str] = []
|
||||
tc_bufs: dict[int, dict[str, str]] = {}
|
||||
finish_reason = "stop"
|
||||
usage: dict[str, int] = {}
|
||||
|
||||
for chunk in chunks:
|
||||
if not chunk.choices:
|
||||
if hasattr(chunk, "usage") and chunk.usage:
|
||||
u = chunk.usage
|
||||
usage = {
|
||||
"prompt_tokens": u.prompt_tokens or 0,
|
||||
"completion_tokens": u.completion_tokens or 0,
|
||||
"total_tokens": u.total_tokens or 0,
|
||||
}
|
||||
continue
|
||||
choice = chunk.choices[0]
|
||||
if choice.finish_reason:
|
||||
finish_reason = choice.finish_reason
|
||||
delta = choice.delta
|
||||
if delta and delta.content:
|
||||
content_parts.append(delta.content)
|
||||
for tc in (delta.tool_calls or []) if delta else []:
|
||||
buf = tc_bufs.setdefault(tc.index, {"id": "", "name": "", "arguments": ""})
|
||||
if tc.id:
|
||||
buf["id"] = tc.id
|
||||
if tc.function and tc.function.name:
|
||||
buf["name"] = tc.function.name
|
||||
if tc.function and tc.function.arguments:
|
||||
buf["arguments"] += tc.function.arguments
|
||||
|
||||
return LLMResponse(
|
||||
content="".join(content_parts) or None,
|
||||
tool_calls=[
|
||||
ToolCallRequest(
|
||||
id=b["id"] or _short_tool_id(),
|
||||
name=b["name"],
|
||||
arguments=json_repair.loads(b["arguments"]) if b["arguments"] else {},
|
||||
)
|
||||
for b in tc_bufs.values()
|
||||
],
|
||||
finish_reason=finish_reason,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _handle_error(e: Exception) -> LLMResponse:
|
||||
body = getattr(e, "doc", None) or getattr(getattr(e, "response", None), "text", None)
|
||||
msg = f"Error: {body.strip()[:500]}" if body and body.strip() else f"Error calling LLM: {e}"
|
||||
return LLMResponse(content=msg, finish_reason="error")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
reasoning_effort: str | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
) -> LLMResponse:
|
||||
kwargs = self._build_kwargs(
|
||||
messages, tools, model, max_tokens, temperature,
|
||||
reasoning_effort, tool_choice,
|
||||
)
|
||||
try:
|
||||
return self._parse(await self._client.chat.completions.create(**kwargs))
|
||||
except Exception as e:
|
||||
return self._handle_error(e)
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
model: str | None = None,
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.7,
|
||||
reasoning_effort: str | None = None,
|
||||
tool_choice: str | dict[str, Any] | None = None,
|
||||
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||
) -> LLMResponse:
|
||||
kwargs = self._build_kwargs(
|
||||
messages, tools, model, max_tokens, temperature,
|
||||
reasoning_effort, tool_choice,
|
||||
)
|
||||
kwargs["stream"] = True
|
||||
kwargs["stream_options"] = {"include_usage": True}
|
||||
try:
|
||||
stream = await self._client.chat.completions.create(**kwargs)
|
||||
chunks: list[Any] = []
|
||||
async for chunk in stream:
|
||||
chunks.append(chunk)
|
||||
if on_content_delta and chunk.choices:
|
||||
text = getattr(chunk.choices[0].delta, "content", None)
|
||||
if text:
|
||||
await on_content_delta(text)
|
||||
return self._parse_chunks(chunks)
|
||||
except Exception as e:
|
||||
return self._handle_error(e)
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return self.default_model
|
||||
@ -4,7 +4,7 @@ Provider Registry — single source of truth for LLM provider metadata.
|
||||
Adding a new provider:
|
||||
1. Add a ProviderSpec to PROVIDERS below.
|
||||
2. Add a field to ProvidersConfig in config/schema.py.
|
||||
Done. Env vars, prefixing, config matching, status display all derive from here.
|
||||
Done. Env vars, config matching, status display all derive from here.
|
||||
|
||||
Order matters — it controls match priority and fallback. Gateways first.
|
||||
Every entry writes out all fields so you can copy-paste as a template.
|
||||
@ -12,7 +12,7 @@ Every entry writes out all fields so you can copy-paste as a template.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from pydantic.alias_generators import to_snake
|
||||
@ -30,12 +30,12 @@ class ProviderSpec:
|
||||
# identity
|
||||
name: str # config field name, e.g. "dashscope"
|
||||
keywords: tuple[str, ...] # model-name keywords for matching (lowercase)
|
||||
env_key: str # LiteLLM env var, e.g. "DASHSCOPE_API_KEY"
|
||||
env_key: str # env var for API key, e.g. "DASHSCOPE_API_KEY"
|
||||
display_name: str = "" # shown in `nanobot status`
|
||||
|
||||
# model prefixing
|
||||
litellm_prefix: str = "" # "dashscope" → model becomes "dashscope/{model}"
|
||||
skip_prefixes: tuple[str, ...] = () # don't prefix if model already starts with these
|
||||
# which provider implementation to use
|
||||
# "openai_compat" | "anthropic" | "azure_openai" | "openai_codex"
|
||||
backend: str = "openai_compat"
|
||||
|
||||
# extra env vars, e.g. (("ZHIPUAI_API_KEY", "{api_key}"),)
|
||||
env_extras: tuple[tuple[str, str], ...] = ()
|
||||
@ -45,19 +45,18 @@ class ProviderSpec:
|
||||
is_local: bool = False # local deployment (vLLM, Ollama)
|
||||
detect_by_key_prefix: str = "" # match api_key prefix, e.g. "sk-or-"
|
||||
detect_by_base_keyword: str = "" # match substring in api_base URL
|
||||
default_api_base: str = "" # fallback base URL
|
||||
default_api_base: str = "" # OpenAI-compatible base URL for this provider
|
||||
|
||||
# gateway behavior
|
||||
strip_model_prefix: bool = False # strip "provider/" before re-prefixing
|
||||
litellm_kwargs: dict[str, Any] = field(default_factory=dict) # extra kwargs passed to LiteLLM
|
||||
strip_model_prefix: bool = False # strip "provider/" before sending to gateway
|
||||
|
||||
# per-model param overrides, e.g. (("kimi-k2.5", {"temperature": 1.0}),)
|
||||
model_overrides: tuple[tuple[str, dict[str, Any]], ...] = ()
|
||||
|
||||
# OAuth-based providers (e.g., OpenAI Codex) don't use API keys
|
||||
is_oauth: bool = False # if True, uses OAuth flow instead of API key
|
||||
is_oauth: bool = False
|
||||
|
||||
# Direct providers bypass LiteLLM entirely (e.g., CustomProvider)
|
||||
# Direct providers skip API-key validation (user supplies everything)
|
||||
is_direct: bool = False
|
||||
|
||||
# Provider supports cache_control on content blocks (e.g. Anthropic prompt caching)
|
||||
@ -73,13 +72,13 @@ class ProviderSpec:
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
PROVIDERS: tuple[ProviderSpec, ...] = (
|
||||
# === Custom (direct OpenAI-compatible endpoint, bypasses LiteLLM) ======
|
||||
# === Custom (direct OpenAI-compatible endpoint) ========================
|
||||
ProviderSpec(
|
||||
name="custom",
|
||||
keywords=(),
|
||||
env_key="",
|
||||
display_name="Custom",
|
||||
litellm_prefix="",
|
||||
backend="openai_compat",
|
||||
is_direct=True,
|
||||
),
|
||||
|
||||
@ -89,7 +88,7 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
||||
keywords=("azure", "azure-openai"),
|
||||
env_key="",
|
||||
display_name="Azure OpenAI",
|
||||
litellm_prefix="",
|
||||
backend="azure_openai",
|
||||
is_direct=True,
|
||||
),
|
||||
# === Gateways (detected by api_key / api_base, not model name) =========
|
||||
@ -100,36 +99,26 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
||||
keywords=("openrouter",),
|
||||
env_key="OPENROUTER_API_KEY",
|
||||
display_name="OpenRouter",
|
||||
litellm_prefix="openrouter", # anthropic/claude-3 → openrouter/anthropic/claude-3
|
||||
skip_prefixes=(),
|
||||
env_extras=(),
|
||||
backend="openai_compat",
|
||||
is_gateway=True,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="sk-or-",
|
||||
detect_by_base_keyword="openrouter",
|
||||
default_api_base="https://openrouter.ai/api/v1",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
supports_prompt_caching=True,
|
||||
),
|
||||
# AiHubMix: global gateway, OpenAI-compatible interface.
|
||||
# strip_model_prefix=True: it doesn't understand "anthropic/claude-3",
|
||||
# so we strip to bare "claude-3" then re-prefix as "openai/claude-3".
|
||||
# strip_model_prefix=True: doesn't understand "anthropic/claude-3",
|
||||
# strips to bare "claude-3".
|
||||
ProviderSpec(
|
||||
name="aihubmix",
|
||||
keywords=("aihubmix",),
|
||||
env_key="OPENAI_API_KEY", # OpenAI-compatible
|
||||
env_key="OPENAI_API_KEY",
|
||||
display_name="AiHubMix",
|
||||
litellm_prefix="openai", # → openai/{model}
|
||||
skip_prefixes=(),
|
||||
env_extras=(),
|
||||
backend="openai_compat",
|
||||
is_gateway=True,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="aihubmix",
|
||||
default_api_base="https://aihubmix.com/v1",
|
||||
strip_model_prefix=True, # anthropic/claude-3 → claude-3 → openai/claude-3
|
||||
model_overrides=(),
|
||||
strip_model_prefix=True,
|
||||
),
|
||||
# SiliconFlow (硅基流动): OpenAI-compatible gateway, model names keep org prefix
|
||||
ProviderSpec(
|
||||
@ -137,16 +126,10 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
||||
keywords=("siliconflow",),
|
||||
env_key="OPENAI_API_KEY",
|
||||
display_name="SiliconFlow",
|
||||
litellm_prefix="openai",
|
||||
skip_prefixes=(),
|
||||
env_extras=(),
|
||||
backend="openai_compat",
|
||||
is_gateway=True,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="siliconflow",
|
||||
default_api_base="https://api.siliconflow.cn/v1",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
),
|
||||
|
||||
# VolcEngine (火山引擎): OpenAI-compatible gateway, pay-per-use models
|
||||
@ -155,16 +138,10 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
||||
keywords=("volcengine", "volces", "ark"),
|
||||
env_key="OPENAI_API_KEY",
|
||||
display_name="VolcEngine",
|
||||
litellm_prefix="volcengine",
|
||||
skip_prefixes=(),
|
||||
env_extras=(),
|
||||
backend="openai_compat",
|
||||
is_gateway=True,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="volces",
|
||||
default_api_base="https://ark.cn-beijing.volces.com/api/v3",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
),
|
||||
|
||||
# VolcEngine Coding Plan (火山引擎 Coding Plan): same key as volcengine
|
||||
@ -173,16 +150,10 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
||||
keywords=("volcengine-plan",),
|
||||
env_key="OPENAI_API_KEY",
|
||||
display_name="VolcEngine Coding Plan",
|
||||
litellm_prefix="volcengine",
|
||||
skip_prefixes=(),
|
||||
env_extras=(),
|
||||
backend="openai_compat",
|
||||
is_gateway=True,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
default_api_base="https://ark.cn-beijing.volces.com/api/coding/v3",
|
||||
strip_model_prefix=True,
|
||||
model_overrides=(),
|
||||
),
|
||||
|
||||
# BytePlus: VolcEngine international, pay-per-use models
|
||||
@ -191,16 +162,11 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
||||
keywords=("byteplus",),
|
||||
env_key="OPENAI_API_KEY",
|
||||
display_name="BytePlus",
|
||||
litellm_prefix="volcengine",
|
||||
skip_prefixes=(),
|
||||
env_extras=(),
|
||||
backend="openai_compat",
|
||||
is_gateway=True,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="bytepluses",
|
||||
default_api_base="https://ark.ap-southeast.bytepluses.com/api/v3",
|
||||
strip_model_prefix=True,
|
||||
model_overrides=(),
|
||||
),
|
||||
|
||||
# BytePlus Coding Plan: same key as byteplus
|
||||
@ -209,250 +175,137 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
||||
keywords=("byteplus-plan",),
|
||||
env_key="OPENAI_API_KEY",
|
||||
display_name="BytePlus Coding Plan",
|
||||
litellm_prefix="volcengine",
|
||||
skip_prefixes=(),
|
||||
env_extras=(),
|
||||
backend="openai_compat",
|
||||
is_gateway=True,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
default_api_base="https://ark.ap-southeast.bytepluses.com/api/coding/v3",
|
||||
strip_model_prefix=True,
|
||||
model_overrides=(),
|
||||
),
|
||||
|
||||
|
||||
# === Standard providers (matched by model-name keywords) ===============
|
||||
# Anthropic: LiteLLM recognizes "claude-*" natively, no prefix needed.
|
||||
# Anthropic: native Anthropic SDK
|
||||
ProviderSpec(
|
||||
name="anthropic",
|
||||
keywords=("anthropic", "claude"),
|
||||
env_key="ANTHROPIC_API_KEY",
|
||||
display_name="Anthropic",
|
||||
litellm_prefix="",
|
||||
skip_prefixes=(),
|
||||
env_extras=(),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
default_api_base="",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
backend="anthropic",
|
||||
supports_prompt_caching=True,
|
||||
),
|
||||
# OpenAI: LiteLLM recognizes "gpt-*" natively, no prefix needed.
|
||||
# OpenAI: SDK default base URL (no override needed)
|
||||
ProviderSpec(
|
||||
name="openai",
|
||||
keywords=("openai", "gpt"),
|
||||
env_key="OPENAI_API_KEY",
|
||||
display_name="OpenAI",
|
||||
litellm_prefix="",
|
||||
skip_prefixes=(),
|
||||
env_extras=(),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
default_api_base="",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
backend="openai_compat",
|
||||
),
|
||||
# OpenAI Codex: uses OAuth, not API key.
|
||||
# OpenAI Codex: OAuth-based, dedicated provider
|
||||
ProviderSpec(
|
||||
name="openai_codex",
|
||||
keywords=("openai-codex",),
|
||||
env_key="", # OAuth-based, no API key
|
||||
env_key="",
|
||||
display_name="OpenAI Codex",
|
||||
litellm_prefix="", # Not routed through LiteLLM
|
||||
skip_prefixes=(),
|
||||
env_extras=(),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
backend="openai_codex",
|
||||
detect_by_base_keyword="codex",
|
||||
default_api_base="https://chatgpt.com/backend-api",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
is_oauth=True, # OAuth-based authentication
|
||||
is_oauth=True,
|
||||
),
|
||||
# Github Copilot: uses OAuth, not API key.
|
||||
# GitHub Copilot: OAuth-based
|
||||
ProviderSpec(
|
||||
name="github_copilot",
|
||||
keywords=("github_copilot", "copilot"),
|
||||
env_key="", # OAuth-based, no API key
|
||||
env_key="",
|
||||
display_name="Github Copilot",
|
||||
litellm_prefix="github_copilot", # github_copilot/model → github_copilot/model
|
||||
skip_prefixes=("github_copilot/",),
|
||||
env_extras=(),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
default_api_base="",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
is_oauth=True, # OAuth-based authentication
|
||||
backend="openai_compat",
|
||||
default_api_base="https://api.githubcopilot.com",
|
||||
is_oauth=True,
|
||||
),
|
||||
# DeepSeek: needs "deepseek/" prefix for LiteLLM routing.
|
||||
# DeepSeek: OpenAI-compatible at api.deepseek.com
|
||||
ProviderSpec(
|
||||
name="deepseek",
|
||||
keywords=("deepseek",),
|
||||
env_key="DEEPSEEK_API_KEY",
|
||||
display_name="DeepSeek",
|
||||
litellm_prefix="deepseek", # deepseek-chat → deepseek/deepseek-chat
|
||||
skip_prefixes=("deepseek/",), # avoid double-prefix
|
||||
env_extras=(),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
default_api_base="",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
backend="openai_compat",
|
||||
default_api_base="https://api.deepseek.com",
|
||||
),
|
||||
# Gemini: needs "gemini/" prefix for LiteLLM.
|
||||
# Gemini: Google's OpenAI-compatible endpoint
|
||||
ProviderSpec(
|
||||
name="gemini",
|
||||
keywords=("gemini",),
|
||||
env_key="GEMINI_API_KEY",
|
||||
display_name="Gemini",
|
||||
litellm_prefix="gemini", # gemini-pro → gemini/gemini-pro
|
||||
skip_prefixes=("gemini/",), # avoid double-prefix
|
||||
env_extras=(),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
default_api_base="",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
backend="openai_compat",
|
||||
default_api_base="https://generativelanguage.googleapis.com/v1beta/openai/",
|
||||
),
|
||||
# Zhipu: LiteLLM uses "zai/" prefix.
|
||||
# Also mirrors key to ZHIPUAI_API_KEY (some LiteLLM paths check that).
|
||||
# skip_prefixes: don't add "zai/" when already routed via gateway.
|
||||
# Zhipu (智谱): OpenAI-compatible at open.bigmodel.cn
|
||||
ProviderSpec(
|
||||
name="zhipu",
|
||||
keywords=("zhipu", "glm", "zai"),
|
||||
env_key="ZAI_API_KEY",
|
||||
display_name="Zhipu AI",
|
||||
litellm_prefix="zai", # glm-4 → zai/glm-4
|
||||
skip_prefixes=("zhipu/", "zai/", "openrouter/", "hosted_vllm/"),
|
||||
backend="openai_compat",
|
||||
env_extras=(("ZHIPUAI_API_KEY", "{api_key}"),),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
default_api_base="",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
default_api_base="https://open.bigmodel.cn/api/paas/v4",
|
||||
),
|
||||
# DashScope: Qwen models, needs "dashscope/" prefix.
|
||||
# DashScope (通义): Qwen models, OpenAI-compatible endpoint
|
||||
ProviderSpec(
|
||||
name="dashscope",
|
||||
keywords=("qwen", "dashscope"),
|
||||
env_key="DASHSCOPE_API_KEY",
|
||||
display_name="DashScope",
|
||||
litellm_prefix="dashscope", # qwen-max → dashscope/qwen-max
|
||||
skip_prefixes=("dashscope/", "openrouter/"),
|
||||
env_extras=(),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
default_api_base="",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
backend="openai_compat",
|
||||
default_api_base="https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
),
|
||||
# Moonshot: Kimi models, needs "moonshot/" prefix.
|
||||
# LiteLLM requires MOONSHOT_API_BASE env var to find the endpoint.
|
||||
# Kimi K2.5 API enforces temperature >= 1.0.
|
||||
# Moonshot (月之暗面): Kimi models. K2.5 enforces temperature >= 1.0.
|
||||
ProviderSpec(
|
||||
name="moonshot",
|
||||
keywords=("moonshot", "kimi"),
|
||||
env_key="MOONSHOT_API_KEY",
|
||||
display_name="Moonshot",
|
||||
litellm_prefix="moonshot", # kimi-k2.5 → moonshot/kimi-k2.5
|
||||
skip_prefixes=("moonshot/", "openrouter/"),
|
||||
env_extras=(("MOONSHOT_API_BASE", "{api_base}"),),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
default_api_base="https://api.moonshot.ai/v1", # intl; use api.moonshot.cn for China
|
||||
strip_model_prefix=False,
|
||||
backend="openai_compat",
|
||||
default_api_base="https://api.moonshot.ai/v1",
|
||||
model_overrides=(("kimi-k2.5", {"temperature": 1.0}),),
|
||||
),
|
||||
# MiniMax: needs "minimax/" prefix for LiteLLM routing.
|
||||
# Uses OpenAI-compatible API at api.minimax.io/v1.
|
||||
# MiniMax: OpenAI-compatible API
|
||||
ProviderSpec(
|
||||
name="minimax",
|
||||
keywords=("minimax",),
|
||||
env_key="MINIMAX_API_KEY",
|
||||
display_name="MiniMax",
|
||||
litellm_prefix="minimax", # MiniMax-M2.1 → minimax/MiniMax-M2.1
|
||||
skip_prefixes=("minimax/", "openrouter/"),
|
||||
env_extras=(),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
backend="openai_compat",
|
||||
default_api_base="https://api.minimax.io/v1",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
),
|
||||
# Mistral AI: OpenAI-compatible API at api.mistral.ai/v1.
|
||||
# Mistral AI: OpenAI-compatible API
|
||||
ProviderSpec(
|
||||
name="mistral",
|
||||
keywords=("mistral",),
|
||||
env_key="MISTRAL_API_KEY",
|
||||
display_name="Mistral",
|
||||
litellm_prefix="mistral", # mistral-large-latest → mistral/mistral-large-latest
|
||||
skip_prefixes=("mistral/",), # avoid double-prefix
|
||||
env_extras=(),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
backend="openai_compat",
|
||||
default_api_base="https://api.mistral.ai/v1",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
),
|
||||
# === Local deployment (matched by config key, NOT by api_base) =========
|
||||
# vLLM / any OpenAI-compatible local server.
|
||||
# Detected when config key is "vllm" (provider_name="vllm").
|
||||
# vLLM / any OpenAI-compatible local server
|
||||
ProviderSpec(
|
||||
name="vllm",
|
||||
keywords=("vllm",),
|
||||
env_key="HOSTED_VLLM_API_KEY",
|
||||
display_name="vLLM/Local",
|
||||
litellm_prefix="hosted_vllm", # Llama-3-8B → hosted_vllm/Llama-3-8B
|
||||
skip_prefixes=(),
|
||||
env_extras=(),
|
||||
is_gateway=False,
|
||||
backend="openai_compat",
|
||||
is_local=True,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
default_api_base="", # user must provide in config
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
),
|
||||
# === Ollama (local, OpenAI-compatible) ===================================
|
||||
# Ollama (local, OpenAI-compatible)
|
||||
ProviderSpec(
|
||||
name="ollama",
|
||||
keywords=("ollama", "nemotron"),
|
||||
env_key="OLLAMA_API_KEY",
|
||||
display_name="Ollama",
|
||||
litellm_prefix="ollama_chat", # model → ollama_chat/model
|
||||
skip_prefixes=("ollama/", "ollama_chat/"),
|
||||
env_extras=(),
|
||||
is_gateway=False,
|
||||
backend="openai_compat",
|
||||
is_local=True,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="11434",
|
||||
default_api_base="http://localhost:11434",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
default_api_base="http://localhost:11434/v1",
|
||||
),
|
||||
# === OpenVINO Model Server (direct, local, OpenAI-compatible at /v3) ===
|
||||
ProviderSpec(
|
||||
@ -460,29 +313,20 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
||||
keywords=("openvino", "ovms"),
|
||||
env_key="",
|
||||
display_name="OpenVINO Model Server",
|
||||
litellm_prefix="",
|
||||
backend="openai_compat",
|
||||
is_direct=True,
|
||||
is_local=True,
|
||||
default_api_base="http://localhost:8000/v3",
|
||||
),
|
||||
# === Auxiliary (not a primary LLM provider) ============================
|
||||
# Groq: mainly used for Whisper voice transcription, also usable for LLM.
|
||||
# Needs "groq/" prefix for LiteLLM routing. Placed last — it rarely wins fallback.
|
||||
# Groq: mainly used for Whisper voice transcription, also usable for LLM
|
||||
ProviderSpec(
|
||||
name="groq",
|
||||
keywords=("groq",),
|
||||
env_key="GROQ_API_KEY",
|
||||
display_name="Groq",
|
||||
litellm_prefix="groq", # llama3-8b-8192 → groq/llama3-8b-8192
|
||||
skip_prefixes=("groq/",), # avoid double-prefix
|
||||
env_extras=(),
|
||||
is_gateway=False,
|
||||
is_local=False,
|
||||
detect_by_key_prefix="",
|
||||
detect_by_base_keyword="",
|
||||
default_api_base="",
|
||||
strip_model_prefix=False,
|
||||
model_overrides=(),
|
||||
backend="openai_compat",
|
||||
default_api_base="https://api.groq.com/openai/v1",
|
||||
),
|
||||
)
|
||||
|
||||
@ -492,59 +336,6 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def find_by_model(model: str) -> ProviderSpec | None:
|
||||
"""Match a standard provider by model-name keyword (case-insensitive).
|
||||
Skips gateways/local — those are matched by api_key/api_base instead."""
|
||||
model_lower = model.lower()
|
||||
model_normalized = model_lower.replace("-", "_")
|
||||
model_prefix = model_lower.split("/", 1)[0] if "/" in model_lower else ""
|
||||
normalized_prefix = model_prefix.replace("-", "_")
|
||||
std_specs = [s for s in PROVIDERS if not s.is_gateway and not s.is_local]
|
||||
|
||||
# Prefer explicit provider prefix — prevents `github-copilot/...codex` matching openai_codex.
|
||||
for spec in std_specs:
|
||||
if model_prefix and normalized_prefix == spec.name:
|
||||
return spec
|
||||
|
||||
for spec in std_specs:
|
||||
if any(
|
||||
kw in model_lower or kw.replace("-", "_") in model_normalized for kw in spec.keywords
|
||||
):
|
||||
return spec
|
||||
return None
|
||||
|
||||
|
||||
def find_gateway(
|
||||
provider_name: str | None = None,
|
||||
api_key: str | None = None,
|
||||
api_base: str | None = None,
|
||||
) -> ProviderSpec | None:
|
||||
"""Detect gateway/local provider.
|
||||
|
||||
Priority:
|
||||
1. provider_name — if it maps to a gateway/local spec, use it directly.
|
||||
2. api_key prefix — e.g. "sk-or-" → OpenRouter.
|
||||
3. api_base keyword — e.g. "aihubmix" in URL → AiHubMix.
|
||||
|
||||
A standard provider with a custom api_base (e.g. DeepSeek behind a proxy)
|
||||
will NOT be mistaken for vLLM — the old fallback is gone.
|
||||
"""
|
||||
# 1. Direct match by config key
|
||||
if provider_name:
|
||||
spec = find_by_name(provider_name)
|
||||
if spec and (spec.is_gateway or spec.is_local):
|
||||
return spec
|
||||
|
||||
# 2. Auto-detect by api_key prefix / api_base keyword
|
||||
for spec in PROVIDERS:
|
||||
if spec.detect_by_key_prefix and api_key and api_key.startswith(spec.detect_by_key_prefix):
|
||||
return spec
|
||||
if spec.detect_by_base_keyword and api_base and spec.detect_by_base_keyword in api_base:
|
||||
return spec
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def find_by_name(name: str) -> ProviderSpec | None:
|
||||
"""Find a provider spec by config field name, e.g. "dashscope"."""
|
||||
normalized = to_snake(name.replace("-", "_"))
|
||||
|
||||
@ -19,7 +19,7 @@ classifiers = [
|
||||
|
||||
dependencies = [
|
||||
"typer>=0.20.0,<1.0.0",
|
||||
"litellm>=1.82.1,<=1.82.6",
|
||||
"anthropic>=0.45.0,<1.0.0",
|
||||
"pydantic>=2.12.0,<3.0.0",
|
||||
"pydantic-settings>=2.12.0,<3.0.0",
|
||||
"websockets>=16.0,<17.0",
|
||||
|
||||
@ -1,40 +1,6 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
from nanobot.providers.base import ToolCallRequest
|
||||
from nanobot.providers.litellm_provider import LiteLLMProvider
|
||||
|
||||
|
||||
def test_litellm_parse_response_preserves_tool_call_provider_fields() -> None:
|
||||
provider = LiteLLMProvider(default_model="gemini/gemini-3-flash")
|
||||
|
||||
response = SimpleNamespace(
|
||||
choices=[
|
||||
SimpleNamespace(
|
||||
finish_reason="tool_calls",
|
||||
message=SimpleNamespace(
|
||||
content=None,
|
||||
tool_calls=[
|
||||
SimpleNamespace(
|
||||
id="call_123",
|
||||
function=SimpleNamespace(
|
||||
name="read_file",
|
||||
arguments='{"path":"todo.md"}',
|
||||
provider_specific_fields={"inner": "value"},
|
||||
),
|
||||
provider_specific_fields={"thought_signature": "signed-token"},
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
],
|
||||
usage=None,
|
||||
)
|
||||
|
||||
parsed = provider._parse_response(response)
|
||||
|
||||
assert len(parsed.tool_calls) == 1
|
||||
assert parsed.tool_calls[0].provider_specific_fields == {"thought_signature": "signed-token"}
|
||||
assert parsed.tool_calls[0].function_provider_specific_fields == {"inner": "value"}
|
||||
|
||||
|
||||
def test_tool_call_request_serializes_provider_fields() -> None:
|
||||
|
||||
@ -380,7 +380,7 @@ class TestMemoryConsolidationTypeHandling:
|
||||
"""Forced tool_choice rejected by provider -> retry with auto and succeed."""
|
||||
store = MemoryStore(tmp_path)
|
||||
error_resp = LLMResponse(
|
||||
content="Error calling LLM: litellm.BadRequestError: "
|
||||
content="Error calling LLM: BadRequestError: "
|
||||
"The tool_choice parameter does not support being set to required or object",
|
||||
finish_reason="error",
|
||||
tool_calls=[],
|
||||
|
||||
@ -9,9 +9,8 @@ from typer.testing import CliRunner
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.cli.commands import _make_provider, app
|
||||
from nanobot.config.schema import Config
|
||||
from nanobot.providers.litellm_provider import LiteLLMProvider
|
||||
from nanobot.providers.openai_codex_provider import _strip_model_prefix
|
||||
from nanobot.providers.registry import find_by_model, find_by_name
|
||||
from nanobot.providers.registry import find_by_name
|
||||
|
||||
runner = CliRunner()
|
||||
|
||||
@ -228,7 +227,7 @@ def test_config_matches_explicit_ollama_prefix_without_api_key():
|
||||
config.agents.defaults.model = "ollama/llama3.2"
|
||||
|
||||
assert config.get_provider_name() == "ollama"
|
||||
assert config.get_api_base() == "http://localhost:11434"
|
||||
assert config.get_api_base() == "http://localhost:11434/v1"
|
||||
|
||||
|
||||
def test_config_explicit_ollama_provider_uses_default_localhost_api_base():
|
||||
@ -237,7 +236,7 @@ def test_config_explicit_ollama_provider_uses_default_localhost_api_base():
|
||||
config.agents.defaults.model = "llama3.2"
|
||||
|
||||
assert config.get_provider_name() == "ollama"
|
||||
assert config.get_api_base() == "http://localhost:11434"
|
||||
assert config.get_api_base() == "http://localhost:11434/v1"
|
||||
|
||||
|
||||
def test_config_accepts_camel_case_explicit_provider_name_for_coding_plan():
|
||||
@ -272,12 +271,12 @@ def test_config_auto_detects_ollama_from_local_api_base():
|
||||
config = Config.model_validate(
|
||||
{
|
||||
"agents": {"defaults": {"provider": "auto", "model": "llama3.2"}},
|
||||
"providers": {"ollama": {"apiBase": "http://localhost:11434"}},
|
||||
"providers": {"ollama": {"apiBase": "http://localhost:11434/v1"}},
|
||||
}
|
||||
)
|
||||
|
||||
assert config.get_provider_name() == "ollama"
|
||||
assert config.get_api_base() == "http://localhost:11434"
|
||||
assert config.get_api_base() == "http://localhost:11434/v1"
|
||||
|
||||
|
||||
def test_config_prefers_ollama_over_vllm_when_both_local_providers_configured():
|
||||
@ -286,13 +285,13 @@ def test_config_prefers_ollama_over_vllm_when_both_local_providers_configured():
|
||||
"agents": {"defaults": {"provider": "auto", "model": "llama3.2"}},
|
||||
"providers": {
|
||||
"vllm": {"apiBase": "http://localhost:8000"},
|
||||
"ollama": {"apiBase": "http://localhost:11434"},
|
||||
"ollama": {"apiBase": "http://localhost:11434/v1"},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
assert config.get_provider_name() == "ollama"
|
||||
assert config.get_api_base() == "http://localhost:11434"
|
||||
assert config.get_api_base() == "http://localhost:11434/v1"
|
||||
|
||||
|
||||
def test_config_falls_back_to_vllm_when_ollama_not_configured():
|
||||
@ -309,19 +308,13 @@ def test_config_falls_back_to_vllm_when_ollama_not_configured():
|
||||
assert config.get_api_base() == "http://localhost:8000"
|
||||
|
||||
|
||||
def test_find_by_model_prefers_explicit_prefix_over_generic_codex_keyword():
|
||||
spec = find_by_model("github-copilot/gpt-5.3-codex")
|
||||
def test_openai_compat_provider_passes_model_through():
|
||||
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
||||
|
||||
assert spec is not None
|
||||
assert spec.name == "github_copilot"
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
provider = OpenAICompatProvider(default_model="github-copilot/gpt-5.3-codex")
|
||||
|
||||
|
||||
def test_litellm_provider_canonicalizes_github_copilot_hyphen_prefix():
|
||||
provider = LiteLLMProvider(default_model="github-copilot/gpt-5.3-codex")
|
||||
|
||||
resolved = provider._resolve_model("github-copilot/gpt-5.3-codex")
|
||||
|
||||
assert resolved == "github_copilot/gpt-5.3-codex"
|
||||
assert provider.get_default_model() == "github-copilot/gpt-5.3-codex"
|
||||
|
||||
|
||||
def test_openai_codex_strip_prefix_supports_hyphen_and_underscore():
|
||||
@ -346,7 +339,7 @@ def test_make_provider_passes_extra_headers_to_custom_provider():
|
||||
}
|
||||
)
|
||||
|
||||
with patch("nanobot.providers.custom_provider.AsyncOpenAI") as mock_async_openai:
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as mock_async_openai:
|
||||
_make_provider(config)
|
||||
|
||||
kwargs = mock_async_openai.call_args.kwargs
|
||||
|
||||
@ -1,10 +1,14 @@
|
||||
from types import SimpleNamespace
|
||||
"""Tests for OpenAICompatProvider handling custom/direct endpoints."""
|
||||
|
||||
from nanobot.providers.custom_provider import CustomProvider
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
||||
|
||||
|
||||
def test_custom_provider_parse_handles_empty_choices() -> None:
|
||||
provider = CustomProvider()
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
provider = OpenAICompatProvider()
|
||||
response = SimpleNamespace(choices=[])
|
||||
|
||||
result = provider._parse(response)
|
||||
|
||||
@ -1,161 +1,122 @@
|
||||
"""Regression tests for PR #2026 — litellm_kwargs injection from ProviderSpec.
|
||||
"""Tests for OpenAICompatProvider spec-driven behavior.
|
||||
|
||||
Validates that:
|
||||
- OpenRouter uses litellm_prefix (NOT custom_llm_provider) to avoid LiteLLM double-prefixing.
|
||||
- The litellm_kwargs mechanism works correctly for providers that declare it.
|
||||
- Non-gateway providers are unaffected.
|
||||
- OpenRouter (no strip) keeps model names intact.
|
||||
- AiHubMix (strip_model_prefix=True) strips provider prefixes.
|
||||
- Standard providers pass model names through as-is.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.providers.litellm_provider import LiteLLMProvider
|
||||
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
||||
from nanobot.providers.registry import find_by_name
|
||||
|
||||
|
||||
def _fake_response(content: str = "ok") -> SimpleNamespace:
|
||||
"""Build a minimal acompletion-shaped response object."""
|
||||
def _fake_chat_response(content: str = "ok") -> SimpleNamespace:
|
||||
"""Build a minimal OpenAI chat completion response."""
|
||||
message = SimpleNamespace(
|
||||
content=content,
|
||||
tool_calls=None,
|
||||
reasoning_content=None,
|
||||
thinking_blocks=None,
|
||||
)
|
||||
choice = SimpleNamespace(message=message, finish_reason="stop")
|
||||
usage = SimpleNamespace(prompt_tokens=10, completion_tokens=5, total_tokens=15)
|
||||
return SimpleNamespace(choices=[choice], usage=usage)
|
||||
|
||||
|
||||
def test_openrouter_spec_uses_prefix_not_custom_llm_provider() -> None:
|
||||
"""OpenRouter must rely on litellm_prefix, not custom_llm_provider kwarg.
|
||||
|
||||
LiteLLM internally adds a provider/ prefix when custom_llm_provider is set,
|
||||
which double-prefixes models (openrouter/anthropic/model) and breaks the API.
|
||||
"""
|
||||
def test_openrouter_spec_is_gateway() -> None:
|
||||
spec = find_by_name("openrouter")
|
||||
assert spec is not None
|
||||
assert spec.litellm_prefix == "openrouter"
|
||||
assert "custom_llm_provider" not in spec.litellm_kwargs, (
|
||||
"custom_llm_provider causes LiteLLM to double-prefix the model name"
|
||||
)
|
||||
assert spec.is_gateway is True
|
||||
assert spec.default_api_base == "https://openrouter.ai/api/v1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openrouter_prefixes_model_correctly() -> None:
|
||||
"""OpenRouter should prefix model as openrouter/vendor/model for LiteLLM routing."""
|
||||
mock_acompletion = AsyncMock(return_value=_fake_response())
|
||||
async def test_openrouter_keeps_model_name_intact() -> None:
|
||||
"""OpenRouter gateway keeps the full model name (gateway does its own routing)."""
|
||||
mock_create = AsyncMock(return_value=_fake_chat_response())
|
||||
spec = find_by_name("openrouter")
|
||||
|
||||
with patch("nanobot.providers.litellm_provider.acompletion", mock_acompletion):
|
||||
provider = LiteLLMProvider(
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
|
||||
client_instance = MockClient.return_value
|
||||
client_instance.chat.completions.create = mock_create
|
||||
|
||||
provider = OpenAICompatProvider(
|
||||
api_key="sk-or-test-key",
|
||||
api_base="https://openrouter.ai/api/v1",
|
||||
default_model="anthropic/claude-sonnet-4-5",
|
||||
provider_name="openrouter",
|
||||
spec=spec,
|
||||
)
|
||||
await provider.chat(
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
model="anthropic/claude-sonnet-4-5",
|
||||
)
|
||||
|
||||
call_kwargs = mock_acompletion.call_args.kwargs
|
||||
assert call_kwargs["model"] == "openrouter/anthropic/claude-sonnet-4-5", (
|
||||
"LiteLLM needs openrouter/ prefix to detect the provider and strip it before API call"
|
||||
)
|
||||
assert "custom_llm_provider" not in call_kwargs
|
||||
call_kwargs = mock_create.call_args.kwargs
|
||||
assert call_kwargs["model"] == "anthropic/claude-sonnet-4-5"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_gateway_provider_no_extra_kwargs() -> None:
|
||||
"""Standard (non-gateway) providers must NOT inject any litellm_kwargs."""
|
||||
mock_acompletion = AsyncMock(return_value=_fake_response())
|
||||
async def test_aihubmix_strips_model_prefix() -> None:
|
||||
"""AiHubMix strips the provider prefix (strip_model_prefix=True)."""
|
||||
mock_create = AsyncMock(return_value=_fake_chat_response())
|
||||
spec = find_by_name("aihubmix")
|
||||
|
||||
with patch("nanobot.providers.litellm_provider.acompletion", mock_acompletion):
|
||||
provider = LiteLLMProvider(
|
||||
api_key="sk-ant-test-key",
|
||||
default_model="claude-sonnet-4-5",
|
||||
)
|
||||
await provider.chat(
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
model="claude-sonnet-4-5",
|
||||
)
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
|
||||
client_instance = MockClient.return_value
|
||||
client_instance.chat.completions.create = mock_create
|
||||
|
||||
call_kwargs = mock_acompletion.call_args.kwargs
|
||||
assert "custom_llm_provider" not in call_kwargs, (
|
||||
"Standard Anthropic provider should NOT inject custom_llm_provider"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gateway_without_litellm_kwargs_injects_nothing_extra() -> None:
|
||||
"""Gateways without litellm_kwargs (e.g. AiHubMix) must not add extra keys."""
|
||||
mock_acompletion = AsyncMock(return_value=_fake_response())
|
||||
|
||||
with patch("nanobot.providers.litellm_provider.acompletion", mock_acompletion):
|
||||
provider = LiteLLMProvider(
|
||||
provider = OpenAICompatProvider(
|
||||
api_key="sk-aihub-test-key",
|
||||
api_base="https://aihubmix.com/v1",
|
||||
default_model="claude-sonnet-4-5",
|
||||
provider_name="aihubmix",
|
||||
)
|
||||
await provider.chat(
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
model="claude-sonnet-4-5",
|
||||
)
|
||||
|
||||
call_kwargs = mock_acompletion.call_args.kwargs
|
||||
assert "custom_llm_provider" not in call_kwargs
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openrouter_autodetect_by_key_prefix() -> None:
|
||||
"""OpenRouter should be auto-detected by sk-or- key prefix even without explicit provider_name."""
|
||||
mock_acompletion = AsyncMock(return_value=_fake_response())
|
||||
|
||||
with patch("nanobot.providers.litellm_provider.acompletion", mock_acompletion):
|
||||
provider = LiteLLMProvider(
|
||||
api_key="sk-or-auto-detect-key",
|
||||
default_model="anthropic/claude-sonnet-4-5",
|
||||
spec=spec,
|
||||
)
|
||||
await provider.chat(
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
model="anthropic/claude-sonnet-4-5",
|
||||
)
|
||||
|
||||
call_kwargs = mock_acompletion.call_args.kwargs
|
||||
assert call_kwargs["model"] == "openrouter/anthropic/claude-sonnet-4-5", (
|
||||
"Auto-detected OpenRouter should prefix model for LiteLLM routing"
|
||||
)
|
||||
call_kwargs = mock_create.call_args.kwargs
|
||||
assert call_kwargs["model"] == "claude-sonnet-4-5"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openrouter_native_model_id_gets_double_prefixed() -> None:
|
||||
"""Models like openrouter/free must be double-prefixed so LiteLLM strips one layer.
|
||||
async def test_standard_provider_passes_model_through() -> None:
|
||||
"""Standard provider (e.g. deepseek) passes model name through as-is."""
|
||||
mock_create = AsyncMock(return_value=_fake_chat_response())
|
||||
spec = find_by_name("deepseek")
|
||||
|
||||
openrouter/free is an actual OpenRouter model ID. LiteLLM strips the first
|
||||
openrouter/ for routing, so we must send openrouter/openrouter/free to ensure
|
||||
the API receives openrouter/free.
|
||||
"""
|
||||
mock_acompletion = AsyncMock(return_value=_fake_response())
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient:
|
||||
client_instance = MockClient.return_value
|
||||
client_instance.chat.completions.create = mock_create
|
||||
|
||||
with patch("nanobot.providers.litellm_provider.acompletion", mock_acompletion):
|
||||
provider = LiteLLMProvider(
|
||||
api_key="sk-or-test-key",
|
||||
api_base="https://openrouter.ai/api/v1",
|
||||
default_model="openrouter/free",
|
||||
provider_name="openrouter",
|
||||
provider = OpenAICompatProvider(
|
||||
api_key="sk-deepseek-test-key",
|
||||
default_model="deepseek-chat",
|
||||
spec=spec,
|
||||
)
|
||||
await provider.chat(
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
model="openrouter/free",
|
||||
model="deepseek-chat",
|
||||
)
|
||||
|
||||
call_kwargs = mock_acompletion.call_args.kwargs
|
||||
assert call_kwargs["model"] == "openrouter/openrouter/free", (
|
||||
"openrouter/free must become openrouter/openrouter/free — "
|
||||
"LiteLLM strips one layer so the API receives openrouter/free"
|
||||
)
|
||||
call_kwargs = mock_create.call_args.kwargs
|
||||
assert call_kwargs["model"] == "deepseek-chat"
|
||||
|
||||
|
||||
def test_openai_model_passthrough() -> None:
|
||||
"""OpenAI models pass through unchanged."""
|
||||
spec = find_by_name("openai")
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
provider = OpenAICompatProvider(
|
||||
api_key="sk-test-key",
|
||||
default_model="gpt-4o",
|
||||
spec=spec,
|
||||
)
|
||||
assert provider.get_default_model() == "gpt-4o"
|
||||
|
||||
@ -17,6 +17,4 @@ def test_mistral_provider_in_registry():
|
||||
|
||||
mistral = specs["mistral"]
|
||||
assert mistral.env_key == "MISTRAL_API_KEY"
|
||||
assert mistral.litellm_prefix == "mistral"
|
||||
assert mistral.default_api_base == "https://api.mistral.ai/v1"
|
||||
assert "mistral/" in mistral.skip_prefixes
|
||||
|
||||
@ -8,19 +8,22 @@ import sys
|
||||
|
||||
def test_importing_providers_package_is_lazy(monkeypatch) -> None:
|
||||
monkeypatch.delitem(sys.modules, "nanobot.providers", raising=False)
|
||||
monkeypatch.delitem(sys.modules, "nanobot.providers.litellm_provider", raising=False)
|
||||
monkeypatch.delitem(sys.modules, "nanobot.providers.anthropic_provider", raising=False)
|
||||
monkeypatch.delitem(sys.modules, "nanobot.providers.openai_compat_provider", raising=False)
|
||||
monkeypatch.delitem(sys.modules, "nanobot.providers.openai_codex_provider", raising=False)
|
||||
monkeypatch.delitem(sys.modules, "nanobot.providers.azure_openai_provider", raising=False)
|
||||
|
||||
providers = importlib.import_module("nanobot.providers")
|
||||
|
||||
assert "nanobot.providers.litellm_provider" not in sys.modules
|
||||
assert "nanobot.providers.anthropic_provider" not in sys.modules
|
||||
assert "nanobot.providers.openai_compat_provider" not in sys.modules
|
||||
assert "nanobot.providers.openai_codex_provider" not in sys.modules
|
||||
assert "nanobot.providers.azure_openai_provider" not in sys.modules
|
||||
assert providers.__all__ == [
|
||||
"LLMProvider",
|
||||
"LLMResponse",
|
||||
"LiteLLMProvider",
|
||||
"AnthropicProvider",
|
||||
"OpenAICompatProvider",
|
||||
"OpenAICodexProvider",
|
||||
"AzureOpenAIProvider",
|
||||
]
|
||||
@ -28,10 +31,10 @@ def test_importing_providers_package_is_lazy(monkeypatch) -> None:
|
||||
|
||||
def test_explicit_provider_import_still_works(monkeypatch) -> None:
|
||||
monkeypatch.delitem(sys.modules, "nanobot.providers", raising=False)
|
||||
monkeypatch.delitem(sys.modules, "nanobot.providers.litellm_provider", raising=False)
|
||||
monkeypatch.delitem(sys.modules, "nanobot.providers.anthropic_provider", raising=False)
|
||||
|
||||
namespace: dict[str, object] = {}
|
||||
exec("from nanobot.providers import LiteLLMProvider", namespace)
|
||||
exec("from nanobot.providers import AnthropicProvider", namespace)
|
||||
|
||||
assert namespace["LiteLLMProvider"].__name__ == "LiteLLMProvider"
|
||||
assert "nanobot.providers.litellm_provider" in sys.modules
|
||||
assert namespace["AnthropicProvider"].__name__ == "AnthropicProvider"
|
||||
assert "nanobot.providers.anthropic_provider" in sys.modules
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user