Merge PR #3756: feat(runner): model failover with fallback_models

feat(runner): model failover with fallback_models
This commit is contained in:
Xubin Ren 2026-05-13 23:38:14 +08:00 committed by GitHub
commit 921fe259f4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 1031 additions and 7 deletions

View File

@ -672,7 +672,8 @@ Existing configs do not need to change. If you do not set `modelPresets` or `age
"maxTokens": 8192, "maxTokens": 8192,
"contextWindowTokens": 128000, "contextWindowTokens": 128000,
"temperature": 0.1, "temperature": 0.1,
"modelPreset": null "modelPreset": "fast",
"fallbackModels": ["deep"]
} }
}, },
"modelPresets": { "modelPresets": {
@ -708,6 +709,40 @@ Existing configs do not need to change. If you do not set `modelPresets` or `age
`default` is reserved and always means the implicit preset built from `agents.defaults.*`; do not define `modelPresets.default`. Use `/model default` to switch back to `agents.defaults.*`. `default` is reserved and always means the implicit preset built from `agents.defaults.*`; do not define `modelPresets.default`. Use `/model default` to switch back to `agents.defaults.*`.
### Model Fallbacks
`agents.defaults.fallbackModels` defines an ordered failover chain for the active model configuration. The primary model is still selected by `agents.defaults.modelPreset` (or the implicit default config when no preset is active).
Each fallback candidate can be either:
- A preset name from `modelPresets`, such as `"deep"`. The preset's full model, provider, generation, and context-window config is used.
- An inline fallback object with at least `provider` and `model`. Optional `maxTokens`, `contextWindowTokens`, and `temperature` fields inherit from the active primary config when omitted. `reasoningEffort` does not inherit; omit it to leave reasoning off for that fallback, or set it explicitly for models that support reasoning.
```json
{
"agents": {
"defaults": {
"modelPreset": "fast",
"fallbackModels": [
"deep",
{
"provider": "deepseek",
"model": "deepseek-v4-pro",
"maxTokens": 4096,
"contextWindowTokens": 262144
}
]
}
}
}
```
String entries are preset names, not raw model names. If you want to use a model that is not already a preset, use the inline object form.
Failover only runs when the primary provider returns a retryable model/provider error before any answer text has been streamed. Typical fallback cases include timeouts, connection errors, 5xx server errors, 429 rate limits, overloads, and quota/balance exhaustion. It does not run for malformed requests, authentication/permission errors, content filtering/refusals, or context-length/message-format errors.
If fallback candidates use smaller `contextWindowTokens` values, nanobot builds context using the smallest window in the active chain so every candidate can receive the same prompt.
Set `agents.defaults.modelPreset` to start with a named preset: Set `agents.defaults.modelPreset` to start with a named preset:
```json ```json

View File

@ -74,6 +74,20 @@ class DreamConfig(Base):
return f"every {hours}h" return f"every {hours}h"
class InlineFallbackConfig(Base):
"""One inline fallback model configuration."""
model: str
provider: str
max_tokens: int | None = None
context_window_tokens: int | None = None
temperature: float | None = None
reasoning_effort: str | None = None
FallbackCandidate = str | InlineFallbackConfig
class ModelPresetConfig(Base): class ModelPresetConfig(Base):
"""A named set of model + generation parameters for quick switching.""" """A named set of model + generation parameters for quick switching."""
@ -106,6 +120,7 @@ class AgentDefaults(Base):
context_window_tokens: int = 65_536 context_window_tokens: int = 65_536
context_block_limit: int | None = None context_block_limit: int | None = None
temperature: float = 0.1 temperature: float = 0.1
fallback_models: list[FallbackCandidate] = Field(default_factory=list)
max_tool_iterations: int = 200 max_tool_iterations: int = 200
max_concurrent_subagents: int = Field(default=1, ge=1) max_concurrent_subagents: int = Field(default=1, ge=1)
max_tool_result_chars: int = 16_000 max_tool_result_chars: int = 16_000
@ -287,6 +302,9 @@ class Config(BaseSettings):
name = self.agents.defaults.model_preset name = self.agents.defaults.model_preset
if name and name != "default" and name not in self.model_presets: if name and name != "default" and name not in self.model_presets:
raise ValueError(f"model_preset {name!r} not found in model_presets") raise ValueError(f"model_preset {name!r} not found in model_presets")
for fallback in self.agents.defaults.fallback_models:
if isinstance(fallback, str) and fallback not in self.model_presets:
raise ValueError(f"fallback_models entry {fallback!r} not found in model_presets")
return self return self
def resolve_default_preset(self) -> ModelPresetConfig: def resolve_default_preset(self) -> ModelPresetConfig:

View File

@ -5,8 +5,9 @@ from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from nanobot.config.schema import Config, ModelPresetConfig from nanobot.config.schema import Config, InlineFallbackConfig, ModelPresetConfig
from nanobot.providers.base import LLMProvider from nanobot.providers.base import LLMProvider
from nanobot.providers.fallback_provider import FallbackProvider
from nanobot.providers.registry import find_by_name from nanobot.providers.registry import find_by_name
@ -27,15 +28,16 @@ def _resolve_model_preset(
return preset if preset is not None else config.resolve_preset(preset_name) return preset if preset is not None else config.resolve_preset(preset_name)
def make_provider( def _make_provider_core(
config: Config, config: Config,
*, *,
preset_name: str | None = None, preset_name: str | None = None,
preset: ModelPresetConfig | None = None, preset: ModelPresetConfig | None = None,
model: str | None = None,
) -> LLMProvider: ) -> LLMProvider:
"""Create the LLM provider implied by config.""" """Create a plain LLM provider without failover wrapping."""
resolved = _resolve_model_preset(config, preset_name=preset_name, preset=preset) resolved = _resolve_model_preset(config, preset_name=preset_name, preset=preset)
model = resolved.model model = model or resolved.model
provider_name = config.get_provider_name(model, preset=resolved) provider_name = config.get_provider_name(model, preset=resolved)
p = config.get_provider(model, preset=resolved) p = config.get_provider(model, preset=resolved)
spec = find_by_name(provider_name) if provider_name else None spec = find_by_name(provider_name) if provider_name else None
@ -102,15 +104,93 @@ def make_provider(
return provider return provider
def _inline_fallback_preset(
primary: ModelPresetConfig,
fallback: InlineFallbackConfig,
) -> ModelPresetConfig:
return ModelPresetConfig(
model=fallback.model,
provider=fallback.provider,
max_tokens=fallback.max_tokens if fallback.max_tokens is not None else primary.max_tokens,
context_window_tokens=(
fallback.context_window_tokens
if fallback.context_window_tokens is not None
else primary.context_window_tokens
),
temperature=(
fallback.temperature if fallback.temperature is not None else primary.temperature
),
reasoning_effort=fallback.reasoning_effort,
)
def _resolve_fallback_presets(config: Config, primary: ModelPresetConfig) -> list[ModelPresetConfig]:
presets: list[ModelPresetConfig] = []
for fallback in config.agents.defaults.fallback_models:
if isinstance(fallback, str):
presets.append(config.model_presets[fallback])
else:
presets.append(_inline_fallback_preset(primary, fallback))
return presets
def make_provider(
config: Config,
*,
preset_name: str | None = None,
preset: ModelPresetConfig | None = None,
model: str | None = None,
) -> LLMProvider:
"""Create the LLM provider implied by config.
When *model* is given, it overrides the resolved/preset model used by
the failover path to create providers for fallback models.
"""
resolved = _resolve_model_preset(config, preset_name=preset_name, preset=preset)
provider = _make_provider_core(config, preset_name=preset_name, preset=preset, model=model)
fallback_presets = _resolve_fallback_presets(config, resolved)
if fallback_presets:
provider = FallbackProvider(
primary=provider,
fallback_presets=fallback_presets,
provider_factory=lambda fb: _make_provider_core(
config, preset_name=preset_name, preset=fb
),
)
return provider
def provider_signature( def provider_signature(
config: Config, config: Config,
*, *,
preset_name: str | None = None, preset_name: str | None = None,
preset: ModelPresetConfig | None = None, preset: ModelPresetConfig | None = None,
) -> tuple[object, ...]: ) -> tuple[object, ...]:
"""Return the config fields that affect the primary LLM provider.""" """Return the config fields that affect the active provider chain."""
resolved = _resolve_model_preset(config, preset_name=preset_name, preset=preset) resolved = _resolve_model_preset(config, preset_name=preset_name, preset=preset)
p = config.get_provider(resolved.model, preset=resolved) p = config.get_provider(resolved.model, preset=resolved)
fallback_presets = _resolve_fallback_presets(config, resolved)
def _fallback_signature(fallback: ModelPresetConfig) -> tuple[object, ...]:
fp = config.get_provider(fallback.model, preset=fallback)
return (
fallback.model,
fallback.provider,
config.get_provider_name(fallback.model, preset=fallback),
config.get_api_key(fallback.model, preset=fallback),
config.get_api_base(fallback.model, preset=fallback),
fp.extra_headers if fp else None,
fp.extra_body if fp else None,
getattr(fp, "region", None) if fp else None,
getattr(fp, "profile", None) if fp else None,
fallback.max_tokens,
fallback.temperature,
fallback.reasoning_effort,
fallback.context_window_tokens,
)
return ( return (
resolved.model, resolved.model,
resolved.provider, resolved.provider,
@ -125,6 +205,7 @@ def provider_signature(
resolved.temperature, resolved.temperature,
resolved.reasoning_effort, resolved.reasoning_effort,
resolved.context_window_tokens, resolved.context_window_tokens,
tuple(_fallback_signature(fallback) for fallback in fallback_presets),
) )
@ -135,10 +216,14 @@ def build_provider_snapshot(
preset: ModelPresetConfig | None = None, preset: ModelPresetConfig | None = None,
) -> ProviderSnapshot: ) -> ProviderSnapshot:
resolved = _resolve_model_preset(config, preset_name=preset_name, preset=preset) resolved = _resolve_model_preset(config, preset_name=preset_name, preset=preset)
fallback_windows = [
fallback.context_window_tokens
for fallback in _resolve_fallback_presets(config, resolved)
]
return ProviderSnapshot( return ProviderSnapshot(
provider=make_provider(config, preset=resolved), provider=make_provider(config, preset=resolved),
model=resolved.model, model=resolved.model,
context_window_tokens=resolved.context_window_tokens, context_window_tokens=min([resolved.context_window_tokens, *fallback_windows]),
signature=provider_signature(config, preset=resolved), signature=provider_signature(config, preset=resolved),
) )

View File

@ -0,0 +1,273 @@
"""Provider wrapper that transparently fails over to fallback models on error."""
from __future__ import annotations
import time
from collections.abc import Awaitable, Callable
from typing import Any
from loguru import logger
from nanobot.providers.base import LLMProvider, LLMResponse
# Circuit breaker tuned to match OpenAICompatProvider's Responses API breaker.
_PRIMARY_FAILURE_THRESHOLD = 3
_PRIMARY_COOLDOWN_S = 60
_MISSING = object()
_FALLBACK_ERROR_KINDS = frozenset({
"timeout",
"connection",
"server_error",
"rate_limit",
"overloaded",
})
_NON_FALLBACK_ERROR_KINDS = frozenset({
"authentication",
"auth",
"permission",
"content_filter",
"refusal",
"context_length",
"invalid_request",
})
_FALLBACK_ERROR_TOKENS = (
"rate_limit",
"rate limit",
"too_many_requests",
"too many requests",
"overloaded",
"server_error",
"server error",
"temporarily unavailable",
"timeout",
"timed out",
"connection",
"insufficient_quota",
"insufficient quota",
"quota_exceeded",
"quota exceeded",
"quota_exhausted",
"quota exhausted",
"billing_hard_limit",
"insufficient_balance",
"balance",
"out of credits",
)
class FallbackProvider(LLMProvider):
"""Wrap a primary provider and transparently failover to fallback models.
When the primary model returns an error and no content has been streamed yet,
the wrapper tries each fallback model in order. Each fallback model may
reside on a different provider a factory callable creates the underlying
provider on-the-fly.
Key design:
- Failover is request-scoped (the wrapper itself is stateless between turns).
- Skipped when content was already streamed to avoid duplicate output.
- Recursive failover is prevented by the factory returning plain providers.
- Primary provider is circuit-broken after repeated failures to avoid
wasting requests on a known-bad endpoint.
"""
def __init__(
self,
primary: LLMProvider,
fallback_presets: list[Any],
provider_factory: Callable[[Any], LLMProvider],
):
self._primary = primary
self._fallback_presets = list(fallback_presets)
self._provider_factory = provider_factory
self._has_fallbacks = bool(fallback_presets)
self._primary_failures = 0
self._primary_tripped_at: float | None = None
@property
def generation(self):
return self._primary.generation
@generation.setter
def generation(self, value):
self._primary.generation = value
def get_default_model(self) -> str:
return self._primary.get_default_model()
@property
def supports_progress_deltas(self) -> bool:
return bool(getattr(self._primary, "supports_progress_deltas", False))
def _primary_available(self) -> bool:
"""Return True if the primary provider is not currently tripped."""
if self._primary_tripped_at is None:
return True
if time.monotonic() - self._primary_tripped_at >= _PRIMARY_COOLDOWN_S:
# Half-open: allow one probe attempt.
return True
return False
async def chat(self, **kwargs: Any) -> LLMResponse:
if not self._has_fallbacks:
return await self._primary.chat(**kwargs)
return await self._try_with_fallback(
lambda p, kw: p.chat(**kw), kwargs, has_streamed=None
)
async def chat_stream(self, **kwargs: Any) -> LLMResponse:
if not self._has_fallbacks:
return await self._primary.chat_stream(**kwargs)
has_streamed: list[bool] = [False]
original_delta = kwargs.get("on_content_delta")
async def _tracking_delta(text: str) -> None:
if text:
has_streamed[0] = True
if original_delta:
await original_delta(text)
kwargs["on_content_delta"] = _tracking_delta
return await self._try_with_fallback(
lambda p, kw: p.chat_stream(**kw), kwargs, has_streamed=has_streamed
)
async def _try_with_fallback(
self,
call: Callable[[LLMProvider, dict[str, Any]], Awaitable[LLMResponse]],
kwargs: dict[str, Any],
has_streamed: list[bool] | None,
) -> LLMResponse:
primary_model = kwargs.get("model") or self._primary.get_default_model()
if self._primary_available():
response = await call(self._primary, kwargs)
if response.finish_reason != "error":
self._primary_failures = 0
self._primary_tripped_at = None
return response
if has_streamed is not None and has_streamed[0]:
logger.warning(
"Primary model error but content already streamed; skipping failover"
)
return response
if not self._should_fallback(response):
logger.warning(
"Primary model '{}' returned non-fallbackable error: {}",
primary_model,
(response.content or "")[:120],
)
return response
self._primary_failures += 1
if self._primary_failures >= _PRIMARY_FAILURE_THRESHOLD:
self._primary_tripped_at = time.monotonic()
logger.warning(
"Primary model '{}' circuit open after {} consecutive failures",
primary_model, self._primary_failures,
)
else:
logger.debug("Primary model '{}' circuit open; skipping", primary_model)
last_response: LLMResponse | None = None
primary_skipped = not self._primary_available()
for idx, fallback in enumerate(self._fallback_presets):
fallback_model = fallback.model
if has_streamed is not None and has_streamed[0]:
break
if idx == 0 and primary_skipped:
logger.info(
"Primary model '{}' circuit open, trying fallback '{}'",
primary_model, fallback_model,
)
elif idx == 0:
logger.info(
"Primary model '{}' failed, trying fallback '{}'",
primary_model, fallback_model,
)
else:
logger.info(
"Fallback '{}' also failed, trying next fallback '{}'",
self._fallback_presets[idx - 1].model, fallback_model,
)
try:
fallback_provider = self._provider_factory(fallback)
except Exception as exc:
logger.warning(
"Failed to create provider for fallback '{}': {}", fallback_model, exc
)
continue
original_values = {
name: kwargs.get(name, _MISSING)
for name in ("model", "max_tokens", "temperature", "reasoning_effort")
}
kwargs["model"] = fallback_model
kwargs["max_tokens"] = fallback.max_tokens
kwargs["temperature"] = fallback.temperature
if fallback.reasoning_effort is None:
kwargs.pop("reasoning_effort", None)
else:
kwargs["reasoning_effort"] = fallback.reasoning_effort
try:
fallback_response = await call(fallback_provider, kwargs)
finally:
for name, value in original_values.items():
if value is _MISSING:
kwargs.pop(name, None)
else:
kwargs[name] = value
if fallback_response.finish_reason != "error":
logger.info(
"Fallback '{}' succeeded after primary '{}' failed",
fallback_model, primary_model,
)
return fallback_response
last_response = fallback_response
logger.warning(
"Fallback '{}' also failed: {}",
fallback_model,
(fallback_response.content or "")[:120],
)
logger.warning(
"All {} fallback model(s) failed",
len(self._fallback_presets),
)
# Return the last error response we saw (primary or last fallback).
if last_response is not None:
return last_response
# Primary was tripped and we have no fallbacks — synthesize an error.
return LLMResponse(
content=f"Primary model '{primary_model}' circuit open and no fallbacks available",
finish_reason="error",
)
@staticmethod
def _should_fallback(response: LLMResponse) -> bool:
if response.error_should_retry is False:
return False
status = response.error_status_code
kind = (response.error_kind or "").lower()
error_type = (response.error_type or "").lower()
code = (response.error_code or "").lower()
text = (response.content or "").lower()
if status in {400, 401, 403, 404, 422}:
return False
if kind in _NON_FALLBACK_ERROR_KINDS:
return False
if any(token in value for value in (kind, error_type, code) for token in _NON_FALLBACK_ERROR_KINDS):
return False
if response.error_should_retry is True:
return True
if status is not None and (status in {408, 409, 429} or 500 <= status <= 599):
return True
if kind in _FALLBACK_ERROR_KINDS:
return True
return any(token in value for value in (kind, error_type, code, text) for token in _FALLBACK_ERROR_TOKENS)

View File

@ -0,0 +1,613 @@
"""Tests for FallbackProvider model failover."""
from __future__ import annotations
from typing import Any
from unittest.mock import MagicMock, patch
import pytest
from nanobot.config.schema import ModelPresetConfig
from nanobot.providers.base import LLMProvider, LLMResponse
from nanobot.providers.fallback_provider import FallbackProvider
def _make_response(
content: str = "ok",
finish_reason: str = "stop",
*,
error_kind: str | None = None,
error_status_code: int | None = None,
error_type: str | None = None,
error_code: str | None = None,
error_should_retry: bool | None = None,
) -> LLMResponse:
return LLMResponse(
content=content,
finish_reason=finish_reason,
error_kind=error_kind,
error_status_code=error_status_code,
error_type=error_type,
error_code=error_code,
error_should_retry=error_should_retry,
)
def _error_response(content: str = "api error") -> LLMResponse:
return _make_response(content, finish_reason="error", error_kind="server_error")
def _fallback(
model: str,
provider: str = "custom",
*,
max_tokens: int = 8192,
context_window_tokens: int = 65_536,
temperature: float = 0.1,
reasoning_effort: str | None = None,
) -> ModelPresetConfig:
return ModelPresetConfig(
model=model,
provider=provider,
max_tokens=max_tokens,
context_window_tokens=context_window_tokens,
temperature=temperature,
reasoning_effort=reasoning_effort,
)
class _FakeProvider(LLMProvider):
"""Fake provider for testing."""
def __init__(self, name: str = "fake", response: LLMResponse | None = None):
super().__init__()
self.name = name
self._response = response or _make_response()
self.chat_calls: list[dict[str, Any]] = []
self.chat_stream_calls: list[dict[str, Any]] = []
def get_default_model(self) -> str:
return f"{self.name}/model"
async def chat(self, **kwargs: Any) -> LLMResponse:
self.chat_calls.append(dict(kwargs))
return self._response
async def chat_stream(self, **kwargs: Any) -> LLMResponse:
self.chat_stream_calls.append(dict(kwargs))
on_delta = kwargs.get("on_content_delta")
if on_delta and self._response.content:
await on_delta(self._response.content)
return self._response
# -- config-level tests --
def test_fallback_models_default_empty() -> None:
from nanobot.config.schema import AgentDefaults
defaults = AgentDefaults()
assert defaults.fallback_models == []
def test_fallback_models_accept_preset_refs_and_inline_configs() -> None:
from nanobot.config.schema import Config, InlineFallbackConfig
config = Config.model_validate({
"agents": {
"defaults": {
"fallbackModels": [
"deep",
{
"provider": "openai",
"model": "gpt-4.1",
"maxTokens": 4096,
},
]
}
},
"modelPresets": {
"deep": {"provider": "anthropic", "model": "claude-opus-4-7"}
},
})
assert config.agents.defaults.fallback_models[0] == "deep"
assert config.agents.defaults.fallback_models[1] == InlineFallbackConfig(
provider="openai",
model="gpt-4.1",
max_tokens=4096,
)
def test_fallback_model_preset_ref_must_exist() -> None:
from nanobot.config.schema import Config
with pytest.raises(ValueError, match="fallback_models.*not found"):
Config.model_validate({
"agents": {"defaults": {"fallbackModels": ["missing"]}},
"modelPresets": {},
})
def test_provider_signature_tracks_fallback_presets_and_provider_config() -> None:
from nanobot.config.schema import Config
from nanobot.providers.factory import provider_signature
base = {
"agents": {
"defaults": {
"modelPreset": "fast",
"fallbackModels": ["deep"],
}
},
"modelPresets": {
"fast": {"model": "openai/gpt-4.1", "provider": "openai"},
"deep": {"model": "anthropic/claude-sonnet-4-6", "provider": "anthropic"},
},
"providers": {
"openai": {"apiKey": "primary-key"},
"anthropic": {"apiKey": "fallback-key"},
},
}
changed_fallback = {
**base,
"agents": {"defaults": {"modelPreset": "fast", "fallbackModels": ["backup"]}},
"modelPresets": {
**base["modelPresets"],
"backup": {"model": "deepseek/deepseek-chat", "provider": "deepseek"},
},
"providers": {
**base["providers"],
"deepseek": {"apiKey": "deepseek-key"},
},
}
changed_key = {
**base,
"providers": {
"openai": {"apiKey": "primary-key"},
"anthropic": {"apiKey": "new-fallback-key"},
},
}
signature = provider_signature(Config.model_validate(base))
assert signature != provider_signature(Config.model_validate(changed_fallback))
assert signature != provider_signature(Config.model_validate(changed_key))
def test_provider_snapshot_uses_smallest_fallback_context_window() -> None:
from nanobot.config.schema import Config
from nanobot.providers.factory import build_provider_snapshot
config = Config.model_validate({
"agents": {
"defaults": {
"modelPreset": "fast",
"fallbackModels": ["deep"],
}
},
"modelPresets": {
"fast": {
"model": "openai/gpt-4.1",
"provider": "openai",
"contextWindowTokens": 128000,
},
"deep": {
"model": "deepseek/deepseek-chat",
"provider": "deepseek",
"contextWindowTokens": 64000,
},
},
"providers": {
"openai": {"apiKey": "primary-key"},
"deepseek": {"apiKey": "fallback-key"},
},
})
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
snapshot = build_provider_snapshot(config)
assert snapshot.context_window_tokens == 64000
def test_inline_fallback_reasoning_effort_does_not_inherit_primary() -> None:
from nanobot.config.schema import Config
from nanobot.providers.factory import provider_signature
config = Config.model_validate({
"agents": {
"defaults": {
"modelPreset": "fast",
"fallbackModels": [
{"provider": "openai", "model": "gpt-4.1"}
],
}
},
"modelPresets": {
"fast": {
"model": "anthropic/claude-opus-4-5",
"provider": "anthropic",
"reasoningEffort": "high",
}
},
"providers": {
"anthropic": {"apiKey": "primary-key"},
"openai": {"apiKey": "fallback-key"},
},
})
signature = provider_signature(config)
fallback_signatures = signature[-1]
assert fallback_signatures[0][11] is None
# -- FallbackProvider tests --
class TestNoFallbackWhenPrimarySucceeds:
@pytest.mark.asyncio
async def test(self) -> None:
primary = _FakeProvider("primary", _make_response("primary ok"))
factory = MagicMock()
fb = FallbackProvider(
primary=primary,
fallback_presets=[_fallback("fallback-a")],
provider_factory=factory,
)
result = await fb.chat(messages=[{"role": "user", "content": "hi"}])
assert result.content == "primary ok"
assert result.finish_reason == "stop"
factory.assert_not_called()
class TestFallbackOnPrimaryError:
@pytest.mark.asyncio
async def test_first_fallback_succeeds(self) -> None:
primary = _FakeProvider("primary", _error_response())
fallback = _FakeProvider("fallback", _make_response("fallback ok"))
factory = MagicMock(return_value=fallback)
fb = FallbackProvider(
primary=primary,
fallback_presets=[_fallback("fallback-a")],
provider_factory=factory,
)
result = await fb.chat(messages=[{"role": "user", "content": "hi"}], model="primary-model")
assert result.content == "fallback ok"
assert result.finish_reason == "stop"
factory.assert_called_once_with(_fallback("fallback-a"))
assert primary.chat_calls[0]["model"] == "primary-model"
assert fallback.chat_calls[0]["model"] == "fallback-a"
class TestNoFallbackWhenContentStreamed:
@pytest.mark.asyncio
async def test(self) -> None:
primary = _FakeProvider("primary", _error_response())
factory = MagicMock()
fb = FallbackProvider(
primary=primary,
fallback_presets=[_fallback("fallback-a")],
provider_factory=factory,
)
async def _delta(text: str) -> None:
pass
result = await fb.chat_stream(
messages=[{"role": "user", "content": "hi"}],
on_content_delta=_delta,
)
# Primary returns error but content was "streamed" (FakeProvider calls delta)
# so failover should be skipped
assert result.finish_reason == "error"
factory.assert_not_called()
class TestFailoverOnTransientError:
@pytest.mark.asyncio
async def test_rate_limit(self) -> None:
primary = _FakeProvider("primary", _error_response("rate limit exceeded"))
fallback = _FakeProvider("fallback", _make_response("fallback ok"))
factory = MagicMock(return_value=fallback)
fb = FallbackProvider(
primary=primary,
fallback_presets=[_fallback("fallback-a")],
provider_factory=factory,
)
result = await fb.chat(messages=[{"role": "user", "content": "hi"}])
assert result.content == "fallback ok"
assert result.finish_reason == "stop"
factory.assert_called_once_with(_fallback("fallback-a"))
class TestNoFallbackOnNonRetryableError:
@pytest.mark.asyncio
async def test_bad_request(self) -> None:
primary = _FakeProvider(
"primary",
_make_response(
"invalid request",
finish_reason="error",
error_status_code=400,
error_kind="invalid_request",
),
)
factory = MagicMock()
fb = FallbackProvider(
primary=primary,
fallback_presets=[_fallback("fallback-a")],
provider_factory=factory,
)
result = await fb.chat(messages=[{"role": "user", "content": "hi"}])
assert result.finish_reason == "error"
factory.assert_not_called()
@pytest.mark.asyncio
async def test_auth_error(self) -> None:
primary = _FakeProvider(
"primary",
_make_response(
"unauthorized",
finish_reason="error",
error_status_code=401,
error_kind="authentication",
),
)
factory = MagicMock()
fb = FallbackProvider(
primary=primary,
fallback_presets=[_fallback("fallback-a")],
provider_factory=factory,
)
result = await fb.chat(messages=[{"role": "user", "content": "hi"}])
assert result.finish_reason == "error"
factory.assert_not_called()
@pytest.mark.asyncio
async def test_timeout(self) -> None:
primary = _FakeProvider(
"primary",
_make_response("timed out", finish_reason="error", error_kind="timeout"),
)
fallback = _FakeProvider("fallback", _make_response("fallback ok"))
factory = MagicMock(return_value=fallback)
fb = FallbackProvider(
primary=primary,
fallback_presets=[_fallback("fallback-a")],
provider_factory=factory,
)
result = await fb.chat(messages=[{"role": "user", "content": "hi"}])
assert result.content == "fallback ok"
assert result.finish_reason == "stop"
factory.assert_called_once_with(_fallback("fallback-a"))
class TestFallbackTriesModelsInOrder:
@pytest.mark.asyncio
async def test(self) -> None:
primary = _FakeProvider("primary", _error_response("primary fail"))
fallback_a = _FakeProvider("a", _error_response("a fail"))
fallback_b = _FakeProvider("b", _make_response("b ok"))
factory = MagicMock(side_effect=[fallback_a, fallback_b])
fb = FallbackProvider(
primary=primary,
fallback_presets=[_fallback("fallback-a"), _fallback("fallback-b")],
provider_factory=factory,
)
result = await fb.chat(messages=[{"role": "user", "content": "hi"}])
assert result.content == "b ok"
assert factory.call_count == 2
factory.assert_any_call(_fallback("fallback-a"))
factory.assert_any_call(_fallback("fallback-b"))
class TestAllFallbacksFail:
@pytest.mark.asyncio
async def test(self) -> None:
primary = _FakeProvider("primary", _error_response("primary fail"))
fallback = _FakeProvider("fallback", _error_response("all fail"))
factory = MagicMock(return_value=fallback)
fb = FallbackProvider(
primary=primary,
fallback_presets=[_fallback("fallback-a")],
provider_factory=factory,
)
result = await fb.chat(messages=[{"role": "user", "content": "hi"}])
assert result.finish_reason == "error"
assert "all fail" in result.content
class TestFactoryExceptionSkipsModel:
@pytest.mark.asyncio
async def test(self) -> None:
primary = _FakeProvider("primary", _error_response())
fallback_b = _FakeProvider("b", _make_response("b ok"))
factory = MagicMock(side_effect=[ValueError("no key"), fallback_b])
fb = FallbackProvider(
primary=primary,
fallback_presets=[_fallback("fallback-a"), _fallback("fallback-b")],
provider_factory=factory,
)
result = await fb.chat(messages=[{"role": "user", "content": "hi"}])
assert result.content == "b ok"
assert factory.call_count == 2
class TestFallbackModelParameter:
@pytest.mark.asyncio
async def test(self) -> None:
"""Fallback calls should use the fallback model name."""
primary = _FakeProvider("primary", _error_response())
fallback = _FakeProvider("fallback", _make_response("ok"))
factory = MagicMock(return_value=fallback)
fb = FallbackProvider(
primary=primary,
fallback_presets=[_fallback("fallback-model")],
provider_factory=factory,
)
await fb.chat(messages=[{"role": "user", "content": "hi"}], model="primary-model")
assert fallback.chat_calls[0]["model"] == "fallback-model"
@pytest.mark.asyncio
async def test_uses_fallback_generation_fields(self) -> None:
primary = _FakeProvider("primary", _error_response())
fallback = _FakeProvider("fallback", _make_response("ok"))
fb = FallbackProvider(
primary=primary,
fallback_presets=[
_fallback(
"fallback-model",
max_tokens=1234,
temperature=0.4,
reasoning_effort=None,
)
],
provider_factory=MagicMock(return_value=fallback),
)
await fb.chat(
messages=[{"role": "user", "content": "hi"}],
model="primary-model",
max_tokens=8192,
temperature=0.1,
reasoning_effort="high",
)
assert fallback.chat_calls[0]["model"] == "fallback-model"
assert fallback.chat_calls[0]["max_tokens"] == 1234
assert fallback.chat_calls[0]["temperature"] == 0.4
assert "reasoning_effort" not in fallback.chat_calls[0]
class TestNoFallbackWhenEmptyList:
@pytest.mark.asyncio
async def test(self) -> None:
primary = _FakeProvider("primary", _error_response())
factory = MagicMock()
fb = FallbackProvider(
primary=primary,
fallback_presets=[],
provider_factory=factory,
)
result = await fb.chat(messages=[{"role": "user", "content": "hi"}])
assert result.finish_reason == "error"
factory.assert_not_called()
class TestChatStreamFailover:
@pytest.mark.asyncio
async def test_fallback_succeeds(self) -> None:
# Use empty content so on_content_delta is not triggered on the error
primary = _FakeProvider("primary", _error_response(""))
fallback = _FakeProvider("fallback", _make_response("stream ok"))
factory = MagicMock(return_value=fallback)
fb = FallbackProvider(
primary=primary,
fallback_presets=[_fallback("fallback-a")],
provider_factory=factory,
)
result = await fb.chat_stream(messages=[{"role": "user", "content": "hi"}])
assert result.content == "stream ok"
assert result.finish_reason == "stop"
class TestGetDefaultModel:
def test(self) -> None:
primary = _FakeProvider("primary")
fb = FallbackProvider(
primary=primary,
fallback_presets=[_fallback("a")],
provider_factory=MagicMock(),
)
assert fb.get_default_model() == "primary/model"
class TestCircuitBreaker:
@pytest.mark.asyncio
async def test_skips_primary_after_three_failures(self) -> None:
primary = _FakeProvider("primary", _error_response())
fallback = _FakeProvider("fallback", _make_response("fallback ok"))
factory = MagicMock(return_value=fallback)
fb = FallbackProvider(
primary=primary,
fallback_presets=[_fallback("fallback-a")],
provider_factory=factory,
)
# 3 failures — primary should still be called each time
for _ in range(3):
result = await fb.chat(messages=[{"role": "user", "content": "hi"}])
assert result.content == "fallback ok"
assert len(primary.chat_calls) == 3
# 4th call — primary circuit is open, should be skipped
primary.chat_calls.clear()
result = await fb.chat(messages=[{"role": "user", "content": "hi"}])
assert result.content == "fallback ok"
assert len(primary.chat_calls) == 0
@pytest.mark.asyncio
async def test_resets_on_success(self) -> None:
primary = _FakeProvider("primary", _error_response())
fallback = _FakeProvider("fallback", _make_response("fallback ok"))
factory = MagicMock(return_value=fallback)
fb = FallbackProvider(
primary=primary,
fallback_presets=[_fallback("fallback-a")],
provider_factory=factory,
)
# 2 failures
for _ in range(2):
await fb.chat(messages=[{"role": "user", "content": "hi"}])
# 3rd call: primary succeeds — circuit resets
primary._response = _make_response("primary ok")
result = await fb.chat(messages=[{"role": "user", "content": "hi"}])
assert result.content == "primary ok"
# 4th call: primary fails again — should still be called (counter reset)
primary._response = _error_response()
primary.chat_calls.clear()
result = await fb.chat(messages=[{"role": "user", "content": "hi"}])
assert result.content == "fallback ok"
assert len(primary.chat_calls) == 1
class TestGenerationForwarded:
def test(self) -> None:
from nanobot.providers.base import GenerationSettings
primary = _FakeProvider("primary")
primary.generation = GenerationSettings(temperature=0.5, max_tokens=1024)
fb = FallbackProvider(
primary=primary,
fallback_presets=[_fallback("a")],
provider_factory=MagicMock(),
)
assert fb.generation.temperature == 0.5
assert fb.generation.max_tokens == 1024