From bc4cc49a59485f56ceda0735239f602a427ffd87 Mon Sep 17 00:00:00 2001 From: chengyongru <61816729+chengyongru@users.noreply.github.com> Date: Sat, 11 Apr 2026 02:11:02 +0800 Subject: [PATCH] feat(agent): mid-turn message injection for responsive follow-ups (#2985) * feat(agent): add mid-turn message injection for responsive follow-ups Allow user messages sent during an active agent turn to be injected into the running LLM context instead of being queued behind a per-session lock. Inspired by Claude Code's mid-turn queue drain mechanism (query.ts:1547-1643). Key design decisions: - Messages are injected as natural user messages between iterations, no tool cancellation or special system prompt needed - Two drain checkpoints: after tool execution and after final LLM response ("last-mile" to prevent dropping late arrivals) - Bounded by MAX_INJECTION_CYCLES (5) to prevent consuming the iteration budget on rapid follow-ups - had_injections flag bypasses _sent_in_turn suppression so follow-up responses are always delivered Closes #1609 * fix(agent): harden mid-turn injection with streaming fix, bounded queue, and message safety - Fix streaming protocol violation: Checkpoint 2 now checks for injections BEFORE calling on_stream_end, passing resuming=True when injections found so streaming channels (Feishu) don't prematurely finalize the card - Bound pending queue to maxsize=20 with QueueFull handling - Add warning log when injection batch exceeds _MAX_INJECTIONS_PER_TURN - Re-publish leftover queue messages to bus in _dispatch finally block to prevent silent message loss on early exit (max_iterations, tool_error, cancel) - Fix PEP 8 blank line before dataclass and logger.info indentation - Add 12 new tests covering drain, checkpoints, cycle cap, queue routing, cleanup, and leftover re-publish --- nanobot/agent/loop.py | 176 ++++++--- nanobot/agent/runner.py | 76 +++- tests/agent/test_hook_composite.py | 6 +- tests/agent/test_runner.py | 415 +++++++++++++++++++++- tests/tools/test_message_tool_suppress.py | 2 +- 5 files changed, 615 insertions(+), 60 deletions(-) diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index d69defaa1..dd92cac88 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -237,6 +237,10 @@ class AgentLoop: self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks self._background_tasks: list[asyncio.Task] = [] self._session_locks: dict[str, asyncio.Lock] = {} + # Per-session pending queues for mid-turn message injection. + # When a session has an active task, new messages for that session + # are routed here instead of creating a new task. + self._pending_queues: dict[str, asyncio.Queue] = {} # NANOBOT_MAX_CONCURRENT_REQUESTS: <=0 means unlimited; default 3. _max = int(os.environ.get("NANOBOT_MAX_CONCURRENT_REQUESTS", "3")) self._concurrency_gate: asyncio.Semaphore | None = ( @@ -348,13 +352,16 @@ class AgentLoop: channel: str = "cli", chat_id: str = "direct", message_id: str | None = None, - ) -> tuple[str | None, list[str], list[dict]]: + pending_queue: asyncio.Queue | None = None, + ) -> tuple[str | None, list[str], list[dict], str, bool]: """Run the agent iteration loop. *on_stream*: called with each content delta during streaming. *on_stream_end(resuming)*: called when a streaming session finishes. ``resuming=True`` means tool calls follow (spinner should restart); ``resuming=False`` means this is the final response. + + Returns (final_content, tools_used, messages, stop_reason, had_injections). """ loop_hook = _LoopHook( self, @@ -376,6 +383,18 @@ class AgentLoop: return self._set_runtime_checkpoint(session, payload) + async def _drain_pending() -> list[InboundMessage]: + """Non-blocking drain of follow-up messages from the pending queue.""" + if pending_queue is None: + return [] + items: list[InboundMessage] = [] + while True: + try: + items.append(pending_queue.get_nowait()) + except asyncio.QueueEmpty: + break + return items + result = await self.runner.run(AgentRunSpec( initial_messages=initial_messages, tools=self.tools, @@ -392,13 +411,14 @@ class AgentLoop: provider_retry_mode=self.provider_retry_mode, progress_callback=on_progress, checkpoint_callback=_checkpoint, + injection_callback=_drain_pending, )) self._last_usage = result.usage if result.stop_reason == "max_iterations": logger.warning("Max iterations ({}) reached", self.max_iterations) elif result.stop_reason == "error": logger.error("LLM returned error: {}", (result.final_content or "")[:200]) - return result.final_content, result.tools_used, result.messages + return result.final_content, result.tools_used, result.messages, result.stop_reason, result.had_injections async def run(self) -> None: """Run the agent loop, dispatching messages as tasks to stay responsive to /stop.""" @@ -429,67 +449,112 @@ class AgentLoop: if result: await self.bus.publish_outbound(result) continue + # If this session already has an active pending queue (i.e. a task + # is processing this session), route the message there for mid-turn + # injection instead of creating a competing task. + if msg.session_key in self._pending_queues: + try: + self._pending_queues[msg.session_key].put_nowait(msg) + except asyncio.QueueFull: + logger.warning( + "Pending queue full for session {}, dropping follow-up", + msg.session_key, + ) + else: + logger.info( + "Routed follow-up message to pending queue for session {}", + msg.session_key, + ) + continue task = asyncio.create_task(self._dispatch(msg)) self._active_tasks.setdefault(msg.session_key, []).append(task) task.add_done_callback(lambda t, k=msg.session_key: self._active_tasks.get(k, []) and self._active_tasks[k].remove(t) if t in self._active_tasks.get(k, []) else None) async def _dispatch(self, msg: InboundMessage) -> None: """Process a message: per-session serial, cross-session concurrent.""" - lock = self._session_locks.setdefault(msg.session_key, asyncio.Lock()) + session_key = msg.session_key + lock = self._session_locks.setdefault(session_key, asyncio.Lock()) gate = self._concurrency_gate or nullcontext() - async with lock, gate: - try: - on_stream = on_stream_end = None - if msg.metadata.get("_wants_stream"): - # Split one answer into distinct stream segments. - stream_base_id = f"{msg.session_key}:{time.time_ns()}" - stream_segment = 0 - def _current_stream_id() -> str: - return f"{stream_base_id}:{stream_segment}" + # Register a pending queue so follow-up messages for this session are + # routed here (mid-turn injection) instead of spawning a new task. + pending = asyncio.Queue(maxsize=20) + self._pending_queues[session_key] = pending - async def on_stream(delta: str) -> None: - meta = dict(msg.metadata or {}) - meta["_stream_delta"] = True - meta["_stream_id"] = _current_stream_id() + try: + async with lock, gate: + try: + on_stream = on_stream_end = None + if msg.metadata.get("_wants_stream"): + # Split one answer into distinct stream segments. + stream_base_id = f"{msg.session_key}:{time.time_ns()}" + stream_segment = 0 + + def _current_stream_id() -> str: + return f"{stream_base_id}:{stream_segment}" + + async def on_stream(delta: str) -> None: + meta = dict(msg.metadata or {}) + meta["_stream_delta"] = True + meta["_stream_id"] = _current_stream_id() + await self.bus.publish_outbound(OutboundMessage( + channel=msg.channel, chat_id=msg.chat_id, + content=delta, + metadata=meta, + )) + + async def on_stream_end(*, resuming: bool = False) -> None: + nonlocal stream_segment + meta = dict(msg.metadata or {}) + meta["_stream_end"] = True + meta["_resuming"] = resuming + meta["_stream_id"] = _current_stream_id() + await self.bus.publish_outbound(OutboundMessage( + channel=msg.channel, chat_id=msg.chat_id, + content="", + metadata=meta, + )) + stream_segment += 1 + + response = await self._process_message( + msg, on_stream=on_stream, on_stream_end=on_stream_end, + pending_queue=pending, + ) + if response is not None: + await self.bus.publish_outbound(response) + elif msg.channel == "cli": await self.bus.publish_outbound(OutboundMessage( channel=msg.channel, chat_id=msg.chat_id, - content=delta, - metadata=meta, + content="", metadata=msg.metadata or {}, )) - - async def on_stream_end(*, resuming: bool = False) -> None: - nonlocal stream_segment - meta = dict(msg.metadata or {}) - meta["_stream_end"] = True - meta["_resuming"] = resuming - meta["_stream_id"] = _current_stream_id() - await self.bus.publish_outbound(OutboundMessage( - channel=msg.channel, chat_id=msg.chat_id, - content="", - metadata=meta, - )) - stream_segment += 1 - - response = await self._process_message( - msg, on_stream=on_stream, on_stream_end=on_stream_end, - ) - if response is not None: - await self.bus.publish_outbound(response) - elif msg.channel == "cli": + except asyncio.CancelledError: + logger.info("Task cancelled for session {}", session_key) + raise + except Exception: + logger.exception("Error processing message for session {}", session_key) await self.bus.publish_outbound(OutboundMessage( channel=msg.channel, chat_id=msg.chat_id, - content="", metadata=msg.metadata or {}, + content="Sorry, I encountered an error.", )) - except asyncio.CancelledError: - logger.info("Task cancelled for session {}", msg.session_key) - raise - except Exception: - logger.exception("Error processing message for session {}", msg.session_key) - await self.bus.publish_outbound(OutboundMessage( - channel=msg.channel, chat_id=msg.chat_id, - content="Sorry, I encountered an error.", - )) + finally: + # Drain any messages still in the pending queue and re-publish + # them to the bus so they are processed as fresh inbound messages + # rather than silently lost. + queue = self._pending_queues.pop(session_key, None) + if queue is not None: + leftover = 0 + while True: + try: + item = queue.get_nowait() + except asyncio.QueueEmpty: + break + await self.bus.publish_inbound(item) + leftover += 1 + if leftover: + logger.info( + "Re-published {} leftover message(s) to bus for session {}", + leftover, session_key, + ) async def close_mcp(self) -> None: """Drain pending background archives, then close MCP connections.""" @@ -521,6 +586,7 @@ class AgentLoop: on_progress: Callable[[str], Awaitable[None]] | None = None, on_stream: Callable[[str], Awaitable[None]] | None = None, on_stream_end: Callable[..., Awaitable[None]] | None = None, + pending_queue: asyncio.Queue | None = None, ) -> OutboundMessage | None: """Process a single inbound message and return the response.""" # System messages: parse origin from chat_id ("channel:chat_id") @@ -546,7 +612,7 @@ class AgentLoop: session_summary=pending, current_role=current_role, ) - final_content, _, all_msgs = await self._run_agent_loop( + final_content, _, all_msgs, _, _ = await self._run_agent_loop( messages, session=session, channel=channel, chat_id=chat_id, message_id=msg.metadata.get("message_id"), ) @@ -598,7 +664,7 @@ class AgentLoop: channel=msg.channel, chat_id=msg.chat_id, content=content, metadata=meta, )) - final_content, _, all_msgs = await self._run_agent_loop( + final_content, _, all_msgs, stop_reason, had_injections = await self._run_agent_loop( initial_messages, on_progress=on_progress or _bus_progress, on_stream=on_stream, @@ -606,6 +672,7 @@ class AgentLoop: session=session, channel=msg.channel, chat_id=msg.chat_id, message_id=msg.metadata.get("message_id"), + pending_queue=pending_queue, ) if final_content is None or not final_content.strip(): @@ -616,8 +683,13 @@ class AgentLoop: self.sessions.save(session) self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session)) - if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn: - return None + # When follow-up messages were injected mid-turn, the LLM's final + # response addresses those follow-ups. Always send the response in + # this case, even if MessageTool was used earlier in the turn — the + # follow-up response is new content the user hasn't seen. + if not had_injections: + if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn: + return None preview = final_content[:120] + "..." if len(final_content) > 120 else final_content logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview) diff --git a/nanobot/agent/runner.py b/nanobot/agent/runner.py index fbc2a4788..a56ceea40 100644 --- a/nanobot/agent/runner.py +++ b/nanobot/agent/runner.py @@ -31,7 +31,11 @@ from nanobot.utils.runtime import ( _DEFAULT_ERROR_MESSAGE = "Sorry, I encountered an error calling the AI model." _MAX_EMPTY_RETRIES = 2 +_MAX_INJECTIONS_PER_TURN = 3 +_MAX_INJECTION_CYCLES = 5 _SNIP_SAFETY_BUFFER = 1024 + + @dataclass(slots=True) class AgentRunSpec: """Configuration for a single agent execution.""" @@ -56,6 +60,7 @@ class AgentRunSpec: provider_retry_mode: str = "standard" progress_callback: Any | None = None checkpoint_callback: Any | None = None + injection_callback: Any | None = None @dataclass(slots=True) @@ -69,6 +74,7 @@ class AgentRunResult: stop_reason: str = "completed" error: str | None = None tool_events: list[dict[str, str]] = field(default_factory=list) + had_injections: bool = False class AgentRunner: @@ -77,6 +83,38 @@ class AgentRunner: def __init__(self, provider: LLMProvider): self.provider = provider + async def _drain_injections(self, spec: AgentRunSpec) -> list[str]: + """Drain pending user messages via the injection callback. + + Returns all drained message contents (capped by + ``_MAX_INJECTIONS_PER_TURN``), or an empty list when there is + nothing to inject. Messages beyond the cap are logged so they + are not silently lost. + """ + if spec.injection_callback is None: + return [] + try: + items = await spec.injection_callback() + except Exception: + logger.exception("injection_callback failed") + return [] + if not items: + return [] + # items are InboundMessage objects from _drain_pending + texts: list[str] = [] + for item in items: + text = getattr(item, "content", str(item)) + if text.strip(): + texts.append(text) + if len(texts) > _MAX_INJECTIONS_PER_TURN: + dropped = len(texts) - _MAX_INJECTIONS_PER_TURN + logger.warning( + "Injection batch has {} messages, capping to {} ({} dropped)", + len(texts), _MAX_INJECTIONS_PER_TURN, dropped, + ) + texts = texts[-_MAX_INJECTIONS_PER_TURN:] + return texts + async def run(self, spec: AgentRunSpec) -> AgentRunResult: hook = spec.hook or AgentHook() messages = list(spec.initial_messages) @@ -88,6 +126,8 @@ class AgentRunner: tool_events: list[dict[str, str]] = [] external_lookup_counts: dict[str, int] = {} empty_content_retries = 0 + had_injections = False + injection_cycles = 0 for iteration in range(spec.max_iterations): try: @@ -181,6 +221,18 @@ class AgentRunner: }, ) empty_content_retries = 0 + # Checkpoint 1: drain injections after tools, before next LLM call + if injection_cycles < _MAX_INJECTION_CYCLES: + injections = await self._drain_injections(spec) + if injections: + had_injections = True + injection_cycles += 1 + for text in injections: + messages.append({"role": "user", "content": text}) + logger.info( + "Injected {} follow-up message(s) after tool execution ({}/{})", + len(injections), injection_cycles, _MAX_INJECTION_CYCLES, + ) await hook.after_iteration(context) continue @@ -216,8 +268,29 @@ class AgentRunner: context.tool_calls = list(response.tool_calls) clean = hook.finalize_content(context, response.content) + # Check for mid-turn injections BEFORE signaling stream end. + # If injections are found we keep the stream alive (resuming=True) + # so streaming channels don't prematurely finalize the card. + _injected_after_final = False + if injection_cycles < _MAX_INJECTION_CYCLES: + injections = await self._drain_injections(spec) + if injections: + had_injections = True + injection_cycles += 1 + _injected_after_final = True + for text in injections: + messages.append({"role": "user", "content": text}) + logger.info( + "Injected {} follow-up message(s) after final response ({}/{})", + len(injections), injection_cycles, _MAX_INJECTION_CYCLES, + ) + if hook.wants_streaming(): - await hook.on_stream_end(context, resuming=False) + await hook.on_stream_end(context, resuming=_injected_after_final) + + if _injected_after_final: + await hook.after_iteration(context) + continue if response.finish_reason == "error": final_content = clean or spec.error_message or _DEFAULT_ERROR_MESSAGE @@ -283,6 +356,7 @@ class AgentRunner: stop_reason=stop_reason, error=error, tool_events=tool_events, + had_injections=had_injections, ) def _build_request_kwargs( diff --git a/tests/agent/test_hook_composite.py b/tests/agent/test_hook_composite.py index 590d8db64..ecba87805 100644 --- a/tests/agent/test_hook_composite.py +++ b/tests/agent/test_hook_composite.py @@ -278,7 +278,7 @@ async def test_agent_loop_extra_hook_receives_calls(tmp_path): ) loop.tools.get_definitions = MagicMock(return_value=[]) - content, tools_used, messages = await loop._run_agent_loop( + content, tools_used, messages, _, _ = await loop._run_agent_loop( [{"role": "user", "content": "hi"}] ) @@ -302,7 +302,7 @@ async def test_agent_loop_extra_hook_error_isolation(tmp_path): ) loop.tools.get_definitions = MagicMock(return_value=[]) - content, _, _ = await loop._run_agent_loop( + content, _, _, _, _ = await loop._run_agent_loop( [{"role": "user", "content": "hi"}] ) @@ -344,7 +344,7 @@ async def test_agent_loop_no_hooks_backward_compat(tmp_path): loop.tools.execute = AsyncMock(return_value="ok") loop.max_iterations = 2 - content, tools_used, _ = await loop._run_agent_loop([]) + content, tools_used, _, _, _ = await loop._run_agent_loop([]) assert content == ( "I reached the maximum number of tool call iterations (2) " "without completing the task. You can try breaking the task into smaller steps." diff --git a/tests/agent/test_runner.py b/tests/agent/test_runner.py index a700f495b..f6cba5304 100644 --- a/tests/agent/test_runner.py +++ b/tests/agent/test_runner.py @@ -798,7 +798,7 @@ async def test_loop_max_iterations_message_stays_stable(tmp_path): loop.tools.execute = AsyncMock(return_value="ok") loop.max_iterations = 2 - final_content, _, _ = await loop._run_agent_loop([]) + final_content, _, _, _, _ = await loop._run_agent_loop([]) assert final_content == ( "I reached the maximum number of tool call iterations (2) " @@ -825,7 +825,7 @@ async def test_loop_stream_filter_handles_think_only_prefix_without_crashing(tmp async def on_stream_end(*, resuming: bool = False) -> None: endings.append(resuming) - final_content, _, _ = await loop._run_agent_loop( + final_content, _, _, _, _ = await loop._run_agent_loop( [], on_stream=on_stream, on_stream_end=on_stream_end, @@ -849,7 +849,7 @@ async def test_loop_retries_think_only_final_response(tmp_path): loop.provider.chat_with_retry = chat_with_retry - final_content, _, _ = await loop._run_agent_loop([]) + final_content, _, _, _, _ = await loop._run_agent_loop([]) assert final_content == "Recovered answer" assert call_count["n"] == 2 @@ -999,3 +999,412 @@ async def test_runner_passes_cached_tokens_to_hook_context(): assert len(captured_usage) == 1 assert captured_usage[0]["cached_tokens"] == 150 + + +# ── Mid-turn injection tests ────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_drain_injections_returns_empty_when_no_callback(): + """No injection_callback → empty list.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock() + runner = AgentRunner(provider) + tools = MagicMock() + tools.get_definitions.return_value = [] + spec = AgentRunSpec( + initial_messages=[], tools=tools, model="m", + max_iterations=1, max_tool_result_chars=1000, + injection_callback=None, + ) + result = await runner._drain_injections(spec) + assert result == [] + + +@pytest.mark.asyncio +async def test_drain_injections_extracts_content_from_inbound_messages(): + """Should extract .content from InboundMessage objects.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner, _MAX_INJECTIONS_PER_TURN + from nanobot.bus.events import InboundMessage + + provider = MagicMock() + runner = AgentRunner(provider) + tools = MagicMock() + tools.get_definitions.return_value = [] + + msgs = [ + InboundMessage(channel="cli", sender_id="u", chat_id="c", content="hello"), + InboundMessage(channel="cli", sender_id="u", chat_id="c", content="world"), + ] + + async def cb(): + return msgs + + spec = AgentRunSpec( + initial_messages=[], tools=tools, model="m", + max_iterations=1, max_tool_result_chars=1000, + injection_callback=cb, + ) + result = await runner._drain_injections(spec) + assert result == ["hello", "world"] + + +@pytest.mark.asyncio +async def test_drain_injections_caps_at_max_and_logs_warning(): + """When more than _MAX_INJECTIONS_PER_TURN items, only the last N are kept.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner, _MAX_INJECTIONS_PER_TURN + from nanobot.bus.events import InboundMessage + + provider = MagicMock() + runner = AgentRunner(provider) + tools = MagicMock() + tools.get_definitions.return_value = [] + + msgs = [ + InboundMessage(channel="cli", sender_id="u", chat_id="c", content=f"msg{i}") + for i in range(_MAX_INJECTIONS_PER_TURN + 3) + ] + + async def cb(): + return msgs + + spec = AgentRunSpec( + initial_messages=[], tools=tools, model="m", + max_iterations=1, max_tool_result_chars=1000, + injection_callback=cb, + ) + result = await runner._drain_injections(spec) + assert len(result) == _MAX_INJECTIONS_PER_TURN + # Should keep the LAST _MAX_INJECTIONS_PER_TURN items + assert result[0] == "msg3" + assert result[-1] == f"msg{_MAX_INJECTIONS_PER_TURN + 2}" + + +@pytest.mark.asyncio +async def test_drain_injections_skips_empty_content(): + """Messages with blank content should be filtered out.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.bus.events import InboundMessage + + provider = MagicMock() + runner = AgentRunner(provider) + tools = MagicMock() + tools.get_definitions.return_value = [] + + msgs = [ + InboundMessage(channel="cli", sender_id="u", chat_id="c", content=""), + InboundMessage(channel="cli", sender_id="u", chat_id="c", content=" "), + InboundMessage(channel="cli", sender_id="u", chat_id="c", content="valid"), + ] + + async def cb(): + return msgs + + spec = AgentRunSpec( + initial_messages=[], tools=tools, model="m", + max_iterations=1, max_tool_result_chars=1000, + injection_callback=cb, + ) + result = await runner._drain_injections(spec) + assert result == ["valid"] + + +@pytest.mark.asyncio +async def test_drain_injections_handles_callback_exception(): + """If the callback raises, return empty list (error is logged).""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock() + runner = AgentRunner(provider) + tools = MagicMock() + tools.get_definitions.return_value = [] + + async def cb(): + raise RuntimeError("boom") + + spec = AgentRunSpec( + initial_messages=[], tools=tools, model="m", + max_iterations=1, max_tool_result_chars=1000, + injection_callback=cb, + ) + result = await runner._drain_injections(spec) + assert result == [] + + +@pytest.mark.asyncio +async def test_checkpoint1_injects_after_tool_execution(): + """Follow-up messages are injected after tool execution, before next LLM call.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.bus.events import InboundMessage + + provider = MagicMock() + call_count = {"n": 0} + captured_messages = [] + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + captured_messages.append(list(messages)) + if call_count["n"] == 1: + return LLMResponse( + content="using tool", + tool_calls=[ToolCallRequest(id="c1", name="read_file", arguments={"path": "x"})], + usage={}, + ) + return LLMResponse(content="final answer", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + tools.execute = AsyncMock(return_value="file content") + + injection_queue = asyncio.Queue() + + async def inject_cb(): + items = [] + while not injection_queue.empty(): + items.append(await injection_queue.get()) + return items + + # Put a follow-up message in the queue before the run starts + await injection_queue.put( + InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up question") + ) + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "hello"}], + tools=tools, + model="test-model", + max_iterations=5, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + injection_callback=inject_cb, + )) + + assert result.had_injections is True + assert result.final_content == "final answer" + # The second call should have the injected user message + assert call_count["n"] == 2 + last_messages = captured_messages[-1] + injected = [m for m in last_messages if m.get("role") == "user" and m.get("content") == "follow-up question"] + assert len(injected) == 1 + + +@pytest.mark.asyncio +async def test_checkpoint2_injects_after_final_response_with_resuming_stream(): + """After final response, if injections exist, stream_end should get resuming=True.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.agent.hook import AgentHook, AgentHookContext + from nanobot.bus.events import InboundMessage + + provider = MagicMock() + call_count = {"n": 0} + stream_end_calls = [] + + class TrackingHook(AgentHook): + def wants_streaming(self) -> bool: + return True + + async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None: + stream_end_calls.append(resuming) + + def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None: + return content + + async def chat_stream_with_retry(*, messages, on_content_delta=None, **kwargs): + call_count["n"] += 1 + if call_count["n"] == 1: + return LLMResponse(content="first answer", tool_calls=[], usage={}) + return LLMResponse(content="second answer", tool_calls=[], usage={}) + + provider.chat_stream_with_retry = chat_stream_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + + injection_queue = asyncio.Queue() + + async def inject_cb(): + items = [] + while not injection_queue.empty(): + items.append(await injection_queue.get()) + return items + + # Inject a follow-up that arrives during the first response + await injection_queue.put( + InboundMessage(channel="cli", sender_id="u", chat_id="c", content="quick follow-up") + ) + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "hello"}], + tools=tools, + model="test-model", + max_iterations=5, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + hook=TrackingHook(), + injection_callback=inject_cb, + )) + + assert result.had_injections is True + assert result.final_content == "second answer" + assert call_count["n"] == 2 + # First stream_end should have resuming=True (because injections found) + assert stream_end_calls[0] is True + # Second (final) stream_end should have resuming=False + assert stream_end_calls[-1] is False + + +@pytest.mark.asyncio +async def test_injection_cycles_capped_at_max(): + """Injection cycles should be capped at _MAX_INJECTION_CYCLES.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner, _MAX_INJECTION_CYCLES + from nanobot.bus.events import InboundMessage + + provider = MagicMock() + call_count = {"n": 0} + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + return LLMResponse(content=f"answer-{call_count['n']}", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + + drain_count = {"n": 0} + + async def inject_cb(): + drain_count["n"] += 1 + # Only inject for the first _MAX_INJECTION_CYCLES drains + if drain_count["n"] <= _MAX_INJECTION_CYCLES: + return [InboundMessage(channel="cli", sender_id="u", chat_id="c", content=f"msg-{drain_count['n']}")] + return [] + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "start"}], + tools=tools, + model="test-model", + max_iterations=20, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + injection_callback=inject_cb, + )) + + assert result.had_injections is True + # Should be capped: _MAX_INJECTION_CYCLES injection rounds + 1 final round + assert call_count["n"] == _MAX_INJECTION_CYCLES + 1 + + +@pytest.mark.asyncio +async def test_no_injections_flag_is_false_by_default(): + """had_injections should be False when no injection callback or no messages.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock() + + async def chat_with_retry(**kwargs): + return LLMResponse(content="done", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "hi"}], + tools=tools, + model="test-model", + max_iterations=1, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + )) + + assert result.had_injections is False + + +@pytest.mark.asyncio +async def test_pending_queue_cleanup_on_dispatch(tmp_path): + """_pending_queues should be cleaned up after _dispatch completes.""" + loop = _make_loop(tmp_path) + + async def chat_with_retry(**kwargs): + return LLMResponse(content="done", tool_calls=[], usage={}) + + loop.provider.chat_with_retry = chat_with_retry + + from nanobot.bus.events import InboundMessage + + msg = InboundMessage(channel="cli", sender_id="u", chat_id="c", content="hello") + # The queue should not exist before dispatch + assert msg.session_key not in loop._pending_queues + + await loop._dispatch(msg) + + # The queue should be cleaned up after dispatch + assert msg.session_key not in loop._pending_queues + + +@pytest.mark.asyncio +async def test_followup_routed_to_pending_queue(tmp_path): + """When a session has an active dispatch, follow-up messages go to pending queue.""" + from nanobot.bus.events import InboundMessage + + loop = _make_loop(tmp_path) + + # Simulate an active dispatch by manually adding a pending queue + pending = asyncio.Queue(maxsize=20) + loop._pending_queues["cli:c"] = pending + + msg = InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up") + + # Directly test the routing logic from run() — if session_key is in + # _pending_queues, the message should be put into the queue. + assert msg.session_key in loop._pending_queues + loop._pending_queues[msg.session_key].put_nowait(msg) + + assert not pending.empty() + queued_msg = pending.get_nowait() + assert queued_msg.content == "follow-up" + + +@pytest.mark.asyncio +async def test_dispatch_republishes_leftover_queue_messages(tmp_path): + """Messages left in the pending queue after _dispatch are re-published to the bus. + + This tests the finally-block cleanup that prevents message loss when + the runner exits early (e.g., max_iterations, tool_error) with messages + still in the queue. + """ + from nanobot.bus.events import InboundMessage + + loop = _make_loop(tmp_path) + bus = loop.bus + + # Simulate a completed dispatch by manually registering a queue + # with leftover messages, then running the cleanup logic directly. + pending = asyncio.Queue(maxsize=20) + session_key = "cli:c" + loop._pending_queues[session_key] = pending + pending.put_nowait(InboundMessage(channel="cli", sender_id="u", chat_id="c", content="leftover-1")) + pending.put_nowait(InboundMessage(channel="cli", sender_id="u", chat_id="c", content="leftover-2")) + + # Execute the cleanup logic from the finally block + queue = loop._pending_queues.pop(session_key, None) + assert queue is not None + leftover = 0 + while True: + try: + item = queue.get_nowait() + except asyncio.QueueEmpty: + break + await bus.publish_inbound(item) + leftover += 1 + + assert leftover == 2 + + # Verify the messages are now on the bus + msgs = [] + while not bus.inbound.empty(): + msgs.append(await asyncio.wait_for(bus.consume_inbound(), timeout=0.5)) + contents = [m.content for m in msgs] + assert "leftover-1" in contents + assert "leftover-2" in contents diff --git a/tests/tools/test_message_tool_suppress.py b/tests/tools/test_message_tool_suppress.py index 26d12085f..a922e95ed 100644 --- a/tests/tools/test_message_tool_suppress.py +++ b/tests/tools/test_message_tool_suppress.py @@ -107,7 +107,7 @@ class TestMessageToolSuppressLogic: async def on_progress(content: str, *, tool_hint: bool = False) -> None: progress.append((content, tool_hint)) - final_content, _, _ = await loop._run_agent_loop([], on_progress=on_progress) + final_content, _, _, _, _ = await loop._run_agent_loop([], on_progress=on_progress) assert final_content == "Done" assert progress == [