refactor(runner): consolidate all injection drain paths and deduplicate tests

- Migrate "after tools" inline drain to use _try_drain_injections,
  completing the refactoring (all 6 drain sites now use the helper).
- Move checkpoint emission into _try_drain_injections via optional
  iteration parameter, eliminating the leaky split between helper
  and caller for the final-response path.
- Extract _make_injection_callback() test helper to replace 7
  identical inject_cb function bodies.
- Add test_injection_cycle_cap_on_error_path to verify the cycle
  cap is enforced on error exit paths.
This commit is contained in:
chengyongru 2026-04-13 23:51:23 +08:00 committed by Xubin Ren
parent d849a3fa06
commit a1e1eed2f1
2 changed files with 90 additions and 68 deletions

View File

@ -142,12 +142,14 @@ class AgentRunner:
injection_cycles: int,
*,
phase: str = "after error",
iteration: int | None = None,
) -> tuple[bool, int]:
"""Drain pending injections. Returns (should_continue, updated_cycles).
If injections are found and we haven't exceeded _MAX_INJECTION_CYCLES,
append them to *messages* and return (True, cycles+1) so the caller
continues the iteration loop. Otherwise return (False, cycles).
append them to *messages* (and emit a checkpoint if *assistant_message*
and *iteration* are both provided) and return (True, cycles+1) so the
caller continues the iteration loop. Otherwise return (False, cycles).
"""
if injection_cycles >= _MAX_INJECTION_CYCLES:
return False, injection_cycles
@ -157,6 +159,18 @@ class AgentRunner:
injection_cycles += 1
if assistant_message is not None:
messages.append(assistant_message)
if iteration is not None:
await self._emit_checkpoint(
spec,
{
"phase": "final_response",
"iteration": iteration,
"model": spec.model,
"assistant_message": assistant_message,
"completed_tool_results": [],
"pending_tool_calls": [],
},
)
self._append_injected_messages(messages, injections)
logger.info(
"Injected {} follow-up message(s) {} ({}/{})",
@ -339,16 +353,12 @@ class AgentRunner:
empty_content_retries = 0
length_recovery_count = 0
# Checkpoint 1: drain injections after tools, before next LLM call
if injection_cycles < _MAX_INJECTION_CYCLES:
injections = await self._drain_injections(spec)
if injections:
had_injections = True
injection_cycles += 1
self._append_injected_messages(messages, injections)
logger.info(
"Injected {} follow-up message(s) after tool execution ({}/{})",
len(injections), injection_cycles, _MAX_INJECTION_CYCLES,
)
_drained, injection_cycles = await self._try_drain_injections(
spec, messages, None, injection_cycles,
phase="after tool execution",
)
if _drained:
had_injections = True
await hook.after_iteration(context)
continue
@ -419,23 +429,10 @@ class AgentRunner:
should_continue, injection_cycles = await self._try_drain_injections(
spec, messages, assistant_message, injection_cycles,
phase="after final response",
iteration=iteration,
)
if should_continue:
had_injections = True
# Emit checkpoint for the assistant message that was appended
# by _try_drain_injections, then keep the stream alive.
if assistant_message is not None:
await self._emit_checkpoint(
spec,
{
"phase": "final_response",
"iteration": iteration,
"model": spec.model,
"assistant_message": assistant_message,
"completed_tool_results": [],
"pending_tool_calls": [],
},
)
if hook.wants_streaming():
await hook.on_stream_end(context, resuming=should_continue)

View File

@ -18,6 +18,16 @@ from nanobot.providers.base import LLMResponse, ToolCallRequest
_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars
def _make_injection_callback(queue: asyncio.Queue):
"""Return an async callback that drains *queue* into a list of dicts."""
async def inject_cb():
items = []
while not queue.empty():
items.append(await queue.get())
return items
return inject_cb
def _make_loop(tmp_path):
from nanobot.agent.loop import AgentLoop
from nanobot.bus.queue import MessageBus
@ -1888,12 +1898,7 @@ async def test_checkpoint1_injects_after_tool_execution():
tools.execute = AsyncMock(return_value="file content")
injection_queue = asyncio.Queue()
async def inject_cb():
items = []
while not injection_queue.empty():
items.append(await injection_queue.get())
return items
inject_cb = _make_injection_callback(injection_queue)
# Put a follow-up message in the queue before the run starts
await injection_queue.put(
@ -1951,12 +1956,7 @@ async def test_checkpoint2_injects_after_final_response_with_resuming_stream():
tools.get_definitions.return_value = []
injection_queue = asyncio.Queue()
async def inject_cb():
items = []
while not injection_queue.empty():
items.append(await injection_queue.get())
return items
inject_cb = _make_injection_callback(injection_queue)
# Inject a follow-up that arrives during the first response
await injection_queue.put(
@ -2005,12 +2005,7 @@ async def test_checkpoint2_preserves_final_response_in_history_before_followup()
tools.get_definitions.return_value = []
injection_queue = asyncio.Queue()
async def inject_cb():
items = []
while not injection_queue.empty():
items.append(await injection_queue.get())
return items
inject_cb = _make_injection_callback(injection_queue)
await injection_queue.put(
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up question")
@ -2438,12 +2433,7 @@ async def test_drain_injections_on_fatal_tool_error():
tools.execute = AsyncMock(side_effect=RuntimeError("tool exploded"))
injection_queue = asyncio.Queue()
async def inject_cb():
items = []
while not injection_queue.empty():
items.append(await injection_queue.get())
return items
inject_cb = _make_injection_callback(injection_queue)
await injection_queue.put(
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up after error")
@ -2496,12 +2486,7 @@ async def test_drain_injections_on_llm_error():
tools.get_definitions.return_value = []
injection_queue = asyncio.Queue()
async def inject_cb():
items = []
while not injection_queue.empty():
items.append(await injection_queue.get())
return items
inject_cb = _make_injection_callback(injection_queue)
await injection_queue.put(
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up after LLM error")
@ -2551,12 +2536,7 @@ async def test_drain_injections_on_empty_final_response():
tools.get_definitions.return_value = []
injection_queue = asyncio.Queue()
async def inject_cb():
items = []
while not injection_queue.empty():
items.append(await injection_queue.get())
return items
inject_cb = _make_injection_callback(injection_queue)
await injection_queue.put(
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up after empty")
@ -2613,12 +2593,7 @@ async def test_drain_injections_on_max_iterations():
tools.execute = AsyncMock(return_value="file content")
injection_queue = asyncio.Queue()
async def inject_cb():
items = []
while not injection_queue.empty():
items.append(await injection_queue.get())
return items
inject_cb = _make_injection_callback(injection_queue)
await injection_queue.put(
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up after max iters")
@ -2643,3 +2618,53 @@ async def test_drain_injections_on_max_iterations():
if m.get("role") == "user" and m.get("content") == "follow-up after max iters"
]
assert len(injected) == 1
@pytest.mark.asyncio
async def test_injection_cycle_cap_on_error_path():
"""Injection cycles should be capped even when every iteration hits an LLM error."""
from nanobot.agent.runner import AgentRunSpec, AgentRunner, _MAX_INJECTION_CYCLES
from nanobot.bus.events import InboundMessage
provider = MagicMock()
call_count = {"n": 0}
async def chat_with_retry(*, messages, **kwargs):
call_count["n"] += 1
return LLMResponse(
content=None,
tool_calls=[],
finish_reason="error",
usage={},
)
provider.chat_with_retry = chat_with_retry
tools = MagicMock()
tools.get_definitions.return_value = []
drain_count = {"n": 0}
async def inject_cb():
drain_count["n"] += 1
if drain_count["n"] <= _MAX_INJECTION_CYCLES:
return [InboundMessage(channel="cli", sender_id="u", chat_id="c", content=f"msg-{drain_count['n']}")]
return []
runner = AgentRunner(provider)
result = await runner.run(AgentRunSpec(
initial_messages=[
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "previous"},
{"role": "user", "content": "trigger error"},
],
tools=tools,
model="test-model",
max_iterations=20,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
injection_callback=inject_cb,
))
assert result.had_injections is True
# Should cap: _MAX_INJECTION_CYCLES drained rounds + 1 final round that breaks
assert call_count["n"] == _MAX_INJECTION_CYCLES + 1
assert drain_count["n"] == _MAX_INJECTION_CYCLES