diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index b51444650..0f72e39f6 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -324,6 +324,12 @@ class AgentLoop: return format_tool_hints(tool_calls) + def _effective_session_key(self, msg: InboundMessage) -> str: + """Return the session key used for task routing and mid-turn injections.""" + if self._unified_session and not msg.session_key_override: + return UNIFIED_SESSION_KEY + return msg.session_key + async def _run_agent_loop( self, initial_messages: list[dict], @@ -430,30 +436,32 @@ class AgentLoop: if result: await self.bus.publish_outbound(result) continue + effective_key = self._effective_session_key(msg) # If this session already has an active pending queue (i.e. a task # is processing this session), route the message there for mid-turn # injection instead of creating a competing task. - if msg.session_key in self._pending_queues: + if effective_key in self._pending_queues: + pending_msg = msg + if effective_key != msg.session_key: + pending_msg = dataclasses.replace( + msg, + session_key_override=effective_key, + ) try: - self._pending_queues[msg.session_key].put_nowait(msg) + self._pending_queues[effective_key].put_nowait(pending_msg) except asyncio.QueueFull: logger.warning( "Pending queue full for session {}, dropping follow-up", - msg.session_key, + effective_key, ) else: logger.info( "Routed follow-up message to pending queue for session {}", - msg.session_key, + effective_key, ) continue # Compute the effective session key before dispatching # This ensures /stop command can find tasks correctly when unified session is enabled - effective_key = ( - UNIFIED_SESSION_KEY - if self._unified_session and not msg.session_key_override - else msg.session_key - ) task = asyncio.create_task(self._dispatch(msg)) self._active_tasks.setdefault(effective_key, []).append(task) task.add_done_callback( @@ -465,9 +473,9 @@ class AgentLoop: async def _dispatch(self, msg: InboundMessage) -> None: """Process a message: per-session serial, cross-session concurrent.""" - if self._unified_session and not msg.session_key_override: - msg = dataclasses.replace(msg, session_key_override=UNIFIED_SESSION_KEY) - session_key = msg.session_key + session_key = self._effective_session_key(msg) + if session_key != msg.session_key: + msg = dataclasses.replace(msg, session_key_override=session_key) lock = self._session_locks.setdefault(session_key, asyncio.Lock()) gate = self._concurrency_gate or nullcontext() diff --git a/nanobot/agent/runner.py b/nanobot/agent/runner.py index f7187191e..0ba0e6bc6 100644 --- a/nanobot/agent/runner.py +++ b/nanobot/agent/runner.py @@ -309,6 +309,14 @@ class AgentRunner: await hook.after_iteration(context) continue + assistant_message: dict[str, Any] | None = None + if response.finish_reason != "error" and not is_blank_text(clean): + assistant_message = build_assistant_message( + clean, + reasoning_content=response.reasoning_content, + thinking_blocks=response.thinking_blocks, + ) + # Check for mid-turn injections BEFORE signaling stream end. # If injections are found we keep the stream alive (resuming=True) # so streaming channels don't prematurely finalize the card. @@ -319,6 +327,19 @@ class AgentRunner: had_injections = True injection_cycles += 1 _injected_after_final = True + if assistant_message is not None: + messages.append(assistant_message) + await self._emit_checkpoint( + spec, + { + "phase": "final_response", + "iteration": iteration, + "model": spec.model, + "assistant_message": assistant_message, + "completed_tool_results": [], + "pending_tool_calls": [], + }, + ) for text in injections: messages.append({"role": "user", "content": text}) logger.info( @@ -354,7 +375,7 @@ class AgentRunner: await hook.after_iteration(context) break - messages.append(build_assistant_message( + messages.append(assistant_message or build_assistant_message( clean, reasoning_content=response.reasoning_content, thinking_blocks=response.thinking_blocks, diff --git a/tests/agent/test_runner.py b/tests/agent/test_runner.py index ba503e988..a9e32e0f8 100644 --- a/tests/agent/test_runner.py +++ b/tests/agent/test_runner.py @@ -1862,6 +1862,66 @@ async def test_checkpoint2_injects_after_final_response_with_resuming_stream(): assert stream_end_calls[-1] is False +@pytest.mark.asyncio +async def test_checkpoint2_preserves_final_response_in_history_before_followup(): + """A follow-up injected after a final answer must still see that answer in history.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.bus.events import InboundMessage + + provider = MagicMock() + call_count = {"n": 0} + captured_messages = [] + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + captured_messages.append([dict(message) for message in messages]) + if call_count["n"] == 1: + return LLMResponse(content="first answer", tool_calls=[], usage={}) + return LLMResponse(content="second answer", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + + injection_queue = asyncio.Queue() + + async def inject_cb(): + items = [] + while not injection_queue.empty(): + items.append(await injection_queue.get()) + return items + + await injection_queue.put( + InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up question") + ) + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "hello"}], + tools=tools, + model="test-model", + max_iterations=5, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + injection_callback=inject_cb, + )) + + assert result.final_content == "second answer" + assert call_count["n"] == 2 + assert captured_messages[-1] == [ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "first answer"}, + {"role": "user", "content": "follow-up question"}, + ] + assert [ + {"role": message["role"], "content": message["content"]} + for message in result.messages + if message.get("role") == "assistant" + ] == [ + {"role": "assistant", "content": "first answer"}, + {"role": "assistant", "content": "second answer"}, + ] + + @pytest.mark.asyncio async def test_injection_cycles_capped_at_max(): """Injection cycles should be capped at _MAX_INJECTION_CYCLES.""" @@ -1953,25 +2013,33 @@ async def test_pending_queue_cleanup_on_dispatch(tmp_path): @pytest.mark.asyncio async def test_followup_routed_to_pending_queue(tmp_path): - """When a session has an active dispatch, follow-up messages go to pending queue.""" + """Unified-session follow-ups should route into the active pending queue.""" + from nanobot.agent.loop import UNIFIED_SESSION_KEY from nanobot.bus.events import InboundMessage loop = _make_loop(tmp_path) + loop._unified_session = True + loop._dispatch = AsyncMock() # type: ignore[method-assign] - # Simulate an active dispatch by manually adding a pending queue pending = asyncio.Queue(maxsize=20) - loop._pending_queues["cli:c"] = pending + loop._pending_queues[UNIFIED_SESSION_KEY] = pending - msg = InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up") + run_task = asyncio.create_task(loop.run()) + msg = InboundMessage(channel="discord", sender_id="u", chat_id="c", content="follow-up") + await loop.bus.publish_inbound(msg) - # Directly test the routing logic from run() — if session_key is in - # _pending_queues, the message should be put into the queue. - assert msg.session_key in loop._pending_queues - loop._pending_queues[msg.session_key].put_nowait(msg) + deadline = time.time() + 2 + while pending.empty() and time.time() < deadline: + await asyncio.sleep(0.01) + loop.stop() + await asyncio.wait_for(run_task, timeout=2) + + assert loop._dispatch.await_count == 0 assert not pending.empty() queued_msg = pending.get_nowait() assert queued_msg.content == "follow-up" + assert queued_msg.session_key == UNIFIED_SESSION_KEY @pytest.mark.asyncio