mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-22 11:00:19 +00:00
feat(agent): enhance message injection handling and content merging
This commit is contained in:
parent
f6c39ec946
commit
cf8381f517
@ -17,7 +17,7 @@ from nanobot.agent.autocompact import AutoCompact
|
|||||||
from nanobot.agent.context import ContextBuilder
|
from nanobot.agent.context import ContextBuilder
|
||||||
from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook
|
from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook
|
||||||
from nanobot.agent.memory import Consolidator, Dream
|
from nanobot.agent.memory import Consolidator, Dream
|
||||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
from nanobot.agent.runner import _MAX_INJECTIONS_PER_TURN, AgentRunSpec, AgentRunner
|
||||||
from nanobot.agent.subagent import SubagentManager
|
from nanobot.agent.subagent import SubagentManager
|
||||||
from nanobot.agent.tools.cron import CronTool
|
from nanobot.agent.tools.cron import CronTool
|
||||||
from nanobot.agent.skills import BUILTIN_SKILLS_DIR
|
from nanobot.agent.skills import BUILTIN_SKILLS_DIR
|
||||||
@ -370,16 +370,30 @@ class AgentLoop:
|
|||||||
return
|
return
|
||||||
self._set_runtime_checkpoint(session, payload)
|
self._set_runtime_checkpoint(session, payload)
|
||||||
|
|
||||||
async def _drain_pending() -> list[InboundMessage]:
|
async def _drain_pending(*, limit: int = _MAX_INJECTIONS_PER_TURN) -> list[dict[str, Any]]:
|
||||||
"""Non-blocking drain of follow-up messages from the pending queue."""
|
"""Non-blocking drain of follow-up messages from the pending queue."""
|
||||||
if pending_queue is None:
|
if pending_queue is None:
|
||||||
return []
|
return []
|
||||||
items: list[InboundMessage] = []
|
items: list[dict[str, Any]] = []
|
||||||
while True:
|
while len(items) < limit:
|
||||||
try:
|
try:
|
||||||
items.append(pending_queue.get_nowait())
|
pending_msg = pending_queue.get_nowait()
|
||||||
except asyncio.QueueEmpty:
|
except asyncio.QueueEmpty:
|
||||||
break
|
break
|
||||||
|
user_content = self.context._build_user_content(
|
||||||
|
pending_msg.content,
|
||||||
|
pending_msg.media if pending_msg.media else None,
|
||||||
|
)
|
||||||
|
runtime_ctx = self.context._build_runtime_context(
|
||||||
|
pending_msg.channel,
|
||||||
|
pending_msg.chat_id,
|
||||||
|
self.context.timezone,
|
||||||
|
)
|
||||||
|
if isinstance(user_content, str):
|
||||||
|
merged: str | list[dict[str, Any]] = f"{runtime_ctx}\n\n{user_content}"
|
||||||
|
else:
|
||||||
|
merged = [{"type": "text", "text": runtime_ctx}] + user_content
|
||||||
|
items.append({"role": "user", "content": merged})
|
||||||
return items
|
return items
|
||||||
|
|
||||||
result = await self.runner.run(AgentRunSpec(
|
result = await self.runner.run(AgentRunSpec(
|
||||||
@ -451,7 +465,7 @@ class AgentLoop:
|
|||||||
self._pending_queues[effective_key].put_nowait(pending_msg)
|
self._pending_queues[effective_key].put_nowait(pending_msg)
|
||||||
except asyncio.QueueFull:
|
except asyncio.QueueFull:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Pending queue full for session {}, dropping follow-up",
|
"Pending queue full for session {}, falling back to queued task",
|
||||||
effective_key,
|
effective_key,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -459,7 +473,7 @@ class AgentLoop:
|
|||||||
"Routed follow-up message to pending queue for session {}",
|
"Routed follow-up message to pending queue for session {}",
|
||||||
effective_key,
|
effective_key,
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
# Compute the effective session key before dispatching
|
# Compute the effective session key before dispatching
|
||||||
# This ensures /stop command can find tasks correctly when unified session is enabled
|
# This ensures /stop command can find tasks correctly when unified session is enabled
|
||||||
task = asyncio.create_task(self._dispatch(msg))
|
task = asyncio.create_task(self._dispatch(msg))
|
||||||
@ -697,12 +711,14 @@ class AgentLoop:
|
|||||||
self.sessions.save(session)
|
self.sessions.save(session)
|
||||||
self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session))
|
self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session))
|
||||||
|
|
||||||
# When follow-up messages were injected mid-turn, the LLM's final
|
# When follow-up messages were injected mid-turn, a later natural
|
||||||
# response addresses those follow-ups. Always send the response in
|
# language reply may address those follow-ups and should not be
|
||||||
# this case, even if MessageTool was used earlier in the turn — the
|
# suppressed just because MessageTool was used earlier in the turn.
|
||||||
# follow-up response is new content the user hasn't seen.
|
# However, if the turn falls back to the empty-final-response
|
||||||
if not had_injections:
|
# placeholder, suppress it when the real user-visible output already
|
||||||
if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn:
|
# came from MessageTool.
|
||||||
|
if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn:
|
||||||
|
if not had_injections or stop_reason == "empty_final_response":
|
||||||
return None
|
return None
|
||||||
|
|
||||||
preview = final_content[:120] + "..." if len(final_content) > 120 else final_content
|
preview = final_content[:120] + "..." if len(final_content) > 120 else final_content
|
||||||
|
|||||||
@ -4,6 +4,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
import inspect
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@ -94,37 +95,89 @@ class AgentRunner:
|
|||||||
def __init__(self, provider: LLMProvider):
|
def __init__(self, provider: LLMProvider):
|
||||||
self.provider = provider
|
self.provider = provider
|
||||||
|
|
||||||
async def _drain_injections(self, spec: AgentRunSpec) -> list[str]:
|
@staticmethod
|
||||||
|
def _merge_message_content(left: Any, right: Any) -> str | list[dict[str, Any]]:
|
||||||
|
if isinstance(left, str) and isinstance(right, str):
|
||||||
|
return f"{left}\n\n{right}" if left else right
|
||||||
|
|
||||||
|
def _to_blocks(value: Any) -> list[dict[str, Any]]:
|
||||||
|
if isinstance(value, list):
|
||||||
|
return [
|
||||||
|
item if isinstance(item, dict) else {"type": "text", "text": str(item)}
|
||||||
|
for item in value
|
||||||
|
]
|
||||||
|
if value is None:
|
||||||
|
return []
|
||||||
|
return [{"type": "text", "text": str(value)}]
|
||||||
|
|
||||||
|
return _to_blocks(left) + _to_blocks(right)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _append_injected_messages(
|
||||||
|
cls,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
injections: list[dict[str, Any]],
|
||||||
|
) -> None:
|
||||||
|
"""Append injected user messages while preserving role alternation."""
|
||||||
|
for injection in injections:
|
||||||
|
if (
|
||||||
|
messages
|
||||||
|
and injection.get("role") == "user"
|
||||||
|
and messages[-1].get("role") == "user"
|
||||||
|
):
|
||||||
|
merged = dict(messages[-1])
|
||||||
|
merged["content"] = cls._merge_message_content(
|
||||||
|
merged.get("content"),
|
||||||
|
injection.get("content"),
|
||||||
|
)
|
||||||
|
messages[-1] = merged
|
||||||
|
continue
|
||||||
|
messages.append(injection)
|
||||||
|
|
||||||
|
async def _drain_injections(self, spec: AgentRunSpec) -> list[dict[str, Any]]:
|
||||||
"""Drain pending user messages via the injection callback.
|
"""Drain pending user messages via the injection callback.
|
||||||
|
|
||||||
Returns all drained message contents (capped by
|
Returns normalized user messages (capped by
|
||||||
``_MAX_INJECTIONS_PER_TURN``), or an empty list when there is
|
``_MAX_INJECTIONS_PER_TURN``), or an empty list when there is
|
||||||
nothing to inject. Messages beyond the cap are logged so they
|
nothing to inject. Messages beyond the cap are logged so they
|
||||||
are not silently lost.
|
are not silently lost.
|
||||||
"""
|
"""
|
||||||
if spec.injection_callback is None:
|
if spec.injection_callback is None:
|
||||||
return []
|
return []
|
||||||
try:
|
try:
|
||||||
items = await spec.injection_callback()
|
signature = inspect.signature(spec.injection_callback)
|
||||||
|
accepts_limit = (
|
||||||
|
"limit" in signature.parameters
|
||||||
|
or any(
|
||||||
|
parameter.kind is inspect.Parameter.VAR_KEYWORD
|
||||||
|
for parameter in signature.parameters.values()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if accepts_limit:
|
||||||
|
items = await spec.injection_callback(limit=_MAX_INJECTIONS_PER_TURN)
|
||||||
|
else:
|
||||||
|
items = await spec.injection_callback()
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("injection_callback failed")
|
logger.exception("injection_callback failed")
|
||||||
return []
|
return []
|
||||||
if not items:
|
if not items:
|
||||||
return []
|
return []
|
||||||
# items are InboundMessage objects from _drain_pending
|
injected_messages: list[dict[str, Any]] = []
|
||||||
texts: list[str] = []
|
|
||||||
for item in items:
|
for item in items:
|
||||||
|
if isinstance(item, dict) and item.get("role") == "user" and "content" in item:
|
||||||
|
injected_messages.append(item)
|
||||||
|
continue
|
||||||
text = getattr(item, "content", str(item))
|
text = getattr(item, "content", str(item))
|
||||||
if text.strip():
|
if text.strip():
|
||||||
texts.append(text)
|
injected_messages.append({"role": "user", "content": text})
|
||||||
if len(texts) > _MAX_INJECTIONS_PER_TURN:
|
if len(injected_messages) > _MAX_INJECTIONS_PER_TURN:
|
||||||
dropped = len(texts) - _MAX_INJECTIONS_PER_TURN
|
dropped = len(injected_messages) - _MAX_INJECTIONS_PER_TURN
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Injection batch has {} messages, capping to {} ({} dropped)",
|
"Injection callback returned {} messages, capping to {} ({} dropped)",
|
||||||
len(texts), _MAX_INJECTIONS_PER_TURN, dropped,
|
len(injected_messages), _MAX_INJECTIONS_PER_TURN, dropped,
|
||||||
)
|
)
|
||||||
texts = texts[-_MAX_INJECTIONS_PER_TURN:]
|
injected_messages = injected_messages[:_MAX_INJECTIONS_PER_TURN]
|
||||||
return texts
|
return injected_messages
|
||||||
|
|
||||||
async def run(self, spec: AgentRunSpec) -> AgentRunResult:
|
async def run(self, spec: AgentRunSpec) -> AgentRunResult:
|
||||||
hook = spec.hook or AgentHook()
|
hook = spec.hook or AgentHook()
|
||||||
@ -247,8 +300,7 @@ class AgentRunner:
|
|||||||
if injections:
|
if injections:
|
||||||
had_injections = True
|
had_injections = True
|
||||||
injection_cycles += 1
|
injection_cycles += 1
|
||||||
for text in injections:
|
self._append_injected_messages(messages, injections)
|
||||||
messages.append({"role": "user", "content": text})
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Injected {} follow-up message(s) after tool execution ({}/{})",
|
"Injected {} follow-up message(s) after tool execution ({}/{})",
|
||||||
len(injections), injection_cycles, _MAX_INJECTION_CYCLES,
|
len(injections), injection_cycles, _MAX_INJECTION_CYCLES,
|
||||||
@ -340,8 +392,7 @@ class AgentRunner:
|
|||||||
"pending_tool_calls": [],
|
"pending_tool_calls": [],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
for text in injections:
|
self._append_injected_messages(messages, injections)
|
||||||
messages.append({"role": "user", "content": text})
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Injected {} follow-up message(s) after final response ({}/{})",
|
"Injected {} follow-up message(s) after final response ({}/{})",
|
||||||
len(injections), injection_cycles, _MAX_INJECTION_CYCLES,
|
len(injections), injection_cycles, _MAX_INJECTION_CYCLES,
|
||||||
|
|||||||
@ -3,6 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import base64
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
@ -1633,7 +1634,7 @@ async def test_drain_injections_returns_empty_when_no_callback():
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_drain_injections_extracts_content_from_inbound_messages():
|
async def test_drain_injections_extracts_content_from_inbound_messages():
|
||||||
"""Should extract .content from InboundMessage objects."""
|
"""Should extract .content from InboundMessage objects."""
|
||||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner, _MAX_INJECTIONS_PER_TURN
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
from nanobot.bus.events import InboundMessage
|
from nanobot.bus.events import InboundMessage
|
||||||
|
|
||||||
provider = MagicMock()
|
provider = MagicMock()
|
||||||
@ -1655,12 +1656,15 @@ async def test_drain_injections_extracts_content_from_inbound_messages():
|
|||||||
injection_callback=cb,
|
injection_callback=cb,
|
||||||
)
|
)
|
||||||
result = await runner._drain_injections(spec)
|
result = await runner._drain_injections(spec)
|
||||||
assert result == ["hello", "world"]
|
assert result == [
|
||||||
|
{"role": "user", "content": "hello"},
|
||||||
|
{"role": "user", "content": "world"},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_drain_injections_caps_at_max_and_logs_warning():
|
async def test_drain_injections_passes_limit_to_callback_when_supported():
|
||||||
"""When more than _MAX_INJECTIONS_PER_TURN items, only the last N are kept."""
|
"""Limit-aware callbacks can preserve overflow in their own queue."""
|
||||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner, _MAX_INJECTIONS_PER_TURN
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner, _MAX_INJECTIONS_PER_TURN
|
||||||
from nanobot.bus.events import InboundMessage
|
from nanobot.bus.events import InboundMessage
|
||||||
|
|
||||||
@ -1668,14 +1672,16 @@ async def test_drain_injections_caps_at_max_and_logs_warning():
|
|||||||
runner = AgentRunner(provider)
|
runner = AgentRunner(provider)
|
||||||
tools = MagicMock()
|
tools = MagicMock()
|
||||||
tools.get_definitions.return_value = []
|
tools.get_definitions.return_value = []
|
||||||
|
seen_limits: list[int] = []
|
||||||
|
|
||||||
msgs = [
|
msgs = [
|
||||||
InboundMessage(channel="cli", sender_id="u", chat_id="c", content=f"msg{i}")
|
InboundMessage(channel="cli", sender_id="u", chat_id="c", content=f"msg{i}")
|
||||||
for i in range(_MAX_INJECTIONS_PER_TURN + 3)
|
for i in range(_MAX_INJECTIONS_PER_TURN + 3)
|
||||||
]
|
]
|
||||||
|
|
||||||
async def cb():
|
async def cb(*, limit: int):
|
||||||
return msgs
|
seen_limits.append(limit)
|
||||||
|
return msgs[:limit]
|
||||||
|
|
||||||
spec = AgentRunSpec(
|
spec = AgentRunSpec(
|
||||||
initial_messages=[], tools=tools, model="m",
|
initial_messages=[], tools=tools, model="m",
|
||||||
@ -1683,10 +1689,12 @@ async def test_drain_injections_caps_at_max_and_logs_warning():
|
|||||||
injection_callback=cb,
|
injection_callback=cb,
|
||||||
)
|
)
|
||||||
result = await runner._drain_injections(spec)
|
result = await runner._drain_injections(spec)
|
||||||
assert len(result) == _MAX_INJECTIONS_PER_TURN
|
assert seen_limits == [_MAX_INJECTIONS_PER_TURN]
|
||||||
# Should keep the LAST _MAX_INJECTIONS_PER_TURN items
|
assert result == [
|
||||||
assert result[0] == "msg3"
|
{"role": "user", "content": "msg0"},
|
||||||
assert result[-1] == f"msg{_MAX_INJECTIONS_PER_TURN + 2}"
|
{"role": "user", "content": "msg1"},
|
||||||
|
{"role": "user", "content": "msg2"},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -1715,7 +1723,7 @@ async def test_drain_injections_skips_empty_content():
|
|||||||
injection_callback=cb,
|
injection_callback=cb,
|
||||||
)
|
)
|
||||||
result = await runner._drain_injections(spec)
|
result = await runner._drain_injections(spec)
|
||||||
assert result == ["valid"]
|
assert result == [{"role": "user", "content": "valid"}]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -1922,6 +1930,129 @@ async def test_checkpoint2_preserves_final_response_in_history_before_followup()
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_loop_injected_followup_preserves_image_media(tmp_path):
|
||||||
|
"""Mid-turn follow-ups with images should keep multimodal content."""
|
||||||
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
from nanobot.bus.events import InboundMessage
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
|
||||||
|
image_path = tmp_path / "followup.png"
|
||||||
|
image_path.write_bytes(base64.b64decode(
|
||||||
|
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO+yF9kAAAAASUVORK5CYII="
|
||||||
|
))
|
||||||
|
|
||||||
|
bus = MessageBus()
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
captured_messages: list[list[dict]] = []
|
||||||
|
call_count = {"n": 0}
|
||||||
|
|
||||||
|
async def chat_with_retry(*, messages, **kwargs):
|
||||||
|
call_count["n"] += 1
|
||||||
|
captured_messages.append(list(messages))
|
||||||
|
if call_count["n"] == 1:
|
||||||
|
return LLMResponse(content="first answer", tool_calls=[], usage={})
|
||||||
|
return LLMResponse(content="second answer", tool_calls=[], usage={})
|
||||||
|
|
||||||
|
provider.chat_with_retry = chat_with_retry
|
||||||
|
loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model")
|
||||||
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
|
||||||
|
pending_queue = asyncio.Queue()
|
||||||
|
await pending_queue.put(InboundMessage(
|
||||||
|
channel="cli",
|
||||||
|
sender_id="u",
|
||||||
|
chat_id="c",
|
||||||
|
content="",
|
||||||
|
media=[str(image_path)],
|
||||||
|
))
|
||||||
|
|
||||||
|
final_content, _, _, _, had_injections = await loop._run_agent_loop(
|
||||||
|
[{"role": "user", "content": "hello"}],
|
||||||
|
channel="cli",
|
||||||
|
chat_id="c",
|
||||||
|
pending_queue=pending_queue,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert final_content == "second answer"
|
||||||
|
assert had_injections is True
|
||||||
|
assert call_count["n"] == 2
|
||||||
|
injected_user_messages = [
|
||||||
|
message for message in captured_messages[-1]
|
||||||
|
if message.get("role") == "user" and isinstance(message.get("content"), list)
|
||||||
|
]
|
||||||
|
assert injected_user_messages
|
||||||
|
assert any(
|
||||||
|
block.get("type") == "image_url"
|
||||||
|
for block in injected_user_messages[-1]["content"]
|
||||||
|
if isinstance(block, dict)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runner_merges_multiple_injected_user_messages_without_losing_media():
|
||||||
|
"""Multiple injected follow-ups should not create lossy consecutive user messages."""
|
||||||
|
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||||
|
|
||||||
|
provider = MagicMock()
|
||||||
|
call_count = {"n": 0}
|
||||||
|
captured_messages = []
|
||||||
|
|
||||||
|
async def chat_with_retry(*, messages, **kwargs):
|
||||||
|
call_count["n"] += 1
|
||||||
|
captured_messages.append([dict(message) for message in messages])
|
||||||
|
if call_count["n"] == 1:
|
||||||
|
return LLMResponse(content="first answer", tool_calls=[], usage={})
|
||||||
|
return LLMResponse(content="second answer", tool_calls=[], usage={})
|
||||||
|
|
||||||
|
provider.chat_with_retry = chat_with_retry
|
||||||
|
tools = MagicMock()
|
||||||
|
tools.get_definitions.return_value = []
|
||||||
|
|
||||||
|
async def inject_cb():
|
||||||
|
if call_count["n"] == 1:
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}},
|
||||||
|
{"type": "text", "text": "look at this"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{"role": "user", "content": "and answer briefly"},
|
||||||
|
]
|
||||||
|
return []
|
||||||
|
|
||||||
|
runner = AgentRunner(provider)
|
||||||
|
result = await runner.run(AgentRunSpec(
|
||||||
|
initial_messages=[{"role": "user", "content": "hello"}],
|
||||||
|
tools=tools,
|
||||||
|
model="test-model",
|
||||||
|
max_iterations=5,
|
||||||
|
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||||
|
injection_callback=inject_cb,
|
||||||
|
))
|
||||||
|
|
||||||
|
assert result.final_content == "second answer"
|
||||||
|
assert call_count["n"] == 2
|
||||||
|
second_call = captured_messages[-1]
|
||||||
|
user_messages = [message for message in second_call if message.get("role") == "user"]
|
||||||
|
assert len(user_messages) == 2
|
||||||
|
injected = user_messages[-1]
|
||||||
|
assert isinstance(injected["content"], list)
|
||||||
|
assert any(
|
||||||
|
block.get("type") == "image_url"
|
||||||
|
for block in injected["content"]
|
||||||
|
if isinstance(block, dict)
|
||||||
|
)
|
||||||
|
assert any(
|
||||||
|
block.get("type") == "text" and block.get("text") == "and answer briefly"
|
||||||
|
for block in injected["content"]
|
||||||
|
if isinstance(block, dict)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_injection_cycles_capped_at_max():
|
async def test_injection_cycles_capped_at_max():
|
||||||
"""Injection cycles should be capped at _MAX_INJECTION_CYCLES."""
|
"""Injection cycles should be capped at _MAX_INJECTION_CYCLES."""
|
||||||
@ -2042,6 +2173,88 @@ async def test_followup_routed_to_pending_queue(tmp_path):
|
|||||||
assert queued_msg.session_key == UNIFIED_SESSION_KEY
|
assert queued_msg.session_key == UNIFIED_SESSION_KEY
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pending_queue_preserves_overflow_for_next_injection_cycle(tmp_path):
|
||||||
|
"""Pending queue should leave overflow messages queued for later drains."""
|
||||||
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
from nanobot.bus.events import InboundMessage
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.agent.runner import _MAX_INJECTIONS_PER_TURN
|
||||||
|
|
||||||
|
bus = MessageBus()
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
captured_messages: list[list[dict]] = []
|
||||||
|
call_count = {"n": 0}
|
||||||
|
|
||||||
|
async def chat_with_retry(*, messages, **kwargs):
|
||||||
|
call_count["n"] += 1
|
||||||
|
captured_messages.append([dict(message) for message in messages])
|
||||||
|
return LLMResponse(content=f"answer-{call_count['n']}", tool_calls=[], usage={})
|
||||||
|
|
||||||
|
provider.chat_with_retry = chat_with_retry
|
||||||
|
loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model")
|
||||||
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
|
||||||
|
pending_queue = asyncio.Queue()
|
||||||
|
total_followups = _MAX_INJECTIONS_PER_TURN + 2
|
||||||
|
for idx in range(total_followups):
|
||||||
|
await pending_queue.put(InboundMessage(
|
||||||
|
channel="cli",
|
||||||
|
sender_id="u",
|
||||||
|
chat_id="c",
|
||||||
|
content=f"follow-up-{idx}",
|
||||||
|
))
|
||||||
|
|
||||||
|
final_content, _, _, _, had_injections = await loop._run_agent_loop(
|
||||||
|
[{"role": "user", "content": "hello"}],
|
||||||
|
channel="cli",
|
||||||
|
chat_id="c",
|
||||||
|
pending_queue=pending_queue,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert final_content == "answer-3"
|
||||||
|
assert had_injections is True
|
||||||
|
assert call_count["n"] == 3
|
||||||
|
flattened_user_content = "\n".join(
|
||||||
|
message["content"]
|
||||||
|
for message in captured_messages[-1]
|
||||||
|
if message.get("role") == "user" and isinstance(message.get("content"), str)
|
||||||
|
)
|
||||||
|
for idx in range(total_followups):
|
||||||
|
assert f"follow-up-{idx}" in flattened_user_content
|
||||||
|
assert pending_queue.empty()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pending_queue_full_falls_back_to_queued_task(tmp_path):
|
||||||
|
"""QueueFull should preserve the message by dispatching a queued task."""
|
||||||
|
from nanobot.bus.events import InboundMessage
|
||||||
|
|
||||||
|
loop = _make_loop(tmp_path)
|
||||||
|
loop._dispatch = AsyncMock() # type: ignore[method-assign]
|
||||||
|
|
||||||
|
pending = asyncio.Queue(maxsize=1)
|
||||||
|
pending.put_nowait(InboundMessage(channel="cli", sender_id="u", chat_id="c", content="already queued"))
|
||||||
|
loop._pending_queues["cli:c"] = pending
|
||||||
|
|
||||||
|
run_task = asyncio.create_task(loop.run())
|
||||||
|
msg = InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up")
|
||||||
|
await loop.bus.publish_inbound(msg)
|
||||||
|
|
||||||
|
deadline = time.time() + 2
|
||||||
|
while loop._dispatch.await_count == 0 and time.time() < deadline:
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
|
||||||
|
loop.stop()
|
||||||
|
await asyncio.wait_for(run_task, timeout=2)
|
||||||
|
|
||||||
|
assert loop._dispatch.await_count == 1
|
||||||
|
dispatched_msg = loop._dispatch.await_args.args[0]
|
||||||
|
assert dispatched_msg.content == "follow-up"
|
||||||
|
assert pending.qsize() == 1
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_dispatch_republishes_leftover_queue_messages(tmp_path):
|
async def test_dispatch_republishes_leftover_queue_messages(tmp_path):
|
||||||
"""Messages left in the pending queue after _dispatch are re-published to the bus.
|
"""Messages left in the pending queue after _dispatch are re-published to the bus.
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
"""Test message tool suppress logic for final replies."""
|
"""Test message tool suppress logic for final replies."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
@ -86,6 +87,42 @@ class TestMessageToolSuppressLogic:
|
|||||||
assert result is not None
|
assert result is not None
|
||||||
assert "Hello" in result.content
|
assert "Hello" in result.content
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_injected_followup_with_message_tool_does_not_emit_empty_fallback(
|
||||||
|
self, tmp_path: Path
|
||||||
|
) -> None:
|
||||||
|
loop = _make_loop(tmp_path)
|
||||||
|
tool_call = ToolCallRequest(
|
||||||
|
id="call1", name="message",
|
||||||
|
arguments={"content": "Tool reply", "channel": "feishu", "chat_id": "chat123"},
|
||||||
|
)
|
||||||
|
calls = iter([
|
||||||
|
LLMResponse(content="First answer", tool_calls=[]),
|
||||||
|
LLMResponse(content="", tool_calls=[tool_call]),
|
||||||
|
LLMResponse(content="", tool_calls=[]),
|
||||||
|
LLMResponse(content="", tool_calls=[]),
|
||||||
|
LLMResponse(content="", tool_calls=[]),
|
||||||
|
])
|
||||||
|
loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls))
|
||||||
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
|
||||||
|
sent: list[OutboundMessage] = []
|
||||||
|
mt = loop.tools.get("message")
|
||||||
|
if isinstance(mt, MessageTool):
|
||||||
|
mt.set_send_callback(AsyncMock(side_effect=lambda m: sent.append(m)))
|
||||||
|
|
||||||
|
pending_queue = asyncio.Queue()
|
||||||
|
await pending_queue.put(
|
||||||
|
InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="follow-up")
|
||||||
|
)
|
||||||
|
|
||||||
|
msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Start")
|
||||||
|
result = await loop._process_message(msg, pending_queue=pending_queue)
|
||||||
|
|
||||||
|
assert len(sent) == 1
|
||||||
|
assert sent[0].content == "Tool reply"
|
||||||
|
assert result is None
|
||||||
|
|
||||||
async def test_progress_hides_internal_reasoning(self, tmp_path: Path) -> None:
|
async def test_progress_hides_internal_reasoning(self, tmp_path: Path) -> None:
|
||||||
loop = _make_loop(tmp_path)
|
loop = _make_loop(tmp_path)
|
||||||
tool_call = ToolCallRequest(id="call1", name="read_file", arguments={"path": "foo.txt"})
|
tool_call = ToolCallRequest(id="call1", name="read_file", arguments={"path": "foo.txt"})
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user