fix(agent): drain injection queue on error/edge-case exit paths

When the agent runner exits due to LLM error, tool error, empty response,
or max_iterations, it breaks out of the iteration loop without draining
the pending injection queue. This causes leftover messages to be
re-published as independent inbound messages, resulting in duplicate or
confusing replies to the user.

Extract the injection drain logic into a `_try_drain_injections` helper
and call it before each break in the error/edge-case paths. If injections
are found, continue the loop instead of breaking. For max_iterations
(where the loop is exhausted), drain injections to prevent re-publish
without continuing.
This commit is contained in:
chengyongru 2026-04-13 23:33:25 +08:00 committed by Xubin Ren
parent 3c06db7e4e
commit d849a3fa06
2 changed files with 314 additions and 26 deletions

View File

@ -134,6 +134,36 @@ class AgentRunner:
continue
messages.append(injection)
async def _try_drain_injections(
self,
spec: AgentRunSpec,
messages: list[dict[str, Any]],
assistant_message: dict[str, Any] | None,
injection_cycles: int,
*,
phase: str = "after error",
) -> tuple[bool, int]:
"""Drain pending injections. Returns (should_continue, updated_cycles).
If injections are found and we haven't exceeded _MAX_INJECTION_CYCLES,
append them to *messages* and return (True, cycles+1) so the caller
continues the iteration loop. Otherwise return (False, cycles).
"""
if injection_cycles >= _MAX_INJECTION_CYCLES:
return False, injection_cycles
injections = await self._drain_injections(spec)
if not injections:
return False, injection_cycles
injection_cycles += 1
if assistant_message is not None:
messages.append(assistant_message)
self._append_injected_messages(messages, injections)
logger.info(
"Injected {} follow-up message(s) {} ({}/{})",
len(injections), phase, injection_cycles, _MAX_INJECTION_CYCLES,
)
return True, injection_cycles
async def _drain_injections(self, spec: AgentRunSpec) -> list[dict[str, Any]]:
"""Drain pending user messages via the injection callback.
@ -287,6 +317,13 @@ class AgentRunner:
context.error = error
context.stop_reason = stop_reason
await hook.after_iteration(context)
should_continue, injection_cycles = await self._try_drain_injections(
spec, messages, None, injection_cycles,
phase="after tool error",
)
if should_continue:
had_injections = True
continue
break
await self._emit_checkpoint(
spec,
@ -379,36 +416,31 @@ class AgentRunner:
# Check for mid-turn injections BEFORE signaling stream end.
# If injections are found we keep the stream alive (resuming=True)
# so streaming channels don't prematurely finalize the card.
_injected_after_final = False
if injection_cycles < _MAX_INJECTION_CYCLES:
injections = await self._drain_injections(spec)
if injections:
had_injections = True
injection_cycles += 1
_injected_after_final = True
if assistant_message is not None:
messages.append(assistant_message)
await self._emit_checkpoint(
spec,
{
"phase": "final_response",
"iteration": iteration,
"model": spec.model,
"assistant_message": assistant_message,
"completed_tool_results": [],
"pending_tool_calls": [],
},
)
self._append_injected_messages(messages, injections)
logger.info(
"Injected {} follow-up message(s) after final response ({}/{})",
len(injections), injection_cycles, _MAX_INJECTION_CYCLES,
should_continue, injection_cycles = await self._try_drain_injections(
spec, messages, assistant_message, injection_cycles,
phase="after final response",
)
if should_continue:
had_injections = True
# Emit checkpoint for the assistant message that was appended
# by _try_drain_injections, then keep the stream alive.
if assistant_message is not None:
await self._emit_checkpoint(
spec,
{
"phase": "final_response",
"iteration": iteration,
"model": spec.model,
"assistant_message": assistant_message,
"completed_tool_results": [],
"pending_tool_calls": [],
},
)
if hook.wants_streaming():
await hook.on_stream_end(context, resuming=_injected_after_final)
await hook.on_stream_end(context, resuming=should_continue)
if _injected_after_final:
if should_continue:
await hook.after_iteration(context)
continue
@ -421,6 +453,13 @@ class AgentRunner:
context.error = error
context.stop_reason = stop_reason
await hook.after_iteration(context)
should_continue, injection_cycles = await self._try_drain_injections(
spec, messages, None, injection_cycles,
phase="after LLM error",
)
if should_continue:
had_injections = True
continue
break
if is_blank_text(clean):
final_content = EMPTY_FINAL_RESPONSE_MESSAGE
@ -431,6 +470,13 @@ class AgentRunner:
context.error = error
context.stop_reason = stop_reason
await hook.after_iteration(context)
should_continue, injection_cycles = await self._try_drain_injections(
spec, messages, None, injection_cycles,
phase="after empty response",
)
if should_continue:
had_injections = True
continue
break
messages.append(assistant_message or build_assistant_message(
@ -467,6 +513,15 @@ class AgentRunner:
max_iterations=spec.max_iterations,
)
self._append_final_message(messages, final_content)
# Drain any remaining injections so they are appended to the
# conversation history instead of being re-published as
# independent inbound messages by _dispatch's finally block.
# We ignore should_continue here because the for-loop has already
# exhausted all iterations.
_, injection_cycles = await self._try_drain_injections(
spec, messages, None, injection_cycles,
phase="after max_iterations",
)
return AgentRunResult(
final_content=final_content,

View File

@ -2410,3 +2410,236 @@ async def test_dispatch_republishes_leftover_queue_messages(tmp_path):
contents = [m.content for m in msgs]
assert "leftover-1" in contents
assert "leftover-2" in contents
@pytest.mark.asyncio
async def test_drain_injections_on_fatal_tool_error():
"""Pending injections should be drained even when a fatal tool error occurs."""
from nanobot.agent.runner import AgentRunSpec, AgentRunner
from nanobot.bus.events import InboundMessage
provider = MagicMock()
call_count = {"n": 0}
async def chat_with_retry(*, messages, **kwargs):
call_count["n"] += 1
if call_count["n"] == 1:
return LLMResponse(
content="",
tool_calls=[ToolCallRequest(id="c1", name="exec", arguments={"cmd": "bad"})],
usage={},
)
# Second call: respond normally to the injected follow-up
return LLMResponse(content="reply to follow-up", tool_calls=[], usage={})
provider.chat_with_retry = chat_with_retry
tools = MagicMock()
tools.get_definitions.return_value = []
tools.execute = AsyncMock(side_effect=RuntimeError("tool exploded"))
injection_queue = asyncio.Queue()
async def inject_cb():
items = []
while not injection_queue.empty():
items.append(await injection_queue.get())
return items
await injection_queue.put(
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up after error")
)
runner = AgentRunner(provider)
result = await runner.run(AgentRunSpec(
initial_messages=[{"role": "user", "content": "hello"}],
tools=tools,
model="test-model",
max_iterations=5,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
fail_on_tool_error=True,
injection_callback=inject_cb,
))
assert result.had_injections is True
assert result.final_content == "reply to follow-up"
# The injection should be in the messages history
injected = [
m for m in result.messages
if m.get("role") == "user" and m.get("content") == "follow-up after error"
]
assert len(injected) == 1
@pytest.mark.asyncio
async def test_drain_injections_on_llm_error():
"""Pending injections should be drained when the LLM returns an error finish_reason."""
from nanobot.agent.runner import AgentRunSpec, AgentRunner
from nanobot.bus.events import InboundMessage
provider = MagicMock()
call_count = {"n": 0}
async def chat_with_retry(*, messages, **kwargs):
call_count["n"] += 1
if call_count["n"] == 1:
return LLMResponse(
content=None,
tool_calls=[],
finish_reason="error",
usage={},
)
# Second call: respond normally to the injected follow-up
return LLMResponse(content="recovered answer", tool_calls=[], usage={})
provider.chat_with_retry = chat_with_retry
tools = MagicMock()
tools.get_definitions.return_value = []
injection_queue = asyncio.Queue()
async def inject_cb():
items = []
while not injection_queue.empty():
items.append(await injection_queue.get())
return items
await injection_queue.put(
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up after LLM error")
)
runner = AgentRunner(provider)
result = await runner.run(AgentRunSpec(
initial_messages=[
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "previous response"},
{"role": "user", "content": "trigger error"},
],
tools=tools,
model="test-model",
max_iterations=5,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
injection_callback=inject_cb,
))
assert result.had_injections is True
assert result.final_content == "recovered answer"
injected = [
m for m in result.messages
if m.get("role") == "user" and "follow-up after LLM error" in str(m.get("content", ""))
]
assert len(injected) == 1
@pytest.mark.asyncio
async def test_drain_injections_on_empty_final_response():
"""Pending injections should be drained when the runner exits due to empty response."""
from nanobot.agent.runner import AgentRunSpec, AgentRunner, _MAX_EMPTY_RETRIES
from nanobot.bus.events import InboundMessage
provider = MagicMock()
call_count = {"n": 0}
async def chat_with_retry(*, messages, **kwargs):
call_count["n"] += 1
if call_count["n"] <= _MAX_EMPTY_RETRIES + 1:
return LLMResponse(content="", tool_calls=[], usage={})
# After retries exhausted + injection drain, respond normally
return LLMResponse(content="answer after empty", tool_calls=[], usage={})
provider.chat_with_retry = chat_with_retry
tools = MagicMock()
tools.get_definitions.return_value = []
injection_queue = asyncio.Queue()
async def inject_cb():
items = []
while not injection_queue.empty():
items.append(await injection_queue.get())
return items
await injection_queue.put(
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up after empty")
)
runner = AgentRunner(provider)
result = await runner.run(AgentRunSpec(
initial_messages=[
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "previous response"},
{"role": "user", "content": "trigger empty"},
],
tools=tools,
model="test-model",
max_iterations=10,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
injection_callback=inject_cb,
))
assert result.had_injections is True
assert result.final_content == "answer after empty"
injected = [
m for m in result.messages
if m.get("role") == "user" and "follow-up after empty" in str(m.get("content", ""))
]
assert len(injected) == 1
@pytest.mark.asyncio
async def test_drain_injections_on_max_iterations():
"""Pending injections should be drained when the runner hits max_iterations.
Unlike other error paths, max_iterations cannot continue the loop, so
injections are appended to messages but not processed by the LLM.
The key point is they are consumed from the queue to prevent re-publish.
"""
from nanobot.agent.runner import AgentRunSpec, AgentRunner
from nanobot.bus.events import InboundMessage
provider = MagicMock()
call_count = {"n": 0}
async def chat_with_retry(*, messages, **kwargs):
call_count["n"] += 1
return LLMResponse(
content="",
tool_calls=[ToolCallRequest(id=f"c{call_count['n']}", name="read_file", arguments={"path": "x"})],
usage={},
)
provider.chat_with_retry = chat_with_retry
tools = MagicMock()
tools.get_definitions.return_value = []
tools.execute = AsyncMock(return_value="file content")
injection_queue = asyncio.Queue()
async def inject_cb():
items = []
while not injection_queue.empty():
items.append(await injection_queue.get())
return items
await injection_queue.put(
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up after max iters")
)
runner = AgentRunner(provider)
result = await runner.run(AgentRunSpec(
initial_messages=[{"role": "user", "content": "hello"}],
tools=tools,
model="test-model",
max_iterations=2,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
injection_callback=inject_cb,
))
assert result.stop_reason == "max_iterations"
# The injection was consumed from the queue (preventing re-publish)
assert injection_queue.empty()
# The injection message is appended to conversation history
injected = [
m for m in result.messages
if m.get("role") == "user" and m.get("content") == "follow-up after max iters"
]
assert len(injected) == 1