Xubin Ren 02b059a616 feat(runner): support structured fallback models
Bind fallback model chains to the active model configuration so defaults and presets do not inherit or merge fallback behavior implicitly. Require explicit fallback providers while preserving per-fallback generation overrides and context-window safety.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-13 13:57:30 +00:00

236 lines
8.4 KiB
Python

"""Create LLM providers from config."""
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from nanobot.config.schema import Config, ModelFallbackConfig, ModelPresetConfig
from nanobot.providers.base import LLMProvider
from nanobot.providers.fallback_provider import FallbackProvider
from nanobot.providers.registry import find_by_name
@dataclass(frozen=True)
class ProviderSnapshot:
provider: LLMProvider
model: str
context_window_tokens: int
signature: tuple[object, ...]
def _resolve_model_preset(
config: Config,
*,
preset_name: str | None = None,
preset: ModelPresetConfig | None = None,
) -> ModelPresetConfig:
return preset if preset is not None else config.resolve_preset(preset_name)
def _make_provider_core(
config: Config,
*,
preset_name: str | None = None,
preset: ModelPresetConfig | None = None,
model: str | None = None,
) -> LLMProvider:
"""Create a plain LLM provider without failover wrapping."""
resolved = _resolve_model_preset(config, preset_name=preset_name, preset=preset)
model = model or resolved.model
provider_name = config.get_provider_name(model, preset=resolved)
p = config.get_provider(model, preset=resolved)
spec = find_by_name(provider_name) if provider_name else None
backend = spec.backend if spec else "openai_compat"
if backend == "azure_openai":
if not p or not p.api_key or not p.api_base:
raise ValueError("Azure OpenAI requires api_key and api_base in config.")
elif backend == "openai_compat" and not model.startswith("bedrock/"):
needs_key = not (p and p.api_key)
exempt = spec and (spec.is_oauth or spec.is_local or spec.is_direct)
if needs_key and not exempt:
raise ValueError(f"No API key configured for provider '{provider_name}'.")
if backend == "openai_codex":
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
provider = OpenAICodexProvider(default_model=model)
elif backend == "azure_openai":
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
provider = AzureOpenAIProvider(
api_key=p.api_key,
api_base=p.api_base,
default_model=model,
)
elif backend == "github_copilot":
from nanobot.providers.github_copilot_provider import GitHubCopilotProvider
provider = GitHubCopilotProvider(default_model=model)
elif backend == "anthropic":
from nanobot.providers.anthropic_provider import AnthropicProvider
provider = AnthropicProvider(
api_key=p.api_key if p else None,
api_base=config.get_api_base(model, preset=resolved),
default_model=model,
extra_headers=p.extra_headers if p else None,
)
elif backend == "bedrock":
from nanobot.providers.bedrock_provider import BedrockProvider
provider = BedrockProvider(
api_key=p.api_key if p else None,
api_base=p.api_base if p else None,
default_model=model,
region=getattr(p, "region", None) if p else None,
profile=getattr(p, "profile", None) if p else None,
extra_body=p.extra_body if p else None,
)
else:
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
provider = OpenAICompatProvider(
api_key=p.api_key if p else None,
api_base=config.get_api_base(model, preset=resolved),
default_model=model,
extra_headers=p.extra_headers if p else None,
spec=spec,
extra_body=p.extra_body if p else None,
)
provider.generation = resolved.to_generation_settings()
return provider
def _fallback_preset(primary: ModelPresetConfig, fallback: ModelFallbackConfig) -> ModelPresetConfig:
"""Build the effective provider/generation config for one fallback model."""
return ModelPresetConfig(
model=fallback.model,
provider=fallback.provider,
max_tokens=fallback.max_tokens if fallback.max_tokens is not None else primary.max_tokens,
context_window_tokens=(
fallback.context_window_tokens
if fallback.context_window_tokens is not None
else primary.context_window_tokens
),
temperature=(
fallback.temperature if fallback.temperature is not None else primary.temperature
),
reasoning_effort=(
fallback.reasoning_effort
if fallback.reasoning_effort is not None
else primary.reasoning_effort
),
)
def make_provider(
config: Config,
*,
preset_name: str | None = None,
preset: ModelPresetConfig | None = None,
model: str | None = None,
) -> LLMProvider:
"""Create the LLM provider implied by config.
When *model* is given, it overrides the resolved/preset model — used by
the failover path to create providers for fallback models.
"""
resolved = _resolve_model_preset(config, preset_name=preset_name, preset=preset)
provider = _make_provider_core(config, preset_name=preset_name, preset=preset, model=model)
if resolved.fallback_models:
provider = FallbackProvider(
primary=provider,
fallback_models=resolved.fallback_models,
provider_factory=lambda fb: _make_provider_core(
config, preset_name=preset_name, preset=_fallback_preset(resolved, fb)
),
)
return provider
def provider_signature(
config: Config,
*,
preset_name: str | None = None,
preset: ModelPresetConfig | None = None,
) -> tuple[object, ...]:
"""Return the config fields that affect the active provider chain."""
resolved = _resolve_model_preset(config, preset_name=preset_name, preset=preset)
p = config.get_provider(resolved.model, preset=resolved)
def _fallback_signature(fallback: ModelFallbackConfig) -> tuple[object, ...]:
fallback_preset = _fallback_preset(resolved, fallback)
fp = config.get_provider(fallback.model, preset=fallback_preset)
return (
fallback.model,
fallback.provider,
fallback_preset.max_tokens,
fallback_preset.temperature,
fallback_preset.reasoning_effort,
fallback_preset.context_window_tokens,
config.get_provider_name(fallback.model, preset=fallback_preset),
config.get_api_key(fallback.model, preset=fallback_preset),
config.get_api_base(fallback.model, preset=fallback_preset),
fp.extra_headers if fp else None,
fp.extra_body if fp else None,
getattr(fp, "region", None) if fp else None,
getattr(fp, "profile", None) if fp else None,
)
fallback_signatures = tuple(
_fallback_signature(fallback) for fallback in resolved.fallback_models
)
return (
resolved.model,
resolved.provider,
config.get_provider_name(resolved.model, preset=resolved),
config.get_api_key(resolved.model, preset=resolved),
config.get_api_base(resolved.model, preset=resolved),
p.extra_headers if p else None,
p.extra_body if p else None,
getattr(p, "region", None) if p else None,
getattr(p, "profile", None) if p else None,
resolved.max_tokens,
resolved.temperature,
resolved.reasoning_effort,
resolved.context_window_tokens,
fallback_signatures,
)
def build_provider_snapshot(
config: Config,
*,
preset_name: str | None = None,
preset: ModelPresetConfig | None = None,
) -> ProviderSnapshot:
resolved = _resolve_model_preset(config, preset_name=preset_name, preset=preset)
fallback_windows = [
_fallback_preset(resolved, fallback).context_window_tokens
for fallback in resolved.fallback_models
]
return ProviderSnapshot(
provider=make_provider(config, preset=resolved),
model=resolved.model,
context_window_tokens=min([resolved.context_window_tokens, *fallback_windows]),
signature=provider_signature(config, preset=resolved),
)
def load_provider_snapshot(
config_path: Path | None = None,
*,
preset_name: str | None = None,
) -> ProviderSnapshot:
from nanobot.config.loader import load_config, resolve_config_env_vars
return build_provider_snapshot(
resolve_config_env_vars(load_config(config_path)),
preset_name=preset_name,
)