From 3dfdab704e14b99de3ac93b24642eb9f09daab44 Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Tue, 24 Mar 2026 17:53:35 +0000 Subject: [PATCH] refactor: replace litellm with native openai + anthropic SDKs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Remove litellm dependency entirely (supply chain risk mitigation) - Add AnthropicProvider (native SDK) and OpenAICompatProvider (unified) - Merge CustomProvider into OpenAICompatProvider, delete custom_provider.py - Add ProviderSpec.backend field for declarative provider routing - Remove _resolve_model, find_gateway, find_by_model (dead heuristics) - Pass resolved spec directly into provider — zero internal lookups - Stub out litellm-dependent model database (cli/models.py) - Add anthropic>=0.45.0 to dependencies, remove litellm - 593 tests passed, net -1034 lines --- README.md | 16 +- nanobot/cli/commands.py | 83 ++-- nanobot/cli/models.py | 214 +-------- nanobot/config/schema.py | 3 +- nanobot/providers/__init__.py | 15 +- nanobot/providers/anthropic_provider.py | 441 ++++++++++++++++++ nanobot/providers/custom_provider.py | 152 ------ nanobot/providers/litellm_provider.py | 413 ---------------- nanobot/providers/openai_compat_provider.py | 349 ++++++++++++++ nanobot/providers/registry.py | 339 +++----------- pyproject.toml | 2 +- tests/agent/test_gemini_thought_signature.py | 34 -- .../agent/test_memory_consolidation_types.py | 2 +- tests/cli/test_commands.py | 33 +- tests/providers/test_custom_provider.py | 10 +- tests/providers/test_litellm_kwargs.py | 157 +++---- tests/providers/test_mistral_provider.py | 2 - tests/providers/test_providers_init.py | 17 +- 18 files changed, 1019 insertions(+), 1263 deletions(-) create mode 100644 nanobot/providers/anthropic_provider.py delete mode 100644 nanobot/providers/custom_provider.py delete mode 100644 nanobot/providers/litellm_provider.py create mode 100644 nanobot/providers/openai_compat_provider.py diff --git a/README.md b/README.md index c9d19a1ca..9f5e0d248 100644 --- a/README.md +++ b/README.md @@ -842,7 +842,7 @@ Config file: `~/.nanobot/config.json` | Provider | Purpose | Get API Key | |----------|---------|-------------| -| `custom` | Any OpenAI-compatible endpoint (direct, no LiteLLM) | — | +| `custom` | Any OpenAI-compatible endpoint | — | | `openrouter` | LLM (recommended, access to all models) | [openrouter.ai](https://openrouter.ai) | | `volcengine` | LLM (VolcEngine, pay-per-use) | [Coding Plan](https://www.volcengine.com/activity/codingplan?utm_campaign=nanobot&utm_content=nanobot&utm_medium=devrel&utm_source=OWO&utm_term=nanobot) · [volcengine.com](https://www.volcengine.com) | | `byteplus` | LLM (VolcEngine international, pay-per-use) | [Coding Plan](https://www.byteplus.com/en/activity/codingplan?utm_campaign=nanobot&utm_content=nanobot&utm_medium=devrel&utm_source=OWO&utm_term=nanobot) · [byteplus.com](https://www.byteplus.com) | @@ -943,7 +943,7 @@ nanobot agent -c ~/.nanobot-telegram/config.json -w /tmp/nanobot-telegram-test -
Custom Provider (Any OpenAI-compatible API) -Connects directly to any OpenAI-compatible endpoint — LM Studio, llama.cpp, Together AI, Fireworks, Azure OpenAI, or any self-hosted server. Bypasses LiteLLM; model name is passed as-is. +Connects directly to any OpenAI-compatible endpoint — LM Studio, llama.cpp, Together AI, Fireworks, Azure OpenAI, or any self-hosted server. Model name is passed as-is. ```json { @@ -1120,10 +1120,9 @@ Adding a new provider only takes **2 steps** — no if-elif chains to touch. ProviderSpec( name="myprovider", # config field name keywords=("myprovider", "mymodel"), # model-name keywords for auto-matching - env_key="MYPROVIDER_API_KEY", # env var for LiteLLM + env_key="MYPROVIDER_API_KEY", # env var name display_name="My Provider", # shown in `nanobot status` - litellm_prefix="myprovider", # auto-prefix: model → myprovider/model - skip_prefixes=("myprovider/",), # don't double-prefix + default_api_base="https://api.myprovider.com/v1", # OpenAI-compatible endpoint ) ``` @@ -1135,20 +1134,19 @@ class ProvidersConfig(BaseModel): myprovider: ProviderConfig = ProviderConfig() ``` -That's it! Environment variables, model prefixing, config matching, and `nanobot status` display will all work automatically. +That's it! Environment variables, model routing, config matching, and `nanobot status` display will all work automatically. **Common `ProviderSpec` options:** | Field | Description | Example | |-------|-------------|---------| -| `litellm_prefix` | Auto-prefix model names for LiteLLM | `"dashscope"` → `dashscope/qwen-max` | -| `skip_prefixes` | Don't prefix if model already starts with these | `("dashscope/", "openrouter/")` | +| `default_api_base` | OpenAI-compatible base URL | `"https://api.deepseek.com"` | | `env_extras` | Additional env vars to set | `(("ZHIPUAI_API_KEY", "{api_key}"),)` | | `model_overrides` | Per-model parameter overrides | `(("kimi-k2.5", {"temperature": 1.0}),)` | | `is_gateway` | Can route any model (like OpenRouter) | `True` | | `detect_by_key_prefix` | Detect gateway by API key prefix | `"sk-or-"` | | `detect_by_base_keyword` | Detect gateway by API base URL | `"openrouter"` | -| `strip_model_prefix` | Strip existing prefix before re-prefixing | `True` (for AiHubMix) | +| `strip_model_prefix` | Strip provider prefix before sending to gateway | `True` (for AiHubMix) |
diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py index 27733239c..91c81d3de 100644 --- a/nanobot/cli/commands.py +++ b/nanobot/cli/commands.py @@ -376,61 +376,61 @@ def _onboard_plugins(config_path: Path) -> None: def _make_provider(config: Config): - """Create the appropriate LLM provider from config.""" - from nanobot.providers.azure_openai_provider import AzureOpenAIProvider + """Create the appropriate LLM provider from config. + + Routing is driven by ``ProviderSpec.backend`` in the registry. + """ from nanobot.providers.base import GenerationSettings - from nanobot.providers.openai_codex_provider import OpenAICodexProvider + from nanobot.providers.registry import find_by_name model = config.agents.defaults.model provider_name = config.get_provider_name(model) p = config.get_provider(model) + spec = find_by_name(provider_name) if provider_name else None + backend = spec.backend if spec else "openai_compat" - # OpenAI Codex (OAuth) - if provider_name == "openai_codex" or model.startswith("openai-codex/"): - provider = OpenAICodexProvider(default_model=model) - # Custom: direct OpenAI-compatible endpoint, bypasses LiteLLM - elif provider_name == "custom": - from nanobot.providers.custom_provider import CustomProvider - provider = CustomProvider( - api_key=p.api_key if p else "no-key", - api_base=config.get_api_base(model) or "http://localhost:8000/v1", - default_model=model, - extra_headers=p.extra_headers if p else None, - ) - # Azure OpenAI: direct Azure OpenAI endpoint with deployment name - elif provider_name == "azure_openai": + # --- validation --- + if backend == "azure_openai": if not p or not p.api_key or not p.api_base: console.print("[red]Error: Azure OpenAI requires api_key and api_base.[/red]") console.print("Set them in ~/.nanobot/config.json under providers.azure_openai section") console.print("Use the model field to specify the deployment name.") raise typer.Exit(1) + elif backend == "openai_compat" and not model.startswith("bedrock/"): + needs_key = not (p and p.api_key) + exempt = spec and (spec.is_oauth or spec.is_local or spec.is_direct) + if needs_key and not exempt: + console.print("[red]Error: No API key configured.[/red]") + console.print("Set one in ~/.nanobot/config.json under providers section") + raise typer.Exit(1) + + # --- instantiation by backend --- + if backend == "openai_codex": + from nanobot.providers.openai_codex_provider import OpenAICodexProvider + provider = OpenAICodexProvider(default_model=model) + elif backend == "azure_openai": + from nanobot.providers.azure_openai_provider import AzureOpenAIProvider provider = AzureOpenAIProvider( api_key=p.api_key, api_base=p.api_base, default_model=model, ) - # OpenVINO Model Server: direct OpenAI-compatible endpoint at /v3 - elif provider_name == "ovms": - from nanobot.providers.custom_provider import CustomProvider - provider = CustomProvider( - api_key=p.api_key if p else "no-key", - api_base=config.get_api_base(model) or "http://localhost:8000/v3", - default_model=model, - ) - else: - from nanobot.providers.litellm_provider import LiteLLMProvider - from nanobot.providers.registry import find_by_name - spec = find_by_name(provider_name) - if not model.startswith("bedrock/") and not (p and p.api_key) and not (spec and (spec.is_oauth or spec.is_local)): - console.print("[red]Error: No API key configured.[/red]") - console.print("Set one in ~/.nanobot/config.json under providers section") - raise typer.Exit(1) - provider = LiteLLMProvider( + elif backend == "anthropic": + from nanobot.providers.anthropic_provider import AnthropicProvider + provider = AnthropicProvider( api_key=p.api_key if p else None, api_base=config.get_api_base(model), default_model=model, extra_headers=p.extra_headers if p else None, - provider_name=provider_name, + ) + else: + from nanobot.providers.openai_compat_provider import OpenAICompatProvider + provider = OpenAICompatProvider( + api_key=p.api_key if p else None, + api_base=config.get_api_base(model), + default_model=model, + extra_headers=p.extra_headers if p else None, + spec=spec, ) defaults = config.agents.defaults @@ -1203,11 +1203,20 @@ def _login_openai_codex() -> None: def _login_github_copilot() -> None: import asyncio + from openai import AsyncOpenAI + console.print("[cyan]Starting GitHub Copilot device flow...[/cyan]\n") async def _trigger(): - from litellm import acompletion - await acompletion(model="github_copilot/gpt-4o", messages=[{"role": "user", "content": "hi"}], max_tokens=1) + client = AsyncOpenAI( + api_key="dummy", + base_url="https://api.githubcopilot.com", + ) + await client.chat.completions.create( + model="gpt-4o", + messages=[{"role": "user", "content": "hi"}], + max_tokens=1, + ) try: asyncio.run(_trigger()) diff --git a/nanobot/cli/models.py b/nanobot/cli/models.py index 520370c4b..0ba24018f 100644 --- a/nanobot/cli/models.py +++ b/nanobot/cli/models.py @@ -1,229 +1,29 @@ """Model information helpers for the onboard wizard. -Provides model context window lookup and autocomplete suggestions using litellm. +Model database / autocomplete is temporarily disabled while litellm is +being replaced. All public function signatures are preserved so callers +continue to work without changes. """ from __future__ import annotations -from functools import lru_cache from typing import Any -def _litellm(): - """Lazy accessor for litellm (heavy import deferred until actually needed).""" - import litellm as _ll - - return _ll - - -@lru_cache(maxsize=1) -def _get_model_cost_map() -> dict[str, Any]: - """Get litellm's model cost map (cached).""" - return getattr(_litellm(), "model_cost", {}) - - -@lru_cache(maxsize=1) def get_all_models() -> list[str]: - """Get all known model names from litellm. - """ - models = set() - - # From model_cost (has pricing info) - cost_map = _get_model_cost_map() - for k in cost_map.keys(): - if k != "sample_spec": - models.add(k) - - # From models_by_provider (more complete provider coverage) - for provider_models in getattr(_litellm(), "models_by_provider", {}).values(): - if isinstance(provider_models, (set, list)): - models.update(provider_models) - - return sorted(models) - - -def _normalize_model_name(model: str) -> str: - """Normalize model name for comparison.""" - return model.lower().replace("-", "_").replace(".", "") + return [] def find_model_info(model_name: str) -> dict[str, Any] | None: - """Find model info with fuzzy matching. - - Args: - model_name: Model name in any common format - - Returns: - Model info dict or None if not found - """ - cost_map = _get_model_cost_map() - if not cost_map: - return None - - # Direct match - if model_name in cost_map: - return cost_map[model_name] - - # Extract base name (without provider prefix) - base_name = model_name.split("/")[-1] if "/" in model_name else model_name - base_normalized = _normalize_model_name(base_name) - - candidates = [] - - for key, info in cost_map.items(): - if key == "sample_spec": - continue - - key_base = key.split("/")[-1] if "/" in key else key - key_base_normalized = _normalize_model_name(key_base) - - # Score the match - score = 0 - - # Exact base name match (highest priority) - if base_normalized == key_base_normalized: - score = 100 - # Base name contains model - elif base_normalized in key_base_normalized: - score = 80 - # Model contains base name - elif key_base_normalized in base_normalized: - score = 70 - # Partial match - elif base_normalized[:10] in key_base_normalized: - score = 50 - - if score > 0: - # Prefer models with max_input_tokens - if info.get("max_input_tokens"): - score += 10 - candidates.append((score, key, info)) - - if not candidates: - return None - - # Return the best match - candidates.sort(key=lambda x: (-x[0], x[1])) - return candidates[0][2] - - -def get_model_context_limit(model: str, provider: str = "auto") -> int | None: - """Get the maximum input context tokens for a model. - - Args: - model: Model name (e.g., "claude-3.5-sonnet", "gpt-4o") - provider: Provider name for informational purposes (not yet used for filtering) - - Returns: - Maximum input tokens, or None if unknown - - Note: - The provider parameter is currently informational only. Future versions may - use it to prefer provider-specific model variants in the lookup. - """ - # First try fuzzy search in model_cost (has more accurate max_input_tokens) - info = find_model_info(model) - if info: - # Prefer max_input_tokens (this is what we want for context window) - max_input = info.get("max_input_tokens") - if max_input and isinstance(max_input, int): - return max_input - - # Fall back to litellm's get_max_tokens (returns max_output_tokens typically) - try: - result = _litellm().get_max_tokens(model) - if result and result > 0: - return result - except (KeyError, ValueError, AttributeError): - # Model not found in litellm's database or invalid response - pass - - # Last resort: use max_tokens from model_cost - if info: - max_tokens = info.get("max_tokens") - if max_tokens and isinstance(max_tokens, int): - return max_tokens - return None -@lru_cache(maxsize=1) -def _get_provider_keywords() -> dict[str, list[str]]: - """Build provider keywords mapping from nanobot's provider registry. - - Returns: - Dict mapping provider name to list of keywords for model filtering. - """ - try: - from nanobot.providers.registry import PROVIDERS - - mapping = {} - for spec in PROVIDERS: - if spec.keywords: - mapping[spec.name] = list(spec.keywords) - return mapping - except ImportError: - return {} +def get_model_context_limit(model: str, provider: str = "auto") -> int | None: + return None def get_model_suggestions(partial: str, provider: str = "auto", limit: int = 20) -> list[str]: - """Get autocomplete suggestions for model names. - - Args: - partial: Partial model name typed by user - provider: Provider name for filtering (e.g., "openrouter", "minimax") - limit: Maximum number of suggestions to return - - Returns: - List of matching model names - """ - all_models = get_all_models() - if not all_models: - return [] - - partial_lower = partial.lower() - partial_normalized = _normalize_model_name(partial) - - # Get provider keywords from registry - provider_keywords = _get_provider_keywords() - - # Filter by provider if specified - allowed_keywords = None - if provider and provider != "auto": - allowed_keywords = provider_keywords.get(provider.lower()) - - matches = [] - - for model in all_models: - model_lower = model.lower() - - # Apply provider filter - if allowed_keywords: - if not any(kw in model_lower for kw in allowed_keywords): - continue - - # Match against partial input - if not partial: - matches.append(model) - continue - - if partial_lower in model_lower: - # Score by position of match (earlier = better) - pos = model_lower.find(partial_lower) - score = 100 - pos - matches.append((score, model)) - elif partial_normalized in _normalize_model_name(model): - score = 50 - matches.append((score, model)) - - # Sort by score if we have scored matches - if matches and isinstance(matches[0], tuple): - matches.sort(key=lambda x: (-x[0], x[1])) - matches = [m[1] for m in matches] - else: - matches.sort() - - return matches[:limit] + return [] def format_token_count(tokens: int) -> str: diff --git a/nanobot/config/schema.py b/nanobot/config/schema.py index b31f3061a..9ae662ec8 100644 --- a/nanobot/config/schema.py +++ b/nanobot/config/schema.py @@ -249,8 +249,7 @@ class Config(BaseSettings): if p and p.api_base: return p.api_base # Only gateways get a default api_base here. Standard providers - # (like Moonshot) set their base URL via env vars in _setup_env - # to avoid polluting the global litellm.api_base. + # resolve their base URL from the registry in the provider constructor. if name: spec = find_by_name(name) if spec and (spec.is_gateway or spec.is_local) and spec.default_api_base: diff --git a/nanobot/providers/__init__.py b/nanobot/providers/__init__.py index 9d4994eb1..0e259e6f0 100644 --- a/nanobot/providers/__init__.py +++ b/nanobot/providers/__init__.py @@ -7,17 +7,26 @@ from typing import TYPE_CHECKING from nanobot.providers.base import LLMProvider, LLMResponse -__all__ = ["LLMProvider", "LLMResponse", "LiteLLMProvider", "OpenAICodexProvider", "AzureOpenAIProvider"] +__all__ = [ + "LLMProvider", + "LLMResponse", + "AnthropicProvider", + "OpenAICompatProvider", + "OpenAICodexProvider", + "AzureOpenAIProvider", +] _LAZY_IMPORTS = { - "LiteLLMProvider": ".litellm_provider", + "AnthropicProvider": ".anthropic_provider", + "OpenAICompatProvider": ".openai_compat_provider", "OpenAICodexProvider": ".openai_codex_provider", "AzureOpenAIProvider": ".azure_openai_provider", } if TYPE_CHECKING: + from nanobot.providers.anthropic_provider import AnthropicProvider from nanobot.providers.azure_openai_provider import AzureOpenAIProvider - from nanobot.providers.litellm_provider import LiteLLMProvider + from nanobot.providers.openai_compat_provider import OpenAICompatProvider from nanobot.providers.openai_codex_provider import OpenAICodexProvider diff --git a/nanobot/providers/anthropic_provider.py b/nanobot/providers/anthropic_provider.py new file mode 100644 index 000000000..3c789e730 --- /dev/null +++ b/nanobot/providers/anthropic_provider.py @@ -0,0 +1,441 @@ +"""Anthropic provider — direct SDK integration for Claude models.""" + +from __future__ import annotations + +import re +import secrets +import string +from collections.abc import Awaitable, Callable +from typing import Any + +import json_repair +from loguru import logger + +from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest + +_ALNUM = string.ascii_letters + string.digits + + +def _gen_tool_id() -> str: + return "toolu_" + "".join(secrets.choice(_ALNUM) for _ in range(22)) + + +class AnthropicProvider(LLMProvider): + """LLM provider using the native Anthropic SDK for Claude models. + + Handles message format conversion (OpenAI → Anthropic Messages API), + prompt caching, extended thinking, tool calls, and streaming. + """ + + def __init__( + self, + api_key: str | None = None, + api_base: str | None = None, + default_model: str = "claude-sonnet-4-20250514", + extra_headers: dict[str, str] | None = None, + ): + super().__init__(api_key, api_base) + self.default_model = default_model + self.extra_headers = extra_headers or {} + + from anthropic import AsyncAnthropic + + client_kw: dict[str, Any] = {} + if api_key: + client_kw["api_key"] = api_key + if api_base: + client_kw["base_url"] = api_base + if extra_headers: + client_kw["default_headers"] = extra_headers + self._client = AsyncAnthropic(**client_kw) + + @staticmethod + def _strip_prefix(model: str) -> str: + if model.startswith("anthropic/"): + return model[len("anthropic/"):] + return model + + # ------------------------------------------------------------------ + # Message conversion: OpenAI chat format → Anthropic Messages API + # ------------------------------------------------------------------ + + def _convert_messages( + self, messages: list[dict[str, Any]], + ) -> tuple[str | list[dict[str, Any]], list[dict[str, Any]]]: + """Return ``(system, anthropic_messages)``.""" + system: str | list[dict[str, Any]] = "" + raw: list[dict[str, Any]] = [] + + for msg in messages: + role = msg.get("role", "") + content = msg.get("content") + + if role == "system": + system = content if isinstance(content, (str, list)) else str(content or "") + continue + + if role == "tool": + block = self._tool_result_block(msg) + if raw and raw[-1]["role"] == "user": + prev_c = raw[-1]["content"] + if isinstance(prev_c, list): + prev_c.append(block) + else: + raw[-1]["content"] = [ + {"type": "text", "text": prev_c or ""}, block, + ] + else: + raw.append({"role": "user", "content": [block]}) + continue + + if role == "assistant": + raw.append({"role": "assistant", "content": self._assistant_blocks(msg)}) + continue + + if role == "user": + raw.append({ + "role": "user", + "content": self._convert_user_content(content), + }) + continue + + return system, self._merge_consecutive(raw) + + @staticmethod + def _tool_result_block(msg: dict[str, Any]) -> dict[str, Any]: + content = msg.get("content") + block: dict[str, Any] = { + "type": "tool_result", + "tool_use_id": msg.get("tool_call_id", ""), + } + if isinstance(content, (str, list)): + block["content"] = content + else: + block["content"] = str(content) if content else "" + return block + + @staticmethod + def _assistant_blocks(msg: dict[str, Any]) -> list[dict[str, Any]]: + blocks: list[dict[str, Any]] = [] + content = msg.get("content") + + for tb in msg.get("thinking_blocks") or []: + if isinstance(tb, dict) and tb.get("type") == "thinking": + blocks.append({ + "type": "thinking", + "thinking": tb.get("thinking", ""), + "signature": tb.get("signature", ""), + }) + + if isinstance(content, str) and content: + blocks.append({"type": "text", "text": content}) + elif isinstance(content, list): + for item in content: + blocks.append(item if isinstance(item, dict) else {"type": "text", "text": str(item)}) + + for tc in msg.get("tool_calls") or []: + if not isinstance(tc, dict): + continue + func = tc.get("function", {}) + args = func.get("arguments", "{}") + if isinstance(args, str): + args = json_repair.loads(args) + blocks.append({ + "type": "tool_use", + "id": tc.get("id") or _gen_tool_id(), + "name": func.get("name", ""), + "input": args, + }) + + return blocks or [{"type": "text", "text": ""}] + + def _convert_user_content(self, content: Any) -> Any: + """Convert user message content, translating image_url blocks.""" + if isinstance(content, str) or content is None: + return content or "(empty)" + if not isinstance(content, list): + return str(content) + + result: list[dict[str, Any]] = [] + for item in content: + if not isinstance(item, dict): + result.append({"type": "text", "text": str(item)}) + continue + if item.get("type") == "image_url": + converted = self._convert_image_block(item) + if converted: + result.append(converted) + continue + result.append(item) + return result or "(empty)" + + @staticmethod + def _convert_image_block(block: dict[str, Any]) -> dict[str, Any] | None: + """Convert OpenAI image_url block to Anthropic image block.""" + url = (block.get("image_url") or {}).get("url", "") + if not url: + return None + m = re.match(r"data:(image/\w+);base64,(.+)", url, re.DOTALL) + if m: + return { + "type": "image", + "source": {"type": "base64", "media_type": m.group(1), "data": m.group(2)}, + } + return { + "type": "image", + "source": {"type": "url", "url": url}, + } + + @staticmethod + def _merge_consecutive(msgs: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Anthropic requires alternating user/assistant roles.""" + merged: list[dict[str, Any]] = [] + for msg in msgs: + if merged and merged[-1]["role"] == msg["role"]: + prev_c = merged[-1]["content"] + cur_c = msg["content"] + if isinstance(prev_c, str): + prev_c = [{"type": "text", "text": prev_c}] + if isinstance(cur_c, str): + cur_c = [{"type": "text", "text": cur_c}] + if isinstance(cur_c, list): + prev_c.extend(cur_c) + merged[-1]["content"] = prev_c + else: + merged.append(msg) + return merged + + # ------------------------------------------------------------------ + # Tool definition conversion + # ------------------------------------------------------------------ + + @staticmethod + def _convert_tools(tools: list[dict[str, Any]] | None) -> list[dict[str, Any]] | None: + if not tools: + return None + result = [] + for tool in tools: + func = tool.get("function", tool) + entry: dict[str, Any] = { + "name": func.get("name", ""), + "input_schema": func.get("parameters", {"type": "object", "properties": {}}), + } + desc = func.get("description") + if desc: + entry["description"] = desc + if "cache_control" in tool: + entry["cache_control"] = tool["cache_control"] + result.append(entry) + return result + + @staticmethod + def _convert_tool_choice( + tool_choice: str | dict[str, Any] | None, + thinking_enabled: bool = False, + ) -> dict[str, Any] | None: + if thinking_enabled: + return {"type": "auto"} + if tool_choice is None or tool_choice == "auto": + return {"type": "auto"} + if tool_choice == "required": + return {"type": "any"} + if tool_choice == "none": + return None + if isinstance(tool_choice, dict): + name = tool_choice.get("function", {}).get("name") + if name: + return {"type": "tool", "name": name} + return {"type": "auto"} + + # ------------------------------------------------------------------ + # Prompt caching + # ------------------------------------------------------------------ + + @staticmethod + def _apply_cache_control( + system: str | list[dict[str, Any]], + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None, + ) -> tuple[str | list[dict[str, Any]], list[dict[str, Any]], list[dict[str, Any]] | None]: + marker = {"type": "ephemeral"} + + if isinstance(system, str) and system: + system = [{"type": "text", "text": system, "cache_control": marker}] + elif isinstance(system, list) and system: + system = list(system) + system[-1] = {**system[-1], "cache_control": marker} + + new_msgs = list(messages) + if len(new_msgs) >= 3: + m = new_msgs[-2] + c = m.get("content") + if isinstance(c, str): + new_msgs[-2] = {**m, "content": [{"type": "text", "text": c, "cache_control": marker}]} + elif isinstance(c, list) and c: + nc = list(c) + nc[-1] = {**nc[-1], "cache_control": marker} + new_msgs[-2] = {**m, "content": nc} + + new_tools = tools + if tools: + new_tools = list(tools) + new_tools[-1] = {**new_tools[-1], "cache_control": marker} + + return system, new_msgs, new_tools + + # ------------------------------------------------------------------ + # Build API kwargs + # ------------------------------------------------------------------ + + def _build_kwargs( + self, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None, + model: str | None, + max_tokens: int, + temperature: float, + reasoning_effort: str | None, + tool_choice: str | dict[str, Any] | None, + supports_caching: bool = True, + ) -> dict[str, Any]: + model_name = self._strip_prefix(model or self.default_model) + system, anthropic_msgs = self._convert_messages(self._sanitize_empty_content(messages)) + anthropic_tools = self._convert_tools(tools) + + if supports_caching: + system, anthropic_msgs, anthropic_tools = self._apply_cache_control( + system, anthropic_msgs, anthropic_tools, + ) + + max_tokens = max(1, max_tokens) + thinking_enabled = bool(reasoning_effort) + + kwargs: dict[str, Any] = { + "model": model_name, + "messages": anthropic_msgs, + "max_tokens": max_tokens, + } + + if system: + kwargs["system"] = system + + if thinking_enabled: + budget_map = {"low": 1024, "medium": 4096, "high": max(8192, max_tokens)} + budget = budget_map.get(reasoning_effort.lower(), 4096) # type: ignore[union-attr] + kwargs["thinking"] = {"type": "enabled", "budget_tokens": budget} + kwargs["max_tokens"] = max(max_tokens, budget + 4096) + kwargs["temperature"] = 1.0 + else: + kwargs["temperature"] = temperature + + if anthropic_tools: + kwargs["tools"] = anthropic_tools + tc = self._convert_tool_choice(tool_choice, thinking_enabled) + if tc: + kwargs["tool_choice"] = tc + + if self.extra_headers: + kwargs["extra_headers"] = self.extra_headers + + return kwargs + + # ------------------------------------------------------------------ + # Response parsing + # ------------------------------------------------------------------ + + @staticmethod + def _parse_response(response: Any) -> LLMResponse: + content_parts: list[str] = [] + tool_calls: list[ToolCallRequest] = [] + thinking_blocks: list[dict[str, Any]] = [] + + for block in response.content: + if block.type == "text": + content_parts.append(block.text) + elif block.type == "tool_use": + tool_calls.append(ToolCallRequest( + id=block.id, + name=block.name, + arguments=block.input if isinstance(block.input, dict) else {}, + )) + elif block.type == "thinking": + thinking_blocks.append({ + "type": "thinking", + "thinking": block.thinking, + "signature": getattr(block, "signature", ""), + }) + + stop_map = {"tool_use": "tool_calls", "end_turn": "stop", "max_tokens": "length"} + finish_reason = stop_map.get(response.stop_reason or "", response.stop_reason or "stop") + + usage: dict[str, int] = {} + if response.usage: + usage = { + "prompt_tokens": response.usage.input_tokens, + "completion_tokens": response.usage.output_tokens, + "total_tokens": response.usage.input_tokens + response.usage.output_tokens, + } + for attr in ("cache_creation_input_tokens", "cache_read_input_tokens"): + val = getattr(response.usage, attr, 0) + if val: + usage[attr] = val + + return LLMResponse( + content="".join(content_parts) or None, + tool_calls=tool_calls, + finish_reason=finish_reason, + usage=usage, + thinking_blocks=thinking_blocks or None, + ) + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + async def chat( + self, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None = None, + model: str | None = None, + max_tokens: int = 4096, + temperature: float = 0.7, + reasoning_effort: str | None = None, + tool_choice: str | dict[str, Any] | None = None, + ) -> LLMResponse: + kwargs = self._build_kwargs( + messages, tools, model, max_tokens, temperature, + reasoning_effort, tool_choice, + ) + try: + response = await self._client.messages.create(**kwargs) + return self._parse_response(response) + except Exception as e: + return LLMResponse(content=f"Error calling LLM: {e}", finish_reason="error") + + async def chat_stream( + self, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None = None, + model: str | None = None, + max_tokens: int = 4096, + temperature: float = 0.7, + reasoning_effort: str | None = None, + tool_choice: str | dict[str, Any] | None = None, + on_content_delta: Callable[[str], Awaitable[None]] | None = None, + ) -> LLMResponse: + kwargs = self._build_kwargs( + messages, tools, model, max_tokens, temperature, + reasoning_effort, tool_choice, + ) + try: + async with self._client.messages.stream(**kwargs) as stream: + if on_content_delta: + async for text in stream.text_stream: + await on_content_delta(text) + response = await stream.get_final_message() + return self._parse_response(response) + except Exception as e: + return LLMResponse(content=f"Error calling LLM: {e}", finish_reason="error") + + def get_default_model(self) -> str: + return self.default_model diff --git a/nanobot/providers/custom_provider.py b/nanobot/providers/custom_provider.py deleted file mode 100644 index a47dae7cd..000000000 --- a/nanobot/providers/custom_provider.py +++ /dev/null @@ -1,152 +0,0 @@ -"""Direct OpenAI-compatible provider — bypasses LiteLLM.""" - -from __future__ import annotations - -import uuid -from collections.abc import Awaitable, Callable -from typing import Any - -import json_repair -from openai import AsyncOpenAI - -from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest - - -class CustomProvider(LLMProvider): - - def __init__( - self, - api_key: str = "no-key", - api_base: str = "http://localhost:8000/v1", - default_model: str = "default", - extra_headers: dict[str, str] | None = None, - ): - super().__init__(api_key, api_base) - self.default_model = default_model - self._client = AsyncOpenAI( - api_key=api_key, - base_url=api_base, - default_headers={ - "x-session-affinity": uuid.uuid4().hex, - **(extra_headers or {}), - }, - ) - - def _build_kwargs( - self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None, - model: str | None, max_tokens: int, temperature: float, - reasoning_effort: str | None, tool_choice: str | dict[str, Any] | None, - ) -> dict[str, Any]: - kwargs: dict[str, Any] = { - "model": model or self.default_model, - "messages": self._sanitize_empty_content(messages), - "max_tokens": max(1, max_tokens), - "temperature": temperature, - } - if reasoning_effort: - kwargs["reasoning_effort"] = reasoning_effort - if tools: - kwargs.update(tools=tools, tool_choice=tool_choice or "auto") - return kwargs - - def _handle_error(self, e: Exception) -> LLMResponse: - body = getattr(e, "doc", None) or getattr(getattr(e, "response", None), "text", None) - msg = f"Error: {body.strip()[:500]}" if body and body.strip() else f"Error: {e}" - return LLMResponse(content=msg, finish_reason="error") - - async def chat(self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None, - model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7, - reasoning_effort: str | None = None, - tool_choice: str | dict[str, Any] | None = None) -> LLMResponse: - kwargs = self._build_kwargs(messages, tools, model, max_tokens, temperature, reasoning_effort, tool_choice) - try: - return self._parse(await self._client.chat.completions.create(**kwargs)) - except Exception as e: - return self._handle_error(e) - - async def chat_stream( - self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None, - model: str | None = None, max_tokens: int = 4096, temperature: float = 0.7, - reasoning_effort: str | None = None, - tool_choice: str | dict[str, Any] | None = None, - on_content_delta: Callable[[str], Awaitable[None]] | None = None, - ) -> LLMResponse: - kwargs = self._build_kwargs(messages, tools, model, max_tokens, temperature, reasoning_effort, tool_choice) - kwargs["stream"] = True - try: - stream = await self._client.chat.completions.create(**kwargs) - chunks: list[Any] = [] - async for chunk in stream: - chunks.append(chunk) - if on_content_delta and chunk.choices: - text = getattr(chunk.choices[0].delta, "content", None) - if text: - await on_content_delta(text) - return self._parse_chunks(chunks) - except Exception as e: - return self._handle_error(e) - - def _parse(self, response: Any) -> LLMResponse: - if not response.choices: - return LLMResponse( - content="Error: API returned empty choices.", - finish_reason="error", - ) - choice = response.choices[0] - msg = choice.message - tool_calls = [ - ToolCallRequest( - id=tc.id, name=tc.function.name, - arguments=json_repair.loads(tc.function.arguments) if isinstance(tc.function.arguments, str) else tc.function.arguments, - ) - for tc in (msg.tool_calls or []) - ] - u = response.usage - return LLMResponse( - content=msg.content, tool_calls=tool_calls, - finish_reason=choice.finish_reason or "stop", - usage={"prompt_tokens": u.prompt_tokens, "completion_tokens": u.completion_tokens, "total_tokens": u.total_tokens} if u else {}, - reasoning_content=getattr(msg, "reasoning_content", None) or None, - ) - - def _parse_chunks(self, chunks: list[Any]) -> LLMResponse: - """Reassemble streamed chunks into a single LLMResponse.""" - content_parts: list[str] = [] - tc_bufs: dict[int, dict[str, str]] = {} - finish_reason = "stop" - usage: dict[str, int] = {} - - for chunk in chunks: - if not chunk.choices: - if hasattr(chunk, "usage") and chunk.usage: - u = chunk.usage - usage = {"prompt_tokens": u.prompt_tokens or 0, "completion_tokens": u.completion_tokens or 0, - "total_tokens": u.total_tokens or 0} - continue - choice = chunk.choices[0] - if choice.finish_reason: - finish_reason = choice.finish_reason - delta = choice.delta - if delta and delta.content: - content_parts.append(delta.content) - for tc in (delta.tool_calls or []) if delta else []: - buf = tc_bufs.setdefault(tc.index, {"id": "", "name": "", "arguments": ""}) - if tc.id: - buf["id"] = tc.id - if tc.function and tc.function.name: - buf["name"] = tc.function.name - if tc.function and tc.function.arguments: - buf["arguments"] += tc.function.arguments - - return LLMResponse( - content="".join(content_parts) or None, - tool_calls=[ - ToolCallRequest(id=b["id"], name=b["name"], arguments=json_repair.loads(b["arguments"]) if b["arguments"] else {}) - for b in tc_bufs.values() - ], - finish_reason=finish_reason, - usage=usage, - ) - - def get_default_model(self) -> str: - return self.default_model diff --git a/nanobot/providers/litellm_provider.py b/nanobot/providers/litellm_provider.py deleted file mode 100644 index 9aa0ba680..000000000 --- a/nanobot/providers/litellm_provider.py +++ /dev/null @@ -1,413 +0,0 @@ -"""LiteLLM provider implementation for multi-provider support.""" - -import hashlib -import os -import secrets -import string -from collections.abc import Awaitable, Callable -from typing import Any - -import json_repair -import litellm -from litellm import acompletion -from loguru import logger - -from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest -from nanobot.providers.registry import find_by_model, find_gateway - -# Standard chat-completion message keys. -_ALLOWED_MSG_KEYS = frozenset({"role", "content", "tool_calls", "tool_call_id", "name", "reasoning_content"}) -_ANTHROPIC_EXTRA_KEYS = frozenset({"thinking_blocks"}) -_ALNUM = string.ascii_letters + string.digits - -def _short_tool_id() -> str: - """Generate a 9-char alphanumeric ID compatible with all providers (incl. Mistral).""" - return "".join(secrets.choice(_ALNUM) for _ in range(9)) - - -class LiteLLMProvider(LLMProvider): - """ - LLM provider using LiteLLM for multi-provider support. - - Supports OpenRouter, Anthropic, OpenAI, Gemini, MiniMax, and many other providers through - a unified interface. Provider-specific logic is driven by the registry - (see providers/registry.py) — no if-elif chains needed here. - """ - - def __init__( - self, - api_key: str | None = None, - api_base: str | None = None, - default_model: str = "anthropic/claude-opus-4-5", - extra_headers: dict[str, str] | None = None, - provider_name: str | None = None, - ): - super().__init__(api_key, api_base) - self.default_model = default_model - self.extra_headers = extra_headers or {} - - # Detect gateway / local deployment. - # provider_name (from config key) is the primary signal; - # api_key / api_base are fallback for auto-detection. - self._gateway = find_gateway(provider_name, api_key, api_base) - - # Configure environment variables - if api_key: - self._setup_env(api_key, api_base, default_model) - - if api_base: - litellm.api_base = api_base - - # Disable LiteLLM logging noise - litellm.suppress_debug_info = True - # Drop unsupported parameters for providers (e.g., gpt-5 rejects some params) - litellm.drop_params = True - - self._langsmith_enabled = bool(os.getenv("LANGSMITH_API_KEY")) - - def _setup_env(self, api_key: str, api_base: str | None, model: str) -> None: - """Set environment variables based on detected provider.""" - spec = self._gateway or find_by_model(model) - if not spec: - return - if not spec.env_key: - # OAuth/provider-only specs (for example: openai_codex) - return - - # Gateway/local overrides existing env; standard provider doesn't - if self._gateway: - os.environ[spec.env_key] = api_key - else: - os.environ.setdefault(spec.env_key, api_key) - - # Resolve env_extras placeholders: - # {api_key} → user's API key - # {api_base} → user's api_base, falling back to spec.default_api_base - effective_base = api_base or spec.default_api_base - for env_name, env_val in spec.env_extras: - resolved = env_val.replace("{api_key}", api_key) - resolved = resolved.replace("{api_base}", effective_base) - os.environ.setdefault(env_name, resolved) - - def _resolve_model(self, model: str) -> str: - """Resolve model name by applying provider/gateway prefixes.""" - if self._gateway: - prefix = self._gateway.litellm_prefix - if self._gateway.strip_model_prefix: - model = model.split("/")[-1] - if prefix: - model = f"{prefix}/{model}" - return model - - # Standard mode: auto-prefix for known providers - spec = find_by_model(model) - if spec and spec.litellm_prefix: - model = self._canonicalize_explicit_prefix(model, spec.name, spec.litellm_prefix) - if not any(model.startswith(s) for s in spec.skip_prefixes): - model = f"{spec.litellm_prefix}/{model}" - - return model - - @staticmethod - def _canonicalize_explicit_prefix(model: str, spec_name: str, canonical_prefix: str) -> str: - """Normalize explicit provider prefixes like `github-copilot/...`.""" - if "/" not in model: - return model - prefix, remainder = model.split("/", 1) - if prefix.lower().replace("-", "_") != spec_name: - return model - return f"{canonical_prefix}/{remainder}" - - def _supports_cache_control(self, model: str) -> bool: - """Return True when the provider supports cache_control on content blocks.""" - if self._gateway is not None: - return self._gateway.supports_prompt_caching - spec = find_by_model(model) - return spec is not None and spec.supports_prompt_caching - - def _apply_cache_control( - self, - messages: list[dict[str, Any]], - tools: list[dict[str, Any]] | None, - ) -> tuple[list[dict[str, Any]], list[dict[str, Any]] | None]: - """Return copies of messages and tools with cache_control injected. - - Two breakpoints are placed: - 1. System message — caches the static system prompt - 2. Second-to-last message — caches the conversation history prefix - This maximises cache hits across multi-turn conversations. - """ - cache_marker = {"type": "ephemeral"} - new_messages = list(messages) - - def _mark(msg: dict[str, Any]) -> dict[str, Any]: - content = msg.get("content") - if isinstance(content, str): - return {**msg, "content": [ - {"type": "text", "text": content, "cache_control": cache_marker} - ]} - elif isinstance(content, list) and content: - new_content = list(content) - new_content[-1] = {**new_content[-1], "cache_control": cache_marker} - return {**msg, "content": new_content} - return msg - - # Breakpoint 1: system message - if new_messages and new_messages[0].get("role") == "system": - new_messages[0] = _mark(new_messages[0]) - - # Breakpoint 2: second-to-last message (caches conversation history prefix) - if len(new_messages) >= 3: - new_messages[-2] = _mark(new_messages[-2]) - - new_tools = tools - if tools: - new_tools = list(tools) - new_tools[-1] = {**new_tools[-1], "cache_control": cache_marker} - - return new_messages, new_tools - - def _apply_model_overrides(self, model: str, kwargs: dict[str, Any]) -> None: - """Apply model-specific parameter overrides from the registry.""" - model_lower = model.lower() - spec = find_by_model(model) - if spec: - for pattern, overrides in spec.model_overrides: - if pattern in model_lower: - kwargs.update(overrides) - return - - @staticmethod - def _extra_msg_keys(original_model: str, resolved_model: str) -> frozenset[str]: - """Return provider-specific extra keys to preserve in request messages.""" - spec = find_by_model(original_model) or find_by_model(resolved_model) - if (spec and spec.name == "anthropic") or "claude" in original_model.lower() or resolved_model.startswith("anthropic/"): - return _ANTHROPIC_EXTRA_KEYS - return frozenset() - - @staticmethod - def _normalize_tool_call_id(tool_call_id: Any) -> Any: - """Normalize tool_call_id to a provider-safe 9-char alphanumeric form.""" - if not isinstance(tool_call_id, str): - return tool_call_id - if len(tool_call_id) == 9 and tool_call_id.isalnum(): - return tool_call_id - return hashlib.sha1(tool_call_id.encode()).hexdigest()[:9] - - @staticmethod - def _sanitize_messages(messages: list[dict[str, Any]], extra_keys: frozenset[str] = frozenset()) -> list[dict[str, Any]]: - """Strip non-standard keys and ensure assistant messages have a content key.""" - allowed = _ALLOWED_MSG_KEYS | extra_keys - sanitized = LLMProvider._sanitize_request_messages(messages, allowed) - id_map: dict[str, str] = {} - - def map_id(value: Any) -> Any: - if not isinstance(value, str): - return value - return id_map.setdefault(value, LiteLLMProvider._normalize_tool_call_id(value)) - - for clean in sanitized: - # Keep assistant tool_calls[].id and tool tool_call_id in sync after - # shortening, otherwise strict providers reject the broken linkage. - if isinstance(clean.get("tool_calls"), list): - normalized_tool_calls = [] - for tc in clean["tool_calls"]: - if not isinstance(tc, dict): - normalized_tool_calls.append(tc) - continue - tc_clean = dict(tc) - tc_clean["id"] = map_id(tc_clean.get("id")) - normalized_tool_calls.append(tc_clean) - clean["tool_calls"] = normalized_tool_calls - - if "tool_call_id" in clean and clean["tool_call_id"]: - clean["tool_call_id"] = map_id(clean["tool_call_id"]) - return sanitized - - def _build_chat_kwargs( - self, - messages: list[dict[str, Any]], - tools: list[dict[str, Any]] | None, - model: str | None, - max_tokens: int, - temperature: float, - reasoning_effort: str | None, - tool_choice: str | dict[str, Any] | None, - ) -> tuple[dict[str, Any], str]: - """Build the kwargs dict for ``acompletion``. - - Returns ``(kwargs, original_model)`` so callers can reuse the - original model string for downstream logic. - """ - original_model = model or self.default_model - resolved = self._resolve_model(original_model) - extra_msg_keys = self._extra_msg_keys(original_model, resolved) - - if self._supports_cache_control(original_model): - messages, tools = self._apply_cache_control(messages, tools) - - max_tokens = max(1, max_tokens) - - kwargs: dict[str, Any] = { - "model": resolved, - "messages": self._sanitize_messages( - self._sanitize_empty_content(messages), extra_keys=extra_msg_keys, - ), - "max_tokens": max_tokens, - "temperature": temperature, - } - - if self._gateway: - kwargs.update(self._gateway.litellm_kwargs) - - self._apply_model_overrides(resolved, kwargs) - - if self._langsmith_enabled: - kwargs.setdefault("callbacks", []).append("langsmith") - - if self.api_key: - kwargs["api_key"] = self.api_key - if self.api_base: - kwargs["api_base"] = self.api_base - if self.extra_headers: - kwargs["extra_headers"] = self.extra_headers - - if reasoning_effort: - kwargs["reasoning_effort"] = reasoning_effort - kwargs["drop_params"] = True - - if tools: - kwargs["tools"] = tools - kwargs["tool_choice"] = tool_choice or "auto" - - return kwargs, original_model - - async def chat( - self, - messages: list[dict[str, Any]], - tools: list[dict[str, Any]] | None = None, - model: str | None = None, - max_tokens: int = 4096, - temperature: float = 0.7, - reasoning_effort: str | None = None, - tool_choice: str | dict[str, Any] | None = None, - ) -> LLMResponse: - """Send a chat completion request via LiteLLM.""" - kwargs, _ = self._build_chat_kwargs( - messages, tools, model, max_tokens, temperature, - reasoning_effort, tool_choice, - ) - try: - response = await acompletion(**kwargs) - return self._parse_response(response) - except Exception as e: - return LLMResponse( - content=f"Error calling LLM: {str(e)}", - finish_reason="error", - ) - - async def chat_stream( - self, - messages: list[dict[str, Any]], - tools: list[dict[str, Any]] | None = None, - model: str | None = None, - max_tokens: int = 4096, - temperature: float = 0.7, - reasoning_effort: str | None = None, - tool_choice: str | dict[str, Any] | None = None, - on_content_delta: Callable[[str], Awaitable[None]] | None = None, - ) -> LLMResponse: - """Stream a chat completion via LiteLLM, forwarding text deltas.""" - kwargs, _ = self._build_chat_kwargs( - messages, tools, model, max_tokens, temperature, - reasoning_effort, tool_choice, - ) - kwargs["stream"] = True - - try: - stream = await acompletion(**kwargs) - chunks: list[Any] = [] - async for chunk in stream: - chunks.append(chunk) - if on_content_delta: - delta = chunk.choices[0].delta if chunk.choices else None - text = getattr(delta, "content", None) if delta else None - if text: - await on_content_delta(text) - - full_response = litellm.stream_chunk_builder( - chunks, messages=kwargs["messages"], - ) - return self._parse_response(full_response) - except Exception as e: - return LLMResponse( - content=f"Error calling LLM: {str(e)}", - finish_reason="error", - ) - - def _parse_response(self, response: Any) -> LLMResponse: - """Parse LiteLLM response into our standard format.""" - choice = response.choices[0] - message = choice.message - content = message.content - finish_reason = choice.finish_reason - - # Some providers (e.g. GitHub Copilot) split content and tool_calls - # across multiple choices. Merge them so tool_calls are not lost. - raw_tool_calls = [] - for ch in response.choices: - msg = ch.message - if hasattr(msg, "tool_calls") and msg.tool_calls: - raw_tool_calls.extend(msg.tool_calls) - if ch.finish_reason in ("tool_calls", "stop"): - finish_reason = ch.finish_reason - if not content and msg.content: - content = msg.content - - if len(response.choices) > 1: - logger.debug("LiteLLM response has {} choices, merged {} tool_calls", - len(response.choices), len(raw_tool_calls)) - - tool_calls = [] - for tc in raw_tool_calls: - # Parse arguments from JSON string if needed - args = tc.function.arguments - if isinstance(args, str): - args = json_repair.loads(args) - - provider_specific_fields = getattr(tc, "provider_specific_fields", None) or None - function_provider_specific_fields = ( - getattr(tc.function, "provider_specific_fields", None) or None - ) - - tool_calls.append(ToolCallRequest( - id=_short_tool_id(), - name=tc.function.name, - arguments=args, - provider_specific_fields=provider_specific_fields, - function_provider_specific_fields=function_provider_specific_fields, - )) - - usage = {} - if hasattr(response, "usage") and response.usage: - usage = { - "prompt_tokens": response.usage.prompt_tokens, - "completion_tokens": response.usage.completion_tokens, - "total_tokens": response.usage.total_tokens, - } - - reasoning_content = getattr(message, "reasoning_content", None) or None - thinking_blocks = getattr(message, "thinking_blocks", None) or None - - return LLMResponse( - content=content, - tool_calls=tool_calls, - finish_reason=finish_reason or "stop", - usage=usage, - reasoning_content=reasoning_content, - thinking_blocks=thinking_blocks, - ) - - def get_default_model(self) -> str: - """Get the default model.""" - return self.default_model diff --git a/nanobot/providers/openai_compat_provider.py b/nanobot/providers/openai_compat_provider.py new file mode 100644 index 000000000..a210bf72d --- /dev/null +++ b/nanobot/providers/openai_compat_provider.py @@ -0,0 +1,349 @@ +"""OpenAI-compatible provider for all non-Anthropic LLM APIs.""" + +from __future__ import annotations + +import hashlib +import os +import secrets +import string +import uuid +from collections.abc import Awaitable, Callable +from typing import TYPE_CHECKING, Any + +import json_repair +from openai import AsyncOpenAI + +from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest + +if TYPE_CHECKING: + from nanobot.providers.registry import ProviderSpec + +_ALLOWED_MSG_KEYS = frozenset({ + "role", "content", "tool_calls", "tool_call_id", "name", "reasoning_content", +}) +_ALNUM = string.ascii_letters + string.digits + + +def _short_tool_id() -> str: + """9-char alphanumeric ID compatible with all providers (incl. Mistral).""" + return "".join(secrets.choice(_ALNUM) for _ in range(9)) + + +class OpenAICompatProvider(LLMProvider): + """Unified provider for all OpenAI-compatible APIs. + + Receives a resolved ``ProviderSpec`` from the caller — no internal + registry lookups needed. + """ + + def __init__( + self, + api_key: str | None = None, + api_base: str | None = None, + default_model: str = "gpt-4o", + extra_headers: dict[str, str] | None = None, + spec: ProviderSpec | None = None, + ): + super().__init__(api_key, api_base) + self.default_model = default_model + self.extra_headers = extra_headers or {} + self._spec = spec + + if api_key and spec and spec.env_key: + self._setup_env(api_key, api_base) + + effective_base = api_base or (spec.default_api_base if spec else None) or None + + self._client = AsyncOpenAI( + api_key=api_key or "no-key", + base_url=effective_base, + default_headers={ + "x-session-affinity": uuid.uuid4().hex, + **(extra_headers or {}), + }, + ) + + def _setup_env(self, api_key: str, api_base: str | None) -> None: + """Set environment variables based on provider spec.""" + spec = self._spec + if not spec or not spec.env_key: + return + if spec.is_gateway: + os.environ[spec.env_key] = api_key + else: + os.environ.setdefault(spec.env_key, api_key) + effective_base = api_base or spec.default_api_base + for env_name, env_val in spec.env_extras: + resolved = env_val.replace("{api_key}", api_key).replace("{api_base}", effective_base) + os.environ.setdefault(env_name, resolved) + + @staticmethod + def _apply_cache_control( + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None, + ) -> tuple[list[dict[str, Any]], list[dict[str, Any]] | None]: + """Inject cache_control markers for prompt caching.""" + cache_marker = {"type": "ephemeral"} + new_messages = list(messages) + + def _mark(msg: dict[str, Any]) -> dict[str, Any]: + content = msg.get("content") + if isinstance(content, str): + return {**msg, "content": [ + {"type": "text", "text": content, "cache_control": cache_marker}, + ]} + if isinstance(content, list) and content: + nc = list(content) + nc[-1] = {**nc[-1], "cache_control": cache_marker} + return {**msg, "content": nc} + return msg + + if new_messages and new_messages[0].get("role") == "system": + new_messages[0] = _mark(new_messages[0]) + if len(new_messages) >= 3: + new_messages[-2] = _mark(new_messages[-2]) + + new_tools = tools + if tools: + new_tools = list(tools) + new_tools[-1] = {**new_tools[-1], "cache_control": cache_marker} + return new_messages, new_tools + + @staticmethod + def _normalize_tool_call_id(tool_call_id: Any) -> Any: + """Normalize to a provider-safe 9-char alphanumeric form.""" + if not isinstance(tool_call_id, str): + return tool_call_id + if len(tool_call_id) == 9 and tool_call_id.isalnum(): + return tool_call_id + return hashlib.sha1(tool_call_id.encode()).hexdigest()[:9] + + def _sanitize_messages(self, messages: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Strip non-standard keys, normalize tool_call IDs.""" + sanitized = LLMProvider._sanitize_request_messages(messages, _ALLOWED_MSG_KEYS) + id_map: dict[str, str] = {} + + def map_id(value: Any) -> Any: + if not isinstance(value, str): + return value + return id_map.setdefault(value, self._normalize_tool_call_id(value)) + + for clean in sanitized: + if isinstance(clean.get("tool_calls"), list): + normalized = [] + for tc in clean["tool_calls"]: + if not isinstance(tc, dict): + normalized.append(tc) + continue + tc_clean = dict(tc) + tc_clean["id"] = map_id(tc_clean.get("id")) + normalized.append(tc_clean) + clean["tool_calls"] = normalized + if "tool_call_id" in clean and clean["tool_call_id"]: + clean["tool_call_id"] = map_id(clean["tool_call_id"]) + return sanitized + + # ------------------------------------------------------------------ + # Build kwargs + # ------------------------------------------------------------------ + + def _build_kwargs( + self, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None, + model: str | None, + max_tokens: int, + temperature: float, + reasoning_effort: str | None, + tool_choice: str | dict[str, Any] | None, + ) -> dict[str, Any]: + model_name = model or self.default_model + spec = self._spec + + if spec and spec.supports_prompt_caching: + messages, tools = self._apply_cache_control(messages, tools) + + if spec and spec.strip_model_prefix: + model_name = model_name.split("/")[-1] + + kwargs: dict[str, Any] = { + "model": model_name, + "messages": self._sanitize_messages(self._sanitize_empty_content(messages)), + "max_tokens": max(1, max_tokens), + "temperature": temperature, + } + + if spec: + model_lower = model_name.lower() + for pattern, overrides in spec.model_overrides: + if pattern in model_lower: + kwargs.update(overrides) + break + + if reasoning_effort: + kwargs["reasoning_effort"] = reasoning_effort + + if tools: + kwargs["tools"] = tools + kwargs["tool_choice"] = tool_choice or "auto" + + return kwargs + + # ------------------------------------------------------------------ + # Response parsing + # ------------------------------------------------------------------ + + def _parse(self, response: Any) -> LLMResponse: + if not response.choices: + return LLMResponse(content="Error: API returned empty choices.", finish_reason="error") + + choice = response.choices[0] + msg = choice.message + content = msg.content + finish_reason = choice.finish_reason + + raw_tool_calls: list[Any] = [] + for ch in response.choices: + m = ch.message + if hasattr(m, "tool_calls") and m.tool_calls: + raw_tool_calls.extend(m.tool_calls) + if ch.finish_reason in ("tool_calls", "stop"): + finish_reason = ch.finish_reason + if not content and m.content: + content = m.content + + tool_calls = [] + for tc in raw_tool_calls: + args = tc.function.arguments + if isinstance(args, str): + args = json_repair.loads(args) + tool_calls.append(ToolCallRequest( + id=_short_tool_id(), + name=tc.function.name, + arguments=args, + )) + + usage: dict[str, int] = {} + if hasattr(response, "usage") and response.usage: + u = response.usage + usage = { + "prompt_tokens": u.prompt_tokens or 0, + "completion_tokens": u.completion_tokens or 0, + "total_tokens": u.total_tokens or 0, + } + + return LLMResponse( + content=content, + tool_calls=tool_calls, + finish_reason=finish_reason or "stop", + usage=usage, + reasoning_content=getattr(msg, "reasoning_content", None) or None, + ) + + @staticmethod + def _parse_chunks(chunks: list[Any]) -> LLMResponse: + content_parts: list[str] = [] + tc_bufs: dict[int, dict[str, str]] = {} + finish_reason = "stop" + usage: dict[str, int] = {} + + for chunk in chunks: + if not chunk.choices: + if hasattr(chunk, "usage") and chunk.usage: + u = chunk.usage + usage = { + "prompt_tokens": u.prompt_tokens or 0, + "completion_tokens": u.completion_tokens or 0, + "total_tokens": u.total_tokens or 0, + } + continue + choice = chunk.choices[0] + if choice.finish_reason: + finish_reason = choice.finish_reason + delta = choice.delta + if delta and delta.content: + content_parts.append(delta.content) + for tc in (delta.tool_calls or []) if delta else []: + buf = tc_bufs.setdefault(tc.index, {"id": "", "name": "", "arguments": ""}) + if tc.id: + buf["id"] = tc.id + if tc.function and tc.function.name: + buf["name"] = tc.function.name + if tc.function and tc.function.arguments: + buf["arguments"] += tc.function.arguments + + return LLMResponse( + content="".join(content_parts) or None, + tool_calls=[ + ToolCallRequest( + id=b["id"] or _short_tool_id(), + name=b["name"], + arguments=json_repair.loads(b["arguments"]) if b["arguments"] else {}, + ) + for b in tc_bufs.values() + ], + finish_reason=finish_reason, + usage=usage, + ) + + @staticmethod + def _handle_error(e: Exception) -> LLMResponse: + body = getattr(e, "doc", None) or getattr(getattr(e, "response", None), "text", None) + msg = f"Error: {body.strip()[:500]}" if body and body.strip() else f"Error calling LLM: {e}" + return LLMResponse(content=msg, finish_reason="error") + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + async def chat( + self, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None = None, + model: str | None = None, + max_tokens: int = 4096, + temperature: float = 0.7, + reasoning_effort: str | None = None, + tool_choice: str | dict[str, Any] | None = None, + ) -> LLMResponse: + kwargs = self._build_kwargs( + messages, tools, model, max_tokens, temperature, + reasoning_effort, tool_choice, + ) + try: + return self._parse(await self._client.chat.completions.create(**kwargs)) + except Exception as e: + return self._handle_error(e) + + async def chat_stream( + self, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None = None, + model: str | None = None, + max_tokens: int = 4096, + temperature: float = 0.7, + reasoning_effort: str | None = None, + tool_choice: str | dict[str, Any] | None = None, + on_content_delta: Callable[[str], Awaitable[None]] | None = None, + ) -> LLMResponse: + kwargs = self._build_kwargs( + messages, tools, model, max_tokens, temperature, + reasoning_effort, tool_choice, + ) + kwargs["stream"] = True + kwargs["stream_options"] = {"include_usage": True} + try: + stream = await self._client.chat.completions.create(**kwargs) + chunks: list[Any] = [] + async for chunk in stream: + chunks.append(chunk) + if on_content_delta and chunk.choices: + text = getattr(chunk.choices[0].delta, "content", None) + if text: + await on_content_delta(text) + return self._parse_chunks(chunks) + except Exception as e: + return self._handle_error(e) + + def get_default_model(self) -> str: + return self.default_model diff --git a/nanobot/providers/registry.py b/nanobot/providers/registry.py index 10e0fec9d..206b0b504 100644 --- a/nanobot/providers/registry.py +++ b/nanobot/providers/registry.py @@ -4,7 +4,7 @@ Provider Registry — single source of truth for LLM provider metadata. Adding a new provider: 1. Add a ProviderSpec to PROVIDERS below. 2. Add a field to ProvidersConfig in config/schema.py. - Done. Env vars, prefixing, config matching, status display all derive from here. + Done. Env vars, config matching, status display all derive from here. Order matters — it controls match priority and fallback. Gateways first. Every entry writes out all fields so you can copy-paste as a template. @@ -12,7 +12,7 @@ Every entry writes out all fields so you can copy-paste as a template. from __future__ import annotations -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import Any from pydantic.alias_generators import to_snake @@ -30,12 +30,12 @@ class ProviderSpec: # identity name: str # config field name, e.g. "dashscope" keywords: tuple[str, ...] # model-name keywords for matching (lowercase) - env_key: str # LiteLLM env var, e.g. "DASHSCOPE_API_KEY" + env_key: str # env var for API key, e.g. "DASHSCOPE_API_KEY" display_name: str = "" # shown in `nanobot status` - # model prefixing - litellm_prefix: str = "" # "dashscope" → model becomes "dashscope/{model}" - skip_prefixes: tuple[str, ...] = () # don't prefix if model already starts with these + # which provider implementation to use + # "openai_compat" | "anthropic" | "azure_openai" | "openai_codex" + backend: str = "openai_compat" # extra env vars, e.g. (("ZHIPUAI_API_KEY", "{api_key}"),) env_extras: tuple[tuple[str, str], ...] = () @@ -45,19 +45,18 @@ class ProviderSpec: is_local: bool = False # local deployment (vLLM, Ollama) detect_by_key_prefix: str = "" # match api_key prefix, e.g. "sk-or-" detect_by_base_keyword: str = "" # match substring in api_base URL - default_api_base: str = "" # fallback base URL + default_api_base: str = "" # OpenAI-compatible base URL for this provider # gateway behavior - strip_model_prefix: bool = False # strip "provider/" before re-prefixing - litellm_kwargs: dict[str, Any] = field(default_factory=dict) # extra kwargs passed to LiteLLM + strip_model_prefix: bool = False # strip "provider/" before sending to gateway # per-model param overrides, e.g. (("kimi-k2.5", {"temperature": 1.0}),) model_overrides: tuple[tuple[str, dict[str, Any]], ...] = () # OAuth-based providers (e.g., OpenAI Codex) don't use API keys - is_oauth: bool = False # if True, uses OAuth flow instead of API key + is_oauth: bool = False - # Direct providers bypass LiteLLM entirely (e.g., CustomProvider) + # Direct providers skip API-key validation (user supplies everything) is_direct: bool = False # Provider supports cache_control on content blocks (e.g. Anthropic prompt caching) @@ -73,13 +72,13 @@ class ProviderSpec: # --------------------------------------------------------------------------- PROVIDERS: tuple[ProviderSpec, ...] = ( - # === Custom (direct OpenAI-compatible endpoint, bypasses LiteLLM) ====== + # === Custom (direct OpenAI-compatible endpoint) ======================== ProviderSpec( name="custom", keywords=(), env_key="", display_name="Custom", - litellm_prefix="", + backend="openai_compat", is_direct=True, ), @@ -89,7 +88,7 @@ PROVIDERS: tuple[ProviderSpec, ...] = ( keywords=("azure", "azure-openai"), env_key="", display_name="Azure OpenAI", - litellm_prefix="", + backend="azure_openai", is_direct=True, ), # === Gateways (detected by api_key / api_base, not model name) ========= @@ -100,36 +99,26 @@ PROVIDERS: tuple[ProviderSpec, ...] = ( keywords=("openrouter",), env_key="OPENROUTER_API_KEY", display_name="OpenRouter", - litellm_prefix="openrouter", # anthropic/claude-3 → openrouter/anthropic/claude-3 - skip_prefixes=(), - env_extras=(), + backend="openai_compat", is_gateway=True, - is_local=False, detect_by_key_prefix="sk-or-", detect_by_base_keyword="openrouter", default_api_base="https://openrouter.ai/api/v1", - strip_model_prefix=False, - model_overrides=(), supports_prompt_caching=True, ), # AiHubMix: global gateway, OpenAI-compatible interface. - # strip_model_prefix=True: it doesn't understand "anthropic/claude-3", - # so we strip to bare "claude-3" then re-prefix as "openai/claude-3". + # strip_model_prefix=True: doesn't understand "anthropic/claude-3", + # strips to bare "claude-3". ProviderSpec( name="aihubmix", keywords=("aihubmix",), - env_key="OPENAI_API_KEY", # OpenAI-compatible + env_key="OPENAI_API_KEY", display_name="AiHubMix", - litellm_prefix="openai", # → openai/{model} - skip_prefixes=(), - env_extras=(), + backend="openai_compat", is_gateway=True, - is_local=False, - detect_by_key_prefix="", detect_by_base_keyword="aihubmix", default_api_base="https://aihubmix.com/v1", - strip_model_prefix=True, # anthropic/claude-3 → claude-3 → openai/claude-3 - model_overrides=(), + strip_model_prefix=True, ), # SiliconFlow (硅基流动): OpenAI-compatible gateway, model names keep org prefix ProviderSpec( @@ -137,16 +126,10 @@ PROVIDERS: tuple[ProviderSpec, ...] = ( keywords=("siliconflow",), env_key="OPENAI_API_KEY", display_name="SiliconFlow", - litellm_prefix="openai", - skip_prefixes=(), - env_extras=(), + backend="openai_compat", is_gateway=True, - is_local=False, - detect_by_key_prefix="", detect_by_base_keyword="siliconflow", default_api_base="https://api.siliconflow.cn/v1", - strip_model_prefix=False, - model_overrides=(), ), # VolcEngine (火山引擎): OpenAI-compatible gateway, pay-per-use models @@ -155,16 +138,10 @@ PROVIDERS: tuple[ProviderSpec, ...] = ( keywords=("volcengine", "volces", "ark"), env_key="OPENAI_API_KEY", display_name="VolcEngine", - litellm_prefix="volcengine", - skip_prefixes=(), - env_extras=(), + backend="openai_compat", is_gateway=True, - is_local=False, - detect_by_key_prefix="", detect_by_base_keyword="volces", default_api_base="https://ark.cn-beijing.volces.com/api/v3", - strip_model_prefix=False, - model_overrides=(), ), # VolcEngine Coding Plan (火山引擎 Coding Plan): same key as volcengine @@ -173,16 +150,10 @@ PROVIDERS: tuple[ProviderSpec, ...] = ( keywords=("volcengine-plan",), env_key="OPENAI_API_KEY", display_name="VolcEngine Coding Plan", - litellm_prefix="volcengine", - skip_prefixes=(), - env_extras=(), + backend="openai_compat", is_gateway=True, - is_local=False, - detect_by_key_prefix="", - detect_by_base_keyword="", default_api_base="https://ark.cn-beijing.volces.com/api/coding/v3", strip_model_prefix=True, - model_overrides=(), ), # BytePlus: VolcEngine international, pay-per-use models @@ -191,16 +162,11 @@ PROVIDERS: tuple[ProviderSpec, ...] = ( keywords=("byteplus",), env_key="OPENAI_API_KEY", display_name="BytePlus", - litellm_prefix="volcengine", - skip_prefixes=(), - env_extras=(), + backend="openai_compat", is_gateway=True, - is_local=False, - detect_by_key_prefix="", detect_by_base_keyword="bytepluses", default_api_base="https://ark.ap-southeast.bytepluses.com/api/v3", strip_model_prefix=True, - model_overrides=(), ), # BytePlus Coding Plan: same key as byteplus @@ -209,250 +175,137 @@ PROVIDERS: tuple[ProviderSpec, ...] = ( keywords=("byteplus-plan",), env_key="OPENAI_API_KEY", display_name="BytePlus Coding Plan", - litellm_prefix="volcengine", - skip_prefixes=(), - env_extras=(), + backend="openai_compat", is_gateway=True, - is_local=False, - detect_by_key_prefix="", - detect_by_base_keyword="", default_api_base="https://ark.ap-southeast.bytepluses.com/api/coding/v3", strip_model_prefix=True, - model_overrides=(), ), # === Standard providers (matched by model-name keywords) =============== - # Anthropic: LiteLLM recognizes "claude-*" natively, no prefix needed. + # Anthropic: native Anthropic SDK ProviderSpec( name="anthropic", keywords=("anthropic", "claude"), env_key="ANTHROPIC_API_KEY", display_name="Anthropic", - litellm_prefix="", - skip_prefixes=(), - env_extras=(), - is_gateway=False, - is_local=False, - detect_by_key_prefix="", - detect_by_base_keyword="", - default_api_base="", - strip_model_prefix=False, - model_overrides=(), + backend="anthropic", supports_prompt_caching=True, ), - # OpenAI: LiteLLM recognizes "gpt-*" natively, no prefix needed. + # OpenAI: SDK default base URL (no override needed) ProviderSpec( name="openai", keywords=("openai", "gpt"), env_key="OPENAI_API_KEY", display_name="OpenAI", - litellm_prefix="", - skip_prefixes=(), - env_extras=(), - is_gateway=False, - is_local=False, - detect_by_key_prefix="", - detect_by_base_keyword="", - default_api_base="", - strip_model_prefix=False, - model_overrides=(), + backend="openai_compat", ), - # OpenAI Codex: uses OAuth, not API key. + # OpenAI Codex: OAuth-based, dedicated provider ProviderSpec( name="openai_codex", keywords=("openai-codex",), - env_key="", # OAuth-based, no API key + env_key="", display_name="OpenAI Codex", - litellm_prefix="", # Not routed through LiteLLM - skip_prefixes=(), - env_extras=(), - is_gateway=False, - is_local=False, - detect_by_key_prefix="", + backend="openai_codex", detect_by_base_keyword="codex", default_api_base="https://chatgpt.com/backend-api", - strip_model_prefix=False, - model_overrides=(), - is_oauth=True, # OAuth-based authentication + is_oauth=True, ), - # Github Copilot: uses OAuth, not API key. + # GitHub Copilot: OAuth-based ProviderSpec( name="github_copilot", keywords=("github_copilot", "copilot"), - env_key="", # OAuth-based, no API key + env_key="", display_name="Github Copilot", - litellm_prefix="github_copilot", # github_copilot/model → github_copilot/model - skip_prefixes=("github_copilot/",), - env_extras=(), - is_gateway=False, - is_local=False, - detect_by_key_prefix="", - detect_by_base_keyword="", - default_api_base="", - strip_model_prefix=False, - model_overrides=(), - is_oauth=True, # OAuth-based authentication + backend="openai_compat", + default_api_base="https://api.githubcopilot.com", + is_oauth=True, ), - # DeepSeek: needs "deepseek/" prefix for LiteLLM routing. + # DeepSeek: OpenAI-compatible at api.deepseek.com ProviderSpec( name="deepseek", keywords=("deepseek",), env_key="DEEPSEEK_API_KEY", display_name="DeepSeek", - litellm_prefix="deepseek", # deepseek-chat → deepseek/deepseek-chat - skip_prefixes=("deepseek/",), # avoid double-prefix - env_extras=(), - is_gateway=False, - is_local=False, - detect_by_key_prefix="", - detect_by_base_keyword="", - default_api_base="", - strip_model_prefix=False, - model_overrides=(), + backend="openai_compat", + default_api_base="https://api.deepseek.com", ), - # Gemini: needs "gemini/" prefix for LiteLLM. + # Gemini: Google's OpenAI-compatible endpoint ProviderSpec( name="gemini", keywords=("gemini",), env_key="GEMINI_API_KEY", display_name="Gemini", - litellm_prefix="gemini", # gemini-pro → gemini/gemini-pro - skip_prefixes=("gemini/",), # avoid double-prefix - env_extras=(), - is_gateway=False, - is_local=False, - detect_by_key_prefix="", - detect_by_base_keyword="", - default_api_base="", - strip_model_prefix=False, - model_overrides=(), + backend="openai_compat", + default_api_base="https://generativelanguage.googleapis.com/v1beta/openai/", ), - # Zhipu: LiteLLM uses "zai/" prefix. - # Also mirrors key to ZHIPUAI_API_KEY (some LiteLLM paths check that). - # skip_prefixes: don't add "zai/" when already routed via gateway. + # Zhipu (智谱): OpenAI-compatible at open.bigmodel.cn ProviderSpec( name="zhipu", keywords=("zhipu", "glm", "zai"), env_key="ZAI_API_KEY", display_name="Zhipu AI", - litellm_prefix="zai", # glm-4 → zai/glm-4 - skip_prefixes=("zhipu/", "zai/", "openrouter/", "hosted_vllm/"), + backend="openai_compat", env_extras=(("ZHIPUAI_API_KEY", "{api_key}"),), - is_gateway=False, - is_local=False, - detect_by_key_prefix="", - detect_by_base_keyword="", - default_api_base="", - strip_model_prefix=False, - model_overrides=(), + default_api_base="https://open.bigmodel.cn/api/paas/v4", ), - # DashScope: Qwen models, needs "dashscope/" prefix. + # DashScope (通义): Qwen models, OpenAI-compatible endpoint ProviderSpec( name="dashscope", keywords=("qwen", "dashscope"), env_key="DASHSCOPE_API_KEY", display_name="DashScope", - litellm_prefix="dashscope", # qwen-max → dashscope/qwen-max - skip_prefixes=("dashscope/", "openrouter/"), - env_extras=(), - is_gateway=False, - is_local=False, - detect_by_key_prefix="", - detect_by_base_keyword="", - default_api_base="", - strip_model_prefix=False, - model_overrides=(), + backend="openai_compat", + default_api_base="https://dashscope.aliyuncs.com/compatible-mode/v1", ), - # Moonshot: Kimi models, needs "moonshot/" prefix. - # LiteLLM requires MOONSHOT_API_BASE env var to find the endpoint. - # Kimi K2.5 API enforces temperature >= 1.0. + # Moonshot (月之暗面): Kimi models. K2.5 enforces temperature >= 1.0. ProviderSpec( name="moonshot", keywords=("moonshot", "kimi"), env_key="MOONSHOT_API_KEY", display_name="Moonshot", - litellm_prefix="moonshot", # kimi-k2.5 → moonshot/kimi-k2.5 - skip_prefixes=("moonshot/", "openrouter/"), - env_extras=(("MOONSHOT_API_BASE", "{api_base}"),), - is_gateway=False, - is_local=False, - detect_by_key_prefix="", - detect_by_base_keyword="", - default_api_base="https://api.moonshot.ai/v1", # intl; use api.moonshot.cn for China - strip_model_prefix=False, + backend="openai_compat", + default_api_base="https://api.moonshot.ai/v1", model_overrides=(("kimi-k2.5", {"temperature": 1.0}),), ), - # MiniMax: needs "minimax/" prefix for LiteLLM routing. - # Uses OpenAI-compatible API at api.minimax.io/v1. + # MiniMax: OpenAI-compatible API ProviderSpec( name="minimax", keywords=("minimax",), env_key="MINIMAX_API_KEY", display_name="MiniMax", - litellm_prefix="minimax", # MiniMax-M2.1 → minimax/MiniMax-M2.1 - skip_prefixes=("minimax/", "openrouter/"), - env_extras=(), - is_gateway=False, - is_local=False, - detect_by_key_prefix="", - detect_by_base_keyword="", + backend="openai_compat", default_api_base="https://api.minimax.io/v1", - strip_model_prefix=False, - model_overrides=(), ), - # Mistral AI: OpenAI-compatible API at api.mistral.ai/v1. + # Mistral AI: OpenAI-compatible API ProviderSpec( name="mistral", keywords=("mistral",), env_key="MISTRAL_API_KEY", display_name="Mistral", - litellm_prefix="mistral", # mistral-large-latest → mistral/mistral-large-latest - skip_prefixes=("mistral/",), # avoid double-prefix - env_extras=(), - is_gateway=False, - is_local=False, - detect_by_key_prefix="", - detect_by_base_keyword="", + backend="openai_compat", default_api_base="https://api.mistral.ai/v1", - strip_model_prefix=False, - model_overrides=(), ), # === Local deployment (matched by config key, NOT by api_base) ========= - # vLLM / any OpenAI-compatible local server. - # Detected when config key is "vllm" (provider_name="vllm"). + # vLLM / any OpenAI-compatible local server ProviderSpec( name="vllm", keywords=("vllm",), env_key="HOSTED_VLLM_API_KEY", display_name="vLLM/Local", - litellm_prefix="hosted_vllm", # Llama-3-8B → hosted_vllm/Llama-3-8B - skip_prefixes=(), - env_extras=(), - is_gateway=False, + backend="openai_compat", is_local=True, - detect_by_key_prefix="", - detect_by_base_keyword="", - default_api_base="", # user must provide in config - strip_model_prefix=False, - model_overrides=(), ), - # === Ollama (local, OpenAI-compatible) =================================== + # Ollama (local, OpenAI-compatible) ProviderSpec( name="ollama", keywords=("ollama", "nemotron"), env_key="OLLAMA_API_KEY", display_name="Ollama", - litellm_prefix="ollama_chat", # model → ollama_chat/model - skip_prefixes=("ollama/", "ollama_chat/"), - env_extras=(), - is_gateway=False, + backend="openai_compat", is_local=True, - detect_by_key_prefix="", detect_by_base_keyword="11434", - default_api_base="http://localhost:11434", - strip_model_prefix=False, - model_overrides=(), + default_api_base="http://localhost:11434/v1", ), # === OpenVINO Model Server (direct, local, OpenAI-compatible at /v3) === ProviderSpec( @@ -460,29 +313,20 @@ PROVIDERS: tuple[ProviderSpec, ...] = ( keywords=("openvino", "ovms"), env_key="", display_name="OpenVINO Model Server", - litellm_prefix="", + backend="openai_compat", is_direct=True, is_local=True, default_api_base="http://localhost:8000/v3", ), # === Auxiliary (not a primary LLM provider) ============================ - # Groq: mainly used for Whisper voice transcription, also usable for LLM. - # Needs "groq/" prefix for LiteLLM routing. Placed last — it rarely wins fallback. + # Groq: mainly used for Whisper voice transcription, also usable for LLM ProviderSpec( name="groq", keywords=("groq",), env_key="GROQ_API_KEY", display_name="Groq", - litellm_prefix="groq", # llama3-8b-8192 → groq/llama3-8b-8192 - skip_prefixes=("groq/",), # avoid double-prefix - env_extras=(), - is_gateway=False, - is_local=False, - detect_by_key_prefix="", - detect_by_base_keyword="", - default_api_base="", - strip_model_prefix=False, - model_overrides=(), + backend="openai_compat", + default_api_base="https://api.groq.com/openai/v1", ), ) @@ -492,59 +336,6 @@ PROVIDERS: tuple[ProviderSpec, ...] = ( # --------------------------------------------------------------------------- -def find_by_model(model: str) -> ProviderSpec | None: - """Match a standard provider by model-name keyword (case-insensitive). - Skips gateways/local — those are matched by api_key/api_base instead.""" - model_lower = model.lower() - model_normalized = model_lower.replace("-", "_") - model_prefix = model_lower.split("/", 1)[0] if "/" in model_lower else "" - normalized_prefix = model_prefix.replace("-", "_") - std_specs = [s for s in PROVIDERS if not s.is_gateway and not s.is_local] - - # Prefer explicit provider prefix — prevents `github-copilot/...codex` matching openai_codex. - for spec in std_specs: - if model_prefix and normalized_prefix == spec.name: - return spec - - for spec in std_specs: - if any( - kw in model_lower or kw.replace("-", "_") in model_normalized for kw in spec.keywords - ): - return spec - return None - - -def find_gateway( - provider_name: str | None = None, - api_key: str | None = None, - api_base: str | None = None, -) -> ProviderSpec | None: - """Detect gateway/local provider. - - Priority: - 1. provider_name — if it maps to a gateway/local spec, use it directly. - 2. api_key prefix — e.g. "sk-or-" → OpenRouter. - 3. api_base keyword — e.g. "aihubmix" in URL → AiHubMix. - - A standard provider with a custom api_base (e.g. DeepSeek behind a proxy) - will NOT be mistaken for vLLM — the old fallback is gone. - """ - # 1. Direct match by config key - if provider_name: - spec = find_by_name(provider_name) - if spec and (spec.is_gateway or spec.is_local): - return spec - - # 2. Auto-detect by api_key prefix / api_base keyword - for spec in PROVIDERS: - if spec.detect_by_key_prefix and api_key and api_key.startswith(spec.detect_by_key_prefix): - return spec - if spec.detect_by_base_keyword and api_base and spec.detect_by_base_keyword in api_base: - return spec - - return None - - def find_by_name(name: str) -> ProviderSpec | None: """Find a provider spec by config field name, e.g. "dashscope".""" normalized = to_snake(name.replace("-", "_")) diff --git a/pyproject.toml b/pyproject.toml index 246ca3074..aca72777d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ classifiers = [ dependencies = [ "typer>=0.20.0,<1.0.0", - "litellm>=1.82.1,<=1.82.6", + "anthropic>=0.45.0,<1.0.0", "pydantic>=2.12.0,<3.0.0", "pydantic-settings>=2.12.0,<3.0.0", "websockets>=16.0,<17.0", diff --git a/tests/agent/test_gemini_thought_signature.py b/tests/agent/test_gemini_thought_signature.py index bc4132c37..35739602a 100644 --- a/tests/agent/test_gemini_thought_signature.py +++ b/tests/agent/test_gemini_thought_signature.py @@ -1,40 +1,6 @@ from types import SimpleNamespace from nanobot.providers.base import ToolCallRequest -from nanobot.providers.litellm_provider import LiteLLMProvider - - -def test_litellm_parse_response_preserves_tool_call_provider_fields() -> None: - provider = LiteLLMProvider(default_model="gemini/gemini-3-flash") - - response = SimpleNamespace( - choices=[ - SimpleNamespace( - finish_reason="tool_calls", - message=SimpleNamespace( - content=None, - tool_calls=[ - SimpleNamespace( - id="call_123", - function=SimpleNamespace( - name="read_file", - arguments='{"path":"todo.md"}', - provider_specific_fields={"inner": "value"}, - ), - provider_specific_fields={"thought_signature": "signed-token"}, - ) - ], - ), - ) - ], - usage=None, - ) - - parsed = provider._parse_response(response) - - assert len(parsed.tool_calls) == 1 - assert parsed.tool_calls[0].provider_specific_fields == {"thought_signature": "signed-token"} - assert parsed.tool_calls[0].function_provider_specific_fields == {"inner": "value"} def test_tool_call_request_serializes_provider_fields() -> None: diff --git a/tests/agent/test_memory_consolidation_types.py b/tests/agent/test_memory_consolidation_types.py index d63cc9047..203e39a90 100644 --- a/tests/agent/test_memory_consolidation_types.py +++ b/tests/agent/test_memory_consolidation_types.py @@ -380,7 +380,7 @@ class TestMemoryConsolidationTypeHandling: """Forced tool_choice rejected by provider -> retry with auto and succeed.""" store = MemoryStore(tmp_path) error_resp = LLMResponse( - content="Error calling LLM: litellm.BadRequestError: " + content="Error calling LLM: BadRequestError: " "The tool_choice parameter does not support being set to required or object", finish_reason="error", tool_calls=[], diff --git a/tests/cli/test_commands.py b/tests/cli/test_commands.py index 4e79fc717..a8fcc4aa0 100644 --- a/tests/cli/test_commands.py +++ b/tests/cli/test_commands.py @@ -9,9 +9,8 @@ from typer.testing import CliRunner from nanobot.bus.events import OutboundMessage from nanobot.cli.commands import _make_provider, app from nanobot.config.schema import Config -from nanobot.providers.litellm_provider import LiteLLMProvider from nanobot.providers.openai_codex_provider import _strip_model_prefix -from nanobot.providers.registry import find_by_model, find_by_name +from nanobot.providers.registry import find_by_name runner = CliRunner() @@ -228,7 +227,7 @@ def test_config_matches_explicit_ollama_prefix_without_api_key(): config.agents.defaults.model = "ollama/llama3.2" assert config.get_provider_name() == "ollama" - assert config.get_api_base() == "http://localhost:11434" + assert config.get_api_base() == "http://localhost:11434/v1" def test_config_explicit_ollama_provider_uses_default_localhost_api_base(): @@ -237,7 +236,7 @@ def test_config_explicit_ollama_provider_uses_default_localhost_api_base(): config.agents.defaults.model = "llama3.2" assert config.get_provider_name() == "ollama" - assert config.get_api_base() == "http://localhost:11434" + assert config.get_api_base() == "http://localhost:11434/v1" def test_config_accepts_camel_case_explicit_provider_name_for_coding_plan(): @@ -272,12 +271,12 @@ def test_config_auto_detects_ollama_from_local_api_base(): config = Config.model_validate( { "agents": {"defaults": {"provider": "auto", "model": "llama3.2"}}, - "providers": {"ollama": {"apiBase": "http://localhost:11434"}}, + "providers": {"ollama": {"apiBase": "http://localhost:11434/v1"}}, } ) assert config.get_provider_name() == "ollama" - assert config.get_api_base() == "http://localhost:11434" + assert config.get_api_base() == "http://localhost:11434/v1" def test_config_prefers_ollama_over_vllm_when_both_local_providers_configured(): @@ -286,13 +285,13 @@ def test_config_prefers_ollama_over_vllm_when_both_local_providers_configured(): "agents": {"defaults": {"provider": "auto", "model": "llama3.2"}}, "providers": { "vllm": {"apiBase": "http://localhost:8000"}, - "ollama": {"apiBase": "http://localhost:11434"}, + "ollama": {"apiBase": "http://localhost:11434/v1"}, }, } ) assert config.get_provider_name() == "ollama" - assert config.get_api_base() == "http://localhost:11434" + assert config.get_api_base() == "http://localhost:11434/v1" def test_config_falls_back_to_vllm_when_ollama_not_configured(): @@ -309,19 +308,13 @@ def test_config_falls_back_to_vllm_when_ollama_not_configured(): assert config.get_api_base() == "http://localhost:8000" -def test_find_by_model_prefers_explicit_prefix_over_generic_codex_keyword(): - spec = find_by_model("github-copilot/gpt-5.3-codex") +def test_openai_compat_provider_passes_model_through(): + from nanobot.providers.openai_compat_provider import OpenAICompatProvider - assert spec is not None - assert spec.name == "github_copilot" + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): + provider = OpenAICompatProvider(default_model="github-copilot/gpt-5.3-codex") - -def test_litellm_provider_canonicalizes_github_copilot_hyphen_prefix(): - provider = LiteLLMProvider(default_model="github-copilot/gpt-5.3-codex") - - resolved = provider._resolve_model("github-copilot/gpt-5.3-codex") - - assert resolved == "github_copilot/gpt-5.3-codex" + assert provider.get_default_model() == "github-copilot/gpt-5.3-codex" def test_openai_codex_strip_prefix_supports_hyphen_and_underscore(): @@ -346,7 +339,7 @@ def test_make_provider_passes_extra_headers_to_custom_provider(): } ) - with patch("nanobot.providers.custom_provider.AsyncOpenAI") as mock_async_openai: + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as mock_async_openai: _make_provider(config) kwargs = mock_async_openai.call_args.kwargs diff --git a/tests/providers/test_custom_provider.py b/tests/providers/test_custom_provider.py index 463affedc..bb46b887a 100644 --- a/tests/providers/test_custom_provider.py +++ b/tests/providers/test_custom_provider.py @@ -1,10 +1,14 @@ -from types import SimpleNamespace +"""Tests for OpenAICompatProvider handling custom/direct endpoints.""" -from nanobot.providers.custom_provider import CustomProvider +from types import SimpleNamespace +from unittest.mock import patch + +from nanobot.providers.openai_compat_provider import OpenAICompatProvider def test_custom_provider_parse_handles_empty_choices() -> None: - provider = CustomProvider() + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): + provider = OpenAICompatProvider() response = SimpleNamespace(choices=[]) result = provider._parse(response) diff --git a/tests/providers/test_litellm_kwargs.py b/tests/providers/test_litellm_kwargs.py index 437f8a555..c55857b3b 100644 --- a/tests/providers/test_litellm_kwargs.py +++ b/tests/providers/test_litellm_kwargs.py @@ -1,161 +1,122 @@ -"""Regression tests for PR #2026 — litellm_kwargs injection from ProviderSpec. +"""Tests for OpenAICompatProvider spec-driven behavior. Validates that: -- OpenRouter uses litellm_prefix (NOT custom_llm_provider) to avoid LiteLLM double-prefixing. -- The litellm_kwargs mechanism works correctly for providers that declare it. -- Non-gateway providers are unaffected. +- OpenRouter (no strip) keeps model names intact. +- AiHubMix (strip_model_prefix=True) strips provider prefixes. +- Standard providers pass model names through as-is. """ from __future__ import annotations from types import SimpleNamespace -from typing import Any from unittest.mock import AsyncMock, patch import pytest -from nanobot.providers.litellm_provider import LiteLLMProvider +from nanobot.providers.openai_compat_provider import OpenAICompatProvider from nanobot.providers.registry import find_by_name -def _fake_response(content: str = "ok") -> SimpleNamespace: - """Build a minimal acompletion-shaped response object.""" +def _fake_chat_response(content: str = "ok") -> SimpleNamespace: + """Build a minimal OpenAI chat completion response.""" message = SimpleNamespace( content=content, tool_calls=None, reasoning_content=None, - thinking_blocks=None, ) choice = SimpleNamespace(message=message, finish_reason="stop") usage = SimpleNamespace(prompt_tokens=10, completion_tokens=5, total_tokens=15) return SimpleNamespace(choices=[choice], usage=usage) -def test_openrouter_spec_uses_prefix_not_custom_llm_provider() -> None: - """OpenRouter must rely on litellm_prefix, not custom_llm_provider kwarg. - - LiteLLM internally adds a provider/ prefix when custom_llm_provider is set, - which double-prefixes models (openrouter/anthropic/model) and breaks the API. - """ +def test_openrouter_spec_is_gateway() -> None: spec = find_by_name("openrouter") assert spec is not None - assert spec.litellm_prefix == "openrouter" - assert "custom_llm_provider" not in spec.litellm_kwargs, ( - "custom_llm_provider causes LiteLLM to double-prefix the model name" - ) + assert spec.is_gateway is True + assert spec.default_api_base == "https://openrouter.ai/api/v1" @pytest.mark.asyncio -async def test_openrouter_prefixes_model_correctly() -> None: - """OpenRouter should prefix model as openrouter/vendor/model for LiteLLM routing.""" - mock_acompletion = AsyncMock(return_value=_fake_response()) +async def test_openrouter_keeps_model_name_intact() -> None: + """OpenRouter gateway keeps the full model name (gateway does its own routing).""" + mock_create = AsyncMock(return_value=_fake_chat_response()) + spec = find_by_name("openrouter") - with patch("nanobot.providers.litellm_provider.acompletion", mock_acompletion): - provider = LiteLLMProvider( + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient: + client_instance = MockClient.return_value + client_instance.chat.completions.create = mock_create + + provider = OpenAICompatProvider( api_key="sk-or-test-key", api_base="https://openrouter.ai/api/v1", default_model="anthropic/claude-sonnet-4-5", - provider_name="openrouter", + spec=spec, ) await provider.chat( messages=[{"role": "user", "content": "hello"}], model="anthropic/claude-sonnet-4-5", ) - call_kwargs = mock_acompletion.call_args.kwargs - assert call_kwargs["model"] == "openrouter/anthropic/claude-sonnet-4-5", ( - "LiteLLM needs openrouter/ prefix to detect the provider and strip it before API call" - ) - assert "custom_llm_provider" not in call_kwargs + call_kwargs = mock_create.call_args.kwargs + assert call_kwargs["model"] == "anthropic/claude-sonnet-4-5" @pytest.mark.asyncio -async def test_non_gateway_provider_no_extra_kwargs() -> None: - """Standard (non-gateway) providers must NOT inject any litellm_kwargs.""" - mock_acompletion = AsyncMock(return_value=_fake_response()) +async def test_aihubmix_strips_model_prefix() -> None: + """AiHubMix strips the provider prefix (strip_model_prefix=True).""" + mock_create = AsyncMock(return_value=_fake_chat_response()) + spec = find_by_name("aihubmix") - with patch("nanobot.providers.litellm_provider.acompletion", mock_acompletion): - provider = LiteLLMProvider( - api_key="sk-ant-test-key", - default_model="claude-sonnet-4-5", - ) - await provider.chat( - messages=[{"role": "user", "content": "hello"}], - model="claude-sonnet-4-5", - ) + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient: + client_instance = MockClient.return_value + client_instance.chat.completions.create = mock_create - call_kwargs = mock_acompletion.call_args.kwargs - assert "custom_llm_provider" not in call_kwargs, ( - "Standard Anthropic provider should NOT inject custom_llm_provider" - ) - - -@pytest.mark.asyncio -async def test_gateway_without_litellm_kwargs_injects_nothing_extra() -> None: - """Gateways without litellm_kwargs (e.g. AiHubMix) must not add extra keys.""" - mock_acompletion = AsyncMock(return_value=_fake_response()) - - with patch("nanobot.providers.litellm_provider.acompletion", mock_acompletion): - provider = LiteLLMProvider( + provider = OpenAICompatProvider( api_key="sk-aihub-test-key", api_base="https://aihubmix.com/v1", default_model="claude-sonnet-4-5", - provider_name="aihubmix", - ) - await provider.chat( - messages=[{"role": "user", "content": "hello"}], - model="claude-sonnet-4-5", - ) - - call_kwargs = mock_acompletion.call_args.kwargs - assert "custom_llm_provider" not in call_kwargs - - -@pytest.mark.asyncio -async def test_openrouter_autodetect_by_key_prefix() -> None: - """OpenRouter should be auto-detected by sk-or- key prefix even without explicit provider_name.""" - mock_acompletion = AsyncMock(return_value=_fake_response()) - - with patch("nanobot.providers.litellm_provider.acompletion", mock_acompletion): - provider = LiteLLMProvider( - api_key="sk-or-auto-detect-key", - default_model="anthropic/claude-sonnet-4-5", + spec=spec, ) await provider.chat( messages=[{"role": "user", "content": "hello"}], model="anthropic/claude-sonnet-4-5", ) - call_kwargs = mock_acompletion.call_args.kwargs - assert call_kwargs["model"] == "openrouter/anthropic/claude-sonnet-4-5", ( - "Auto-detected OpenRouter should prefix model for LiteLLM routing" - ) + call_kwargs = mock_create.call_args.kwargs + assert call_kwargs["model"] == "claude-sonnet-4-5" @pytest.mark.asyncio -async def test_openrouter_native_model_id_gets_double_prefixed() -> None: - """Models like openrouter/free must be double-prefixed so LiteLLM strips one layer. +async def test_standard_provider_passes_model_through() -> None: + """Standard provider (e.g. deepseek) passes model name through as-is.""" + mock_create = AsyncMock(return_value=_fake_chat_response()) + spec = find_by_name("deepseek") - openrouter/free is an actual OpenRouter model ID. LiteLLM strips the first - openrouter/ for routing, so we must send openrouter/openrouter/free to ensure - the API receives openrouter/free. - """ - mock_acompletion = AsyncMock(return_value=_fake_response()) + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI") as MockClient: + client_instance = MockClient.return_value + client_instance.chat.completions.create = mock_create - with patch("nanobot.providers.litellm_provider.acompletion", mock_acompletion): - provider = LiteLLMProvider( - api_key="sk-or-test-key", - api_base="https://openrouter.ai/api/v1", - default_model="openrouter/free", - provider_name="openrouter", + provider = OpenAICompatProvider( + api_key="sk-deepseek-test-key", + default_model="deepseek-chat", + spec=spec, ) await provider.chat( messages=[{"role": "user", "content": "hello"}], - model="openrouter/free", + model="deepseek-chat", ) - call_kwargs = mock_acompletion.call_args.kwargs - assert call_kwargs["model"] == "openrouter/openrouter/free", ( - "openrouter/free must become openrouter/openrouter/free — " - "LiteLLM strips one layer so the API receives openrouter/free" - ) + call_kwargs = mock_create.call_args.kwargs + assert call_kwargs["model"] == "deepseek-chat" + + +def test_openai_model_passthrough() -> None: + """OpenAI models pass through unchanged.""" + spec = find_by_name("openai") + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): + provider = OpenAICompatProvider( + api_key="sk-test-key", + default_model="gpt-4o", + spec=spec, + ) + assert provider.get_default_model() == "gpt-4o" diff --git a/tests/providers/test_mistral_provider.py b/tests/providers/test_mistral_provider.py index 401122178..30023afe7 100644 --- a/tests/providers/test_mistral_provider.py +++ b/tests/providers/test_mistral_provider.py @@ -17,6 +17,4 @@ def test_mistral_provider_in_registry(): mistral = specs["mistral"] assert mistral.env_key == "MISTRAL_API_KEY" - assert mistral.litellm_prefix == "mistral" assert mistral.default_api_base == "https://api.mistral.ai/v1" - assert "mistral/" in mistral.skip_prefixes diff --git a/tests/providers/test_providers_init.py b/tests/providers/test_providers_init.py index 02ab7c1ef..32cbab478 100644 --- a/tests/providers/test_providers_init.py +++ b/tests/providers/test_providers_init.py @@ -8,19 +8,22 @@ import sys def test_importing_providers_package_is_lazy(monkeypatch) -> None: monkeypatch.delitem(sys.modules, "nanobot.providers", raising=False) - monkeypatch.delitem(sys.modules, "nanobot.providers.litellm_provider", raising=False) + monkeypatch.delitem(sys.modules, "nanobot.providers.anthropic_provider", raising=False) + monkeypatch.delitem(sys.modules, "nanobot.providers.openai_compat_provider", raising=False) monkeypatch.delitem(sys.modules, "nanobot.providers.openai_codex_provider", raising=False) monkeypatch.delitem(sys.modules, "nanobot.providers.azure_openai_provider", raising=False) providers = importlib.import_module("nanobot.providers") - assert "nanobot.providers.litellm_provider" not in sys.modules + assert "nanobot.providers.anthropic_provider" not in sys.modules + assert "nanobot.providers.openai_compat_provider" not in sys.modules assert "nanobot.providers.openai_codex_provider" not in sys.modules assert "nanobot.providers.azure_openai_provider" not in sys.modules assert providers.__all__ == [ "LLMProvider", "LLMResponse", - "LiteLLMProvider", + "AnthropicProvider", + "OpenAICompatProvider", "OpenAICodexProvider", "AzureOpenAIProvider", ] @@ -28,10 +31,10 @@ def test_importing_providers_package_is_lazy(monkeypatch) -> None: def test_explicit_provider_import_still_works(monkeypatch) -> None: monkeypatch.delitem(sys.modules, "nanobot.providers", raising=False) - monkeypatch.delitem(sys.modules, "nanobot.providers.litellm_provider", raising=False) + monkeypatch.delitem(sys.modules, "nanobot.providers.anthropic_provider", raising=False) namespace: dict[str, object] = {} - exec("from nanobot.providers import LiteLLMProvider", namespace) + exec("from nanobot.providers import AnthropicProvider", namespace) - assert namespace["LiteLLMProvider"].__name__ == "LiteLLMProvider" - assert "nanobot.providers.litellm_provider" in sys.modules + assert namespace["AnthropicProvider"].__name__ == "AnthropicProvider" + assert "nanobot.providers.anthropic_provider" in sys.modules