mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-02 07:45:54 +00:00
fix(agent): drain injection queue on error/edge-case exit paths
When the agent runner exits due to LLM error, tool error, empty response, or max_iterations, it breaks out of the iteration loop without draining the pending injection queue. This causes leftover messages to be re-published as independent inbound messages, resulting in duplicate or confusing replies to the user. Extract the injection drain logic into a `_try_drain_injections` helper and call it before each break in the error/edge-case paths. If injections are found, continue the loop instead of breaking. For max_iterations (where the loop is exhausted), drain injections to prevent re-publish without continuing.
This commit is contained in:
parent
3c06db7e4e
commit
d849a3fa06
@ -134,6 +134,36 @@ class AgentRunner:
|
|||||||
continue
|
continue
|
||||||
messages.append(injection)
|
messages.append(injection)
|
||||||
|
|
||||||
|
async def _try_drain_injections(
|
||||||
|
self,
|
||||||
|
spec: AgentRunSpec,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
assistant_message: dict[str, Any] | None,
|
||||||
|
injection_cycles: int,
|
||||||
|
*,
|
||||||
|
phase: str = "after error",
|
||||||
|
) -> tuple[bool, int]:
|
||||||
|
"""Drain pending injections. Returns (should_continue, updated_cycles).
|
||||||
|
|
||||||
|
If injections are found and we haven't exceeded _MAX_INJECTION_CYCLES,
|
||||||
|
append them to *messages* and return (True, cycles+1) so the caller
|
||||||
|
continues the iteration loop. Otherwise return (False, cycles).
|
||||||
|
"""
|
||||||
|
if injection_cycles >= _MAX_INJECTION_CYCLES:
|
||||||
|
return False, injection_cycles
|
||||||
|
injections = await self._drain_injections(spec)
|
||||||
|
if not injections:
|
||||||
|
return False, injection_cycles
|
||||||
|
injection_cycles += 1
|
||||||
|
if assistant_message is not None:
|
||||||
|
messages.append(assistant_message)
|
||||||
|
self._append_injected_messages(messages, injections)
|
||||||
|
logger.info(
|
||||||
|
"Injected {} follow-up message(s) {} ({}/{})",
|
||||||
|
len(injections), phase, injection_cycles, _MAX_INJECTION_CYCLES,
|
||||||
|
)
|
||||||
|
return True, injection_cycles
|
||||||
|
|
||||||
async def _drain_injections(self, spec: AgentRunSpec) -> list[dict[str, Any]]:
|
async def _drain_injections(self, spec: AgentRunSpec) -> list[dict[str, Any]]:
|
||||||
"""Drain pending user messages via the injection callback.
|
"""Drain pending user messages via the injection callback.
|
||||||
|
|
||||||
@ -287,6 +317,13 @@ class AgentRunner:
|
|||||||
context.error = error
|
context.error = error
|
||||||
context.stop_reason = stop_reason
|
context.stop_reason = stop_reason
|
||||||
await hook.after_iteration(context)
|
await hook.after_iteration(context)
|
||||||
|
should_continue, injection_cycles = await self._try_drain_injections(
|
||||||
|
spec, messages, None, injection_cycles,
|
||||||
|
phase="after tool error",
|
||||||
|
)
|
||||||
|
if should_continue:
|
||||||
|
had_injections = True
|
||||||
|
continue
|
||||||
break
|
break
|
||||||
await self._emit_checkpoint(
|
await self._emit_checkpoint(
|
||||||
spec,
|
spec,
|
||||||
@ -379,36 +416,31 @@ class AgentRunner:
|
|||||||
# Check for mid-turn injections BEFORE signaling stream end.
|
# Check for mid-turn injections BEFORE signaling stream end.
|
||||||
# If injections are found we keep the stream alive (resuming=True)
|
# If injections are found we keep the stream alive (resuming=True)
|
||||||
# so streaming channels don't prematurely finalize the card.
|
# so streaming channels don't prematurely finalize the card.
|
||||||
_injected_after_final = False
|
should_continue, injection_cycles = await self._try_drain_injections(
|
||||||
if injection_cycles < _MAX_INJECTION_CYCLES:
|
spec, messages, assistant_message, injection_cycles,
|
||||||
injections = await self._drain_injections(spec)
|
phase="after final response",
|
||||||
if injections:
|
)
|
||||||
had_injections = True
|
if should_continue:
|
||||||
injection_cycles += 1
|
had_injections = True
|
||||||
_injected_after_final = True
|
# Emit checkpoint for the assistant message that was appended
|
||||||
if assistant_message is not None:
|
# by _try_drain_injections, then keep the stream alive.
|
||||||
messages.append(assistant_message)
|
if assistant_message is not None:
|
||||||
await self._emit_checkpoint(
|
await self._emit_checkpoint(
|
||||||
spec,
|
spec,
|
||||||
{
|
{
|
||||||
"phase": "final_response",
|
"phase": "final_response",
|
||||||
"iteration": iteration,
|
"iteration": iteration,
|
||||||
"model": spec.model,
|
"model": spec.model,
|
||||||
"assistant_message": assistant_message,
|
"assistant_message": assistant_message,
|
||||||
"completed_tool_results": [],
|
"completed_tool_results": [],
|
||||||
"pending_tool_calls": [],
|
"pending_tool_calls": [],
|
||||||
},
|
},
|
||||||
)
|
|
||||||
self._append_injected_messages(messages, injections)
|
|
||||||
logger.info(
|
|
||||||
"Injected {} follow-up message(s) after final response ({}/{})",
|
|
||||||
len(injections), injection_cycles, _MAX_INJECTION_CYCLES,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if hook.wants_streaming():
|
if hook.wants_streaming():
|
||||||
await hook.on_stream_end(context, resuming=_injected_after_final)
|
await hook.on_stream_end(context, resuming=should_continue)
|
||||||
|
|
||||||
if _injected_after_final:
|
if should_continue:
|
||||||
await hook.after_iteration(context)
|
await hook.after_iteration(context)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -421,6 +453,13 @@ class AgentRunner:
|
|||||||
context.error = error
|
context.error = error
|
||||||
context.stop_reason = stop_reason
|
context.stop_reason = stop_reason
|
||||||
await hook.after_iteration(context)
|
await hook.after_iteration(context)
|
||||||
|
should_continue, injection_cycles = await self._try_drain_injections(
|
||||||
|
spec, messages, None, injection_cycles,
|
||||||
|
phase="after LLM error",
|
||||||
|
)
|
||||||
|
if should_continue:
|
||||||
|
had_injections = True
|
||||||
|
continue
|
||||||
break
|
break
|
||||||
if is_blank_text(clean):
|
if is_blank_text(clean):
|
||||||
final_content = EMPTY_FINAL_RESPONSE_MESSAGE
|
final_content = EMPTY_FINAL_RESPONSE_MESSAGE
|
||||||
@ -431,6 +470,13 @@ class AgentRunner:
|
|||||||
context.error = error
|
context.error = error
|
||||||
context.stop_reason = stop_reason
|
context.stop_reason = stop_reason
|
||||||
await hook.after_iteration(context)
|
await hook.after_iteration(context)
|
||||||
|
should_continue, injection_cycles = await self._try_drain_injections(
|
||||||
|
spec, messages, None, injection_cycles,
|
||||||
|
phase="after empty response",
|
||||||
|
)
|
||||||
|
if should_continue:
|
||||||
|
had_injections = True
|
||||||
|
continue
|
||||||
break
|
break
|
||||||
|
|
||||||
messages.append(assistant_message or build_assistant_message(
|
messages.append(assistant_message or build_assistant_message(
|
||||||
@ -467,6 +513,15 @@ class AgentRunner:
|
|||||||
max_iterations=spec.max_iterations,
|
max_iterations=spec.max_iterations,
|
||||||
)
|
)
|
||||||
self._append_final_message(messages, final_content)
|
self._append_final_message(messages, final_content)
|
||||||
|
# Drain any remaining injections so they are appended to the
|
||||||
|
# conversation history instead of being re-published as
|
||||||
|
# independent inbound messages by _dispatch's finally block.
|
||||||
|
# We ignore should_continue here because the for-loop has already
|
||||||
|
# exhausted all iterations.
|
||||||
|
_, injection_cycles = await self._try_drain_injections(
|
||||||
|
spec, messages, None, injection_cycles,
|
||||||
|
phase="after max_iterations",
|
||||||
|
)
|
||||||
|
|
||||||
return AgentRunResult(
|
return AgentRunResult(
|
||||||
final_content=final_content,
|
final_content=final_content,
|
||||||
|
|||||||
@ -2410,3 +2410,236 @@ async def test_dispatch_republishes_leftover_queue_messages(tmp_path):
|
|||||||
contents = [m.content for m in msgs]
|
contents = [m.content for m in msgs]
|
||||||
assert "leftover-1" in contents
|
assert "leftover-1" in contents
|
||||||
assert "leftover-2" in contents
|
assert "leftover-2" in contents
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_drain_injections_on_fatal_tool_error():
|
||||||
|
"""Pending injections should be drained even when a fatal tool error occurs."""
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
|
from nanobot.bus.events import InboundMessage
|
||||||
|
|
||||||
|
provider = MagicMock()
|
||||||
|
call_count = {"n": 0}
|
||||||
|
|
||||||
|
async def chat_with_retry(*, messages, **kwargs):
|
||||||
|
call_count["n"] += 1
|
||||||
|
if call_count["n"] == 1:
|
||||||
|
return LLMResponse(
|
||||||
|
content="",
|
||||||
|
tool_calls=[ToolCallRequest(id="c1", name="exec", arguments={"cmd": "bad"})],
|
||||||
|
usage={},
|
||||||
|
)
|
||||||
|
# Second call: respond normally to the injected follow-up
|
||||||
|
return LLMResponse(content="reply to follow-up", tool_calls=[], usage={})
|
||||||
|
|
||||||
|
provider.chat_with_retry = chat_with_retry
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
tools.execute = AsyncMock(side_effect=RuntimeError("tool exploded"))
|
||||||
|
|
||||||
|
injection_queue = asyncio.Queue()
|
||||||
|
|
||||||
|
async def inject_cb():
|
||||||
|
items = []
|
||||||
|
while not injection_queue.empty():
|
||||||
|
items.append(await injection_queue.get())
|
||||||
|
return items
|
||||||
|
|
||||||
|
await injection_queue.put(
|
||||||
|
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up after error")
|
||||||
|
)
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
result = await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=[{"role": "user", "content": "hello"}],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=5,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
fail_on_tool_error=True,
|
||||||
|
injection_callback=inject_cb,
|
||||||
|
))
|
||||||
|
|
||||||
|
assert result.had_injections is True
|
||||||
|
assert result.final_content == "reply to follow-up"
|
||||||
|
# The injection should be in the messages history
|
||||||
|
injected = [
|
||||||
|
m for m in result.messages
|
||||||
|
if m.get("role") == "user" and m.get("content") == "follow-up after error"
|
||||||
|
]
|
||||||
|
assert len(injected) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_drain_injections_on_llm_error():
|
||||||
|
"""Pending injections should be drained when the LLM returns an error finish_reason."""
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
|
from nanobot.bus.events import InboundMessage
|
||||||
|
|
||||||
|
provider = MagicMock()
|
||||||
|
call_count = {"n": 0}
|
||||||
|
|
||||||
|
async def chat_with_retry(*, messages, **kwargs):
|
||||||
|
call_count["n"] += 1
|
||||||
|
if call_count["n"] == 1:
|
||||||
|
return LLMResponse(
|
||||||
|
content=None,
|
||||||
|
tool_calls=[],
|
||||||
|
finish_reason="error",
|
||||||
|
usage={},
|
||||||
|
)
|
||||||
|
# Second call: respond normally to the injected follow-up
|
||||||
|
return LLMResponse(content="recovered answer", tool_calls=[], usage={})
|
||||||
|
|
||||||
|
provider.chat_with_retry = chat_with_retry
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
|
||||||
|
injection_queue = asyncio.Queue()
|
||||||
|
|
||||||
|
async def inject_cb():
|
||||||
|
items = []
|
||||||
|
while not injection_queue.empty():
|
||||||
|
items.append(await injection_queue.get())
|
||||||
|
return items
|
||||||
|
|
||||||
|
await injection_queue.put(
|
||||||
|
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up after LLM error")
|
||||||
|
)
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
result = await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=[
|
||||||
|
{"role": "user", "content": "hello"},
|
||||||
|
{"role": "assistant", "content": "previous response"},
|
||||||
|
{"role": "user", "content": "trigger error"},
|
||||||
|
],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=5,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
injection_callback=inject_cb,
|
||||||
|
))
|
||||||
|
|
||||||
|
assert result.had_injections is True
|
||||||
|
assert result.final_content == "recovered answer"
|
||||||
|
injected = [
|
||||||
|
m for m in result.messages
|
||||||
|
if m.get("role") == "user" and "follow-up after LLM error" in str(m.get("content", ""))
|
||||||
|
]
|
||||||
|
assert len(injected) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_drain_injections_on_empty_final_response():
|
||||||
|
"""Pending injections should be drained when the runner exits due to empty response."""
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner, _MAX_EMPTY_RETRIES
|
||||||
|
from nanobot.bus.events import InboundMessage
|
||||||
|
|
||||||
|
provider = MagicMock()
|
||||||
|
call_count = {"n": 0}
|
||||||
|
|
||||||
|
async def chat_with_retry(*, messages, **kwargs):
|
||||||
|
call_count["n"] += 1
|
||||||
|
if call_count["n"] <= _MAX_EMPTY_RETRIES + 1:
|
||||||
|
return LLMResponse(content="", tool_calls=[], usage={})
|
||||||
|
# After retries exhausted + injection drain, respond normally
|
||||||
|
return LLMResponse(content="answer after empty", tool_calls=[], usage={})
|
||||||
|
|
||||||
|
provider.chat_with_retry = chat_with_retry
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
|
||||||
|
injection_queue = asyncio.Queue()
|
||||||
|
|
||||||
|
async def inject_cb():
|
||||||
|
items = []
|
||||||
|
while not injection_queue.empty():
|
||||||
|
items.append(await injection_queue.get())
|
||||||
|
return items
|
||||||
|
|
||||||
|
await injection_queue.put(
|
||||||
|
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up after empty")
|
||||||
|
)
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
result = await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=[
|
||||||
|
{"role": "user", "content": "hello"},
|
||||||
|
{"role": "assistant", "content": "previous response"},
|
||||||
|
{"role": "user", "content": "trigger empty"},
|
||||||
|
],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=10,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
injection_callback=inject_cb,
|
||||||
|
))
|
||||||
|
|
||||||
|
assert result.had_injections is True
|
||||||
|
assert result.final_content == "answer after empty"
|
||||||
|
injected = [
|
||||||
|
m for m in result.messages
|
||||||
|
if m.get("role") == "user" and "follow-up after empty" in str(m.get("content", ""))
|
||||||
|
]
|
||||||
|
assert len(injected) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_drain_injections_on_max_iterations():
|
||||||
|
"""Pending injections should be drained when the runner hits max_iterations.
|
||||||
|
|
||||||
|
Unlike other error paths, max_iterations cannot continue the loop, so
|
||||||
|
injections are appended to messages but not processed by the LLM.
|
||||||
|
The key point is they are consumed from the queue to prevent re-publish.
|
||||||
|
"""
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
|
from nanobot.bus.events import InboundMessage
|
||||||
|
|
||||||
|
provider = MagicMock()
|
||||||
|
call_count = {"n": 0}
|
||||||
|
|
||||||
|
async def chat_with_retry(*, messages, **kwargs):
|
||||||
|
call_count["n"] += 1
|
||||||
|
return LLMResponse(
|
||||||
|
content="",
|
||||||
|
tool_calls=[ToolCallRequest(id=f"c{call_count['n']}", name="read_file", arguments={"path": "x"})],
|
||||||
|
usage={},
|
||||||
|
)
|
||||||
|
|
||||||
|
provider.chat_with_retry = chat_with_retry
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
tools.execute = AsyncMock(return_value="file content")
|
||||||
|
|
||||||
|
injection_queue = asyncio.Queue()
|
||||||
|
|
||||||
|
async def inject_cb():
|
||||||
|
items = []
|
||||||
|
while not injection_queue.empty():
|
||||||
|
items.append(await injection_queue.get())
|
||||||
|
return items
|
||||||
|
|
||||||
|
await injection_queue.put(
|
||||||
|
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up after max iters")
|
||||||
|
)
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
result = await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=[{"role": "user", "content": "hello"}],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=2,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
injection_callback=inject_cb,
|
||||||
|
))
|
||||||
|
|
||||||
|
assert result.stop_reason == "max_iterations"
|
||||||
|
# The injection was consumed from the queue (preventing re-publish)
|
||||||
|
assert injection_queue.empty()
|
||||||
|
# The injection message is appended to conversation history
|
||||||
|
injected = [
|
||||||
|
m for m in result.messages
|
||||||
|
if m.get("role") == "user" and m.get("content") == "follow-up after max iters"
|
||||||
|
]
|
||||||
|
assert len(injected) == 1
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user