From 653de4a7efde76aead9de8e631df2c778c3a1460 Mon Sep 17 00:00:00 2001 From: hanyuanling Date: Wed, 6 May 2026 14:45:20 +0800 Subject: [PATCH] fix(agent): gate provider progress deltas --- nanobot/agent/loop.py | 1 + nanobot/agent/runner.py | 2 + tests/agent/test_loop_progress.py | 73 ++++++++++++++++---- tests/agent/test_runner_progress_deltas.py | 79 ++++++++++++++++++++++ 4 files changed, 142 insertions(+), 13 deletions(-) create mode 100644 tests/agent/test_runner_progress_deltas.py diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index d5e7681f1..07006b057 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -649,6 +649,7 @@ class AgentLoop: context_block_limit=self.context_block_limit, provider_retry_mode=self.provider_retry_mode, progress_callback=on_progress, + stream_progress_deltas=on_stream is not None, retry_wait_callback=on_retry_wait, checkpoint_callback=_checkpoint, injection_callback=_drain_pending, diff --git a/nanobot/agent/runner.py b/nanobot/agent/runner.py index b81df4168..7fe92ad51 100644 --- a/nanobot/agent/runner.py +++ b/nanobot/agent/runner.py @@ -76,6 +76,7 @@ class AgentRunSpec: context_block_limit: int | None = None provider_retry_mode: str = "standard" progress_callback: Any | None = None + stream_progress_deltas: bool = True retry_wait_callback: Any | None = None checkpoint_callback: Any | None = None injection_callback: Any | None = None @@ -615,6 +616,7 @@ class AgentRunner: wants_streaming = hook.wants_streaming() wants_progress_streaming = ( not wants_streaming + and spec.stream_progress_deltas and spec.progress_callback is not None and getattr(self.provider, "supports_progress_deltas", False) is True ) diff --git a/tests/agent/test_loop_progress.py b/tests/agent/test_loop_progress.py index d08448992..47a63ba02 100644 --- a/tests/agent/test_loop_progress.py +++ b/tests/agent/test_loop_progress.py @@ -130,11 +130,44 @@ class TestToolEventProgress: assert finish["result"] == "file.txt" @pytest.mark.asyncio - async def test_bus_progress_streams_provider_deltas_for_codex_style_provider( + async def test_non_streaming_channel_does_not_publish_codex_progress_deltas( self, tmp_path: Path, ) -> None: - """Providers that opt in can stream content deltas through _progress messages.""" + """Non-streaming channels should get one final reply, not token progress spam.""" + bus = MessageBus() + provider = MagicMock() + provider.supports_progress_deltas = True + provider.get_default_model.return_value = "openai-codex/gpt-5.5" + provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="Hello", tool_calls=[])) + provider.chat_stream_with_retry = AsyncMock() + loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="openai-codex/gpt-5.5") + loop.tools.get_definitions = MagicMock(return_value=[]) + loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign] + + await loop._dispatch(InboundMessage( + channel="whatsapp", + sender_id="u1", + chat_id="chat1", + content="say hello", + )) + + outbound = [] + while bus.outbound_size > 0: + outbound.append(await bus.consume_outbound()) + + assert [m.content for m in outbound] == ["Hello"] + assert not any(m.metadata.get("_progress") for m in outbound) + assert not any(m.metadata.get("_streamed") for m in outbound) + provider.chat_stream_with_retry.assert_not_awaited() + provider.chat_with_retry.assert_awaited_once() + + @pytest.mark.asyncio + async def test_streaming_channel_streams_provider_deltas_for_codex_style_provider( + self, + tmp_path: Path, + ) -> None: + """Streaming channels still receive provider deltas through _stream_delta messages.""" bus = MessageBus() provider = MagicMock() provider.supports_progress_deltas = True @@ -156,18 +189,27 @@ class TestToolEventProgress: sender_id="u1", chat_id="chat1", content="say hello", + metadata={"_wants_stream": True}, )) outbound = [] while bus.outbound_size > 0: outbound.append(await bus.consume_outbound()) - progress = [m for m in outbound if m.metadata.get("_progress")] - final = [m for m in outbound if not m.metadata.get("_progress")] + deltas = [m for m in outbound if m.metadata.get("_stream_delta")] + stream_end = [m for m in outbound if m.metadata.get("_stream_end")] + final = [ + m for m in outbound + if not m.metadata.get("_stream_delta") + and not m.metadata.get("_stream_end") + and not m.metadata.get("_turn_end") + ] - assert [m.content for m in progress] == ["Hel", "lo"] - assert final[-2].content == "Hello" - assert (final[-1].metadata or {}).get("_turn_end") is True + assert [m.content for m in deltas] == ["Hel", "lo"] + assert len(stream_end) == 1 + assert final[-1].content == "Hello" + assert final[-1].metadata.get("_streamed") is True + assert outbound[-1].metadata.get("_turn_end") is True provider.chat_with_retry.assert_not_awaited() @pytest.mark.asyncio @@ -197,8 +239,12 @@ class TestToolEventProgress: loop.tools.prepare_call = MagicMock(return_value=(None, {"path": "foo.txt"}, None)) loop.tools.execute = AsyncMock(return_value="ok") + streamed: list[str] = [] progress: list[tuple[str, bool, list[dict] | None]] = [] + async def on_stream(delta: str) -> None: + streamed.append(delta) + async def on_progress( content: str, *, @@ -207,14 +253,15 @@ class TestToolEventProgress: ) -> None: progress.append((content, tool_hint, tool_events)) - final_content, _, _, _, _ = await loop._run_agent_loop([], on_progress=on_progress) + final_content, _, _, _, _ = await loop._run_agent_loop( + [], + on_progress=on_progress, + on_stream=on_stream, + ) assert final_content == "Done" - assert [item[0] for item in progress[:3]] == [ - "I will", - " inspect it.", - 'custom_tool("foo.txt")', - ] + assert streamed == ["I will", " inspect it."] + assert progress[0][0] == 'custom_tool("foo.txt")' assert all(item[0] != "I will inspect it." for item in progress) @pytest.mark.asyncio diff --git a/tests/agent/test_runner_progress_deltas.py b/tests/agent/test_runner_progress_deltas.py new file mode 100644 index 000000000..13d5ea799 --- /dev/null +++ b/tests/agent/test_runner_progress_deltas.py @@ -0,0 +1,79 @@ +"""Tests for provider progress delta routing in the shared runner.""" + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from nanobot.agent.runner import AgentRunner, AgentRunSpec +from nanobot.config.schema import AgentDefaults +from nanobot.providers.base import LLMResponse + +_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars + + +@pytest.mark.asyncio +async def test_runner_can_disable_provider_progress_delta_streaming(): + """AgentLoop disables token progress streaming for non-streaming channels.""" + provider = MagicMock() + provider.supports_progress_deltas = True + provider.chat_with_retry = AsyncMock( + return_value=LLMResponse(content="done", tool_calls=[], usage={}) + ) + provider.chat_stream_with_retry = AsyncMock() + tools = MagicMock() + tools.get_definitions.return_value = [] + progress_cb = AsyncMock() + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[ + {"role": "system", "content": "system"}, + {"role": "user", "content": "hi"}, + ], + tools=tools, + model="test-model", + max_iterations=1, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + progress_callback=progress_cb, + stream_progress_deltas=False, + )) + + assert result.final_content == "done" + provider.chat_with_retry.assert_awaited_once() + provider.chat_stream_with_retry.assert_not_awaited() + progress_cb.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_runner_streams_provider_progress_deltas_by_default(): + """Direct runner users keep the existing opt-in provider progress behavior.""" + provider = MagicMock() + provider.supports_progress_deltas = True + + async def chat_stream_with_retry(*, on_content_delta, **kwargs): + await on_content_delta("he") + await on_content_delta("llo") + return LLMResponse(content="hello", tool_calls=[], usage={}) + + provider.chat_stream_with_retry = chat_stream_with_retry + provider.chat_with_retry = AsyncMock() + tools = MagicMock() + tools.get_definitions.return_value = [] + progress_cb = AsyncMock() + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[ + {"role": "system", "content": "system"}, + {"role": "user", "content": "hi"}, + ], + tools=tools, + model="test-model", + max_iterations=1, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + progress_callback=progress_cb, + )) + + assert result.final_content == "hello" + assert [call.args[0] for call in progress_cb.await_args_list] == ["he", "llo"] + provider.chat_with_retry.assert_not_awaited()