fix(agent): gate provider progress deltas

This commit is contained in:
hanyuanling 2026-05-06 14:45:20 +08:00 committed by Xubin Ren
parent 05e0106592
commit 653de4a7ef
4 changed files with 142 additions and 13 deletions

View File

@ -649,6 +649,7 @@ class AgentLoop:
context_block_limit=self.context_block_limit,
provider_retry_mode=self.provider_retry_mode,
progress_callback=on_progress,
stream_progress_deltas=on_stream is not None,
retry_wait_callback=on_retry_wait,
checkpoint_callback=_checkpoint,
injection_callback=_drain_pending,

View File

@ -76,6 +76,7 @@ class AgentRunSpec:
context_block_limit: int | None = None
provider_retry_mode: str = "standard"
progress_callback: Any | None = None
stream_progress_deltas: bool = True
retry_wait_callback: Any | None = None
checkpoint_callback: Any | None = None
injection_callback: Any | None = None
@ -615,6 +616,7 @@ class AgentRunner:
wants_streaming = hook.wants_streaming()
wants_progress_streaming = (
not wants_streaming
and spec.stream_progress_deltas
and spec.progress_callback is not None
and getattr(self.provider, "supports_progress_deltas", False) is True
)

View File

@ -130,11 +130,44 @@ class TestToolEventProgress:
assert finish["result"] == "file.txt"
@pytest.mark.asyncio
async def test_bus_progress_streams_provider_deltas_for_codex_style_provider(
async def test_non_streaming_channel_does_not_publish_codex_progress_deltas(
self,
tmp_path: Path,
) -> None:
"""Providers that opt in can stream content deltas through _progress messages."""
"""Non-streaming channels should get one final reply, not token progress spam."""
bus = MessageBus()
provider = MagicMock()
provider.supports_progress_deltas = True
provider.get_default_model.return_value = "openai-codex/gpt-5.5"
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="Hello", tool_calls=[]))
provider.chat_stream_with_retry = AsyncMock()
loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="openai-codex/gpt-5.5")
loop.tools.get_definitions = MagicMock(return_value=[])
loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign]
await loop._dispatch(InboundMessage(
channel="whatsapp",
sender_id="u1",
chat_id="chat1",
content="say hello",
))
outbound = []
while bus.outbound_size > 0:
outbound.append(await bus.consume_outbound())
assert [m.content for m in outbound] == ["Hello"]
assert not any(m.metadata.get("_progress") for m in outbound)
assert not any(m.metadata.get("_streamed") for m in outbound)
provider.chat_stream_with_retry.assert_not_awaited()
provider.chat_with_retry.assert_awaited_once()
@pytest.mark.asyncio
async def test_streaming_channel_streams_provider_deltas_for_codex_style_provider(
self,
tmp_path: Path,
) -> None:
"""Streaming channels still receive provider deltas through _stream_delta messages."""
bus = MessageBus()
provider = MagicMock()
provider.supports_progress_deltas = True
@ -156,18 +189,27 @@ class TestToolEventProgress:
sender_id="u1",
chat_id="chat1",
content="say hello",
metadata={"_wants_stream": True},
))
outbound = []
while bus.outbound_size > 0:
outbound.append(await bus.consume_outbound())
progress = [m for m in outbound if m.metadata.get("_progress")]
final = [m for m in outbound if not m.metadata.get("_progress")]
deltas = [m for m in outbound if m.metadata.get("_stream_delta")]
stream_end = [m for m in outbound if m.metadata.get("_stream_end")]
final = [
m for m in outbound
if not m.metadata.get("_stream_delta")
and not m.metadata.get("_stream_end")
and not m.metadata.get("_turn_end")
]
assert [m.content for m in progress] == ["Hel", "lo"]
assert final[-2].content == "Hello"
assert (final[-1].metadata or {}).get("_turn_end") is True
assert [m.content for m in deltas] == ["Hel", "lo"]
assert len(stream_end) == 1
assert final[-1].content == "Hello"
assert final[-1].metadata.get("_streamed") is True
assert outbound[-1].metadata.get("_turn_end") is True
provider.chat_with_retry.assert_not_awaited()
@pytest.mark.asyncio
@ -197,8 +239,12 @@ class TestToolEventProgress:
loop.tools.prepare_call = MagicMock(return_value=(None, {"path": "foo.txt"}, None))
loop.tools.execute = AsyncMock(return_value="ok")
streamed: list[str] = []
progress: list[tuple[str, bool, list[dict] | None]] = []
async def on_stream(delta: str) -> None:
streamed.append(delta)
async def on_progress(
content: str,
*,
@ -207,14 +253,15 @@ class TestToolEventProgress:
) -> None:
progress.append((content, tool_hint, tool_events))
final_content, _, _, _, _ = await loop._run_agent_loop([], on_progress=on_progress)
final_content, _, _, _, _ = await loop._run_agent_loop(
[],
on_progress=on_progress,
on_stream=on_stream,
)
assert final_content == "Done"
assert [item[0] for item in progress[:3]] == [
"I will",
" inspect it.",
'custom_tool("foo.txt")',
]
assert streamed == ["I will", " inspect it."]
assert progress[0][0] == 'custom_tool("foo.txt")'
assert all(item[0] != "I will inspect it." for item in progress)
@pytest.mark.asyncio

View File

@ -0,0 +1,79 @@
"""Tests for provider progress delta routing in the shared runner."""
from unittest.mock import AsyncMock, MagicMock
import pytest
from nanobot.agent.runner import AgentRunner, AgentRunSpec
from nanobot.config.schema import AgentDefaults
from nanobot.providers.base import LLMResponse
_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars
@pytest.mark.asyncio
async def test_runner_can_disable_provider_progress_delta_streaming():
"""AgentLoop disables token progress streaming for non-streaming channels."""
provider = MagicMock()
provider.supports_progress_deltas = True
provider.chat_with_retry = AsyncMock(
return_value=LLMResponse(content="done", tool_calls=[], usage={})
)
provider.chat_stream_with_retry = AsyncMock()
tools = MagicMock()
tools.get_definitions.return_value = []
progress_cb = AsyncMock()
runner = AgentRunner(provider)
result = await runner.run(AgentRunSpec(
initial_messages=[
{"role": "system", "content": "system"},
{"role": "user", "content": "hi"},
],
tools=tools,
model="test-model",
max_iterations=1,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
progress_callback=progress_cb,
stream_progress_deltas=False,
))
assert result.final_content == "done"
provider.chat_with_retry.assert_awaited_once()
provider.chat_stream_with_retry.assert_not_awaited()
progress_cb.assert_not_awaited()
@pytest.mark.asyncio
async def test_runner_streams_provider_progress_deltas_by_default():
"""Direct runner users keep the existing opt-in provider progress behavior."""
provider = MagicMock()
provider.supports_progress_deltas = True
async def chat_stream_with_retry(*, on_content_delta, **kwargs):
await on_content_delta("he")
await on_content_delta("llo")
return LLMResponse(content="hello", tool_calls=[], usage={})
provider.chat_stream_with_retry = chat_stream_with_retry
provider.chat_with_retry = AsyncMock()
tools = MagicMock()
tools.get_definitions.return_value = []
progress_cb = AsyncMock()
runner = AgentRunner(provider)
result = await runner.run(AgentRunSpec(
initial_messages=[
{"role": "system", "content": "system"},
{"role": "user", "content": "hi"},
],
tools=tools,
model="test-model",
max_iterations=1,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
progress_callback=progress_cb,
))
assert result.final_content == "hello"
assert [call.args[0] for call in progress_cb.await_args_list] == ["he", "llo"]
provider.chat_with_retry.assert_not_awaited()