feat(agent): enhance message injection handling and content merging

This commit is contained in:
Xubin Ren 2026-04-11 13:37:47 +00:00 committed by Xubin Ren
parent f6c39ec946
commit cf8381f517
4 changed files with 358 additions and 41 deletions

View File

@ -17,7 +17,7 @@ from nanobot.agent.autocompact import AutoCompact
from nanobot.agent.context import ContextBuilder from nanobot.agent.context import ContextBuilder
from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook
from nanobot.agent.memory import Consolidator, Dream from nanobot.agent.memory import Consolidator, Dream
from nanobot.agent.runner import AgentRunSpec, AgentRunner from nanobot.agent.runner import _MAX_INJECTIONS_PER_TURN, AgentRunSpec, AgentRunner
from nanobot.agent.subagent import SubagentManager from nanobot.agent.subagent import SubagentManager
from nanobot.agent.tools.cron import CronTool from nanobot.agent.tools.cron import CronTool
from nanobot.agent.skills import BUILTIN_SKILLS_DIR from nanobot.agent.skills import BUILTIN_SKILLS_DIR
@ -370,16 +370,30 @@ class AgentLoop:
return return
self._set_runtime_checkpoint(session, payload) self._set_runtime_checkpoint(session, payload)
async def _drain_pending() -> list[InboundMessage]: async def _drain_pending(*, limit: int = _MAX_INJECTIONS_PER_TURN) -> list[dict[str, Any]]:
"""Non-blocking drain of follow-up messages from the pending queue.""" """Non-blocking drain of follow-up messages from the pending queue."""
if pending_queue is None: if pending_queue is None:
return [] return []
items: list[InboundMessage] = [] items: list[dict[str, Any]] = []
while True: while len(items) < limit:
try: try:
items.append(pending_queue.get_nowait()) pending_msg = pending_queue.get_nowait()
except asyncio.QueueEmpty: except asyncio.QueueEmpty:
break break
user_content = self.context._build_user_content(
pending_msg.content,
pending_msg.media if pending_msg.media else None,
)
runtime_ctx = self.context._build_runtime_context(
pending_msg.channel,
pending_msg.chat_id,
self.context.timezone,
)
if isinstance(user_content, str):
merged: str | list[dict[str, Any]] = f"{runtime_ctx}\n\n{user_content}"
else:
merged = [{"type": "text", "text": runtime_ctx}] + user_content
items.append({"role": "user", "content": merged})
return items return items
result = await self.runner.run(AgentRunSpec( result = await self.runner.run(AgentRunSpec(
@ -451,7 +465,7 @@ class AgentLoop:
self._pending_queues[effective_key].put_nowait(pending_msg) self._pending_queues[effective_key].put_nowait(pending_msg)
except asyncio.QueueFull: except asyncio.QueueFull:
logger.warning( logger.warning(
"Pending queue full for session {}, dropping follow-up", "Pending queue full for session {}, falling back to queued task",
effective_key, effective_key,
) )
else: else:
@ -459,7 +473,7 @@ class AgentLoop:
"Routed follow-up message to pending queue for session {}", "Routed follow-up message to pending queue for session {}",
effective_key, effective_key,
) )
continue continue
# Compute the effective session key before dispatching # Compute the effective session key before dispatching
# This ensures /stop command can find tasks correctly when unified session is enabled # This ensures /stop command can find tasks correctly when unified session is enabled
task = asyncio.create_task(self._dispatch(msg)) task = asyncio.create_task(self._dispatch(msg))
@ -697,12 +711,14 @@ class AgentLoop:
self.sessions.save(session) self.sessions.save(session)
self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session)) self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session))
# When follow-up messages were injected mid-turn, the LLM's final # When follow-up messages were injected mid-turn, a later natural
# response addresses those follow-ups. Always send the response in # language reply may address those follow-ups and should not be
# this case, even if MessageTool was used earlier in the turn — the # suppressed just because MessageTool was used earlier in the turn.
# follow-up response is new content the user hasn't seen. # However, if the turn falls back to the empty-final-response
if not had_injections: # placeholder, suppress it when the real user-visible output already
if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn: # came from MessageTool.
if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn:
if not had_injections or stop_reason == "empty_final_response":
return None return None
preview = final_content[:120] + "..." if len(final_content) > 120 else final_content preview = final_content[:120] + "..." if len(final_content) > 120 else final_content

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import asyncio import asyncio
from dataclasses import dataclass, field from dataclasses import dataclass, field
import inspect
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
@ -94,37 +95,89 @@ class AgentRunner:
def __init__(self, provider: LLMProvider): def __init__(self, provider: LLMProvider):
self.provider = provider self.provider = provider
async def _drain_injections(self, spec: AgentRunSpec) -> list[str]: @staticmethod
def _merge_message_content(left: Any, right: Any) -> str | list[dict[str, Any]]:
if isinstance(left, str) and isinstance(right, str):
return f"{left}\n\n{right}" if left else right
def _to_blocks(value: Any) -> list[dict[str, Any]]:
if isinstance(value, list):
return [
item if isinstance(item, dict) else {"type": "text", "text": str(item)}
for item in value
]
if value is None:
return []
return [{"type": "text", "text": str(value)}]
return _to_blocks(left) + _to_blocks(right)
@classmethod
def _append_injected_messages(
cls,
messages: list[dict[str, Any]],
injections: list[dict[str, Any]],
) -> None:
"""Append injected user messages while preserving role alternation."""
for injection in injections:
if (
messages
and injection.get("role") == "user"
and messages[-1].get("role") == "user"
):
merged = dict(messages[-1])
merged["content"] = cls._merge_message_content(
merged.get("content"),
injection.get("content"),
)
messages[-1] = merged
continue
messages.append(injection)
async def _drain_injections(self, spec: AgentRunSpec) -> list[dict[str, Any]]:
"""Drain pending user messages via the injection callback. """Drain pending user messages via the injection callback.
Returns all drained message contents (capped by Returns normalized user messages (capped by
``_MAX_INJECTIONS_PER_TURN``), or an empty list when there is ``_MAX_INJECTIONS_PER_TURN``), or an empty list when there is
nothing to inject. Messages beyond the cap are logged so they nothing to inject. Messages beyond the cap are logged so they
are not silently lost. are not silently lost.
""" """
if spec.injection_callback is None: if spec.injection_callback is None:
return [] return []
try: try:
items = await spec.injection_callback() signature = inspect.signature(spec.injection_callback)
accepts_limit = (
"limit" in signature.parameters
or any(
parameter.kind is inspect.Parameter.VAR_KEYWORD
for parameter in signature.parameters.values()
)
)
if accepts_limit:
items = await spec.injection_callback(limit=_MAX_INJECTIONS_PER_TURN)
else:
items = await spec.injection_callback()
except Exception: except Exception:
logger.exception("injection_callback failed") logger.exception("injection_callback failed")
return [] return []
if not items: if not items:
return [] return []
# items are InboundMessage objects from _drain_pending injected_messages: list[dict[str, Any]] = []
texts: list[str] = []
for item in items: for item in items:
if isinstance(item, dict) and item.get("role") == "user" and "content" in item:
injected_messages.append(item)
continue
text = getattr(item, "content", str(item)) text = getattr(item, "content", str(item))
if text.strip(): if text.strip():
texts.append(text) injected_messages.append({"role": "user", "content": text})
if len(texts) > _MAX_INJECTIONS_PER_TURN: if len(injected_messages) > _MAX_INJECTIONS_PER_TURN:
dropped = len(texts) - _MAX_INJECTIONS_PER_TURN dropped = len(injected_messages) - _MAX_INJECTIONS_PER_TURN
logger.warning( logger.warning(
"Injection batch has {} messages, capping to {} ({} dropped)", "Injection callback returned {} messages, capping to {} ({} dropped)",
len(texts), _MAX_INJECTIONS_PER_TURN, dropped, len(injected_messages), _MAX_INJECTIONS_PER_TURN, dropped,
) )
texts = texts[-_MAX_INJECTIONS_PER_TURN:] injected_messages = injected_messages[:_MAX_INJECTIONS_PER_TURN]
return texts return injected_messages
async def run(self, spec: AgentRunSpec) -> AgentRunResult: async def run(self, spec: AgentRunSpec) -> AgentRunResult:
hook = spec.hook or AgentHook() hook = spec.hook or AgentHook()
@ -247,8 +300,7 @@ class AgentRunner:
if injections: if injections:
had_injections = True had_injections = True
injection_cycles += 1 injection_cycles += 1
for text in injections: self._append_injected_messages(messages, injections)
messages.append({"role": "user", "content": text})
logger.info( logger.info(
"Injected {} follow-up message(s) after tool execution ({}/{})", "Injected {} follow-up message(s) after tool execution ({}/{})",
len(injections), injection_cycles, _MAX_INJECTION_CYCLES, len(injections), injection_cycles, _MAX_INJECTION_CYCLES,
@ -340,8 +392,7 @@ class AgentRunner:
"pending_tool_calls": [], "pending_tool_calls": [],
}, },
) )
for text in injections: self._append_injected_messages(messages, injections)
messages.append({"role": "user", "content": text})
logger.info( logger.info(
"Injected {} follow-up message(s) after final response ({}/{})", "Injected {} follow-up message(s) after final response ({}/{})",
len(injections), injection_cycles, _MAX_INJECTION_CYCLES, len(injections), injection_cycles, _MAX_INJECTION_CYCLES,

View File

@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import base64
import os import os
import time import time
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
@ -1633,7 +1634,7 @@ async def test_drain_injections_returns_empty_when_no_callback():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_drain_injections_extracts_content_from_inbound_messages(): async def test_drain_injections_extracts_content_from_inbound_messages():
"""Should extract .content from InboundMessage objects.""" """Should extract .content from InboundMessage objects."""
from nanobot.agent.runner import AgentRunSpec, AgentRunner, _MAX_INJECTIONS_PER_TURN from nanobot.agent.runner import AgentRunSpec, AgentRunner
from nanobot.bus.events import InboundMessage from nanobot.bus.events import InboundMessage
provider = MagicMock() provider = MagicMock()
@ -1655,12 +1656,15 @@ async def test_drain_injections_extracts_content_from_inbound_messages():
injection_callback=cb, injection_callback=cb,
) )
result = await runner._drain_injections(spec) result = await runner._drain_injections(spec)
assert result == ["hello", "world"] assert result == [
{"role": "user", "content": "hello"},
{"role": "user", "content": "world"},
]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_drain_injections_caps_at_max_and_logs_warning(): async def test_drain_injections_passes_limit_to_callback_when_supported():
"""When more than _MAX_INJECTIONS_PER_TURN items, only the last N are kept.""" """Limit-aware callbacks can preserve overflow in their own queue."""
from nanobot.agent.runner import AgentRunSpec, AgentRunner, _MAX_INJECTIONS_PER_TURN from nanobot.agent.runner import AgentRunSpec, AgentRunner, _MAX_INJECTIONS_PER_TURN
from nanobot.bus.events import InboundMessage from nanobot.bus.events import InboundMessage
@ -1668,14 +1672,16 @@ async def test_drain_injections_caps_at_max_and_logs_warning():
runner = AgentRunner(provider) runner = AgentRunner(provider)
tools = MagicMock() tools = MagicMock()
tools.get_definitions.return_value = [] tools.get_definitions.return_value = []
seen_limits: list[int] = []
msgs = [ msgs = [
InboundMessage(channel="cli", sender_id="u", chat_id="c", content=f"msg{i}") InboundMessage(channel="cli", sender_id="u", chat_id="c", content=f"msg{i}")
for i in range(_MAX_INJECTIONS_PER_TURN + 3) for i in range(_MAX_INJECTIONS_PER_TURN + 3)
] ]
async def cb(): async def cb(*, limit: int):
return msgs seen_limits.append(limit)
return msgs[:limit]
spec = AgentRunSpec( spec = AgentRunSpec(
initial_messages=[], tools=tools, model="m", initial_messages=[], tools=tools, model="m",
@ -1683,10 +1689,12 @@ async def test_drain_injections_caps_at_max_and_logs_warning():
injection_callback=cb, injection_callback=cb,
) )
result = await runner._drain_injections(spec) result = await runner._drain_injections(spec)
assert len(result) == _MAX_INJECTIONS_PER_TURN assert seen_limits == [_MAX_INJECTIONS_PER_TURN]
# Should keep the LAST _MAX_INJECTIONS_PER_TURN items assert result == [
assert result[0] == "msg3" {"role": "user", "content": "msg0"},
assert result[-1] == f"msg{_MAX_INJECTIONS_PER_TURN + 2}" {"role": "user", "content": "msg1"},
{"role": "user", "content": "msg2"},
]
@pytest.mark.asyncio @pytest.mark.asyncio
@ -1715,7 +1723,7 @@ async def test_drain_injections_skips_empty_content():
injection_callback=cb, injection_callback=cb,
) )
result = await runner._drain_injections(spec) result = await runner._drain_injections(spec)
assert result == ["valid"] assert result == [{"role": "user", "content": "valid"}]
@pytest.mark.asyncio @pytest.mark.asyncio
@ -1922,6 +1930,129 @@ async def test_checkpoint2_preserves_final_response_in_history_before_followup()
] ]
@pytest.mark.asyncio
async def test_loop_injected_followup_preserves_image_media(tmp_path):
"""Mid-turn follow-ups with images should keep multimodal content."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
image_path = tmp_path / "followup.png"
image_path.write_bytes(base64.b64decode(
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO+yF9kAAAAASUVORK5CYII="
))
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
captured_messages: list[list[dict]] = []
call_count = {"n": 0}
async def chat_with_retry(*, messages, **kwargs):
call_count["n"] += 1
captured_messages.append(list(messages))
if call_count["n"] == 1:
return LLMResponse(content="first answer", tool_calls=[], usage={})
return LLMResponse(content="second answer", tool_calls=[], usage={})
provider.chat_with_retry = chat_with_retry
loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model")
loop.tools.get_definitions = MagicMock(return_value=[])
pending_queue = asyncio.Queue()
await pending_queue.put(InboundMessage(
channel="cli",
sender_id="u",
chat_id="c",
content="",
media=[str(image_path)],
))
final_content, _, _, _, had_injections = await loop._run_agent_loop(
[{"role": "user", "content": "hello"}],
channel="cli",
chat_id="c",
pending_queue=pending_queue,
)
assert final_content == "second answer"
assert had_injections is True
assert call_count["n"] == 2
injected_user_messages = [
message for message in captured_messages[-1]
if message.get("role") == "user" and isinstance(message.get("content"), list)
]
assert injected_user_messages
assert any(
block.get("type") == "image_url"
for block in injected_user_messages[-1]["content"]
if isinstance(block, dict)
)
@pytest.mark.asyncio
async def test_runner_merges_multiple_injected_user_messages_without_losing_media():
"""Multiple injected follow-ups should not create lossy consecutive user messages."""
from nanobot.agent.runner import AgentRunSpec, AgentRunner
provider = MagicMock()
call_count = {"n": 0}
captured_messages = []
async def chat_with_retry(*, messages, **kwargs):
call_count["n"] += 1
captured_messages.append([dict(message) for message in messages])
if call_count["n"] == 1:
return LLMResponse(content="first answer", tool_calls=[], usage={})
return LLMResponse(content="second answer", tool_calls=[], usage={})
provider.chat_with_retry = chat_with_retry
tools = MagicMock()
tools.get_definitions.return_value = []
async def inject_cb():
if call_count["n"] == 1:
return [
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}},
{"type": "text", "text": "look at this"},
],
},
{"role": "user", "content": "and answer briefly"},
]
return []
runner = AgentRunner(provider)
result = await runner.run(AgentRunSpec(
initial_messages=[{"role": "user", "content": "hello"}],
tools=tools,
model="test-model",
max_iterations=5,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
injection_callback=inject_cb,
))
assert result.final_content == "second answer"
assert call_count["n"] == 2
second_call = captured_messages[-1]
user_messages = [message for message in second_call if message.get("role") == "user"]
assert len(user_messages) == 2
injected = user_messages[-1]
assert isinstance(injected["content"], list)
assert any(
block.get("type") == "image_url"
for block in injected["content"]
if isinstance(block, dict)
)
assert any(
block.get("type") == "text" and block.get("text") == "and answer briefly"
for block in injected["content"]
if isinstance(block, dict)
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_injection_cycles_capped_at_max(): async def test_injection_cycles_capped_at_max():
"""Injection cycles should be capped at _MAX_INJECTION_CYCLES.""" """Injection cycles should be capped at _MAX_INJECTION_CYCLES."""
@ -2042,6 +2173,88 @@ async def test_followup_routed_to_pending_queue(tmp_path):
assert queued_msg.session_key == UNIFIED_SESSION_KEY assert queued_msg.session_key == UNIFIED_SESSION_KEY
@pytest.mark.asyncio
async def test_pending_queue_preserves_overflow_for_next_injection_cycle(tmp_path):
"""Pending queue should leave overflow messages queued for later drains."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.agent.runner import _MAX_INJECTIONS_PER_TURN
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
captured_messages: list[list[dict]] = []
call_count = {"n": 0}
async def chat_with_retry(*, messages, **kwargs):
call_count["n"] += 1
captured_messages.append([dict(message) for message in messages])
return LLMResponse(content=f"answer-{call_count['n']}", tool_calls=[], usage={})
provider.chat_with_retry = chat_with_retry
loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model")
loop.tools.get_definitions = MagicMock(return_value=[])
pending_queue = asyncio.Queue()
total_followups = _MAX_INJECTIONS_PER_TURN + 2
for idx in range(total_followups):
await pending_queue.put(InboundMessage(
channel="cli",
sender_id="u",
chat_id="c",
content=f"follow-up-{idx}",
))
final_content, _, _, _, had_injections = await loop._run_agent_loop(
[{"role": "user", "content": "hello"}],
channel="cli",
chat_id="c",
pending_queue=pending_queue,
)
assert final_content == "answer-3"
assert had_injections is True
assert call_count["n"] == 3
flattened_user_content = "\n".join(
message["content"]
for message in captured_messages[-1]
if message.get("role") == "user" and isinstance(message.get("content"), str)
)
for idx in range(total_followups):
assert f"follow-up-{idx}" in flattened_user_content
assert pending_queue.empty()
@pytest.mark.asyncio
async def test_pending_queue_full_falls_back_to_queued_task(tmp_path):
"""QueueFull should preserve the message by dispatching a queued task."""
from nanobot.bus.events import InboundMessage
loop = _make_loop(tmp_path)
loop._dispatch = AsyncMock() # type: ignore[method-assign]
pending = asyncio.Queue(maxsize=1)
pending.put_nowait(InboundMessage(channel="cli", sender_id="u", chat_id="c", content="already queued"))
loop._pending_queues["cli:c"] = pending
run_task = asyncio.create_task(loop.run())
msg = InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up")
await loop.bus.publish_inbound(msg)
deadline = time.time() + 2
while loop._dispatch.await_count == 0 and time.time() < deadline:
await asyncio.sleep(0.01)
loop.stop()
await asyncio.wait_for(run_task, timeout=2)
assert loop._dispatch.await_count == 1
dispatched_msg = loop._dispatch.await_args.args[0]
assert dispatched_msg.content == "follow-up"
assert pending.qsize() == 1
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_dispatch_republishes_leftover_queue_messages(tmp_path): async def test_dispatch_republishes_leftover_queue_messages(tmp_path):
"""Messages left in the pending queue after _dispatch are re-published to the bus. """Messages left in the pending queue after _dispatch are re-published to the bus.

View File

@ -1,5 +1,6 @@
"""Test message tool suppress logic for final replies.""" """Test message tool suppress logic for final replies."""
import asyncio
from pathlib import Path from pathlib import Path
from unittest.mock import AsyncMock, MagicMock from unittest.mock import AsyncMock, MagicMock
@ -86,6 +87,42 @@ class TestMessageToolSuppressLogic:
assert result is not None assert result is not None
assert "Hello" in result.content assert "Hello" in result.content
@pytest.mark.asyncio
async def test_injected_followup_with_message_tool_does_not_emit_empty_fallback(
self, tmp_path: Path
) -> None:
loop = _make_loop(tmp_path)
tool_call = ToolCallRequest(
id="call1", name="message",
arguments={"content": "Tool reply", "channel": "feishu", "chat_id": "chat123"},
)
calls = iter([
LLMResponse(content="First answer", tool_calls=[]),
LLMResponse(content="", tool_calls=[tool_call]),
LLMResponse(content="", tool_calls=[]),
LLMResponse(content="", tool_calls=[]),
LLMResponse(content="", tool_calls=[]),
])
loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls))
loop.tools.get_definitions = MagicMock(return_value=[])
sent: list[OutboundMessage] = []
mt = loop.tools.get("message")
if isinstance(mt, MessageTool):
mt.set_send_callback(AsyncMock(side_effect=lambda m: sent.append(m)))
pending_queue = asyncio.Queue()
await pending_queue.put(
InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="follow-up")
)
msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Start")
result = await loop._process_message(msg, pending_queue=pending_queue)
assert len(sent) == 1
assert sent[0].content == "Tool reply"
assert result is None
async def test_progress_hides_internal_reasoning(self, tmp_path: Path) -> None: async def test_progress_hides_internal_reasoning(self, tmp_path: Path) -> None:
loop = _make_loop(tmp_path) loop = _make_loop(tmp_path)
tool_call = ToolCallRequest(id="call1", name="read_file", arguments={"path": "foo.txt"}) tool_call = ToolCallRequest(id="call1", name="read_file", arguments={"path": "foo.txt"})