mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-15 07:29:52 +00:00
refactor(runner): consolidate all injection drain paths and deduplicate tests
- Migrate "after tools" inline drain to use _try_drain_injections, completing the refactoring (all 6 drain sites now use the helper). - Move checkpoint emission into _try_drain_injections via optional iteration parameter, eliminating the leaky split between helper and caller for the final-response path. - Extract _make_injection_callback() test helper to replace 7 identical inject_cb function bodies. - Add test_injection_cycle_cap_on_error_path to verify the cycle cap is enforced on error exit paths.
This commit is contained in:
parent
d849a3fa06
commit
a1e1eed2f1
@ -142,12 +142,14 @@ class AgentRunner:
|
||||
injection_cycles: int,
|
||||
*,
|
||||
phase: str = "after error",
|
||||
iteration: int | None = None,
|
||||
) -> tuple[bool, int]:
|
||||
"""Drain pending injections. Returns (should_continue, updated_cycles).
|
||||
|
||||
If injections are found and we haven't exceeded _MAX_INJECTION_CYCLES,
|
||||
append them to *messages* and return (True, cycles+1) so the caller
|
||||
continues the iteration loop. Otherwise return (False, cycles).
|
||||
append them to *messages* (and emit a checkpoint if *assistant_message*
|
||||
and *iteration* are both provided) and return (True, cycles+1) so the
|
||||
caller continues the iteration loop. Otherwise return (False, cycles).
|
||||
"""
|
||||
if injection_cycles >= _MAX_INJECTION_CYCLES:
|
||||
return False, injection_cycles
|
||||
@ -157,6 +159,18 @@ class AgentRunner:
|
||||
injection_cycles += 1
|
||||
if assistant_message is not None:
|
||||
messages.append(assistant_message)
|
||||
if iteration is not None:
|
||||
await self._emit_checkpoint(
|
||||
spec,
|
||||
{
|
||||
"phase": "final_response",
|
||||
"iteration": iteration,
|
||||
"model": spec.model,
|
||||
"assistant_message": assistant_message,
|
||||
"completed_tool_results": [],
|
||||
"pending_tool_calls": [],
|
||||
},
|
||||
)
|
||||
self._append_injected_messages(messages, injections)
|
||||
logger.info(
|
||||
"Injected {} follow-up message(s) {} ({}/{})",
|
||||
@ -339,16 +353,12 @@ class AgentRunner:
|
||||
empty_content_retries = 0
|
||||
length_recovery_count = 0
|
||||
# Checkpoint 1: drain injections after tools, before next LLM call
|
||||
if injection_cycles < _MAX_INJECTION_CYCLES:
|
||||
injections = await self._drain_injections(spec)
|
||||
if injections:
|
||||
had_injections = True
|
||||
injection_cycles += 1
|
||||
self._append_injected_messages(messages, injections)
|
||||
logger.info(
|
||||
"Injected {} follow-up message(s) after tool execution ({}/{})",
|
||||
len(injections), injection_cycles, _MAX_INJECTION_CYCLES,
|
||||
)
|
||||
_drained, injection_cycles = await self._try_drain_injections(
|
||||
spec, messages, None, injection_cycles,
|
||||
phase="after tool execution",
|
||||
)
|
||||
if _drained:
|
||||
had_injections = True
|
||||
await hook.after_iteration(context)
|
||||
continue
|
||||
|
||||
@ -419,23 +429,10 @@ class AgentRunner:
|
||||
should_continue, injection_cycles = await self._try_drain_injections(
|
||||
spec, messages, assistant_message, injection_cycles,
|
||||
phase="after final response",
|
||||
iteration=iteration,
|
||||
)
|
||||
if should_continue:
|
||||
had_injections = True
|
||||
# Emit checkpoint for the assistant message that was appended
|
||||
# by _try_drain_injections, then keep the stream alive.
|
||||
if assistant_message is not None:
|
||||
await self._emit_checkpoint(
|
||||
spec,
|
||||
{
|
||||
"phase": "final_response",
|
||||
"iteration": iteration,
|
||||
"model": spec.model,
|
||||
"assistant_message": assistant_message,
|
||||
"completed_tool_results": [],
|
||||
"pending_tool_calls": [],
|
||||
},
|
||||
)
|
||||
|
||||
if hook.wants_streaming():
|
||||
await hook.on_stream_end(context, resuming=should_continue)
|
||||
|
||||
@ -18,6 +18,16 @@ from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||
_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars
|
||||
|
||||
|
||||
def _make_injection_callback(queue: asyncio.Queue):
|
||||
"""Return an async callback that drains *queue* into a list of dicts."""
|
||||
async def inject_cb():
|
||||
items = []
|
||||
while not queue.empty():
|
||||
items.append(await queue.get())
|
||||
return items
|
||||
return inject_cb
|
||||
|
||||
|
||||
def _make_loop(tmp_path):
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.queue import MessageBus
|
||||
@ -1888,12 +1898,7 @@ async def test_checkpoint1_injects_after_tool_execution():
|
||||
tools.execute = AsyncMock(return_value="file content")
|
||||
|
||||
injection_queue = asyncio.Queue()
|
||||
|
||||
async def inject_cb():
|
||||
items = []
|
||||
while not injection_queue.empty():
|
||||
items.append(await injection_queue.get())
|
||||
return items
|
||||
inject_cb = _make_injection_callback(injection_queue)
|
||||
|
||||
# Put a follow-up message in the queue before the run starts
|
||||
await injection_queue.put(
|
||||
@ -1951,12 +1956,7 @@ async def test_checkpoint2_injects_after_final_response_with_resuming_stream():
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
injection_queue = asyncio.Queue()
|
||||
|
||||
async def inject_cb():
|
||||
items = []
|
||||
while not injection_queue.empty():
|
||||
items.append(await injection_queue.get())
|
||||
return items
|
||||
inject_cb = _make_injection_callback(injection_queue)
|
||||
|
||||
# Inject a follow-up that arrives during the first response
|
||||
await injection_queue.put(
|
||||
@ -2005,12 +2005,7 @@ async def test_checkpoint2_preserves_final_response_in_history_before_followup()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
injection_queue = asyncio.Queue()
|
||||
|
||||
async def inject_cb():
|
||||
items = []
|
||||
while not injection_queue.empty():
|
||||
items.append(await injection_queue.get())
|
||||
return items
|
||||
inject_cb = _make_injection_callback(injection_queue)
|
||||
|
||||
await injection_queue.put(
|
||||
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up question")
|
||||
@ -2438,12 +2433,7 @@ async def test_drain_injections_on_fatal_tool_error():
|
||||
tools.execute = AsyncMock(side_effect=RuntimeError("tool exploded"))
|
||||
|
||||
injection_queue = asyncio.Queue()
|
||||
|
||||
async def inject_cb():
|
||||
items = []
|
||||
while not injection_queue.empty():
|
||||
items.append(await injection_queue.get())
|
||||
return items
|
||||
inject_cb = _make_injection_callback(injection_queue)
|
||||
|
||||
await injection_queue.put(
|
||||
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up after error")
|
||||
@ -2496,12 +2486,7 @@ async def test_drain_injections_on_llm_error():
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
injection_queue = asyncio.Queue()
|
||||
|
||||
async def inject_cb():
|
||||
items = []
|
||||
while not injection_queue.empty():
|
||||
items.append(await injection_queue.get())
|
||||
return items
|
||||
inject_cb = _make_injection_callback(injection_queue)
|
||||
|
||||
await injection_queue.put(
|
||||
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up after LLM error")
|
||||
@ -2551,12 +2536,7 @@ async def test_drain_injections_on_empty_final_response():
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
injection_queue = asyncio.Queue()
|
||||
|
||||
async def inject_cb():
|
||||
items = []
|
||||
while not injection_queue.empty():
|
||||
items.append(await injection_queue.get())
|
||||
return items
|
||||
inject_cb = _make_injection_callback(injection_queue)
|
||||
|
||||
await injection_queue.put(
|
||||
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up after empty")
|
||||
@ -2613,12 +2593,7 @@ async def test_drain_injections_on_max_iterations():
|
||||
tools.execute = AsyncMock(return_value="file content")
|
||||
|
||||
injection_queue = asyncio.Queue()
|
||||
|
||||
async def inject_cb():
|
||||
items = []
|
||||
while not injection_queue.empty():
|
||||
items.append(await injection_queue.get())
|
||||
return items
|
||||
inject_cb = _make_injection_callback(injection_queue)
|
||||
|
||||
await injection_queue.put(
|
||||
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up after max iters")
|
||||
@ -2643,3 +2618,53 @@ async def test_drain_injections_on_max_iterations():
|
||||
if m.get("role") == "user" and m.get("content") == "follow-up after max iters"
|
||||
]
|
||||
assert len(injected) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_injection_cycle_cap_on_error_path():
|
||||
"""Injection cycles should be capped even when every iteration hits an LLM error."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner, _MAX_INJECTION_CYCLES
|
||||
from nanobot.bus.events import InboundMessage
|
||||
|
||||
provider = MagicMock()
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
call_count["n"] += 1
|
||||
return LLMResponse(
|
||||
content=None,
|
||||
tool_calls=[],
|
||||
finish_reason="error",
|
||||
usage={},
|
||||
)
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
drain_count = {"n": 0}
|
||||
|
||||
async def inject_cb():
|
||||
drain_count["n"] += 1
|
||||
if drain_count["n"] <= _MAX_INJECTION_CYCLES:
|
||||
return [InboundMessage(channel="cli", sender_id="u", chat_id="c", content=f"msg-{drain_count['n']}")]
|
||||
return []
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "previous"},
|
||||
{"role": "user", "content": "trigger error"},
|
||||
],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=20,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
injection_callback=inject_cb,
|
||||
))
|
||||
|
||||
assert result.had_injections is True
|
||||
# Should cap: _MAX_INJECTION_CYCLES drained rounds + 1 final round that breaks
|
||||
assert call_count["n"] == _MAX_INJECTION_CYCLES + 1
|
||||
assert drain_count["n"] == _MAX_INJECTION_CYCLES
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user