Refine fallback routing on model presets

This commit is contained in:
hanyuanling 2026-04-28 14:31:02 +08:00 committed by chengyongru
parent 2e5930e355
commit 7c270577e1
5 changed files with 393 additions and 30 deletions

View File

@ -16,6 +16,7 @@ from nanobot.agent.hook import AgentHook, AgentHookContext
from nanobot.agent.tools.ask import AskUserInterrupt from nanobot.agent.tools.ask import AskUserInterrupt
from nanobot.agent.tools.registry import ToolRegistry from nanobot.agent.tools.registry import ToolRegistry
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
from nanobot.providers.failover import ModelCandidate, ModelRouter
from nanobot.utils.helpers import ( from nanobot.utils.helpers import (
build_assistant_message, build_assistant_message,
estimate_message_tokens, estimate_message_tokens,
@ -621,24 +622,14 @@ class AgentRunner:
messages, messages,
tools=spec.tools.get_definitions(), tools=spec.tools.get_definitions(),
) )
response = await self._call_provider(self.provider, kwargs, hook, context, spec, timeout_s) provider: LLMProvider = self.provider
request_timeout = timeout_s
if response.finish_reason == "error" and spec.fallback_models: if spec.fallback_models:
for fb_model in spec.fallback_models: provider = self._build_model_router(spec, timeout_s)
logger.warning( # ModelRouter applies the same timeout per candidate, preserving
"Primary model {} failed, trying fallback: {}", # fallback on primary timeouts instead of timing out the whole chain.
spec.model, request_timeout = None
fb_model, return await self._call_provider(provider, kwargs, hook, context, spec, request_timeout)
)
fb_provider, resolved_model = self._resolve_fallback_provider(fb_model)
fb_kwargs = dict(kwargs, model=resolved_model)
response = await self._call_provider(
fb_provider, fb_kwargs, hook, context, spec, timeout_s,
)
if response.finish_reason != "error":
break
return response
async def _call_provider( async def _call_provider(
self, self,
@ -716,6 +707,25 @@ class AgentRunner:
return provider, provider.get_default_model() return provider, provider.get_default_model()
return self.provider, model return self.provider, model
def _build_model_router(
self,
spec: AgentRunSpec,
timeout_s: float | None,
) -> ModelRouter:
candidates = [
ModelCandidate(
label=model,
resolver=lambda m=model: self._resolve_fallback_provider(m),
)
for model in spec.fallback_models
]
return ModelRouter(
primary_provider=self.provider,
primary_model=spec.model,
fallback_candidates=candidates,
per_candidate_timeout_s=timeout_s,
)
async def _request_finalization_retry( async def _request_finalization_retry(
self, self,
spec: AgentRunSpec, spec: AgentRunSpec,

View File

@ -527,8 +527,7 @@ def _make_cli_provider_factory(config: Config):
def factory(model_or_preset: str): def factory(model_or_preset: str):
preset = presets.get(model_or_preset) preset = presets.get(model_or_preset)
actual_model = preset.model if preset else model_or_preset actual_model = preset.model if preset else model_or_preset
provider_name = config.get_provider_name(actual_model) key = actual_model
key = provider_name or actual_model
if key not in cache: if key not in cache:
cache[key] = _make_provider_for_model(config, actual_model, preset=preset) cache[key] = _make_provider_for_model(config, actual_model, preset=preset)
return cache[key] return cache[key]

View File

@ -220,8 +220,7 @@ def _make_provider_factory(config: Any):
def factory(model_or_preset: str): def factory(model_or_preset: str):
preset = presets.get(model_or_preset) preset = presets.get(model_or_preset)
actual_model = preset.model if preset else model_or_preset actual_model = preset.model if preset else model_or_preset
provider_name = config.get_provider_name(actual_model) key = actual_model
key = provider_name or actual_model
if key not in cache: if key not in cache:
cache[key] = _make_provider_for_model(config, actual_model, preset=preset) cache[key] = _make_provider_for_model(config, actual_model, preset=preset)
return cache[key] return cache[key]

View File

@ -0,0 +1,276 @@
"""Provider-like failover router used after provider-local retry is exhausted."""
from __future__ import annotations
import asyncio
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from typing import Any
from loguru import logger
from nanobot.providers.base import GenerationSettings, LLMProvider, LLMResponse
@dataclass(frozen=True)
class ModelCandidate:
"""A lazily resolved model/provider candidate."""
label: str
resolver: Callable[[], tuple[LLMProvider, str]]
class ModelRouter(LLMProvider):
"""Try fallback model candidates for eligible transient final errors."""
supports_progress_deltas = False
_BLOCKED_STATUS_CODES = frozenset({400, 401, 403, 404, 422})
_QUOTA_MARKERS = (
"insufficient_quota",
"insufficient quota",
"quota exceeded",
"quota_exceeded",
"quota exhausted",
"quota_exhausted",
"billing hard limit",
"billing_hard_limit_reached",
"billing not active",
"insufficient balance",
"insufficient_balance",
"credit balance too low",
"payment required",
"out of credits",
"out of quota",
"exceeded your current quota",
)
_NON_FAILOVER_MARKERS = (
"context length",
"context_length",
"maximum context",
"max context",
"token budget",
"too many tokens",
"schema",
"invalid request",
"invalid_request",
"invalid parameter",
"invalid_parameter",
"unsupported",
"unauthorized",
"authentication",
"permission",
"forbidden",
"refusal",
"content policy",
"content_filter",
"policy violation",
"safety",
)
def __init__(
self,
*,
primary_provider: LLMProvider,
primary_model: str,
fallback_candidates: list[ModelCandidate],
per_candidate_timeout_s: float | None = None,
) -> None:
super().__init__(
api_key=getattr(primary_provider, "api_key", None),
api_base=getattr(primary_provider, "api_base", None),
)
self.primary_provider = primary_provider
self.primary_model = primary_model
self.fallback_candidates = list(fallback_candidates)
self.per_candidate_timeout_s = per_candidate_timeout_s
self.generation = getattr(primary_provider, "generation", GenerationSettings())
def get_default_model(self) -> str:
return self.primary_model
async def chat(self, **kwargs: Any) -> LLMResponse:
return await self.primary_provider.chat(**kwargs)
async def chat_stream(self, **kwargs: Any) -> LLMResponse:
return await self.primary_provider.chat_stream(**kwargs)
@classmethod
def _is_quota_error(cls, response: LLMResponse) -> bool:
tokens = {
cls._normalize_error_token(response.error_type),
cls._normalize_error_token(response.error_code),
}
if any(token in cls._NON_RETRYABLE_429_ERROR_TOKENS for token in tokens if token):
return True
content = (response.content or "").lower()
return any(marker in content for marker in cls._QUOTA_MARKERS)
@classmethod
def _is_blocked_error(cls, response: LLMResponse) -> bool:
status = response.error_status_code
if status is not None and int(status) in cls._BLOCKED_STATUS_CODES:
return True
if response.finish_reason in {"refusal", "content_filter"}:
return True
content = (response.content or "").lower()
return any(marker in content for marker in cls._NON_FAILOVER_MARKERS)
@classmethod
def _should_failover(cls, response: LLMResponse) -> bool:
if response.finish_reason != "error":
return False
if cls._is_blocked_error(response):
return False
if cls._is_quota_error(response):
return False
return cls._is_transient_response(response)
async def _with_timeout(self, coro: Awaitable[LLMResponse]) -> LLMResponse:
timeout_s = self.per_candidate_timeout_s
if timeout_s is None:
return await coro
try:
return await asyncio.wait_for(coro, timeout=timeout_s)
except asyncio.TimeoutError:
return LLMResponse(
content=f"Error calling LLM: timed out after {timeout_s:g}s",
finish_reason="error",
error_kind="timeout",
)
@staticmethod
def _resolver_error(candidate: ModelCandidate, exc: Exception) -> LLMResponse:
logger.warning("Failed to resolve fallback model {}: {}", candidate.label, exc)
return LLMResponse(
content=f"Error configuring fallback model {candidate.label}: {exc}",
finish_reason="error",
error_kind="configuration",
error_should_retry=False,
)
def _candidate_chain(self) -> list[ModelCandidate]:
return [
ModelCandidate(
label=self.primary_model,
resolver=lambda: (self.primary_provider, self.primary_model),
),
*self.fallback_candidates,
]
async def _route(
self,
call: Callable[[LLMProvider, str, Callable[[str], Awaitable[None]] | None], Awaitable[LLMResponse]],
*,
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
) -> LLMResponse:
last_response: LLMResponse | None = None
chain = self._candidate_chain()
for index, candidate in enumerate(chain):
try:
provider, model = candidate.resolver()
except asyncio.CancelledError:
raise
except Exception as exc:
response = self._resolver_error(candidate, exc)
else:
response = await self._with_timeout(call(provider, model, on_content_delta))
if response.finish_reason != "error":
if index > 0:
logger.info("LLM failover selected model={}", candidate.label)
return response
last_response = response
if not self._should_failover(response):
return response
if index + 1 >= len(chain):
logger.warning("LLM failover exhausted after model={}", candidate.label)
return response
logger.warning(
"LLM failover model={} next_model={} status={} kind={}",
candidate.label,
chain[index + 1].label,
response.error_status_code,
response.error_kind or response.error_type or response.error_code or "unknown",
)
return last_response or LLMResponse(
content="No available fallback model candidate.",
finish_reason="error",
error_kind="configuration",
error_should_retry=False,
)
async def chat_with_retry(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
model: str | None = None,
max_tokens: object = LLMProvider._SENTINEL,
temperature: object = LLMProvider._SENTINEL,
reasoning_effort: object = LLMProvider._SENTINEL,
tool_choice: str | dict[str, Any] | None = None,
retry_mode: str = "standard",
on_retry_wait: Callable[[str], Awaitable[None]] | None = None,
) -> LLMResponse:
async def call(
provider: LLMProvider,
candidate_model: str,
_delta: Callable[[str], Awaitable[None]] | None,
) -> LLMResponse:
return await provider.chat_with_retry(
messages=messages,
tools=tools,
model=candidate_model,
max_tokens=max_tokens,
temperature=temperature,
reasoning_effort=reasoning_effort,
tool_choice=tool_choice,
retry_mode=retry_mode,
on_retry_wait=on_retry_wait,
)
return await self._route(call)
async def chat_stream_with_retry(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
model: str | None = None,
max_tokens: object = LLMProvider._SENTINEL,
temperature: object = LLMProvider._SENTINEL,
reasoning_effort: object = LLMProvider._SENTINEL,
tool_choice: str | dict[str, Any] | None = None,
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
retry_mode: str = "standard",
on_retry_wait: Callable[[str], Awaitable[None]] | None = None,
) -> LLMResponse:
async def call(
provider: LLMProvider,
candidate_model: str,
external_delta: Callable[[str], Awaitable[None]] | None,
) -> LLMResponse:
buffered: list[str] = []
async def buffer_delta(delta: str) -> None:
buffered.append(delta)
response = await provider.chat_stream_with_retry(
messages=messages,
tools=tools,
model=candidate_model,
max_tokens=max_tokens,
temperature=temperature,
reasoning_effort=reasoning_effort,
tool_choice=tool_choice,
on_content_delta=buffer_delta if external_delta else None,
retry_mode=retry_mode,
on_retry_wait=on_retry_wait,
)
if response.finish_reason != "error" and external_delta:
for delta in buffered:
await external_delta(delta)
return response
return await self._route(call, on_content_delta=on_content_delta)

View File

@ -6,7 +6,8 @@ from unittest.mock import AsyncMock, MagicMock
import pytest import pytest
from nanobot.agent.runner import AgentRunSpec, AgentRunner from nanobot.agent.hook import AgentHook, AgentHookContext
from nanobot.agent.runner import AgentRunner, AgentRunSpec
from nanobot.providers.base import LLMResponse from nanobot.providers.base import LLMResponse
@ -24,6 +25,10 @@ def _make_provider(*, model_response: LLMResponse | None = None):
return p return p
def _transient_error(content: str = "server unavailable") -> LLMResponse:
return LLMResponse(content=content, finish_reason="error", error_status_code=503)
def _base_spec(**overrides) -> AgentRunSpec: def _base_spec(**overrides) -> AgentRunSpec:
defaults = dict( defaults = dict(
initial_messages=[ initial_messages=[
@ -56,7 +61,7 @@ async def test_no_fallback_when_primary_succeeds():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_fallback_triggered_on_primary_error(): async def test_fallback_triggered_on_primary_error():
"""Primary fails -> first fallback succeeds.""" """Primary fails -> first fallback succeeds."""
err = LLMResponse(content=None, finish_reason="error", usage={}) err = _transient_error()
ok = LLMResponse(content="fallback-ok", tool_calls=[], usage={}) ok = LLMResponse(content="fallback-ok", tool_calls=[], usage={})
primary = _make_provider(model_response=err) primary = _make_provider(model_response=err)
@ -76,12 +81,12 @@ async def test_fallback_triggered_on_primary_error():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_all_fallbacks_fail_returns_last_error(): async def test_all_fallbacks_fail_returns_last_error():
"""Primary + all fallbacks fail -> return last error response.""" """Primary + all fallbacks fail -> return last error response."""
err = LLMResponse(content=None, finish_reason="error", usage={}) err = _transient_error()
primary = _make_provider(model_response=err) primary = _make_provider(model_response=err)
fb1 = _make_provider(model_response=err) fb1 = _make_provider(model_response=err)
fb2 = _make_provider(model_response=LLMResponse( fb2 = _make_provider(model_response=LLMResponse(
content="last-error", finish_reason="error", usage={}, content="last-error", finish_reason="error", error_status_code=500, usage={},
)) ))
providers = {"fb-1": fb1, "fb-2": fb2} providers = {"fb-1": fb1, "fb-2": fb2}
@ -110,7 +115,7 @@ async def test_empty_fallback_list_no_retry():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cross_provider_fallback(): async def test_cross_provider_fallback():
"""Fallback uses a different provider instance (cross-provider).""" """Fallback uses a different provider instance (cross-provider)."""
err = LLMResponse(content=None, finish_reason="error", usage={}) err = _transient_error()
ok = LLMResponse(content="cross-provider-ok", tool_calls=[], usage={}) ok = LLMResponse(content="cross-provider-ok", tool_calls=[], usage={})
primary = _make_provider(model_response=err) primary = _make_provider(model_response=err)
@ -132,7 +137,7 @@ async def test_cross_provider_fallback():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_fallback_skips_to_second_on_first_error(): async def test_fallback_skips_to_second_on_first_error():
"""First fallback also fails -> second fallback succeeds.""" """First fallback also fails -> second fallback succeeds."""
err = LLMResponse(content=None, finish_reason="error", usage={}) err = _transient_error()
ok = LLMResponse(content="second-fb-ok", tool_calls=[], usage={}) ok = LLMResponse(content="second-fb-ok", tool_calls=[], usage={})
primary = _make_provider(model_response=err) primary = _make_provider(model_response=err)
@ -158,7 +163,7 @@ async def test_fallback_reuses_same_provider_without_factory():
async def chat_with_retry(*, messages, model, **kw): async def chat_with_retry(*, messages, model, **kw):
call_count["n"] += 1 call_count["n"] += 1
if call_count["n"] == 1: if call_count["n"] == 1:
return LLMResponse(content=None, finish_reason="error", usage={}) return _transient_error()
return LLMResponse(content=f"ok-via-{model}", tool_calls=[], usage={}) return LLMResponse(content=f"ok-via-{model}", tool_calls=[], usage={})
primary = MagicMock() primary = MagicMock()
@ -173,7 +178,7 @@ async def test_fallback_reuses_same_provider_without_factory():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_fallback_provider_cached(): async def test_fallback_provider_cached():
"""Provider factory is called once per unique provider, not per attempt.""" """Provider factory is called once per unique provider, not per attempt."""
err = LLMResponse(content=None, finish_reason="error", usage={}) err = _transient_error()
ok = LLMResponse(content="cached-ok", tool_calls=[], usage={}) ok = LLMResponse(content="cached-ok", tool_calls=[], usage={})
primary = _make_provider(model_response=err) primary = _make_provider(model_response=err)
@ -188,3 +193,77 @@ async def test_fallback_provider_cached():
result = await runner.run(_base_spec(fallback_models=["same-provider-model-a", "same-provider-model-b"])) result = await runner.run(_base_spec(fallback_models=["same-provider-model-a", "same-provider-model-b"]))
assert result.final_content == "cached-ok" assert result.final_content == "cached-ok"
@pytest.mark.asyncio
async def test_non_transient_error_does_not_fallback():
"""Auth/config-style errors should surface instead of hiding bugs via fallback."""
primary = _make_provider(model_response=LLMResponse(
content="401 unauthorized",
finish_reason="error",
error_status_code=401,
))
fallback = _make_provider(model_response=LLMResponse(content="fallback-ok"))
factory = MagicMock(return_value=fallback)
runner = AgentRunner(primary, provider_factory=factory)
result = await runner.run(_base_spec(fallback_models=["fb-model"]))
factory.assert_not_called()
assert result.error is not None
@pytest.mark.asyncio
async def test_quota_error_does_not_fallback_by_default():
"""Quota/billing/payment 429s should not route by default."""
primary = _make_provider(model_response=LLMResponse(
content="insufficient quota",
finish_reason="error",
error_status_code=429,
error_code="insufficient_quota",
))
fallback = _make_provider(model_response=LLMResponse(content="fallback-ok"))
factory = MagicMock(return_value=fallback)
runner = AgentRunner(primary, provider_factory=factory)
result = await runner.run(_base_spec(fallback_models=["fb-model"]))
factory.assert_not_called()
assert result.error is not None
@pytest.mark.asyncio
async def test_streaming_fallback_discards_failed_primary_deltas():
"""Buffered streaming prevents primary partial output from leaking on fallback."""
streamed: list[str] = []
class StreamingHook(AgentHook):
def wants_streaming(self) -> bool:
return True
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
streamed.append(delta)
async def primary_stream(*, on_content_delta, **kwargs):
await on_content_delta("bad partial")
return _transient_error()
async def fallback_stream(*, on_content_delta, **kwargs):
await on_content_delta("good")
await on_content_delta(" answer")
return LLMResponse(content="good answer", tool_calls=[], usage={})
primary = MagicMock()
primary.chat_stream_with_retry = primary_stream
fallback = MagicMock()
fallback.chat_stream_with_retry = fallback_stream
factory = MagicMock(return_value=fallback)
runner = AgentRunner(primary, provider_factory=factory)
result = await runner.run(_base_spec(
fallback_models=["fb-model"],
hook=StreamingHook(),
))
assert result.final_content == "good answer"
assert streamed == ["good", " answer"]