feat(matrix): streaming support (#2447)

* Added streaming message support with incremental updates for Matrix channel

* Improve Matrix message handling and add tests

* Adjust Matrix streaming edit interval to 2 seconds

---------

Co-authored-by: natan <natan@podbielski>
This commit is contained in:
npodbielski 2026-03-27 08:12:14 +01:00 committed by Xubin Ren
parent 351e3720b6
commit b94d4c0509
2 changed files with 323 additions and 9 deletions

View File

@ -3,6 +3,8 @@
import asyncio
import logging
import mimetypes
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Literal, TypeAlias
@ -28,8 +30,8 @@ try:
RoomSendError,
RoomTypingError,
SyncError,
UploadError,
)
UploadError, RoomSendResponse,
)
from nio.crypto.attachments import decrypt_attachment
from nio.exceptions import EncryptionError
except ImportError as e:
@ -97,6 +99,22 @@ MATRIX_HTML_CLEANER = nh3.Cleaner(
link_rel="noopener noreferrer",
)
@dataclass
class _StreamBuf:
"""
Represents a buffer for managing LLM response stream data.
:ivar text: Stores the text content of the buffer.
:type text: str
:ivar event_id: Identifier for the associated event. None indicates no
specific event association.
:type event_id: str | None
:ivar last_edit: Timestamp of the most recent edit to the buffer.
:type last_edit: float
"""
text: str = ""
event_id: str | None = None
last_edit: float = 0.0
def _render_markdown_html(text: str) -> str | None:
"""Render markdown to sanitized HTML; returns None for plain text."""
@ -114,12 +132,36 @@ def _render_markdown_html(text: str) -> str | None:
return formatted
def _build_matrix_text_content(text: str) -> dict[str, object]:
"""Build Matrix m.text payload with optional HTML formatted_body."""
def _build_matrix_text_content(text: str, event_id: str | None = None) -> dict[str, object]:
"""
Constructs and returns a dictionary representing the matrix text content with optional
HTML formatting and reference to an existing event for replacement. This function is
primarily used to create content payloads compatible with the Matrix messaging protocol.
:param text: The plain text content to include in the message.
:type text: str
:param event_id: Optional ID of the event to replace. If provided, the function will
include information indicating that the message is a replacement of the specified
event.
:type event_id: str | None
:return: A dictionary containing the matrix text content, potentially enriched with
HTML formatting and replacement metadata if applicable.
:rtype: dict[str, object]
"""
content: dict[str, object] = {"msgtype": "m.text", "body": text, "m.mentions": {}}
if html := _render_markdown_html(text):
content["format"] = MATRIX_HTML_FORMAT
content["formatted_body"] = html
if event_id:
content["m.new_content"] = {
"body": text,
"msgtype": "m.text"
}
content["m.relates_to"] = {
"rel_type": "m.replace",
"event_id": event_id
}
return content
@ -159,7 +201,8 @@ class MatrixConfig(Base):
allow_from: list[str] = Field(default_factory=list)
group_policy: Literal["open", "mention", "allowlist"] = "open"
group_allow_from: list[str] = Field(default_factory=list)
allow_room_mentions: bool = False
allow_room_mentions: bool = False,
streaming: bool = False
class MatrixChannel(BaseChannel):
@ -167,6 +210,8 @@ class MatrixChannel(BaseChannel):
name = "matrix"
display_name = "Matrix"
_STREAM_EDIT_INTERVAL = 2 # min seconds between edit_message_text calls
monotonic_time = time.monotonic
@classmethod
def default_config(cls) -> dict[str, Any]:
@ -192,6 +237,8 @@ class MatrixChannel(BaseChannel):
)
self._server_upload_limit_bytes: int | None = None
self._server_upload_limit_checked = False
self._stream_bufs: dict[str, _StreamBuf] = {}
async def start(self) -> None:
"""Start Matrix client and begin sync loop."""
@ -297,14 +344,17 @@ class MatrixChannel(BaseChannel):
room = getattr(self.client, "rooms", {}).get(room_id)
return bool(getattr(room, "encrypted", False))
async def _send_room_content(self, room_id: str, content: dict[str, Any]) -> None:
async def _send_room_content(self, room_id: str,
content: dict[str, Any]) -> None | RoomSendResponse | RoomSendError:
"""Send m.room.message with E2EE options."""
if not self.client:
return
return None
kwargs: dict[str, Any] = {"room_id": room_id, "message_type": "m.room.message", "content": content}
if self.config.e2ee_enabled:
kwargs["ignore_unverified_devices"] = True
await self.client.room_send(**kwargs)
response = await self.client.room_send(**kwargs)
return response
async def _resolve_server_upload_limit_bytes(self) -> int | None:
"""Query homeserver upload limit once per channel lifecycle."""
@ -414,6 +464,47 @@ class MatrixChannel(BaseChannel):
if not is_progress:
await self._stop_typing_keepalive(msg.chat_id, clear_typing=True)
async def send_delta(self, chat_id: str, delta: str, metadata: dict[str, Any] | None = None) -> None:
meta = metadata or {}
relates_to = self._build_thread_relates_to(metadata)
if meta.get("_stream_end"):
buf = self._stream_bufs.pop(chat_id, None)
if not buf or not buf.event_id or not buf.text:
return
await self._stop_typing_keepalive(chat_id, clear_typing=True)
content = _build_matrix_text_content(buf.text, buf.event_id)
if relates_to:
content["m.relates_to"] = relates_to
await self._send_room_content(chat_id, content)
return
buf = self._stream_bufs.get(chat_id)
if buf is None:
buf = _StreamBuf()
self._stream_bufs[chat_id] = buf
buf.text += delta
if not buf.text.strip():
return
now = self.monotonic_time()
if not buf.last_edit or (now - buf.last_edit) >= self._STREAM_EDIT_INTERVAL:
try:
content = _build_matrix_text_content(buf.text, buf.event_id)
response = await self._send_room_content(chat_id, content)
buf.last_edit = now
if not buf.event_id:
# we are editing the same message all the time, so only the first time the event id needs to be set
buf.event_id = response.event_id
except Exception:
await self._stop_typing_keepalive(metadata["room_id"], clear_typing=True)
pass
def _register_event_callbacks(self) -> None:
self.client.add_event_callback(self._on_message, RoomMessageText)
self.client.add_event_callback(self._on_media_message, MATRIX_MEDIA_EVENT_FILTER)

View File

@ -3,6 +3,9 @@ from pathlib import Path
from types import SimpleNamespace
import pytest
from nio import RoomSendResponse
from nanobot.channels.matrix import _build_matrix_text_content
# Check optional matrix dependencies before importing
try:
@ -65,6 +68,7 @@ class _FakeAsyncClient:
self.raise_on_send = False
self.raise_on_typing = False
self.raise_on_upload = False
self.room_send_response: RoomSendResponse | None = RoomSendResponse(event_id="", room_id="")
def add_event_callback(self, callback, event_type) -> None:
self.callbacks.append((callback, event_type))
@ -87,7 +91,7 @@ class _FakeAsyncClient:
message_type: str,
content: dict[str, object],
ignore_unverified_devices: object = _ROOM_SEND_UNSET,
) -> None:
) -> RoomSendResponse:
call: dict[str, object] = {
"room_id": room_id,
"message_type": message_type,
@ -98,6 +102,7 @@ class _FakeAsyncClient:
self.room_send_calls.append(call)
if self.raise_on_send:
raise RuntimeError("send failed")
return self.room_send_response
async def room_typing(
self,
@ -520,6 +525,7 @@ async def test_on_message_room_mention_requires_opt_in() -> None:
source={"content": {"m.mentions": {"room": True}}},
)
channel.config.allow_room_mentions = False
await channel._on_message(room, room_mention_event)
assert handled == []
assert client.typing_calls == []
@ -1322,3 +1328,220 @@ async def test_send_keeps_plaintext_only_for_plain_text() -> None:
"body": text,
"m.mentions": {},
}
def test_build_matrix_text_content_basic_text() -> None:
"""Test basic text content without HTML formatting."""
result = _build_matrix_text_content("Hello, World!")
expected = {
"msgtype": "m.text",
"body": "Hello, World!",
"m.mentions": {}
}
assert expected == result
def test_build_matrix_text_content_with_markdown() -> None:
"""Test text content with markdown that renders to HTML."""
text = "*Hello* **World**"
result = _build_matrix_text_content(text)
assert "msgtype" in result
assert "body" in result
assert result["body"] == text
assert "format" in result
assert result["format"] == "org.matrix.custom.html"
assert "formatted_body" in result
assert isinstance(result["formatted_body"], str)
assert len(result["formatted_body"]) > 0
def test_build_matrix_text_content_with_event_id() -> None:
"""Test text content with event_id for message replacement."""
event_id = "$8E2XVyINbEhcuAxvxd1d9JhQosNPzkVoU8TrbCAvyHo"
result = _build_matrix_text_content("Updated message", event_id)
assert "msgtype" in result
assert "body" in result
assert result["m.new_content"]
assert result["m.new_content"]["body"] == "Updated message"
assert result["m.relates_to"]["rel_type"] == "m.replace"
assert result["m.relates_to"]["event_id"] == event_id
def test_build_matrix_text_content_no_event_id() -> None:
"""Test that when event_id is not provided, no extra properties are added."""
result = _build_matrix_text_content("Regular message")
# Basic required properties should be present
assert "msgtype" in result
assert "body" in result
assert result["body"] == "Regular message"
# Extra properties for replacement should NOT be present
assert "m.relates_to" not in result
assert "m.new_content" not in result
assert "format" not in result
assert "formatted_body" not in result
def test_build_matrix_text_content_plain_text_no_html() -> None:
"""Test plain text that should not include HTML formatting."""
result = _build_matrix_text_content("Simple plain text")
assert "msgtype" in result
assert "body" in result
assert "format" not in result
assert "formatted_body" not in result
@pytest.mark.asyncio
async def test_send_room_content_returns_room_send_response():
"""Test that _send_room_content returns the response from client.room_send."""
client = _FakeAsyncClient("", "", "", None)
channel = MatrixChannel(_make_config(), MessageBus())
channel.client = client
room_id = "!test_room:matrix.org"
content = {"msgtype": "m.text", "body": "Hello World"}
result = await channel._send_room_content(room_id, content)
assert result is client.room_send_response
@pytest.mark.asyncio
async def test_send_delta_creates_stream_buffer_and_sends_initial_message() -> None:
channel = MatrixChannel(_make_config(), MessageBus())
client = _FakeAsyncClient("", "", "", None)
channel.client = client
client.room_send_response.event_id = "$8E2XVyINbEhcuAxvxd1d9JhQosNPzkVoU8TrbCAvyHo"
await channel.send_delta("!room:matrix.org", "Hello")
assert "!room:matrix.org" in channel._stream_bufs
buf = channel._stream_bufs["!room:matrix.org"]
assert buf.text == "Hello"
assert buf.event_id == "$8E2XVyINbEhcuAxvxd1d9JhQosNPzkVoU8TrbCAvyHo"
assert len(client.room_send_calls) == 1
assert client.room_send_calls[0]["content"]["body"] == "Hello"
@pytest.mark.asyncio
async def test_send_delta_appends_without_sending_before_edit_interval(monkeypatch) -> None:
channel = MatrixChannel(_make_config(), MessageBus())
client = _FakeAsyncClient("", "", "", None)
channel.client = client
client.room_send_response.event_id = "$8E2XVyINbEhcuAxvxd1d9JhQosNPzkVoU8TrbCAvyHo"
now = 100.0
monkeypatch.setattr(channel, "monotonic_time", lambda: now)
await channel.send_delta("!room:matrix.org", "Hello")
assert len(client.room_send_calls) == 1
await channel.send_delta("!room:matrix.org", " world")
assert len(client.room_send_calls) == 1
buf = channel._stream_bufs["!room:matrix.org"]
assert buf.text == "Hello world"
assert buf.event_id == "$8E2XVyINbEhcuAxvxd1d9JhQosNPzkVoU8TrbCAvyHo"
@pytest.mark.asyncio
async def test_send_delta_edits_again_after_interval(monkeypatch) -> None:
channel = MatrixChannel(_make_config(), MessageBus())
client = _FakeAsyncClient("", "", "", None)
channel.client = client
client.room_send_response.event_id = "$8E2XVyINbEhcuAxvxd1d9JhQosNPzkVoU8TrbCAvyHo"
times = [100.0, 102.0, 104.0, 106.0, 108.0]
times.reverse()
monkeypatch.setattr(channel, "monotonic_time", lambda: times and times.pop())
await channel.send_delta("!room:matrix.org", "Hello")
await channel.send_delta("!room:matrix.org", " world")
assert len(client.room_send_calls) == 2
first_content = client.room_send_calls[0]["content"]
second_content = client.room_send_calls[1]["content"]
assert "body" in first_content
assert first_content["body"] == "Hello"
assert "m.relates_to" not in first_content
assert "body" in second_content
assert "m.relates_to" in second_content
assert second_content["body"] == "Hello world"
assert second_content["m.relates_to"] == {
"rel_type": "m.replace",
"event_id": "$8E2XVyINbEhcuAxvxd1d9JhQosNPzkVoU8TrbCAvyHo",
}
@pytest.mark.asyncio
async def test_send_delta_stream_end_replaces_existing_message() -> None:
channel = MatrixChannel(_make_config(), MessageBus())
client = _FakeAsyncClient("", "", "", None)
channel.client = client
channel._stream_bufs["!room:matrix.org"] = matrix_module._StreamBuf(
text="Final text",
event_id="event-1",
last_edit=100.0,
)
await channel.send_delta("!room:matrix.org", "", {"_stream_end": True})
assert "!room:matrix.org" not in channel._stream_bufs
assert client.typing_calls[-1] == ("!room:matrix.org", False, TYPING_NOTICE_TIMEOUT_MS)
assert len(client.room_send_calls) == 1
assert client.room_send_calls[0]["content"]["body"] == "Final text"
assert client.room_send_calls[0]["content"]["m.relates_to"] == {
"rel_type": "m.replace",
"event_id": "event-1",
}
@pytest.mark.asyncio
async def test_send_delta_stream_end_noop_when_buffer_missing() -> None:
channel = MatrixChannel(_make_config(), MessageBus())
client = _FakeAsyncClient("", "", "", None)
channel.client = client
await channel.send_delta("!room:matrix.org", "", {"_stream_end": True})
assert client.room_send_calls == []
assert client.typing_calls == []
@pytest.mark.asyncio
async def test_send_delta_on_error_stops_typing(monkeypatch) -> None:
channel = MatrixChannel(_make_config(), MessageBus())
client = _FakeAsyncClient("", "", "", None)
client.raise_on_send = True
channel.client = client
now = 100.0
monkeypatch.setattr(channel, "monotonic_time", lambda: now)
await channel.send_delta("!room:matrix.org", "Hello", {"room_id": "!room:matrix.org"})
assert "!room:matrix.org" in channel._stream_bufs
assert channel._stream_bufs["!room:matrix.org"].text == "Hello"
assert len(client.room_send_calls) == 1
assert len(client.typing_calls) == 1
@pytest.mark.asyncio
async def test_send_delta_ignores_whitespace_only_delta(monkeypatch) -> None:
channel = MatrixChannel(_make_config(), MessageBus())
client = _FakeAsyncClient("", "", "", None)
channel.client = client
now = 100.0
monkeypatch.setattr(channel, "monotonic_time", lambda: now)
await channel.send_delta("!room:matrix.org", " ")
assert "!room:matrix.org" in channel._stream_bufs
assert channel._stream_bufs["!room:matrix.org"].text == " "
assert client.room_send_calls == []