mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-02 09:22:36 +00:00
fix(agent): preserve inbound metadata in streaming callbacks
The on_stream and on_stream_end closures in _dispatch hardcoded their metadata dicts, dropping channel-specific fields like message_thread_id. Copy msg.metadata first, then add internal streaming flags, matching the pattern already used by _bus_progress.
This commit is contained in:
parent
80403d352b
commit
6b140076fb
@ -321,25 +321,23 @@ class AgentLoop:
|
||||
return f"{stream_base_id}:{stream_segment}"
|
||||
|
||||
async def on_stream(delta: str) -> None:
|
||||
meta = dict(msg.metadata or {})
|
||||
meta["_stream_delta"] = True
|
||||
meta["_stream_id"] = _current_stream_id()
|
||||
await self.bus.publish_outbound(OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
content=delta,
|
||||
metadata={
|
||||
"_stream_delta": True,
|
||||
"_stream_id": _current_stream_id(),
|
||||
},
|
||||
content=delta, metadata=meta,
|
||||
))
|
||||
|
||||
async def on_stream_end(*, resuming: bool = False) -> None:
|
||||
nonlocal stream_segment
|
||||
meta = dict(msg.metadata or {})
|
||||
meta["_stream_end"] = True
|
||||
meta["_resuming"] = resuming
|
||||
meta["_stream_id"] = _current_stream_id()
|
||||
await self.bus.publish_outbound(OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
content="",
|
||||
metadata={
|
||||
"_stream_end": True,
|
||||
"_resuming": resuming,
|
||||
"_stream_id": _current_stream_id(),
|
||||
},
|
||||
content="", metadata=meta,
|
||||
))
|
||||
stream_segment += 1
|
||||
|
||||
|
||||
138
tests/agent/test_loop.py
Normal file
138
tests/agent/test_loop.py
Normal file
@ -0,0 +1,138 @@
|
||||
"""Tests for AgentLoop._dispatch streaming metadata passthrough."""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
|
||||
|
||||
def _make_inbound(**meta) -> InboundMessage:
|
||||
meta.setdefault("_wants_stream", True)
|
||||
return InboundMessage(
|
||||
channel="telegram",
|
||||
sender_id="user1",
|
||||
chat_id="chat1",
|
||||
content="hello",
|
||||
metadata=meta,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_stream_forwards_message_metadata() -> None:
|
||||
"""on_stream should include original message metadata (e.g. message_thread_id)."""
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
|
||||
bus = MessageBus()
|
||||
msg = _make_inbound(message_thread_id="42")
|
||||
|
||||
loop = AgentLoop.__new__(AgentLoop)
|
||||
loop.bus = bus
|
||||
loop._session_locks = {}
|
||||
loop._concurrency_gate = None
|
||||
|
||||
async def fake_process_message(msg_in, **kwargs):
|
||||
on_stream = kwargs.get("on_stream")
|
||||
on_stream_end = kwargs.get("on_stream_end")
|
||||
if on_stream:
|
||||
await on_stream("hello")
|
||||
if on_stream_end:
|
||||
await on_stream_end()
|
||||
return OutboundMessage(
|
||||
channel=msg_in.channel, chat_id=msg_in.chat_id,
|
||||
content="done", metadata=msg_in.metadata,
|
||||
)
|
||||
|
||||
with patch.object(loop, "_process_message", side_effect=fake_process_message):
|
||||
await loop._dispatch(msg)
|
||||
|
||||
# Collect all outbound messages (stream delta, stream end, final response)
|
||||
outbound: list[OutboundMessage] = []
|
||||
while not bus.outbound.empty():
|
||||
outbound.append(await bus.outbound.get())
|
||||
|
||||
stream_msg = next(m for m in outbound if m.metadata.get("_stream_delta"))
|
||||
assert stream_msg.metadata["message_thread_id"] == "42"
|
||||
assert stream_msg.metadata["_stream_delta"] is True
|
||||
assert "_stream_id" in stream_msg.metadata
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_stream_end_forwards_message_metadata() -> None:
|
||||
"""on_stream_end should include original message metadata."""
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
|
||||
bus = MessageBus()
|
||||
msg = _make_inbound(message_thread_id="42")
|
||||
|
||||
loop = AgentLoop.__new__(AgentLoop)
|
||||
loop.bus = bus
|
||||
loop._session_locks = {}
|
||||
loop._concurrency_gate = None
|
||||
|
||||
async def fake_process_message(msg_in, **kwargs):
|
||||
on_stream = kwargs.get("on_stream")
|
||||
on_stream_end = kwargs.get("on_stream_end")
|
||||
if on_stream:
|
||||
await on_stream("hello")
|
||||
if on_stream_end:
|
||||
await on_stream_end()
|
||||
return OutboundMessage(
|
||||
channel=msg_in.channel, chat_id=msg_in.chat_id,
|
||||
content="done", metadata=msg_in.metadata,
|
||||
)
|
||||
|
||||
with patch.object(loop, "_process_message", side_effect=fake_process_message):
|
||||
await loop._dispatch(msg)
|
||||
|
||||
outbound: list[OutboundMessage] = []
|
||||
while not bus.outbound.empty():
|
||||
outbound.append(await bus.outbound.get())
|
||||
|
||||
end_msg = next(m for m in outbound if m.metadata.get("_stream_end"))
|
||||
assert end_msg.metadata["message_thread_id"] == "42"
|
||||
assert end_msg.metadata["_stream_end"] is True
|
||||
assert end_msg.metadata["_resuming"] is False
|
||||
assert "_stream_id" in end_msg.metadata
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_preserves_arbitrary_metadata_keys() -> None:
|
||||
"""Both streaming callbacks should forward all original metadata keys untouched."""
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
|
||||
bus = MessageBus()
|
||||
msg = _make_inbound(message_thread_id="99", custom_flag="abc", reply_to_id="msg77")
|
||||
|
||||
loop = AgentLoop.__new__(AgentLoop)
|
||||
loop.bus = bus
|
||||
loop._session_locks = {}
|
||||
loop._concurrency_gate = None
|
||||
|
||||
async def fake_process_message(msg_in, **kwargs):
|
||||
on_stream = kwargs.get("on_stream")
|
||||
on_stream_end = kwargs.get("on_stream_end")
|
||||
if on_stream:
|
||||
await on_stream("hi")
|
||||
if on_stream_end:
|
||||
await on_stream_end()
|
||||
return OutboundMessage(
|
||||
channel=msg_in.channel, chat_id=msg_in.chat_id,
|
||||
content="done", metadata=msg_in.metadata,
|
||||
)
|
||||
|
||||
with patch.object(loop, "_process_message", side_effect=fake_process_message):
|
||||
await loop._dispatch(msg)
|
||||
|
||||
outbound: list[OutboundMessage] = []
|
||||
while not bus.outbound.empty():
|
||||
outbound.append(await bus.outbound.get())
|
||||
|
||||
stream_msg = next(m for m in outbound if m.metadata.get("_stream_delta"))
|
||||
for key in ("message_thread_id", "custom_flag", "reply_to_id"):
|
||||
assert stream_msg.metadata[key] == msg.metadata[key]
|
||||
|
||||
end_msg = next(m for m in outbound if m.metadata.get("_stream_end"))
|
||||
for key in ("message_thread_id", "custom_flag", "reply_to_id"):
|
||||
assert end_msg.metadata[key] == msg.metadata[key]
|
||||
Loading…
x
Reference in New Issue
Block a user