mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-15 07:29:52 +00:00
fix(agent): drain injection queue on error/edge-case exit paths
When the agent runner exits due to LLM error, tool error, empty response, or max_iterations, it breaks out of the iteration loop without draining the pending injection queue. This causes leftover messages to be re-published as independent inbound messages, resulting in duplicate or confusing replies to the user. Extract the injection drain logic into a `_try_drain_injections` helper and call it before each break in the error/edge-case paths. If injections are found, continue the loop instead of breaking. For max_iterations (where the loop is exhausted), drain injections to prevent re-publish without continuing.
This commit is contained in:
parent
3c06db7e4e
commit
d849a3fa06
@ -134,6 +134,36 @@ class AgentRunner:
|
||||
continue
|
||||
messages.append(injection)
|
||||
|
||||
async def _try_drain_injections(
|
||||
self,
|
||||
spec: AgentRunSpec,
|
||||
messages: list[dict[str, Any]],
|
||||
assistant_message: dict[str, Any] | None,
|
||||
injection_cycles: int,
|
||||
*,
|
||||
phase: str = "after error",
|
||||
) -> tuple[bool, int]:
|
||||
"""Drain pending injections. Returns (should_continue, updated_cycles).
|
||||
|
||||
If injections are found and we haven't exceeded _MAX_INJECTION_CYCLES,
|
||||
append them to *messages* and return (True, cycles+1) so the caller
|
||||
continues the iteration loop. Otherwise return (False, cycles).
|
||||
"""
|
||||
if injection_cycles >= _MAX_INJECTION_CYCLES:
|
||||
return False, injection_cycles
|
||||
injections = await self._drain_injections(spec)
|
||||
if not injections:
|
||||
return False, injection_cycles
|
||||
injection_cycles += 1
|
||||
if assistant_message is not None:
|
||||
messages.append(assistant_message)
|
||||
self._append_injected_messages(messages, injections)
|
||||
logger.info(
|
||||
"Injected {} follow-up message(s) {} ({}/{})",
|
||||
len(injections), phase, injection_cycles, _MAX_INJECTION_CYCLES,
|
||||
)
|
||||
return True, injection_cycles
|
||||
|
||||
async def _drain_injections(self, spec: AgentRunSpec) -> list[dict[str, Any]]:
|
||||
"""Drain pending user messages via the injection callback.
|
||||
|
||||
@ -287,6 +317,13 @@ class AgentRunner:
|
||||
context.error = error
|
||||
context.stop_reason = stop_reason
|
||||
await hook.after_iteration(context)
|
||||
should_continue, injection_cycles = await self._try_drain_injections(
|
||||
spec, messages, None, injection_cycles,
|
||||
phase="after tool error",
|
||||
)
|
||||
if should_continue:
|
||||
had_injections = True
|
||||
continue
|
||||
break
|
||||
await self._emit_checkpoint(
|
||||
spec,
|
||||
@ -379,36 +416,31 @@ class AgentRunner:
|
||||
# Check for mid-turn injections BEFORE signaling stream end.
|
||||
# If injections are found we keep the stream alive (resuming=True)
|
||||
# so streaming channels don't prematurely finalize the card.
|
||||
_injected_after_final = False
|
||||
if injection_cycles < _MAX_INJECTION_CYCLES:
|
||||
injections = await self._drain_injections(spec)
|
||||
if injections:
|
||||
had_injections = True
|
||||
injection_cycles += 1
|
||||
_injected_after_final = True
|
||||
if assistant_message is not None:
|
||||
messages.append(assistant_message)
|
||||
await self._emit_checkpoint(
|
||||
spec,
|
||||
{
|
||||
"phase": "final_response",
|
||||
"iteration": iteration,
|
||||
"model": spec.model,
|
||||
"assistant_message": assistant_message,
|
||||
"completed_tool_results": [],
|
||||
"pending_tool_calls": [],
|
||||
},
|
||||
)
|
||||
self._append_injected_messages(messages, injections)
|
||||
logger.info(
|
||||
"Injected {} follow-up message(s) after final response ({}/{})",
|
||||
len(injections), injection_cycles, _MAX_INJECTION_CYCLES,
|
||||
should_continue, injection_cycles = await self._try_drain_injections(
|
||||
spec, messages, assistant_message, injection_cycles,
|
||||
phase="after final response",
|
||||
)
|
||||
if should_continue:
|
||||
had_injections = True
|
||||
# Emit checkpoint for the assistant message that was appended
|
||||
# by _try_drain_injections, then keep the stream alive.
|
||||
if assistant_message is not None:
|
||||
await self._emit_checkpoint(
|
||||
spec,
|
||||
{
|
||||
"phase": "final_response",
|
||||
"iteration": iteration,
|
||||
"model": spec.model,
|
||||
"assistant_message": assistant_message,
|
||||
"completed_tool_results": [],
|
||||
"pending_tool_calls": [],
|
||||
},
|
||||
)
|
||||
|
||||
if hook.wants_streaming():
|
||||
await hook.on_stream_end(context, resuming=_injected_after_final)
|
||||
await hook.on_stream_end(context, resuming=should_continue)
|
||||
|
||||
if _injected_after_final:
|
||||
if should_continue:
|
||||
await hook.after_iteration(context)
|
||||
continue
|
||||
|
||||
@ -421,6 +453,13 @@ class AgentRunner:
|
||||
context.error = error
|
||||
context.stop_reason = stop_reason
|
||||
await hook.after_iteration(context)
|
||||
should_continue, injection_cycles = await self._try_drain_injections(
|
||||
spec, messages, None, injection_cycles,
|
||||
phase="after LLM error",
|
||||
)
|
||||
if should_continue:
|
||||
had_injections = True
|
||||
continue
|
||||
break
|
||||
if is_blank_text(clean):
|
||||
final_content = EMPTY_FINAL_RESPONSE_MESSAGE
|
||||
@ -431,6 +470,13 @@ class AgentRunner:
|
||||
context.error = error
|
||||
context.stop_reason = stop_reason
|
||||
await hook.after_iteration(context)
|
||||
should_continue, injection_cycles = await self._try_drain_injections(
|
||||
spec, messages, None, injection_cycles,
|
||||
phase="after empty response",
|
||||
)
|
||||
if should_continue:
|
||||
had_injections = True
|
||||
continue
|
||||
break
|
||||
|
||||
messages.append(assistant_message or build_assistant_message(
|
||||
@ -467,6 +513,15 @@ class AgentRunner:
|
||||
max_iterations=spec.max_iterations,
|
||||
)
|
||||
self._append_final_message(messages, final_content)
|
||||
# Drain any remaining injections so they are appended to the
|
||||
# conversation history instead of being re-published as
|
||||
# independent inbound messages by _dispatch's finally block.
|
||||
# We ignore should_continue here because the for-loop has already
|
||||
# exhausted all iterations.
|
||||
_, injection_cycles = await self._try_drain_injections(
|
||||
spec, messages, None, injection_cycles,
|
||||
phase="after max_iterations",
|
||||
)
|
||||
|
||||
return AgentRunResult(
|
||||
final_content=final_content,
|
||||
|
||||
@ -2410,3 +2410,236 @@ async def test_dispatch_republishes_leftover_queue_messages(tmp_path):
|
||||
contents = [m.content for m in msgs]
|
||||
assert "leftover-1" in contents
|
||||
assert "leftover-2" in contents
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drain_injections_on_fatal_tool_error():
|
||||
"""Pending injections should be drained even when a fatal tool error occurs."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
from nanobot.bus.events import InboundMessage
|
||||
|
||||
provider = MagicMock()
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
return LLMResponse(
|
||||
content="",
|
||||
tool_calls=[ToolCallRequest(id="c1", name="exec", arguments={"cmd": "bad"})],
|
||||
usage={},
|
||||
)
|
||||
# Second call: respond normally to the injected follow-up
|
||||
return LLMResponse(content="reply to follow-up", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
tools.execute = AsyncMock(side_effect=RuntimeError("tool exploded"))
|
||||
|
||||
injection_queue = asyncio.Queue()
|
||||
|
||||
async def inject_cb():
|
||||
items = []
|
||||
while not injection_queue.empty():
|
||||
items.append(await injection_queue.get())
|
||||
return items
|
||||
|
||||
await injection_queue.put(
|
||||
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up after error")
|
||||
)
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "hello"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=5,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
fail_on_tool_error=True,
|
||||
injection_callback=inject_cb,
|
||||
))
|
||||
|
||||
assert result.had_injections is True
|
||||
assert result.final_content == "reply to follow-up"
|
||||
# The injection should be in the messages history
|
||||
injected = [
|
||||
m for m in result.messages
|
||||
if m.get("role") == "user" and m.get("content") == "follow-up after error"
|
||||
]
|
||||
assert len(injected) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drain_injections_on_llm_error():
|
||||
"""Pending injections should be drained when the LLM returns an error finish_reason."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
from nanobot.bus.events import InboundMessage
|
||||
|
||||
provider = MagicMock()
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
return LLMResponse(
|
||||
content=None,
|
||||
tool_calls=[],
|
||||
finish_reason="error",
|
||||
usage={},
|
||||
)
|
||||
# Second call: respond normally to the injected follow-up
|
||||
return LLMResponse(content="recovered answer", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
injection_queue = asyncio.Queue()
|
||||
|
||||
async def inject_cb():
|
||||
items = []
|
||||
while not injection_queue.empty():
|
||||
items.append(await injection_queue.get())
|
||||
return items
|
||||
|
||||
await injection_queue.put(
|
||||
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up after LLM error")
|
||||
)
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "previous response"},
|
||||
{"role": "user", "content": "trigger error"},
|
||||
],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=5,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
injection_callback=inject_cb,
|
||||
))
|
||||
|
||||
assert result.had_injections is True
|
||||
assert result.final_content == "recovered answer"
|
||||
injected = [
|
||||
m for m in result.messages
|
||||
if m.get("role") == "user" and "follow-up after LLM error" in str(m.get("content", ""))
|
||||
]
|
||||
assert len(injected) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drain_injections_on_empty_final_response():
|
||||
"""Pending injections should be drained when the runner exits due to empty response."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner, _MAX_EMPTY_RETRIES
|
||||
from nanobot.bus.events import InboundMessage
|
||||
|
||||
provider = MagicMock()
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] <= _MAX_EMPTY_RETRIES + 1:
|
||||
return LLMResponse(content="", tool_calls=[], usage={})
|
||||
# After retries exhausted + injection drain, respond normally
|
||||
return LLMResponse(content="answer after empty", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
injection_queue = asyncio.Queue()
|
||||
|
||||
async def inject_cb():
|
||||
items = []
|
||||
while not injection_queue.empty():
|
||||
items.append(await injection_queue.get())
|
||||
return items
|
||||
|
||||
await injection_queue.put(
|
||||
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up after empty")
|
||||
)
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "previous response"},
|
||||
{"role": "user", "content": "trigger empty"},
|
||||
],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=10,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
injection_callback=inject_cb,
|
||||
))
|
||||
|
||||
assert result.had_injections is True
|
||||
assert result.final_content == "answer after empty"
|
||||
injected = [
|
||||
m for m in result.messages
|
||||
if m.get("role") == "user" and "follow-up after empty" in str(m.get("content", ""))
|
||||
]
|
||||
assert len(injected) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drain_injections_on_max_iterations():
|
||||
"""Pending injections should be drained when the runner hits max_iterations.
|
||||
|
||||
Unlike other error paths, max_iterations cannot continue the loop, so
|
||||
injections are appended to messages but not processed by the LLM.
|
||||
The key point is they are consumed from the queue to prevent re-publish.
|
||||
"""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
from nanobot.bus.events import InboundMessage
|
||||
|
||||
provider = MagicMock()
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
call_count["n"] += 1
|
||||
return LLMResponse(
|
||||
content="",
|
||||
tool_calls=[ToolCallRequest(id=f"c{call_count['n']}", name="read_file", arguments={"path": "x"})],
|
||||
usage={},
|
||||
)
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
tools.execute = AsyncMock(return_value="file content")
|
||||
|
||||
injection_queue = asyncio.Queue()
|
||||
|
||||
async def inject_cb():
|
||||
items = []
|
||||
while not injection_queue.empty():
|
||||
items.append(await injection_queue.get())
|
||||
return items
|
||||
|
||||
await injection_queue.put(
|
||||
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up after max iters")
|
||||
)
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "hello"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=2,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
injection_callback=inject_cb,
|
||||
))
|
||||
|
||||
assert result.stop_reason == "max_iterations"
|
||||
# The injection was consumed from the queue (preventing re-publish)
|
||||
assert injection_queue.empty()
|
||||
# The injection message is appended to conversation history
|
||||
injected = [
|
||||
m for m in result.messages
|
||||
if m.get("role") == "user" and m.get("content") == "follow-up after max iters"
|
||||
]
|
||||
assert len(injected) == 1
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user