From bc4bb508a13c45a102db4db142316ded8fbfc1cd Mon Sep 17 00:00:00 2001 From: chengyongru Date: Wed, 10 Jun 2026 15:53:54 +0800 Subject: [PATCH] fix: continue recovered streams in a new segment maintainer edit: streamed timeout recovery was returning the retried response internally while the channel still treated the final outbound as already streamed. End the current stream segment before retry/fallback recovery so subsequent deltas are delivered in a new segment. --- nanobot/agent/runner.py | 4 ++ nanobot/providers/base.py | 36 ++++++++++++----- nanobot/providers/fallback_provider.py | 29 ++++++++++++-- tests/agent/test_loop_progress.py | 55 ++++++++++++++++++++++++++ tests/agent/test_runner_fallback.py | 8 +++- tests/providers/test_provider_retry.py | 43 ++++++++++++++++++++ 6 files changed, 162 insertions(+), 13 deletions(-) diff --git a/nanobot/agent/runner.py b/nanobot/agent/runner.py index 5c9ff6e2d..53f6554ab 100644 --- a/nanobot/agent/runner.py +++ b/nanobot/agent/runner.py @@ -754,11 +754,15 @@ class AgentRunner: context.streamed_reasoning = True await hook.emit_reasoning(delta) + async def _stream_recover() -> None: + await hook.on_stream_end(context, resuming=True) + coro = self.provider.chat_stream_with_retry( **kwargs, on_content_delta=_stream, on_thinking_delta=_thinking, on_tool_call_delta=_tool_call_delta if live_file_edits is not None else None, + on_stream_recover=_stream_recover, ) elif wants_progress_streaming: stream_buf = "" diff --git a/nanobot/providers/base.py b/nanobot/providers/base.py index 640a5c910..802ac314a 100644 --- a/nanobot/providers/base.py +++ b/nanobot/providers/base.py @@ -631,6 +631,7 @@ class LLMProvider(ABC): on_content_delta: Callable[[str], Awaitable[None]] | None = None, on_thinking_delta: Callable[[str], Awaitable[None]] | None = None, on_tool_call_delta: Callable[[dict[str, Any]], Awaitable[None]] | None = None, + on_stream_recover: Callable[[], Awaitable[None]] | None = None, retry_mode: str = "standard", on_retry_wait: Callable[[str], Awaitable[None]] | None = None, ) -> LLMResponse: @@ -651,6 +652,12 @@ class LLMProvider(ABC): if on_content_delta: await on_content_delta(text) + async def _recover_stream() -> None: + nonlocal has_streamed_content + if on_stream_recover: + await on_stream_recover() + has_streamed_content = False + kw: dict[str, Any] = dict( messages=messages, tools=tools, model=model, max_tokens=max_tokens, temperature=temperature, @@ -659,6 +666,8 @@ class LLMProvider(ABC): on_thinking_delta=on_thinking_delta, on_tool_call_delta=on_tool_call_delta, ) + if on_stream_recover and getattr(self, "supports_stream_recover_callback", False): + kw["on_stream_recover"] = _recover_stream return await self._run_with_retry( self._safe_chat_stream, kw, @@ -666,6 +675,7 @@ class LLMProvider(ABC): retry_mode=retry_mode, on_retry_wait=on_retry_wait, should_retry_guard=lambda: not has_streamed_content, + on_stream_recover=_recover_stream if on_stream_recover else None, ) async def chat_with_retry( @@ -813,6 +823,7 @@ class LLMProvider(ABC): retry_mode: str, on_retry_wait: Callable[[str], Awaitable[None]] | None, should_retry_guard: Callable[[], bool] | None = None, + on_stream_recover: Callable[[], Awaitable[None]] | None = None, ) -> LLMResponse: attempt = 0 delays = list(self._CHAT_RETRY_DELAYS) @@ -829,15 +840,22 @@ class LLMProvider(ABC): if should_retry_guard is not None and not should_retry_guard(): is_timeout = (response.error_kind or "").lower() == "timeout" if is_timeout: - logger.warning( - "LLM stream stalled after content was emitted; " - "suppressing delta callbacks and retrying" - ) - kw.setdefault("on_content_delta", None) - kw["on_content_delta"] = None - kw["on_thinking_delta"] = None - kw["on_tool_call_delta"] = None - should_retry_guard = None + if on_stream_recover: + logger.warning( + "LLM stream stalled after content was emitted; " + "starting a new stream segment and retrying" + ) + await on_stream_recover() + else: + logger.warning( + "LLM stream stalled after content was emitted; " + "suppressing delta callbacks and retrying" + ) + kw.setdefault("on_content_delta", None) + kw["on_content_delta"] = None + kw["on_thinking_delta"] = None + kw["on_tool_call_delta"] = None + should_retry_guard = None else: logger.warning( "LLM stream failed after content was emitted; skipping retry" diff --git a/nanobot/providers/fallback_provider.py b/nanobot/providers/fallback_provider.py index d8ee4a5fa..2381d6175 100644 --- a/nanobot/providers/fallback_provider.py +++ b/nanobot/providers/fallback_provider.py @@ -71,6 +71,8 @@ class FallbackProvider(LLMProvider): wasting requests on a known-bad endpoint. """ + supports_stream_recover_callback = True + def __init__( self, primary: LLMProvider, @@ -116,6 +118,7 @@ class FallbackProvider(LLMProvider): ) async def chat_stream(self, **kwargs: Any) -> LLMResponse: + on_stream_recover = kwargs.pop("on_stream_recover", None) if not self._has_fallbacks: return await self._primary.chat_stream(**kwargs) @@ -130,7 +133,10 @@ class FallbackProvider(LLMProvider): kwargs["on_content_delta"] = _tracking_delta return await self._try_with_fallback( - lambda p, kw: p.chat_stream(**kw), kwargs, has_streamed=has_streamed + lambda p, kw: p.chat_stream(**kw), + kwargs, + has_streamed=has_streamed, + on_stream_recover=on_stream_recover, ) async def _try_with_fallback( @@ -138,6 +144,7 @@ class FallbackProvider(LLMProvider): call: Callable[[LLMProvider, dict[str, Any]], Awaitable[LLMResponse]], kwargs: dict[str, Any], has_streamed: list[bool] | None, + on_stream_recover: Callable[[], Awaitable[None]] | None = None, ) -> LLMResponse: primary_model = kwargs.get("model") or self._primary.get_default_model() @@ -157,7 +164,10 @@ class FallbackProvider(LLMProvider): primary_model, ) has_streamed[0] = False - kwargs["on_content_delta"] = None + if on_stream_recover: + await on_stream_recover() + else: + kwargs["on_content_delta"] = None else: logger.warning( "Primary model error but content already streamed; skipping failover" @@ -187,7 +197,20 @@ class FallbackProvider(LLMProvider): for idx, fallback in enumerate(self._fallback_presets): fallback_model = fallback.model if has_streamed is not None and has_streamed[0]: - break + is_timeout = ( + last_response is not None + and (last_response.error_kind or "").lower() == "timeout" + ) + if is_timeout and on_stream_recover: + logger.warning( + "Fallback model '{}' stream stalled after content was emitted; " + "starting a new stream segment and trying next fallback", + self._fallback_presets[idx - 1].model if idx > 0 else primary_model, + ) + has_streamed[0] = False + await on_stream_recover() + else: + break if idx == 0 and primary_skipped: logger.info( "Primary model '{}' circuit open, trying fallback '{}'", diff --git a/tests/agent/test_loop_progress.py b/tests/agent/test_loop_progress.py index bbac2e6af..19473cc7f 100644 --- a/tests/agent/test_loop_progress.py +++ b/tests/agent/test_loop_progress.py @@ -492,6 +492,61 @@ class TestToolEventProgress: assert turn_end_msgs[0].content == "" provider.chat_with_retry.assert_not_awaited() + @pytest.mark.asyncio + async def test_stream_timeout_recovery_continues_in_new_segment( + self, + tmp_path: Path, + ) -> None: + """Recovered streaming output should use a new stream segment.""" + bus = MessageBus() + provider = MagicMock() + provider.supports_progress_deltas = True + provider.get_default_model.return_value = "openai-codex/gpt-5.5" + + async def chat_stream_with_retry(*, on_content_delta, on_stream_recover, **kwargs): + await on_content_delta("partial") + await on_stream_recover() + await on_content_delta("full retry response") + return LLMResponse(content="full retry response", tool_calls=[]) + + provider.chat_stream_with_retry = chat_stream_with_retry + provider.chat_with_retry = AsyncMock() + loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="openai-codex/gpt-5.5") + _attach_webui_runtime_events(loop, bus) + loop.tools.get_definitions = MagicMock(return_value=[]) + loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign] + + await loop._dispatch(InboundMessage( + channel="websocket", + sender_id="u1", + chat_id="chat1", + content="say hello", + metadata={"_wants_stream": True}, + )) + + outbound = [] + while bus.outbound_size > 0: + outbound.append(await bus.consume_outbound()) + + deltas = [m for m in outbound if m.metadata.get("_stream_delta")] + stream_end = [m for m in outbound if m.metadata.get("_stream_end")] + final = [ + m for m in outbound + if not m.metadata.get("_stream_delta") + and not m.metadata.get("_stream_end") + and not m.metadata.get("_turn_end") + and not m.metadata.get("_goal_status") + ] + + assert [m.content for m in deltas] == ["partial", "full retry response"] + assert [m.metadata.get("_resuming") for m in stream_end] == [True, False] + assert deltas[0].metadata.get("_stream_id") == stream_end[0].metadata.get("_stream_id") + assert deltas[1].metadata.get("_stream_id") == stream_end[1].metadata.get("_stream_id") + assert deltas[0].metadata.get("_stream_id") != deltas[1].metadata.get("_stream_id") + assert final[-1].content == "full retry response" + assert final[-1].metadata.get("_streamed") is True + provider.chat_with_retry.assert_not_awaited() + @pytest.mark.asyncio async def test_streamed_progress_is_not_repeated_before_tool_execution( self, diff --git a/tests/agent/test_runner_fallback.py b/tests/agent/test_runner_fallback.py index 70d44e71d..d7e536c0c 100644 --- a/tests/agent/test_runner_fallback.py +++ b/tests/agent/test_runner_fallback.py @@ -323,18 +323,24 @@ class TestFallbackOnStreamStalledAfterContent: ) streamed: list[str] = [] + recoveries: list[str] = [] async def _delta(text: str) -> None: streamed.append(text) + async def _recover() -> None: + recoveries.append("recover") + result = await fb.chat_stream( messages=[{"role": "user", "content": "hi"}], on_content_delta=_delta, + on_stream_recover=_recover, ) assert result.finish_reason == "stop" assert result.content == "fallback ok" factory.assert_called_once_with(_fallback("fallback-a")) - assert "stream stalled" in streamed + assert streamed == ["stream stalled", "fallback ok"] + assert recoveries == ["recover"] class TestFailoverOnTransientError: diff --git a/tests/providers/test_provider_retry.py b/tests/providers/test_provider_retry.py index 07c3b1b18..9483fee9b 100644 --- a/tests/providers/test_provider_retry.py +++ b/tests/providers/test_provider_retry.py @@ -199,6 +199,49 @@ async def test_chat_stream_with_retry_retries_timeout_after_emitting_content(mon assert provider.last_kwargs.get("on_content_delta") is None +@pytest.mark.asyncio +async def test_chat_stream_with_retry_retries_timeout_in_new_stream_segment( + monkeypatch, +) -> None: + first = LLMResponse( + content="Error calling LLM: stream stalled for more than 30 seconds", + finish_reason="error", + error_kind="timeout", + ) + first._test_stream_delta = "partial" # type: ignore[attr-defined] + second = LLMResponse(content="full retry response") + second._test_stream_delta = "full retry response" # type: ignore[attr-defined] + provider = ScriptedProvider([first, second]) + deltas: list[str] = [] + recoveries: list[str] = [] + delays: list[int] = [] + + async def _fake_sleep(delay: int) -> None: + delays.append(delay) + + async def _on_delta(delta: str) -> None: + deltas.append(delta) + + async def _on_stream_recover() -> None: + recoveries.append("recover") + + monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep) + + response = await provider.chat_stream_with_retry( + messages=[{"role": "user", "content": "hello"}], + on_content_delta=_on_delta, + on_stream_recover=_on_stream_recover, + ) + + assert response.content == "full retry response" + assert response.finish_reason == "stop" + assert provider.calls == 2 + assert deltas == ["partial", "full retry response"] + assert recoveries == ["recover"] + assert delays == [1] + assert provider.last_kwargs.get("on_content_delta") is not None + + @pytest.mark.asyncio async def test_chat_with_retry_uses_provider_generation_defaults() -> None: """When callers omit generation params, provider.generation defaults are used."""