feat(agent): mid-turn message injection for responsive follow-ups (#2985)

* feat(agent): add mid-turn message injection for responsive follow-ups

Allow user messages sent during an active agent turn to be injected
into the running LLM context instead of being queued behind a
per-session lock. Inspired by Claude Code's mid-turn queue drain
mechanism (query.ts:1547-1643).

Key design decisions:
- Messages are injected as natural user messages between iterations,
  no tool cancellation or special system prompt needed
- Two drain checkpoints: after tool execution and after final LLM
  response ("last-mile" to prevent dropping late arrivals)
- Bounded by MAX_INJECTION_CYCLES (5) to prevent consuming the
  iteration budget on rapid follow-ups
- had_injections flag bypasses _sent_in_turn suppression so follow-up
  responses are always delivered

Closes #1609

* fix(agent): harden mid-turn injection with streaming fix, bounded queue, and message safety

- Fix streaming protocol violation: Checkpoint 2 now checks for injections
  BEFORE calling on_stream_end, passing resuming=True when injections found
  so streaming channels (Feishu) don't prematurely finalize the card
- Bound pending queue to maxsize=20 with QueueFull handling
- Add warning log when injection batch exceeds _MAX_INJECTIONS_PER_TURN
- Re-publish leftover queue messages to bus in _dispatch finally block to
  prevent silent message loss on early exit (max_iterations, tool_error, cancel)
- Fix PEP 8 blank line before dataclass and logger.info indentation
- Add 12 new tests covering drain, checkpoints, cycle cap, queue routing,
  cleanup, and leftover re-publish
This commit is contained in:
chengyongru 2026-04-11 02:11:02 +08:00 committed by GitHub
parent df6f9dd171
commit bc4cc49a59
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 615 additions and 60 deletions

View File

@ -237,6 +237,10 @@ class AgentLoop:
self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks
self._background_tasks: list[asyncio.Task] = []
self._session_locks: dict[str, asyncio.Lock] = {}
# Per-session pending queues for mid-turn message injection.
# When a session has an active task, new messages for that session
# are routed here instead of creating a new task.
self._pending_queues: dict[str, asyncio.Queue] = {}
# NANOBOT_MAX_CONCURRENT_REQUESTS: <=0 means unlimited; default 3.
_max = int(os.environ.get("NANOBOT_MAX_CONCURRENT_REQUESTS", "3"))
self._concurrency_gate: asyncio.Semaphore | None = (
@ -348,13 +352,16 @@ class AgentLoop:
channel: str = "cli",
chat_id: str = "direct",
message_id: str | None = None,
) -> tuple[str | None, list[str], list[dict]]:
pending_queue: asyncio.Queue | None = None,
) -> tuple[str | None, list[str], list[dict], str, bool]:
"""Run the agent iteration loop.
*on_stream*: called with each content delta during streaming.
*on_stream_end(resuming)*: called when a streaming session finishes.
``resuming=True`` means tool calls follow (spinner should restart);
``resuming=False`` means this is the final response.
Returns (final_content, tools_used, messages, stop_reason, had_injections).
"""
loop_hook = _LoopHook(
self,
@ -376,6 +383,18 @@ class AgentLoop:
return
self._set_runtime_checkpoint(session, payload)
async def _drain_pending() -> list[InboundMessage]:
"""Non-blocking drain of follow-up messages from the pending queue."""
if pending_queue is None:
return []
items: list[InboundMessage] = []
while True:
try:
items.append(pending_queue.get_nowait())
except asyncio.QueueEmpty:
break
return items
result = await self.runner.run(AgentRunSpec(
initial_messages=initial_messages,
tools=self.tools,
@ -392,13 +411,14 @@ class AgentLoop:
provider_retry_mode=self.provider_retry_mode,
progress_callback=on_progress,
checkpoint_callback=_checkpoint,
injection_callback=_drain_pending,
))
self._last_usage = result.usage
if result.stop_reason == "max_iterations":
logger.warning("Max iterations ({}) reached", self.max_iterations)
elif result.stop_reason == "error":
logger.error("LLM returned error: {}", (result.final_content or "")[:200])
return result.final_content, result.tools_used, result.messages
return result.final_content, result.tools_used, result.messages, result.stop_reason, result.had_injections
async def run(self) -> None:
"""Run the agent loop, dispatching messages as tasks to stay responsive to /stop."""
@ -429,67 +449,112 @@ class AgentLoop:
if result:
await self.bus.publish_outbound(result)
continue
# If this session already has an active pending queue (i.e. a task
# is processing this session), route the message there for mid-turn
# injection instead of creating a competing task.
if msg.session_key in self._pending_queues:
try:
self._pending_queues[msg.session_key].put_nowait(msg)
except asyncio.QueueFull:
logger.warning(
"Pending queue full for session {}, dropping follow-up",
msg.session_key,
)
else:
logger.info(
"Routed follow-up message to pending queue for session {}",
msg.session_key,
)
continue
task = asyncio.create_task(self._dispatch(msg))
self._active_tasks.setdefault(msg.session_key, []).append(task)
task.add_done_callback(lambda t, k=msg.session_key: self._active_tasks.get(k, []) and self._active_tasks[k].remove(t) if t in self._active_tasks.get(k, []) else None)
async def _dispatch(self, msg: InboundMessage) -> None:
"""Process a message: per-session serial, cross-session concurrent."""
lock = self._session_locks.setdefault(msg.session_key, asyncio.Lock())
session_key = msg.session_key
lock = self._session_locks.setdefault(session_key, asyncio.Lock())
gate = self._concurrency_gate or nullcontext()
async with lock, gate:
try:
on_stream = on_stream_end = None
if msg.metadata.get("_wants_stream"):
# Split one answer into distinct stream segments.
stream_base_id = f"{msg.session_key}:{time.time_ns()}"
stream_segment = 0
def _current_stream_id() -> str:
return f"{stream_base_id}:{stream_segment}"
# Register a pending queue so follow-up messages for this session are
# routed here (mid-turn injection) instead of spawning a new task.
pending = asyncio.Queue(maxsize=20)
self._pending_queues[session_key] = pending
async def on_stream(delta: str) -> None:
meta = dict(msg.metadata or {})
meta["_stream_delta"] = True
meta["_stream_id"] = _current_stream_id()
try:
async with lock, gate:
try:
on_stream = on_stream_end = None
if msg.metadata.get("_wants_stream"):
# Split one answer into distinct stream segments.
stream_base_id = f"{msg.session_key}:{time.time_ns()}"
stream_segment = 0
def _current_stream_id() -> str:
return f"{stream_base_id}:{stream_segment}"
async def on_stream(delta: str) -> None:
meta = dict(msg.metadata or {})
meta["_stream_delta"] = True
meta["_stream_id"] = _current_stream_id()
await self.bus.publish_outbound(OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id,
content=delta,
metadata=meta,
))
async def on_stream_end(*, resuming: bool = False) -> None:
nonlocal stream_segment
meta = dict(msg.metadata or {})
meta["_stream_end"] = True
meta["_resuming"] = resuming
meta["_stream_id"] = _current_stream_id()
await self.bus.publish_outbound(OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id,
content="",
metadata=meta,
))
stream_segment += 1
response = await self._process_message(
msg, on_stream=on_stream, on_stream_end=on_stream_end,
pending_queue=pending,
)
if response is not None:
await self.bus.publish_outbound(response)
elif msg.channel == "cli":
await self.bus.publish_outbound(OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id,
content=delta,
metadata=meta,
content="", metadata=msg.metadata or {},
))
async def on_stream_end(*, resuming: bool = False) -> None:
nonlocal stream_segment
meta = dict(msg.metadata or {})
meta["_stream_end"] = True
meta["_resuming"] = resuming
meta["_stream_id"] = _current_stream_id()
await self.bus.publish_outbound(OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id,
content="",
metadata=meta,
))
stream_segment += 1
response = await self._process_message(
msg, on_stream=on_stream, on_stream_end=on_stream_end,
)
if response is not None:
await self.bus.publish_outbound(response)
elif msg.channel == "cli":
except asyncio.CancelledError:
logger.info("Task cancelled for session {}", session_key)
raise
except Exception:
logger.exception("Error processing message for session {}", session_key)
await self.bus.publish_outbound(OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id,
content="", metadata=msg.metadata or {},
content="Sorry, I encountered an error.",
))
except asyncio.CancelledError:
logger.info("Task cancelled for session {}", msg.session_key)
raise
except Exception:
logger.exception("Error processing message for session {}", msg.session_key)
await self.bus.publish_outbound(OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id,
content="Sorry, I encountered an error.",
))
finally:
# Drain any messages still in the pending queue and re-publish
# them to the bus so they are processed as fresh inbound messages
# rather than silently lost.
queue = self._pending_queues.pop(session_key, None)
if queue is not None:
leftover = 0
while True:
try:
item = queue.get_nowait()
except asyncio.QueueEmpty:
break
await self.bus.publish_inbound(item)
leftover += 1
if leftover:
logger.info(
"Re-published {} leftover message(s) to bus for session {}",
leftover, session_key,
)
async def close_mcp(self) -> None:
"""Drain pending background archives, then close MCP connections."""
@ -521,6 +586,7 @@ class AgentLoop:
on_progress: Callable[[str], Awaitable[None]] | None = None,
on_stream: Callable[[str], Awaitable[None]] | None = None,
on_stream_end: Callable[..., Awaitable[None]] | None = None,
pending_queue: asyncio.Queue | None = None,
) -> OutboundMessage | None:
"""Process a single inbound message and return the response."""
# System messages: parse origin from chat_id ("channel:chat_id")
@ -546,7 +612,7 @@ class AgentLoop:
session_summary=pending,
current_role=current_role,
)
final_content, _, all_msgs = await self._run_agent_loop(
final_content, _, all_msgs, _, _ = await self._run_agent_loop(
messages, session=session, channel=channel, chat_id=chat_id,
message_id=msg.metadata.get("message_id"),
)
@ -598,7 +664,7 @@ class AgentLoop:
channel=msg.channel, chat_id=msg.chat_id, content=content, metadata=meta,
))
final_content, _, all_msgs = await self._run_agent_loop(
final_content, _, all_msgs, stop_reason, had_injections = await self._run_agent_loop(
initial_messages,
on_progress=on_progress or _bus_progress,
on_stream=on_stream,
@ -606,6 +672,7 @@ class AgentLoop:
session=session,
channel=msg.channel, chat_id=msg.chat_id,
message_id=msg.metadata.get("message_id"),
pending_queue=pending_queue,
)
if final_content is None or not final_content.strip():
@ -616,8 +683,13 @@ class AgentLoop:
self.sessions.save(session)
self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session))
if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn:
return None
# When follow-up messages were injected mid-turn, the LLM's final
# response addresses those follow-ups. Always send the response in
# this case, even if MessageTool was used earlier in the turn — the
# follow-up response is new content the user hasn't seen.
if not had_injections:
if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn:
return None
preview = final_content[:120] + "..." if len(final_content) > 120 else final_content
logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview)

View File

@ -31,7 +31,11 @@ from nanobot.utils.runtime import (
_DEFAULT_ERROR_MESSAGE = "Sorry, I encountered an error calling the AI model."
_MAX_EMPTY_RETRIES = 2
_MAX_INJECTIONS_PER_TURN = 3
_MAX_INJECTION_CYCLES = 5
_SNIP_SAFETY_BUFFER = 1024
@dataclass(slots=True)
class AgentRunSpec:
"""Configuration for a single agent execution."""
@ -56,6 +60,7 @@ class AgentRunSpec:
provider_retry_mode: str = "standard"
progress_callback: Any | None = None
checkpoint_callback: Any | None = None
injection_callback: Any | None = None
@dataclass(slots=True)
@ -69,6 +74,7 @@ class AgentRunResult:
stop_reason: str = "completed"
error: str | None = None
tool_events: list[dict[str, str]] = field(default_factory=list)
had_injections: bool = False
class AgentRunner:
@ -77,6 +83,38 @@ class AgentRunner:
def __init__(self, provider: LLMProvider):
self.provider = provider
async def _drain_injections(self, spec: AgentRunSpec) -> list[str]:
"""Drain pending user messages via the injection callback.
Returns all drained message contents (capped by
``_MAX_INJECTIONS_PER_TURN``), or an empty list when there is
nothing to inject. Messages beyond the cap are logged so they
are not silently lost.
"""
if spec.injection_callback is None:
return []
try:
items = await spec.injection_callback()
except Exception:
logger.exception("injection_callback failed")
return []
if not items:
return []
# items are InboundMessage objects from _drain_pending
texts: list[str] = []
for item in items:
text = getattr(item, "content", str(item))
if text.strip():
texts.append(text)
if len(texts) > _MAX_INJECTIONS_PER_TURN:
dropped = len(texts) - _MAX_INJECTIONS_PER_TURN
logger.warning(
"Injection batch has {} messages, capping to {} ({} dropped)",
len(texts), _MAX_INJECTIONS_PER_TURN, dropped,
)
texts = texts[-_MAX_INJECTIONS_PER_TURN:]
return texts
async def run(self, spec: AgentRunSpec) -> AgentRunResult:
hook = spec.hook or AgentHook()
messages = list(spec.initial_messages)
@ -88,6 +126,8 @@ class AgentRunner:
tool_events: list[dict[str, str]] = []
external_lookup_counts: dict[str, int] = {}
empty_content_retries = 0
had_injections = False
injection_cycles = 0
for iteration in range(spec.max_iterations):
try:
@ -181,6 +221,18 @@ class AgentRunner:
},
)
empty_content_retries = 0
# Checkpoint 1: drain injections after tools, before next LLM call
if injection_cycles < _MAX_INJECTION_CYCLES:
injections = await self._drain_injections(spec)
if injections:
had_injections = True
injection_cycles += 1
for text in injections:
messages.append({"role": "user", "content": text})
logger.info(
"Injected {} follow-up message(s) after tool execution ({}/{})",
len(injections), injection_cycles, _MAX_INJECTION_CYCLES,
)
await hook.after_iteration(context)
continue
@ -216,8 +268,29 @@ class AgentRunner:
context.tool_calls = list(response.tool_calls)
clean = hook.finalize_content(context, response.content)
# Check for mid-turn injections BEFORE signaling stream end.
# If injections are found we keep the stream alive (resuming=True)
# so streaming channels don't prematurely finalize the card.
_injected_after_final = False
if injection_cycles < _MAX_INJECTION_CYCLES:
injections = await self._drain_injections(spec)
if injections:
had_injections = True
injection_cycles += 1
_injected_after_final = True
for text in injections:
messages.append({"role": "user", "content": text})
logger.info(
"Injected {} follow-up message(s) after final response ({}/{})",
len(injections), injection_cycles, _MAX_INJECTION_CYCLES,
)
if hook.wants_streaming():
await hook.on_stream_end(context, resuming=False)
await hook.on_stream_end(context, resuming=_injected_after_final)
if _injected_after_final:
await hook.after_iteration(context)
continue
if response.finish_reason == "error":
final_content = clean or spec.error_message or _DEFAULT_ERROR_MESSAGE
@ -283,6 +356,7 @@ class AgentRunner:
stop_reason=stop_reason,
error=error,
tool_events=tool_events,
had_injections=had_injections,
)
def _build_request_kwargs(

View File

@ -278,7 +278,7 @@ async def test_agent_loop_extra_hook_receives_calls(tmp_path):
)
loop.tools.get_definitions = MagicMock(return_value=[])
content, tools_used, messages = await loop._run_agent_loop(
content, tools_used, messages, _, _ = await loop._run_agent_loop(
[{"role": "user", "content": "hi"}]
)
@ -302,7 +302,7 @@ async def test_agent_loop_extra_hook_error_isolation(tmp_path):
)
loop.tools.get_definitions = MagicMock(return_value=[])
content, _, _ = await loop._run_agent_loop(
content, _, _, _, _ = await loop._run_agent_loop(
[{"role": "user", "content": "hi"}]
)
@ -344,7 +344,7 @@ async def test_agent_loop_no_hooks_backward_compat(tmp_path):
loop.tools.execute = AsyncMock(return_value="ok")
loop.max_iterations = 2
content, tools_used, _ = await loop._run_agent_loop([])
content, tools_used, _, _, _ = await loop._run_agent_loop([])
assert content == (
"I reached the maximum number of tool call iterations (2) "
"without completing the task. You can try breaking the task into smaller steps."

View File

@ -798,7 +798,7 @@ async def test_loop_max_iterations_message_stays_stable(tmp_path):
loop.tools.execute = AsyncMock(return_value="ok")
loop.max_iterations = 2
final_content, _, _ = await loop._run_agent_loop([])
final_content, _, _, _, _ = await loop._run_agent_loop([])
assert final_content == (
"I reached the maximum number of tool call iterations (2) "
@ -825,7 +825,7 @@ async def test_loop_stream_filter_handles_think_only_prefix_without_crashing(tmp
async def on_stream_end(*, resuming: bool = False) -> None:
endings.append(resuming)
final_content, _, _ = await loop._run_agent_loop(
final_content, _, _, _, _ = await loop._run_agent_loop(
[],
on_stream=on_stream,
on_stream_end=on_stream_end,
@ -849,7 +849,7 @@ async def test_loop_retries_think_only_final_response(tmp_path):
loop.provider.chat_with_retry = chat_with_retry
final_content, _, _ = await loop._run_agent_loop([])
final_content, _, _, _, _ = await loop._run_agent_loop([])
assert final_content == "Recovered answer"
assert call_count["n"] == 2
@ -999,3 +999,412 @@ async def test_runner_passes_cached_tokens_to_hook_context():
assert len(captured_usage) == 1
assert captured_usage[0]["cached_tokens"] == 150
# ── Mid-turn injection tests ──────────────────────────────────────────────
@pytest.mark.asyncio
async def test_drain_injections_returns_empty_when_no_callback():
"""No injection_callback → empty list."""
from nanobot.agent.runner import AgentRunSpec, AgentRunner
provider = MagicMock()
runner = AgentRunner(provider)
tools = MagicMock()
tools.get_definitions.return_value = []
spec = AgentRunSpec(
initial_messages=[], tools=tools, model="m",
max_iterations=1, max_tool_result_chars=1000,
injection_callback=None,
)
result = await runner._drain_injections(spec)
assert result == []
@pytest.mark.asyncio
async def test_drain_injections_extracts_content_from_inbound_messages():
"""Should extract .content from InboundMessage objects."""
from nanobot.agent.runner import AgentRunSpec, AgentRunner, _MAX_INJECTIONS_PER_TURN
from nanobot.bus.events import InboundMessage
provider = MagicMock()
runner = AgentRunner(provider)
tools = MagicMock()
tools.get_definitions.return_value = []
msgs = [
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="hello"),
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="world"),
]
async def cb():
return msgs
spec = AgentRunSpec(
initial_messages=[], tools=tools, model="m",
max_iterations=1, max_tool_result_chars=1000,
injection_callback=cb,
)
result = await runner._drain_injections(spec)
assert result == ["hello", "world"]
@pytest.mark.asyncio
async def test_drain_injections_caps_at_max_and_logs_warning():
"""When more than _MAX_INJECTIONS_PER_TURN items, only the last N are kept."""
from nanobot.agent.runner import AgentRunSpec, AgentRunner, _MAX_INJECTIONS_PER_TURN
from nanobot.bus.events import InboundMessage
provider = MagicMock()
runner = AgentRunner(provider)
tools = MagicMock()
tools.get_definitions.return_value = []
msgs = [
InboundMessage(channel="cli", sender_id="u", chat_id="c", content=f"msg{i}")
for i in range(_MAX_INJECTIONS_PER_TURN + 3)
]
async def cb():
return msgs
spec = AgentRunSpec(
initial_messages=[], tools=tools, model="m",
max_iterations=1, max_tool_result_chars=1000,
injection_callback=cb,
)
result = await runner._drain_injections(spec)
assert len(result) == _MAX_INJECTIONS_PER_TURN
# Should keep the LAST _MAX_INJECTIONS_PER_TURN items
assert result[0] == "msg3"
assert result[-1] == f"msg{_MAX_INJECTIONS_PER_TURN + 2}"
@pytest.mark.asyncio
async def test_drain_injections_skips_empty_content():
"""Messages with blank content should be filtered out."""
from nanobot.agent.runner import AgentRunSpec, AgentRunner
from nanobot.bus.events import InboundMessage
provider = MagicMock()
runner = AgentRunner(provider)
tools = MagicMock()
tools.get_definitions.return_value = []
msgs = [
InboundMessage(channel="cli", sender_id="u", chat_id="c", content=""),
InboundMessage(channel="cli", sender_id="u", chat_id="c", content=" "),
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="valid"),
]
async def cb():
return msgs
spec = AgentRunSpec(
initial_messages=[], tools=tools, model="m",
max_iterations=1, max_tool_result_chars=1000,
injection_callback=cb,
)
result = await runner._drain_injections(spec)
assert result == ["valid"]
@pytest.mark.asyncio
async def test_drain_injections_handles_callback_exception():
"""If the callback raises, return empty list (error is logged)."""
from nanobot.agent.runner import AgentRunSpec, AgentRunner
provider = MagicMock()
runner = AgentRunner(provider)
tools = MagicMock()
tools.get_definitions.return_value = []
async def cb():
raise RuntimeError("boom")
spec = AgentRunSpec(
initial_messages=[], tools=tools, model="m",
max_iterations=1, max_tool_result_chars=1000,
injection_callback=cb,
)
result = await runner._drain_injections(spec)
assert result == []
@pytest.mark.asyncio
async def test_checkpoint1_injects_after_tool_execution():
"""Follow-up messages are injected after tool execution, before next LLM call."""
from nanobot.agent.runner import AgentRunSpec, AgentRunner
from nanobot.bus.events import InboundMessage
provider = MagicMock()
call_count = {"n": 0}
captured_messages = []
async def chat_with_retry(*, messages, **kwargs):
call_count["n"] += 1
captured_messages.append(list(messages))
if call_count["n"] == 1:
return LLMResponse(
content="using tool",
tool_calls=[ToolCallRequest(id="c1", name="read_file", arguments={"path": "x"})],
usage={},
)
return LLMResponse(content="final answer", tool_calls=[], usage={})
provider.chat_with_retry = chat_with_retry
tools = MagicMock()
tools.get_definitions.return_value = []
tools.execute = AsyncMock(return_value="file content")
injection_queue = asyncio.Queue()
async def inject_cb():
items = []
while not injection_queue.empty():
items.append(await injection_queue.get())
return items
# Put a follow-up message in the queue before the run starts
await injection_queue.put(
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up question")
)
runner = AgentRunner(provider)
result = await runner.run(AgentRunSpec(
initial_messages=[{"role": "user", "content": "hello"}],
tools=tools,
model="test-model",
max_iterations=5,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
injection_callback=inject_cb,
))
assert result.had_injections is True
assert result.final_content == "final answer"
# The second call should have the injected user message
assert call_count["n"] == 2
last_messages = captured_messages[-1]
injected = [m for m in last_messages if m.get("role") == "user" and m.get("content") == "follow-up question"]
assert len(injected) == 1
@pytest.mark.asyncio
async def test_checkpoint2_injects_after_final_response_with_resuming_stream():
"""After final response, if injections exist, stream_end should get resuming=True."""
from nanobot.agent.runner import AgentRunSpec, AgentRunner
from nanobot.agent.hook import AgentHook, AgentHookContext
from nanobot.bus.events import InboundMessage
provider = MagicMock()
call_count = {"n": 0}
stream_end_calls = []
class TrackingHook(AgentHook):
def wants_streaming(self) -> bool:
return True
async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None:
stream_end_calls.append(resuming)
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
return content
async def chat_stream_with_retry(*, messages, on_content_delta=None, **kwargs):
call_count["n"] += 1
if call_count["n"] == 1:
return LLMResponse(content="first answer", tool_calls=[], usage={})
return LLMResponse(content="second answer", tool_calls=[], usage={})
provider.chat_stream_with_retry = chat_stream_with_retry
tools = MagicMock()
tools.get_definitions.return_value = []
injection_queue = asyncio.Queue()
async def inject_cb():
items = []
while not injection_queue.empty():
items.append(await injection_queue.get())
return items
# Inject a follow-up that arrives during the first response
await injection_queue.put(
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="quick follow-up")
)
runner = AgentRunner(provider)
result = await runner.run(AgentRunSpec(
initial_messages=[{"role": "user", "content": "hello"}],
tools=tools,
model="test-model",
max_iterations=5,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
hook=TrackingHook(),
injection_callback=inject_cb,
))
assert result.had_injections is True
assert result.final_content == "second answer"
assert call_count["n"] == 2
# First stream_end should have resuming=True (because injections found)
assert stream_end_calls[0] is True
# Second (final) stream_end should have resuming=False
assert stream_end_calls[-1] is False
@pytest.mark.asyncio
async def test_injection_cycles_capped_at_max():
"""Injection cycles should be capped at _MAX_INJECTION_CYCLES."""
from nanobot.agent.runner import AgentRunSpec, AgentRunner, _MAX_INJECTION_CYCLES
from nanobot.bus.events import InboundMessage
provider = MagicMock()
call_count = {"n": 0}
async def chat_with_retry(*, messages, **kwargs):
call_count["n"] += 1
return LLMResponse(content=f"answer-{call_count['n']}", tool_calls=[], usage={})
provider.chat_with_retry = chat_with_retry
tools = MagicMock()
tools.get_definitions.return_value = []
drain_count = {"n": 0}
async def inject_cb():
drain_count["n"] += 1
# Only inject for the first _MAX_INJECTION_CYCLES drains
if drain_count["n"] <= _MAX_INJECTION_CYCLES:
return [InboundMessage(channel="cli", sender_id="u", chat_id="c", content=f"msg-{drain_count['n']}")]
return []
runner = AgentRunner(provider)
result = await runner.run(AgentRunSpec(
initial_messages=[{"role": "user", "content": "start"}],
tools=tools,
model="test-model",
max_iterations=20,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
injection_callback=inject_cb,
))
assert result.had_injections is True
# Should be capped: _MAX_INJECTION_CYCLES injection rounds + 1 final round
assert call_count["n"] == _MAX_INJECTION_CYCLES + 1
@pytest.mark.asyncio
async def test_no_injections_flag_is_false_by_default():
"""had_injections should be False when no injection callback or no messages."""
from nanobot.agent.runner import AgentRunSpec, AgentRunner
provider = MagicMock()
async def chat_with_retry(**kwargs):
return LLMResponse(content="done", tool_calls=[], usage={})
provider.chat_with_retry = chat_with_retry
tools = MagicMock()
tools.get_definitions.return_value = []
runner = AgentRunner(provider)
result = await runner.run(AgentRunSpec(
initial_messages=[{"role": "user", "content": "hi"}],
tools=tools,
model="test-model",
max_iterations=1,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
))
assert result.had_injections is False
@pytest.mark.asyncio
async def test_pending_queue_cleanup_on_dispatch(tmp_path):
"""_pending_queues should be cleaned up after _dispatch completes."""
loop = _make_loop(tmp_path)
async def chat_with_retry(**kwargs):
return LLMResponse(content="done", tool_calls=[], usage={})
loop.provider.chat_with_retry = chat_with_retry
from nanobot.bus.events import InboundMessage
msg = InboundMessage(channel="cli", sender_id="u", chat_id="c", content="hello")
# The queue should not exist before dispatch
assert msg.session_key not in loop._pending_queues
await loop._dispatch(msg)
# The queue should be cleaned up after dispatch
assert msg.session_key not in loop._pending_queues
@pytest.mark.asyncio
async def test_followup_routed_to_pending_queue(tmp_path):
"""When a session has an active dispatch, follow-up messages go to pending queue."""
from nanobot.bus.events import InboundMessage
loop = _make_loop(tmp_path)
# Simulate an active dispatch by manually adding a pending queue
pending = asyncio.Queue(maxsize=20)
loop._pending_queues["cli:c"] = pending
msg = InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up")
# Directly test the routing logic from run() — if session_key is in
# _pending_queues, the message should be put into the queue.
assert msg.session_key in loop._pending_queues
loop._pending_queues[msg.session_key].put_nowait(msg)
assert not pending.empty()
queued_msg = pending.get_nowait()
assert queued_msg.content == "follow-up"
@pytest.mark.asyncio
async def test_dispatch_republishes_leftover_queue_messages(tmp_path):
"""Messages left in the pending queue after _dispatch are re-published to the bus.
This tests the finally-block cleanup that prevents message loss when
the runner exits early (e.g., max_iterations, tool_error) with messages
still in the queue.
"""
from nanobot.bus.events import InboundMessage
loop = _make_loop(tmp_path)
bus = loop.bus
# Simulate a completed dispatch by manually registering a queue
# with leftover messages, then running the cleanup logic directly.
pending = asyncio.Queue(maxsize=20)
session_key = "cli:c"
loop._pending_queues[session_key] = pending
pending.put_nowait(InboundMessage(channel="cli", sender_id="u", chat_id="c", content="leftover-1"))
pending.put_nowait(InboundMessage(channel="cli", sender_id="u", chat_id="c", content="leftover-2"))
# Execute the cleanup logic from the finally block
queue = loop._pending_queues.pop(session_key, None)
assert queue is not None
leftover = 0
while True:
try:
item = queue.get_nowait()
except asyncio.QueueEmpty:
break
await bus.publish_inbound(item)
leftover += 1
assert leftover == 2
# Verify the messages are now on the bus
msgs = []
while not bus.inbound.empty():
msgs.append(await asyncio.wait_for(bus.consume_inbound(), timeout=0.5))
contents = [m.content for m in msgs]
assert "leftover-1" in contents
assert "leftover-2" in contents

View File

@ -107,7 +107,7 @@ class TestMessageToolSuppressLogic:
async def on_progress(content: str, *, tool_hint: bool = False) -> None:
progress.append((content, tool_hint))
final_content, _, _ = await loop._run_agent_loop([], on_progress=on_progress)
final_content, _, _, _, _ = await loop._run_agent_loop([], on_progress=on_progress)
assert final_content == "Done"
assert progress == [