mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-20 00:22:31 +00:00
Resolve fallbackModels as preset references or explicit inline provider configs so failover uses complete model settings without exposing fallback logic to the agent loop. Co-authored-by: Cursor <cursoragent@cursor.com>
242 lines
8.5 KiB
Python
242 lines
8.5 KiB
Python
"""Create LLM providers from config."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
|
|
from nanobot.config.schema import Config, InlineFallbackConfig, ModelPresetConfig
|
|
from nanobot.providers.base import LLMProvider
|
|
from nanobot.providers.fallback_provider import FallbackProvider
|
|
from nanobot.providers.registry import find_by_name
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ProviderSnapshot:
|
|
provider: LLMProvider
|
|
model: str
|
|
context_window_tokens: int
|
|
signature: tuple[object, ...]
|
|
|
|
|
|
def _resolve_model_preset(
|
|
config: Config,
|
|
*,
|
|
preset_name: str | None = None,
|
|
preset: ModelPresetConfig | None = None,
|
|
) -> ModelPresetConfig:
|
|
return preset if preset is not None else config.resolve_preset(preset_name)
|
|
|
|
|
|
def _make_provider_core(
|
|
config: Config,
|
|
*,
|
|
preset_name: str | None = None,
|
|
preset: ModelPresetConfig | None = None,
|
|
model: str | None = None,
|
|
) -> LLMProvider:
|
|
"""Create a plain LLM provider without failover wrapping."""
|
|
resolved = _resolve_model_preset(config, preset_name=preset_name, preset=preset)
|
|
model = model or resolved.model
|
|
provider_name = config.get_provider_name(model, preset=resolved)
|
|
p = config.get_provider(model, preset=resolved)
|
|
spec = find_by_name(provider_name) if provider_name else None
|
|
backend = spec.backend if spec else "openai_compat"
|
|
|
|
if backend == "azure_openai":
|
|
if not p or not p.api_key or not p.api_base:
|
|
raise ValueError("Azure OpenAI requires api_key and api_base in config.")
|
|
elif backend == "openai_compat" and not model.startswith("bedrock/"):
|
|
needs_key = not (p and p.api_key)
|
|
exempt = spec and (spec.is_oauth or spec.is_local or spec.is_direct)
|
|
if needs_key and not exempt:
|
|
raise ValueError(f"No API key configured for provider '{provider_name}'.")
|
|
|
|
if backend == "openai_codex":
|
|
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
|
|
|
|
provider = OpenAICodexProvider(default_model=model)
|
|
elif backend == "azure_openai":
|
|
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
|
|
|
provider = AzureOpenAIProvider(
|
|
api_key=p.api_key,
|
|
api_base=p.api_base,
|
|
default_model=model,
|
|
)
|
|
elif backend == "github_copilot":
|
|
from nanobot.providers.github_copilot_provider import GitHubCopilotProvider
|
|
|
|
provider = GitHubCopilotProvider(default_model=model)
|
|
elif backend == "anthropic":
|
|
from nanobot.providers.anthropic_provider import AnthropicProvider
|
|
|
|
provider = AnthropicProvider(
|
|
api_key=p.api_key if p else None,
|
|
api_base=config.get_api_base(model, preset=resolved),
|
|
default_model=model,
|
|
extra_headers=p.extra_headers if p else None,
|
|
)
|
|
elif backend == "bedrock":
|
|
from nanobot.providers.bedrock_provider import BedrockProvider
|
|
|
|
provider = BedrockProvider(
|
|
api_key=p.api_key if p else None,
|
|
api_base=p.api_base if p else None,
|
|
default_model=model,
|
|
region=getattr(p, "region", None) if p else None,
|
|
profile=getattr(p, "profile", None) if p else None,
|
|
extra_body=p.extra_body if p else None,
|
|
)
|
|
else:
|
|
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
|
|
|
provider = OpenAICompatProvider(
|
|
api_key=p.api_key if p else None,
|
|
api_base=config.get_api_base(model, preset=resolved),
|
|
default_model=model,
|
|
extra_headers=p.extra_headers if p else None,
|
|
spec=spec,
|
|
extra_body=p.extra_body if p else None,
|
|
)
|
|
|
|
provider.generation = resolved.to_generation_settings()
|
|
return provider
|
|
|
|
|
|
def _inline_fallback_preset(
|
|
primary: ModelPresetConfig,
|
|
fallback: InlineFallbackConfig,
|
|
) -> ModelPresetConfig:
|
|
return ModelPresetConfig(
|
|
model=fallback.model,
|
|
provider=fallback.provider,
|
|
max_tokens=fallback.max_tokens if fallback.max_tokens is not None else primary.max_tokens,
|
|
context_window_tokens=(
|
|
fallback.context_window_tokens
|
|
if fallback.context_window_tokens is not None
|
|
else primary.context_window_tokens
|
|
),
|
|
temperature=(
|
|
fallback.temperature if fallback.temperature is not None else primary.temperature
|
|
),
|
|
reasoning_effort=fallback.reasoning_effort,
|
|
)
|
|
|
|
|
|
def _resolve_fallback_presets(config: Config, primary: ModelPresetConfig) -> list[ModelPresetConfig]:
|
|
presets: list[ModelPresetConfig] = []
|
|
for fallback in config.agents.defaults.fallback_models:
|
|
if isinstance(fallback, str):
|
|
presets.append(config.model_presets[fallback])
|
|
else:
|
|
presets.append(_inline_fallback_preset(primary, fallback))
|
|
return presets
|
|
|
|
|
|
def make_provider(
|
|
config: Config,
|
|
*,
|
|
preset_name: str | None = None,
|
|
preset: ModelPresetConfig | None = None,
|
|
model: str | None = None,
|
|
) -> LLMProvider:
|
|
"""Create the LLM provider implied by config.
|
|
|
|
When *model* is given, it overrides the resolved/preset model — used by
|
|
the failover path to create providers for fallback models.
|
|
"""
|
|
resolved = _resolve_model_preset(config, preset_name=preset_name, preset=preset)
|
|
provider = _make_provider_core(config, preset_name=preset_name, preset=preset, model=model)
|
|
fallback_presets = _resolve_fallback_presets(config, resolved)
|
|
|
|
if fallback_presets:
|
|
provider = FallbackProvider(
|
|
primary=provider,
|
|
fallback_presets=fallback_presets,
|
|
provider_factory=lambda fb: _make_provider_core(
|
|
config, preset_name=preset_name, preset=fb
|
|
),
|
|
)
|
|
|
|
return provider
|
|
|
|
|
|
def provider_signature(
|
|
config: Config,
|
|
*,
|
|
preset_name: str | None = None,
|
|
preset: ModelPresetConfig | None = None,
|
|
) -> tuple[object, ...]:
|
|
"""Return the config fields that affect the active provider chain."""
|
|
resolved = _resolve_model_preset(config, preset_name=preset_name, preset=preset)
|
|
p = config.get_provider(resolved.model, preset=resolved)
|
|
fallback_presets = _resolve_fallback_presets(config, resolved)
|
|
|
|
def _fallback_signature(fallback: ModelPresetConfig) -> tuple[object, ...]:
|
|
fp = config.get_provider(fallback.model, preset=fallback)
|
|
return (
|
|
fallback.model,
|
|
fallback.provider,
|
|
config.get_provider_name(fallback.model, preset=fallback),
|
|
config.get_api_key(fallback.model, preset=fallback),
|
|
config.get_api_base(fallback.model, preset=fallback),
|
|
fp.extra_headers if fp else None,
|
|
fp.extra_body if fp else None,
|
|
getattr(fp, "region", None) if fp else None,
|
|
getattr(fp, "profile", None) if fp else None,
|
|
fallback.max_tokens,
|
|
fallback.temperature,
|
|
fallback.reasoning_effort,
|
|
fallback.context_window_tokens,
|
|
)
|
|
|
|
return (
|
|
resolved.model,
|
|
resolved.provider,
|
|
config.get_provider_name(resolved.model, preset=resolved),
|
|
config.get_api_key(resolved.model, preset=resolved),
|
|
config.get_api_base(resolved.model, preset=resolved),
|
|
p.extra_headers if p else None,
|
|
p.extra_body if p else None,
|
|
getattr(p, "region", None) if p else None,
|
|
getattr(p, "profile", None) if p else None,
|
|
resolved.max_tokens,
|
|
resolved.temperature,
|
|
resolved.reasoning_effort,
|
|
resolved.context_window_tokens,
|
|
tuple(_fallback_signature(fallback) for fallback in fallback_presets),
|
|
)
|
|
|
|
|
|
def build_provider_snapshot(
|
|
config: Config,
|
|
*,
|
|
preset_name: str | None = None,
|
|
preset: ModelPresetConfig | None = None,
|
|
) -> ProviderSnapshot:
|
|
resolved = _resolve_model_preset(config, preset_name=preset_name, preset=preset)
|
|
fallback_windows = [
|
|
fallback.context_window_tokens
|
|
for fallback in _resolve_fallback_presets(config, resolved)
|
|
]
|
|
return ProviderSnapshot(
|
|
provider=make_provider(config, preset=resolved),
|
|
model=resolved.model,
|
|
context_window_tokens=min([resolved.context_window_tokens, *fallback_windows]),
|
|
signature=provider_signature(config, preset=resolved),
|
|
)
|
|
|
|
|
|
def load_provider_snapshot(
|
|
config_path: Path | None = None,
|
|
*,
|
|
preset_name: str | None = None,
|
|
) -> ProviderSnapshot:
|
|
from nanobot.config.loader import load_config, resolve_config_env_vars
|
|
|
|
return build_provider_snapshot(
|
|
resolve_config_env_vars(load_config(config_path)),
|
|
preset_name=preset_name,
|
|
)
|