diff --git a/nanobot/agent/runner.py b/nanobot/agent/runner.py index e92d864f2..5cb7b4f0e 100644 --- a/nanobot/agent/runner.py +++ b/nanobot/agent/runner.py @@ -134,6 +134,36 @@ class AgentRunner: continue messages.append(injection) + async def _try_drain_injections( + self, + spec: AgentRunSpec, + messages: list[dict[str, Any]], + assistant_message: dict[str, Any] | None, + injection_cycles: int, + *, + phase: str = "after error", + ) -> tuple[bool, int]: + """Drain pending injections. Returns (should_continue, updated_cycles). + + If injections are found and we haven't exceeded _MAX_INJECTION_CYCLES, + append them to *messages* and return (True, cycles+1) so the caller + continues the iteration loop. Otherwise return (False, cycles). + """ + if injection_cycles >= _MAX_INJECTION_CYCLES: + return False, injection_cycles + injections = await self._drain_injections(spec) + if not injections: + return False, injection_cycles + injection_cycles += 1 + if assistant_message is not None: + messages.append(assistant_message) + self._append_injected_messages(messages, injections) + logger.info( + "Injected {} follow-up message(s) {} ({}/{})", + len(injections), phase, injection_cycles, _MAX_INJECTION_CYCLES, + ) + return True, injection_cycles + async def _drain_injections(self, spec: AgentRunSpec) -> list[dict[str, Any]]: """Drain pending user messages via the injection callback. @@ -287,6 +317,13 @@ class AgentRunner: context.error = error context.stop_reason = stop_reason await hook.after_iteration(context) + should_continue, injection_cycles = await self._try_drain_injections( + spec, messages, None, injection_cycles, + phase="after tool error", + ) + if should_continue: + had_injections = True + continue break await self._emit_checkpoint( spec, @@ -379,36 +416,31 @@ class AgentRunner: # Check for mid-turn injections BEFORE signaling stream end. # If injections are found we keep the stream alive (resuming=True) # so streaming channels don't prematurely finalize the card. - _injected_after_final = False - if injection_cycles < _MAX_INJECTION_CYCLES: - injections = await self._drain_injections(spec) - if injections: - had_injections = True - injection_cycles += 1 - _injected_after_final = True - if assistant_message is not None: - messages.append(assistant_message) - await self._emit_checkpoint( - spec, - { - "phase": "final_response", - "iteration": iteration, - "model": spec.model, - "assistant_message": assistant_message, - "completed_tool_results": [], - "pending_tool_calls": [], - }, - ) - self._append_injected_messages(messages, injections) - logger.info( - "Injected {} follow-up message(s) after final response ({}/{})", - len(injections), injection_cycles, _MAX_INJECTION_CYCLES, + should_continue, injection_cycles = await self._try_drain_injections( + spec, messages, assistant_message, injection_cycles, + phase="after final response", + ) + if should_continue: + had_injections = True + # Emit checkpoint for the assistant message that was appended + # by _try_drain_injections, then keep the stream alive. + if assistant_message is not None: + await self._emit_checkpoint( + spec, + { + "phase": "final_response", + "iteration": iteration, + "model": spec.model, + "assistant_message": assistant_message, + "completed_tool_results": [], + "pending_tool_calls": [], + }, ) if hook.wants_streaming(): - await hook.on_stream_end(context, resuming=_injected_after_final) + await hook.on_stream_end(context, resuming=should_continue) - if _injected_after_final: + if should_continue: await hook.after_iteration(context) continue @@ -421,6 +453,13 @@ class AgentRunner: context.error = error context.stop_reason = stop_reason await hook.after_iteration(context) + should_continue, injection_cycles = await self._try_drain_injections( + spec, messages, None, injection_cycles, + phase="after LLM error", + ) + if should_continue: + had_injections = True + continue break if is_blank_text(clean): final_content = EMPTY_FINAL_RESPONSE_MESSAGE @@ -431,6 +470,13 @@ class AgentRunner: context.error = error context.stop_reason = stop_reason await hook.after_iteration(context) + should_continue, injection_cycles = await self._try_drain_injections( + spec, messages, None, injection_cycles, + phase="after empty response", + ) + if should_continue: + had_injections = True + continue break messages.append(assistant_message or build_assistant_message( @@ -467,6 +513,15 @@ class AgentRunner: max_iterations=spec.max_iterations, ) self._append_final_message(messages, final_content) + # Drain any remaining injections so they are appended to the + # conversation history instead of being re-published as + # independent inbound messages by _dispatch's finally block. + # We ignore should_continue here because the for-loop has already + # exhausted all iterations. + _, injection_cycles = await self._try_drain_injections( + spec, messages, None, injection_cycles, + phase="after max_iterations", + ) return AgentRunResult( final_content=final_content, diff --git a/tests/agent/test_runner.py b/tests/agent/test_runner.py index a62457aa8..4a943165c 100644 --- a/tests/agent/test_runner.py +++ b/tests/agent/test_runner.py @@ -2410,3 +2410,236 @@ async def test_dispatch_republishes_leftover_queue_messages(tmp_path): contents = [m.content for m in msgs] assert "leftover-1" in contents assert "leftover-2" in contents + + +@pytest.mark.asyncio +async def test_drain_injections_on_fatal_tool_error(): + """Pending injections should be drained even when a fatal tool error occurs.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.bus.events import InboundMessage + + provider = MagicMock() + call_count = {"n": 0} + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + if call_count["n"] == 1: + return LLMResponse( + content="", + tool_calls=[ToolCallRequest(id="c1", name="exec", arguments={"cmd": "bad"})], + usage={}, + ) + # Second call: respond normally to the injected follow-up + return LLMResponse(content="reply to follow-up", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + tools.execute = AsyncMock(side_effect=RuntimeError("tool exploded")) + + injection_queue = asyncio.Queue() + + async def inject_cb(): + items = [] + while not injection_queue.empty(): + items.append(await injection_queue.get()) + return items + + await injection_queue.put( + InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up after error") + ) + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "hello"}], + tools=tools, + model="test-model", + max_iterations=5, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + fail_on_tool_error=True, + injection_callback=inject_cb, + )) + + assert result.had_injections is True + assert result.final_content == "reply to follow-up" + # The injection should be in the messages history + injected = [ + m for m in result.messages + if m.get("role") == "user" and m.get("content") == "follow-up after error" + ] + assert len(injected) == 1 + + +@pytest.mark.asyncio +async def test_drain_injections_on_llm_error(): + """Pending injections should be drained when the LLM returns an error finish_reason.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.bus.events import InboundMessage + + provider = MagicMock() + call_count = {"n": 0} + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + if call_count["n"] == 1: + return LLMResponse( + content=None, + tool_calls=[], + finish_reason="error", + usage={}, + ) + # Second call: respond normally to the injected follow-up + return LLMResponse(content="recovered answer", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + + injection_queue = asyncio.Queue() + + async def inject_cb(): + items = [] + while not injection_queue.empty(): + items.append(await injection_queue.get()) + return items + + await injection_queue.put( + InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up after LLM error") + ) + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "previous response"}, + {"role": "user", "content": "trigger error"}, + ], + tools=tools, + model="test-model", + max_iterations=5, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + injection_callback=inject_cb, + )) + + assert result.had_injections is True + assert result.final_content == "recovered answer" + injected = [ + m for m in result.messages + if m.get("role") == "user" and "follow-up after LLM error" in str(m.get("content", "")) + ] + assert len(injected) == 1 + + +@pytest.mark.asyncio +async def test_drain_injections_on_empty_final_response(): + """Pending injections should be drained when the runner exits due to empty response.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner, _MAX_EMPTY_RETRIES + from nanobot.bus.events import InboundMessage + + provider = MagicMock() + call_count = {"n": 0} + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + if call_count["n"] <= _MAX_EMPTY_RETRIES + 1: + return LLMResponse(content="", tool_calls=[], usage={}) + # After retries exhausted + injection drain, respond normally + return LLMResponse(content="answer after empty", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + + injection_queue = asyncio.Queue() + + async def inject_cb(): + items = [] + while not injection_queue.empty(): + items.append(await injection_queue.get()) + return items + + await injection_queue.put( + InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up after empty") + ) + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "previous response"}, + {"role": "user", "content": "trigger empty"}, + ], + tools=tools, + model="test-model", + max_iterations=10, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + injection_callback=inject_cb, + )) + + assert result.had_injections is True + assert result.final_content == "answer after empty" + injected = [ + m for m in result.messages + if m.get("role") == "user" and "follow-up after empty" in str(m.get("content", "")) + ] + assert len(injected) == 1 + + +@pytest.mark.asyncio +async def test_drain_injections_on_max_iterations(): + """Pending injections should be drained when the runner hits max_iterations. + + Unlike other error paths, max_iterations cannot continue the loop, so + injections are appended to messages but not processed by the LLM. + The key point is they are consumed from the queue to prevent re-publish. + """ + from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.bus.events import InboundMessage + + provider = MagicMock() + call_count = {"n": 0} + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + return LLMResponse( + content="", + tool_calls=[ToolCallRequest(id=f"c{call_count['n']}", name="read_file", arguments={"path": "x"})], + usage={}, + ) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + tools.execute = AsyncMock(return_value="file content") + + injection_queue = asyncio.Queue() + + async def inject_cb(): + items = [] + while not injection_queue.empty(): + items.append(await injection_queue.get()) + return items + + await injection_queue.put( + InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up after max iters") + ) + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "hello"}], + tools=tools, + model="test-model", + max_iterations=2, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + injection_callback=inject_cb, + )) + + assert result.stop_reason == "max_iterations" + # The injection was consumed from the queue (preventing re-publish) + assert injection_queue.empty() + # The injection message is appended to conversation history + injected = [ + m for m in result.messages + if m.get("role") == "user" and m.get("content") == "follow-up after max iters" + ] + assert len(injected) == 1