mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-20 16:42:25 +00:00
fix(agent): gate provider progress deltas
This commit is contained in:
parent
05e0106592
commit
653de4a7ef
@ -649,6 +649,7 @@ class AgentLoop:
|
|||||||
context_block_limit=self.context_block_limit,
|
context_block_limit=self.context_block_limit,
|
||||||
provider_retry_mode=self.provider_retry_mode,
|
provider_retry_mode=self.provider_retry_mode,
|
||||||
progress_callback=on_progress,
|
progress_callback=on_progress,
|
||||||
|
stream_progress_deltas=on_stream is not None,
|
||||||
retry_wait_callback=on_retry_wait,
|
retry_wait_callback=on_retry_wait,
|
||||||
checkpoint_callback=_checkpoint,
|
checkpoint_callback=_checkpoint,
|
||||||
injection_callback=_drain_pending,
|
injection_callback=_drain_pending,
|
||||||
|
|||||||
@ -76,6 +76,7 @@ class AgentRunSpec:
|
|||||||
context_block_limit: int | None = None
|
context_block_limit: int | None = None
|
||||||
provider_retry_mode: str = "standard"
|
provider_retry_mode: str = "standard"
|
||||||
progress_callback: Any | None = None
|
progress_callback: Any | None = None
|
||||||
|
stream_progress_deltas: bool = True
|
||||||
retry_wait_callback: Any | None = None
|
retry_wait_callback: Any | None = None
|
||||||
checkpoint_callback: Any | None = None
|
checkpoint_callback: Any | None = None
|
||||||
injection_callback: Any | None = None
|
injection_callback: Any | None = None
|
||||||
@ -615,6 +616,7 @@ class AgentRunner:
|
|||||||
wants_streaming = hook.wants_streaming()
|
wants_streaming = hook.wants_streaming()
|
||||||
wants_progress_streaming = (
|
wants_progress_streaming = (
|
||||||
not wants_streaming
|
not wants_streaming
|
||||||
|
and spec.stream_progress_deltas
|
||||||
and spec.progress_callback is not None
|
and spec.progress_callback is not None
|
||||||
and getattr(self.provider, "supports_progress_deltas", False) is True
|
and getattr(self.provider, "supports_progress_deltas", False) is True
|
||||||
)
|
)
|
||||||
|
|||||||
@ -130,11 +130,44 @@ class TestToolEventProgress:
|
|||||||
assert finish["result"] == "file.txt"
|
assert finish["result"] == "file.txt"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_bus_progress_streams_provider_deltas_for_codex_style_provider(
|
async def test_non_streaming_channel_does_not_publish_codex_progress_deltas(
|
||||||
self,
|
self,
|
||||||
tmp_path: Path,
|
tmp_path: Path,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Providers that opt in can stream content deltas through _progress messages."""
|
"""Non-streaming channels should get one final reply, not token progress spam."""
|
||||||
|
bus = MessageBus()
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.supports_progress_deltas = True
|
||||||
|
provider.get_default_model.return_value = "openai-codex/gpt-5.5"
|
||||||
|
provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="Hello", tool_calls=[]))
|
||||||
|
provider.chat_stream_with_retry = AsyncMock()
|
||||||
|
loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="openai-codex/gpt-5.5")
|
||||||
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign]
|
||||||
|
|
||||||
|
await loop._dispatch(InboundMessage(
|
||||||
|
channel="whatsapp",
|
||||||
|
sender_id="u1",
|
||||||
|
chat_id="chat1",
|
||||||
|
content="say hello",
|
||||||
|
))
|
||||||
|
|
||||||
|
outbound = []
|
||||||
|
while bus.outbound_size > 0:
|
||||||
|
outbound.append(await bus.consume_outbound())
|
||||||
|
|
||||||
|
assert [m.content for m in outbound] == ["Hello"]
|
||||||
|
assert not any(m.metadata.get("_progress") for m in outbound)
|
||||||
|
assert not any(m.metadata.get("_streamed") for m in outbound)
|
||||||
|
provider.chat_stream_with_retry.assert_not_awaited()
|
||||||
|
provider.chat_with_retry.assert_awaited_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_streaming_channel_streams_provider_deltas_for_codex_style_provider(
|
||||||
|
self,
|
||||||
|
tmp_path: Path,
|
||||||
|
) -> None:
|
||||||
|
"""Streaming channels still receive provider deltas through _stream_delta messages."""
|
||||||
bus = MessageBus()
|
bus = MessageBus()
|
||||||
provider = MagicMock()
|
provider = MagicMock()
|
||||||
provider.supports_progress_deltas = True
|
provider.supports_progress_deltas = True
|
||||||
@ -156,18 +189,27 @@ class TestToolEventProgress:
|
|||||||
sender_id="u1",
|
sender_id="u1",
|
||||||
chat_id="chat1",
|
chat_id="chat1",
|
||||||
content="say hello",
|
content="say hello",
|
||||||
|
metadata={"_wants_stream": True},
|
||||||
))
|
))
|
||||||
|
|
||||||
outbound = []
|
outbound = []
|
||||||
while bus.outbound_size > 0:
|
while bus.outbound_size > 0:
|
||||||
outbound.append(await bus.consume_outbound())
|
outbound.append(await bus.consume_outbound())
|
||||||
|
|
||||||
progress = [m for m in outbound if m.metadata.get("_progress")]
|
deltas = [m for m in outbound if m.metadata.get("_stream_delta")]
|
||||||
final = [m for m in outbound if not m.metadata.get("_progress")]
|
stream_end = [m for m in outbound if m.metadata.get("_stream_end")]
|
||||||
|
final = [
|
||||||
|
m for m in outbound
|
||||||
|
if not m.metadata.get("_stream_delta")
|
||||||
|
and not m.metadata.get("_stream_end")
|
||||||
|
and not m.metadata.get("_turn_end")
|
||||||
|
]
|
||||||
|
|
||||||
assert [m.content for m in progress] == ["Hel", "lo"]
|
assert [m.content for m in deltas] == ["Hel", "lo"]
|
||||||
assert final[-2].content == "Hello"
|
assert len(stream_end) == 1
|
||||||
assert (final[-1].metadata or {}).get("_turn_end") is True
|
assert final[-1].content == "Hello"
|
||||||
|
assert final[-1].metadata.get("_streamed") is True
|
||||||
|
assert outbound[-1].metadata.get("_turn_end") is True
|
||||||
provider.chat_with_retry.assert_not_awaited()
|
provider.chat_with_retry.assert_not_awaited()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -197,8 +239,12 @@ class TestToolEventProgress:
|
|||||||
loop.tools.prepare_call = MagicMock(return_value=(None, {"path": "foo.txt"}, None))
|
loop.tools.prepare_call = MagicMock(return_value=(None, {"path": "foo.txt"}, None))
|
||||||
loop.tools.execute = AsyncMock(return_value="ok")
|
loop.tools.execute = AsyncMock(return_value="ok")
|
||||||
|
|
||||||
|
streamed: list[str] = []
|
||||||
progress: list[tuple[str, bool, list[dict] | None]] = []
|
progress: list[tuple[str, bool, list[dict] | None]] = []
|
||||||
|
|
||||||
|
async def on_stream(delta: str) -> None:
|
||||||
|
streamed.append(delta)
|
||||||
|
|
||||||
async def on_progress(
|
async def on_progress(
|
||||||
content: str,
|
content: str,
|
||||||
*,
|
*,
|
||||||
@ -207,14 +253,15 @@ class TestToolEventProgress:
|
|||||||
) -> None:
|
) -> None:
|
||||||
progress.append((content, tool_hint, tool_events))
|
progress.append((content, tool_hint, tool_events))
|
||||||
|
|
||||||
final_content, _, _, _, _ = await loop._run_agent_loop([], on_progress=on_progress)
|
final_content, _, _, _, _ = await loop._run_agent_loop(
|
||||||
|
[],
|
||||||
|
on_progress=on_progress,
|
||||||
|
on_stream=on_stream,
|
||||||
|
)
|
||||||
|
|
||||||
assert final_content == "Done"
|
assert final_content == "Done"
|
||||||
assert [item[0] for item in progress[:3]] == [
|
assert streamed == ["I will", " inspect it."]
|
||||||
"I will",
|
assert progress[0][0] == 'custom_tool("foo.txt")'
|
||||||
" inspect it.",
|
|
||||||
'custom_tool("foo.txt")',
|
|
||||||
]
|
|
||||||
assert all(item[0] != "I will inspect it." for item in progress)
|
assert all(item[0] != "I will inspect it." for item in progress)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
79
tests/agent/test_runner_progress_deltas.py
Normal file
79
tests/agent/test_runner_progress_deltas.py
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
"""Tests for provider progress delta routing in the shared runner."""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.agent.runner import AgentRunner, AgentRunSpec
|
||||||
|
from nanobot.config.schema import AgentDefaults
|
||||||
|
from nanobot.providers.base import LLMResponse
|
||||||
|
|
||||||
|
_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_can_disable_provider_progress_delta_streaming():
|
||||||
|
"""AgentLoop disables token progress streaming for non-streaming channels."""
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.supports_progress_deltas = True
|
||||||
|
provider.chat_with_retry = AsyncMock(
|
||||||
|
return_value=LLMResponse(content="done", tool_calls=[], usage={})
|
||||||
|
)
|
||||||
|
provider.chat_stream_with_retry = AsyncMock()
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
progress_cb = AsyncMock()
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
result = await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=[
|
||||||
|
{"role": "system", "content": "system"},
|
||||||
|
{"role": "user", "content": "hi"},
|
||||||
|
],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=1,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
progress_callback=progress_cb,
|
||||||
|
stream_progress_deltas=False,
|
||||||
|
))
|
||||||
|
|
||||||
|
assert result.final_content == "done"
|
||||||
|
provider.chat_with_retry.assert_awaited_once()
|
||||||
|
provider.chat_stream_with_retry.assert_not_awaited()
|
||||||
|
progress_cb.assert_not_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_streams_provider_progress_deltas_by_default():
|
||||||
|
"""Direct runner users keep the existing opt-in provider progress behavior."""
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.supports_progress_deltas = True
|
||||||
|
|
||||||
|
async def chat_stream_with_retry(*, on_content_delta, **kwargs):
|
||||||
|
await on_content_delta("he")
|
||||||
|
await on_content_delta("llo")
|
||||||
|
return LLMResponse(content="hello", tool_calls=[], usage={})
|
||||||
|
|
||||||
|
provider.chat_stream_with_retry = chat_stream_with_retry
|
||||||
|
provider.chat_with_retry = AsyncMock()
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
progress_cb = AsyncMock()
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
result = await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=[
|
||||||
|
{"role": "system", "content": "system"},
|
||||||
|
{"role": "user", "content": "hi"},
|
||||||
|
],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=1,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
progress_callback=progress_cb,
|
||||||
|
))
|
||||||
|
|
||||||
|
assert result.final_content == "hello"
|
||||||
|
assert [call.args[0] for call in progress_cb.await_args_list] == ["he", "llo"]
|
||||||
|
provider.chat_with_retry.assert_not_awaited()
|
||||||
Loading…
x
Reference in New Issue
Block a user