feat(runner): add model failover with fallback_models

When the primary model returns a non-transient error and no content
has been streamed yet, the runner now tries each model listed in the
active preset's fallback_models in order.  Each fallback model may
reside on a different provider — a temporary provider instance is
created on-the-fly via make_provider(config, model=...).

Key design:
- Failover is request-scoped (does not affect subagents/dream/consolidator)
- Provider is restored via try/finally after each fallback attempt
- Skipped when content was already streamed to avoid duplicate output
- Recursive failover prevented by clearing fallback_models on fallback spec
- Circuit breaker trips open after 3 consecutive primary failures (60s cooldown)
- Cross-provider routing: fallback model prefix (e.g. groq/) determines provider

Fixes: cross-provider fallback was broken because the factory passed the
original preset (with provider forced to primary's provider) when creating
fallback providers.  Now uses provider="auto" so the model string prefix
correctly routes to the right provider.

Also fixes: log messages now distinguish between primary-failed,
previous-fallback-failed, and circuit-open scenarios.

closes: https://github.com/HKUDS/nanobot/issues/3376
This commit is contained in:
chengyongru 2026-05-12 16:51:48 +08:00
parent 07f9ab580a
commit 913b0774d8
4 changed files with 584 additions and 3 deletions

View File

@ -82,6 +82,7 @@ class ModelPresetConfig(Base):
context_window_tokens: int = 65_536 context_window_tokens: int = 65_536
temperature: float = 0.1 temperature: float = 0.1
reasoning_effort: str | None = None reasoning_effort: str | None = None
fallback_models: list[str] = Field(default_factory=list)
def to_generation_settings(self) -> Any: def to_generation_settings(self) -> Any:
from nanobot.providers.base import GenerationSettings from nanobot.providers.base import GenerationSettings

View File

@ -7,6 +7,7 @@ from pathlib import Path
from nanobot.config.schema import Config, ModelPresetConfig from nanobot.config.schema import Config, ModelPresetConfig
from nanobot.providers.base import LLMProvider from nanobot.providers.base import LLMProvider
from nanobot.providers.fallback_provider import FallbackProvider
from nanobot.providers.registry import find_by_name from nanobot.providers.registry import find_by_name
@ -27,15 +28,16 @@ def _resolve_model_preset(
return preset if preset is not None else config.resolve_preset(preset_name) return preset if preset is not None else config.resolve_preset(preset_name)
def make_provider( def _make_provider_core(
config: Config, config: Config,
*, *,
preset_name: str | None = None, preset_name: str | None = None,
preset: ModelPresetConfig | None = None, preset: ModelPresetConfig | None = None,
model: str | None = None,
) -> LLMProvider: ) -> LLMProvider:
"""Create the LLM provider implied by config.""" """Create a plain LLM provider without failover wrapping."""
resolved = _resolve_model_preset(config, preset_name=preset_name, preset=preset) resolved = _resolve_model_preset(config, preset_name=preset_name, preset=preset)
model = resolved.model model = model or resolved.model
provider_name = config.get_provider_name(model, preset=resolved) provider_name = config.get_provider_name(model, preset=resolved)
p = config.get_provider(model, preset=resolved) p = config.get_provider(model, preset=resolved)
spec = find_by_name(provider_name) if provider_name else None spec = find_by_name(provider_name) if provider_name else None
@ -102,6 +104,34 @@ def make_provider(
return provider return provider
def make_provider(
config: Config,
*,
preset_name: str | None = None,
preset: ModelPresetConfig | None = None,
model: str | None = None,
) -> LLMProvider:
"""Create the LLM provider implied by config.
When *model* is given, it overrides the resolved/preset model used by
the failover path to create providers for fallback models.
"""
resolved = _resolve_model_preset(config, preset_name=preset_name, preset=preset)
provider = _make_provider_core(config, preset_name=preset_name, preset=preset, model=model)
if resolved.fallback_models:
fb_preset = resolved.model_copy(update={"provider": "auto", "fallback_models": []})
provider = FallbackProvider(
primary=provider,
fallback_models=resolved.fallback_models,
provider_factory=lambda m: _make_provider_core(
config, preset_name=preset_name, preset=fb_preset, model=m
),
)
return provider
def provider_signature( def provider_signature(
config: Config, config: Config,
*, *,

View File

@ -0,0 +1,186 @@
"""Provider wrapper that transparently fails over to fallback models on error."""
from __future__ import annotations
import time
from collections.abc import Awaitable, Callable
from typing import Any
from loguru import logger
from nanobot.providers.base import LLMProvider, LLMResponse
# Circuit breaker tuned to match OpenAICompatProvider's Responses API breaker.
_PRIMARY_FAILURE_THRESHOLD = 3
_PRIMARY_COOLDOWN_S = 60
class FallbackProvider(LLMProvider):
"""Wrap a primary provider and transparently failover to fallback models.
When the primary model returns an error and no content has been streamed yet,
the wrapper tries each fallback model in order. Each fallback model may
reside on a different provider a factory callable creates the underlying
provider on-the-fly.
Key design:
- Failover is request-scoped (the wrapper itself is stateless between turns).
- Skipped when content was already streamed to avoid duplicate output.
- Recursive failover is prevented by the factory returning plain providers.
- Primary provider is circuit-broken after repeated failures to avoid
wasting requests on a known-bad endpoint.
"""
def __init__(
self,
primary: LLMProvider,
fallback_models: list[str],
provider_factory: Callable[[str], LLMProvider],
):
self._primary = primary
self._fallback_models = list(fallback_models)
self._provider_factory = provider_factory
self._has_fallbacks = bool(fallback_models)
self._primary_failures = 0
self._primary_tripped_at: float | None = None
@property
def generation(self):
return self._primary.generation
@generation.setter
def generation(self, value):
self._primary.generation = value
def get_default_model(self) -> str:
return self._primary.get_default_model()
def _primary_available(self) -> bool:
"""Return True if the primary provider is not currently tripped."""
if self._primary_tripped_at is None:
return True
if time.monotonic() - self._primary_tripped_at >= _PRIMARY_COOLDOWN_S:
# Half-open: allow one probe attempt.
return True
return False
async def chat(self, **kwargs: Any) -> LLMResponse:
if not self._has_fallbacks:
return await self._primary.chat(**kwargs)
return await self._try_with_fallback(
lambda p, kw: p.chat(**kw), kwargs, has_streamed=None
)
async def chat_stream(self, **kwargs: Any) -> LLMResponse:
if not self._has_fallbacks:
return await self._primary.chat_stream(**kwargs)
has_streamed: list[bool] = [False]
original_delta = kwargs.get("on_content_delta")
async def _tracking_delta(text: str) -> None:
if text:
has_streamed[0] = True
if original_delta:
await original_delta(text)
kwargs["on_content_delta"] = _tracking_delta
return await self._try_with_fallback(
lambda p, kw: p.chat_stream(**kw), kwargs, has_streamed=has_streamed
)
async def _try_with_fallback(
self,
call: Callable[[LLMProvider, dict[str, Any]], Awaitable[LLMResponse]],
kwargs: dict[str, Any],
has_streamed: list[bool] | None,
) -> LLMResponse:
primary_model = kwargs.get("model") or self._primary.get_default_model()
if self._primary_available():
response = await call(self._primary, kwargs)
if response.finish_reason != "error":
self._primary_failures = 0
self._primary_tripped_at = None
return response
if has_streamed is not None and has_streamed[0]:
logger.warning(
"Primary model error but content already streamed; skipping failover"
)
return response
self._primary_failures += 1
if self._primary_failures >= _PRIMARY_FAILURE_THRESHOLD:
self._primary_tripped_at = time.monotonic()
logger.warning(
"Primary model '{}' circuit open after {} consecutive failures",
primary_model, self._primary_failures,
)
else:
logger.debug("Primary model '{}' circuit open; skipping", primary_model)
last_response: LLMResponse | None = None
primary_skipped = not self._primary_available()
for idx, fallback_model in enumerate(self._fallback_models):
if has_streamed is not None and has_streamed[0]:
break
if idx == 0 and primary_skipped:
logger.info(
"Primary model '{}' circuit open, trying fallback '{}'",
primary_model, fallback_model,
)
elif idx == 0:
logger.info(
"Primary model '{}' failed, trying fallback '{}'",
primary_model, fallback_model,
)
else:
logger.info(
"Fallback '{}' also failed, trying next fallback '{}'",
self._fallback_models[idx - 1], fallback_model,
)
try:
fallback_provider = self._provider_factory(fallback_model)
except Exception as exc:
logger.warning(
"Failed to create provider for fallback '{}': {}", fallback_model, exc
)
continue
original_model = kwargs.get("model")
kwargs["model"] = fallback_model
try:
fallback_response = await call(fallback_provider, kwargs)
finally:
if original_model is not None:
kwargs["model"] = original_model
else:
kwargs.pop("model", None)
if fallback_response.finish_reason != "error":
logger.info(
"Fallback '{}' succeeded after primary '{}' failed",
fallback_model, primary_model,
)
return fallback_response
last_response = fallback_response
logger.warning(
"Fallback '{}' also failed: {}",
fallback_model,
(fallback_response.content or "")[:120],
)
logger.warning(
"All {} fallback model(s) failed",
len(self._fallback_models),
)
# Return the last error response we saw (primary or last fallback).
if last_response is not None:
return last_response
# Primary was tripped and we have no fallbacks — synthesize an error.
return LLMResponse(
content=f"Primary model '{primary_model}' circuit open and no fallbacks available",
finish_reason="error",
)

View File

@ -0,0 +1,364 @@
"""Tests for FallbackProvider model failover."""
from __future__ import annotations
from typing import Any
from unittest.mock import MagicMock
import pytest
from nanobot.providers.base import LLMProvider, LLMResponse
from nanobot.providers.fallback_provider import FallbackProvider
def _make_response(
content: str = "ok",
finish_reason: str = "stop",
*,
error_kind: str | None = None,
) -> LLMResponse:
return LLMResponse(content=content, finish_reason=finish_reason, error_kind=error_kind)
def _error_response(content: str = "api error") -> LLMResponse:
return _make_response(content, finish_reason="error", error_kind="server_error")
class _FakeProvider(LLMProvider):
"""Fake provider for testing."""
def __init__(self, name: str = "fake", response: LLMResponse | None = None):
super().__init__()
self.name = name
self._response = response or _make_response()
self.chat_calls: list[dict[str, Any]] = []
self.chat_stream_calls: list[dict[str, Any]] = []
def get_default_model(self) -> str:
return f"{self.name}/model"
async def chat(self, **kwargs: Any) -> LLMResponse:
self.chat_calls.append(dict(kwargs))
return self._response
async def chat_stream(self, **kwargs: Any) -> LLMResponse:
self.chat_stream_calls.append(dict(kwargs))
on_delta = kwargs.get("on_content_delta")
if on_delta and self._response.content:
await on_delta(self._response.content)
return self._response
# -- config-level tests --
def test_fallback_models_default_empty() -> None:
from nanobot.config.schema import ModelPresetConfig
p = ModelPresetConfig(model="test/model")
assert p.fallback_models == []
def test_fallback_models_accepts_list() -> None:
from nanobot.config.schema import ModelPresetConfig
p = ModelPresetConfig(model="test/primary", fallback_models=["test/a", "test/b"])
assert p.fallback_models == ["test/a", "test/b"]
def test_fallback_models_from_camel_case() -> None:
from nanobot.config.schema import ModelPresetConfig
p = ModelPresetConfig.model_validate({
"model": "test/primary",
"fallbackModels": ["test/a"],
})
assert p.fallback_models == ["test/a"]
# -- FallbackProvider tests --
class TestNoFallbackWhenPrimarySucceeds:
@pytest.mark.asyncio
async def test(self) -> None:
primary = _FakeProvider("primary", _make_response("primary ok"))
factory = MagicMock()
fb = FallbackProvider(
primary=primary,
fallback_models=["fallback-a"],
provider_factory=factory,
)
result = await fb.chat(messages=[{"role": "user", "content": "hi"}])
assert result.content == "primary ok"
assert result.finish_reason == "stop"
factory.assert_not_called()
class TestFallbackOnPrimaryError:
@pytest.mark.asyncio
async def test_first_fallback_succeeds(self) -> None:
primary = _FakeProvider("primary", _error_response())
fallback = _FakeProvider("fallback", _make_response("fallback ok"))
factory = MagicMock(return_value=fallback)
fb = FallbackProvider(
primary=primary,
fallback_models=["fallback-a"],
provider_factory=factory,
)
result = await fb.chat(messages=[{"role": "user", "content": "hi"}], model="primary-model")
assert result.content == "fallback ok"
assert result.finish_reason == "stop"
factory.assert_called_once_with("fallback-a")
assert primary.chat_calls[0]["model"] == "primary-model"
assert fallback.chat_calls[0]["model"] == "fallback-a"
class TestNoFallbackWhenContentStreamed:
@pytest.mark.asyncio
async def test(self) -> None:
primary = _FakeProvider("primary", _error_response())
factory = MagicMock()
fb = FallbackProvider(
primary=primary,
fallback_models=["fallback-a"],
provider_factory=factory,
)
async def _delta(text: str) -> None:
pass
result = await fb.chat_stream(
messages=[{"role": "user", "content": "hi"}],
on_content_delta=_delta,
)
# Primary returns error but content was "streamed" (FakeProvider calls delta)
# so failover should be skipped
assert result.finish_reason == "error"
factory.assert_not_called()
class TestFailoverOnTransientError:
@pytest.mark.asyncio
async def test_rate_limit(self) -> None:
primary = _FakeProvider("primary", _error_response("rate limit exceeded"))
fallback = _FakeProvider("fallback", _make_response("fallback ok"))
factory = MagicMock(return_value=fallback)
fb = FallbackProvider(
primary=primary,
fallback_models=["fallback-a"],
provider_factory=factory,
)
result = await fb.chat(messages=[{"role": "user", "content": "hi"}])
assert result.content == "fallback ok"
assert result.finish_reason == "stop"
factory.assert_called_once_with("fallback-a")
@pytest.mark.asyncio
async def test_timeout(self) -> None:
primary = _FakeProvider(
"primary",
_make_response("timed out", finish_reason="error", error_kind="timeout"),
)
fallback = _FakeProvider("fallback", _make_response("fallback ok"))
factory = MagicMock(return_value=fallback)
fb = FallbackProvider(
primary=primary,
fallback_models=["fallback-a"],
provider_factory=factory,
)
result = await fb.chat(messages=[{"role": "user", "content": "hi"}])
assert result.content == "fallback ok"
assert result.finish_reason == "stop"
factory.assert_called_once_with("fallback-a")
class TestFallbackTriesModelsInOrder:
@pytest.mark.asyncio
async def test(self) -> None:
primary = _FakeProvider("primary", _error_response("primary fail"))
fallback_a = _FakeProvider("a", _error_response("a fail"))
fallback_b = _FakeProvider("b", _make_response("b ok"))
factory = MagicMock(side_effect=[fallback_a, fallback_b])
fb = FallbackProvider(
primary=primary,
fallback_models=["fallback-a", "fallback-b"],
provider_factory=factory,
)
result = await fb.chat(messages=[{"role": "user", "content": "hi"}])
assert result.content == "b ok"
assert factory.call_count == 2
factory.assert_any_call("fallback-a")
factory.assert_any_call("fallback-b")
class TestAllFallbacksFail:
@pytest.mark.asyncio
async def test(self) -> None:
primary = _FakeProvider("primary", _error_response("primary fail"))
fallback = _FakeProvider("fallback", _error_response("all fail"))
factory = MagicMock(return_value=fallback)
fb = FallbackProvider(
primary=primary,
fallback_models=["fallback-a"],
provider_factory=factory,
)
result = await fb.chat(messages=[{"role": "user", "content": "hi"}])
assert result.finish_reason == "error"
assert "all fail" in result.content
class TestFactoryExceptionSkipsModel:
@pytest.mark.asyncio
async def test(self) -> None:
primary = _FakeProvider("primary", _error_response())
fallback_b = _FakeProvider("b", _make_response("b ok"))
factory = MagicMock(side_effect=[ValueError("no key"), fallback_b])
fb = FallbackProvider(
primary=primary,
fallback_models=["fallback-a", "fallback-b"],
provider_factory=factory,
)
result = await fb.chat(messages=[{"role": "user", "content": "hi"}])
assert result.content == "b ok"
assert factory.call_count == 2
class TestFallbackModelParameter:
@pytest.mark.asyncio
async def test(self) -> None:
"""Fallback calls should use the fallback model name."""
primary = _FakeProvider("primary", _error_response())
fallback = _FakeProvider("fallback", _make_response("ok"))
factory = MagicMock(return_value=fallback)
fb = FallbackProvider(
primary=primary,
fallback_models=["fallback-model"],
provider_factory=factory,
)
await fb.chat(messages=[{"role": "user", "content": "hi"}], model="primary-model")
assert fallback.chat_calls[0]["model"] == "fallback-model"
class TestNoFallbackWhenEmptyList:
@pytest.mark.asyncio
async def test(self) -> None:
primary = _FakeProvider("primary", _error_response())
factory = MagicMock()
fb = FallbackProvider(
primary=primary,
fallback_models=[],
provider_factory=factory,
)
result = await fb.chat(messages=[{"role": "user", "content": "hi"}])
assert result.finish_reason == "error"
factory.assert_not_called()
class TestChatStreamFailover:
@pytest.mark.asyncio
async def test_fallback_succeeds(self) -> None:
# Use empty content so on_content_delta is not triggered on the error
primary = _FakeProvider("primary", _error_response(""))
fallback = _FakeProvider("fallback", _make_response("stream ok"))
factory = MagicMock(return_value=fallback)
fb = FallbackProvider(
primary=primary,
fallback_models=["fallback-a"],
provider_factory=factory,
)
result = await fb.chat_stream(messages=[{"role": "user", "content": "hi"}])
assert result.content == "stream ok"
assert result.finish_reason == "stop"
class TestGetDefaultModel:
def test(self) -> None:
primary = _FakeProvider("primary")
fb = FallbackProvider(
primary=primary,
fallback_models=["a"],
provider_factory=MagicMock(),
)
assert fb.get_default_model() == "primary/model"
class TestCircuitBreaker:
@pytest.mark.asyncio
async def test_skips_primary_after_three_failures(self) -> None:
primary = _FakeProvider("primary", _error_response())
fallback = _FakeProvider("fallback", _make_response("fallback ok"))
factory = MagicMock(return_value=fallback)
fb = FallbackProvider(
primary=primary,
fallback_models=["fallback-a"],
provider_factory=factory,
)
# 3 failures — primary should still be called each time
for _ in range(3):
result = await fb.chat(messages=[{"role": "user", "content": "hi"}])
assert result.content == "fallback ok"
assert len(primary.chat_calls) == 3
# 4th call — primary circuit is open, should be skipped
primary.chat_calls.clear()
result = await fb.chat(messages=[{"role": "user", "content": "hi"}])
assert result.content == "fallback ok"
assert len(primary.chat_calls) == 0
@pytest.mark.asyncio
async def test_resets_on_success(self) -> None:
primary = _FakeProvider("primary", _error_response())
fallback = _FakeProvider("fallback", _make_response("fallback ok"))
factory = MagicMock(return_value=fallback)
fb = FallbackProvider(
primary=primary,
fallback_models=["fallback-a"],
provider_factory=factory,
)
# 2 failures
for _ in range(2):
await fb.chat(messages=[{"role": "user", "content": "hi"}])
# 3rd call: primary succeeds — circuit resets
primary._response = _make_response("primary ok")
result = await fb.chat(messages=[{"role": "user", "content": "hi"}])
assert result.content == "primary ok"
# 4th call: primary fails again — should still be called (counter reset)
primary._response = _error_response()
primary.chat_calls.clear()
result = await fb.chat(messages=[{"role": "user", "content": "hi"}])
assert result.content == "fallback ok"
assert len(primary.chat_calls) == 1
class TestGenerationForwarded:
def test(self) -> None:
from nanobot.providers.base import GenerationSettings
primary = _FakeProvider("primary")
primary.generation = GenerationSettings(temperature=0.5, max_tokens=1024)
fb = FallbackProvider(
primary=primary,
fallback_models=["a"],
provider_factory=MagicMock(),
)
assert fb.generation.temperature == 0.5
assert fb.generation.max_tokens == 1024