From 6b140076fb8366d236f998a0b136d3fc97489c7a Mon Sep 17 00:00:00 2001 From: chengyongru Date: Mon, 30 Mar 2026 11:54:24 +0800 Subject: [PATCH] fix(agent): preserve inbound metadata in streaming callbacks The on_stream and on_stream_end closures in _dispatch hardcoded their metadata dicts, dropping channel-specific fields like message_thread_id. Copy msg.metadata first, then add internal streaming flags, matching the pattern already used by _bus_progress. --- nanobot/agent/loop.py | 20 +++--- tests/agent/test_loop.py | 138 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 147 insertions(+), 11 deletions(-) create mode 100644 tests/agent/test_loop.py diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index 63ee92ca5..555c6b9f5 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -321,25 +321,23 @@ class AgentLoop: return f"{stream_base_id}:{stream_segment}" async def on_stream(delta: str) -> None: + meta = dict(msg.metadata or {}) + meta["_stream_delta"] = True + meta["_stream_id"] = _current_stream_id() await self.bus.publish_outbound(OutboundMessage( channel=msg.channel, chat_id=msg.chat_id, - content=delta, - metadata={ - "_stream_delta": True, - "_stream_id": _current_stream_id(), - }, + content=delta, metadata=meta, )) async def on_stream_end(*, resuming: bool = False) -> None: nonlocal stream_segment + meta = dict(msg.metadata or {}) + meta["_stream_end"] = True + meta["_resuming"] = resuming + meta["_stream_id"] = _current_stream_id() await self.bus.publish_outbound(OutboundMessage( channel=msg.channel, chat_id=msg.chat_id, - content="", - metadata={ - "_stream_end": True, - "_resuming": resuming, - "_stream_id": _current_stream_id(), - }, + content="", metadata=meta, )) stream_segment += 1 diff --git a/tests/agent/test_loop.py b/tests/agent/test_loop.py new file mode 100644 index 000000000..3c8276e14 --- /dev/null +++ b/tests/agent/test_loop.py @@ -0,0 +1,138 @@ +"""Tests for AgentLoop._dispatch streaming metadata passthrough.""" + +from unittest.mock import AsyncMock, patch + +import pytest + +from nanobot.bus.events import InboundMessage, OutboundMessage +from nanobot.bus.queue import MessageBus + + +def _make_inbound(**meta) -> InboundMessage: + meta.setdefault("_wants_stream", True) + return InboundMessage( + channel="telegram", + sender_id="user1", + chat_id="chat1", + content="hello", + metadata=meta, + ) + + +@pytest.mark.asyncio +async def test_on_stream_forwards_message_metadata() -> None: + """on_stream should include original message metadata (e.g. message_thread_id).""" + from nanobot.agent.loop import AgentLoop + + bus = MessageBus() + msg = _make_inbound(message_thread_id="42") + + loop = AgentLoop.__new__(AgentLoop) + loop.bus = bus + loop._session_locks = {} + loop._concurrency_gate = None + + async def fake_process_message(msg_in, **kwargs): + on_stream = kwargs.get("on_stream") + on_stream_end = kwargs.get("on_stream_end") + if on_stream: + await on_stream("hello") + if on_stream_end: + await on_stream_end() + return OutboundMessage( + channel=msg_in.channel, chat_id=msg_in.chat_id, + content="done", metadata=msg_in.metadata, + ) + + with patch.object(loop, "_process_message", side_effect=fake_process_message): + await loop._dispatch(msg) + + # Collect all outbound messages (stream delta, stream end, final response) + outbound: list[OutboundMessage] = [] + while not bus.outbound.empty(): + outbound.append(await bus.outbound.get()) + + stream_msg = next(m for m in outbound if m.metadata.get("_stream_delta")) + assert stream_msg.metadata["message_thread_id"] == "42" + assert stream_msg.metadata["_stream_delta"] is True + assert "_stream_id" in stream_msg.metadata + + +@pytest.mark.asyncio +async def test_on_stream_end_forwards_message_metadata() -> None: + """on_stream_end should include original message metadata.""" + from nanobot.agent.loop import AgentLoop + + bus = MessageBus() + msg = _make_inbound(message_thread_id="42") + + loop = AgentLoop.__new__(AgentLoop) + loop.bus = bus + loop._session_locks = {} + loop._concurrency_gate = None + + async def fake_process_message(msg_in, **kwargs): + on_stream = kwargs.get("on_stream") + on_stream_end = kwargs.get("on_stream_end") + if on_stream: + await on_stream("hello") + if on_stream_end: + await on_stream_end() + return OutboundMessage( + channel=msg_in.channel, chat_id=msg_in.chat_id, + content="done", metadata=msg_in.metadata, + ) + + with patch.object(loop, "_process_message", side_effect=fake_process_message): + await loop._dispatch(msg) + + outbound: list[OutboundMessage] = [] + while not bus.outbound.empty(): + outbound.append(await bus.outbound.get()) + + end_msg = next(m for m in outbound if m.metadata.get("_stream_end")) + assert end_msg.metadata["message_thread_id"] == "42" + assert end_msg.metadata["_stream_end"] is True + assert end_msg.metadata["_resuming"] is False + assert "_stream_id" in end_msg.metadata + + +@pytest.mark.asyncio +async def test_streaming_preserves_arbitrary_metadata_keys() -> None: + """Both streaming callbacks should forward all original metadata keys untouched.""" + from nanobot.agent.loop import AgentLoop + + bus = MessageBus() + msg = _make_inbound(message_thread_id="99", custom_flag="abc", reply_to_id="msg77") + + loop = AgentLoop.__new__(AgentLoop) + loop.bus = bus + loop._session_locks = {} + loop._concurrency_gate = None + + async def fake_process_message(msg_in, **kwargs): + on_stream = kwargs.get("on_stream") + on_stream_end = kwargs.get("on_stream_end") + if on_stream: + await on_stream("hi") + if on_stream_end: + await on_stream_end() + return OutboundMessage( + channel=msg_in.channel, chat_id=msg_in.chat_id, + content="done", metadata=msg_in.metadata, + ) + + with patch.object(loop, "_process_message", side_effect=fake_process_message): + await loop._dispatch(msg) + + outbound: list[OutboundMessage] = [] + while not bus.outbound.empty(): + outbound.append(await bus.outbound.get()) + + stream_msg = next(m for m in outbound if m.metadata.get("_stream_delta")) + for key in ("message_thread_id", "custom_flag", "reply_to_id"): + assert stream_msg.metadata[key] == msg.metadata[key] + + end_msg = next(m for m in outbound if m.metadata.get("_stream_end")) + for key in ("message_thread_id", "custom_flag", "reply_to_id"): + assert end_msg.metadata[key] == msg.metadata[key]