fix(providers): allow retry and fallback on stream stalled timeout

When a stream stalls mid-response, both the retry layer and
FallbackProvider blocked recovery because content had already been
emitted via on_content_delta. This left users with truncated replies
and no automatic recovery.

For error_kind="timeout" specifically:
- _run_with_retry now suppresses delta callbacks and retries the same
  model instead of returning immediately
- FallbackProvider now allows failover to a different model with
  delta callbacks suppressed

Non-timeout errors retain the original "skip retry/failover after
streamed content" behavior to avoid duplicate output.
This commit is contained in:
aiguozhi123456 2026-06-10 14:38:11 +08:00 committed by Xubin Ren
parent dadb35af49
commit 2c5a4e0703
4 changed files with 97 additions and 11 deletions

View File

@ -827,10 +827,22 @@ class LLMProvider(ABC):
return response return response
last_response = response last_response = response
if should_retry_guard is not None and not should_retry_guard(): if should_retry_guard is not None and not should_retry_guard():
logger.warning( is_timeout = (response.error_kind or "").lower() == "timeout"
"LLM stream failed after content was emitted; skipping retry" if is_timeout:
) logger.warning(
return response "LLM stream stalled after content was emitted; "
"suppressing delta callbacks and retrying"
)
kw.setdefault("on_content_delta", None)
kw["on_content_delta"] = None
kw["on_thinking_delta"] = None
kw["on_tool_call_delta"] = None
should_retry_guard = None
else:
logger.warning(
"LLM stream failed after content was emitted; skipping retry"
)
return response
error_key = ((response.content or "").strip().lower() or None) error_key = ((response.content or "").strip().lower() or None)
if error_key and error_key == last_error_key: if error_key and error_key == last_error_key:
identical_error_count += 1 identical_error_count += 1

View File

@ -149,10 +149,20 @@ class FallbackProvider(LLMProvider):
return response return response
if has_streamed is not None and has_streamed[0]: if has_streamed is not None and has_streamed[0]:
logger.warning( is_timeout = (response.error_kind or "").lower() == "timeout"
"Primary model error but content already streamed; skipping failover" if is_timeout:
) logger.warning(
return response "Primary model '{}' stream stalled after content was emitted; "
"attempting failover anyway",
primary_model,
)
has_streamed[0] = False
kwargs["on_content_delta"] = None
else:
logger.warning(
"Primary model error but content already streamed; skipping failover"
)
return response
if not self._should_fallback(response): if not self._should_fallback(response):
logger.warning( logger.warning(

View File

@ -287,7 +287,7 @@ class TestFallbackOnPrimaryError:
class TestNoFallbackWhenContentStreamed: class TestNoFallbackWhenContentStreamed:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test(self) -> None: async def test_non_timeout_error_skips_failover(self) -> None:
primary = _FakeProvider("primary", _error_response()) primary = _FakeProvider("primary", _error_response())
factory = MagicMock() factory = MagicMock()
fb = FallbackProvider( fb = FallbackProvider(
@ -303,12 +303,40 @@ class TestNoFallbackWhenContentStreamed:
messages=[{"role": "user", "content": "hi"}], messages=[{"role": "user", "content": "hi"}],
on_content_delta=_delta, on_content_delta=_delta,
) )
# Primary returns error but content was "streamed" (FakeProvider calls delta)
# so failover should be skipped
assert result.finish_reason == "error" assert result.finish_reason == "error"
factory.assert_not_called() factory.assert_not_called()
class TestFallbackOnStreamStalledAfterContent:
@pytest.mark.asyncio
async def test_timeout_with_streamed_content_falls_back(self) -> None:
primary = _FakeProvider(
"primary",
_make_response("stream stalled", finish_reason="error", error_kind="timeout"),
)
fallback = _FakeProvider("fallback", _make_response("fallback ok"))
factory = MagicMock(return_value=fallback)
fb = FallbackProvider(
primary=primary,
fallback_presets=[_fallback("fallback-a")],
provider_factory=factory,
)
streamed: list[str] = []
async def _delta(text: str) -> None:
streamed.append(text)
result = await fb.chat_stream(
messages=[{"role": "user", "content": "hi"}],
on_content_delta=_delta,
)
assert result.finish_reason == "stop"
assert result.content == "fallback ok"
factory.assert_called_once_with(_fallback("fallback-a"))
assert "stream stalled" in streamed
class TestFailoverOnTransientError: class TestFailoverOnTransientError:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_rate_limit(self) -> None: async def test_rate_limit(self) -> None:

View File

@ -163,6 +163,42 @@ async def test_chat_stream_with_retry_does_not_retry_after_emitting_content(monk
assert delays == [] assert delays == []
@pytest.mark.asyncio
async def test_chat_stream_with_retry_retries_timeout_after_emitting_content(monkeypatch) -> None:
first = LLMResponse(
content="Error calling LLM: stream stalled for more than 30 seconds",
finish_reason="error",
error_kind="timeout",
)
first._test_stream_delta = "partial" # type: ignore[attr-defined]
provider = ScriptedProvider([
first,
LLMResponse(content="full retry response"),
])
deltas: list[str] = []
delays: list[int] = []
async def _fake_sleep(delay: int) -> None:
delays.append(delay)
async def _on_delta(delta: str) -> None:
deltas.append(delta)
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
response = await provider.chat_stream_with_retry(
messages=[{"role": "user", "content": "hello"}],
on_content_delta=_on_delta,
)
assert response.content == "full retry response"
assert response.finish_reason == "stop"
assert provider.calls == 2
assert deltas == ["partial"]
assert delays == [1]
assert provider.last_kwargs.get("on_content_delta") is None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_chat_with_retry_uses_provider_generation_defaults() -> None: async def test_chat_with_retry_uses_provider_generation_defaults() -> None:
"""When callers omit generation params, provider.generation defaults are used.""" """When callers omit generation params, provider.generation defaults are used."""