From 2c5a4e070375cb2aed99752952e3fa2adb1f798f Mon Sep 17 00:00:00 2001 From: aiguozhi123456 <126325311+aiguozhi123456@users.noreply.github.com> Date: Wed, 10 Jun 2026 14:38:11 +0800 Subject: [PATCH] fix(providers): allow retry and fallback on stream stalled timeout When a stream stalls mid-response, both the retry layer and FallbackProvider blocked recovery because content had already been emitted via on_content_delta. This left users with truncated replies and no automatic recovery. For error_kind="timeout" specifically: - _run_with_retry now suppresses delta callbacks and retries the same model instead of returning immediately - FallbackProvider now allows failover to a different model with delta callbacks suppressed Non-timeout errors retain the original "skip retry/failover after streamed content" behavior to avoid duplicate output. --- nanobot/providers/base.py | 20 +++++++++++--- nanobot/providers/fallback_provider.py | 18 ++++++++++--- tests/agent/test_runner_fallback.py | 34 +++++++++++++++++++++--- tests/providers/test_provider_retry.py | 36 ++++++++++++++++++++++++++ 4 files changed, 97 insertions(+), 11 deletions(-) diff --git a/nanobot/providers/base.py b/nanobot/providers/base.py index 4a692b424..640a5c910 100644 --- a/nanobot/providers/base.py +++ b/nanobot/providers/base.py @@ -827,10 +827,22 @@ class LLMProvider(ABC): return response last_response = response if should_retry_guard is not None and not should_retry_guard(): - logger.warning( - "LLM stream failed after content was emitted; skipping retry" - ) - return response + is_timeout = (response.error_kind or "").lower() == "timeout" + if is_timeout: + logger.warning( + "LLM stream stalled after content was emitted; " + "suppressing delta callbacks and retrying" + ) + kw.setdefault("on_content_delta", None) + kw["on_content_delta"] = None + kw["on_thinking_delta"] = None + kw["on_tool_call_delta"] = None + should_retry_guard = None + else: + logger.warning( + "LLM stream failed after content was emitted; skipping retry" + ) + return response error_key = ((response.content or "").strip().lower() or None) if error_key and error_key == last_error_key: identical_error_count += 1 diff --git a/nanobot/providers/fallback_provider.py b/nanobot/providers/fallback_provider.py index c082c2361..d8ee4a5fa 100644 --- a/nanobot/providers/fallback_provider.py +++ b/nanobot/providers/fallback_provider.py @@ -149,10 +149,20 @@ class FallbackProvider(LLMProvider): return response if has_streamed is not None and has_streamed[0]: - logger.warning( - "Primary model error but content already streamed; skipping failover" - ) - return response + is_timeout = (response.error_kind or "").lower() == "timeout" + if is_timeout: + logger.warning( + "Primary model '{}' stream stalled after content was emitted; " + "attempting failover anyway", + primary_model, + ) + has_streamed[0] = False + kwargs["on_content_delta"] = None + else: + logger.warning( + "Primary model error but content already streamed; skipping failover" + ) + return response if not self._should_fallback(response): logger.warning( diff --git a/tests/agent/test_runner_fallback.py b/tests/agent/test_runner_fallback.py index a7a6f7c30..70d44e71d 100644 --- a/tests/agent/test_runner_fallback.py +++ b/tests/agent/test_runner_fallback.py @@ -287,7 +287,7 @@ class TestFallbackOnPrimaryError: class TestNoFallbackWhenContentStreamed: @pytest.mark.asyncio - async def test(self) -> None: + async def test_non_timeout_error_skips_failover(self) -> None: primary = _FakeProvider("primary", _error_response()) factory = MagicMock() fb = FallbackProvider( @@ -303,12 +303,40 @@ class TestNoFallbackWhenContentStreamed: messages=[{"role": "user", "content": "hi"}], on_content_delta=_delta, ) - # Primary returns error but content was "streamed" (FakeProvider calls delta) - # so failover should be skipped assert result.finish_reason == "error" factory.assert_not_called() +class TestFallbackOnStreamStalledAfterContent: + @pytest.mark.asyncio + async def test_timeout_with_streamed_content_falls_back(self) -> None: + primary = _FakeProvider( + "primary", + _make_response("stream stalled", finish_reason="error", error_kind="timeout"), + ) + fallback = _FakeProvider("fallback", _make_response("fallback ok")) + factory = MagicMock(return_value=fallback) + fb = FallbackProvider( + primary=primary, + fallback_presets=[_fallback("fallback-a")], + provider_factory=factory, + ) + + streamed: list[str] = [] + + async def _delta(text: str) -> None: + streamed.append(text) + + result = await fb.chat_stream( + messages=[{"role": "user", "content": "hi"}], + on_content_delta=_delta, + ) + assert result.finish_reason == "stop" + assert result.content == "fallback ok" + factory.assert_called_once_with(_fallback("fallback-a")) + assert "stream stalled" in streamed + + class TestFailoverOnTransientError: @pytest.mark.asyncio async def test_rate_limit(self) -> None: diff --git a/tests/providers/test_provider_retry.py b/tests/providers/test_provider_retry.py index 6fc2137df..07c3b1b18 100644 --- a/tests/providers/test_provider_retry.py +++ b/tests/providers/test_provider_retry.py @@ -163,6 +163,42 @@ async def test_chat_stream_with_retry_does_not_retry_after_emitting_content(monk assert delays == [] +@pytest.mark.asyncio +async def test_chat_stream_with_retry_retries_timeout_after_emitting_content(monkeypatch) -> None: + first = LLMResponse( + content="Error calling LLM: stream stalled for more than 30 seconds", + finish_reason="error", + error_kind="timeout", + ) + first._test_stream_delta = "partial" # type: ignore[attr-defined] + provider = ScriptedProvider([ + first, + LLMResponse(content="full retry response"), + ]) + deltas: list[str] = [] + delays: list[int] = [] + + async def _fake_sleep(delay: int) -> None: + delays.append(delay) + + async def _on_delta(delta: str) -> None: + deltas.append(delta) + + monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep) + + response = await provider.chat_stream_with_retry( + messages=[{"role": "user", "content": "hello"}], + on_content_delta=_on_delta, + ) + + assert response.content == "full retry response" + assert response.finish_reason == "stop" + assert provider.calls == 2 + assert deltas == ["partial"] + assert delays == [1] + assert provider.last_kwargs.get("on_content_delta") is None + + @pytest.mark.asyncio async def test_chat_with_retry_uses_provider_generation_defaults() -> None: """When callers omit generation params, provider.generation defaults are used."""