mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-27 13:25:52 +00:00
fix(channel): coalesce queued stream deltas to reduce API calls
When LLM generates faster than channel can process, asyncio.Queue accumulates multiple _stream_delta messages. Each delta triggers a separate API call (~700ms each), causing visible delay after LLM finishes. Solution: In _dispatch_outbound, drain all queued deltas for the same (channel, chat_id) before sending, combining them into a single API call. Non-matching messages are preserved in a pending buffer for subsequent processing. This reduces N API calls to 1 when queue has N accumulated deltas.
This commit is contained in:
parent
1331084873
commit
5ff9146a24
@ -118,12 +118,20 @@ class ChannelManager:
|
|||||||
"""Dispatch outbound messages to the appropriate channel."""
|
"""Dispatch outbound messages to the appropriate channel."""
|
||||||
logger.info("Outbound dispatcher started")
|
logger.info("Outbound dispatcher started")
|
||||||
|
|
||||||
|
# Buffer for messages that couldn't be processed during delta coalescing
|
||||||
|
# (since asyncio.Queue doesn't support push_front)
|
||||||
|
pending: list[OutboundMessage] = []
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
msg = await asyncio.wait_for(
|
# First check pending buffer before waiting on queue
|
||||||
self.bus.consume_outbound(),
|
if pending:
|
||||||
timeout=1.0
|
msg = pending.pop(0)
|
||||||
)
|
else:
|
||||||
|
msg = await asyncio.wait_for(
|
||||||
|
self.bus.consume_outbound(),
|
||||||
|
timeout=1.0
|
||||||
|
)
|
||||||
|
|
||||||
if msg.metadata.get("_progress"):
|
if msg.metadata.get("_progress"):
|
||||||
if msg.metadata.get("_tool_hint") and not self.config.channels.send_tool_hints:
|
if msg.metadata.get("_tool_hint") and not self.config.channels.send_tool_hints:
|
||||||
@ -131,6 +139,12 @@ class ChannelManager:
|
|||||||
if not msg.metadata.get("_tool_hint") and not self.config.channels.send_progress:
|
if not msg.metadata.get("_tool_hint") and not self.config.channels.send_progress:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Coalesce consecutive _stream_delta messages for the same (channel, chat_id)
|
||||||
|
# to reduce API calls and improve streaming latency
|
||||||
|
if msg.metadata.get("_stream_delta") and not msg.metadata.get("_stream_end"):
|
||||||
|
msg, extra_pending = self._coalesce_stream_deltas(msg)
|
||||||
|
pending.extend(extra_pending)
|
||||||
|
|
||||||
channel = self.channels.get(msg.channel)
|
channel = self.channels.get(msg.channel)
|
||||||
if channel:
|
if channel:
|
||||||
await self._send_with_retry(channel, msg)
|
await self._send_with_retry(channel, msg)
|
||||||
@ -150,6 +164,54 @@ class ChannelManager:
|
|||||||
elif not msg.metadata.get("_streamed"):
|
elif not msg.metadata.get("_streamed"):
|
||||||
await channel.send(msg)
|
await channel.send(msg)
|
||||||
|
|
||||||
|
def _coalesce_stream_deltas(
|
||||||
|
self, first_msg: OutboundMessage
|
||||||
|
) -> tuple[OutboundMessage, list[OutboundMessage]]:
|
||||||
|
"""Merge consecutive _stream_delta messages for the same (channel, chat_id).
|
||||||
|
|
||||||
|
This reduces the number of API calls when the queue has accumulated multiple
|
||||||
|
deltas, which happens when LLM generates faster than the channel can process.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple of (merged_message, list_of_non_matching_messages)
|
||||||
|
"""
|
||||||
|
target_key = (first_msg.channel, first_msg.chat_id)
|
||||||
|
combined_content = first_msg.content
|
||||||
|
final_metadata = dict(first_msg.metadata or {})
|
||||||
|
non_matching: list[OutboundMessage] = []
|
||||||
|
|
||||||
|
# Drain all pending _stream_delta messages for the same (channel, chat_id)
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
next_msg = self.bus.outbound.get_nowait()
|
||||||
|
except asyncio.QueueEmpty:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Check if this message belongs to the same stream
|
||||||
|
same_target = (next_msg.channel, next_msg.chat_id) == target_key
|
||||||
|
is_delta = next_msg.metadata and next_msg.metadata.get("_stream_delta")
|
||||||
|
is_end = next_msg.metadata and next_msg.metadata.get("_stream_end")
|
||||||
|
|
||||||
|
if same_target and is_delta and not final_metadata.get("_stream_end"):
|
||||||
|
# Accumulate content
|
||||||
|
combined_content += next_msg.content
|
||||||
|
# If we see _stream_end, remember it and stop coalescing this stream
|
||||||
|
if is_end:
|
||||||
|
final_metadata["_stream_end"] = True
|
||||||
|
# Stream ended - stop coalescing this stream
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Keep for later processing
|
||||||
|
non_matching.append(next_msg)
|
||||||
|
|
||||||
|
merged = OutboundMessage(
|
||||||
|
channel=first_msg.channel,
|
||||||
|
chat_id=first_msg.chat_id,
|
||||||
|
content=combined_content,
|
||||||
|
metadata=final_metadata,
|
||||||
|
)
|
||||||
|
return merged, non_matching
|
||||||
|
|
||||||
async def _send_with_retry(self, channel: BaseChannel, msg: OutboundMessage) -> None:
|
async def _send_with_retry(self, channel: BaseChannel, msg: OutboundMessage) -> None:
|
||||||
"""Send a message with retry on failure using exponential backoff.
|
"""Send a message with retry on failure using exponential backoff.
|
||||||
|
|
||||||
|
|||||||
262
tests/channels/test_channel_manager_delta_coalescing.py
Normal file
262
tests/channels/test_channel_manager_delta_coalescing.py
Normal file
@ -0,0 +1,262 @@
|
|||||||
|
"""Tests for ChannelManager delta coalescing to reduce streaming latency."""
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.bus.events import OutboundMessage
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.channels.base import BaseChannel
|
||||||
|
from nanobot.channels.manager import ChannelManager
|
||||||
|
from nanobot.config.schema import Config
|
||||||
|
|
||||||
|
|
||||||
|
class MockChannel(BaseChannel):
|
||||||
|
"""Mock channel for testing."""
|
||||||
|
|
||||||
|
name = "mock"
|
||||||
|
display_name = "Mock"
|
||||||
|
|
||||||
|
def __init__(self, config, bus):
|
||||||
|
super().__init__(config, bus)
|
||||||
|
self._send_delta_mock = AsyncMock()
|
||||||
|
self._send_mock = AsyncMock()
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def stop(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def send(self, msg):
|
||||||
|
"""Implement abstract method."""
|
||||||
|
return await self._send_mock(msg)
|
||||||
|
|
||||||
|
async def send_delta(self, chat_id, delta, metadata=None):
|
||||||
|
"""Override send_delta for testing."""
|
||||||
|
return await self._send_delta_mock(chat_id, delta, metadata)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def config():
|
||||||
|
"""Create a minimal config for testing."""
|
||||||
|
return Config()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def bus():
|
||||||
|
"""Create a message bus for testing."""
|
||||||
|
return MessageBus()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def manager(config, bus):
|
||||||
|
"""Create a channel manager with a mock channel."""
|
||||||
|
manager = ChannelManager(config, bus)
|
||||||
|
manager.channels["mock"] = MockChannel({}, bus)
|
||||||
|
return manager
|
||||||
|
|
||||||
|
|
||||||
|
class TestDeltaCoalescing:
|
||||||
|
"""Tests for _stream_delta message coalescing."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_single_delta_not_coalesced(self, manager, bus):
|
||||||
|
"""A single delta should be sent as-is."""
|
||||||
|
msg = OutboundMessage(
|
||||||
|
channel="mock",
|
||||||
|
chat_id="chat1",
|
||||||
|
content="Hello",
|
||||||
|
metadata={"_stream_delta": True},
|
||||||
|
)
|
||||||
|
await bus.publish_outbound(msg)
|
||||||
|
|
||||||
|
# Process one message
|
||||||
|
async def process_one():
|
||||||
|
try:
|
||||||
|
m = await asyncio.wait_for(bus.consume_outbound(), timeout=0.1)
|
||||||
|
if m.metadata.get("_stream_delta"):
|
||||||
|
m, pending = manager._coalesce_stream_deltas(m)
|
||||||
|
# Put pending back (none expected)
|
||||||
|
for p in pending:
|
||||||
|
await bus.publish_outbound(p)
|
||||||
|
channel = manager.channels.get(m.channel)
|
||||||
|
if channel:
|
||||||
|
await channel.send_delta(m.chat_id, m.content, m.metadata)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
await process_one()
|
||||||
|
|
||||||
|
manager.channels["mock"]._send_delta_mock.assert_called_once_with(
|
||||||
|
"chat1", "Hello", {"_stream_delta": True}
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_multiple_deltas_coalesced(self, manager, bus):
|
||||||
|
"""Multiple consecutive deltas for same chat should be merged."""
|
||||||
|
# Put multiple deltas in queue
|
||||||
|
for text in ["Hello", " ", "world", "!"]:
|
||||||
|
await bus.publish_outbound(OutboundMessage(
|
||||||
|
channel="mock",
|
||||||
|
chat_id="chat1",
|
||||||
|
content=text,
|
||||||
|
metadata={"_stream_delta": True},
|
||||||
|
))
|
||||||
|
|
||||||
|
# Process using coalescing logic
|
||||||
|
first_msg = await bus.consume_outbound()
|
||||||
|
merged, pending = manager._coalesce_stream_deltas(first_msg)
|
||||||
|
|
||||||
|
# Should have merged all deltas
|
||||||
|
assert merged.content == "Hello world!"
|
||||||
|
assert merged.metadata.get("_stream_delta") is True
|
||||||
|
# No pending messages (all were coalesced)
|
||||||
|
assert len(pending) == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_deltas_different_chats_not_coalesced(self, manager, bus):
|
||||||
|
"""Deltas for different chats should not be merged."""
|
||||||
|
# Put deltas for different chats
|
||||||
|
await bus.publish_outbound(OutboundMessage(
|
||||||
|
channel="mock",
|
||||||
|
chat_id="chat1",
|
||||||
|
content="Hello",
|
||||||
|
metadata={"_stream_delta": True},
|
||||||
|
))
|
||||||
|
await bus.publish_outbound(OutboundMessage(
|
||||||
|
channel="mock",
|
||||||
|
chat_id="chat2",
|
||||||
|
content="World",
|
||||||
|
metadata={"_stream_delta": True},
|
||||||
|
))
|
||||||
|
|
||||||
|
first_msg = await bus.consume_outbound()
|
||||||
|
merged, pending = manager._coalesce_stream_deltas(first_msg)
|
||||||
|
|
||||||
|
# First chat should not include second chat's content
|
||||||
|
assert merged.content == "Hello"
|
||||||
|
assert merged.chat_id == "chat1"
|
||||||
|
# Second chat should be in pending
|
||||||
|
assert len(pending) == 1
|
||||||
|
assert pending[0].chat_id == "chat2"
|
||||||
|
assert pending[0].content == "World"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stream_end_terminates_coalescing(self, manager, bus):
|
||||||
|
"""_stream_end should stop coalescing and be included in final message."""
|
||||||
|
# Put deltas with stream_end at the end
|
||||||
|
await bus.publish_outbound(OutboundMessage(
|
||||||
|
channel="mock",
|
||||||
|
chat_id="chat1",
|
||||||
|
content="Hello",
|
||||||
|
metadata={"_stream_delta": True},
|
||||||
|
))
|
||||||
|
await bus.publish_outbound(OutboundMessage(
|
||||||
|
channel="mock",
|
||||||
|
chat_id="chat1",
|
||||||
|
content=" world",
|
||||||
|
metadata={"_stream_delta": True, "_stream_end": True},
|
||||||
|
))
|
||||||
|
|
||||||
|
first_msg = await bus.consume_outbound()
|
||||||
|
merged, pending = manager._coalesce_stream_deltas(first_msg)
|
||||||
|
|
||||||
|
# Should have merged content
|
||||||
|
assert merged.content == "Hello world"
|
||||||
|
# Should have stream_end flag
|
||||||
|
assert merged.metadata.get("_stream_end") is True
|
||||||
|
# No pending
|
||||||
|
assert len(pending) == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_non_delta_message_preserved(self, manager, bus):
|
||||||
|
"""Non-delta messages should be preserved in pending list."""
|
||||||
|
await bus.publish_outbound(OutboundMessage(
|
||||||
|
channel="mock",
|
||||||
|
chat_id="chat1",
|
||||||
|
content="Delta",
|
||||||
|
metadata={"_stream_delta": True},
|
||||||
|
))
|
||||||
|
await bus.publish_outbound(OutboundMessage(
|
||||||
|
channel="mock",
|
||||||
|
chat_id="chat1",
|
||||||
|
content="Final message",
|
||||||
|
metadata={}, # Not a delta
|
||||||
|
))
|
||||||
|
|
||||||
|
first_msg = await bus.consume_outbound()
|
||||||
|
merged, pending = manager._coalesce_stream_deltas(first_msg)
|
||||||
|
|
||||||
|
assert merged.content == "Delta"
|
||||||
|
assert len(pending) == 1
|
||||||
|
assert pending[0].content == "Final message"
|
||||||
|
assert pending[0].metadata.get("_stream_delta") is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_empty_queue_stops_coalescing(self, manager, bus):
|
||||||
|
"""Coalescing should stop when queue is empty."""
|
||||||
|
await bus.publish_outbound(OutboundMessage(
|
||||||
|
channel="mock",
|
||||||
|
chat_id="chat1",
|
||||||
|
content="Only message",
|
||||||
|
metadata={"_stream_delta": True},
|
||||||
|
))
|
||||||
|
|
||||||
|
first_msg = await bus.consume_outbound()
|
||||||
|
merged, pending = manager._coalesce_stream_deltas(first_msg)
|
||||||
|
|
||||||
|
assert merged.content == "Only message"
|
||||||
|
assert len(pending) == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestDispatchOutboundWithCoalescing:
|
||||||
|
"""Tests for the full _dispatch_outbound flow with coalescing."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dispatch_coalesces_and_processes_pending(self, manager, bus):
|
||||||
|
"""_dispatch_outbound should coalesce deltas and process pending messages."""
|
||||||
|
# Put multiple deltas followed by a regular message
|
||||||
|
await bus.publish_outbound(OutboundMessage(
|
||||||
|
channel="mock",
|
||||||
|
chat_id="chat1",
|
||||||
|
content="A",
|
||||||
|
metadata={"_stream_delta": True},
|
||||||
|
))
|
||||||
|
await bus.publish_outbound(OutboundMessage(
|
||||||
|
channel="mock",
|
||||||
|
chat_id="chat1",
|
||||||
|
content="B",
|
||||||
|
metadata={"_stream_delta": True},
|
||||||
|
))
|
||||||
|
await bus.publish_outbound(OutboundMessage(
|
||||||
|
channel="mock",
|
||||||
|
chat_id="chat1",
|
||||||
|
content="Final",
|
||||||
|
metadata={}, # Regular message
|
||||||
|
))
|
||||||
|
|
||||||
|
# Run one iteration of dispatch logic manually
|
||||||
|
pending = []
|
||||||
|
processed = []
|
||||||
|
|
||||||
|
# First iteration: should coalesce A+B
|
||||||
|
if pending:
|
||||||
|
msg = pending.pop(0)
|
||||||
|
else:
|
||||||
|
msg = await bus.consume_outbound()
|
||||||
|
|
||||||
|
if msg.metadata.get("_stream_delta") and not msg.metadata.get("_stream_end"):
|
||||||
|
msg, extra_pending = manager._coalesce_stream_deltas(msg)
|
||||||
|
pending.extend(extra_pending)
|
||||||
|
|
||||||
|
channel = manager.channels.get(msg.channel)
|
||||||
|
if channel:
|
||||||
|
await channel.send_delta(msg.chat_id, msg.content, msg.metadata)
|
||||||
|
processed.append(("delta", msg.content))
|
||||||
|
|
||||||
|
# Should have sent coalesced delta
|
||||||
|
assert processed == [("delta", "AB")]
|
||||||
|
# Should have pending regular message
|
||||||
|
assert len(pending) == 1
|
||||||
|
assert pending[0].content == "Final"
|
||||||
Loading…
x
Reference in New Issue
Block a user