mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-15 07:14:08 +00:00
fix: continue recovered streams in a new segment
maintainer edit: streamed timeout recovery was returning the retried response internally while the channel still treated the final outbound as already streamed. End the current stream segment before retry/fallback recovery so subsequent deltas are delivered in a new segment.
This commit is contained in:
parent
2c5a4e0703
commit
bc4bb508a1
@ -754,11 +754,15 @@ class AgentRunner:
|
|||||||
context.streamed_reasoning = True
|
context.streamed_reasoning = True
|
||||||
await hook.emit_reasoning(delta)
|
await hook.emit_reasoning(delta)
|
||||||
|
|
||||||
|
async def _stream_recover() -> None:
|
||||||
|
await hook.on_stream_end(context, resuming=True)
|
||||||
|
|
||||||
coro = self.provider.chat_stream_with_retry(
|
coro = self.provider.chat_stream_with_retry(
|
||||||
**kwargs,
|
**kwargs,
|
||||||
on_content_delta=_stream,
|
on_content_delta=_stream,
|
||||||
on_thinking_delta=_thinking,
|
on_thinking_delta=_thinking,
|
||||||
on_tool_call_delta=_tool_call_delta if live_file_edits is not None else None,
|
on_tool_call_delta=_tool_call_delta if live_file_edits is not None else None,
|
||||||
|
on_stream_recover=_stream_recover,
|
||||||
)
|
)
|
||||||
elif wants_progress_streaming:
|
elif wants_progress_streaming:
|
||||||
stream_buf = ""
|
stream_buf = ""
|
||||||
|
|||||||
@ -631,6 +631,7 @@ class LLMProvider(ABC):
|
|||||||
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
on_content_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||||
on_thinking_delta: Callable[[str], Awaitable[None]] | None = None,
|
on_thinking_delta: Callable[[str], Awaitable[None]] | None = None,
|
||||||
on_tool_call_delta: Callable[[dict[str, Any]], Awaitable[None]] | None = None,
|
on_tool_call_delta: Callable[[dict[str, Any]], Awaitable[None]] | None = None,
|
||||||
|
on_stream_recover: Callable[[], Awaitable[None]] | None = None,
|
||||||
retry_mode: str = "standard",
|
retry_mode: str = "standard",
|
||||||
on_retry_wait: Callable[[str], Awaitable[None]] | None = None,
|
on_retry_wait: Callable[[str], Awaitable[None]] | None = None,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
@ -651,6 +652,12 @@ class LLMProvider(ABC):
|
|||||||
if on_content_delta:
|
if on_content_delta:
|
||||||
await on_content_delta(text)
|
await on_content_delta(text)
|
||||||
|
|
||||||
|
async def _recover_stream() -> None:
|
||||||
|
nonlocal has_streamed_content
|
||||||
|
if on_stream_recover:
|
||||||
|
await on_stream_recover()
|
||||||
|
has_streamed_content = False
|
||||||
|
|
||||||
kw: dict[str, Any] = dict(
|
kw: dict[str, Any] = dict(
|
||||||
messages=messages, tools=tools, model=model,
|
messages=messages, tools=tools, model=model,
|
||||||
max_tokens=max_tokens, temperature=temperature,
|
max_tokens=max_tokens, temperature=temperature,
|
||||||
@ -659,6 +666,8 @@ class LLMProvider(ABC):
|
|||||||
on_thinking_delta=on_thinking_delta,
|
on_thinking_delta=on_thinking_delta,
|
||||||
on_tool_call_delta=on_tool_call_delta,
|
on_tool_call_delta=on_tool_call_delta,
|
||||||
)
|
)
|
||||||
|
if on_stream_recover and getattr(self, "supports_stream_recover_callback", False):
|
||||||
|
kw["on_stream_recover"] = _recover_stream
|
||||||
return await self._run_with_retry(
|
return await self._run_with_retry(
|
||||||
self._safe_chat_stream,
|
self._safe_chat_stream,
|
||||||
kw,
|
kw,
|
||||||
@ -666,6 +675,7 @@ class LLMProvider(ABC):
|
|||||||
retry_mode=retry_mode,
|
retry_mode=retry_mode,
|
||||||
on_retry_wait=on_retry_wait,
|
on_retry_wait=on_retry_wait,
|
||||||
should_retry_guard=lambda: not has_streamed_content,
|
should_retry_guard=lambda: not has_streamed_content,
|
||||||
|
on_stream_recover=_recover_stream if on_stream_recover else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def chat_with_retry(
|
async def chat_with_retry(
|
||||||
@ -813,6 +823,7 @@ class LLMProvider(ABC):
|
|||||||
retry_mode: str,
|
retry_mode: str,
|
||||||
on_retry_wait: Callable[[str], Awaitable[None]] | None,
|
on_retry_wait: Callable[[str], Awaitable[None]] | None,
|
||||||
should_retry_guard: Callable[[], bool] | None = None,
|
should_retry_guard: Callable[[], bool] | None = None,
|
||||||
|
on_stream_recover: Callable[[], Awaitable[None]] | None = None,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
attempt = 0
|
attempt = 0
|
||||||
delays = list(self._CHAT_RETRY_DELAYS)
|
delays = list(self._CHAT_RETRY_DELAYS)
|
||||||
@ -829,15 +840,22 @@ class LLMProvider(ABC):
|
|||||||
if should_retry_guard is not None and not should_retry_guard():
|
if should_retry_guard is not None and not should_retry_guard():
|
||||||
is_timeout = (response.error_kind or "").lower() == "timeout"
|
is_timeout = (response.error_kind or "").lower() == "timeout"
|
||||||
if is_timeout:
|
if is_timeout:
|
||||||
logger.warning(
|
if on_stream_recover:
|
||||||
"LLM stream stalled after content was emitted; "
|
logger.warning(
|
||||||
"suppressing delta callbacks and retrying"
|
"LLM stream stalled after content was emitted; "
|
||||||
)
|
"starting a new stream segment and retrying"
|
||||||
kw.setdefault("on_content_delta", None)
|
)
|
||||||
kw["on_content_delta"] = None
|
await on_stream_recover()
|
||||||
kw["on_thinking_delta"] = None
|
else:
|
||||||
kw["on_tool_call_delta"] = None
|
logger.warning(
|
||||||
should_retry_guard = None
|
"LLM stream stalled after content was emitted; "
|
||||||
|
"suppressing delta callbacks and retrying"
|
||||||
|
)
|
||||||
|
kw.setdefault("on_content_delta", None)
|
||||||
|
kw["on_content_delta"] = None
|
||||||
|
kw["on_thinking_delta"] = None
|
||||||
|
kw["on_tool_call_delta"] = None
|
||||||
|
should_retry_guard = None
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"LLM stream failed after content was emitted; skipping retry"
|
"LLM stream failed after content was emitted; skipping retry"
|
||||||
|
|||||||
@ -71,6 +71,8 @@ class FallbackProvider(LLMProvider):
|
|||||||
wasting requests on a known-bad endpoint.
|
wasting requests on a known-bad endpoint.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
supports_stream_recover_callback = True
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
primary: LLMProvider,
|
primary: LLMProvider,
|
||||||
@ -116,6 +118,7 @@ class FallbackProvider(LLMProvider):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def chat_stream(self, **kwargs: Any) -> LLMResponse:
|
async def chat_stream(self, **kwargs: Any) -> LLMResponse:
|
||||||
|
on_stream_recover = kwargs.pop("on_stream_recover", None)
|
||||||
if not self._has_fallbacks:
|
if not self._has_fallbacks:
|
||||||
return await self._primary.chat_stream(**kwargs)
|
return await self._primary.chat_stream(**kwargs)
|
||||||
|
|
||||||
@ -130,7 +133,10 @@ class FallbackProvider(LLMProvider):
|
|||||||
|
|
||||||
kwargs["on_content_delta"] = _tracking_delta
|
kwargs["on_content_delta"] = _tracking_delta
|
||||||
return await self._try_with_fallback(
|
return await self._try_with_fallback(
|
||||||
lambda p, kw: p.chat_stream(**kw), kwargs, has_streamed=has_streamed
|
lambda p, kw: p.chat_stream(**kw),
|
||||||
|
kwargs,
|
||||||
|
has_streamed=has_streamed,
|
||||||
|
on_stream_recover=on_stream_recover,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _try_with_fallback(
|
async def _try_with_fallback(
|
||||||
@ -138,6 +144,7 @@ class FallbackProvider(LLMProvider):
|
|||||||
call: Callable[[LLMProvider, dict[str, Any]], Awaitable[LLMResponse]],
|
call: Callable[[LLMProvider, dict[str, Any]], Awaitable[LLMResponse]],
|
||||||
kwargs: dict[str, Any],
|
kwargs: dict[str, Any],
|
||||||
has_streamed: list[bool] | None,
|
has_streamed: list[bool] | None,
|
||||||
|
on_stream_recover: Callable[[], Awaitable[None]] | None = None,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
primary_model = kwargs.get("model") or self._primary.get_default_model()
|
primary_model = kwargs.get("model") or self._primary.get_default_model()
|
||||||
|
|
||||||
@ -157,7 +164,10 @@ class FallbackProvider(LLMProvider):
|
|||||||
primary_model,
|
primary_model,
|
||||||
)
|
)
|
||||||
has_streamed[0] = False
|
has_streamed[0] = False
|
||||||
kwargs["on_content_delta"] = None
|
if on_stream_recover:
|
||||||
|
await on_stream_recover()
|
||||||
|
else:
|
||||||
|
kwargs["on_content_delta"] = None
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Primary model error but content already streamed; skipping failover"
|
"Primary model error but content already streamed; skipping failover"
|
||||||
@ -187,7 +197,20 @@ class FallbackProvider(LLMProvider):
|
|||||||
for idx, fallback in enumerate(self._fallback_presets):
|
for idx, fallback in enumerate(self._fallback_presets):
|
||||||
fallback_model = fallback.model
|
fallback_model = fallback.model
|
||||||
if has_streamed is not None and has_streamed[0]:
|
if has_streamed is not None and has_streamed[0]:
|
||||||
break
|
is_timeout = (
|
||||||
|
last_response is not None
|
||||||
|
and (last_response.error_kind or "").lower() == "timeout"
|
||||||
|
)
|
||||||
|
if is_timeout and on_stream_recover:
|
||||||
|
logger.warning(
|
||||||
|
"Fallback model '{}' stream stalled after content was emitted; "
|
||||||
|
"starting a new stream segment and trying next fallback",
|
||||||
|
self._fallback_presets[idx - 1].model if idx > 0 else primary_model,
|
||||||
|
)
|
||||||
|
has_streamed[0] = False
|
||||||
|
await on_stream_recover()
|
||||||
|
else:
|
||||||
|
break
|
||||||
if idx == 0 and primary_skipped:
|
if idx == 0 and primary_skipped:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Primary model '{}' circuit open, trying fallback '{}'",
|
"Primary model '{}' circuit open, trying fallback '{}'",
|
||||||
|
|||||||
@ -492,6 +492,61 @@ class TestToolEventProgress:
|
|||||||
assert turn_end_msgs[0].content == ""
|
assert turn_end_msgs[0].content == ""
|
||||||
provider.chat_with_retry.assert_not_awaited()
|
provider.chat_with_retry.assert_not_awaited()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stream_timeout_recovery_continues_in_new_segment(
|
||||||
|
self,
|
||||||
|
tmp_path: Path,
|
||||||
|
) -> None:
|
||||||
|
"""Recovered streaming output should use a new stream segment."""
|
||||||
|
bus = MessageBus()
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.supports_progress_deltas = True
|
||||||
|
provider.get_default_model.return_value = "openai-codex/gpt-5.5"
|
||||||
|
|
||||||
|
async def chat_stream_with_retry(*, on_content_delta, on_stream_recover, **kwargs):
|
||||||
|
await on_content_delta("partial")
|
||||||
|
await on_stream_recover()
|
||||||
|
await on_content_delta("full retry response")
|
||||||
|
return LLMResponse(content="full retry response", tool_calls=[])
|
||||||
|
|
||||||
|
provider.chat_stream_with_retry = chat_stream_with_retry
|
||||||
|
provider.chat_with_retry = AsyncMock()
|
||||||
|
loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="openai-codex/gpt-5.5")
|
||||||
|
_attach_webui_runtime_events(loop, bus)
|
||||||
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign]
|
||||||
|
|
||||||
|
await loop._dispatch(InboundMessage(
|
||||||
|
channel="websocket",
|
||||||
|
sender_id="u1",
|
||||||
|
chat_id="chat1",
|
||||||
|
content="say hello",
|
||||||
|
metadata={"_wants_stream": True},
|
||||||
|
))
|
||||||
|
|
||||||
|
outbound = []
|
||||||
|
while bus.outbound_size > 0:
|
||||||
|
outbound.append(await bus.consume_outbound())
|
||||||
|
|
||||||
|
deltas = [m for m in outbound if m.metadata.get("_stream_delta")]
|
||||||
|
stream_end = [m for m in outbound if m.metadata.get("_stream_end")]
|
||||||
|
final = [
|
||||||
|
m for m in outbound
|
||||||
|
if not m.metadata.get("_stream_delta")
|
||||||
|
and not m.metadata.get("_stream_end")
|
||||||
|
and not m.metadata.get("_turn_end")
|
||||||
|
and not m.metadata.get("_goal_status")
|
||||||
|
]
|
||||||
|
|
||||||
|
assert [m.content for m in deltas] == ["partial", "full retry response"]
|
||||||
|
assert [m.metadata.get("_resuming") for m in stream_end] == [True, False]
|
||||||
|
assert deltas[0].metadata.get("_stream_id") == stream_end[0].metadata.get("_stream_id")
|
||||||
|
assert deltas[1].metadata.get("_stream_id") == stream_end[1].metadata.get("_stream_id")
|
||||||
|
assert deltas[0].metadata.get("_stream_id") != deltas[1].metadata.get("_stream_id")
|
||||||
|
assert final[-1].content == "full retry response"
|
||||||
|
assert final[-1].metadata.get("_streamed") is True
|
||||||
|
provider.chat_with_retry.assert_not_awaited()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_streamed_progress_is_not_repeated_before_tool_execution(
|
async def test_streamed_progress_is_not_repeated_before_tool_execution(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -323,18 +323,24 @@ class TestFallbackOnStreamStalledAfterContent:
|
|||||||
)
|
)
|
||||||
|
|
||||||
streamed: list[str] = []
|
streamed: list[str] = []
|
||||||
|
recoveries: list[str] = []
|
||||||
|
|
||||||
async def _delta(text: str) -> None:
|
async def _delta(text: str) -> None:
|
||||||
streamed.append(text)
|
streamed.append(text)
|
||||||
|
|
||||||
|
async def _recover() -> None:
|
||||||
|
recoveries.append("recover")
|
||||||
|
|
||||||
result = await fb.chat_stream(
|
result = await fb.chat_stream(
|
||||||
messages=[{"role": "user", "content": "hi"}],
|
messages=[{"role": "user", "content": "hi"}],
|
||||||
on_content_delta=_delta,
|
on_content_delta=_delta,
|
||||||
|
on_stream_recover=_recover,
|
||||||
)
|
)
|
||||||
assert result.finish_reason == "stop"
|
assert result.finish_reason == "stop"
|
||||||
assert result.content == "fallback ok"
|
assert result.content == "fallback ok"
|
||||||
factory.assert_called_once_with(_fallback("fallback-a"))
|
factory.assert_called_once_with(_fallback("fallback-a"))
|
||||||
assert "stream stalled" in streamed
|
assert streamed == ["stream stalled", "fallback ok"]
|
||||||
|
assert recoveries == ["recover"]
|
||||||
|
|
||||||
|
|
||||||
class TestFailoverOnTransientError:
|
class TestFailoverOnTransientError:
|
||||||
|
|||||||
@ -199,6 +199,49 @@ async def test_chat_stream_with_retry_retries_timeout_after_emitting_content(mon
|
|||||||
assert provider.last_kwargs.get("on_content_delta") is None
|
assert provider.last_kwargs.get("on_content_delta") is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_chat_stream_with_retry_retries_timeout_in_new_stream_segment(
|
||||||
|
monkeypatch,
|
||||||
|
) -> None:
|
||||||
|
first = LLMResponse(
|
||||||
|
content="Error calling LLM: stream stalled for more than 30 seconds",
|
||||||
|
finish_reason="error",
|
||||||
|
error_kind="timeout",
|
||||||
|
)
|
||||||
|
first._test_stream_delta = "partial" # type: ignore[attr-defined]
|
||||||
|
second = LLMResponse(content="full retry response")
|
||||||
|
second._test_stream_delta = "full retry response" # type: ignore[attr-defined]
|
||||||
|
provider = ScriptedProvider([first, second])
|
||||||
|
deltas: list[str] = []
|
||||||
|
recoveries: list[str] = []
|
||||||
|
delays: list[int] = []
|
||||||
|
|
||||||
|
async def _fake_sleep(delay: int) -> None:
|
||||||
|
delays.append(delay)
|
||||||
|
|
||||||
|
async def _on_delta(delta: str) -> None:
|
||||||
|
deltas.append(delta)
|
||||||
|
|
||||||
|
async def _on_stream_recover() -> None:
|
||||||
|
recoveries.append("recover")
|
||||||
|
|
||||||
|
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
|
||||||
|
|
||||||
|
response = await provider.chat_stream_with_retry(
|
||||||
|
messages=[{"role": "user", "content": "hello"}],
|
||||||
|
on_content_delta=_on_delta,
|
||||||
|
on_stream_recover=_on_stream_recover,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.content == "full retry response"
|
||||||
|
assert response.finish_reason == "stop"
|
||||||
|
assert provider.calls == 2
|
||||||
|
assert deltas == ["partial", "full retry response"]
|
||||||
|
assert recoveries == ["recover"]
|
||||||
|
assert delays == [1]
|
||||||
|
assert provider.last_kwargs.get("on_content_delta") is not None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_chat_with_retry_uses_provider_generation_defaults() -> None:
|
async def test_chat_with_retry_uses_provider_generation_defaults() -> None:
|
||||||
"""When callers omit generation params, provider.generation defaults are used."""
|
"""When callers omit generation params, provider.generation defaults are used."""
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user