mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-19 16:12:30 +00:00
fix(agent): gate provider progress deltas
This commit is contained in:
parent
05e0106592
commit
653de4a7ef
@ -649,6 +649,7 @@ class AgentLoop:
|
||||
context_block_limit=self.context_block_limit,
|
||||
provider_retry_mode=self.provider_retry_mode,
|
||||
progress_callback=on_progress,
|
||||
stream_progress_deltas=on_stream is not None,
|
||||
retry_wait_callback=on_retry_wait,
|
||||
checkpoint_callback=_checkpoint,
|
||||
injection_callback=_drain_pending,
|
||||
|
||||
@ -76,6 +76,7 @@ class AgentRunSpec:
|
||||
context_block_limit: int | None = None
|
||||
provider_retry_mode: str = "standard"
|
||||
progress_callback: Any | None = None
|
||||
stream_progress_deltas: bool = True
|
||||
retry_wait_callback: Any | None = None
|
||||
checkpoint_callback: Any | None = None
|
||||
injection_callback: Any | None = None
|
||||
@ -615,6 +616,7 @@ class AgentRunner:
|
||||
wants_streaming = hook.wants_streaming()
|
||||
wants_progress_streaming = (
|
||||
not wants_streaming
|
||||
and spec.stream_progress_deltas
|
||||
and spec.progress_callback is not None
|
||||
and getattr(self.provider, "supports_progress_deltas", False) is True
|
||||
)
|
||||
|
||||
@ -130,11 +130,44 @@ class TestToolEventProgress:
|
||||
assert finish["result"] == "file.txt"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bus_progress_streams_provider_deltas_for_codex_style_provider(
|
||||
async def test_non_streaming_channel_does_not_publish_codex_progress_deltas(
|
||||
self,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Providers that opt in can stream content deltas through _progress messages."""
|
||||
"""Non-streaming channels should get one final reply, not token progress spam."""
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.supports_progress_deltas = True
|
||||
provider.get_default_model.return_value = "openai-codex/gpt-5.5"
|
||||
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="Hello", tool_calls=[]))
|
||||
provider.chat_stream_with_retry = AsyncMock()
|
||||
loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="openai-codex/gpt-5.5")
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign]
|
||||
|
||||
await loop._dispatch(InboundMessage(
|
||||
channel="whatsapp",
|
||||
sender_id="u1",
|
||||
chat_id="chat1",
|
||||
content="say hello",
|
||||
))
|
||||
|
||||
outbound = []
|
||||
while bus.outbound_size > 0:
|
||||
outbound.append(await bus.consume_outbound())
|
||||
|
||||
assert [m.content for m in outbound] == ["Hello"]
|
||||
assert not any(m.metadata.get("_progress") for m in outbound)
|
||||
assert not any(m.metadata.get("_streamed") for m in outbound)
|
||||
provider.chat_stream_with_retry.assert_not_awaited()
|
||||
provider.chat_with_retry.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_channel_streams_provider_deltas_for_codex_style_provider(
|
||||
self,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Streaming channels still receive provider deltas through _stream_delta messages."""
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.supports_progress_deltas = True
|
||||
@ -156,18 +189,27 @@ class TestToolEventProgress:
|
||||
sender_id="u1",
|
||||
chat_id="chat1",
|
||||
content="say hello",
|
||||
metadata={"_wants_stream": True},
|
||||
))
|
||||
|
||||
outbound = []
|
||||
while bus.outbound_size > 0:
|
||||
outbound.append(await bus.consume_outbound())
|
||||
|
||||
progress = [m for m in outbound if m.metadata.get("_progress")]
|
||||
final = [m for m in outbound if not m.metadata.get("_progress")]
|
||||
deltas = [m for m in outbound if m.metadata.get("_stream_delta")]
|
||||
stream_end = [m for m in outbound if m.metadata.get("_stream_end")]
|
||||
final = [
|
||||
m for m in outbound
|
||||
if not m.metadata.get("_stream_delta")
|
||||
and not m.metadata.get("_stream_end")
|
||||
and not m.metadata.get("_turn_end")
|
||||
]
|
||||
|
||||
assert [m.content for m in progress] == ["Hel", "lo"]
|
||||
assert final[-2].content == "Hello"
|
||||
assert (final[-1].metadata or {}).get("_turn_end") is True
|
||||
assert [m.content for m in deltas] == ["Hel", "lo"]
|
||||
assert len(stream_end) == 1
|
||||
assert final[-1].content == "Hello"
|
||||
assert final[-1].metadata.get("_streamed") is True
|
||||
assert outbound[-1].metadata.get("_turn_end") is True
|
||||
provider.chat_with_retry.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -197,8 +239,12 @@ class TestToolEventProgress:
|
||||
loop.tools.prepare_call = MagicMock(return_value=(None, {"path": "foo.txt"}, None))
|
||||
loop.tools.execute = AsyncMock(return_value="ok")
|
||||
|
||||
streamed: list[str] = []
|
||||
progress: list[tuple[str, bool, list[dict] | None]] = []
|
||||
|
||||
async def on_stream(delta: str) -> None:
|
||||
streamed.append(delta)
|
||||
|
||||
async def on_progress(
|
||||
content: str,
|
||||
*,
|
||||
@ -207,14 +253,15 @@ class TestToolEventProgress:
|
||||
) -> None:
|
||||
progress.append((content, tool_hint, tool_events))
|
||||
|
||||
final_content, _, _, _, _ = await loop._run_agent_loop([], on_progress=on_progress)
|
||||
final_content, _, _, _, _ = await loop._run_agent_loop(
|
||||
[],
|
||||
on_progress=on_progress,
|
||||
on_stream=on_stream,
|
||||
)
|
||||
|
||||
assert final_content == "Done"
|
||||
assert [item[0] for item in progress[:3]] == [
|
||||
"I will",
|
||||
" inspect it.",
|
||||
'custom_tool("foo.txt")',
|
||||
]
|
||||
assert streamed == ["I will", " inspect it."]
|
||||
assert progress[0][0] == 'custom_tool("foo.txt")'
|
||||
assert all(item[0] != "I will inspect it." for item in progress)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
79
tests/agent/test_runner_progress_deltas.py
Normal file
79
tests/agent/test_runner_progress_deltas.py
Normal file
@ -0,0 +1,79 @@
|
||||
"""Tests for provider progress delta routing in the shared runner."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.runner import AgentRunner, AgentRunSpec
|
||||
from nanobot.config.schema import AgentDefaults
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_can_disable_provider_progress_delta_streaming():
|
||||
"""AgentLoop disables token progress streaming for non-streaming channels."""
|
||||
provider = MagicMock()
|
||||
provider.supports_progress_deltas = True
|
||||
provider.chat_with_retry = AsyncMock(
|
||||
return_value=LLMResponse(content="done", tool_calls=[], usage={})
|
||||
)
|
||||
provider.chat_stream_with_retry = AsyncMock()
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
progress_cb = AsyncMock()
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[
|
||||
{"role": "system", "content": "system"},
|
||||
{"role": "user", "content": "hi"},
|
||||
],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
progress_callback=progress_cb,
|
||||
stream_progress_deltas=False,
|
||||
))
|
||||
|
||||
assert result.final_content == "done"
|
||||
provider.chat_with_retry.assert_awaited_once()
|
||||
provider.chat_stream_with_retry.assert_not_awaited()
|
||||
progress_cb.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_streams_provider_progress_deltas_by_default():
|
||||
"""Direct runner users keep the existing opt-in provider progress behavior."""
|
||||
provider = MagicMock()
|
||||
provider.supports_progress_deltas = True
|
||||
|
||||
async def chat_stream_with_retry(*, on_content_delta, **kwargs):
|
||||
await on_content_delta("he")
|
||||
await on_content_delta("llo")
|
||||
return LLMResponse(content="hello", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_stream_with_retry = chat_stream_with_retry
|
||||
provider.chat_with_retry = AsyncMock()
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
progress_cb = AsyncMock()
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[
|
||||
{"role": "system", "content": "system"},
|
||||
{"role": "user", "content": "hi"},
|
||||
],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
progress_callback=progress_cb,
|
||||
))
|
||||
|
||||
assert result.final_content == "hello"
|
||||
assert [call.args[0] for call in progress_cb.await_args_list] == ["he", "llo"]
|
||||
provider.chat_with_retry.assert_not_awaited()
|
||||
Loading…
x
Reference in New Issue
Block a user