mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-15 07:14:08 +00:00
fix(providers): allow retry and fallback on stream stalled timeout
When a stream stalls mid-response, both the retry layer and FallbackProvider blocked recovery because content had already been emitted via on_content_delta. This left users with truncated replies and no automatic recovery. For error_kind="timeout" specifically: - _run_with_retry now suppresses delta callbacks and retries the same model instead of returning immediately - FallbackProvider now allows failover to a different model with delta callbacks suppressed Non-timeout errors retain the original "skip retry/failover after streamed content" behavior to avoid duplicate output.
This commit is contained in:
parent
dadb35af49
commit
2c5a4e0703
@ -827,10 +827,22 @@ class LLMProvider(ABC):
|
||||
return response
|
||||
last_response = response
|
||||
if should_retry_guard is not None and not should_retry_guard():
|
||||
logger.warning(
|
||||
"LLM stream failed after content was emitted; skipping retry"
|
||||
)
|
||||
return response
|
||||
is_timeout = (response.error_kind or "").lower() == "timeout"
|
||||
if is_timeout:
|
||||
logger.warning(
|
||||
"LLM stream stalled after content was emitted; "
|
||||
"suppressing delta callbacks and retrying"
|
||||
)
|
||||
kw.setdefault("on_content_delta", None)
|
||||
kw["on_content_delta"] = None
|
||||
kw["on_thinking_delta"] = None
|
||||
kw["on_tool_call_delta"] = None
|
||||
should_retry_guard = None
|
||||
else:
|
||||
logger.warning(
|
||||
"LLM stream failed after content was emitted; skipping retry"
|
||||
)
|
||||
return response
|
||||
error_key = ((response.content or "").strip().lower() or None)
|
||||
if error_key and error_key == last_error_key:
|
||||
identical_error_count += 1
|
||||
|
||||
@ -149,10 +149,20 @@ class FallbackProvider(LLMProvider):
|
||||
return response
|
||||
|
||||
if has_streamed is not None and has_streamed[0]:
|
||||
logger.warning(
|
||||
"Primary model error but content already streamed; skipping failover"
|
||||
)
|
||||
return response
|
||||
is_timeout = (response.error_kind or "").lower() == "timeout"
|
||||
if is_timeout:
|
||||
logger.warning(
|
||||
"Primary model '{}' stream stalled after content was emitted; "
|
||||
"attempting failover anyway",
|
||||
primary_model,
|
||||
)
|
||||
has_streamed[0] = False
|
||||
kwargs["on_content_delta"] = None
|
||||
else:
|
||||
logger.warning(
|
||||
"Primary model error but content already streamed; skipping failover"
|
||||
)
|
||||
return response
|
||||
|
||||
if not self._should_fallback(response):
|
||||
logger.warning(
|
||||
|
||||
@ -287,7 +287,7 @@ class TestFallbackOnPrimaryError:
|
||||
|
||||
class TestNoFallbackWhenContentStreamed:
|
||||
@pytest.mark.asyncio
|
||||
async def test(self) -> None:
|
||||
async def test_non_timeout_error_skips_failover(self) -> None:
|
||||
primary = _FakeProvider("primary", _error_response())
|
||||
factory = MagicMock()
|
||||
fb = FallbackProvider(
|
||||
@ -303,12 +303,40 @@ class TestNoFallbackWhenContentStreamed:
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
on_content_delta=_delta,
|
||||
)
|
||||
# Primary returns error but content was "streamed" (FakeProvider calls delta)
|
||||
# so failover should be skipped
|
||||
assert result.finish_reason == "error"
|
||||
factory.assert_not_called()
|
||||
|
||||
|
||||
class TestFallbackOnStreamStalledAfterContent:
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeout_with_streamed_content_falls_back(self) -> None:
|
||||
primary = _FakeProvider(
|
||||
"primary",
|
||||
_make_response("stream stalled", finish_reason="error", error_kind="timeout"),
|
||||
)
|
||||
fallback = _FakeProvider("fallback", _make_response("fallback ok"))
|
||||
factory = MagicMock(return_value=fallback)
|
||||
fb = FallbackProvider(
|
||||
primary=primary,
|
||||
fallback_presets=[_fallback("fallback-a")],
|
||||
provider_factory=factory,
|
||||
)
|
||||
|
||||
streamed: list[str] = []
|
||||
|
||||
async def _delta(text: str) -> None:
|
||||
streamed.append(text)
|
||||
|
||||
result = await fb.chat_stream(
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
on_content_delta=_delta,
|
||||
)
|
||||
assert result.finish_reason == "stop"
|
||||
assert result.content == "fallback ok"
|
||||
factory.assert_called_once_with(_fallback("fallback-a"))
|
||||
assert "stream stalled" in streamed
|
||||
|
||||
|
||||
class TestFailoverOnTransientError:
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limit(self) -> None:
|
||||
|
||||
@ -163,6 +163,42 @@ async def test_chat_stream_with_retry_does_not_retry_after_emitting_content(monk
|
||||
assert delays == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_stream_with_retry_retries_timeout_after_emitting_content(monkeypatch) -> None:
|
||||
first = LLMResponse(
|
||||
content="Error calling LLM: stream stalled for more than 30 seconds",
|
||||
finish_reason="error",
|
||||
error_kind="timeout",
|
||||
)
|
||||
first._test_stream_delta = "partial" # type: ignore[attr-defined]
|
||||
provider = ScriptedProvider([
|
||||
first,
|
||||
LLMResponse(content="full retry response"),
|
||||
])
|
||||
deltas: list[str] = []
|
||||
delays: list[int] = []
|
||||
|
||||
async def _fake_sleep(delay: int) -> None:
|
||||
delays.append(delay)
|
||||
|
||||
async def _on_delta(delta: str) -> None:
|
||||
deltas.append(delta)
|
||||
|
||||
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
|
||||
|
||||
response = await provider.chat_stream_with_retry(
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
on_content_delta=_on_delta,
|
||||
)
|
||||
|
||||
assert response.content == "full retry response"
|
||||
assert response.finish_reason == "stop"
|
||||
assert provider.calls == 2
|
||||
assert deltas == ["partial"]
|
||||
assert delays == [1]
|
||||
assert provider.last_kwargs.get("on_content_delta") is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_with_retry_uses_provider_generation_defaults() -> None:
|
||||
"""When callers omit generation params, provider.generation defaults are used."""
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user