mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-20 08:32:25 +00:00
feat(runner): add model failover with fallback_models
When the primary model returns a non-transient error and no content has been streamed yet, the runner now tries each model listed in the active preset's fallback_models in order. Each fallback model may reside on a different provider — a temporary provider instance is created on-the-fly via make_provider(config, model=...). Key design: - Failover is request-scoped (does not affect subagents/dream/consolidator) - Provider is restored via try/finally after each fallback attempt - Skipped when content was already streamed to avoid duplicate output - Recursive failover prevented by clearing fallback_models on fallback spec - Circuit breaker trips open after 3 consecutive primary failures (60s cooldown) - Cross-provider routing: fallback model prefix (e.g. groq/) determines provider Fixes: cross-provider fallback was broken because the factory passed the original preset (with provider forced to primary's provider) when creating fallback providers. Now uses provider="auto" so the model string prefix correctly routes to the right provider. Also fixes: log messages now distinguish between primary-failed, previous-fallback-failed, and circuit-open scenarios. closes: https://github.com/HKUDS/nanobot/issues/3376
This commit is contained in:
parent
07f9ab580a
commit
913b0774d8
@ -82,6 +82,7 @@ class ModelPresetConfig(Base):
|
|||||||
context_window_tokens: int = 65_536
|
context_window_tokens: int = 65_536
|
||||||
temperature: float = 0.1
|
temperature: float = 0.1
|
||||||
reasoning_effort: str | None = None
|
reasoning_effort: str | None = None
|
||||||
|
fallback_models: list[str] = Field(default_factory=list)
|
||||||
|
|
||||||
def to_generation_settings(self) -> Any:
|
def to_generation_settings(self) -> Any:
|
||||||
from nanobot.providers.base import GenerationSettings
|
from nanobot.providers.base import GenerationSettings
|
||||||
|
|||||||
@ -7,6 +7,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
from nanobot.config.schema import Config, ModelPresetConfig
|
from nanobot.config.schema import Config, ModelPresetConfig
|
||||||
from nanobot.providers.base import LLMProvider
|
from nanobot.providers.base import LLMProvider
|
||||||
|
from nanobot.providers.fallback_provider import FallbackProvider
|
||||||
from nanobot.providers.registry import find_by_name
|
from nanobot.providers.registry import find_by_name
|
||||||
|
|
||||||
|
|
||||||
@ -27,15 +28,16 @@ def _resolve_model_preset(
|
|||||||
return preset if preset is not None else config.resolve_preset(preset_name)
|
return preset if preset is not None else config.resolve_preset(preset_name)
|
||||||
|
|
||||||
|
|
||||||
def make_provider(
|
def _make_provider_core(
|
||||||
config: Config,
|
config: Config,
|
||||||
*,
|
*,
|
||||||
preset_name: str | None = None,
|
preset_name: str | None = None,
|
||||||
preset: ModelPresetConfig | None = None,
|
preset: ModelPresetConfig | None = None,
|
||||||
|
model: str | None = None,
|
||||||
) -> LLMProvider:
|
) -> LLMProvider:
|
||||||
"""Create the LLM provider implied by config."""
|
"""Create a plain LLM provider without failover wrapping."""
|
||||||
resolved = _resolve_model_preset(config, preset_name=preset_name, preset=preset)
|
resolved = _resolve_model_preset(config, preset_name=preset_name, preset=preset)
|
||||||
model = resolved.model
|
model = model or resolved.model
|
||||||
provider_name = config.get_provider_name(model, preset=resolved)
|
provider_name = config.get_provider_name(model, preset=resolved)
|
||||||
p = config.get_provider(model, preset=resolved)
|
p = config.get_provider(model, preset=resolved)
|
||||||
spec = find_by_name(provider_name) if provider_name else None
|
spec = find_by_name(provider_name) if provider_name else None
|
||||||
@ -102,6 +104,34 @@ def make_provider(
|
|||||||
return provider
|
return provider
|
||||||
|
|
||||||
|
|
||||||
|
def make_provider(
|
||||||
|
config: Config,
|
||||||
|
*,
|
||||||
|
preset_name: str | None = None,
|
||||||
|
preset: ModelPresetConfig | None = None,
|
||||||
|
model: str | None = None,
|
||||||
|
) -> LLMProvider:
|
||||||
|
"""Create the LLM provider implied by config.
|
||||||
|
|
||||||
|
When *model* is given, it overrides the resolved/preset model — used by
|
||||||
|
the failover path to create providers for fallback models.
|
||||||
|
"""
|
||||||
|
resolved = _resolve_model_preset(config, preset_name=preset_name, preset=preset)
|
||||||
|
provider = _make_provider_core(config, preset_name=preset_name, preset=preset, model=model)
|
||||||
|
|
||||||
|
if resolved.fallback_models:
|
||||||
|
fb_preset = resolved.model_copy(update={"provider": "auto", "fallback_models": []})
|
||||||
|
provider = FallbackProvider(
|
||||||
|
primary=provider,
|
||||||
|
fallback_models=resolved.fallback_models,
|
||||||
|
provider_factory=lambda m: _make_provider_core(
|
||||||
|
config, preset_name=preset_name, preset=fb_preset, model=m
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return provider
|
||||||
|
|
||||||
|
|
||||||
def provider_signature(
|
def provider_signature(
|
||||||
config: Config,
|
config: Config,
|
||||||
*,
|
*,
|
||||||
|
|||||||
186
nanobot/providers/fallback_provider.py
Normal file
186
nanobot/providers/fallback_provider.py
Normal file
@ -0,0 +1,186 @@
|
|||||||
|
"""Provider wrapper that transparently fails over to fallback models on error."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from nanobot.providers.base import LLMProvider, LLMResponse
|
||||||
|
|
||||||
|
# Circuit breaker tuned to match OpenAICompatProvider's Responses API breaker.
|
||||||
|
_PRIMARY_FAILURE_THRESHOLD = 3
|
||||||
|
_PRIMARY_COOLDOWN_S = 60
|
||||||
|
|
||||||
|
|
||||||
|
class FallbackProvider(LLMProvider):
|
||||||
|
"""Wrap a primary provider and transparently failover to fallback models.
|
||||||
|
|
||||||
|
When the primary model returns an error and no content has been streamed yet,
|
||||||
|
the wrapper tries each fallback model in order. Each fallback model may
|
||||||
|
reside on a different provider — a factory callable creates the underlying
|
||||||
|
provider on-the-fly.
|
||||||
|
|
||||||
|
Key design:
|
||||||
|
- Failover is request-scoped (the wrapper itself is stateless between turns).
|
||||||
|
- Skipped when content was already streamed to avoid duplicate output.
|
||||||
|
- Recursive failover is prevented by the factory returning plain providers.
|
||||||
|
- Primary provider is circuit-broken after repeated failures to avoid
|
||||||
|
wasting requests on a known-bad endpoint.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
primary: LLMProvider,
|
||||||
|
fallback_models: list[str],
|
||||||
|
provider_factory: Callable[[str], LLMProvider],
|
||||||
|
):
|
||||||
|
self._primary = primary
|
||||||
|
self._fallback_models = list(fallback_models)
|
||||||
|
self._provider_factory = provider_factory
|
||||||
|
self._has_fallbacks = bool(fallback_models)
|
||||||
|
self._primary_failures = 0
|
||||||
|
self._primary_tripped_at: float | None = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def generation(self):
|
||||||
|
return self._primary.generation
|
||||||
|
|
||||||
|
@generation.setter
|
||||||
|
def generation(self, value):
|
||||||
|
self._primary.generation = value
|
||||||
|
|
||||||
|
def get_default_model(self) -> str:
|
||||||
|
return self._primary.get_default_model()
|
||||||
|
|
||||||
|
def _primary_available(self) -> bool:
|
||||||
|
"""Return True if the primary provider is not currently tripped."""
|
||||||
|
if self._primary_tripped_at is None:
|
||||||
|
return True
|
||||||
|
if time.monotonic() - self._primary_tripped_at >= _PRIMARY_COOLDOWN_S:
|
||||||
|
# Half-open: allow one probe attempt.
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def chat(self, **kwargs: Any) -> LLMResponse:
|
||||||
|
if not self._has_fallbacks:
|
||||||
|
return await self._primary.chat(**kwargs)
|
||||||
|
return await self._try_with_fallback(
|
||||||
|
lambda p, kw: p.chat(**kw), kwargs, has_streamed=None
|
||||||
|
)
|
||||||
|
|
||||||
|
async def chat_stream(self, **kwargs: Any) -> LLMResponse:
|
||||||
|
if not self._has_fallbacks:
|
||||||
|
return await self._primary.chat_stream(**kwargs)
|
||||||
|
|
||||||
|
has_streamed: list[bool] = [False]
|
||||||
|
original_delta = kwargs.get("on_content_delta")
|
||||||
|
|
||||||
|
async def _tracking_delta(text: str) -> None:
|
||||||
|
if text:
|
||||||
|
has_streamed[0] = True
|
||||||
|
if original_delta:
|
||||||
|
await original_delta(text)
|
||||||
|
|
||||||
|
kwargs["on_content_delta"] = _tracking_delta
|
||||||
|
return await self._try_with_fallback(
|
||||||
|
lambda p, kw: p.chat_stream(**kw), kwargs, has_streamed=has_streamed
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _try_with_fallback(
|
||||||
|
self,
|
||||||
|
call: Callable[[LLMProvider, dict[str, Any]], Awaitable[LLMResponse]],
|
||||||
|
kwargs: dict[str, Any],
|
||||||
|
has_streamed: list[bool] | None,
|
||||||
|
) -> LLMResponse:
|
||||||
|
primary_model = kwargs.get("model") or self._primary.get_default_model()
|
||||||
|
|
||||||
|
if self._primary_available():
|
||||||
|
response = await call(self._primary, kwargs)
|
||||||
|
if response.finish_reason != "error":
|
||||||
|
self._primary_failures = 0
|
||||||
|
self._primary_tripped_at = None
|
||||||
|
return response
|
||||||
|
|
||||||
|
if has_streamed is not None and has_streamed[0]:
|
||||||
|
logger.warning(
|
||||||
|
"Primary model error but content already streamed; skipping failover"
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
self._primary_failures += 1
|
||||||
|
if self._primary_failures >= _PRIMARY_FAILURE_THRESHOLD:
|
||||||
|
self._primary_tripped_at = time.monotonic()
|
||||||
|
logger.warning(
|
||||||
|
"Primary model '{}' circuit open after {} consecutive failures",
|
||||||
|
primary_model, self._primary_failures,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.debug("Primary model '{}' circuit open; skipping", primary_model)
|
||||||
|
|
||||||
|
last_response: LLMResponse | None = None
|
||||||
|
primary_skipped = not self._primary_available()
|
||||||
|
for idx, fallback_model in enumerate(self._fallback_models):
|
||||||
|
if has_streamed is not None and has_streamed[0]:
|
||||||
|
break
|
||||||
|
if idx == 0 and primary_skipped:
|
||||||
|
logger.info(
|
||||||
|
"Primary model '{}' circuit open, trying fallback '{}'",
|
||||||
|
primary_model, fallback_model,
|
||||||
|
)
|
||||||
|
elif idx == 0:
|
||||||
|
logger.info(
|
||||||
|
"Primary model '{}' failed, trying fallback '{}'",
|
||||||
|
primary_model, fallback_model,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
"Fallback '{}' also failed, trying next fallback '{}'",
|
||||||
|
self._fallback_models[idx - 1], fallback_model,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
fallback_provider = self._provider_factory(fallback_model)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to create provider for fallback '{}': {}", fallback_model, exc
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
original_model = kwargs.get("model")
|
||||||
|
kwargs["model"] = fallback_model
|
||||||
|
try:
|
||||||
|
fallback_response = await call(fallback_provider, kwargs)
|
||||||
|
finally:
|
||||||
|
if original_model is not None:
|
||||||
|
kwargs["model"] = original_model
|
||||||
|
else:
|
||||||
|
kwargs.pop("model", None)
|
||||||
|
|
||||||
|
if fallback_response.finish_reason != "error":
|
||||||
|
logger.info(
|
||||||
|
"Fallback '{}' succeeded after primary '{}' failed",
|
||||||
|
fallback_model, primary_model,
|
||||||
|
)
|
||||||
|
return fallback_response
|
||||||
|
|
||||||
|
last_response = fallback_response
|
||||||
|
logger.warning(
|
||||||
|
"Fallback '{}' also failed: {}",
|
||||||
|
fallback_model,
|
||||||
|
(fallback_response.content or "")[:120],
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
"All {} fallback model(s) failed",
|
||||||
|
len(self._fallback_models),
|
||||||
|
)
|
||||||
|
# Return the last error response we saw (primary or last fallback).
|
||||||
|
if last_response is not None:
|
||||||
|
return last_response
|
||||||
|
# Primary was tripped and we have no fallbacks — synthesize an error.
|
||||||
|
return LLMResponse(
|
||||||
|
content=f"Primary model '{primary_model}' circuit open and no fallbacks available",
|
||||||
|
finish_reason="error",
|
||||||
|
)
|
||||||
364
tests/agent/test_runner_fallback.py
Normal file
364
tests/agent/test_runner_fallback.py
Normal file
@ -0,0 +1,364 @@
|
|||||||
|
"""Tests for FallbackProvider model failover."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.providers.base import LLMProvider, LLMResponse
|
||||||
|
from nanobot.providers.fallback_provider import FallbackProvider
|
||||||
|
|
||||||
|
|
||||||
|
def _make_response(
|
||||||
|
content: str = "ok",
|
||||||
|
finish_reason: str = "stop",
|
||||||
|
*,
|
||||||
|
error_kind: str | None = None,
|
||||||
|
) -> LLMResponse:
|
||||||
|
return LLMResponse(content=content, finish_reason=finish_reason, error_kind=error_kind)
|
||||||
|
|
||||||
|
|
||||||
|
def _error_response(content: str = "api error") -> LLMResponse:
|
||||||
|
return _make_response(content, finish_reason="error", error_kind="server_error")
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeProvider(LLMProvider):
|
||||||
|
"""Fake provider for testing."""
|
||||||
|
|
||||||
|
def __init__(self, name: str = "fake", response: LLMResponse | None = None):
|
||||||
|
super().__init__()
|
||||||
|
self.name = name
|
||||||
|
self._response = response or _make_response()
|
||||||
|
self.chat_calls: list[dict[str, Any]] = []
|
||||||
|
self.chat_stream_calls: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
def get_default_model(self) -> str:
|
||||||
|
return f"{self.name}/model"
|
||||||
|
|
||||||
|
async def chat(self, **kwargs: Any) -> LLMResponse:
|
||||||
|
self.chat_calls.append(dict(kwargs))
|
||||||
|
return self._response
|
||||||
|
|
||||||
|
async def chat_stream(self, **kwargs: Any) -> LLMResponse:
|
||||||
|
self.chat_stream_calls.append(dict(kwargs))
|
||||||
|
on_delta = kwargs.get("on_content_delta")
|
||||||
|
if on_delta and self._response.content:
|
||||||
|
await on_delta(self._response.content)
|
||||||
|
return self._response
|
||||||
|
|
||||||
|
|
||||||
|
# -- config-level tests --
|
||||||
|
|
||||||
|
|
||||||
|
def test_fallback_models_default_empty() -> None:
|
||||||
|
from nanobot.config.schema import ModelPresetConfig
|
||||||
|
p = ModelPresetConfig(model="test/model")
|
||||||
|
assert p.fallback_models == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_fallback_models_accepts_list() -> None:
|
||||||
|
from nanobot.config.schema import ModelPresetConfig
|
||||||
|
p = ModelPresetConfig(model="test/primary", fallback_models=["test/a", "test/b"])
|
||||||
|
assert p.fallback_models == ["test/a", "test/b"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_fallback_models_from_camel_case() -> None:
|
||||||
|
from nanobot.config.schema import ModelPresetConfig
|
||||||
|
p = ModelPresetConfig.model_validate({
|
||||||
|
"model": "test/primary",
|
||||||
|
"fallbackModels": ["test/a"],
|
||||||
|
})
|
||||||
|
assert p.fallback_models == ["test/a"]
|
||||||
|
|
||||||
|
|
||||||
|
# -- FallbackProvider tests --
|
||||||
|
|
||||||
|
|
||||||
|
class TestNoFallbackWhenPrimarySucceeds:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test(self) -> None:
|
||||||
|
primary = _FakeProvider("primary", _make_response("primary ok"))
|
||||||
|
factory = MagicMock()
|
||||||
|
fb = FallbackProvider(
|
||||||
|
primary=primary,
|
||||||
|
fallback_models=["fallback-a"],
|
||||||
|
provider_factory=factory,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await fb.chat(messages=[{"role": "user", "content": "hi"}])
|
||||||
|
assert result.content == "primary ok"
|
||||||
|
assert result.finish_reason == "stop"
|
||||||
|
factory.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
class TestFallbackOnPrimaryError:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_first_fallback_succeeds(self) -> None:
|
||||||
|
primary = _FakeProvider("primary", _error_response())
|
||||||
|
fallback = _FakeProvider("fallback", _make_response("fallback ok"))
|
||||||
|
factory = MagicMock(return_value=fallback)
|
||||||
|
|
||||||
|
fb = FallbackProvider(
|
||||||
|
primary=primary,
|
||||||
|
fallback_models=["fallback-a"],
|
||||||
|
provider_factory=factory,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await fb.chat(messages=[{"role": "user", "content": "hi"}], model="primary-model")
|
||||||
|
assert result.content == "fallback ok"
|
||||||
|
assert result.finish_reason == "stop"
|
||||||
|
factory.assert_called_once_with("fallback-a")
|
||||||
|
assert primary.chat_calls[0]["model"] == "primary-model"
|
||||||
|
assert fallback.chat_calls[0]["model"] == "fallback-a"
|
||||||
|
|
||||||
|
|
||||||
|
class TestNoFallbackWhenContentStreamed:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test(self) -> None:
|
||||||
|
primary = _FakeProvider("primary", _error_response())
|
||||||
|
factory = MagicMock()
|
||||||
|
fb = FallbackProvider(
|
||||||
|
primary=primary,
|
||||||
|
fallback_models=["fallback-a"],
|
||||||
|
provider_factory=factory,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _delta(text: str) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
result = await fb.chat_stream(
|
||||||
|
messages=[{"role": "user", "content": "hi"}],
|
||||||
|
on_content_delta=_delta,
|
||||||
|
)
|
||||||
|
# Primary returns error but content was "streamed" (FakeProvider calls delta)
|
||||||
|
# so failover should be skipped
|
||||||
|
assert result.finish_reason == "error"
|
||||||
|
factory.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
class TestFailoverOnTransientError:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rate_limit(self) -> None:
|
||||||
|
primary = _FakeProvider("primary", _error_response("rate limit exceeded"))
|
||||||
|
fallback = _FakeProvider("fallback", _make_response("fallback ok"))
|
||||||
|
factory = MagicMock(return_value=fallback)
|
||||||
|
fb = FallbackProvider(
|
||||||
|
primary=primary,
|
||||||
|
fallback_models=["fallback-a"],
|
||||||
|
provider_factory=factory,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await fb.chat(messages=[{"role": "user", "content": "hi"}])
|
||||||
|
assert result.content == "fallback ok"
|
||||||
|
assert result.finish_reason == "stop"
|
||||||
|
factory.assert_called_once_with("fallback-a")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_timeout(self) -> None:
|
||||||
|
primary = _FakeProvider(
|
||||||
|
"primary",
|
||||||
|
_make_response("timed out", finish_reason="error", error_kind="timeout"),
|
||||||
|
)
|
||||||
|
fallback = _FakeProvider("fallback", _make_response("fallback ok"))
|
||||||
|
factory = MagicMock(return_value=fallback)
|
||||||
|
fb = FallbackProvider(
|
||||||
|
primary=primary,
|
||||||
|
fallback_models=["fallback-a"],
|
||||||
|
provider_factory=factory,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await fb.chat(messages=[{"role": "user", "content": "hi"}])
|
||||||
|
assert result.content == "fallback ok"
|
||||||
|
assert result.finish_reason == "stop"
|
||||||
|
factory.assert_called_once_with("fallback-a")
|
||||||
|
|
||||||
|
|
||||||
|
class TestFallbackTriesModelsInOrder:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test(self) -> None:
|
||||||
|
primary = _FakeProvider("primary", _error_response("primary fail"))
|
||||||
|
fallback_a = _FakeProvider("a", _error_response("a fail"))
|
||||||
|
fallback_b = _FakeProvider("b", _make_response("b ok"))
|
||||||
|
factory = MagicMock(side_effect=[fallback_a, fallback_b])
|
||||||
|
|
||||||
|
fb = FallbackProvider(
|
||||||
|
primary=primary,
|
||||||
|
fallback_models=["fallback-a", "fallback-b"],
|
||||||
|
provider_factory=factory,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await fb.chat(messages=[{"role": "user", "content": "hi"}])
|
||||||
|
assert result.content == "b ok"
|
||||||
|
assert factory.call_count == 2
|
||||||
|
factory.assert_any_call("fallback-a")
|
||||||
|
factory.assert_any_call("fallback-b")
|
||||||
|
|
||||||
|
|
||||||
|
class TestAllFallbacksFail:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test(self) -> None:
|
||||||
|
primary = _FakeProvider("primary", _error_response("primary fail"))
|
||||||
|
fallback = _FakeProvider("fallback", _error_response("all fail"))
|
||||||
|
factory = MagicMock(return_value=fallback)
|
||||||
|
|
||||||
|
fb = FallbackProvider(
|
||||||
|
primary=primary,
|
||||||
|
fallback_models=["fallback-a"],
|
||||||
|
provider_factory=factory,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await fb.chat(messages=[{"role": "user", "content": "hi"}])
|
||||||
|
assert result.finish_reason == "error"
|
||||||
|
assert "all fail" in result.content
|
||||||
|
|
||||||
|
|
||||||
|
class TestFactoryExceptionSkipsModel:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test(self) -> None:
|
||||||
|
primary = _FakeProvider("primary", _error_response())
|
||||||
|
fallback_b = _FakeProvider("b", _make_response("b ok"))
|
||||||
|
factory = MagicMock(side_effect=[ValueError("no key"), fallback_b])
|
||||||
|
|
||||||
|
fb = FallbackProvider(
|
||||||
|
primary=primary,
|
||||||
|
fallback_models=["fallback-a", "fallback-b"],
|
||||||
|
provider_factory=factory,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await fb.chat(messages=[{"role": "user", "content": "hi"}])
|
||||||
|
assert result.content == "b ok"
|
||||||
|
assert factory.call_count == 2
|
||||||
|
|
||||||
|
|
||||||
|
class TestFallbackModelParameter:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test(self) -> None:
|
||||||
|
"""Fallback calls should use the fallback model name."""
|
||||||
|
primary = _FakeProvider("primary", _error_response())
|
||||||
|
fallback = _FakeProvider("fallback", _make_response("ok"))
|
||||||
|
factory = MagicMock(return_value=fallback)
|
||||||
|
|
||||||
|
fb = FallbackProvider(
|
||||||
|
primary=primary,
|
||||||
|
fallback_models=["fallback-model"],
|
||||||
|
provider_factory=factory,
|
||||||
|
)
|
||||||
|
|
||||||
|
await fb.chat(messages=[{"role": "user", "content": "hi"}], model="primary-model")
|
||||||
|
assert fallback.chat_calls[0]["model"] == "fallback-model"
|
||||||
|
|
||||||
|
|
||||||
|
class TestNoFallbackWhenEmptyList:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test(self) -> None:
|
||||||
|
primary = _FakeProvider("primary", _error_response())
|
||||||
|
factory = MagicMock()
|
||||||
|
|
||||||
|
fb = FallbackProvider(
|
||||||
|
primary=primary,
|
||||||
|
fallback_models=[],
|
||||||
|
provider_factory=factory,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await fb.chat(messages=[{"role": "user", "content": "hi"}])
|
||||||
|
assert result.finish_reason == "error"
|
||||||
|
factory.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
class TestChatStreamFailover:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fallback_succeeds(self) -> None:
|
||||||
|
# Use empty content so on_content_delta is not triggered on the error
|
||||||
|
primary = _FakeProvider("primary", _error_response(""))
|
||||||
|
fallback = _FakeProvider("fallback", _make_response("stream ok"))
|
||||||
|
factory = MagicMock(return_value=fallback)
|
||||||
|
|
||||||
|
fb = FallbackProvider(
|
||||||
|
primary=primary,
|
||||||
|
fallback_models=["fallback-a"],
|
||||||
|
provider_factory=factory,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await fb.chat_stream(messages=[{"role": "user", "content": "hi"}])
|
||||||
|
assert result.content == "stream ok"
|
||||||
|
assert result.finish_reason == "stop"
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetDefaultModel:
|
||||||
|
def test(self) -> None:
|
||||||
|
primary = _FakeProvider("primary")
|
||||||
|
fb = FallbackProvider(
|
||||||
|
primary=primary,
|
||||||
|
fallback_models=["a"],
|
||||||
|
provider_factory=MagicMock(),
|
||||||
|
)
|
||||||
|
assert fb.get_default_model() == "primary/model"
|
||||||
|
|
||||||
|
|
||||||
|
class TestCircuitBreaker:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_skips_primary_after_three_failures(self) -> None:
|
||||||
|
primary = _FakeProvider("primary", _error_response())
|
||||||
|
fallback = _FakeProvider("fallback", _make_response("fallback ok"))
|
||||||
|
factory = MagicMock(return_value=fallback)
|
||||||
|
fb = FallbackProvider(
|
||||||
|
primary=primary,
|
||||||
|
fallback_models=["fallback-a"],
|
||||||
|
provider_factory=factory,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3 failures — primary should still be called each time
|
||||||
|
for _ in range(3):
|
||||||
|
result = await fb.chat(messages=[{"role": "user", "content": "hi"}])
|
||||||
|
assert result.content == "fallback ok"
|
||||||
|
|
||||||
|
assert len(primary.chat_calls) == 3
|
||||||
|
|
||||||
|
# 4th call — primary circuit is open, should be skipped
|
||||||
|
primary.chat_calls.clear()
|
||||||
|
result = await fb.chat(messages=[{"role": "user", "content": "hi"}])
|
||||||
|
assert result.content == "fallback ok"
|
||||||
|
assert len(primary.chat_calls) == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_resets_on_success(self) -> None:
|
||||||
|
primary = _FakeProvider("primary", _error_response())
|
||||||
|
fallback = _FakeProvider("fallback", _make_response("fallback ok"))
|
||||||
|
factory = MagicMock(return_value=fallback)
|
||||||
|
fb = FallbackProvider(
|
||||||
|
primary=primary,
|
||||||
|
fallback_models=["fallback-a"],
|
||||||
|
provider_factory=factory,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2 failures
|
||||||
|
for _ in range(2):
|
||||||
|
await fb.chat(messages=[{"role": "user", "content": "hi"}])
|
||||||
|
|
||||||
|
# 3rd call: primary succeeds — circuit resets
|
||||||
|
primary._response = _make_response("primary ok")
|
||||||
|
result = await fb.chat(messages=[{"role": "user", "content": "hi"}])
|
||||||
|
assert result.content == "primary ok"
|
||||||
|
|
||||||
|
# 4th call: primary fails again — should still be called (counter reset)
|
||||||
|
primary._response = _error_response()
|
||||||
|
primary.chat_calls.clear()
|
||||||
|
result = await fb.chat(messages=[{"role": "user", "content": "hi"}])
|
||||||
|
assert result.content == "fallback ok"
|
||||||
|
assert len(primary.chat_calls) == 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestGenerationForwarded:
|
||||||
|
def test(self) -> None:
|
||||||
|
from nanobot.providers.base import GenerationSettings
|
||||||
|
primary = _FakeProvider("primary")
|
||||||
|
primary.generation = GenerationSettings(temperature=0.5, max_tokens=1024)
|
||||||
|
fb = FallbackProvider(
|
||||||
|
primary=primary,
|
||||||
|
fallback_models=["a"],
|
||||||
|
provider_factory=MagicMock(),
|
||||||
|
)
|
||||||
|
assert fb.generation.temperature == 0.5
|
||||||
|
assert fb.generation.max_tokens == 1024
|
||||||
Loading…
x
Reference in New Issue
Block a user