fix: continue recovered streams in a new segment

maintainer edit: streamed timeout recovery was returning the retried response internally while the channel still treated the final outbound as already streamed. End the current stream segment before retry/fallback recovery so subsequent deltas are delivered in a new segment.
This commit is contained in:
chengyongru 2026-06-10 15:53:54 +08:00 committed by Xubin Ren
parent 2c5a4e0703
commit bc4bb508a1
6 changed files with 162 additions and 13 deletions

View File

@ -754,11 +754,15 @@ class AgentRunner:
context.streamed_reasoning = True context.streamed_reasoning = True
await hook.emit_reasoning(delta) await hook.emit_reasoning(delta)
async def _stream_recover() -> None:
await hook.on_stream_end(context, resuming=True)
coro = self.provider.chat_stream_with_retry( coro = self.provider.chat_stream_with_retry(
**kwargs, **kwargs,
on_content_delta=_stream, on_content_delta=_stream,
on_thinking_delta=_thinking, on_thinking_delta=_thinking,
on_tool_call_delta=_tool_call_delta if live_file_edits is not None else None, on_tool_call_delta=_tool_call_delta if live_file_edits is not None else None,
on_stream_recover=_stream_recover,
) )
elif wants_progress_streaming: elif wants_progress_streaming:
stream_buf = "" stream_buf = ""

View File

@ -631,6 +631,7 @@ class LLMProvider(ABC):
on_content_delta: Callable[[str], Awaitable[None]] | None = None, on_content_delta: Callable[[str], Awaitable[None]] | None = None,
on_thinking_delta: Callable[[str], Awaitable[None]] | None = None, on_thinking_delta: Callable[[str], Awaitable[None]] | None = None,
on_tool_call_delta: Callable[[dict[str, Any]], Awaitable[None]] | None = None, on_tool_call_delta: Callable[[dict[str, Any]], Awaitable[None]] | None = None,
on_stream_recover: Callable[[], Awaitable[None]] | None = None,
retry_mode: str = "standard", retry_mode: str = "standard",
on_retry_wait: Callable[[str], Awaitable[None]] | None = None, on_retry_wait: Callable[[str], Awaitable[None]] | None = None,
) -> LLMResponse: ) -> LLMResponse:
@ -651,6 +652,12 @@ class LLMProvider(ABC):
if on_content_delta: if on_content_delta:
await on_content_delta(text) await on_content_delta(text)
async def _recover_stream() -> None:
nonlocal has_streamed_content
if on_stream_recover:
await on_stream_recover()
has_streamed_content = False
kw: dict[str, Any] = dict( kw: dict[str, Any] = dict(
messages=messages, tools=tools, model=model, messages=messages, tools=tools, model=model,
max_tokens=max_tokens, temperature=temperature, max_tokens=max_tokens, temperature=temperature,
@ -659,6 +666,8 @@ class LLMProvider(ABC):
on_thinking_delta=on_thinking_delta, on_thinking_delta=on_thinking_delta,
on_tool_call_delta=on_tool_call_delta, on_tool_call_delta=on_tool_call_delta,
) )
if on_stream_recover and getattr(self, "supports_stream_recover_callback", False):
kw["on_stream_recover"] = _recover_stream
return await self._run_with_retry( return await self._run_with_retry(
self._safe_chat_stream, self._safe_chat_stream,
kw, kw,
@ -666,6 +675,7 @@ class LLMProvider(ABC):
retry_mode=retry_mode, retry_mode=retry_mode,
on_retry_wait=on_retry_wait, on_retry_wait=on_retry_wait,
should_retry_guard=lambda: not has_streamed_content, should_retry_guard=lambda: not has_streamed_content,
on_stream_recover=_recover_stream if on_stream_recover else None,
) )
async def chat_with_retry( async def chat_with_retry(
@ -813,6 +823,7 @@ class LLMProvider(ABC):
retry_mode: str, retry_mode: str,
on_retry_wait: Callable[[str], Awaitable[None]] | None, on_retry_wait: Callable[[str], Awaitable[None]] | None,
should_retry_guard: Callable[[], bool] | None = None, should_retry_guard: Callable[[], bool] | None = None,
on_stream_recover: Callable[[], Awaitable[None]] | None = None,
) -> LLMResponse: ) -> LLMResponse:
attempt = 0 attempt = 0
delays = list(self._CHAT_RETRY_DELAYS) delays = list(self._CHAT_RETRY_DELAYS)
@ -829,15 +840,22 @@ class LLMProvider(ABC):
if should_retry_guard is not None and not should_retry_guard(): if should_retry_guard is not None and not should_retry_guard():
is_timeout = (response.error_kind or "").lower() == "timeout" is_timeout = (response.error_kind or "").lower() == "timeout"
if is_timeout: if is_timeout:
logger.warning( if on_stream_recover:
"LLM stream stalled after content was emitted; " logger.warning(
"suppressing delta callbacks and retrying" "LLM stream stalled after content was emitted; "
) "starting a new stream segment and retrying"
kw.setdefault("on_content_delta", None) )
kw["on_content_delta"] = None await on_stream_recover()
kw["on_thinking_delta"] = None else:
kw["on_tool_call_delta"] = None logger.warning(
should_retry_guard = None "LLM stream stalled after content was emitted; "
"suppressing delta callbacks and retrying"
)
kw.setdefault("on_content_delta", None)
kw["on_content_delta"] = None
kw["on_thinking_delta"] = None
kw["on_tool_call_delta"] = None
should_retry_guard = None
else: else:
logger.warning( logger.warning(
"LLM stream failed after content was emitted; skipping retry" "LLM stream failed after content was emitted; skipping retry"

View File

@ -71,6 +71,8 @@ class FallbackProvider(LLMProvider):
wasting requests on a known-bad endpoint. wasting requests on a known-bad endpoint.
""" """
supports_stream_recover_callback = True
def __init__( def __init__(
self, self,
primary: LLMProvider, primary: LLMProvider,
@ -116,6 +118,7 @@ class FallbackProvider(LLMProvider):
) )
async def chat_stream(self, **kwargs: Any) -> LLMResponse: async def chat_stream(self, **kwargs: Any) -> LLMResponse:
on_stream_recover = kwargs.pop("on_stream_recover", None)
if not self._has_fallbacks: if not self._has_fallbacks:
return await self._primary.chat_stream(**kwargs) return await self._primary.chat_stream(**kwargs)
@ -130,7 +133,10 @@ class FallbackProvider(LLMProvider):
kwargs["on_content_delta"] = _tracking_delta kwargs["on_content_delta"] = _tracking_delta
return await self._try_with_fallback( return await self._try_with_fallback(
lambda p, kw: p.chat_stream(**kw), kwargs, has_streamed=has_streamed lambda p, kw: p.chat_stream(**kw),
kwargs,
has_streamed=has_streamed,
on_stream_recover=on_stream_recover,
) )
async def _try_with_fallback( async def _try_with_fallback(
@ -138,6 +144,7 @@ class FallbackProvider(LLMProvider):
call: Callable[[LLMProvider, dict[str, Any]], Awaitable[LLMResponse]], call: Callable[[LLMProvider, dict[str, Any]], Awaitable[LLMResponse]],
kwargs: dict[str, Any], kwargs: dict[str, Any],
has_streamed: list[bool] | None, has_streamed: list[bool] | None,
on_stream_recover: Callable[[], Awaitable[None]] | None = None,
) -> LLMResponse: ) -> LLMResponse:
primary_model = kwargs.get("model") or self._primary.get_default_model() primary_model = kwargs.get("model") or self._primary.get_default_model()
@ -157,7 +164,10 @@ class FallbackProvider(LLMProvider):
primary_model, primary_model,
) )
has_streamed[0] = False has_streamed[0] = False
kwargs["on_content_delta"] = None if on_stream_recover:
await on_stream_recover()
else:
kwargs["on_content_delta"] = None
else: else:
logger.warning( logger.warning(
"Primary model error but content already streamed; skipping failover" "Primary model error but content already streamed; skipping failover"
@ -187,7 +197,20 @@ class FallbackProvider(LLMProvider):
for idx, fallback in enumerate(self._fallback_presets): for idx, fallback in enumerate(self._fallback_presets):
fallback_model = fallback.model fallback_model = fallback.model
if has_streamed is not None and has_streamed[0]: if has_streamed is not None and has_streamed[0]:
break is_timeout = (
last_response is not None
and (last_response.error_kind or "").lower() == "timeout"
)
if is_timeout and on_stream_recover:
logger.warning(
"Fallback model '{}' stream stalled after content was emitted; "
"starting a new stream segment and trying next fallback",
self._fallback_presets[idx - 1].model if idx > 0 else primary_model,
)
has_streamed[0] = False
await on_stream_recover()
else:
break
if idx == 0 and primary_skipped: if idx == 0 and primary_skipped:
logger.info( logger.info(
"Primary model '{}' circuit open, trying fallback '{}'", "Primary model '{}' circuit open, trying fallback '{}'",

View File

@ -492,6 +492,61 @@ class TestToolEventProgress:
assert turn_end_msgs[0].content == "" assert turn_end_msgs[0].content == ""
provider.chat_with_retry.assert_not_awaited() provider.chat_with_retry.assert_not_awaited()
@pytest.mark.asyncio
async def test_stream_timeout_recovery_continues_in_new_segment(
self,
tmp_path: Path,
) -> None:
"""Recovered streaming output should use a new stream segment."""
bus = MessageBus()
provider = MagicMock()
provider.supports_progress_deltas = True
provider.get_default_model.return_value = "openai-codex/gpt-5.5"
async def chat_stream_with_retry(*, on_content_delta, on_stream_recover, **kwargs):
await on_content_delta("partial")
await on_stream_recover()
await on_content_delta("full retry response")
return LLMResponse(content="full retry response", tool_calls=[])
provider.chat_stream_with_retry = chat_stream_with_retry
provider.chat_with_retry = AsyncMock()
loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="openai-codex/gpt-5.5")
_attach_webui_runtime_events(loop, bus)
loop.tools.get_definitions = MagicMock(return_value=[])
loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign]
await loop._dispatch(InboundMessage(
channel="websocket",
sender_id="u1",
chat_id="chat1",
content="say hello",
metadata={"_wants_stream": True},
))
outbound = []
while bus.outbound_size > 0:
outbound.append(await bus.consume_outbound())
deltas = [m for m in outbound if m.metadata.get("_stream_delta")]
stream_end = [m for m in outbound if m.metadata.get("_stream_end")]
final = [
m for m in outbound
if not m.metadata.get("_stream_delta")
and not m.metadata.get("_stream_end")
and not m.metadata.get("_turn_end")
and not m.metadata.get("_goal_status")
]
assert [m.content for m in deltas] == ["partial", "full retry response"]
assert [m.metadata.get("_resuming") for m in stream_end] == [True, False]
assert deltas[0].metadata.get("_stream_id") == stream_end[0].metadata.get("_stream_id")
assert deltas[1].metadata.get("_stream_id") == stream_end[1].metadata.get("_stream_id")
assert deltas[0].metadata.get("_stream_id") != deltas[1].metadata.get("_stream_id")
assert final[-1].content == "full retry response"
assert final[-1].metadata.get("_streamed") is True
provider.chat_with_retry.assert_not_awaited()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_streamed_progress_is_not_repeated_before_tool_execution( async def test_streamed_progress_is_not_repeated_before_tool_execution(
self, self,

View File

@ -323,18 +323,24 @@ class TestFallbackOnStreamStalledAfterContent:
) )
streamed: list[str] = [] streamed: list[str] = []
recoveries: list[str] = []
async def _delta(text: str) -> None: async def _delta(text: str) -> None:
streamed.append(text) streamed.append(text)
async def _recover() -> None:
recoveries.append("recover")
result = await fb.chat_stream( result = await fb.chat_stream(
messages=[{"role": "user", "content": "hi"}], messages=[{"role": "user", "content": "hi"}],
on_content_delta=_delta, on_content_delta=_delta,
on_stream_recover=_recover,
) )
assert result.finish_reason == "stop" assert result.finish_reason == "stop"
assert result.content == "fallback ok" assert result.content == "fallback ok"
factory.assert_called_once_with(_fallback("fallback-a")) factory.assert_called_once_with(_fallback("fallback-a"))
assert "stream stalled" in streamed assert streamed == ["stream stalled", "fallback ok"]
assert recoveries == ["recover"]
class TestFailoverOnTransientError: class TestFailoverOnTransientError:

View File

@ -199,6 +199,49 @@ async def test_chat_stream_with_retry_retries_timeout_after_emitting_content(mon
assert provider.last_kwargs.get("on_content_delta") is None assert provider.last_kwargs.get("on_content_delta") is None
@pytest.mark.asyncio
async def test_chat_stream_with_retry_retries_timeout_in_new_stream_segment(
monkeypatch,
) -> None:
first = LLMResponse(
content="Error calling LLM: stream stalled for more than 30 seconds",
finish_reason="error",
error_kind="timeout",
)
first._test_stream_delta = "partial" # type: ignore[attr-defined]
second = LLMResponse(content="full retry response")
second._test_stream_delta = "full retry response" # type: ignore[attr-defined]
provider = ScriptedProvider([first, second])
deltas: list[str] = []
recoveries: list[str] = []
delays: list[int] = []
async def _fake_sleep(delay: int) -> None:
delays.append(delay)
async def _on_delta(delta: str) -> None:
deltas.append(delta)
async def _on_stream_recover() -> None:
recoveries.append("recover")
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
response = await provider.chat_stream_with_retry(
messages=[{"role": "user", "content": "hello"}],
on_content_delta=_on_delta,
on_stream_recover=_on_stream_recover,
)
assert response.content == "full retry response"
assert response.finish_reason == "stop"
assert provider.calls == 2
assert deltas == ["partial", "full retry response"]
assert recoveries == ["recover"]
assert delays == [1]
assert provider.last_kwargs.get("on_content_delta") is not None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_chat_with_retry_uses_provider_generation_defaults() -> None: async def test_chat_with_retry_uses_provider_generation_defaults() -> None:
"""When callers omit generation params, provider.generation defaults are used.""" """When callers omit generation params, provider.generation defaults are used."""