feat(runner): support fallback candidates

Resolve fallbackModels as preset references or explicit inline provider configs so failover uses complete model settings without exposing fallback logic to the agent loop.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
Xubin Ren 2026-05-13 15:34:03 +00:00
parent 43db848db0
commit 5efd67919b
5 changed files with 502 additions and 59 deletions

View File

@ -672,7 +672,8 @@ Existing configs do not need to change. If you do not set `modelPresets` or `age
"maxTokens": 8192, "maxTokens": 8192,
"contextWindowTokens": 128000, "contextWindowTokens": 128000,
"temperature": 0.1, "temperature": 0.1,
"modelPreset": null "modelPreset": "fast",
"fallbackModels": ["deep"]
} }
}, },
"modelPresets": { "modelPresets": {
@ -708,6 +709,40 @@ Existing configs do not need to change. If you do not set `modelPresets` or `age
`default` is reserved and always means the implicit preset built from `agents.defaults.*`; do not define `modelPresets.default`. Use `/model default` to switch back to `agents.defaults.*`. `default` is reserved and always means the implicit preset built from `agents.defaults.*`; do not define `modelPresets.default`. Use `/model default` to switch back to `agents.defaults.*`.
### Model Fallbacks
`agents.defaults.fallbackModels` defines an ordered failover chain for the active model configuration. The primary model is still selected by `agents.defaults.modelPreset` (or the implicit default config when no preset is active).
Each fallback candidate can be either:
- A preset name from `modelPresets`, such as `"deep"`. The preset's full model, provider, generation, and context-window config is used.
- An inline fallback object with at least `provider` and `model`. Optional `maxTokens`, `contextWindowTokens`, and `temperature` fields inherit from the active primary config when omitted. `reasoningEffort` does not inherit; omit it to leave reasoning off for that fallback, or set it explicitly for models that support reasoning.
```json
{
"agents": {
"defaults": {
"modelPreset": "fast",
"fallbackModels": [
"deep",
{
"provider": "deepseek",
"model": "deepseek-v4-pro",
"maxTokens": 4096,
"contextWindowTokens": 262144
}
]
}
}
}
```
String entries are preset names, not raw model names. If you want to use a model that is not already a preset, use the inline object form.
Failover only runs when the primary provider returns a retryable model/provider error before any answer text has been streamed. Typical fallback cases include timeouts, connection errors, 5xx server errors, 429 rate limits, overloads, and quota/balance exhaustion. It does not run for malformed requests, authentication/permission errors, content filtering/refusals, or context-length/message-format errors.
If fallback candidates use smaller `contextWindowTokens` values, nanobot builds context using the smallest window in the active chain so every candidate can receive the same prompt.
Set `agents.defaults.modelPreset` to start with a named preset: Set `agents.defaults.modelPreset` to start with a named preset:
```json ```json

View File

@ -74,6 +74,20 @@ class DreamConfig(Base):
return f"every {hours}h" return f"every {hours}h"
class InlineFallbackConfig(Base):
"""One inline fallback model configuration."""
model: str
provider: str
max_tokens: int | None = None
context_window_tokens: int | None = None
temperature: float | None = None
reasoning_effort: str | None = None
FallbackCandidate = str | InlineFallbackConfig
class ModelPresetConfig(Base): class ModelPresetConfig(Base):
"""A named set of model + generation parameters for quick switching.""" """A named set of model + generation parameters for quick switching."""
@ -83,7 +97,6 @@ class ModelPresetConfig(Base):
context_window_tokens: int = 65_536 context_window_tokens: int = 65_536
temperature: float = 0.1 temperature: float = 0.1
reasoning_effort: str | None = None reasoning_effort: str | None = None
fallback_models: list[str] = Field(default_factory=list)
def to_generation_settings(self) -> Any: def to_generation_settings(self) -> Any:
from nanobot.providers.base import GenerationSettings from nanobot.providers.base import GenerationSettings
@ -107,6 +120,7 @@ class AgentDefaults(Base):
context_window_tokens: int = 65_536 context_window_tokens: int = 65_536
context_block_limit: int | None = None context_block_limit: int | None = None
temperature: float = 0.1 temperature: float = 0.1
fallback_models: list[FallbackCandidate] = Field(default_factory=list)
max_tool_iterations: int = 200 max_tool_iterations: int = 200
max_concurrent_subagents: int = Field(default=1, ge=1) max_concurrent_subagents: int = Field(default=1, ge=1)
max_tool_result_chars: int = 16_000 max_tool_result_chars: int = 16_000
@ -288,6 +302,9 @@ class Config(BaseSettings):
name = self.agents.defaults.model_preset name = self.agents.defaults.model_preset
if name and name != "default" and name not in self.model_presets: if name and name != "default" and name not in self.model_presets:
raise ValueError(f"model_preset {name!r} not found in model_presets") raise ValueError(f"model_preset {name!r} not found in model_presets")
for fallback in self.agents.defaults.fallback_models:
if isinstance(fallback, str) and fallback not in self.model_presets:
raise ValueError(f"fallback_models entry {fallback!r} not found in model_presets")
return self return self
def resolve_default_preset(self) -> ModelPresetConfig: def resolve_default_preset(self) -> ModelPresetConfig:

View File

@ -5,7 +5,7 @@ from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from nanobot.config.schema import Config, ModelPresetConfig from nanobot.config.schema import Config, InlineFallbackConfig, ModelPresetConfig
from nanobot.providers.base import LLMProvider from nanobot.providers.base import LLMProvider
from nanobot.providers.fallback_provider import FallbackProvider from nanobot.providers.fallback_provider import FallbackProvider
from nanobot.providers.registry import find_by_name from nanobot.providers.registry import find_by_name
@ -104,6 +104,36 @@ def _make_provider_core(
return provider return provider
def _inline_fallback_preset(
primary: ModelPresetConfig,
fallback: InlineFallbackConfig,
) -> ModelPresetConfig:
return ModelPresetConfig(
model=fallback.model,
provider=fallback.provider,
max_tokens=fallback.max_tokens if fallback.max_tokens is not None else primary.max_tokens,
context_window_tokens=(
fallback.context_window_tokens
if fallback.context_window_tokens is not None
else primary.context_window_tokens
),
temperature=(
fallback.temperature if fallback.temperature is not None else primary.temperature
),
reasoning_effort=fallback.reasoning_effort,
)
def _resolve_fallback_presets(config: Config, primary: ModelPresetConfig) -> list[ModelPresetConfig]:
presets: list[ModelPresetConfig] = []
for fallback in config.agents.defaults.fallback_models:
if isinstance(fallback, str):
presets.append(config.model_presets[fallback])
else:
presets.append(_inline_fallback_preset(primary, fallback))
return presets
def make_provider( def make_provider(
config: Config, config: Config,
*, *,
@ -118,14 +148,14 @@ def make_provider(
""" """
resolved = _resolve_model_preset(config, preset_name=preset_name, preset=preset) resolved = _resolve_model_preset(config, preset_name=preset_name, preset=preset)
provider = _make_provider_core(config, preset_name=preset_name, preset=preset, model=model) provider = _make_provider_core(config, preset_name=preset_name, preset=preset, model=model)
fallback_presets = _resolve_fallback_presets(config, resolved)
if resolved.fallback_models: if fallback_presets:
fb_preset = resolved.model_copy(update={"provider": "auto", "fallback_models": []})
provider = FallbackProvider( provider = FallbackProvider(
primary=provider, primary=provider,
fallback_models=resolved.fallback_models, fallback_presets=fallback_presets,
provider_factory=lambda m: _make_provider_core( provider_factory=lambda fb: _make_provider_core(
config, preset_name=preset_name, preset=fb_preset, model=m config, preset_name=preset_name, preset=fb
), ),
) )
@ -138,9 +168,29 @@ def provider_signature(
preset_name: str | None = None, preset_name: str | None = None,
preset: ModelPresetConfig | None = None, preset: ModelPresetConfig | None = None,
) -> tuple[object, ...]: ) -> tuple[object, ...]:
"""Return the config fields that affect the primary LLM provider.""" """Return the config fields that affect the active provider chain."""
resolved = _resolve_model_preset(config, preset_name=preset_name, preset=preset) resolved = _resolve_model_preset(config, preset_name=preset_name, preset=preset)
p = config.get_provider(resolved.model, preset=resolved) p = config.get_provider(resolved.model, preset=resolved)
fallback_presets = _resolve_fallback_presets(config, resolved)
def _fallback_signature(fallback: ModelPresetConfig) -> tuple[object, ...]:
fp = config.get_provider(fallback.model, preset=fallback)
return (
fallback.model,
fallback.provider,
config.get_provider_name(fallback.model, preset=fallback),
config.get_api_key(fallback.model, preset=fallback),
config.get_api_base(fallback.model, preset=fallback),
fp.extra_headers if fp else None,
fp.extra_body if fp else None,
getattr(fp, "region", None) if fp else None,
getattr(fp, "profile", None) if fp else None,
fallback.max_tokens,
fallback.temperature,
fallback.reasoning_effort,
fallback.context_window_tokens,
)
return ( return (
resolved.model, resolved.model,
resolved.provider, resolved.provider,
@ -155,6 +205,7 @@ def provider_signature(
resolved.temperature, resolved.temperature,
resolved.reasoning_effort, resolved.reasoning_effort,
resolved.context_window_tokens, resolved.context_window_tokens,
tuple(_fallback_signature(fallback) for fallback in fallback_presets),
) )
@ -165,10 +216,14 @@ def build_provider_snapshot(
preset: ModelPresetConfig | None = None, preset: ModelPresetConfig | None = None,
) -> ProviderSnapshot: ) -> ProviderSnapshot:
resolved = _resolve_model_preset(config, preset_name=preset_name, preset=preset) resolved = _resolve_model_preset(config, preset_name=preset_name, preset=preset)
fallback_windows = [
fallback.context_window_tokens
for fallback in _resolve_fallback_presets(config, resolved)
]
return ProviderSnapshot( return ProviderSnapshot(
provider=make_provider(config, preset=resolved), provider=make_provider(config, preset=resolved),
model=resolved.model, model=resolved.model,
context_window_tokens=resolved.context_window_tokens, context_window_tokens=min([resolved.context_window_tokens, *fallback_windows]),
signature=provider_signature(config, preset=resolved), signature=provider_signature(config, preset=resolved),
) )

View File

@ -13,6 +13,46 @@ from nanobot.providers.base import LLMProvider, LLMResponse
# Circuit breaker tuned to match OpenAICompatProvider's Responses API breaker. # Circuit breaker tuned to match OpenAICompatProvider's Responses API breaker.
_PRIMARY_FAILURE_THRESHOLD = 3 _PRIMARY_FAILURE_THRESHOLD = 3
_PRIMARY_COOLDOWN_S = 60 _PRIMARY_COOLDOWN_S = 60
_MISSING = object()
_FALLBACK_ERROR_KINDS = frozenset({
"timeout",
"connection",
"server_error",
"rate_limit",
"overloaded",
})
_NON_FALLBACK_ERROR_KINDS = frozenset({
"authentication",
"auth",
"permission",
"content_filter",
"refusal",
"context_length",
"invalid_request",
})
_FALLBACK_ERROR_TOKENS = (
"rate_limit",
"rate limit",
"too_many_requests",
"too many requests",
"overloaded",
"server_error",
"server error",
"temporarily unavailable",
"timeout",
"timed out",
"connection",
"insufficient_quota",
"insufficient quota",
"quota_exceeded",
"quota exceeded",
"quota_exhausted",
"quota exhausted",
"billing_hard_limit",
"insufficient_balance",
"balance",
"out of credits",
)
class FallbackProvider(LLMProvider): class FallbackProvider(LLMProvider):
@ -34,13 +74,13 @@ class FallbackProvider(LLMProvider):
def __init__( def __init__(
self, self,
primary: LLMProvider, primary: LLMProvider,
fallback_models: list[str], fallback_presets: list[Any],
provider_factory: Callable[[str], LLMProvider], provider_factory: Callable[[Any], LLMProvider],
): ):
self._primary = primary self._primary = primary
self._fallback_models = list(fallback_models) self._fallback_presets = list(fallback_presets)
self._provider_factory = provider_factory self._provider_factory = provider_factory
self._has_fallbacks = bool(fallback_models) self._has_fallbacks = bool(fallback_presets)
self._primary_failures = 0 self._primary_failures = 0
self._primary_tripped_at: float | None = None self._primary_tripped_at: float | None = None
@ -55,6 +95,10 @@ class FallbackProvider(LLMProvider):
def get_default_model(self) -> str: def get_default_model(self) -> str:
return self._primary.get_default_model() return self._primary.get_default_model()
@property
def supports_progress_deltas(self) -> bool:
return bool(getattr(self._primary, "supports_progress_deltas", False))
def _primary_available(self) -> bool: def _primary_available(self) -> bool:
"""Return True if the primary provider is not currently tripped.""" """Return True if the primary provider is not currently tripped."""
if self._primary_tripped_at is None: if self._primary_tripped_at is None:
@ -110,6 +154,14 @@ class FallbackProvider(LLMProvider):
) )
return response return response
if not self._should_fallback(response):
logger.warning(
"Primary model '{}' returned non-fallbackable error: {}",
primary_model,
(response.content or "")[:120],
)
return response
self._primary_failures += 1 self._primary_failures += 1
if self._primary_failures >= _PRIMARY_FAILURE_THRESHOLD: if self._primary_failures >= _PRIMARY_FAILURE_THRESHOLD:
self._primary_tripped_at = time.monotonic() self._primary_tripped_at = time.monotonic()
@ -122,7 +174,8 @@ class FallbackProvider(LLMProvider):
last_response: LLMResponse | None = None last_response: LLMResponse | None = None
primary_skipped = not self._primary_available() primary_skipped = not self._primary_available()
for idx, fallback_model in enumerate(self._fallback_models): for idx, fallback in enumerate(self._fallback_presets):
fallback_model = fallback.model
if has_streamed is not None and has_streamed[0]: if has_streamed is not None and has_streamed[0]:
break break
if idx == 0 and primary_skipped: if idx == 0 and primary_skipped:
@ -138,25 +191,35 @@ class FallbackProvider(LLMProvider):
else: else:
logger.info( logger.info(
"Fallback '{}' also failed, trying next fallback '{}'", "Fallback '{}' also failed, trying next fallback '{}'",
self._fallback_models[idx - 1], fallback_model, self._fallback_presets[idx - 1].model, fallback_model,
) )
try: try:
fallback_provider = self._provider_factory(fallback_model) fallback_provider = self._provider_factory(fallback)
except Exception as exc: except Exception as exc:
logger.warning( logger.warning(
"Failed to create provider for fallback '{}': {}", fallback_model, exc "Failed to create provider for fallback '{}': {}", fallback_model, exc
) )
continue continue
original_model = kwargs.get("model") original_values = {
name: kwargs.get(name, _MISSING)
for name in ("model", "max_tokens", "temperature", "reasoning_effort")
}
kwargs["model"] = fallback_model kwargs["model"] = fallback_model
kwargs["max_tokens"] = fallback.max_tokens
kwargs["temperature"] = fallback.temperature
if fallback.reasoning_effort is None:
kwargs.pop("reasoning_effort", None)
else:
kwargs["reasoning_effort"] = fallback.reasoning_effort
try: try:
fallback_response = await call(fallback_provider, kwargs) fallback_response = await call(fallback_provider, kwargs)
finally: finally:
if original_model is not None: for name, value in original_values.items():
kwargs["model"] = original_model if value is _MISSING:
kwargs.pop(name, None)
else: else:
kwargs.pop("model", None) kwargs[name] = value
if fallback_response.finish_reason != "error": if fallback_response.finish_reason != "error":
logger.info( logger.info(
@ -174,7 +237,7 @@ class FallbackProvider(LLMProvider):
logger.warning( logger.warning(
"All {} fallback model(s) failed", "All {} fallback model(s) failed",
len(self._fallback_models), len(self._fallback_presets),
) )
# Return the last error response we saw (primary or last fallback). # Return the last error response we saw (primary or last fallback).
if last_response is not None: if last_response is not None:
@ -184,3 +247,27 @@ class FallbackProvider(LLMProvider):
content=f"Primary model '{primary_model}' circuit open and no fallbacks available", content=f"Primary model '{primary_model}' circuit open and no fallbacks available",
finish_reason="error", finish_reason="error",
) )
@staticmethod
def _should_fallback(response: LLMResponse) -> bool:
if response.error_should_retry is False:
return False
status = response.error_status_code
kind = (response.error_kind or "").lower()
error_type = (response.error_type or "").lower()
code = (response.error_code or "").lower()
text = (response.content or "").lower()
if status in {400, 401, 403, 404, 422}:
return False
if kind in _NON_FALLBACK_ERROR_KINDS:
return False
if any(token in value for value in (kind, error_type, code) for token in _NON_FALLBACK_ERROR_KINDS):
return False
if response.error_should_retry is True:
return True
if status is not None and (status in {408, 409, 429} or 500 <= status <= 599):
return True
if kind in _FALLBACK_ERROR_KINDS:
return True
return any(token in value for value in (kind, error_type, code, text) for token in _FALLBACK_ERROR_TOKENS)

View File

@ -3,10 +3,11 @@
from __future__ import annotations from __future__ import annotations
from typing import Any from typing import Any
from unittest.mock import MagicMock from unittest.mock import MagicMock, patch
import pytest import pytest
from nanobot.config.schema import ModelPresetConfig
from nanobot.providers.base import LLMProvider, LLMResponse from nanobot.providers.base import LLMProvider, LLMResponse
from nanobot.providers.fallback_provider import FallbackProvider from nanobot.providers.fallback_provider import FallbackProvider
@ -16,14 +17,45 @@ def _make_response(
finish_reason: str = "stop", finish_reason: str = "stop",
*, *,
error_kind: str | None = None, error_kind: str | None = None,
error_status_code: int | None = None,
error_type: str | None = None,
error_code: str | None = None,
error_should_retry: bool | None = None,
) -> LLMResponse: ) -> LLMResponse:
return LLMResponse(content=content, finish_reason=finish_reason, error_kind=error_kind) return LLMResponse(
content=content,
finish_reason=finish_reason,
error_kind=error_kind,
error_status_code=error_status_code,
error_type=error_type,
error_code=error_code,
error_should_retry=error_should_retry,
)
def _error_response(content: str = "api error") -> LLMResponse: def _error_response(content: str = "api error") -> LLMResponse:
return _make_response(content, finish_reason="error", error_kind="server_error") return _make_response(content, finish_reason="error", error_kind="server_error")
def _fallback(
model: str,
provider: str = "custom",
*,
max_tokens: int = 8192,
context_window_tokens: int = 65_536,
temperature: float = 0.1,
reasoning_effort: str | None = None,
) -> ModelPresetConfig:
return ModelPresetConfig(
model=model,
provider=provider,
max_tokens=max_tokens,
context_window_tokens=context_window_tokens,
temperature=temperature,
reasoning_effort=reasoning_effort,
)
class _FakeProvider(LLMProvider): class _FakeProvider(LLMProvider):
"""Fake provider for testing.""" """Fake provider for testing."""
@ -53,24 +85,163 @@ class _FakeProvider(LLMProvider):
def test_fallback_models_default_empty() -> None: def test_fallback_models_default_empty() -> None:
from nanobot.config.schema import ModelPresetConfig from nanobot.config.schema import AgentDefaults
p = ModelPresetConfig(model="test/model")
assert p.fallback_models == [] defaults = AgentDefaults()
assert defaults.fallback_models == []
def test_fallback_models_accepts_list() -> None: def test_fallback_models_accept_preset_refs_and_inline_configs() -> None:
from nanobot.config.schema import ModelPresetConfig from nanobot.config.schema import Config, InlineFallbackConfig
p = ModelPresetConfig(model="test/primary", fallback_models=["test/a", "test/b"])
assert p.fallback_models == ["test/a", "test/b"]
config = Config.model_validate({
def test_fallback_models_from_camel_case() -> None: "agents": {
from nanobot.config.schema import ModelPresetConfig "defaults": {
p = ModelPresetConfig.model_validate({ "fallbackModels": [
"model": "test/primary", "deep",
"fallbackModels": ["test/a"], {
"provider": "openai",
"model": "gpt-4.1",
"maxTokens": 4096,
},
]
}
},
"modelPresets": {
"deep": {"provider": "anthropic", "model": "claude-opus-4-7"}
},
}) })
assert p.fallback_models == ["test/a"]
assert config.agents.defaults.fallback_models[0] == "deep"
assert config.agents.defaults.fallback_models[1] == InlineFallbackConfig(
provider="openai",
model="gpt-4.1",
max_tokens=4096,
)
def test_fallback_model_preset_ref_must_exist() -> None:
from nanobot.config.schema import Config
with pytest.raises(ValueError, match="fallback_models.*not found"):
Config.model_validate({
"agents": {"defaults": {"fallbackModels": ["missing"]}},
"modelPresets": {},
})
def test_provider_signature_tracks_fallback_presets_and_provider_config() -> None:
from nanobot.config.schema import Config
from nanobot.providers.factory import provider_signature
base = {
"agents": {
"defaults": {
"modelPreset": "fast",
"fallbackModels": ["deep"],
}
},
"modelPresets": {
"fast": {"model": "openai/gpt-4.1", "provider": "openai"},
"deep": {"model": "anthropic/claude-sonnet-4-6", "provider": "anthropic"},
},
"providers": {
"openai": {"apiKey": "primary-key"},
"anthropic": {"apiKey": "fallback-key"},
},
}
changed_fallback = {
**base,
"agents": {"defaults": {"modelPreset": "fast", "fallbackModels": ["backup"]}},
"modelPresets": {
**base["modelPresets"],
"backup": {"model": "deepseek/deepseek-chat", "provider": "deepseek"},
},
"providers": {
**base["providers"],
"deepseek": {"apiKey": "deepseek-key"},
},
}
changed_key = {
**base,
"providers": {
"openai": {"apiKey": "primary-key"},
"anthropic": {"apiKey": "new-fallback-key"},
},
}
signature = provider_signature(Config.model_validate(base))
assert signature != provider_signature(Config.model_validate(changed_fallback))
assert signature != provider_signature(Config.model_validate(changed_key))
def test_provider_snapshot_uses_smallest_fallback_context_window() -> None:
from nanobot.config.schema import Config
from nanobot.providers.factory import build_provider_snapshot
config = Config.model_validate({
"agents": {
"defaults": {
"modelPreset": "fast",
"fallbackModels": ["deep"],
}
},
"modelPresets": {
"fast": {
"model": "openai/gpt-4.1",
"provider": "openai",
"contextWindowTokens": 128000,
},
"deep": {
"model": "deepseek/deepseek-chat",
"provider": "deepseek",
"contextWindowTokens": 64000,
},
},
"providers": {
"openai": {"apiKey": "primary-key"},
"deepseek": {"apiKey": "fallback-key"},
},
})
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
snapshot = build_provider_snapshot(config)
assert snapshot.context_window_tokens == 64000
def test_inline_fallback_reasoning_effort_does_not_inherit_primary() -> None:
from nanobot.config.schema import Config
from nanobot.providers.factory import provider_signature
config = Config.model_validate({
"agents": {
"defaults": {
"modelPreset": "fast",
"fallbackModels": [
{"provider": "openai", "model": "gpt-4.1"}
],
}
},
"modelPresets": {
"fast": {
"model": "anthropic/claude-opus-4-5",
"provider": "anthropic",
"reasoningEffort": "high",
}
},
"providers": {
"anthropic": {"apiKey": "primary-key"},
"openai": {"apiKey": "fallback-key"},
},
})
signature = provider_signature(config)
fallback_signatures = signature[-1]
assert fallback_signatures[0][11] is None
# -- FallbackProvider tests -- # -- FallbackProvider tests --
@ -83,7 +254,7 @@ class TestNoFallbackWhenPrimarySucceeds:
factory = MagicMock() factory = MagicMock()
fb = FallbackProvider( fb = FallbackProvider(
primary=primary, primary=primary,
fallback_models=["fallback-a"], fallback_presets=[_fallback("fallback-a")],
provider_factory=factory, provider_factory=factory,
) )
@ -102,14 +273,14 @@ class TestFallbackOnPrimaryError:
fb = FallbackProvider( fb = FallbackProvider(
primary=primary, primary=primary,
fallback_models=["fallback-a"], fallback_presets=[_fallback("fallback-a")],
provider_factory=factory, provider_factory=factory,
) )
result = await fb.chat(messages=[{"role": "user", "content": "hi"}], model="primary-model") result = await fb.chat(messages=[{"role": "user", "content": "hi"}], model="primary-model")
assert result.content == "fallback ok" assert result.content == "fallback ok"
assert result.finish_reason == "stop" assert result.finish_reason == "stop"
factory.assert_called_once_with("fallback-a") factory.assert_called_once_with(_fallback("fallback-a"))
assert primary.chat_calls[0]["model"] == "primary-model" assert primary.chat_calls[0]["model"] == "primary-model"
assert fallback.chat_calls[0]["model"] == "fallback-a" assert fallback.chat_calls[0]["model"] == "fallback-a"
@ -121,7 +292,7 @@ class TestNoFallbackWhenContentStreamed:
factory = MagicMock() factory = MagicMock()
fb = FallbackProvider( fb = FallbackProvider(
primary=primary, primary=primary,
fallback_models=["fallback-a"], fallback_presets=[_fallback("fallback-a")],
provider_factory=factory, provider_factory=factory,
) )
@ -146,14 +317,62 @@ class TestFailoverOnTransientError:
factory = MagicMock(return_value=fallback) factory = MagicMock(return_value=fallback)
fb = FallbackProvider( fb = FallbackProvider(
primary=primary, primary=primary,
fallback_models=["fallback-a"], fallback_presets=[_fallback("fallback-a")],
provider_factory=factory, provider_factory=factory,
) )
result = await fb.chat(messages=[{"role": "user", "content": "hi"}]) result = await fb.chat(messages=[{"role": "user", "content": "hi"}])
assert result.content == "fallback ok" assert result.content == "fallback ok"
assert result.finish_reason == "stop" assert result.finish_reason == "stop"
factory.assert_called_once_with("fallback-a") factory.assert_called_once_with(_fallback("fallback-a"))
class TestNoFallbackOnNonRetryableError:
@pytest.mark.asyncio
async def test_bad_request(self) -> None:
primary = _FakeProvider(
"primary",
_make_response(
"invalid request",
finish_reason="error",
error_status_code=400,
error_kind="invalid_request",
),
)
factory = MagicMock()
fb = FallbackProvider(
primary=primary,
fallback_presets=[_fallback("fallback-a")],
provider_factory=factory,
)
result = await fb.chat(messages=[{"role": "user", "content": "hi"}])
assert result.finish_reason == "error"
factory.assert_not_called()
@pytest.mark.asyncio
async def test_auth_error(self) -> None:
primary = _FakeProvider(
"primary",
_make_response(
"unauthorized",
finish_reason="error",
error_status_code=401,
error_kind="authentication",
),
)
factory = MagicMock()
fb = FallbackProvider(
primary=primary,
fallback_presets=[_fallback("fallback-a")],
provider_factory=factory,
)
result = await fb.chat(messages=[{"role": "user", "content": "hi"}])
assert result.finish_reason == "error"
factory.assert_not_called()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_timeout(self) -> None: async def test_timeout(self) -> None:
@ -165,14 +384,14 @@ class TestFailoverOnTransientError:
factory = MagicMock(return_value=fallback) factory = MagicMock(return_value=fallback)
fb = FallbackProvider( fb = FallbackProvider(
primary=primary, primary=primary,
fallback_models=["fallback-a"], fallback_presets=[_fallback("fallback-a")],
provider_factory=factory, provider_factory=factory,
) )
result = await fb.chat(messages=[{"role": "user", "content": "hi"}]) result = await fb.chat(messages=[{"role": "user", "content": "hi"}])
assert result.content == "fallback ok" assert result.content == "fallback ok"
assert result.finish_reason == "stop" assert result.finish_reason == "stop"
factory.assert_called_once_with("fallback-a") factory.assert_called_once_with(_fallback("fallback-a"))
class TestFallbackTriesModelsInOrder: class TestFallbackTriesModelsInOrder:
@ -185,15 +404,15 @@ class TestFallbackTriesModelsInOrder:
fb = FallbackProvider( fb = FallbackProvider(
primary=primary, primary=primary,
fallback_models=["fallback-a", "fallback-b"], fallback_presets=[_fallback("fallback-a"), _fallback("fallback-b")],
provider_factory=factory, provider_factory=factory,
) )
result = await fb.chat(messages=[{"role": "user", "content": "hi"}]) result = await fb.chat(messages=[{"role": "user", "content": "hi"}])
assert result.content == "b ok" assert result.content == "b ok"
assert factory.call_count == 2 assert factory.call_count == 2
factory.assert_any_call("fallback-a") factory.assert_any_call(_fallback("fallback-a"))
factory.assert_any_call("fallback-b") factory.assert_any_call(_fallback("fallback-b"))
class TestAllFallbacksFail: class TestAllFallbacksFail:
@ -205,7 +424,7 @@ class TestAllFallbacksFail:
fb = FallbackProvider( fb = FallbackProvider(
primary=primary, primary=primary,
fallback_models=["fallback-a"], fallback_presets=[_fallback("fallback-a")],
provider_factory=factory, provider_factory=factory,
) )
@ -223,7 +442,7 @@ class TestFactoryExceptionSkipsModel:
fb = FallbackProvider( fb = FallbackProvider(
primary=primary, primary=primary,
fallback_models=["fallback-a", "fallback-b"], fallback_presets=[_fallback("fallback-a"), _fallback("fallback-b")],
provider_factory=factory, provider_factory=factory,
) )
@ -242,13 +461,43 @@ class TestFallbackModelParameter:
fb = FallbackProvider( fb = FallbackProvider(
primary=primary, primary=primary,
fallback_models=["fallback-model"], fallback_presets=[_fallback("fallback-model")],
provider_factory=factory, provider_factory=factory,
) )
await fb.chat(messages=[{"role": "user", "content": "hi"}], model="primary-model") await fb.chat(messages=[{"role": "user", "content": "hi"}], model="primary-model")
assert fallback.chat_calls[0]["model"] == "fallback-model" assert fallback.chat_calls[0]["model"] == "fallback-model"
@pytest.mark.asyncio
async def test_uses_fallback_generation_fields(self) -> None:
primary = _FakeProvider("primary", _error_response())
fallback = _FakeProvider("fallback", _make_response("ok"))
fb = FallbackProvider(
primary=primary,
fallback_presets=[
_fallback(
"fallback-model",
max_tokens=1234,
temperature=0.4,
reasoning_effort=None,
)
],
provider_factory=MagicMock(return_value=fallback),
)
await fb.chat(
messages=[{"role": "user", "content": "hi"}],
model="primary-model",
max_tokens=8192,
temperature=0.1,
reasoning_effort="high",
)
assert fallback.chat_calls[0]["model"] == "fallback-model"
assert fallback.chat_calls[0]["max_tokens"] == 1234
assert fallback.chat_calls[0]["temperature"] == 0.4
assert "reasoning_effort" not in fallback.chat_calls[0]
class TestNoFallbackWhenEmptyList: class TestNoFallbackWhenEmptyList:
@pytest.mark.asyncio @pytest.mark.asyncio
@ -258,7 +507,7 @@ class TestNoFallbackWhenEmptyList:
fb = FallbackProvider( fb = FallbackProvider(
primary=primary, primary=primary,
fallback_models=[], fallback_presets=[],
provider_factory=factory, provider_factory=factory,
) )
@ -277,7 +526,7 @@ class TestChatStreamFailover:
fb = FallbackProvider( fb = FallbackProvider(
primary=primary, primary=primary,
fallback_models=["fallback-a"], fallback_presets=[_fallback("fallback-a")],
provider_factory=factory, provider_factory=factory,
) )
@ -291,7 +540,7 @@ class TestGetDefaultModel:
primary = _FakeProvider("primary") primary = _FakeProvider("primary")
fb = FallbackProvider( fb = FallbackProvider(
primary=primary, primary=primary,
fallback_models=["a"], fallback_presets=[_fallback("a")],
provider_factory=MagicMock(), provider_factory=MagicMock(),
) )
assert fb.get_default_model() == "primary/model" assert fb.get_default_model() == "primary/model"
@ -305,7 +554,7 @@ class TestCircuitBreaker:
factory = MagicMock(return_value=fallback) factory = MagicMock(return_value=fallback)
fb = FallbackProvider( fb = FallbackProvider(
primary=primary, primary=primary,
fallback_models=["fallback-a"], fallback_presets=[_fallback("fallback-a")],
provider_factory=factory, provider_factory=factory,
) )
@ -329,7 +578,7 @@ class TestCircuitBreaker:
factory = MagicMock(return_value=fallback) factory = MagicMock(return_value=fallback)
fb = FallbackProvider( fb = FallbackProvider(
primary=primary, primary=primary,
fallback_models=["fallback-a"], fallback_presets=[_fallback("fallback-a")],
provider_factory=factory, provider_factory=factory,
) )
@ -357,7 +606,7 @@ class TestGenerationForwarded:
primary.generation = GenerationSettings(temperature=0.5, max_tokens=1024) primary.generation = GenerationSettings(temperature=0.5, max_tokens=1024)
fb = FallbackProvider( fb = FallbackProvider(
primary=primary, primary=primary,
fallback_models=["a"], fallback_presets=[_fallback("a")],
provider_factory=MagicMock(), provider_factory=MagicMock(),
) )
assert fb.generation.temperature == 0.5 assert fb.generation.temperature == 0.5