From cf8381f517084b5e9ec8e14539dcdc5b0eab2baa Mon Sep 17 00:00:00 2001 From: Xubin Ren Date: Sat, 11 Apr 2026 13:37:47 +0000 Subject: [PATCH] feat(agent): enhance message injection handling and content merging --- nanobot/agent/loop.py | 42 ++-- nanobot/agent/runner.py | 85 ++++++-- tests/agent/test_runner.py | 235 +++++++++++++++++++++- tests/tools/test_message_tool_suppress.py | 37 ++++ 4 files changed, 358 insertions(+), 41 deletions(-) diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index 0f72e39f6..675865350 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -17,7 +17,7 @@ from nanobot.agent.autocompact import AutoCompact from nanobot.agent.context import ContextBuilder from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook from nanobot.agent.memory import Consolidator, Dream -from nanobot.agent.runner import AgentRunSpec, AgentRunner +from nanobot.agent.runner import _MAX_INJECTIONS_PER_TURN, AgentRunSpec, AgentRunner from nanobot.agent.subagent import SubagentManager from nanobot.agent.tools.cron import CronTool from nanobot.agent.skills import BUILTIN_SKILLS_DIR @@ -370,16 +370,30 @@ class AgentLoop: return self._set_runtime_checkpoint(session, payload) - async def _drain_pending() -> list[InboundMessage]: + async def _drain_pending(*, limit: int = _MAX_INJECTIONS_PER_TURN) -> list[dict[str, Any]]: """Non-blocking drain of follow-up messages from the pending queue.""" if pending_queue is None: return [] - items: list[InboundMessage] = [] - while True: + items: list[dict[str, Any]] = [] + while len(items) < limit: try: - items.append(pending_queue.get_nowait()) + pending_msg = pending_queue.get_nowait() except asyncio.QueueEmpty: break + user_content = self.context._build_user_content( + pending_msg.content, + pending_msg.media if pending_msg.media else None, + ) + runtime_ctx = self.context._build_runtime_context( + pending_msg.channel, + pending_msg.chat_id, + self.context.timezone, + ) + if isinstance(user_content, str): + merged: str | list[dict[str, Any]] = f"{runtime_ctx}\n\n{user_content}" + else: + merged = [{"type": "text", "text": runtime_ctx}] + user_content + items.append({"role": "user", "content": merged}) return items result = await self.runner.run(AgentRunSpec( @@ -451,7 +465,7 @@ class AgentLoop: self._pending_queues[effective_key].put_nowait(pending_msg) except asyncio.QueueFull: logger.warning( - "Pending queue full for session {}, dropping follow-up", + "Pending queue full for session {}, falling back to queued task", effective_key, ) else: @@ -459,7 +473,7 @@ class AgentLoop: "Routed follow-up message to pending queue for session {}", effective_key, ) - continue + continue # Compute the effective session key before dispatching # This ensures /stop command can find tasks correctly when unified session is enabled task = asyncio.create_task(self._dispatch(msg)) @@ -697,12 +711,14 @@ class AgentLoop: self.sessions.save(session) self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session)) - # When follow-up messages were injected mid-turn, the LLM's final - # response addresses those follow-ups. Always send the response in - # this case, even if MessageTool was used earlier in the turn — the - # follow-up response is new content the user hasn't seen. - if not had_injections: - if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn: + # When follow-up messages were injected mid-turn, a later natural + # language reply may address those follow-ups and should not be + # suppressed just because MessageTool was used earlier in the turn. + # However, if the turn falls back to the empty-final-response + # placeholder, suppress it when the real user-visible output already + # came from MessageTool. + if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn: + if not had_injections or stop_reason == "empty_final_response": return None preview = final_content[:120] + "..." if len(final_content) > 120 else final_content diff --git a/nanobot/agent/runner.py b/nanobot/agent/runner.py index 0ba0e6bc6..164921bb4 100644 --- a/nanobot/agent/runner.py +++ b/nanobot/agent/runner.py @@ -4,6 +4,7 @@ from __future__ import annotations import asyncio from dataclasses import dataclass, field +import inspect from pathlib import Path from typing import Any @@ -94,37 +95,89 @@ class AgentRunner: def __init__(self, provider: LLMProvider): self.provider = provider - async def _drain_injections(self, spec: AgentRunSpec) -> list[str]: + @staticmethod + def _merge_message_content(left: Any, right: Any) -> str | list[dict[str, Any]]: + if isinstance(left, str) and isinstance(right, str): + return f"{left}\n\n{right}" if left else right + + def _to_blocks(value: Any) -> list[dict[str, Any]]: + if isinstance(value, list): + return [ + item if isinstance(item, dict) else {"type": "text", "text": str(item)} + for item in value + ] + if value is None: + return [] + return [{"type": "text", "text": str(value)}] + + return _to_blocks(left) + _to_blocks(right) + + @classmethod + def _append_injected_messages( + cls, + messages: list[dict[str, Any]], + injections: list[dict[str, Any]], + ) -> None: + """Append injected user messages while preserving role alternation.""" + for injection in injections: + if ( + messages + and injection.get("role") == "user" + and messages[-1].get("role") == "user" + ): + merged = dict(messages[-1]) + merged["content"] = cls._merge_message_content( + merged.get("content"), + injection.get("content"), + ) + messages[-1] = merged + continue + messages.append(injection) + + async def _drain_injections(self, spec: AgentRunSpec) -> list[dict[str, Any]]: """Drain pending user messages via the injection callback. - Returns all drained message contents (capped by + Returns normalized user messages (capped by ``_MAX_INJECTIONS_PER_TURN``), or an empty list when there is - nothing to inject. Messages beyond the cap are logged so they + nothing to inject. Messages beyond the cap are logged so they are not silently lost. """ if spec.injection_callback is None: return [] try: - items = await spec.injection_callback() + signature = inspect.signature(spec.injection_callback) + accepts_limit = ( + "limit" in signature.parameters + or any( + parameter.kind is inspect.Parameter.VAR_KEYWORD + for parameter in signature.parameters.values() + ) + ) + if accepts_limit: + items = await spec.injection_callback(limit=_MAX_INJECTIONS_PER_TURN) + else: + items = await spec.injection_callback() except Exception: logger.exception("injection_callback failed") return [] if not items: return [] - # items are InboundMessage objects from _drain_pending - texts: list[str] = [] + injected_messages: list[dict[str, Any]] = [] for item in items: + if isinstance(item, dict) and item.get("role") == "user" and "content" in item: + injected_messages.append(item) + continue text = getattr(item, "content", str(item)) if text.strip(): - texts.append(text) - if len(texts) > _MAX_INJECTIONS_PER_TURN: - dropped = len(texts) - _MAX_INJECTIONS_PER_TURN + injected_messages.append({"role": "user", "content": text}) + if len(injected_messages) > _MAX_INJECTIONS_PER_TURN: + dropped = len(injected_messages) - _MAX_INJECTIONS_PER_TURN logger.warning( - "Injection batch has {} messages, capping to {} ({} dropped)", - len(texts), _MAX_INJECTIONS_PER_TURN, dropped, + "Injection callback returned {} messages, capping to {} ({} dropped)", + len(injected_messages), _MAX_INJECTIONS_PER_TURN, dropped, ) - texts = texts[-_MAX_INJECTIONS_PER_TURN:] - return texts + injected_messages = injected_messages[:_MAX_INJECTIONS_PER_TURN] + return injected_messages async def run(self, spec: AgentRunSpec) -> AgentRunResult: hook = spec.hook or AgentHook() @@ -247,8 +300,7 @@ class AgentRunner: if injections: had_injections = True injection_cycles += 1 - for text in injections: - messages.append({"role": "user", "content": text}) + self._append_injected_messages(messages, injections) logger.info( "Injected {} follow-up message(s) after tool execution ({}/{})", len(injections), injection_cycles, _MAX_INJECTION_CYCLES, @@ -340,8 +392,7 @@ class AgentRunner: "pending_tool_calls": [], }, ) - for text in injections: - messages.append({"role": "user", "content": text}) + self._append_injected_messages(messages, injections) logger.info( "Injected {} follow-up message(s) after final response ({}/{})", len(injections), injection_cycles, _MAX_INJECTION_CYCLES, diff --git a/tests/agent/test_runner.py b/tests/agent/test_runner.py index a9e32e0f8..b9047b674 100644 --- a/tests/agent/test_runner.py +++ b/tests/agent/test_runner.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio +import base64 import os import time from unittest.mock import AsyncMock, MagicMock, patch @@ -1633,7 +1634,7 @@ async def test_drain_injections_returns_empty_when_no_callback(): @pytest.mark.asyncio async def test_drain_injections_extracts_content_from_inbound_messages(): """Should extract .content from InboundMessage objects.""" - from nanobot.agent.runner import AgentRunSpec, AgentRunner, _MAX_INJECTIONS_PER_TURN + from nanobot.agent.runner import AgentRunSpec, AgentRunner from nanobot.bus.events import InboundMessage provider = MagicMock() @@ -1655,12 +1656,15 @@ async def test_drain_injections_extracts_content_from_inbound_messages(): injection_callback=cb, ) result = await runner._drain_injections(spec) - assert result == ["hello", "world"] + assert result == [ + {"role": "user", "content": "hello"}, + {"role": "user", "content": "world"}, + ] @pytest.mark.asyncio -async def test_drain_injections_caps_at_max_and_logs_warning(): - """When more than _MAX_INJECTIONS_PER_TURN items, only the last N are kept.""" +async def test_drain_injections_passes_limit_to_callback_when_supported(): + """Limit-aware callbacks can preserve overflow in their own queue.""" from nanobot.agent.runner import AgentRunSpec, AgentRunner, _MAX_INJECTIONS_PER_TURN from nanobot.bus.events import InboundMessage @@ -1668,14 +1672,16 @@ async def test_drain_injections_caps_at_max_and_logs_warning(): runner = AgentRunner(provider) tools = MagicMock() tools.get_definitions.return_value = [] + seen_limits: list[int] = [] msgs = [ InboundMessage(channel="cli", sender_id="u", chat_id="c", content=f"msg{i}") for i in range(_MAX_INJECTIONS_PER_TURN + 3) ] - async def cb(): - return msgs + async def cb(*, limit: int): + seen_limits.append(limit) + return msgs[:limit] spec = AgentRunSpec( initial_messages=[], tools=tools, model="m", @@ -1683,10 +1689,12 @@ async def test_drain_injections_caps_at_max_and_logs_warning(): injection_callback=cb, ) result = await runner._drain_injections(spec) - assert len(result) == _MAX_INJECTIONS_PER_TURN - # Should keep the LAST _MAX_INJECTIONS_PER_TURN items - assert result[0] == "msg3" - assert result[-1] == f"msg{_MAX_INJECTIONS_PER_TURN + 2}" + assert seen_limits == [_MAX_INJECTIONS_PER_TURN] + assert result == [ + {"role": "user", "content": "msg0"}, + {"role": "user", "content": "msg1"}, + {"role": "user", "content": "msg2"}, + ] @pytest.mark.asyncio @@ -1715,7 +1723,7 @@ async def test_drain_injections_skips_empty_content(): injection_callback=cb, ) result = await runner._drain_injections(spec) - assert result == ["valid"] + assert result == [{"role": "user", "content": "valid"}] @pytest.mark.asyncio @@ -1922,6 +1930,129 @@ async def test_checkpoint2_preserves_final_response_in_history_before_followup() ] +@pytest.mark.asyncio +async def test_loop_injected_followup_preserves_image_media(tmp_path): + """Mid-turn follow-ups with images should keep multimodal content.""" + from nanobot.agent.loop import AgentLoop + from nanobot.bus.events import InboundMessage + from nanobot.bus.queue import MessageBus + + image_path = tmp_path / "followup.png" + image_path.write_bytes(base64.b64decode( + "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO+yF9kAAAAASUVORK5CYII=" + )) + + bus = MessageBus() + provider = MagicMock() + provider.get_default_model.return_value = "test-model" + captured_messages: list[list[dict]] = [] + call_count = {"n": 0} + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + captured_messages.append(list(messages)) + if call_count["n"] == 1: + return LLMResponse(content="first answer", tool_calls=[], usage={}) + return LLMResponse(content="second answer", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model") + loop.tools.get_definitions = MagicMock(return_value=[]) + + pending_queue = asyncio.Queue() + await pending_queue.put(InboundMessage( + channel="cli", + sender_id="u", + chat_id="c", + content="", + media=[str(image_path)], + )) + + final_content, _, _, _, had_injections = await loop._run_agent_loop( + [{"role": "user", "content": "hello"}], + channel="cli", + chat_id="c", + pending_queue=pending_queue, + ) + + assert final_content == "second answer" + assert had_injections is True + assert call_count["n"] == 2 + injected_user_messages = [ + message for message in captured_messages[-1] + if message.get("role") == "user" and isinstance(message.get("content"), list) + ] + assert injected_user_messages + assert any( + block.get("type") == "image_url" + for block in injected_user_messages[-1]["content"] + if isinstance(block, dict) + ) + + +@pytest.mark.asyncio +async def test_runner_merges_multiple_injected_user_messages_without_losing_media(): + """Multiple injected follow-ups should not create lossy consecutive user messages.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock() + call_count = {"n": 0} + captured_messages = [] + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + captured_messages.append([dict(message) for message in messages]) + if call_count["n"] == 1: + return LLMResponse(content="first answer", tool_calls=[], usage={}) + return LLMResponse(content="second answer", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + + async def inject_cb(): + if call_count["n"] == 1: + return [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}}, + {"type": "text", "text": "look at this"}, + ], + }, + {"role": "user", "content": "and answer briefly"}, + ] + return [] + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "hello"}], + tools=tools, + model="test-model", + max_iterations=5, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + injection_callback=inject_cb, + )) + + assert result.final_content == "second answer" + assert call_count["n"] == 2 + second_call = captured_messages[-1] + user_messages = [message for message in second_call if message.get("role") == "user"] + assert len(user_messages) == 2 + injected = user_messages[-1] + assert isinstance(injected["content"], list) + assert any( + block.get("type") == "image_url" + for block in injected["content"] + if isinstance(block, dict) + ) + assert any( + block.get("type") == "text" and block.get("text") == "and answer briefly" + for block in injected["content"] + if isinstance(block, dict) + ) + + @pytest.mark.asyncio async def test_injection_cycles_capped_at_max(): """Injection cycles should be capped at _MAX_INJECTION_CYCLES.""" @@ -2042,6 +2173,88 @@ async def test_followup_routed_to_pending_queue(tmp_path): assert queued_msg.session_key == UNIFIED_SESSION_KEY +@pytest.mark.asyncio +async def test_pending_queue_preserves_overflow_for_next_injection_cycle(tmp_path): + """Pending queue should leave overflow messages queued for later drains.""" + from nanobot.agent.loop import AgentLoop + from nanobot.bus.events import InboundMessage + from nanobot.bus.queue import MessageBus + from nanobot.agent.runner import _MAX_INJECTIONS_PER_TURN + + bus = MessageBus() + provider = MagicMock() + provider.get_default_model.return_value = "test-model" + captured_messages: list[list[dict]] = [] + call_count = {"n": 0} + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + captured_messages.append([dict(message) for message in messages]) + return LLMResponse(content=f"answer-{call_count['n']}", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model") + loop.tools.get_definitions = MagicMock(return_value=[]) + + pending_queue = asyncio.Queue() + total_followups = _MAX_INJECTIONS_PER_TURN + 2 + for idx in range(total_followups): + await pending_queue.put(InboundMessage( + channel="cli", + sender_id="u", + chat_id="c", + content=f"follow-up-{idx}", + )) + + final_content, _, _, _, had_injections = await loop._run_agent_loop( + [{"role": "user", "content": "hello"}], + channel="cli", + chat_id="c", + pending_queue=pending_queue, + ) + + assert final_content == "answer-3" + assert had_injections is True + assert call_count["n"] == 3 + flattened_user_content = "\n".join( + message["content"] + for message in captured_messages[-1] + if message.get("role") == "user" and isinstance(message.get("content"), str) + ) + for idx in range(total_followups): + assert f"follow-up-{idx}" in flattened_user_content + assert pending_queue.empty() + + +@pytest.mark.asyncio +async def test_pending_queue_full_falls_back_to_queued_task(tmp_path): + """QueueFull should preserve the message by dispatching a queued task.""" + from nanobot.bus.events import InboundMessage + + loop = _make_loop(tmp_path) + loop._dispatch = AsyncMock() # type: ignore[method-assign] + + pending = asyncio.Queue(maxsize=1) + pending.put_nowait(InboundMessage(channel="cli", sender_id="u", chat_id="c", content="already queued")) + loop._pending_queues["cli:c"] = pending + + run_task = asyncio.create_task(loop.run()) + msg = InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up") + await loop.bus.publish_inbound(msg) + + deadline = time.time() + 2 + while loop._dispatch.await_count == 0 and time.time() < deadline: + await asyncio.sleep(0.01) + + loop.stop() + await asyncio.wait_for(run_task, timeout=2) + + assert loop._dispatch.await_count == 1 + dispatched_msg = loop._dispatch.await_args.args[0] + assert dispatched_msg.content == "follow-up" + assert pending.qsize() == 1 + + @pytest.mark.asyncio async def test_dispatch_republishes_leftover_queue_messages(tmp_path): """Messages left in the pending queue after _dispatch are re-published to the bus. diff --git a/tests/tools/test_message_tool_suppress.py b/tests/tools/test_message_tool_suppress.py index a922e95ed..434b2ca71 100644 --- a/tests/tools/test_message_tool_suppress.py +++ b/tests/tools/test_message_tool_suppress.py @@ -1,5 +1,6 @@ """Test message tool suppress logic for final replies.""" +import asyncio from pathlib import Path from unittest.mock import AsyncMock, MagicMock @@ -86,6 +87,42 @@ class TestMessageToolSuppressLogic: assert result is not None assert "Hello" in result.content + @pytest.mark.asyncio + async def test_injected_followup_with_message_tool_does_not_emit_empty_fallback( + self, tmp_path: Path + ) -> None: + loop = _make_loop(tmp_path) + tool_call = ToolCallRequest( + id="call1", name="message", + arguments={"content": "Tool reply", "channel": "feishu", "chat_id": "chat123"}, + ) + calls = iter([ + LLMResponse(content="First answer", tool_calls=[]), + LLMResponse(content="", tool_calls=[tool_call]), + LLMResponse(content="", tool_calls=[]), + LLMResponse(content="", tool_calls=[]), + LLMResponse(content="", tool_calls=[]), + ]) + loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls)) + loop.tools.get_definitions = MagicMock(return_value=[]) + + sent: list[OutboundMessage] = [] + mt = loop.tools.get("message") + if isinstance(mt, MessageTool): + mt.set_send_callback(AsyncMock(side_effect=lambda m: sent.append(m))) + + pending_queue = asyncio.Queue() + await pending_queue.put( + InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="follow-up") + ) + + msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Start") + result = await loop._process_message(msg, pending_queue=pending_queue) + + assert len(sent) == 1 + assert sent[0].content == "Tool reply" + assert result is None + async def test_progress_hides_internal_reasoning(self, tmp_path: Path) -> None: loop = _make_loop(tmp_path) tool_call = ToolCallRequest(id="call1", name="read_file", arguments={"path": "foo.txt"})