From 2e5930e355c9dff6348ba11c6e240d132af134fc Mon Sep 17 00:00:00 2001 From: LeftX <53989315+xzqling@users.noreply.github.com> Date: Wed, 6 May 2026 14:49:07 +0800 Subject: [PATCH] feat: add fallback_models support for automatic model failover When the primary model fails (finish_reason="error" after exhausting provider-level retries), automatically try each model in the configured fallback_models list. Supports cross-provider fallback via a cached provider_factory that resolves the correct provider for each model string. Config: agents.defaults.fallback_models: ["model-b", "provider/model-c"] Changes: - AgentDefaults: add fallback_models field - AgentRunSpec: add fallback_models field - AgentRunner: add provider_factory, _call_provider, _resolve_fallback_provider - AgentLoop: accept and forward fallback_models + provider_factory - nanobot.py: extract _make_provider_for_model, add _make_provider_factory - cli/commands.py: add _make_cli_provider_factory, wire all AgentLoop sites - tests/agent/test_runner_fallback.py: 8 test cases covering primary success, single/multi fallback, cross-provider, no-factory reuse, caching Made-with: Cursor --- nanobot/agent/loop.py | 6 +- nanobot/agent/runner.py | 70 ++++++++-- nanobot/cli/commands.py | 66 +++++++--- nanobot/config/schema.py | 1 + nanobot/nanobot.py | 51 +++++++- tests/agent/test_runner_fallback.py | 190 ++++++++++++++++++++++++++++ 6 files changed, 350 insertions(+), 34 deletions(-) create mode 100644 tests/agent/test_runner_fallback.py diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index 66c3162db..4a50d4536 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -200,6 +200,8 @@ class AgentLoop: max_tool_result_chars: int | None = None, provider_retry_mode: str = "standard", tool_hint_max_length: int | None = None, + fallback_models: list[str] | None = None, + provider_factory: Any | None = None, web_config: WebToolsConfig | None = None, exec_config: ExecToolConfig | None = None, cron_service: CronService | None = None, @@ -250,6 +252,7 @@ class AgentLoop: tool_hint_max_length if tool_hint_max_length is not None else defaults.tool_hint_max_length ) + self.fallback_models = fallback_models or [] self.web_config = web_config or WebToolsConfig() self.exec_config = exec_config or ExecToolConfig() self.cron_service = cron_service @@ -263,7 +266,7 @@ class AgentLoop: # One file-read/write tracker per logical session. The tool registry is # shared by this loop, so tools resolve the active state via contextvars. self._file_state_store = FileStateStore() - self.runner = AgentRunner(provider) + self.runner = AgentRunner(provider, provider_factory=provider_factory) self.subagents = SubagentManager( provider=provider, workspace=workspace, @@ -681,6 +684,7 @@ class AgentLoop: context_window_tokens=self.context_window_tokens, context_block_limit=self.context_block_limit, provider_retry_mode=self.provider_retry_mode, + fallback_models=self.fallback_models, progress_callback=on_progress, stream_progress_deltas=on_stream is not None, retry_wait_callback=on_retry_wait, diff --git a/nanobot/agent/runner.py b/nanobot/agent/runner.py index 7fe92ad51..4a25b5ca3 100644 --- a/nanobot/agent/runner.py +++ b/nanobot/agent/runner.py @@ -75,6 +75,7 @@ class AgentRunSpec: context_window_tokens: int | None = None context_block_limit: int | None = None provider_retry_mode: str = "standard" + fallback_models: list[str] = field(default_factory=list) progress_callback: Any | None = None stream_progress_deltas: bool = True retry_wait_callback: Any | None = None @@ -97,11 +98,21 @@ class AgentRunResult: had_injections: bool = False +ProviderFactory = Any # Callable[[str], LLMProvider] — avoids circular import + + class AgentRunner: """Run a tool-capable LLM loop without product-layer concerns.""" - def __init__(self, provider: LLMProvider): + def __init__( + self, + provider: LLMProvider, + *, + provider_factory: ProviderFactory | None = None, + ): self.provider = provider + self._provider_factory = provider_factory + self._fallback_providers: dict[str, LLMProvider] = {} @staticmethod def _merge_message_content(left: Any, right: Any) -> str | list[dict[str, Any]]: @@ -594,12 +605,9 @@ class AgentRunner: messages: list[dict[str, Any]], hook: AgentHook, context: AgentHookContext, - ): + ) -> LLMResponse: timeout_s: float | None = spec.llm_timeout_s if timeout_s is None: - # Default to a finite timeout to avoid per-session lock starvation when an LLM - # request hangs indefinitely (e.g. gateway/network stall). - # Set NANOBOT_LLM_TIMEOUT_S=0 to disable. raw = os.environ.get("NANOBOT_LLM_TIMEOUT_S", "300").strip() try: timeout_s = float(raw) @@ -613,12 +621,40 @@ class AgentRunner: messages, tools=spec.tools.get_definitions(), ) + response = await self._call_provider(self.provider, kwargs, hook, context, spec, timeout_s) + + if response.finish_reason == "error" and spec.fallback_models: + for fb_model in spec.fallback_models: + logger.warning( + "Primary model {} failed, trying fallback: {}", + spec.model, + fb_model, + ) + fb_provider, resolved_model = self._resolve_fallback_provider(fb_model) + fb_kwargs = dict(kwargs, model=resolved_model) + response = await self._call_provider( + fb_provider, fb_kwargs, hook, context, spec, timeout_s, + ) + if response.finish_reason != "error": + break + + return response + + async def _call_provider( + self, + provider: LLMProvider, + kwargs: dict[str, Any], + hook: AgentHook, + context: AgentHookContext, + spec: AgentRunSpec, + timeout_s: float | None = None, + ) -> LLMResponse: wants_streaming = hook.wants_streaming() wants_progress_streaming = ( not wants_streaming and spec.stream_progress_deltas and spec.progress_callback is not None - and getattr(self.provider, "supports_progress_deltas", False) is True + and getattr(provider, "supports_progress_deltas", False) is True ) if wants_streaming: @@ -627,7 +663,7 @@ class AgentRunner: context.streamed_content = True await hook.on_stream(context, delta) - coro = self.provider.chat_stream_with_retry( + coro = provider.chat_stream_with_retry( **kwargs, on_content_delta=_stream, ) @@ -646,12 +682,12 @@ class AgentRunner: context.streamed_content = True await spec.progress_callback(incremental) - coro = self.provider.chat_stream_with_retry( + coro = provider.chat_stream_with_retry( **kwargs, on_content_delta=_stream_progress, ) else: - coro = self.provider.chat_with_retry(**kwargs) + coro = provider.chat_with_retry(**kwargs) if timeout_s is None: return await coro @@ -664,6 +700,22 @@ class AgentRunner: error_kind="timeout", ) + def _resolve_fallback_provider(self, model: str) -> tuple[LLMProvider, str]: + """Return (provider, actual_model_name) for a fallback model. + + When a provider_factory is available (and the model string may be a + preset name), the factory resolves the actual model; otherwise the + primary provider is reused with the raw model string. + """ + if model in self._fallback_providers: + p = self._fallback_providers[model] + return p, p.get_default_model() + if self._provider_factory: + provider = self._provider_factory(model) + self._fallback_providers[model] = provider + return provider, provider.get_default_model() + return self.provider, model + async def _request_finalization_retry( self, spec: AgentRunSpec, diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py index 14458ff8a..ae11d3ffc 100644 --- a/nanobot/cli/commands.py +++ b/nanobot/cli/commands.py @@ -513,6 +513,29 @@ def _make_provider(config: Config): return provider +def _make_cli_provider_factory(config: Config): + """Build a cached factory for fallback model providers (CLI side). + + Supports preset names: if a model string matches a preset, the preset's + full config is used for provider creation. + """ + from nanobot.nanobot import _make_provider_for_model + + cache: dict[str, Any] = {} + presets = getattr(config, "model_presets", {}) or {} + + def factory(model_or_preset: str): + preset = presets.get(model_or_preset) + actual_model = preset.model if preset else model_or_preset + provider_name = config.get_provider_name(actual_model) + key = provider_name or actual_model + if key not in cache: + cache[key] = _make_provider_for_model(config, actual_model, preset=preset) + return cache[key] + + return factory + + def _load_runtime_config(config: str | None = None, workspace: str | None = None) -> Config: """Load config and optionally override the active workspace.""" from nanobot.config.loader import load_config, resolve_config_env_vars, set_config_path @@ -608,6 +631,8 @@ def serve( sync_workspace_templates(runtime_config.workspace_path) bus = MessageBus() provider = _make_provider(runtime_config) + defaults = runtime_config.agents.defaults + pf = _make_cli_provider_factory(runtime_config) if defaults.fallback_models else None session_manager = SessionManager(runtime_config.workspace_path) _resolved = runtime_config.resolve_preset() agent_loop = AgentLoop( @@ -615,12 +640,13 @@ def serve( provider=provider, workspace=runtime_config.workspace_path, model=_resolved.model, - max_iterations=runtime_config.agents.defaults.max_tool_iterations, + max_iterations=defaults.max_tool_iterations, context_window_tokens=_resolved.context_window_tokens, - context_block_limit=runtime_config.agents.defaults.context_block_limit, - max_tool_result_chars=runtime_config.agents.defaults.max_tool_result_chars, - provider_retry_mode=runtime_config.agents.defaults.provider_retry_mode, - tool_hint_max_length=runtime_config.agents.defaults.tool_hint_max_length, + context_block_limit=defaults.context_block_limit, + max_tool_result_chars=defaults.max_tool_result_chars, + provider_retry_mode=defaults.provider_retry_mode, + fallback_models=defaults.fallback_models, + provider_factory=pf, web_config=runtime_config.tools.web, exec_config=runtime_config.tools.exec, restrict_to_workspace=runtime_config.tools.restrict_to_workspace, @@ -639,7 +665,7 @@ def serve( ) model_name = _resolved.model - preset_name = runtime_config.agents.defaults.model_preset + preset_name = defaults.model_preset preset_tag = f" (preset: {preset_name})" if preset_name else "" console.print(f"{__logo__} Starting OpenAI-compatible API server") console.print(f" [cyan]Endpoint[/cyan] : http://{host}:{port}/v1/chat/completions") @@ -721,12 +747,14 @@ def _run_gateway( console.print(f"{__logo__} Starting nanobot gateway version {__version__} on port {port}...") sync_workspace_templates(config.workspace_path) bus = MessageBus() + provider = _make_provider(config) + gw_defaults = config.agents.defaults + gw_pf = _make_cli_provider_factory(config) if gw_defaults.fallback_models else None try: provider_snapshot = build_provider_snapshot(config) except ValueError as exc: console.print(f"[red]Error: {exc}[/red]") raise typer.Exit(1) from exc - provider = provider_snapshot.provider session_manager = SessionManager(config.workspace_path) # Preserve existing single-workspace installs, but keep custom workspaces clean. @@ -744,13 +772,14 @@ def _run_gateway( provider=provider, workspace=config.workspace_path, model=_resolved.model, - max_iterations=config.agents.defaults.max_tool_iterations, + max_iterations=gw_defaults.max_tool_iterations, context_window_tokens=_resolved.context_window_tokens, web_config=config.tools.web, - context_block_limit=config.agents.defaults.context_block_limit, - max_tool_result_chars=config.agents.defaults.max_tool_result_chars, - provider_retry_mode=config.agents.defaults.provider_retry_mode, - tool_hint_max_length=config.agents.defaults.tool_hint_max_length, + context_block_limit=gw_defaults.context_block_limit, + max_tool_result_chars=gw_defaults.max_tool_result_chars, + provider_retry_mode=gw_defaults.provider_retry_mode, + fallback_models=gw_defaults.fallback_models, + provider_factory=gw_pf, exec_config=config.tools.exec, cron_service=cron, restrict_to_workspace=config.tools.restrict_to_workspace, @@ -1120,6 +1149,8 @@ def agent( bus = MessageBus() provider = _make_provider(config) + chat_defaults = config.agents.defaults + chat_pf = _make_cli_provider_factory(config) if chat_defaults.fallback_models else None # Preserve existing single-workspace installs, but keep custom workspaces clean. if is_default_workspace(config.workspace_path): @@ -1140,13 +1171,14 @@ def agent( provider=provider, workspace=config.workspace_path, model=_resolved.model, - max_iterations=config.agents.defaults.max_tool_iterations, + max_iterations=chat_defaults.max_tool_iterations, context_window_tokens=_resolved.context_window_tokens, web_config=config.tools.web, - context_block_limit=config.agents.defaults.context_block_limit, - max_tool_result_chars=config.agents.defaults.max_tool_result_chars, - provider_retry_mode=config.agents.defaults.provider_retry_mode, - tool_hint_max_length=config.agents.defaults.tool_hint_max_length, + context_block_limit=chat_defaults.context_block_limit, + max_tool_result_chars=chat_defaults.max_tool_result_chars, + provider_retry_mode=chat_defaults.provider_retry_mode, + fallback_models=chat_defaults.fallback_models, + provider_factory=chat_pf, exec_config=config.tools.exec, cron_service=cron, restrict_to_workspace=config.tools.restrict_to_workspace, diff --git a/nanobot/config/schema.py b/nanobot/config/schema.py index 6a1b89cf4..1bc78be98 100644 --- a/nanobot/config/schema.py +++ b/nanobot/config/schema.py @@ -105,6 +105,7 @@ class AgentDefaults(Base): serialization_alias="toolHintMaxLength", ) # Max characters for tool hint display (e.g. "$ cd …/project && npm test") reasoning_effort: str | None = None # low / medium / high / adaptive - enables LLM thinking mode + fallback_models: list[str] = Field(default_factory=list) timezone: str = "UTC" # IANA timezone, e.g. "Asia/Shanghai", "America/New_York" unified_session: bool = False # Share one session across all channels (single-user multi-device) disabled_skills: list[str] = Field(default_factory=list) # Skill names to exclude from loading (e.g. ["summarize", "skill-creator"]) diff --git a/nanobot/nanobot.py b/nanobot/nanobot.py index fd1e40b26..bf3342f44 100644 --- a/nanobot/nanobot.py +++ b/nanobot/nanobot.py @@ -66,6 +66,7 @@ class Nanobot: bus = MessageBus() defaults = config.agents.defaults _resolved = config.resolve_preset() + pf = _make_provider_factory(config) if defaults.fallback_models else None loop = AgentLoop( bus=bus, @@ -78,6 +79,8 @@ class Nanobot: max_tool_result_chars=defaults.max_tool_result_chars, provider_retry_mode=defaults.provider_retry_mode, tool_hint_max_length=defaults.tool_hint_max_length, + fallback_models=defaults.fallback_models, + provider_factory=pf, web_config=config.tools.web, exec_config=config.tools.exec, restrict_to_workspace=config.tools.restrict_to_workspace, @@ -127,14 +130,22 @@ class Nanobot: ) -def _make_provider(config: Any) -> Any: - """Create the LLM provider from config (extracted from CLI).""" +def _make_provider_for_model( + config: Any, + model: str, + *, + preset: Any | None = None, +) -> Any: + """Create an LLM provider instance for a specific model string. + + When *preset* is given, its generation settings (temperature, max_tokens, + reasoning_effort) override the active preset defaults. + """ from nanobot.providers.base import GenerationSettings from nanobot.providers.factory import make_provider from nanobot.providers.registry import find_by_name - resolved = config.resolve_preset() - model = resolved.model + gen_src = preset or config.resolve_preset() provider_name = config.get_provider_name(model) p = config.get_provider(model) spec = find_by_name(provider_name) if provider_name else None @@ -185,8 +196,34 @@ def _make_provider(config: Any) -> Any: ) provider.generation = GenerationSettings( - temperature=resolved.temperature, - max_tokens=resolved.max_tokens, - reasoning_effort=resolved.reasoning_effort, + temperature=gen_src.temperature, + max_tokens=gen_src.max_tokens, + reasoning_effort=gen_src.reasoning_effort, ) return provider + + +def _make_provider(config: Any) -> Any: + """Create the LLM provider for the primary model from config.""" + return _make_provider_for_model(config, config.resolve_preset().model) + + +def _make_provider_factory(config: Any): + """Build a cached factory that creates providers for arbitrary model strings. + + If a model string matches a preset name in ``config.model_presets``, the + preset's full config (model, temperature, max_tokens, …) is used. + """ + cache: dict[str, Any] = {} + presets = getattr(config, "model_presets", {}) or {} + + def factory(model_or_preset: str): + preset = presets.get(model_or_preset) + actual_model = preset.model if preset else model_or_preset + provider_name = config.get_provider_name(actual_model) + key = provider_name or actual_model + if key not in cache: + cache[key] = _make_provider_for_model(config, actual_model, preset=preset) + return cache[key] + + return factory diff --git a/tests/agent/test_runner_fallback.py b/tests/agent/test_runner_fallback.py new file mode 100644 index 000000000..bfd0bc225 --- /dev/null +++ b/tests/agent/test_runner_fallback.py @@ -0,0 +1,190 @@ +"""Tests for the provider fallback models feature in AgentRunner.""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from nanobot.agent.runner import AgentRunSpec, AgentRunner +from nanobot.providers.base import LLMResponse + + +def _make_tools(): + tools = MagicMock() + tools.get_definitions.return_value = [] + tools.execute = AsyncMock(return_value="ok") + return tools + + +def _make_provider(*, model_response: LLMResponse | None = None): + p = MagicMock() + if model_response is not None: + p.chat_with_retry = AsyncMock(return_value=model_response) + return p + + +def _base_spec(**overrides) -> AgentRunSpec: + defaults = dict( + initial_messages=[ + {"role": "system", "content": "sys"}, + {"role": "user", "content": "hello"}, + ], + tools=_make_tools(), + model="primary-model", + max_iterations=1, + max_tool_result_chars=8000, + ) + defaults.update(overrides) + return AgentRunSpec(**defaults) + + +@pytest.mark.asyncio +async def test_no_fallback_when_primary_succeeds(): + """Primary succeeds -> fallback list never consulted.""" + ok = LLMResponse(content="done", tool_calls=[], usage={}) + provider = _make_provider(model_response=ok) + factory = MagicMock() + + runner = AgentRunner(provider, provider_factory=factory) + result = await runner.run(_base_spec(fallback_models=["fb-1", "fb-2"])) + + assert result.final_content == "done" + factory.assert_not_called() + + +@pytest.mark.asyncio +async def test_fallback_triggered_on_primary_error(): + """Primary fails -> first fallback succeeds.""" + err = LLMResponse(content=None, finish_reason="error", usage={}) + ok = LLMResponse(content="fallback-ok", tool_calls=[], usage={}) + + primary = _make_provider(model_response=err) + + fb_provider = MagicMock() + fb_provider.chat_with_retry = AsyncMock(return_value=ok) + factory = MagicMock(return_value=fb_provider) + + runner = AgentRunner(primary, provider_factory=factory) + result = await runner.run(_base_spec(fallback_models=["fb-model"])) + + assert result.final_content == "fallback-ok" + factory.assert_called_once_with("fb-model") + fb_provider.chat_with_retry.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_all_fallbacks_fail_returns_last_error(): + """Primary + all fallbacks fail -> return last error response.""" + err = LLMResponse(content=None, finish_reason="error", usage={}) + + primary = _make_provider(model_response=err) + fb1 = _make_provider(model_response=err) + fb2 = _make_provider(model_response=LLMResponse( + content="last-error", finish_reason="error", usage={}, + )) + + providers = {"fb-1": fb1, "fb-2": fb2} + factory = MagicMock(side_effect=lambda m: providers[m]) + + runner = AgentRunner(primary, provider_factory=factory) + result = await runner.run(_base_spec(fallback_models=["fb-1", "fb-2"])) + + assert result.error is not None or result.final_content is not None + + +@pytest.mark.asyncio +async def test_empty_fallback_list_no_retry(): + """Empty fallback_models -> no fallback attempted.""" + err = LLMResponse(content=None, finish_reason="error", usage={}) + primary = _make_provider(model_response=err) + factory = MagicMock() + + runner = AgentRunner(primary, provider_factory=factory) + result = await runner.run(_base_spec(fallback_models=[])) + + factory.assert_not_called() + assert result.error is not None + + +@pytest.mark.asyncio +async def test_cross_provider_fallback(): + """Fallback uses a different provider instance (cross-provider).""" + err = LLMResponse(content=None, finish_reason="error", usage={}) + ok = LLMResponse(content="cross-provider-ok", tool_calls=[], usage={}) + + primary = _make_provider(model_response=err) + anthropic_provider = MagicMock() + anthropic_provider.chat_with_retry = AsyncMock(return_value=ok) + + def cross_factory(model: str): + if model == "anthropic/claude-sonnet": + return anthropic_provider + raise ValueError(f"unexpected model: {model}") + + runner = AgentRunner(primary, provider_factory=cross_factory) + result = await runner.run(_base_spec(fallback_models=["anthropic/claude-sonnet"])) + + assert result.final_content == "cross-provider-ok" + anthropic_provider.chat_with_retry.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_fallback_skips_to_second_on_first_error(): + """First fallback also fails -> second fallback succeeds.""" + err = LLMResponse(content=None, finish_reason="error", usage={}) + ok = LLMResponse(content="second-fb-ok", tool_calls=[], usage={}) + + primary = _make_provider(model_response=err) + fb1 = _make_provider(model_response=err) + fb2 = MagicMock() + fb2.chat_with_retry = AsyncMock(return_value=ok) + + providers = {"fb-1": fb1, "fb-2": fb2} + factory = MagicMock(side_effect=lambda m: providers[m]) + + runner = AgentRunner(primary, provider_factory=factory) + result = await runner.run(_base_spec(fallback_models=["fb-1", "fb-2"])) + + assert result.final_content == "second-fb-ok" + assert factory.call_count == 2 + + +@pytest.mark.asyncio +async def test_fallback_reuses_same_provider_without_factory(): + """No provider_factory -> fallback reuses primary provider with different model.""" + call_count = {"n": 0} + + async def chat_with_retry(*, messages, model, **kw): + call_count["n"] += 1 + if call_count["n"] == 1: + return LLMResponse(content=None, finish_reason="error", usage={}) + return LLMResponse(content=f"ok-via-{model}", tool_calls=[], usage={}) + + primary = MagicMock() + primary.chat_with_retry = chat_with_retry + + runner = AgentRunner(primary, provider_factory=None) + result = await runner.run(_base_spec(fallback_models=["fallback-model"])) + + assert result.final_content == "ok-via-fallback-model" + + +@pytest.mark.asyncio +async def test_fallback_provider_cached(): + """Provider factory is called once per unique provider, not per attempt.""" + err = LLMResponse(content=None, finish_reason="error", usage={}) + ok = LLMResponse(content="cached-ok", tool_calls=[], usage={}) + + primary = _make_provider(model_response=err) + + fb_provider = MagicMock() + call_seq = [err, ok] + fb_provider.chat_with_retry = AsyncMock(side_effect=call_seq) + + factory = MagicMock(return_value=fb_provider) + + runner = AgentRunner(primary, provider_factory=factory) + result = await runner.run(_base_spec(fallback_models=["same-provider-model-a", "same-provider-model-b"])) + + assert result.final_content == "cached-ok"