mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-30 06:45:55 +00:00
refactor(runner): consolidate all injection drain paths and deduplicate tests
- Migrate "after tools" inline drain to use _try_drain_injections, completing the refactoring (all 6 drain sites now use the helper). - Move checkpoint emission into _try_drain_injections via optional iteration parameter, eliminating the leaky split between helper and caller for the final-response path. - Extract _make_injection_callback() test helper to replace 7 identical inject_cb function bodies. - Add test_injection_cycle_cap_on_error_path to verify the cycle cap is enforced on error exit paths.
This commit is contained in:
parent
d849a3fa06
commit
a1e1eed2f1
@ -142,12 +142,14 @@ class AgentRunner:
|
|||||||
injection_cycles: int,
|
injection_cycles: int,
|
||||||
*,
|
*,
|
||||||
phase: str = "after error",
|
phase: str = "after error",
|
||||||
|
iteration: int | None = None,
|
||||||
) -> tuple[bool, int]:
|
) -> tuple[bool, int]:
|
||||||
"""Drain pending injections. Returns (should_continue, updated_cycles).
|
"""Drain pending injections. Returns (should_continue, updated_cycles).
|
||||||
|
|
||||||
If injections are found and we haven't exceeded _MAX_INJECTION_CYCLES,
|
If injections are found and we haven't exceeded _MAX_INJECTION_CYCLES,
|
||||||
append them to *messages* and return (True, cycles+1) so the caller
|
append them to *messages* (and emit a checkpoint if *assistant_message*
|
||||||
continues the iteration loop. Otherwise return (False, cycles).
|
and *iteration* are both provided) and return (True, cycles+1) so the
|
||||||
|
caller continues the iteration loop. Otherwise return (False, cycles).
|
||||||
"""
|
"""
|
||||||
if injection_cycles >= _MAX_INJECTION_CYCLES:
|
if injection_cycles >= _MAX_INJECTION_CYCLES:
|
||||||
return False, injection_cycles
|
return False, injection_cycles
|
||||||
@ -157,6 +159,18 @@ class AgentRunner:
|
|||||||
injection_cycles += 1
|
injection_cycles += 1
|
||||||
if assistant_message is not None:
|
if assistant_message is not None:
|
||||||
messages.append(assistant_message)
|
messages.append(assistant_message)
|
||||||
|
if iteration is not None:
|
||||||
|
await self._emit_checkpoint(
|
||||||
|
spec,
|
||||||
|
{
|
||||||
|
"phase": "final_response",
|
||||||
|
"iteration": iteration,
|
||||||
|
"model": spec.model,
|
||||||
|
"assistant_message": assistant_message,
|
||||||
|
"completed_tool_results": [],
|
||||||
|
"pending_tool_calls": [],
|
||||||
|
},
|
||||||
|
)
|
||||||
self._append_injected_messages(messages, injections)
|
self._append_injected_messages(messages, injections)
|
||||||
logger.info(
|
logger.info(
|
||||||
"Injected {} follow-up message(s) {} ({}/{})",
|
"Injected {} follow-up message(s) {} ({}/{})",
|
||||||
@ -339,16 +353,12 @@ class AgentRunner:
|
|||||||
empty_content_retries = 0
|
empty_content_retries = 0
|
||||||
length_recovery_count = 0
|
length_recovery_count = 0
|
||||||
# Checkpoint 1: drain injections after tools, before next LLM call
|
# Checkpoint 1: drain injections after tools, before next LLM call
|
||||||
if injection_cycles < _MAX_INJECTION_CYCLES:
|
_drained, injection_cycles = await self._try_drain_injections(
|
||||||
injections = await self._drain_injections(spec)
|
spec, messages, None, injection_cycles,
|
||||||
if injections:
|
phase="after tool execution",
|
||||||
had_injections = True
|
)
|
||||||
injection_cycles += 1
|
if _drained:
|
||||||
self._append_injected_messages(messages, injections)
|
had_injections = True
|
||||||
logger.info(
|
|
||||||
"Injected {} follow-up message(s) after tool execution ({}/{})",
|
|
||||||
len(injections), injection_cycles, _MAX_INJECTION_CYCLES,
|
|
||||||
)
|
|
||||||
await hook.after_iteration(context)
|
await hook.after_iteration(context)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -419,23 +429,10 @@ class AgentRunner:
|
|||||||
should_continue, injection_cycles = await self._try_drain_injections(
|
should_continue, injection_cycles = await self._try_drain_injections(
|
||||||
spec, messages, assistant_message, injection_cycles,
|
spec, messages, assistant_message, injection_cycles,
|
||||||
phase="after final response",
|
phase="after final response",
|
||||||
|
iteration=iteration,
|
||||||
)
|
)
|
||||||
if should_continue:
|
if should_continue:
|
||||||
had_injections = True
|
had_injections = True
|
||||||
# Emit checkpoint for the assistant message that was appended
|
|
||||||
# by _try_drain_injections, then keep the stream alive.
|
|
||||||
if assistant_message is not None:
|
|
||||||
await self._emit_checkpoint(
|
|
||||||
spec,
|
|
||||||
{
|
|
||||||
"phase": "final_response",
|
|
||||||
"iteration": iteration,
|
|
||||||
"model": spec.model,
|
|
||||||
"assistant_message": assistant_message,
|
|
||||||
"completed_tool_results": [],
|
|
||||||
"pending_tool_calls": [],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
if hook.wants_streaming():
|
if hook.wants_streaming():
|
||||||
await hook.on_stream_end(context, resuming=should_continue)
|
await hook.on_stream_end(context, resuming=should_continue)
|
||||||
|
|||||||
@ -18,6 +18,16 @@ from nanobot.providers.base import LLMResponse, ToolCallRequest
|
|||||||
_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars
|
_MAX_TOOL_RESULT_CHARS = AgentDefaults().max_tool_result_chars
|
||||||
|
|
||||||
|
|
||||||
|
def _make_injection_callback(queue: asyncio.Queue):
|
||||||
|
"""Return an async callback that drains *queue* into a list of dicts."""
|
||||||
|
async def inject_cb():
|
||||||
|
items = []
|
||||||
|
while not queue.empty():
|
||||||
|
items.append(await queue.get())
|
||||||
|
return items
|
||||||
|
return inject_cb
|
||||||
|
|
||||||
|
|
||||||
def _make_loop(tmp_path):
|
def _make_loop(tmp_path):
|
||||||
from nanobot.agent.loop import AgentLoop
|
from nanobot.agent.loop import AgentLoop
|
||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
@ -1888,12 +1898,7 @@ async def test_checkpoint1_injects_after_tool_execution():
|
|||||||
tools.execute = AsyncMock(return_value="file content")
|
tools.execute = AsyncMock(return_value="file content")
|
||||||
|
|
||||||
injection_queue = asyncio.Queue()
|
injection_queue = asyncio.Queue()
|
||||||
|
inject_cb = _make_injection_callback(injection_queue)
|
||||||
async def inject_cb():
|
|
||||||
items = []
|
|
||||||
while not injection_queue.empty():
|
|
||||||
items.append(await injection_queue.get())
|
|
||||||
return items
|
|
||||||
|
|
||||||
# Put a follow-up message in the queue before the run starts
|
# Put a follow-up message in the queue before the run starts
|
||||||
await injection_queue.put(
|
await injection_queue.put(
|
||||||
@ -1951,12 +1956,7 @@ async def test_checkpoint2_injects_after_final_response_with_resuming_stream():
|
|||||||
tools.get_definitions.return_value = []
|
tools.get_definitions.return_value = []
|
||||||
|
|
||||||
injection_queue = asyncio.Queue()
|
injection_queue = asyncio.Queue()
|
||||||
|
inject_cb = _make_injection_callback(injection_queue)
|
||||||
async def inject_cb():
|
|
||||||
items = []
|
|
||||||
while not injection_queue.empty():
|
|
||||||
items.append(await injection_queue.get())
|
|
||||||
return items
|
|
||||||
|
|
||||||
# Inject a follow-up that arrives during the first response
|
# Inject a follow-up that arrives during the first response
|
||||||
await injection_queue.put(
|
await injection_queue.put(
|
||||||
@ -2005,12 +2005,7 @@ async def test_checkpoint2_preserves_final_response_in_history_before_followup()
|
|||||||
tools.get_definitions.return_value = []
|
tools.get_definitions.return_value = []
|
||||||
|
|
||||||
injection_queue = asyncio.Queue()
|
injection_queue = asyncio.Queue()
|
||||||
|
inject_cb = _make_injection_callback(injection_queue)
|
||||||
async def inject_cb():
|
|
||||||
items = []
|
|
||||||
while not injection_queue.empty():
|
|
||||||
items.append(await injection_queue.get())
|
|
||||||
return items
|
|
||||||
|
|
||||||
await injection_queue.put(
|
await injection_queue.put(
|
||||||
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up question")
|
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up question")
|
||||||
@ -2438,12 +2433,7 @@ async def test_drain_injections_on_fatal_tool_error():
|
|||||||
tools.execute = AsyncMock(side_effect=RuntimeError("tool exploded"))
|
tools.execute = AsyncMock(side_effect=RuntimeError("tool exploded"))
|
||||||
|
|
||||||
injection_queue = asyncio.Queue()
|
injection_queue = asyncio.Queue()
|
||||||
|
inject_cb = _make_injection_callback(injection_queue)
|
||||||
async def inject_cb():
|
|
||||||
items = []
|
|
||||||
while not injection_queue.empty():
|
|
||||||
items.append(await injection_queue.get())
|
|
||||||
return items
|
|
||||||
|
|
||||||
await injection_queue.put(
|
await injection_queue.put(
|
||||||
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up after error")
|
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up after error")
|
||||||
@ -2496,12 +2486,7 @@ async def test_drain_injections_on_llm_error():
|
|||||||
tools.get_definitions.return_value = []
|
tools.get_definitions.return_value = []
|
||||||
|
|
||||||
injection_queue = asyncio.Queue()
|
injection_queue = asyncio.Queue()
|
||||||
|
inject_cb = _make_injection_callback(injection_queue)
|
||||||
async def inject_cb():
|
|
||||||
items = []
|
|
||||||
while not injection_queue.empty():
|
|
||||||
items.append(await injection_queue.get())
|
|
||||||
return items
|
|
||||||
|
|
||||||
await injection_queue.put(
|
await injection_queue.put(
|
||||||
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up after LLM error")
|
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up after LLM error")
|
||||||
@ -2551,12 +2536,7 @@ async def test_drain_injections_on_empty_final_response():
|
|||||||
tools.get_definitions.return_value = []
|
tools.get_definitions.return_value = []
|
||||||
|
|
||||||
injection_queue = asyncio.Queue()
|
injection_queue = asyncio.Queue()
|
||||||
|
inject_cb = _make_injection_callback(injection_queue)
|
||||||
async def inject_cb():
|
|
||||||
items = []
|
|
||||||
while not injection_queue.empty():
|
|
||||||
items.append(await injection_queue.get())
|
|
||||||
return items
|
|
||||||
|
|
||||||
await injection_queue.put(
|
await injection_queue.put(
|
||||||
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up after empty")
|
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up after empty")
|
||||||
@ -2613,12 +2593,7 @@ async def test_drain_injections_on_max_iterations():
|
|||||||
tools.execute = AsyncMock(return_value="file content")
|
tools.execute = AsyncMock(return_value="file content")
|
||||||
|
|
||||||
injection_queue = asyncio.Queue()
|
injection_queue = asyncio.Queue()
|
||||||
|
inject_cb = _make_injection_callback(injection_queue)
|
||||||
async def inject_cb():
|
|
||||||
items = []
|
|
||||||
while not injection_queue.empty():
|
|
||||||
items.append(await injection_queue.get())
|
|
||||||
return items
|
|
||||||
|
|
||||||
await injection_queue.put(
|
await injection_queue.put(
|
||||||
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up after max iters")
|
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up after max iters")
|
||||||
@ -2643,3 +2618,53 @@ async def test_drain_injections_on_max_iterations():
|
|||||||
if m.get("role") == "user" and m.get("content") == "follow-up after max iters"
|
if m.get("role") == "user" and m.get("content") == "follow-up after max iters"
|
||||||
]
|
]
|
||||||
assert len(injected) == 1
|
assert len(injected) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_injection_cycle_cap_on_error_path():
|
||||||
|
"""Injection cycles should be capped even when every iteration hits an LLM error."""
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner, _MAX_INJECTION_CYCLES
|
||||||
|
from nanobot.bus.events import InboundMessage
|
||||||
|
|
||||||
|
provider = MagicMock()
|
||||||
|
call_count = {"n": 0}
|
||||||
|
|
||||||
|
async def chat_with_retry(*, messages, **kwargs):
|
||||||
|
call_count["n"] += 1
|
||||||
|
return LLMResponse(
|
||||||
|
content=None,
|
||||||
|
tool_calls=[],
|
||||||
|
finish_reason="error",
|
||||||
|
usage={},
|
||||||
|
)
|
||||||
|
|
||||||
|
provider.chat_with_retry = chat_with_retry
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
|
||||||
|
drain_count = {"n": 0}
|
||||||
|
|
||||||
|
async def inject_cb():
|
||||||
|
drain_count["n"] += 1
|
||||||
|
if drain_count["n"] <= _MAX_INJECTION_CYCLES:
|
||||||
|
return [InboundMessage(channel="cli", sender_id="u", chat_id="c", content=f"msg-{drain_count['n']}")]
|
||||||
|
return []
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
result = await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=[
|
||||||
|
{"role": "user", "content": "hello"},
|
||||||
|
{"role": "assistant", "content": "previous"},
|
||||||
|
{"role": "user", "content": "trigger error"},
|
||||||
|
],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=20,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
injection_callback=inject_cb,
|
||||||
|
))
|
||||||
|
|
||||||
|
assert result.had_injections is True
|
||||||
|
# Should cap: _MAX_INJECTION_CYCLES drained rounds + 1 final round that breaks
|
||||||
|
assert call_count["n"] == _MAX_INJECTION_CYCLES + 1
|
||||||
|
assert drain_count["n"] == _MAX_INJECTION_CYCLES
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user