mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-02 09:22:36 +00:00
feat(matrix): streaming support (#2447)
* Added streaming message support with incremental updates for Matrix channel * Improve Matrix message handling and add tests * Adjust Matrix streaming edit interval to 2 seconds --------- Co-authored-by: natan <natan@podbielski>
This commit is contained in:
parent
351e3720b6
commit
b94d4c0509
@ -3,6 +3,8 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import mimetypes
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, TypeAlias
|
||||
|
||||
@ -28,8 +30,8 @@ try:
|
||||
RoomSendError,
|
||||
RoomTypingError,
|
||||
SyncError,
|
||||
UploadError,
|
||||
)
|
||||
UploadError, RoomSendResponse,
|
||||
)
|
||||
from nio.crypto.attachments import decrypt_attachment
|
||||
from nio.exceptions import EncryptionError
|
||||
except ImportError as e:
|
||||
@ -97,6 +99,22 @@ MATRIX_HTML_CLEANER = nh3.Cleaner(
|
||||
link_rel="noopener noreferrer",
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class _StreamBuf:
|
||||
"""
|
||||
Represents a buffer for managing LLM response stream data.
|
||||
|
||||
:ivar text: Stores the text content of the buffer.
|
||||
:type text: str
|
||||
:ivar event_id: Identifier for the associated event. None indicates no
|
||||
specific event association.
|
||||
:type event_id: str | None
|
||||
:ivar last_edit: Timestamp of the most recent edit to the buffer.
|
||||
:type last_edit: float
|
||||
"""
|
||||
text: str = ""
|
||||
event_id: str | None = None
|
||||
last_edit: float = 0.0
|
||||
|
||||
def _render_markdown_html(text: str) -> str | None:
|
||||
"""Render markdown to sanitized HTML; returns None for plain text."""
|
||||
@ -114,12 +132,36 @@ def _render_markdown_html(text: str) -> str | None:
|
||||
return formatted
|
||||
|
||||
|
||||
def _build_matrix_text_content(text: str) -> dict[str, object]:
|
||||
"""Build Matrix m.text payload with optional HTML formatted_body."""
|
||||
def _build_matrix_text_content(text: str, event_id: str | None = None) -> dict[str, object]:
|
||||
"""
|
||||
Constructs and returns a dictionary representing the matrix text content with optional
|
||||
HTML formatting and reference to an existing event for replacement. This function is
|
||||
primarily used to create content payloads compatible with the Matrix messaging protocol.
|
||||
|
||||
:param text: The plain text content to include in the message.
|
||||
:type text: str
|
||||
:param event_id: Optional ID of the event to replace. If provided, the function will
|
||||
include information indicating that the message is a replacement of the specified
|
||||
event.
|
||||
:type event_id: str | None
|
||||
:return: A dictionary containing the matrix text content, potentially enriched with
|
||||
HTML formatting and replacement metadata if applicable.
|
||||
:rtype: dict[str, object]
|
||||
"""
|
||||
content: dict[str, object] = {"msgtype": "m.text", "body": text, "m.mentions": {}}
|
||||
if html := _render_markdown_html(text):
|
||||
content["format"] = MATRIX_HTML_FORMAT
|
||||
content["formatted_body"] = html
|
||||
if event_id:
|
||||
content["m.new_content"] = {
|
||||
"body": text,
|
||||
"msgtype": "m.text"
|
||||
}
|
||||
content["m.relates_to"] = {
|
||||
"rel_type": "m.replace",
|
||||
"event_id": event_id
|
||||
}
|
||||
|
||||
return content
|
||||
|
||||
|
||||
@ -159,7 +201,8 @@ class MatrixConfig(Base):
|
||||
allow_from: list[str] = Field(default_factory=list)
|
||||
group_policy: Literal["open", "mention", "allowlist"] = "open"
|
||||
group_allow_from: list[str] = Field(default_factory=list)
|
||||
allow_room_mentions: bool = False
|
||||
allow_room_mentions: bool = False,
|
||||
streaming: bool = False
|
||||
|
||||
|
||||
class MatrixChannel(BaseChannel):
|
||||
@ -167,6 +210,8 @@ class MatrixChannel(BaseChannel):
|
||||
|
||||
name = "matrix"
|
||||
display_name = "Matrix"
|
||||
_STREAM_EDIT_INTERVAL = 2 # min seconds between edit_message_text calls
|
||||
monotonic_time = time.monotonic
|
||||
|
||||
@classmethod
|
||||
def default_config(cls) -> dict[str, Any]:
|
||||
@ -192,6 +237,8 @@ class MatrixChannel(BaseChannel):
|
||||
)
|
||||
self._server_upload_limit_bytes: int | None = None
|
||||
self._server_upload_limit_checked = False
|
||||
self._stream_bufs: dict[str, _StreamBuf] = {}
|
||||
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start Matrix client and begin sync loop."""
|
||||
@ -297,14 +344,17 @@ class MatrixChannel(BaseChannel):
|
||||
room = getattr(self.client, "rooms", {}).get(room_id)
|
||||
return bool(getattr(room, "encrypted", False))
|
||||
|
||||
async def _send_room_content(self, room_id: str, content: dict[str, Any]) -> None:
|
||||
async def _send_room_content(self, room_id: str,
|
||||
content: dict[str, Any]) -> None | RoomSendResponse | RoomSendError:
|
||||
"""Send m.room.message with E2EE options."""
|
||||
if not self.client:
|
||||
return
|
||||
return None
|
||||
kwargs: dict[str, Any] = {"room_id": room_id, "message_type": "m.room.message", "content": content}
|
||||
|
||||
if self.config.e2ee_enabled:
|
||||
kwargs["ignore_unverified_devices"] = True
|
||||
await self.client.room_send(**kwargs)
|
||||
response = await self.client.room_send(**kwargs)
|
||||
return response
|
||||
|
||||
async def _resolve_server_upload_limit_bytes(self) -> int | None:
|
||||
"""Query homeserver upload limit once per channel lifecycle."""
|
||||
@ -414,6 +464,47 @@ class MatrixChannel(BaseChannel):
|
||||
if not is_progress:
|
||||
await self._stop_typing_keepalive(msg.chat_id, clear_typing=True)
|
||||
|
||||
async def send_delta(self, chat_id: str, delta: str, metadata: dict[str, Any] | None = None) -> None:
|
||||
meta = metadata or {}
|
||||
relates_to = self._build_thread_relates_to(metadata)
|
||||
|
||||
if meta.get("_stream_end"):
|
||||
buf = self._stream_bufs.pop(chat_id, None)
|
||||
if not buf or not buf.event_id or not buf.text:
|
||||
return
|
||||
|
||||
await self._stop_typing_keepalive(chat_id, clear_typing=True)
|
||||
|
||||
content = _build_matrix_text_content(buf.text, buf.event_id)
|
||||
if relates_to:
|
||||
content["m.relates_to"] = relates_to
|
||||
await self._send_room_content(chat_id, content)
|
||||
return
|
||||
|
||||
buf = self._stream_bufs.get(chat_id)
|
||||
if buf is None:
|
||||
buf = _StreamBuf()
|
||||
self._stream_bufs[chat_id] = buf
|
||||
buf.text += delta
|
||||
|
||||
if not buf.text.strip():
|
||||
return
|
||||
|
||||
now = self.monotonic_time()
|
||||
|
||||
if not buf.last_edit or (now - buf.last_edit) >= self._STREAM_EDIT_INTERVAL:
|
||||
try:
|
||||
content = _build_matrix_text_content(buf.text, buf.event_id)
|
||||
response = await self._send_room_content(chat_id, content)
|
||||
buf.last_edit = now
|
||||
if not buf.event_id:
|
||||
# we are editing the same message all the time, so only the first time the event id needs to be set
|
||||
buf.event_id = response.event_id
|
||||
except Exception:
|
||||
await self._stop_typing_keepalive(metadata["room_id"], clear_typing=True)
|
||||
pass
|
||||
|
||||
|
||||
def _register_event_callbacks(self) -> None:
|
||||
self.client.add_event_callback(self._on_message, RoomMessageText)
|
||||
self.client.add_event_callback(self._on_media_message, MATRIX_MEDIA_EVENT_FILTER)
|
||||
|
||||
@ -3,6 +3,9 @@ from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from nio import RoomSendResponse
|
||||
|
||||
from nanobot.channels.matrix import _build_matrix_text_content
|
||||
|
||||
# Check optional matrix dependencies before importing
|
||||
try:
|
||||
@ -65,6 +68,7 @@ class _FakeAsyncClient:
|
||||
self.raise_on_send = False
|
||||
self.raise_on_typing = False
|
||||
self.raise_on_upload = False
|
||||
self.room_send_response: RoomSendResponse | None = RoomSendResponse(event_id="", room_id="")
|
||||
|
||||
def add_event_callback(self, callback, event_type) -> None:
|
||||
self.callbacks.append((callback, event_type))
|
||||
@ -87,7 +91,7 @@ class _FakeAsyncClient:
|
||||
message_type: str,
|
||||
content: dict[str, object],
|
||||
ignore_unverified_devices: object = _ROOM_SEND_UNSET,
|
||||
) -> None:
|
||||
) -> RoomSendResponse:
|
||||
call: dict[str, object] = {
|
||||
"room_id": room_id,
|
||||
"message_type": message_type,
|
||||
@ -98,6 +102,7 @@ class _FakeAsyncClient:
|
||||
self.room_send_calls.append(call)
|
||||
if self.raise_on_send:
|
||||
raise RuntimeError("send failed")
|
||||
return self.room_send_response
|
||||
|
||||
async def room_typing(
|
||||
self,
|
||||
@ -520,6 +525,7 @@ async def test_on_message_room_mention_requires_opt_in() -> None:
|
||||
source={"content": {"m.mentions": {"room": True}}},
|
||||
)
|
||||
|
||||
channel.config.allow_room_mentions = False
|
||||
await channel._on_message(room, room_mention_event)
|
||||
assert handled == []
|
||||
assert client.typing_calls == []
|
||||
@ -1322,3 +1328,220 @@ async def test_send_keeps_plaintext_only_for_plain_text() -> None:
|
||||
"body": text,
|
||||
"m.mentions": {},
|
||||
}
|
||||
|
||||
|
||||
def test_build_matrix_text_content_basic_text() -> None:
|
||||
"""Test basic text content without HTML formatting."""
|
||||
result = _build_matrix_text_content("Hello, World!")
|
||||
expected = {
|
||||
"msgtype": "m.text",
|
||||
"body": "Hello, World!",
|
||||
"m.mentions": {}
|
||||
}
|
||||
assert expected == result
|
||||
|
||||
|
||||
def test_build_matrix_text_content_with_markdown() -> None:
|
||||
"""Test text content with markdown that renders to HTML."""
|
||||
text = "*Hello* **World**"
|
||||
result = _build_matrix_text_content(text)
|
||||
assert "msgtype" in result
|
||||
assert "body" in result
|
||||
assert result["body"] == text
|
||||
assert "format" in result
|
||||
assert result["format"] == "org.matrix.custom.html"
|
||||
assert "formatted_body" in result
|
||||
assert isinstance(result["formatted_body"], str)
|
||||
assert len(result["formatted_body"]) > 0
|
||||
|
||||
|
||||
def test_build_matrix_text_content_with_event_id() -> None:
|
||||
"""Test text content with event_id for message replacement."""
|
||||
event_id = "$8E2XVyINbEhcuAxvxd1d9JhQosNPzkVoU8TrbCAvyHo"
|
||||
result = _build_matrix_text_content("Updated message", event_id)
|
||||
assert "msgtype" in result
|
||||
assert "body" in result
|
||||
assert result["m.new_content"]
|
||||
assert result["m.new_content"]["body"] == "Updated message"
|
||||
assert result["m.relates_to"]["rel_type"] == "m.replace"
|
||||
assert result["m.relates_to"]["event_id"] == event_id
|
||||
|
||||
|
||||
def test_build_matrix_text_content_no_event_id() -> None:
|
||||
"""Test that when event_id is not provided, no extra properties are added."""
|
||||
result = _build_matrix_text_content("Regular message")
|
||||
|
||||
# Basic required properties should be present
|
||||
assert "msgtype" in result
|
||||
assert "body" in result
|
||||
assert result["body"] == "Regular message"
|
||||
|
||||
# Extra properties for replacement should NOT be present
|
||||
assert "m.relates_to" not in result
|
||||
assert "m.new_content" not in result
|
||||
assert "format" not in result
|
||||
assert "formatted_body" not in result
|
||||
|
||||
|
||||
def test_build_matrix_text_content_plain_text_no_html() -> None:
|
||||
"""Test plain text that should not include HTML formatting."""
|
||||
result = _build_matrix_text_content("Simple plain text")
|
||||
assert "msgtype" in result
|
||||
assert "body" in result
|
||||
assert "format" not in result
|
||||
assert "formatted_body" not in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_room_content_returns_room_send_response():
|
||||
"""Test that _send_room_content returns the response from client.room_send."""
|
||||
client = _FakeAsyncClient("", "", "", None)
|
||||
channel = MatrixChannel(_make_config(), MessageBus())
|
||||
channel.client = client
|
||||
|
||||
room_id = "!test_room:matrix.org"
|
||||
content = {"msgtype": "m.text", "body": "Hello World"}
|
||||
|
||||
result = await channel._send_room_content(room_id, content)
|
||||
|
||||
assert result is client.room_send_response
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delta_creates_stream_buffer_and_sends_initial_message() -> None:
|
||||
channel = MatrixChannel(_make_config(), MessageBus())
|
||||
client = _FakeAsyncClient("", "", "", None)
|
||||
channel.client = client
|
||||
client.room_send_response.event_id = "$8E2XVyINbEhcuAxvxd1d9JhQosNPzkVoU8TrbCAvyHo"
|
||||
|
||||
await channel.send_delta("!room:matrix.org", "Hello")
|
||||
|
||||
assert "!room:matrix.org" in channel._stream_bufs
|
||||
buf = channel._stream_bufs["!room:matrix.org"]
|
||||
assert buf.text == "Hello"
|
||||
assert buf.event_id == "$8E2XVyINbEhcuAxvxd1d9JhQosNPzkVoU8TrbCAvyHo"
|
||||
assert len(client.room_send_calls) == 1
|
||||
assert client.room_send_calls[0]["content"]["body"] == "Hello"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delta_appends_without_sending_before_edit_interval(monkeypatch) -> None:
|
||||
channel = MatrixChannel(_make_config(), MessageBus())
|
||||
client = _FakeAsyncClient("", "", "", None)
|
||||
channel.client = client
|
||||
client.room_send_response.event_id = "$8E2XVyINbEhcuAxvxd1d9JhQosNPzkVoU8TrbCAvyHo"
|
||||
|
||||
now = 100.0
|
||||
monkeypatch.setattr(channel, "monotonic_time", lambda: now)
|
||||
|
||||
await channel.send_delta("!room:matrix.org", "Hello")
|
||||
assert len(client.room_send_calls) == 1
|
||||
|
||||
await channel.send_delta("!room:matrix.org", " world")
|
||||
assert len(client.room_send_calls) == 1
|
||||
|
||||
buf = channel._stream_bufs["!room:matrix.org"]
|
||||
assert buf.text == "Hello world"
|
||||
assert buf.event_id == "$8E2XVyINbEhcuAxvxd1d9JhQosNPzkVoU8TrbCAvyHo"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delta_edits_again_after_interval(monkeypatch) -> None:
|
||||
channel = MatrixChannel(_make_config(), MessageBus())
|
||||
client = _FakeAsyncClient("", "", "", None)
|
||||
channel.client = client
|
||||
client.room_send_response.event_id = "$8E2XVyINbEhcuAxvxd1d9JhQosNPzkVoU8TrbCAvyHo"
|
||||
|
||||
times = [100.0, 102.0, 104.0, 106.0, 108.0]
|
||||
times.reverse()
|
||||
monkeypatch.setattr(channel, "monotonic_time", lambda: times and times.pop())
|
||||
|
||||
await channel.send_delta("!room:matrix.org", "Hello")
|
||||
await channel.send_delta("!room:matrix.org", " world")
|
||||
|
||||
assert len(client.room_send_calls) == 2
|
||||
first_content = client.room_send_calls[0]["content"]
|
||||
second_content = client.room_send_calls[1]["content"]
|
||||
|
||||
assert "body" in first_content
|
||||
assert first_content["body"] == "Hello"
|
||||
assert "m.relates_to" not in first_content
|
||||
|
||||
assert "body" in second_content
|
||||
assert "m.relates_to" in second_content
|
||||
assert second_content["body"] == "Hello world"
|
||||
assert second_content["m.relates_to"] == {
|
||||
"rel_type": "m.replace",
|
||||
"event_id": "$8E2XVyINbEhcuAxvxd1d9JhQosNPzkVoU8TrbCAvyHo",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delta_stream_end_replaces_existing_message() -> None:
|
||||
channel = MatrixChannel(_make_config(), MessageBus())
|
||||
client = _FakeAsyncClient("", "", "", None)
|
||||
channel.client = client
|
||||
|
||||
channel._stream_bufs["!room:matrix.org"] = matrix_module._StreamBuf(
|
||||
text="Final text",
|
||||
event_id="event-1",
|
||||
last_edit=100.0,
|
||||
)
|
||||
|
||||
await channel.send_delta("!room:matrix.org", "", {"_stream_end": True})
|
||||
|
||||
assert "!room:matrix.org" not in channel._stream_bufs
|
||||
assert client.typing_calls[-1] == ("!room:matrix.org", False, TYPING_NOTICE_TIMEOUT_MS)
|
||||
assert len(client.room_send_calls) == 1
|
||||
assert client.room_send_calls[0]["content"]["body"] == "Final text"
|
||||
assert client.room_send_calls[0]["content"]["m.relates_to"] == {
|
||||
"rel_type": "m.replace",
|
||||
"event_id": "event-1",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delta_stream_end_noop_when_buffer_missing() -> None:
|
||||
channel = MatrixChannel(_make_config(), MessageBus())
|
||||
client = _FakeAsyncClient("", "", "", None)
|
||||
channel.client = client
|
||||
|
||||
await channel.send_delta("!room:matrix.org", "", {"_stream_end": True})
|
||||
|
||||
assert client.room_send_calls == []
|
||||
assert client.typing_calls == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delta_on_error_stops_typing(monkeypatch) -> None:
|
||||
channel = MatrixChannel(_make_config(), MessageBus())
|
||||
client = _FakeAsyncClient("", "", "", None)
|
||||
client.raise_on_send = True
|
||||
channel.client = client
|
||||
|
||||
now = 100.0
|
||||
monkeypatch.setattr(channel, "monotonic_time", lambda: now)
|
||||
|
||||
await channel.send_delta("!room:matrix.org", "Hello", {"room_id": "!room:matrix.org"})
|
||||
|
||||
assert "!room:matrix.org" in channel._stream_bufs
|
||||
assert channel._stream_bufs["!room:matrix.org"].text == "Hello"
|
||||
assert len(client.room_send_calls) == 1
|
||||
|
||||
assert len(client.typing_calls) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_delta_ignores_whitespace_only_delta(monkeypatch) -> None:
|
||||
channel = MatrixChannel(_make_config(), MessageBus())
|
||||
client = _FakeAsyncClient("", "", "", None)
|
||||
channel.client = client
|
||||
|
||||
now = 100.0
|
||||
monkeypatch.setattr(channel, "monotonic_time", lambda: now)
|
||||
|
||||
await channel.send_delta("!room:matrix.org", " ")
|
||||
|
||||
assert "!room:matrix.org" in channel._stream_bufs
|
||||
assert channel._stream_bufs["!room:matrix.org"].text == " "
|
||||
assert client.room_send_calls == []
|
||||
Loading…
x
Reference in New Issue
Block a user