nanobot/tests/agent/test_loop.py
chengyongru 6b140076fb fix(agent): preserve inbound metadata in streaming callbacks
The on_stream and on_stream_end closures in _dispatch hardcoded
their metadata dicts, dropping channel-specific fields like
message_thread_id. Copy msg.metadata first, then add internal
streaming flags, matching the pattern already used by _bus_progress.
2026-03-30 11:54:24 +08:00

139 lines
4.7 KiB
Python

"""Tests for AgentLoop._dispatch streaming metadata passthrough."""
from unittest.mock import AsyncMock, patch
import pytest
from nanobot.bus.events import InboundMessage, OutboundMessage
from nanobot.bus.queue import MessageBus
def _make_inbound(**meta) -> InboundMessage:
meta.setdefault("_wants_stream", True)
return InboundMessage(
channel="telegram",
sender_id="user1",
chat_id="chat1",
content="hello",
metadata=meta,
)
@pytest.mark.asyncio
async def test_on_stream_forwards_message_metadata() -> None:
"""on_stream should include original message metadata (e.g. message_thread_id)."""
from nanobot.agent.loop import AgentLoop
bus = MessageBus()
msg = _make_inbound(message_thread_id="42")
loop = AgentLoop.__new__(AgentLoop)
loop.bus = bus
loop._session_locks = {}
loop._concurrency_gate = None
async def fake_process_message(msg_in, **kwargs):
on_stream = kwargs.get("on_stream")
on_stream_end = kwargs.get("on_stream_end")
if on_stream:
await on_stream("hello")
if on_stream_end:
await on_stream_end()
return OutboundMessage(
channel=msg_in.channel, chat_id=msg_in.chat_id,
content="done", metadata=msg_in.metadata,
)
with patch.object(loop, "_process_message", side_effect=fake_process_message):
await loop._dispatch(msg)
# Collect all outbound messages (stream delta, stream end, final response)
outbound: list[OutboundMessage] = []
while not bus.outbound.empty():
outbound.append(await bus.outbound.get())
stream_msg = next(m for m in outbound if m.metadata.get("_stream_delta"))
assert stream_msg.metadata["message_thread_id"] == "42"
assert stream_msg.metadata["_stream_delta"] is True
assert "_stream_id" in stream_msg.metadata
@pytest.mark.asyncio
async def test_on_stream_end_forwards_message_metadata() -> None:
"""on_stream_end should include original message metadata."""
from nanobot.agent.loop import AgentLoop
bus = MessageBus()
msg = _make_inbound(message_thread_id="42")
loop = AgentLoop.__new__(AgentLoop)
loop.bus = bus
loop._session_locks = {}
loop._concurrency_gate = None
async def fake_process_message(msg_in, **kwargs):
on_stream = kwargs.get("on_stream")
on_stream_end = kwargs.get("on_stream_end")
if on_stream:
await on_stream("hello")
if on_stream_end:
await on_stream_end()
return OutboundMessage(
channel=msg_in.channel, chat_id=msg_in.chat_id,
content="done", metadata=msg_in.metadata,
)
with patch.object(loop, "_process_message", side_effect=fake_process_message):
await loop._dispatch(msg)
outbound: list[OutboundMessage] = []
while not bus.outbound.empty():
outbound.append(await bus.outbound.get())
end_msg = next(m for m in outbound if m.metadata.get("_stream_end"))
assert end_msg.metadata["message_thread_id"] == "42"
assert end_msg.metadata["_stream_end"] is True
assert end_msg.metadata["_resuming"] is False
assert "_stream_id" in end_msg.metadata
@pytest.mark.asyncio
async def test_streaming_preserves_arbitrary_metadata_keys() -> None:
"""Both streaming callbacks should forward all original metadata keys untouched."""
from nanobot.agent.loop import AgentLoop
bus = MessageBus()
msg = _make_inbound(message_thread_id="99", custom_flag="abc", reply_to_id="msg77")
loop = AgentLoop.__new__(AgentLoop)
loop.bus = bus
loop._session_locks = {}
loop._concurrency_gate = None
async def fake_process_message(msg_in, **kwargs):
on_stream = kwargs.get("on_stream")
on_stream_end = kwargs.get("on_stream_end")
if on_stream:
await on_stream("hi")
if on_stream_end:
await on_stream_end()
return OutboundMessage(
channel=msg_in.channel, chat_id=msg_in.chat_id,
content="done", metadata=msg_in.metadata,
)
with patch.object(loop, "_process_message", side_effect=fake_process_message):
await loop._dispatch(msg)
outbound: list[OutboundMessage] = []
while not bus.outbound.empty():
outbound.append(await bus.outbound.get())
stream_msg = next(m for m in outbound if m.metadata.get("_stream_delta"))
for key in ("message_thread_id", "custom_flag", "reply_to_id"):
assert stream_msg.metadata[key] == msg.metadata[key]
end_msg = next(m for m in outbound if m.metadata.get("_stream_end"))
for key in ("message_thread_id", "custom_flag", "reply_to_id"):
assert end_msg.metadata[key] == msg.metadata[key]