feat(runner): support fallback candidates

Resolve fallbackModels as preset references or explicit inline provider configs so failover uses complete model settings without exposing fallback logic to the agent loop.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
Xubin Ren 2026-05-13 15:34:03 +00:00
parent 43db848db0
commit 5efd67919b
5 changed files with 502 additions and 59 deletions

View File

@ -672,7 +672,8 @@ Existing configs do not need to change. If you do not set `modelPresets` or `age
"maxTokens": 8192,
"contextWindowTokens": 128000,
"temperature": 0.1,
"modelPreset": null
"modelPreset": "fast",
"fallbackModels": ["deep"]
}
},
"modelPresets": {
@ -708,6 +709,40 @@ Existing configs do not need to change. If you do not set `modelPresets` or `age
`default` is reserved and always means the implicit preset built from `agents.defaults.*`; do not define `modelPresets.default`. Use `/model default` to switch back to `agents.defaults.*`.
### Model Fallbacks
`agents.defaults.fallbackModels` defines an ordered failover chain for the active model configuration. The primary model is still selected by `agents.defaults.modelPreset` (or the implicit default config when no preset is active).
Each fallback candidate can be either:
- A preset name from `modelPresets`, such as `"deep"`. The preset's full model, provider, generation, and context-window config is used.
- An inline fallback object with at least `provider` and `model`. Optional `maxTokens`, `contextWindowTokens`, and `temperature` fields inherit from the active primary config when omitted. `reasoningEffort` does not inherit; omit it to leave reasoning off for that fallback, or set it explicitly for models that support reasoning.
```json
{
"agents": {
"defaults": {
"modelPreset": "fast",
"fallbackModels": [
"deep",
{
"provider": "deepseek",
"model": "deepseek-v4-pro",
"maxTokens": 4096,
"contextWindowTokens": 262144
}
]
}
}
}
```
String entries are preset names, not raw model names. If you want to use a model that is not already a preset, use the inline object form.
Failover only runs when the primary provider returns a retryable model/provider error before any answer text has been streamed. Typical fallback cases include timeouts, connection errors, 5xx server errors, 429 rate limits, overloads, and quota/balance exhaustion. It does not run for malformed requests, authentication/permission errors, content filtering/refusals, or context-length/message-format errors.
If fallback candidates use smaller `contextWindowTokens` values, nanobot builds context using the smallest window in the active chain so every candidate can receive the same prompt.
Set `agents.defaults.modelPreset` to start with a named preset:
```json

View File

@ -74,6 +74,20 @@ class DreamConfig(Base):
return f"every {hours}h"
class InlineFallbackConfig(Base):
"""One inline fallback model configuration."""
model: str
provider: str
max_tokens: int | None = None
context_window_tokens: int | None = None
temperature: float | None = None
reasoning_effort: str | None = None
FallbackCandidate = str | InlineFallbackConfig
class ModelPresetConfig(Base):
"""A named set of model + generation parameters for quick switching."""
@ -83,7 +97,6 @@ class ModelPresetConfig(Base):
context_window_tokens: int = 65_536
temperature: float = 0.1
reasoning_effort: str | None = None
fallback_models: list[str] = Field(default_factory=list)
def to_generation_settings(self) -> Any:
from nanobot.providers.base import GenerationSettings
@ -107,6 +120,7 @@ class AgentDefaults(Base):
context_window_tokens: int = 65_536
context_block_limit: int | None = None
temperature: float = 0.1
fallback_models: list[FallbackCandidate] = Field(default_factory=list)
max_tool_iterations: int = 200
max_concurrent_subagents: int = Field(default=1, ge=1)
max_tool_result_chars: int = 16_000
@ -288,6 +302,9 @@ class Config(BaseSettings):
name = self.agents.defaults.model_preset
if name and name != "default" and name not in self.model_presets:
raise ValueError(f"model_preset {name!r} not found in model_presets")
for fallback in self.agents.defaults.fallback_models:
if isinstance(fallback, str) and fallback not in self.model_presets:
raise ValueError(f"fallback_models entry {fallback!r} not found in model_presets")
return self
def resolve_default_preset(self) -> ModelPresetConfig:

View File

@ -5,7 +5,7 @@ from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from nanobot.config.schema import Config, ModelPresetConfig
from nanobot.config.schema import Config, InlineFallbackConfig, ModelPresetConfig
from nanobot.providers.base import LLMProvider
from nanobot.providers.fallback_provider import FallbackProvider
from nanobot.providers.registry import find_by_name
@ -104,6 +104,36 @@ def _make_provider_core(
return provider
def _inline_fallback_preset(
primary: ModelPresetConfig,
fallback: InlineFallbackConfig,
) -> ModelPresetConfig:
return ModelPresetConfig(
model=fallback.model,
provider=fallback.provider,
max_tokens=fallback.max_tokens if fallback.max_tokens is not None else primary.max_tokens,
context_window_tokens=(
fallback.context_window_tokens
if fallback.context_window_tokens is not None
else primary.context_window_tokens
),
temperature=(
fallback.temperature if fallback.temperature is not None else primary.temperature
),
reasoning_effort=fallback.reasoning_effort,
)
def _resolve_fallback_presets(config: Config, primary: ModelPresetConfig) -> list[ModelPresetConfig]:
presets: list[ModelPresetConfig] = []
for fallback in config.agents.defaults.fallback_models:
if isinstance(fallback, str):
presets.append(config.model_presets[fallback])
else:
presets.append(_inline_fallback_preset(primary, fallback))
return presets
def make_provider(
config: Config,
*,
@ -118,14 +148,14 @@ def make_provider(
"""
resolved = _resolve_model_preset(config, preset_name=preset_name, preset=preset)
provider = _make_provider_core(config, preset_name=preset_name, preset=preset, model=model)
fallback_presets = _resolve_fallback_presets(config, resolved)
if resolved.fallback_models:
fb_preset = resolved.model_copy(update={"provider": "auto", "fallback_models": []})
if fallback_presets:
provider = FallbackProvider(
primary=provider,
fallback_models=resolved.fallback_models,
provider_factory=lambda m: _make_provider_core(
config, preset_name=preset_name, preset=fb_preset, model=m
fallback_presets=fallback_presets,
provider_factory=lambda fb: _make_provider_core(
config, preset_name=preset_name, preset=fb
),
)
@ -138,9 +168,29 @@ def provider_signature(
preset_name: str | None = None,
preset: ModelPresetConfig | None = None,
) -> tuple[object, ...]:
"""Return the config fields that affect the primary LLM provider."""
"""Return the config fields that affect the active provider chain."""
resolved = _resolve_model_preset(config, preset_name=preset_name, preset=preset)
p = config.get_provider(resolved.model, preset=resolved)
fallback_presets = _resolve_fallback_presets(config, resolved)
def _fallback_signature(fallback: ModelPresetConfig) -> tuple[object, ...]:
fp = config.get_provider(fallback.model, preset=fallback)
return (
fallback.model,
fallback.provider,
config.get_provider_name(fallback.model, preset=fallback),
config.get_api_key(fallback.model, preset=fallback),
config.get_api_base(fallback.model, preset=fallback),
fp.extra_headers if fp else None,
fp.extra_body if fp else None,
getattr(fp, "region", None) if fp else None,
getattr(fp, "profile", None) if fp else None,
fallback.max_tokens,
fallback.temperature,
fallback.reasoning_effort,
fallback.context_window_tokens,
)
return (
resolved.model,
resolved.provider,
@ -155,6 +205,7 @@ def provider_signature(
resolved.temperature,
resolved.reasoning_effort,
resolved.context_window_tokens,
tuple(_fallback_signature(fallback) for fallback in fallback_presets),
)
@ -165,10 +216,14 @@ def build_provider_snapshot(
preset: ModelPresetConfig | None = None,
) -> ProviderSnapshot:
resolved = _resolve_model_preset(config, preset_name=preset_name, preset=preset)
fallback_windows = [
fallback.context_window_tokens
for fallback in _resolve_fallback_presets(config, resolved)
]
return ProviderSnapshot(
provider=make_provider(config, preset=resolved),
model=resolved.model,
context_window_tokens=resolved.context_window_tokens,
context_window_tokens=min([resolved.context_window_tokens, *fallback_windows]),
signature=provider_signature(config, preset=resolved),
)

View File

@ -13,6 +13,46 @@ from nanobot.providers.base import LLMProvider, LLMResponse
# Circuit breaker tuned to match OpenAICompatProvider's Responses API breaker.
_PRIMARY_FAILURE_THRESHOLD = 3
_PRIMARY_COOLDOWN_S = 60
_MISSING = object()
_FALLBACK_ERROR_KINDS = frozenset({
"timeout",
"connection",
"server_error",
"rate_limit",
"overloaded",
})
_NON_FALLBACK_ERROR_KINDS = frozenset({
"authentication",
"auth",
"permission",
"content_filter",
"refusal",
"context_length",
"invalid_request",
})
_FALLBACK_ERROR_TOKENS = (
"rate_limit",
"rate limit",
"too_many_requests",
"too many requests",
"overloaded",
"server_error",
"server error",
"temporarily unavailable",
"timeout",
"timed out",
"connection",
"insufficient_quota",
"insufficient quota",
"quota_exceeded",
"quota exceeded",
"quota_exhausted",
"quota exhausted",
"billing_hard_limit",
"insufficient_balance",
"balance",
"out of credits",
)
class FallbackProvider(LLMProvider):
@ -34,13 +74,13 @@ class FallbackProvider(LLMProvider):
def __init__(
self,
primary: LLMProvider,
fallback_models: list[str],
provider_factory: Callable[[str], LLMProvider],
fallback_presets: list[Any],
provider_factory: Callable[[Any], LLMProvider],
):
self._primary = primary
self._fallback_models = list(fallback_models)
self._fallback_presets = list(fallback_presets)
self._provider_factory = provider_factory
self._has_fallbacks = bool(fallback_models)
self._has_fallbacks = bool(fallback_presets)
self._primary_failures = 0
self._primary_tripped_at: float | None = None
@ -55,6 +95,10 @@ class FallbackProvider(LLMProvider):
def get_default_model(self) -> str:
return self._primary.get_default_model()
@property
def supports_progress_deltas(self) -> bool:
return bool(getattr(self._primary, "supports_progress_deltas", False))
def _primary_available(self) -> bool:
"""Return True if the primary provider is not currently tripped."""
if self._primary_tripped_at is None:
@ -110,6 +154,14 @@ class FallbackProvider(LLMProvider):
)
return response
if not self._should_fallback(response):
logger.warning(
"Primary model '{}' returned non-fallbackable error: {}",
primary_model,
(response.content or "")[:120],
)
return response
self._primary_failures += 1
if self._primary_failures >= _PRIMARY_FAILURE_THRESHOLD:
self._primary_tripped_at = time.monotonic()
@ -122,7 +174,8 @@ class FallbackProvider(LLMProvider):
last_response: LLMResponse | None = None
primary_skipped = not self._primary_available()
for idx, fallback_model in enumerate(self._fallback_models):
for idx, fallback in enumerate(self._fallback_presets):
fallback_model = fallback.model
if has_streamed is not None and has_streamed[0]:
break
if idx == 0 and primary_skipped:
@ -138,25 +191,35 @@ class FallbackProvider(LLMProvider):
else:
logger.info(
"Fallback '{}' also failed, trying next fallback '{}'",
self._fallback_models[idx - 1], fallback_model,
self._fallback_presets[idx - 1].model, fallback_model,
)
try:
fallback_provider = self._provider_factory(fallback_model)
fallback_provider = self._provider_factory(fallback)
except Exception as exc:
logger.warning(
"Failed to create provider for fallback '{}': {}", fallback_model, exc
)
continue
original_model = kwargs.get("model")
original_values = {
name: kwargs.get(name, _MISSING)
for name in ("model", "max_tokens", "temperature", "reasoning_effort")
}
kwargs["model"] = fallback_model
kwargs["max_tokens"] = fallback.max_tokens
kwargs["temperature"] = fallback.temperature
if fallback.reasoning_effort is None:
kwargs.pop("reasoning_effort", None)
else:
kwargs["reasoning_effort"] = fallback.reasoning_effort
try:
fallback_response = await call(fallback_provider, kwargs)
finally:
if original_model is not None:
kwargs["model"] = original_model
else:
kwargs.pop("model", None)
for name, value in original_values.items():
if value is _MISSING:
kwargs.pop(name, None)
else:
kwargs[name] = value
if fallback_response.finish_reason != "error":
logger.info(
@ -174,7 +237,7 @@ class FallbackProvider(LLMProvider):
logger.warning(
"All {} fallback model(s) failed",
len(self._fallback_models),
len(self._fallback_presets),
)
# Return the last error response we saw (primary or last fallback).
if last_response is not None:
@ -184,3 +247,27 @@ class FallbackProvider(LLMProvider):
content=f"Primary model '{primary_model}' circuit open and no fallbacks available",
finish_reason="error",
)
@staticmethod
def _should_fallback(response: LLMResponse) -> bool:
if response.error_should_retry is False:
return False
status = response.error_status_code
kind = (response.error_kind or "").lower()
error_type = (response.error_type or "").lower()
code = (response.error_code or "").lower()
text = (response.content or "").lower()
if status in {400, 401, 403, 404, 422}:
return False
if kind in _NON_FALLBACK_ERROR_KINDS:
return False
if any(token in value for value in (kind, error_type, code) for token in _NON_FALLBACK_ERROR_KINDS):
return False
if response.error_should_retry is True:
return True
if status is not None and (status in {408, 409, 429} or 500 <= status <= 599):
return True
if kind in _FALLBACK_ERROR_KINDS:
return True
return any(token in value for value in (kind, error_type, code, text) for token in _FALLBACK_ERROR_TOKENS)

View File

@ -3,10 +3,11 @@
from __future__ import annotations
from typing import Any
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch
import pytest
from nanobot.config.schema import ModelPresetConfig
from nanobot.providers.base import LLMProvider, LLMResponse
from nanobot.providers.fallback_provider import FallbackProvider
@ -16,14 +17,45 @@ def _make_response(
finish_reason: str = "stop",
*,
error_kind: str | None = None,
error_status_code: int | None = None,
error_type: str | None = None,
error_code: str | None = None,
error_should_retry: bool | None = None,
) -> LLMResponse:
return LLMResponse(content=content, finish_reason=finish_reason, error_kind=error_kind)
return LLMResponse(
content=content,
finish_reason=finish_reason,
error_kind=error_kind,
error_status_code=error_status_code,
error_type=error_type,
error_code=error_code,
error_should_retry=error_should_retry,
)
def _error_response(content: str = "api error") -> LLMResponse:
return _make_response(content, finish_reason="error", error_kind="server_error")
def _fallback(
model: str,
provider: str = "custom",
*,
max_tokens: int = 8192,
context_window_tokens: int = 65_536,
temperature: float = 0.1,
reasoning_effort: str | None = None,
) -> ModelPresetConfig:
return ModelPresetConfig(
model=model,
provider=provider,
max_tokens=max_tokens,
context_window_tokens=context_window_tokens,
temperature=temperature,
reasoning_effort=reasoning_effort,
)
class _FakeProvider(LLMProvider):
"""Fake provider for testing."""
@ -53,24 +85,163 @@ class _FakeProvider(LLMProvider):
def test_fallback_models_default_empty() -> None:
from nanobot.config.schema import ModelPresetConfig
p = ModelPresetConfig(model="test/model")
assert p.fallback_models == []
from nanobot.config.schema import AgentDefaults
defaults = AgentDefaults()
assert defaults.fallback_models == []
def test_fallback_models_accepts_list() -> None:
from nanobot.config.schema import ModelPresetConfig
p = ModelPresetConfig(model="test/primary", fallback_models=["test/a", "test/b"])
assert p.fallback_models == ["test/a", "test/b"]
def test_fallback_models_accept_preset_refs_and_inline_configs() -> None:
from nanobot.config.schema import Config, InlineFallbackConfig
def test_fallback_models_from_camel_case() -> None:
from nanobot.config.schema import ModelPresetConfig
p = ModelPresetConfig.model_validate({
"model": "test/primary",
"fallbackModels": ["test/a"],
config = Config.model_validate({
"agents": {
"defaults": {
"fallbackModels": [
"deep",
{
"provider": "openai",
"model": "gpt-4.1",
"maxTokens": 4096,
},
]
}
},
"modelPresets": {
"deep": {"provider": "anthropic", "model": "claude-opus-4-7"}
},
})
assert p.fallback_models == ["test/a"]
assert config.agents.defaults.fallback_models[0] == "deep"
assert config.agents.defaults.fallback_models[1] == InlineFallbackConfig(
provider="openai",
model="gpt-4.1",
max_tokens=4096,
)
def test_fallback_model_preset_ref_must_exist() -> None:
from nanobot.config.schema import Config
with pytest.raises(ValueError, match="fallback_models.*not found"):
Config.model_validate({
"agents": {"defaults": {"fallbackModels": ["missing"]}},
"modelPresets": {},
})
def test_provider_signature_tracks_fallback_presets_and_provider_config() -> None:
from nanobot.config.schema import Config
from nanobot.providers.factory import provider_signature
base = {
"agents": {
"defaults": {
"modelPreset": "fast",
"fallbackModels": ["deep"],
}
},
"modelPresets": {
"fast": {"model": "openai/gpt-4.1", "provider": "openai"},
"deep": {"model": "anthropic/claude-sonnet-4-6", "provider": "anthropic"},
},
"providers": {
"openai": {"apiKey": "primary-key"},
"anthropic": {"apiKey": "fallback-key"},
},
}
changed_fallback = {
**base,
"agents": {"defaults": {"modelPreset": "fast", "fallbackModels": ["backup"]}},
"modelPresets": {
**base["modelPresets"],
"backup": {"model": "deepseek/deepseek-chat", "provider": "deepseek"},
},
"providers": {
**base["providers"],
"deepseek": {"apiKey": "deepseek-key"},
},
}
changed_key = {
**base,
"providers": {
"openai": {"apiKey": "primary-key"},
"anthropic": {"apiKey": "new-fallback-key"},
},
}
signature = provider_signature(Config.model_validate(base))
assert signature != provider_signature(Config.model_validate(changed_fallback))
assert signature != provider_signature(Config.model_validate(changed_key))
def test_provider_snapshot_uses_smallest_fallback_context_window() -> None:
from nanobot.config.schema import Config
from nanobot.providers.factory import build_provider_snapshot
config = Config.model_validate({
"agents": {
"defaults": {
"modelPreset": "fast",
"fallbackModels": ["deep"],
}
},
"modelPresets": {
"fast": {
"model": "openai/gpt-4.1",
"provider": "openai",
"contextWindowTokens": 128000,
},
"deep": {
"model": "deepseek/deepseek-chat",
"provider": "deepseek",
"contextWindowTokens": 64000,
},
},
"providers": {
"openai": {"apiKey": "primary-key"},
"deepseek": {"apiKey": "fallback-key"},
},
})
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
snapshot = build_provider_snapshot(config)
assert snapshot.context_window_tokens == 64000
def test_inline_fallback_reasoning_effort_does_not_inherit_primary() -> None:
from nanobot.config.schema import Config
from nanobot.providers.factory import provider_signature
config = Config.model_validate({
"agents": {
"defaults": {
"modelPreset": "fast",
"fallbackModels": [
{"provider": "openai", "model": "gpt-4.1"}
],
}
},
"modelPresets": {
"fast": {
"model": "anthropic/claude-opus-4-5",
"provider": "anthropic",
"reasoningEffort": "high",
}
},
"providers": {
"anthropic": {"apiKey": "primary-key"},
"openai": {"apiKey": "fallback-key"},
},
})
signature = provider_signature(config)
fallback_signatures = signature[-1]
assert fallback_signatures[0][11] is None
# -- FallbackProvider tests --
@ -83,7 +254,7 @@ class TestNoFallbackWhenPrimarySucceeds:
factory = MagicMock()
fb = FallbackProvider(
primary=primary,
fallback_models=["fallback-a"],
fallback_presets=[_fallback("fallback-a")],
provider_factory=factory,
)
@ -102,14 +273,14 @@ class TestFallbackOnPrimaryError:
fb = FallbackProvider(
primary=primary,
fallback_models=["fallback-a"],
fallback_presets=[_fallback("fallback-a")],
provider_factory=factory,
)
result = await fb.chat(messages=[{"role": "user", "content": "hi"}], model="primary-model")
assert result.content == "fallback ok"
assert result.finish_reason == "stop"
factory.assert_called_once_with("fallback-a")
factory.assert_called_once_with(_fallback("fallback-a"))
assert primary.chat_calls[0]["model"] == "primary-model"
assert fallback.chat_calls[0]["model"] == "fallback-a"
@ -121,7 +292,7 @@ class TestNoFallbackWhenContentStreamed:
factory = MagicMock()
fb = FallbackProvider(
primary=primary,
fallback_models=["fallback-a"],
fallback_presets=[_fallback("fallback-a")],
provider_factory=factory,
)
@ -146,14 +317,62 @@ class TestFailoverOnTransientError:
factory = MagicMock(return_value=fallback)
fb = FallbackProvider(
primary=primary,
fallback_models=["fallback-a"],
fallback_presets=[_fallback("fallback-a")],
provider_factory=factory,
)
result = await fb.chat(messages=[{"role": "user", "content": "hi"}])
assert result.content == "fallback ok"
assert result.finish_reason == "stop"
factory.assert_called_once_with("fallback-a")
factory.assert_called_once_with(_fallback("fallback-a"))
class TestNoFallbackOnNonRetryableError:
@pytest.mark.asyncio
async def test_bad_request(self) -> None:
primary = _FakeProvider(
"primary",
_make_response(
"invalid request",
finish_reason="error",
error_status_code=400,
error_kind="invalid_request",
),
)
factory = MagicMock()
fb = FallbackProvider(
primary=primary,
fallback_presets=[_fallback("fallback-a")],
provider_factory=factory,
)
result = await fb.chat(messages=[{"role": "user", "content": "hi"}])
assert result.finish_reason == "error"
factory.assert_not_called()
@pytest.mark.asyncio
async def test_auth_error(self) -> None:
primary = _FakeProvider(
"primary",
_make_response(
"unauthorized",
finish_reason="error",
error_status_code=401,
error_kind="authentication",
),
)
factory = MagicMock()
fb = FallbackProvider(
primary=primary,
fallback_presets=[_fallback("fallback-a")],
provider_factory=factory,
)
result = await fb.chat(messages=[{"role": "user", "content": "hi"}])
assert result.finish_reason == "error"
factory.assert_not_called()
@pytest.mark.asyncio
async def test_timeout(self) -> None:
@ -165,14 +384,14 @@ class TestFailoverOnTransientError:
factory = MagicMock(return_value=fallback)
fb = FallbackProvider(
primary=primary,
fallback_models=["fallback-a"],
fallback_presets=[_fallback("fallback-a")],
provider_factory=factory,
)
result = await fb.chat(messages=[{"role": "user", "content": "hi"}])
assert result.content == "fallback ok"
assert result.finish_reason == "stop"
factory.assert_called_once_with("fallback-a")
factory.assert_called_once_with(_fallback("fallback-a"))
class TestFallbackTriesModelsInOrder:
@ -185,15 +404,15 @@ class TestFallbackTriesModelsInOrder:
fb = FallbackProvider(
primary=primary,
fallback_models=["fallback-a", "fallback-b"],
fallback_presets=[_fallback("fallback-a"), _fallback("fallback-b")],
provider_factory=factory,
)
result = await fb.chat(messages=[{"role": "user", "content": "hi"}])
assert result.content == "b ok"
assert factory.call_count == 2
factory.assert_any_call("fallback-a")
factory.assert_any_call("fallback-b")
factory.assert_any_call(_fallback("fallback-a"))
factory.assert_any_call(_fallback("fallback-b"))
class TestAllFallbacksFail:
@ -205,7 +424,7 @@ class TestAllFallbacksFail:
fb = FallbackProvider(
primary=primary,
fallback_models=["fallback-a"],
fallback_presets=[_fallback("fallback-a")],
provider_factory=factory,
)
@ -223,7 +442,7 @@ class TestFactoryExceptionSkipsModel:
fb = FallbackProvider(
primary=primary,
fallback_models=["fallback-a", "fallback-b"],
fallback_presets=[_fallback("fallback-a"), _fallback("fallback-b")],
provider_factory=factory,
)
@ -242,13 +461,43 @@ class TestFallbackModelParameter:
fb = FallbackProvider(
primary=primary,
fallback_models=["fallback-model"],
fallback_presets=[_fallback("fallback-model")],
provider_factory=factory,
)
await fb.chat(messages=[{"role": "user", "content": "hi"}], model="primary-model")
assert fallback.chat_calls[0]["model"] == "fallback-model"
@pytest.mark.asyncio
async def test_uses_fallback_generation_fields(self) -> None:
primary = _FakeProvider("primary", _error_response())
fallback = _FakeProvider("fallback", _make_response("ok"))
fb = FallbackProvider(
primary=primary,
fallback_presets=[
_fallback(
"fallback-model",
max_tokens=1234,
temperature=0.4,
reasoning_effort=None,
)
],
provider_factory=MagicMock(return_value=fallback),
)
await fb.chat(
messages=[{"role": "user", "content": "hi"}],
model="primary-model",
max_tokens=8192,
temperature=0.1,
reasoning_effort="high",
)
assert fallback.chat_calls[0]["model"] == "fallback-model"
assert fallback.chat_calls[0]["max_tokens"] == 1234
assert fallback.chat_calls[0]["temperature"] == 0.4
assert "reasoning_effort" not in fallback.chat_calls[0]
class TestNoFallbackWhenEmptyList:
@pytest.mark.asyncio
@ -258,7 +507,7 @@ class TestNoFallbackWhenEmptyList:
fb = FallbackProvider(
primary=primary,
fallback_models=[],
fallback_presets=[],
provider_factory=factory,
)
@ -277,7 +526,7 @@ class TestChatStreamFailover:
fb = FallbackProvider(
primary=primary,
fallback_models=["fallback-a"],
fallback_presets=[_fallback("fallback-a")],
provider_factory=factory,
)
@ -291,7 +540,7 @@ class TestGetDefaultModel:
primary = _FakeProvider("primary")
fb = FallbackProvider(
primary=primary,
fallback_models=["a"],
fallback_presets=[_fallback("a")],
provider_factory=MagicMock(),
)
assert fb.get_default_model() == "primary/model"
@ -305,7 +554,7 @@ class TestCircuitBreaker:
factory = MagicMock(return_value=fallback)
fb = FallbackProvider(
primary=primary,
fallback_models=["fallback-a"],
fallback_presets=[_fallback("fallback-a")],
provider_factory=factory,
)
@ -329,7 +578,7 @@ class TestCircuitBreaker:
factory = MagicMock(return_value=fallback)
fb = FallbackProvider(
primary=primary,
fallback_models=["fallback-a"],
fallback_presets=[_fallback("fallback-a")],
provider_factory=factory,
)
@ -357,7 +606,7 @@ class TestGenerationForwarded:
primary.generation = GenerationSettings(temperature=0.5, max_tokens=1024)
fb = FallbackProvider(
primary=primary,
fallback_models=["a"],
fallback_presets=[_fallback("a")],
provider_factory=MagicMock(),
)
assert fb.generation.temperature == 0.5