diff --git a/tests/agent/test_runner.py b/tests/agent/test_runner.py index 51d449a96..b47db948b 100644 --- a/tests/agent/test_runner.py +++ b/tests/agent/test_runner.py @@ -2918,3 +2918,45 @@ def test_snip_history_no_user_at_all_falls_back_gracefully(monkeypatch): assert non_system[0]["role"] in ("user", "tool"), ( f"Safety net should ensure first non-system is user/tool, got {non_system[0]['role']}" ) + + +@pytest.mark.asyncio +async def test_runner_binds_on_retry_wait_to_retry_callback_not_progress(): + """Regression: provider retry heartbeats must route through + ``retry_wait_callback``, not ``progress_callback``. Binding them to + the progress callback (as an earlier runtime refactor did) caused + internal retry diagnostics like "Model request failed, retry in 1s" + to leak to end-user channels as normal progress updates. + """ + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + captured: dict = {} + + async def chat_with_retry(**kwargs): + captured.update(kwargs) + return LLMResponse(content="done", tool_calls=[], usage={}) + + provider = MagicMock() + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + + progress_cb = AsyncMock() + retry_wait_cb = AsyncMock() + + runner = AgentRunner(provider) + await runner.run(AgentRunSpec( + initial_messages=[ + {"role": "system", "content": "system"}, + {"role": "user", "content": "hi"}, + ], + tools=tools, + model="test-model", + max_iterations=1, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + progress_callback=progress_cb, + retry_wait_callback=retry_wait_cb, + )) + + assert captured["on_retry_wait"] is retry_wait_cb + assert captured["on_retry_wait"] is not progress_cb diff --git a/tests/channels/test_channel_manager_delta_coalescing.py b/tests/channels/test_channel_manager_delta_coalescing.py index 0fa97f5b8..3c150f903 100644 --- a/tests/channels/test_channel_manager_delta_coalescing.py +++ b/tests/channels/test_channel_manager_delta_coalescing.py @@ -296,3 +296,50 @@ class TestDispatchOutboundWithCoalescing: # Should have pending regular message assert len(pending) == 1 assert pending[0].content == "Final" + + +class TestRetryWaitFiltering: + """Internal provider retry heartbeats must never reach channels.""" + + @pytest.mark.asyncio + async def test_retry_wait_message_dropped(self, manager, bus): + """A ``_retry_wait`` message must be filtered before channel dispatch. + + Regression: provider retry diagnostics like + ``Model request failed, retry in 1s (attempt 1).`` were being + delivered to end-user channels because the runner bound + ``on_retry_wait`` to the progress callback. + """ + retry_msg = OutboundMessage( + channel="mock", + chat_id="chat1", + content="Model request failed, retry in 1s (attempt 1).", + metadata={"_retry_wait": True}, + ) + real_msg = OutboundMessage( + channel="mock", + chat_id="chat1", + content="final answer", + metadata={}, + ) + await bus.publish_outbound(retry_msg) + await bus.publish_outbound(real_msg) + + task = asyncio.create_task(manager._dispatch_outbound()) + try: + for _ in range(30): + if manager.channels["mock"]._send_mock.await_count >= 1: + break + await asyncio.sleep(0.05) + finally: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + send_mock = manager.channels["mock"]._send_mock + assert send_mock.await_count == 1 + sent = send_mock.await_args_list[0].args[0] + assert sent.content == "final answer" + assert not sent.metadata.get("_retry_wait")