feat: add fallback_models support for automatic model failover

When the primary model fails (finish_reason="error" after exhausting
provider-level retries), automatically try each model in the configured
fallback_models list. Supports cross-provider fallback via a cached
provider_factory that resolves the correct provider for each model string.

Config:
  agents.defaults.fallback_models: ["model-b", "provider/model-c"]

Changes:
- AgentDefaults: add fallback_models field
- AgentRunSpec: add fallback_models field
- AgentRunner: add provider_factory, _call_provider, _resolve_fallback_provider
- AgentLoop: accept and forward fallback_models + provider_factory
- nanobot.py: extract _make_provider_for_model, add _make_provider_factory
- cli/commands.py: add _make_cli_provider_factory, wire all AgentLoop sites
- tests/agent/test_runner_fallback.py: 8 test cases covering primary success,
  single/multi fallback, cross-provider, no-factory reuse, caching

Made-with: Cursor
This commit is contained in:
LeftX 2026-05-06 14:49:07 +08:00 committed by chengyongru
parent 83f437a088
commit 2e5930e355
6 changed files with 350 additions and 34 deletions

View File

@ -200,6 +200,8 @@ class AgentLoop:
max_tool_result_chars: int | None = None,
provider_retry_mode: str = "standard",
tool_hint_max_length: int | None = None,
fallback_models: list[str] | None = None,
provider_factory: Any | None = None,
web_config: WebToolsConfig | None = None,
exec_config: ExecToolConfig | None = None,
cron_service: CronService | None = None,
@ -250,6 +252,7 @@ class AgentLoop:
tool_hint_max_length if tool_hint_max_length is not None
else defaults.tool_hint_max_length
)
self.fallback_models = fallback_models or []
self.web_config = web_config or WebToolsConfig()
self.exec_config = exec_config or ExecToolConfig()
self.cron_service = cron_service
@ -263,7 +266,7 @@ class AgentLoop:
# One file-read/write tracker per logical session. The tool registry is
# shared by this loop, so tools resolve the active state via contextvars.
self._file_state_store = FileStateStore()
self.runner = AgentRunner(provider)
self.runner = AgentRunner(provider, provider_factory=provider_factory)
self.subagents = SubagentManager(
provider=provider,
workspace=workspace,
@ -681,6 +684,7 @@ class AgentLoop:
context_window_tokens=self.context_window_tokens,
context_block_limit=self.context_block_limit,
provider_retry_mode=self.provider_retry_mode,
fallback_models=self.fallback_models,
progress_callback=on_progress,
stream_progress_deltas=on_stream is not None,
retry_wait_callback=on_retry_wait,

View File

@ -75,6 +75,7 @@ class AgentRunSpec:
context_window_tokens: int | None = None
context_block_limit: int | None = None
provider_retry_mode: str = "standard"
fallback_models: list[str] = field(default_factory=list)
progress_callback: Any | None = None
stream_progress_deltas: bool = True
retry_wait_callback: Any | None = None
@ -97,11 +98,21 @@ class AgentRunResult:
had_injections: bool = False
ProviderFactory = Any # Callable[[str], LLMProvider] — avoids circular import
class AgentRunner:
"""Run a tool-capable LLM loop without product-layer concerns."""
def __init__(self, provider: LLMProvider):
def __init__(
self,
provider: LLMProvider,
*,
provider_factory: ProviderFactory | None = None,
):
self.provider = provider
self._provider_factory = provider_factory
self._fallback_providers: dict[str, LLMProvider] = {}
@staticmethod
def _merge_message_content(left: Any, right: Any) -> str | list[dict[str, Any]]:
@ -594,12 +605,9 @@ class AgentRunner:
messages: list[dict[str, Any]],
hook: AgentHook,
context: AgentHookContext,
):
) -> LLMResponse:
timeout_s: float | None = spec.llm_timeout_s
if timeout_s is None:
# Default to a finite timeout to avoid per-session lock starvation when an LLM
# request hangs indefinitely (e.g. gateway/network stall).
# Set NANOBOT_LLM_TIMEOUT_S=0 to disable.
raw = os.environ.get("NANOBOT_LLM_TIMEOUT_S", "300").strip()
try:
timeout_s = float(raw)
@ -613,12 +621,40 @@ class AgentRunner:
messages,
tools=spec.tools.get_definitions(),
)
response = await self._call_provider(self.provider, kwargs, hook, context, spec, timeout_s)
if response.finish_reason == "error" and spec.fallback_models:
for fb_model in spec.fallback_models:
logger.warning(
"Primary model {} failed, trying fallback: {}",
spec.model,
fb_model,
)
fb_provider, resolved_model = self._resolve_fallback_provider(fb_model)
fb_kwargs = dict(kwargs, model=resolved_model)
response = await self._call_provider(
fb_provider, fb_kwargs, hook, context, spec, timeout_s,
)
if response.finish_reason != "error":
break
return response
async def _call_provider(
self,
provider: LLMProvider,
kwargs: dict[str, Any],
hook: AgentHook,
context: AgentHookContext,
spec: AgentRunSpec,
timeout_s: float | None = None,
) -> LLMResponse:
wants_streaming = hook.wants_streaming()
wants_progress_streaming = (
not wants_streaming
and spec.stream_progress_deltas
and spec.progress_callback is not None
and getattr(self.provider, "supports_progress_deltas", False) is True
and getattr(provider, "supports_progress_deltas", False) is True
)
if wants_streaming:
@ -627,7 +663,7 @@ class AgentRunner:
context.streamed_content = True
await hook.on_stream(context, delta)
coro = self.provider.chat_stream_with_retry(
coro = provider.chat_stream_with_retry(
**kwargs,
on_content_delta=_stream,
)
@ -646,12 +682,12 @@ class AgentRunner:
context.streamed_content = True
await spec.progress_callback(incremental)
coro = self.provider.chat_stream_with_retry(
coro = provider.chat_stream_with_retry(
**kwargs,
on_content_delta=_stream_progress,
)
else:
coro = self.provider.chat_with_retry(**kwargs)
coro = provider.chat_with_retry(**kwargs)
if timeout_s is None:
return await coro
@ -664,6 +700,22 @@ class AgentRunner:
error_kind="timeout",
)
def _resolve_fallback_provider(self, model: str) -> tuple[LLMProvider, str]:
"""Return (provider, actual_model_name) for a fallback model.
When a provider_factory is available (and the model string may be a
preset name), the factory resolves the actual model; otherwise the
primary provider is reused with the raw model string.
"""
if model in self._fallback_providers:
p = self._fallback_providers[model]
return p, p.get_default_model()
if self._provider_factory:
provider = self._provider_factory(model)
self._fallback_providers[model] = provider
return provider, provider.get_default_model()
return self.provider, model
async def _request_finalization_retry(
self,
spec: AgentRunSpec,

View File

@ -513,6 +513,29 @@ def _make_provider(config: Config):
return provider
def _make_cli_provider_factory(config: Config):
"""Build a cached factory for fallback model providers (CLI side).
Supports preset names: if a model string matches a preset, the preset's
full config is used for provider creation.
"""
from nanobot.nanobot import _make_provider_for_model
cache: dict[str, Any] = {}
presets = getattr(config, "model_presets", {}) or {}
def factory(model_or_preset: str):
preset = presets.get(model_or_preset)
actual_model = preset.model if preset else model_or_preset
provider_name = config.get_provider_name(actual_model)
key = provider_name or actual_model
if key not in cache:
cache[key] = _make_provider_for_model(config, actual_model, preset=preset)
return cache[key]
return factory
def _load_runtime_config(config: str | None = None, workspace: str | None = None) -> Config:
"""Load config and optionally override the active workspace."""
from nanobot.config.loader import load_config, resolve_config_env_vars, set_config_path
@ -608,6 +631,8 @@ def serve(
sync_workspace_templates(runtime_config.workspace_path)
bus = MessageBus()
provider = _make_provider(runtime_config)
defaults = runtime_config.agents.defaults
pf = _make_cli_provider_factory(runtime_config) if defaults.fallback_models else None
session_manager = SessionManager(runtime_config.workspace_path)
_resolved = runtime_config.resolve_preset()
agent_loop = AgentLoop(
@ -615,12 +640,13 @@ def serve(
provider=provider,
workspace=runtime_config.workspace_path,
model=_resolved.model,
max_iterations=runtime_config.agents.defaults.max_tool_iterations,
max_iterations=defaults.max_tool_iterations,
context_window_tokens=_resolved.context_window_tokens,
context_block_limit=runtime_config.agents.defaults.context_block_limit,
max_tool_result_chars=runtime_config.agents.defaults.max_tool_result_chars,
provider_retry_mode=runtime_config.agents.defaults.provider_retry_mode,
tool_hint_max_length=runtime_config.agents.defaults.tool_hint_max_length,
context_block_limit=defaults.context_block_limit,
max_tool_result_chars=defaults.max_tool_result_chars,
provider_retry_mode=defaults.provider_retry_mode,
fallback_models=defaults.fallback_models,
provider_factory=pf,
web_config=runtime_config.tools.web,
exec_config=runtime_config.tools.exec,
restrict_to_workspace=runtime_config.tools.restrict_to_workspace,
@ -639,7 +665,7 @@ def serve(
)
model_name = _resolved.model
preset_name = runtime_config.agents.defaults.model_preset
preset_name = defaults.model_preset
preset_tag = f" (preset: {preset_name})" if preset_name else ""
console.print(f"{__logo__} Starting OpenAI-compatible API server")
console.print(f" [cyan]Endpoint[/cyan] : http://{host}:{port}/v1/chat/completions")
@ -721,12 +747,14 @@ def _run_gateway(
console.print(f"{__logo__} Starting nanobot gateway version {__version__} on port {port}...")
sync_workspace_templates(config.workspace_path)
bus = MessageBus()
provider = _make_provider(config)
gw_defaults = config.agents.defaults
gw_pf = _make_cli_provider_factory(config) if gw_defaults.fallback_models else None
try:
provider_snapshot = build_provider_snapshot(config)
except ValueError as exc:
console.print(f"[red]Error: {exc}[/red]")
raise typer.Exit(1) from exc
provider = provider_snapshot.provider
session_manager = SessionManager(config.workspace_path)
# Preserve existing single-workspace installs, but keep custom workspaces clean.
@ -744,13 +772,14 @@ def _run_gateway(
provider=provider,
workspace=config.workspace_path,
model=_resolved.model,
max_iterations=config.agents.defaults.max_tool_iterations,
max_iterations=gw_defaults.max_tool_iterations,
context_window_tokens=_resolved.context_window_tokens,
web_config=config.tools.web,
context_block_limit=config.agents.defaults.context_block_limit,
max_tool_result_chars=config.agents.defaults.max_tool_result_chars,
provider_retry_mode=config.agents.defaults.provider_retry_mode,
tool_hint_max_length=config.agents.defaults.tool_hint_max_length,
context_block_limit=gw_defaults.context_block_limit,
max_tool_result_chars=gw_defaults.max_tool_result_chars,
provider_retry_mode=gw_defaults.provider_retry_mode,
fallback_models=gw_defaults.fallback_models,
provider_factory=gw_pf,
exec_config=config.tools.exec,
cron_service=cron,
restrict_to_workspace=config.tools.restrict_to_workspace,
@ -1120,6 +1149,8 @@ def agent(
bus = MessageBus()
provider = _make_provider(config)
chat_defaults = config.agents.defaults
chat_pf = _make_cli_provider_factory(config) if chat_defaults.fallback_models else None
# Preserve existing single-workspace installs, but keep custom workspaces clean.
if is_default_workspace(config.workspace_path):
@ -1140,13 +1171,14 @@ def agent(
provider=provider,
workspace=config.workspace_path,
model=_resolved.model,
max_iterations=config.agents.defaults.max_tool_iterations,
max_iterations=chat_defaults.max_tool_iterations,
context_window_tokens=_resolved.context_window_tokens,
web_config=config.tools.web,
context_block_limit=config.agents.defaults.context_block_limit,
max_tool_result_chars=config.agents.defaults.max_tool_result_chars,
provider_retry_mode=config.agents.defaults.provider_retry_mode,
tool_hint_max_length=config.agents.defaults.tool_hint_max_length,
context_block_limit=chat_defaults.context_block_limit,
max_tool_result_chars=chat_defaults.max_tool_result_chars,
provider_retry_mode=chat_defaults.provider_retry_mode,
fallback_models=chat_defaults.fallback_models,
provider_factory=chat_pf,
exec_config=config.tools.exec,
cron_service=cron,
restrict_to_workspace=config.tools.restrict_to_workspace,

View File

@ -105,6 +105,7 @@ class AgentDefaults(Base):
serialization_alias="toolHintMaxLength",
) # Max characters for tool hint display (e.g. "$ cd …/project && npm test")
reasoning_effort: str | None = None # low / medium / high / adaptive - enables LLM thinking mode
fallback_models: list[str] = Field(default_factory=list)
timezone: str = "UTC" # IANA timezone, e.g. "Asia/Shanghai", "America/New_York"
unified_session: bool = False # Share one session across all channels (single-user multi-device)
disabled_skills: list[str] = Field(default_factory=list) # Skill names to exclude from loading (e.g. ["summarize", "skill-creator"])

View File

@ -66,6 +66,7 @@ class Nanobot:
bus = MessageBus()
defaults = config.agents.defaults
_resolved = config.resolve_preset()
pf = _make_provider_factory(config) if defaults.fallback_models else None
loop = AgentLoop(
bus=bus,
@ -78,6 +79,8 @@ class Nanobot:
max_tool_result_chars=defaults.max_tool_result_chars,
provider_retry_mode=defaults.provider_retry_mode,
tool_hint_max_length=defaults.tool_hint_max_length,
fallback_models=defaults.fallback_models,
provider_factory=pf,
web_config=config.tools.web,
exec_config=config.tools.exec,
restrict_to_workspace=config.tools.restrict_to_workspace,
@ -127,14 +130,22 @@ class Nanobot:
)
def _make_provider(config: Any) -> Any:
"""Create the LLM provider from config (extracted from CLI)."""
def _make_provider_for_model(
config: Any,
model: str,
*,
preset: Any | None = None,
) -> Any:
"""Create an LLM provider instance for a specific model string.
When *preset* is given, its generation settings (temperature, max_tokens,
reasoning_effort) override the active preset defaults.
"""
from nanobot.providers.base import GenerationSettings
from nanobot.providers.factory import make_provider
from nanobot.providers.registry import find_by_name
resolved = config.resolve_preset()
model = resolved.model
gen_src = preset or config.resolve_preset()
provider_name = config.get_provider_name(model)
p = config.get_provider(model)
spec = find_by_name(provider_name) if provider_name else None
@ -185,8 +196,34 @@ def _make_provider(config: Any) -> Any:
)
provider.generation = GenerationSettings(
temperature=resolved.temperature,
max_tokens=resolved.max_tokens,
reasoning_effort=resolved.reasoning_effort,
temperature=gen_src.temperature,
max_tokens=gen_src.max_tokens,
reasoning_effort=gen_src.reasoning_effort,
)
return provider
def _make_provider(config: Any) -> Any:
"""Create the LLM provider for the primary model from config."""
return _make_provider_for_model(config, config.resolve_preset().model)
def _make_provider_factory(config: Any):
"""Build a cached factory that creates providers for arbitrary model strings.
If a model string matches a preset name in ``config.model_presets``, the
preset's full config (model, temperature, max_tokens, …) is used.
"""
cache: dict[str, Any] = {}
presets = getattr(config, "model_presets", {}) or {}
def factory(model_or_preset: str):
preset = presets.get(model_or_preset)
actual_model = preset.model if preset else model_or_preset
provider_name = config.get_provider_name(actual_model)
key = provider_name or actual_model
if key not in cache:
cache[key] = _make_provider_for_model(config, actual_model, preset=preset)
return cache[key]
return factory

View File

@ -0,0 +1,190 @@
"""Tests for the provider fallback models feature in AgentRunner."""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock
import pytest
from nanobot.agent.runner import AgentRunSpec, AgentRunner
from nanobot.providers.base import LLMResponse
def _make_tools():
tools = MagicMock()
tools.get_definitions.return_value = []
tools.execute = AsyncMock(return_value="ok")
return tools
def _make_provider(*, model_response: LLMResponse | None = None):
p = MagicMock()
if model_response is not None:
p.chat_with_retry = AsyncMock(return_value=model_response)
return p
def _base_spec(**overrides) -> AgentRunSpec:
defaults = dict(
initial_messages=[
{"role": "system", "content": "sys"},
{"role": "user", "content": "hello"},
],
tools=_make_tools(),
model="primary-model",
max_iterations=1,
max_tool_result_chars=8000,
)
defaults.update(overrides)
return AgentRunSpec(**defaults)
@pytest.mark.asyncio
async def test_no_fallback_when_primary_succeeds():
"""Primary succeeds -> fallback list never consulted."""
ok = LLMResponse(content="done", tool_calls=[], usage={})
provider = _make_provider(model_response=ok)
factory = MagicMock()
runner = AgentRunner(provider, provider_factory=factory)
result = await runner.run(_base_spec(fallback_models=["fb-1", "fb-2"]))
assert result.final_content == "done"
factory.assert_not_called()
@pytest.mark.asyncio
async def test_fallback_triggered_on_primary_error():
"""Primary fails -> first fallback succeeds."""
err = LLMResponse(content=None, finish_reason="error", usage={})
ok = LLMResponse(content="fallback-ok", tool_calls=[], usage={})
primary = _make_provider(model_response=err)
fb_provider = MagicMock()
fb_provider.chat_with_retry = AsyncMock(return_value=ok)
factory = MagicMock(return_value=fb_provider)
runner = AgentRunner(primary, provider_factory=factory)
result = await runner.run(_base_spec(fallback_models=["fb-model"]))
assert result.final_content == "fallback-ok"
factory.assert_called_once_with("fb-model")
fb_provider.chat_with_retry.assert_awaited_once()
@pytest.mark.asyncio
async def test_all_fallbacks_fail_returns_last_error():
"""Primary + all fallbacks fail -> return last error response."""
err = LLMResponse(content=None, finish_reason="error", usage={})
primary = _make_provider(model_response=err)
fb1 = _make_provider(model_response=err)
fb2 = _make_provider(model_response=LLMResponse(
content="last-error", finish_reason="error", usage={},
))
providers = {"fb-1": fb1, "fb-2": fb2}
factory = MagicMock(side_effect=lambda m: providers[m])
runner = AgentRunner(primary, provider_factory=factory)
result = await runner.run(_base_spec(fallback_models=["fb-1", "fb-2"]))
assert result.error is not None or result.final_content is not None
@pytest.mark.asyncio
async def test_empty_fallback_list_no_retry():
"""Empty fallback_models -> no fallback attempted."""
err = LLMResponse(content=None, finish_reason="error", usage={})
primary = _make_provider(model_response=err)
factory = MagicMock()
runner = AgentRunner(primary, provider_factory=factory)
result = await runner.run(_base_spec(fallback_models=[]))
factory.assert_not_called()
assert result.error is not None
@pytest.mark.asyncio
async def test_cross_provider_fallback():
"""Fallback uses a different provider instance (cross-provider)."""
err = LLMResponse(content=None, finish_reason="error", usage={})
ok = LLMResponse(content="cross-provider-ok", tool_calls=[], usage={})
primary = _make_provider(model_response=err)
anthropic_provider = MagicMock()
anthropic_provider.chat_with_retry = AsyncMock(return_value=ok)
def cross_factory(model: str):
if model == "anthropic/claude-sonnet":
return anthropic_provider
raise ValueError(f"unexpected model: {model}")
runner = AgentRunner(primary, provider_factory=cross_factory)
result = await runner.run(_base_spec(fallback_models=["anthropic/claude-sonnet"]))
assert result.final_content == "cross-provider-ok"
anthropic_provider.chat_with_retry.assert_awaited_once()
@pytest.mark.asyncio
async def test_fallback_skips_to_second_on_first_error():
"""First fallback also fails -> second fallback succeeds."""
err = LLMResponse(content=None, finish_reason="error", usage={})
ok = LLMResponse(content="second-fb-ok", tool_calls=[], usage={})
primary = _make_provider(model_response=err)
fb1 = _make_provider(model_response=err)
fb2 = MagicMock()
fb2.chat_with_retry = AsyncMock(return_value=ok)
providers = {"fb-1": fb1, "fb-2": fb2}
factory = MagicMock(side_effect=lambda m: providers[m])
runner = AgentRunner(primary, provider_factory=factory)
result = await runner.run(_base_spec(fallback_models=["fb-1", "fb-2"]))
assert result.final_content == "second-fb-ok"
assert factory.call_count == 2
@pytest.mark.asyncio
async def test_fallback_reuses_same_provider_without_factory():
"""No provider_factory -> fallback reuses primary provider with different model."""
call_count = {"n": 0}
async def chat_with_retry(*, messages, model, **kw):
call_count["n"] += 1
if call_count["n"] == 1:
return LLMResponse(content=None, finish_reason="error", usage={})
return LLMResponse(content=f"ok-via-{model}", tool_calls=[], usage={})
primary = MagicMock()
primary.chat_with_retry = chat_with_retry
runner = AgentRunner(primary, provider_factory=None)
result = await runner.run(_base_spec(fallback_models=["fallback-model"]))
assert result.final_content == "ok-via-fallback-model"
@pytest.mark.asyncio
async def test_fallback_provider_cached():
"""Provider factory is called once per unique provider, not per attempt."""
err = LLMResponse(content=None, finish_reason="error", usage={})
ok = LLMResponse(content="cached-ok", tool_calls=[], usage={})
primary = _make_provider(model_response=err)
fb_provider = MagicMock()
call_seq = [err, ok]
fb_provider.chat_with_retry = AsyncMock(side_effect=call_seq)
factory = MagicMock(return_value=fb_provider)
runner = AgentRunner(primary, provider_factory=factory)
result = await runner.run(_base_spec(fallback_models=["same-provider-model-a", "same-provider-model-b"]))
assert result.final_content == "cached-ok"