mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-15 07:14:08 +00:00
fix(providers): allow retry and fallback on stream stalled timeout
When a stream stalls mid-response, both the retry layer and FallbackProvider blocked recovery because content had already been emitted via on_content_delta. This left users with truncated replies and no automatic recovery. For error_kind="timeout" specifically: - _run_with_retry now suppresses delta callbacks and retries the same model instead of returning immediately - FallbackProvider now allows failover to a different model with delta callbacks suppressed Non-timeout errors retain the original "skip retry/failover after streamed content" behavior to avoid duplicate output.
This commit is contained in:
parent
dadb35af49
commit
2c5a4e0703
@ -827,10 +827,22 @@ class LLMProvider(ABC):
|
|||||||
return response
|
return response
|
||||||
last_response = response
|
last_response = response
|
||||||
if should_retry_guard is not None and not should_retry_guard():
|
if should_retry_guard is not None and not should_retry_guard():
|
||||||
logger.warning(
|
is_timeout = (response.error_kind or "").lower() == "timeout"
|
||||||
"LLM stream failed after content was emitted; skipping retry"
|
if is_timeout:
|
||||||
)
|
logger.warning(
|
||||||
return response
|
"LLM stream stalled after content was emitted; "
|
||||||
|
"suppressing delta callbacks and retrying"
|
||||||
|
)
|
||||||
|
kw.setdefault("on_content_delta", None)
|
||||||
|
kw["on_content_delta"] = None
|
||||||
|
kw["on_thinking_delta"] = None
|
||||||
|
kw["on_tool_call_delta"] = None
|
||||||
|
should_retry_guard = None
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"LLM stream failed after content was emitted; skipping retry"
|
||||||
|
)
|
||||||
|
return response
|
||||||
error_key = ((response.content or "").strip().lower() or None)
|
error_key = ((response.content or "").strip().lower() or None)
|
||||||
if error_key and error_key == last_error_key:
|
if error_key and error_key == last_error_key:
|
||||||
identical_error_count += 1
|
identical_error_count += 1
|
||||||
|
|||||||
@ -149,10 +149,20 @@ class FallbackProvider(LLMProvider):
|
|||||||
return response
|
return response
|
||||||
|
|
||||||
if has_streamed is not None and has_streamed[0]:
|
if has_streamed is not None and has_streamed[0]:
|
||||||
logger.warning(
|
is_timeout = (response.error_kind or "").lower() == "timeout"
|
||||||
"Primary model error but content already streamed; skipping failover"
|
if is_timeout:
|
||||||
)
|
logger.warning(
|
||||||
return response
|
"Primary model '{}' stream stalled after content was emitted; "
|
||||||
|
"attempting failover anyway",
|
||||||
|
primary_model,
|
||||||
|
)
|
||||||
|
has_streamed[0] = False
|
||||||
|
kwargs["on_content_delta"] = None
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"Primary model error but content already streamed; skipping failover"
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
if not self._should_fallback(response):
|
if not self._should_fallback(response):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|||||||
@ -287,7 +287,7 @@ class TestFallbackOnPrimaryError:
|
|||||||
|
|
||||||
class TestNoFallbackWhenContentStreamed:
|
class TestNoFallbackWhenContentStreamed:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test(self) -> None:
|
async def test_non_timeout_error_skips_failover(self) -> None:
|
||||||
primary = _FakeProvider("primary", _error_response())
|
primary = _FakeProvider("primary", _error_response())
|
||||||
factory = MagicMock()
|
factory = MagicMock()
|
||||||
fb = FallbackProvider(
|
fb = FallbackProvider(
|
||||||
@ -303,12 +303,40 @@ class TestNoFallbackWhenContentStreamed:
|
|||||||
messages=[{"role": "user", "content": "hi"}],
|
messages=[{"role": "user", "content": "hi"}],
|
||||||
on_content_delta=_delta,
|
on_content_delta=_delta,
|
||||||
)
|
)
|
||||||
# Primary returns error but content was "streamed" (FakeProvider calls delta)
|
|
||||||
# so failover should be skipped
|
|
||||||
assert result.finish_reason == "error"
|
assert result.finish_reason == "error"
|
||||||
factory.assert_not_called()
|
factory.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
class TestFallbackOnStreamStalledAfterContent:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_timeout_with_streamed_content_falls_back(self) -> None:
|
||||||
|
primary = _FakeProvider(
|
||||||
|
"primary",
|
||||||
|
_make_response("stream stalled", finish_reason="error", error_kind="timeout"),
|
||||||
|
)
|
||||||
|
fallback = _FakeProvider("fallback", _make_response("fallback ok"))
|
||||||
|
factory = MagicMock(return_value=fallback)
|
||||||
|
fb = FallbackProvider(
|
||||||
|
primary=primary,
|
||||||
|
fallback_presets=[_fallback("fallback-a")],
|
||||||
|
provider_factory=factory,
|
||||||
|
)
|
||||||
|
|
||||||
|
streamed: list[str] = []
|
||||||
|
|
||||||
|
async def _delta(text: str) -> None:
|
||||||
|
streamed.append(text)
|
||||||
|
|
||||||
|
result = await fb.chat_stream(
|
||||||
|
messages=[{"role": "user", "content": "hi"}],
|
||||||
|
on_content_delta=_delta,
|
||||||
|
)
|
||||||
|
assert result.finish_reason == "stop"
|
||||||
|
assert result.content == "fallback ok"
|
||||||
|
factory.assert_called_once_with(_fallback("fallback-a"))
|
||||||
|
assert "stream stalled" in streamed
|
||||||
|
|
||||||
|
|
||||||
class TestFailoverOnTransientError:
|
class TestFailoverOnTransientError:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_rate_limit(self) -> None:
|
async def test_rate_limit(self) -> None:
|
||||||
|
|||||||
@ -163,6 +163,42 @@ async def test_chat_stream_with_retry_does_not_retry_after_emitting_content(monk
|
|||||||
assert delays == []
|
assert delays == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_stream_with_retry_retries_timeout_after_emitting_content(monkeypatch) -> None:
|
||||||
|
first = LLMResponse(
|
||||||
|
content="Error calling LLM: stream stalled for more than 30 seconds",
|
||||||
|
finish_reason="error",
|
||||||
|
error_kind="timeout",
|
||||||
|
)
|
||||||
|
first._test_stream_delta = "partial" # type: ignore[attr-defined]
|
||||||
|
provider = ScriptedProvider([
|
||||||
|
first,
|
||||||
|
LLMResponse(content="full retry response"),
|
||||||
|
])
|
||||||
|
deltas: list[str] = []
|
||||||
|
delays: list[int] = []
|
||||||
|
|
||||||
|
async def _fake_sleep(delay: int) -> None:
|
||||||
|
delays.append(delay)
|
||||||
|
|
||||||
|
async def _on_delta(delta: str) -> None:
|
||||||
|
deltas.append(delta)
|
||||||
|
|
||||||
|
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
|
||||||
|
|
||||||
|
response = await provider.chat_stream_with_retry(
|
||||||
|
messages=[{"role": "user", "content": "hello"}],
|
||||||
|
on_content_delta=_on_delta,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.content == "full retry response"
|
||||||
|
assert response.finish_reason == "stop"
|
||||||
|
assert provider.calls == 2
|
||||||
|
assert deltas == ["partial"]
|
||||||
|
assert delays == [1]
|
||||||
|
assert provider.last_kwargs.get("on_content_delta") is None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_chat_with_retry_uses_provider_generation_defaults() -> None:
|
async def test_chat_with_retry_uses_provider_generation_defaults() -> None:
|
||||||
"""When callers omit generation params, provider.generation defaults are used."""
|
"""When callers omit generation params, provider.generation defaults are used."""
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user