feat(agent): enhance message injection handling and content merging

This commit is contained in:
Xubin Ren 2026-04-11 13:37:47 +00:00 committed by Xubin Ren
parent f6c39ec946
commit cf8381f517
4 changed files with 358 additions and 41 deletions

View File

@ -17,7 +17,7 @@ from nanobot.agent.autocompact import AutoCompact
from nanobot.agent.context import ContextBuilder
from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook
from nanobot.agent.memory import Consolidator, Dream
from nanobot.agent.runner import AgentRunSpec, AgentRunner
from nanobot.agent.runner import _MAX_INJECTIONS_PER_TURN, AgentRunSpec, AgentRunner
from nanobot.agent.subagent import SubagentManager
from nanobot.agent.tools.cron import CronTool
from nanobot.agent.skills import BUILTIN_SKILLS_DIR
@ -370,16 +370,30 @@ class AgentLoop:
return
self._set_runtime_checkpoint(session, payload)
async def _drain_pending() -> list[InboundMessage]:
async def _drain_pending(*, limit: int = _MAX_INJECTIONS_PER_TURN) -> list[dict[str, Any]]:
"""Non-blocking drain of follow-up messages from the pending queue."""
if pending_queue is None:
return []
items: list[InboundMessage] = []
while True:
items: list[dict[str, Any]] = []
while len(items) < limit:
try:
items.append(pending_queue.get_nowait())
pending_msg = pending_queue.get_nowait()
except asyncio.QueueEmpty:
break
user_content = self.context._build_user_content(
pending_msg.content,
pending_msg.media if pending_msg.media else None,
)
runtime_ctx = self.context._build_runtime_context(
pending_msg.channel,
pending_msg.chat_id,
self.context.timezone,
)
if isinstance(user_content, str):
merged: str | list[dict[str, Any]] = f"{runtime_ctx}\n\n{user_content}"
else:
merged = [{"type": "text", "text": runtime_ctx}] + user_content
items.append({"role": "user", "content": merged})
return items
result = await self.runner.run(AgentRunSpec(
@ -451,7 +465,7 @@ class AgentLoop:
self._pending_queues[effective_key].put_nowait(pending_msg)
except asyncio.QueueFull:
logger.warning(
"Pending queue full for session {}, dropping follow-up",
"Pending queue full for session {}, falling back to queued task",
effective_key,
)
else:
@ -459,7 +473,7 @@ class AgentLoop:
"Routed follow-up message to pending queue for session {}",
effective_key,
)
continue
continue
# Compute the effective session key before dispatching
# This ensures /stop command can find tasks correctly when unified session is enabled
task = asyncio.create_task(self._dispatch(msg))
@ -697,12 +711,14 @@ class AgentLoop:
self.sessions.save(session)
self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session))
# When follow-up messages were injected mid-turn, the LLM's final
# response addresses those follow-ups. Always send the response in
# this case, even if MessageTool was used earlier in the turn — the
# follow-up response is new content the user hasn't seen.
if not had_injections:
if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn:
# When follow-up messages were injected mid-turn, a later natural
# language reply may address those follow-ups and should not be
# suppressed just because MessageTool was used earlier in the turn.
# However, if the turn falls back to the empty-final-response
# placeholder, suppress it when the real user-visible output already
# came from MessageTool.
if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn:
if not had_injections or stop_reason == "empty_final_response":
return None
preview = final_content[:120] + "..." if len(final_content) > 120 else final_content

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import asyncio
from dataclasses import dataclass, field
import inspect
from pathlib import Path
from typing import Any
@ -94,37 +95,89 @@ class AgentRunner:
def __init__(self, provider: LLMProvider):
self.provider = provider
async def _drain_injections(self, spec: AgentRunSpec) -> list[str]:
@staticmethod
def _merge_message_content(left: Any, right: Any) -> str | list[dict[str, Any]]:
if isinstance(left, str) and isinstance(right, str):
return f"{left}\n\n{right}" if left else right
def _to_blocks(value: Any) -> list[dict[str, Any]]:
if isinstance(value, list):
return [
item if isinstance(item, dict) else {"type": "text", "text": str(item)}
for item in value
]
if value is None:
return []
return [{"type": "text", "text": str(value)}]
return _to_blocks(left) + _to_blocks(right)
@classmethod
def _append_injected_messages(
cls,
messages: list[dict[str, Any]],
injections: list[dict[str, Any]],
) -> None:
"""Append injected user messages while preserving role alternation."""
for injection in injections:
if (
messages
and injection.get("role") == "user"
and messages[-1].get("role") == "user"
):
merged = dict(messages[-1])
merged["content"] = cls._merge_message_content(
merged.get("content"),
injection.get("content"),
)
messages[-1] = merged
continue
messages.append(injection)
async def _drain_injections(self, spec: AgentRunSpec) -> list[dict[str, Any]]:
"""Drain pending user messages via the injection callback.
Returns all drained message contents (capped by
Returns normalized user messages (capped by
``_MAX_INJECTIONS_PER_TURN``), or an empty list when there is
nothing to inject. Messages beyond the cap are logged so they
nothing to inject. Messages beyond the cap are logged so they
are not silently lost.
"""
if spec.injection_callback is None:
return []
try:
items = await spec.injection_callback()
signature = inspect.signature(spec.injection_callback)
accepts_limit = (
"limit" in signature.parameters
or any(
parameter.kind is inspect.Parameter.VAR_KEYWORD
for parameter in signature.parameters.values()
)
)
if accepts_limit:
items = await spec.injection_callback(limit=_MAX_INJECTIONS_PER_TURN)
else:
items = await spec.injection_callback()
except Exception:
logger.exception("injection_callback failed")
return []
if not items:
return []
# items are InboundMessage objects from _drain_pending
texts: list[str] = []
injected_messages: list[dict[str, Any]] = []
for item in items:
if isinstance(item, dict) and item.get("role") == "user" and "content" in item:
injected_messages.append(item)
continue
text = getattr(item, "content", str(item))
if text.strip():
texts.append(text)
if len(texts) > _MAX_INJECTIONS_PER_TURN:
dropped = len(texts) - _MAX_INJECTIONS_PER_TURN
injected_messages.append({"role": "user", "content": text})
if len(injected_messages) > _MAX_INJECTIONS_PER_TURN:
dropped = len(injected_messages) - _MAX_INJECTIONS_PER_TURN
logger.warning(
"Injection batch has {} messages, capping to {} ({} dropped)",
len(texts), _MAX_INJECTIONS_PER_TURN, dropped,
"Injection callback returned {} messages, capping to {} ({} dropped)",
len(injected_messages), _MAX_INJECTIONS_PER_TURN, dropped,
)
texts = texts[-_MAX_INJECTIONS_PER_TURN:]
return texts
injected_messages = injected_messages[:_MAX_INJECTIONS_PER_TURN]
return injected_messages
async def run(self, spec: AgentRunSpec) -> AgentRunResult:
hook = spec.hook or AgentHook()
@ -247,8 +300,7 @@ class AgentRunner:
if injections:
had_injections = True
injection_cycles += 1
for text in injections:
messages.append({"role": "user", "content": text})
self._append_injected_messages(messages, injections)
logger.info(
"Injected {} follow-up message(s) after tool execution ({}/{})",
len(injections), injection_cycles, _MAX_INJECTION_CYCLES,
@ -340,8 +392,7 @@ class AgentRunner:
"pending_tool_calls": [],
},
)
for text in injections:
messages.append({"role": "user", "content": text})
self._append_injected_messages(messages, injections)
logger.info(
"Injected {} follow-up message(s) after final response ({}/{})",
len(injections), injection_cycles, _MAX_INJECTION_CYCLES,

View File

@ -3,6 +3,7 @@
from __future__ import annotations
import asyncio
import base64
import os
import time
from unittest.mock import AsyncMock, MagicMock, patch
@ -1633,7 +1634,7 @@ async def test_drain_injections_returns_empty_when_no_callback():
@pytest.mark.asyncio
async def test_drain_injections_extracts_content_from_inbound_messages():
"""Should extract .content from InboundMessage objects."""
from nanobot.agent.runner import AgentRunSpec, AgentRunner, _MAX_INJECTIONS_PER_TURN
from nanobot.agent.runner import AgentRunSpec, AgentRunner
from nanobot.bus.events import InboundMessage
provider = MagicMock()
@ -1655,12 +1656,15 @@ async def test_drain_injections_extracts_content_from_inbound_messages():
injection_callback=cb,
)
result = await runner._drain_injections(spec)
assert result == ["hello", "world"]
assert result == [
{"role": "user", "content": "hello"},
{"role": "user", "content": "world"},
]
@pytest.mark.asyncio
async def test_drain_injections_caps_at_max_and_logs_warning():
"""When more than _MAX_INJECTIONS_PER_TURN items, only the last N are kept."""
async def test_drain_injections_passes_limit_to_callback_when_supported():
"""Limit-aware callbacks can preserve overflow in their own queue."""
from nanobot.agent.runner import AgentRunSpec, AgentRunner, _MAX_INJECTIONS_PER_TURN
from nanobot.bus.events import InboundMessage
@ -1668,14 +1672,16 @@ async def test_drain_injections_caps_at_max_and_logs_warning():
runner = AgentRunner(provider)
tools = MagicMock()
tools.get_definitions.return_value = []
seen_limits: list[int] = []
msgs = [
InboundMessage(channel="cli", sender_id="u", chat_id="c", content=f"msg{i}")
for i in range(_MAX_INJECTIONS_PER_TURN + 3)
]
async def cb():
return msgs
async def cb(*, limit: int):
seen_limits.append(limit)
return msgs[:limit]
spec = AgentRunSpec(
initial_messages=[], tools=tools, model="m",
@ -1683,10 +1689,12 @@ async def test_drain_injections_caps_at_max_and_logs_warning():
injection_callback=cb,
)
result = await runner._drain_injections(spec)
assert len(result) == _MAX_INJECTIONS_PER_TURN
# Should keep the LAST _MAX_INJECTIONS_PER_TURN items
assert result[0] == "msg3"
assert result[-1] == f"msg{_MAX_INJECTIONS_PER_TURN + 2}"
assert seen_limits == [_MAX_INJECTIONS_PER_TURN]
assert result == [
{"role": "user", "content": "msg0"},
{"role": "user", "content": "msg1"},
{"role": "user", "content": "msg2"},
]
@pytest.mark.asyncio
@ -1715,7 +1723,7 @@ async def test_drain_injections_skips_empty_content():
injection_callback=cb,
)
result = await runner._drain_injections(spec)
assert result == ["valid"]
assert result == [{"role": "user", "content": "valid"}]
@pytest.mark.asyncio
@ -1922,6 +1930,129 @@ async def test_checkpoint2_preserves_final_response_in_history_before_followup()
]
@pytest.mark.asyncio
async def test_loop_injected_followup_preserves_image_media(tmp_path):
"""Mid-turn follow-ups with images should keep multimodal content."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
image_path = tmp_path / "followup.png"
image_path.write_bytes(base64.b64decode(
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO+yF9kAAAAASUVORK5CYII="
))
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
captured_messages: list[list[dict]] = []
call_count = {"n": 0}
async def chat_with_retry(*, messages, **kwargs):
call_count["n"] += 1
captured_messages.append(list(messages))
if call_count["n"] == 1:
return LLMResponse(content="first answer", tool_calls=[], usage={})
return LLMResponse(content="second answer", tool_calls=[], usage={})
provider.chat_with_retry = chat_with_retry
loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model")
loop.tools.get_definitions = MagicMock(return_value=[])
pending_queue = asyncio.Queue()
await pending_queue.put(InboundMessage(
channel="cli",
sender_id="u",
chat_id="c",
content="",
media=[str(image_path)],
))
final_content, _, _, _, had_injections = await loop._run_agent_loop(
[{"role": "user", "content": "hello"}],
channel="cli",
chat_id="c",
pending_queue=pending_queue,
)
assert final_content == "second answer"
assert had_injections is True
assert call_count["n"] == 2
injected_user_messages = [
message for message in captured_messages[-1]
if message.get("role") == "user" and isinstance(message.get("content"), list)
]
assert injected_user_messages
assert any(
block.get("type") == "image_url"
for block in injected_user_messages[-1]["content"]
if isinstance(block, dict)
)
@pytest.mark.asyncio
async def test_runner_merges_multiple_injected_user_messages_without_losing_media():
"""Multiple injected follow-ups should not create lossy consecutive user messages."""
from nanobot.agent.runner import AgentRunSpec, AgentRunner
provider = MagicMock()
call_count = {"n": 0}
captured_messages = []
async def chat_with_retry(*, messages, **kwargs):
call_count["n"] += 1
captured_messages.append([dict(message) for message in messages])
if call_count["n"] == 1:
return LLMResponse(content="first answer", tool_calls=[], usage={})
return LLMResponse(content="second answer", tool_calls=[], usage={})
provider.chat_with_retry = chat_with_retry
tools = MagicMock()
tools.get_definitions.return_value = []
async def inject_cb():
if call_count["n"] == 1:
return [
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}},
{"type": "text", "text": "look at this"},
],
},
{"role": "user", "content": "and answer briefly"},
]
return []
runner = AgentRunner(provider)
result = await runner.run(AgentRunSpec(
initial_messages=[{"role": "user", "content": "hello"}],
tools=tools,
model="test-model",
max_iterations=5,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
injection_callback=inject_cb,
))
assert result.final_content == "second answer"
assert call_count["n"] == 2
second_call = captured_messages[-1]
user_messages = [message for message in second_call if message.get("role") == "user"]
assert len(user_messages) == 2
injected = user_messages[-1]
assert isinstance(injected["content"], list)
assert any(
block.get("type") == "image_url"
for block in injected["content"]
if isinstance(block, dict)
)
assert any(
block.get("type") == "text" and block.get("text") == "and answer briefly"
for block in injected["content"]
if isinstance(block, dict)
)
@pytest.mark.asyncio
async def test_injection_cycles_capped_at_max():
"""Injection cycles should be capped at _MAX_INJECTION_CYCLES."""
@ -2042,6 +2173,88 @@ async def test_followup_routed_to_pending_queue(tmp_path):
assert queued_msg.session_key == UNIFIED_SESSION_KEY
@pytest.mark.asyncio
async def test_pending_queue_preserves_overflow_for_next_injection_cycle(tmp_path):
"""Pending queue should leave overflow messages queued for later drains."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.agent.runner import _MAX_INJECTIONS_PER_TURN
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
captured_messages: list[list[dict]] = []
call_count = {"n": 0}
async def chat_with_retry(*, messages, **kwargs):
call_count["n"] += 1
captured_messages.append([dict(message) for message in messages])
return LLMResponse(content=f"answer-{call_count['n']}", tool_calls=[], usage={})
provider.chat_with_retry = chat_with_retry
loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model")
loop.tools.get_definitions = MagicMock(return_value=[])
pending_queue = asyncio.Queue()
total_followups = _MAX_INJECTIONS_PER_TURN + 2
for idx in range(total_followups):
await pending_queue.put(InboundMessage(
channel="cli",
sender_id="u",
chat_id="c",
content=f"follow-up-{idx}",
))
final_content, _, _, _, had_injections = await loop._run_agent_loop(
[{"role": "user", "content": "hello"}],
channel="cli",
chat_id="c",
pending_queue=pending_queue,
)
assert final_content == "answer-3"
assert had_injections is True
assert call_count["n"] == 3
flattened_user_content = "\n".join(
message["content"]
for message in captured_messages[-1]
if message.get("role") == "user" and isinstance(message.get("content"), str)
)
for idx in range(total_followups):
assert f"follow-up-{idx}" in flattened_user_content
assert pending_queue.empty()
@pytest.mark.asyncio
async def test_pending_queue_full_falls_back_to_queued_task(tmp_path):
"""QueueFull should preserve the message by dispatching a queued task."""
from nanobot.bus.events import InboundMessage
loop = _make_loop(tmp_path)
loop._dispatch = AsyncMock() # type: ignore[method-assign]
pending = asyncio.Queue(maxsize=1)
pending.put_nowait(InboundMessage(channel="cli", sender_id="u", chat_id="c", content="already queued"))
loop._pending_queues["cli:c"] = pending
run_task = asyncio.create_task(loop.run())
msg = InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up")
await loop.bus.publish_inbound(msg)
deadline = time.time() + 2
while loop._dispatch.await_count == 0 and time.time() < deadline:
await asyncio.sleep(0.01)
loop.stop()
await asyncio.wait_for(run_task, timeout=2)
assert loop._dispatch.await_count == 1
dispatched_msg = loop._dispatch.await_args.args[0]
assert dispatched_msg.content == "follow-up"
assert pending.qsize() == 1
@pytest.mark.asyncio
async def test_dispatch_republishes_leftover_queue_messages(tmp_path):
"""Messages left in the pending queue after _dispatch are re-published to the bus.

View File

@ -1,5 +1,6 @@
"""Test message tool suppress logic for final replies."""
import asyncio
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock
@ -86,6 +87,42 @@ class TestMessageToolSuppressLogic:
assert result is not None
assert "Hello" in result.content
@pytest.mark.asyncio
async def test_injected_followup_with_message_tool_does_not_emit_empty_fallback(
self, tmp_path: Path
) -> None:
loop = _make_loop(tmp_path)
tool_call = ToolCallRequest(
id="call1", name="message",
arguments={"content": "Tool reply", "channel": "feishu", "chat_id": "chat123"},
)
calls = iter([
LLMResponse(content="First answer", tool_calls=[]),
LLMResponse(content="", tool_calls=[tool_call]),
LLMResponse(content="", tool_calls=[]),
LLMResponse(content="", tool_calls=[]),
LLMResponse(content="", tool_calls=[]),
])
loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls))
loop.tools.get_definitions = MagicMock(return_value=[])
sent: list[OutboundMessage] = []
mt = loop.tools.get("message")
if isinstance(mt, MessageTool):
mt.set_send_callback(AsyncMock(side_effect=lambda m: sent.append(m)))
pending_queue = asyncio.Queue()
await pending_queue.put(
InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="follow-up")
)
msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Start")
result = await loop._process_message(msg, pending_queue=pending_queue)
assert len(sent) == 1
assert sent[0].content == "Tool reply"
assert result is None
async def test_progress_hides_internal_reasoning(self, tmp_path: Path) -> None:
loop = _make_loop(tmp_path)
tool_call = ToolCallRequest(id="call1", name="read_file", arguments={"path": "foo.txt"})