Refine fallback routing on model presets

This commit is contained in:
hanyuanling 2026-04-28 14:31:02 +08:00 committed by chengyongru
parent 15b7e65358
commit ecbe56dd92
5 changed files with 393 additions and 30 deletions

View File

@ -16,6 +16,7 @@ from nanobot.agent.hook import AgentHook, AgentHookContext
from nanobot.agent.tools.ask import AskUserInterrupt
from nanobot.agent.tools.registry import ToolRegistry
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
from nanobot.providers.failover import ModelCandidate, ModelRouter
from nanobot.utils.helpers import (
build_assistant_message,
estimate_message_tokens,
@ -621,24 +622,14 @@ class AgentRunner:
messages,
tools=spec.tools.get_definitions(),
)
response = await self._call_provider(self.provider, kwargs, hook, context, spec, timeout_s)
if response.finish_reason == "error" and spec.fallback_models:
for fb_model in spec.fallback_models:
logger.warning(
"Primary model {} failed, trying fallback: {}",
spec.model,
fb_model,
)
fb_provider, resolved_model = self._resolve_fallback_provider(fb_model)
fb_kwargs = dict(kwargs, model=resolved_model)
response = await self._call_provider(
fb_provider, fb_kwargs, hook, context, spec, timeout_s,
)
if response.finish_reason != "error":
break
return response
provider: LLMProvider = self.provider
request_timeout = timeout_s
if spec.fallback_models:
provider = self._build_model_router(spec, timeout_s)
# ModelRouter applies the same timeout per candidate, preserving
# fallback on primary timeouts instead of timing out the whole chain.
request_timeout = None
return await self._call_provider(provider, kwargs, hook, context, spec, request_timeout)
async def _call_provider(
self,
@ -716,6 +707,25 @@ class AgentRunner:
return provider, provider.get_default_model()
return self.provider, model
def _build_model_router(
self,
spec: AgentRunSpec,
timeout_s: float | None,
) -> ModelRouter:
candidates = [
ModelCandidate(
label=model,
resolver=lambda m=model: self._resolve_fallback_provider(m),
)
for model in spec.fallback_models
]
return ModelRouter(
primary_provider=self.provider,
primary_model=spec.model,
fallback_candidates=candidates,
per_candidate_timeout_s=timeout_s,
)
async def _request_finalization_retry(
self,
spec: AgentRunSpec,

View File

@ -527,8 +527,7 @@ def _make_cli_provider_factory(config: Config):
def factory(model_or_preset: str):
preset = presets.get(model_or_preset)
actual_model = preset.model if preset else model_or_preset
provider_name = config.get_provider_name(actual_model)
key = provider_name or actual_model
key = actual_model
if key not in cache:
cache[key] = _make_provider_for_model(config, actual_model, preset=preset)
return cache[key]

View File

@ -224,8 +224,7 @@ def _make_provider_factory(config: Any):
def factory(model_or_preset: str):
preset = presets.get(model_or_preset)
actual_model = preset.model if preset else model_or_preset
provider_name = config.get_provider_name(actual_model)
key = provider_name or actual_model
key = actual_model
if key not in cache:
cache[key] = _make_provider_for_model(config, actual_model, preset=preset)
return cache[key]

View File

@ -0,0 +1,276 @@
"""Provider-like failover router used after provider-local retry is exhausted."""
from __future__ import annotations
import asyncio
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from typing import Any
from loguru import logger
from nanobot.providers.base import GenerationSettings, LLMProvider, LLMResponse
@dataclass(frozen=True)
class ModelCandidate:
"""A lazily resolved model/provider candidate."""
label: str
resolver: Callable[[], tuple[LLMProvider, str]]
class ModelRouter(LLMProvider):
"""Try fallback model candidates for eligible transient final errors."""
supports_progress_deltas = False
_BLOCKED_STATUS_CODES = frozenset({400, 401, 403, 404, 422})
_QUOTA_MARKERS = (
"insufficient_quota",
"insufficient quota",
"quota exceeded",
"quota_exceeded",
"quota exhausted",
"quota_exhausted",
"billing hard limit",
"billing_hard_limit_reached",
"billing not active",
"insufficient balance",
"insufficient_balance",
"credit balance too low",
"payment required",
"out of credits",
"out of quota",
"exceeded your current quota",
)
_NON_FAILOVER_MARKERS = (
"context length",
"context_length",
"maximum context",
"max context",
"token budget",
"too many tokens",
"schema",
"invalid request",
"invalid_request",
"invalid parameter",
"invalid_parameter",
"unsupported",
"unauthorized",
"authentication",
"permission",
"forbidden",
"refusal",
"content policy",
"content_filter",
"policy violation",
"safety",
)
def __init__(
self,
*,
primary_provider: LLMProvider,
primary_model: str,
fallback_candidates: list[ModelCandidate],
per_candidate_timeout_s: float | None = None,
) -> None:
super().__init__(
api_key=getattr(primary_provider, "api_key", None),
api_base=getattr(primary_provider, "api_base", None),
)
self.primary_provider = primary_provider
self.primary_model = primary_model
self.fallback_candidates = list(fallback_candidates)
self.per_candidate_timeout_s = per_candidate_timeout_s
self.generation = getattr(primary_provider, "generation", GenerationSettings())
def get_default_model(self) -> str:
return self.primary_model
async def chat(self, **kwargs: Any) -> LLMResponse:
return await self.primary_provider.chat(**kwargs)
async def chat_stream(self, **kwargs: Any) -> LLMResponse:
return await self.primary_provider.chat_stream(**kwargs)
@classmethod
def _is_quota_error(cls, response: LLMResponse) -> bool:
tokens = {
cls._normalize_error_token(response.error_type),
cls._normalize_error_token(response.error_code),
}
if any(token in cls._NON_RETRYABLE_429_ERROR_TOKENS for token in tokens if token):
return True
content = (response.content or "").lower()
return any(marker in content for marker in cls._QUOTA_MARKERS)
@classmethod
def _is_blocked_error(cls, response: LLMResponse) -> bool:
status = response.error_status_code
if status is not None and int(status) in cls._BLOCKED_STATUS_CODES:
return True
if response.finish_reason in {"refusal", "content_filter"}:
return True
content = (response.content or "").lower()
return any(marker in content for marker in cls._NON_FAILOVER_MARKERS)
@classmethod
def _should_failover(cls, response: LLMResponse) -> bool:
if response.finish_reason != "error":
return False
if cls._is_blocked_error(response):
return False
if cls._is_quota_error(response):
return False
return cls._is_transient_response(response)
async def _with_timeout(self, coro: Awaitable[LLMResponse]) -> LLMResponse:
timeout_s = self.per_candidate_timeout_s
if timeout_s is None:
return await coro
try:
return await asyncio.wait_for(coro, timeout=timeout_s)
except asyncio.TimeoutError:
return LLMResponse(
content=f"Error calling LLM: timed out after {timeout_s:g}s",
finish_reason="error",
error_kind="timeout",
)
@staticmethod
def _resolver_error(candidate: ModelCandidate, exc: Exception) -> LLMResponse:
logger.warning("Failed to resolve fallback model {}: {}", candidate.label, exc)
return LLMResponse(
content=f"Error configuring fallback model {candidate.label}: {exc}",
finish_reason="error",
error_kind="configuration",
error_should_retry=False,
)
def _candidate_chain(self) -> list[ModelCandidate]:
return [
ModelCandidate(
label=self.primary_model,
resolver=lambda: (self.primary_provider, self.primary_model),
),
*self.fallback_candidates,
]
async def _route(
self,
call: Callable[[LLMProvider, str, Callable[[str], Awaitable[None]] | None], Awaitable[LLMResponse]],
*,
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
) -> LLMResponse:
last_response: LLMResponse | None = None
chain = self._candidate_chain()
for index, candidate in enumerate(chain):
try:
provider, model = candidate.resolver()
except asyncio.CancelledError:
raise
except Exception as exc:
response = self._resolver_error(candidate, exc)
else:
response = await self._with_timeout(call(provider, model, on_content_delta))
if response.finish_reason != "error":
if index > 0:
logger.info("LLM failover selected model={}", candidate.label)
return response
last_response = response
if not self._should_failover(response):
return response
if index + 1 >= len(chain):
logger.warning("LLM failover exhausted after model={}", candidate.label)
return response
logger.warning(
"LLM failover model={} next_model={} status={} kind={}",
candidate.label,
chain[index + 1].label,
response.error_status_code,
response.error_kind or response.error_type or response.error_code or "unknown",
)
return last_response or LLMResponse(
content="No available fallback model candidate.",
finish_reason="error",
error_kind="configuration",
error_should_retry=False,
)
async def chat_with_retry(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
model: str | None = None,
max_tokens: object = LLMProvider._SENTINEL,
temperature: object = LLMProvider._SENTINEL,
reasoning_effort: object = LLMProvider._SENTINEL,
tool_choice: str | dict[str, Any] | None = None,
retry_mode: str = "standard",
on_retry_wait: Callable[[str], Awaitable[None]] | None = None,
) -> LLMResponse:
async def call(
provider: LLMProvider,
candidate_model: str,
_delta: Callable[[str], Awaitable[None]] | None,
) -> LLMResponse:
return await provider.chat_with_retry(
messages=messages,
tools=tools,
model=candidate_model,
max_tokens=max_tokens,
temperature=temperature,
reasoning_effort=reasoning_effort,
tool_choice=tool_choice,
retry_mode=retry_mode,
on_retry_wait=on_retry_wait,
)
return await self._route(call)
async def chat_stream_with_retry(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
model: str | None = None,
max_tokens: object = LLMProvider._SENTINEL,
temperature: object = LLMProvider._SENTINEL,
reasoning_effort: object = LLMProvider._SENTINEL,
tool_choice: str | dict[str, Any] | None = None,
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
retry_mode: str = "standard",
on_retry_wait: Callable[[str], Awaitable[None]] | None = None,
) -> LLMResponse:
async def call(
provider: LLMProvider,
candidate_model: str,
external_delta: Callable[[str], Awaitable[None]] | None,
) -> LLMResponse:
buffered: list[str] = []
async def buffer_delta(delta: str) -> None:
buffered.append(delta)
response = await provider.chat_stream_with_retry(
messages=messages,
tools=tools,
model=candidate_model,
max_tokens=max_tokens,
temperature=temperature,
reasoning_effort=reasoning_effort,
tool_choice=tool_choice,
on_content_delta=buffer_delta if external_delta else None,
retry_mode=retry_mode,
on_retry_wait=on_retry_wait,
)
if response.finish_reason != "error" and external_delta:
for delta in buffered:
await external_delta(delta)
return response
return await self._route(call, on_content_delta=on_content_delta)

View File

@ -6,7 +6,8 @@ from unittest.mock import AsyncMock, MagicMock
import pytest
from nanobot.agent.runner import AgentRunSpec, AgentRunner
from nanobot.agent.hook import AgentHook, AgentHookContext
from nanobot.agent.runner import AgentRunner, AgentRunSpec
from nanobot.providers.base import LLMResponse
@ -24,6 +25,10 @@ def _make_provider(*, model_response: LLMResponse | None = None):
return p
def _transient_error(content: str = "server unavailable") -> LLMResponse:
return LLMResponse(content=content, finish_reason="error", error_status_code=503)
def _base_spec(**overrides) -> AgentRunSpec:
defaults = dict(
initial_messages=[
@ -56,7 +61,7 @@ async def test_no_fallback_when_primary_succeeds():
@pytest.mark.asyncio
async def test_fallback_triggered_on_primary_error():
"""Primary fails -> first fallback succeeds."""
err = LLMResponse(content=None, finish_reason="error", usage={})
err = _transient_error()
ok = LLMResponse(content="fallback-ok", tool_calls=[], usage={})
primary = _make_provider(model_response=err)
@ -76,12 +81,12 @@ async def test_fallback_triggered_on_primary_error():
@pytest.mark.asyncio
async def test_all_fallbacks_fail_returns_last_error():
"""Primary + all fallbacks fail -> return last error response."""
err = LLMResponse(content=None, finish_reason="error", usage={})
err = _transient_error()
primary = _make_provider(model_response=err)
fb1 = _make_provider(model_response=err)
fb2 = _make_provider(model_response=LLMResponse(
content="last-error", finish_reason="error", usage={},
content="last-error", finish_reason="error", error_status_code=500, usage={},
))
providers = {"fb-1": fb1, "fb-2": fb2}
@ -110,7 +115,7 @@ async def test_empty_fallback_list_no_retry():
@pytest.mark.asyncio
async def test_cross_provider_fallback():
"""Fallback uses a different provider instance (cross-provider)."""
err = LLMResponse(content=None, finish_reason="error", usage={})
err = _transient_error()
ok = LLMResponse(content="cross-provider-ok", tool_calls=[], usage={})
primary = _make_provider(model_response=err)
@ -132,7 +137,7 @@ async def test_cross_provider_fallback():
@pytest.mark.asyncio
async def test_fallback_skips_to_second_on_first_error():
"""First fallback also fails -> second fallback succeeds."""
err = LLMResponse(content=None, finish_reason="error", usage={})
err = _transient_error()
ok = LLMResponse(content="second-fb-ok", tool_calls=[], usage={})
primary = _make_provider(model_response=err)
@ -158,7 +163,7 @@ async def test_fallback_reuses_same_provider_without_factory():
async def chat_with_retry(*, messages, model, **kw):
call_count["n"] += 1
if call_count["n"] == 1:
return LLMResponse(content=None, finish_reason="error", usage={})
return _transient_error()
return LLMResponse(content=f"ok-via-{model}", tool_calls=[], usage={})
primary = MagicMock()
@ -173,7 +178,7 @@ async def test_fallback_reuses_same_provider_without_factory():
@pytest.mark.asyncio
async def test_fallback_provider_cached():
"""Provider factory is called once per unique provider, not per attempt."""
err = LLMResponse(content=None, finish_reason="error", usage={})
err = _transient_error()
ok = LLMResponse(content="cached-ok", tool_calls=[], usage={})
primary = _make_provider(model_response=err)
@ -188,3 +193,77 @@ async def test_fallback_provider_cached():
result = await runner.run(_base_spec(fallback_models=["same-provider-model-a", "same-provider-model-b"]))
assert result.final_content == "cached-ok"
@pytest.mark.asyncio
async def test_non_transient_error_does_not_fallback():
"""Auth/config-style errors should surface instead of hiding bugs via fallback."""
primary = _make_provider(model_response=LLMResponse(
content="401 unauthorized",
finish_reason="error",
error_status_code=401,
))
fallback = _make_provider(model_response=LLMResponse(content="fallback-ok"))
factory = MagicMock(return_value=fallback)
runner = AgentRunner(primary, provider_factory=factory)
result = await runner.run(_base_spec(fallback_models=["fb-model"]))
factory.assert_not_called()
assert result.error is not None
@pytest.mark.asyncio
async def test_quota_error_does_not_fallback_by_default():
"""Quota/billing/payment 429s should not route by default."""
primary = _make_provider(model_response=LLMResponse(
content="insufficient quota",
finish_reason="error",
error_status_code=429,
error_code="insufficient_quota",
))
fallback = _make_provider(model_response=LLMResponse(content="fallback-ok"))
factory = MagicMock(return_value=fallback)
runner = AgentRunner(primary, provider_factory=factory)
result = await runner.run(_base_spec(fallback_models=["fb-model"]))
factory.assert_not_called()
assert result.error is not None
@pytest.mark.asyncio
async def test_streaming_fallback_discards_failed_primary_deltas():
"""Buffered streaming prevents primary partial output from leaking on fallback."""
streamed: list[str] = []
class StreamingHook(AgentHook):
def wants_streaming(self) -> bool:
return True
async def on_stream(self, context: AgentHookContext, delta: str) -> None:
streamed.append(delta)
async def primary_stream(*, on_content_delta, **kwargs):
await on_content_delta("bad partial")
return _transient_error()
async def fallback_stream(*, on_content_delta, **kwargs):
await on_content_delta("good")
await on_content_delta(" answer")
return LLMResponse(content="good answer", tool_calls=[], usage={})
primary = MagicMock()
primary.chat_stream_with_retry = primary_stream
fallback = MagicMock()
fallback.chat_stream_with_retry = fallback_stream
factory = MagicMock(return_value=fallback)
runner = AgentRunner(primary, provider_factory=factory)
result = await runner.run(_base_spec(
fallback_models=["fb-model"],
hook=StreamingHook(),
))
assert result.final_content == "good answer"
assert streamed == ["good", " answer"]