diff --git a/nanobot/channels/matrix.py b/nanobot/channels/matrix.py index 98926735e..dcece1043 100644 --- a/nanobot/channels/matrix.py +++ b/nanobot/channels/matrix.py @@ -3,6 +3,8 @@ import asyncio import logging import mimetypes +import time +from dataclasses import dataclass from pathlib import Path from typing import Any, Literal, TypeAlias @@ -28,8 +30,8 @@ try: RoomSendError, RoomTypingError, SyncError, - UploadError, - ) + UploadError, RoomSendResponse, +) from nio.crypto.attachments import decrypt_attachment from nio.exceptions import EncryptionError except ImportError as e: @@ -97,6 +99,22 @@ MATRIX_HTML_CLEANER = nh3.Cleaner( link_rel="noopener noreferrer", ) +@dataclass +class _StreamBuf: + """ + Represents a buffer for managing LLM response stream data. + + :ivar text: Stores the text content of the buffer. + :type text: str + :ivar event_id: Identifier for the associated event. None indicates no + specific event association. + :type event_id: str | None + :ivar last_edit: Timestamp of the most recent edit to the buffer. + :type last_edit: float + """ + text: str = "" + event_id: str | None = None + last_edit: float = 0.0 def _render_markdown_html(text: str) -> str | None: """Render markdown to sanitized HTML; returns None for plain text.""" @@ -114,12 +132,36 @@ def _render_markdown_html(text: str) -> str | None: return formatted -def _build_matrix_text_content(text: str) -> dict[str, object]: - """Build Matrix m.text payload with optional HTML formatted_body.""" +def _build_matrix_text_content(text: str, event_id: str | None = None) -> dict[str, object]: + """ + Constructs and returns a dictionary representing the matrix text content with optional + HTML formatting and reference to an existing event for replacement. This function is + primarily used to create content payloads compatible with the Matrix messaging protocol. + + :param text: The plain text content to include in the message. + :type text: str + :param event_id: Optional ID of the event to replace. If provided, the function will + include information indicating that the message is a replacement of the specified + event. + :type event_id: str | None + :return: A dictionary containing the matrix text content, potentially enriched with + HTML formatting and replacement metadata if applicable. + :rtype: dict[str, object] + """ content: dict[str, object] = {"msgtype": "m.text", "body": text, "m.mentions": {}} if html := _render_markdown_html(text): content["format"] = MATRIX_HTML_FORMAT content["formatted_body"] = html + if event_id: + content["m.new_content"] = { + "body": text, + "msgtype": "m.text" + } + content["m.relates_to"] = { + "rel_type": "m.replace", + "event_id": event_id + } + return content @@ -159,7 +201,8 @@ class MatrixConfig(Base): allow_from: list[str] = Field(default_factory=list) group_policy: Literal["open", "mention", "allowlist"] = "open" group_allow_from: list[str] = Field(default_factory=list) - allow_room_mentions: bool = False + allow_room_mentions: bool = False, + streaming: bool = False class MatrixChannel(BaseChannel): @@ -167,6 +210,8 @@ class MatrixChannel(BaseChannel): name = "matrix" display_name = "Matrix" + _STREAM_EDIT_INTERVAL = 2 # min seconds between edit_message_text calls + monotonic_time = time.monotonic @classmethod def default_config(cls) -> dict[str, Any]: @@ -192,6 +237,8 @@ class MatrixChannel(BaseChannel): ) self._server_upload_limit_bytes: int | None = None self._server_upload_limit_checked = False + self._stream_bufs: dict[str, _StreamBuf] = {} + async def start(self) -> None: """Start Matrix client and begin sync loop.""" @@ -297,14 +344,17 @@ class MatrixChannel(BaseChannel): room = getattr(self.client, "rooms", {}).get(room_id) return bool(getattr(room, "encrypted", False)) - async def _send_room_content(self, room_id: str, content: dict[str, Any]) -> None: + async def _send_room_content(self, room_id: str, + content: dict[str, Any]) -> None | RoomSendResponse | RoomSendError: """Send m.room.message with E2EE options.""" if not self.client: - return + return None kwargs: dict[str, Any] = {"room_id": room_id, "message_type": "m.room.message", "content": content} + if self.config.e2ee_enabled: kwargs["ignore_unverified_devices"] = True - await self.client.room_send(**kwargs) + response = await self.client.room_send(**kwargs) + return response async def _resolve_server_upload_limit_bytes(self) -> int | None: """Query homeserver upload limit once per channel lifecycle.""" @@ -414,6 +464,47 @@ class MatrixChannel(BaseChannel): if not is_progress: await self._stop_typing_keepalive(msg.chat_id, clear_typing=True) + async def send_delta(self, chat_id: str, delta: str, metadata: dict[str, Any] | None = None) -> None: + meta = metadata or {} + relates_to = self._build_thread_relates_to(metadata) + + if meta.get("_stream_end"): + buf = self._stream_bufs.pop(chat_id, None) + if not buf or not buf.event_id or not buf.text: + return + + await self._stop_typing_keepalive(chat_id, clear_typing=True) + + content = _build_matrix_text_content(buf.text, buf.event_id) + if relates_to: + content["m.relates_to"] = relates_to + await self._send_room_content(chat_id, content) + return + + buf = self._stream_bufs.get(chat_id) + if buf is None: + buf = _StreamBuf() + self._stream_bufs[chat_id] = buf + buf.text += delta + + if not buf.text.strip(): + return + + now = self.monotonic_time() + + if not buf.last_edit or (now - buf.last_edit) >= self._STREAM_EDIT_INTERVAL: + try: + content = _build_matrix_text_content(buf.text, buf.event_id) + response = await self._send_room_content(chat_id, content) + buf.last_edit = now + if not buf.event_id: + # we are editing the same message all the time, so only the first time the event id needs to be set + buf.event_id = response.event_id + except Exception: + await self._stop_typing_keepalive(metadata["room_id"], clear_typing=True) + pass + + def _register_event_callbacks(self) -> None: self.client.add_event_callback(self._on_message, RoomMessageText) self.client.add_event_callback(self._on_media_message, MATRIX_MEDIA_EVENT_FILTER) diff --git a/tests/channels/test_matrix_channel.py b/tests/channels/test_matrix_channel.py index dd5e97d90..3ad65e76b 100644 --- a/tests/channels/test_matrix_channel.py +++ b/tests/channels/test_matrix_channel.py @@ -3,6 +3,9 @@ from pathlib import Path from types import SimpleNamespace import pytest +from nio import RoomSendResponse + +from nanobot.channels.matrix import _build_matrix_text_content # Check optional matrix dependencies before importing try: @@ -65,6 +68,7 @@ class _FakeAsyncClient: self.raise_on_send = False self.raise_on_typing = False self.raise_on_upload = False + self.room_send_response: RoomSendResponse | None = RoomSendResponse(event_id="", room_id="") def add_event_callback(self, callback, event_type) -> None: self.callbacks.append((callback, event_type)) @@ -87,7 +91,7 @@ class _FakeAsyncClient: message_type: str, content: dict[str, object], ignore_unverified_devices: object = _ROOM_SEND_UNSET, - ) -> None: + ) -> RoomSendResponse: call: dict[str, object] = { "room_id": room_id, "message_type": message_type, @@ -98,6 +102,7 @@ class _FakeAsyncClient: self.room_send_calls.append(call) if self.raise_on_send: raise RuntimeError("send failed") + return self.room_send_response async def room_typing( self, @@ -520,6 +525,7 @@ async def test_on_message_room_mention_requires_opt_in() -> None: source={"content": {"m.mentions": {"room": True}}}, ) + channel.config.allow_room_mentions = False await channel._on_message(room, room_mention_event) assert handled == [] assert client.typing_calls == [] @@ -1322,3 +1328,220 @@ async def test_send_keeps_plaintext_only_for_plain_text() -> None: "body": text, "m.mentions": {}, } + + +def test_build_matrix_text_content_basic_text() -> None: + """Test basic text content without HTML formatting.""" + result = _build_matrix_text_content("Hello, World!") + expected = { + "msgtype": "m.text", + "body": "Hello, World!", + "m.mentions": {} + } + assert expected == result + + +def test_build_matrix_text_content_with_markdown() -> None: + """Test text content with markdown that renders to HTML.""" + text = "*Hello* **World**" + result = _build_matrix_text_content(text) + assert "msgtype" in result + assert "body" in result + assert result["body"] == text + assert "format" in result + assert result["format"] == "org.matrix.custom.html" + assert "formatted_body" in result + assert isinstance(result["formatted_body"], str) + assert len(result["formatted_body"]) > 0 + + +def test_build_matrix_text_content_with_event_id() -> None: + """Test text content with event_id for message replacement.""" + event_id = "$8E2XVyINbEhcuAxvxd1d9JhQosNPzkVoU8TrbCAvyHo" + result = _build_matrix_text_content("Updated message", event_id) + assert "msgtype" in result + assert "body" in result + assert result["m.new_content"] + assert result["m.new_content"]["body"] == "Updated message" + assert result["m.relates_to"]["rel_type"] == "m.replace" + assert result["m.relates_to"]["event_id"] == event_id + + +def test_build_matrix_text_content_no_event_id() -> None: + """Test that when event_id is not provided, no extra properties are added.""" + result = _build_matrix_text_content("Regular message") + + # Basic required properties should be present + assert "msgtype" in result + assert "body" in result + assert result["body"] == "Regular message" + + # Extra properties for replacement should NOT be present + assert "m.relates_to" not in result + assert "m.new_content" not in result + assert "format" not in result + assert "formatted_body" not in result + + +def test_build_matrix_text_content_plain_text_no_html() -> None: + """Test plain text that should not include HTML formatting.""" + result = _build_matrix_text_content("Simple plain text") + assert "msgtype" in result + assert "body" in result + assert "format" not in result + assert "formatted_body" not in result + + +@pytest.mark.asyncio +async def test_send_room_content_returns_room_send_response(): + """Test that _send_room_content returns the response from client.room_send.""" + client = _FakeAsyncClient("", "", "", None) + channel = MatrixChannel(_make_config(), MessageBus()) + channel.client = client + + room_id = "!test_room:matrix.org" + content = {"msgtype": "m.text", "body": "Hello World"} + + result = await channel._send_room_content(room_id, content) + + assert result is client.room_send_response + + +@pytest.mark.asyncio +async def test_send_delta_creates_stream_buffer_and_sends_initial_message() -> None: + channel = MatrixChannel(_make_config(), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + channel.client = client + client.room_send_response.event_id = "$8E2XVyINbEhcuAxvxd1d9JhQosNPzkVoU8TrbCAvyHo" + + await channel.send_delta("!room:matrix.org", "Hello") + + assert "!room:matrix.org" in channel._stream_bufs + buf = channel._stream_bufs["!room:matrix.org"] + assert buf.text == "Hello" + assert buf.event_id == "$8E2XVyINbEhcuAxvxd1d9JhQosNPzkVoU8TrbCAvyHo" + assert len(client.room_send_calls) == 1 + assert client.room_send_calls[0]["content"]["body"] == "Hello" + + +@pytest.mark.asyncio +async def test_send_delta_appends_without_sending_before_edit_interval(monkeypatch) -> None: + channel = MatrixChannel(_make_config(), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + channel.client = client + client.room_send_response.event_id = "$8E2XVyINbEhcuAxvxd1d9JhQosNPzkVoU8TrbCAvyHo" + + now = 100.0 + monkeypatch.setattr(channel, "monotonic_time", lambda: now) + + await channel.send_delta("!room:matrix.org", "Hello") + assert len(client.room_send_calls) == 1 + + await channel.send_delta("!room:matrix.org", " world") + assert len(client.room_send_calls) == 1 + + buf = channel._stream_bufs["!room:matrix.org"] + assert buf.text == "Hello world" + assert buf.event_id == "$8E2XVyINbEhcuAxvxd1d9JhQosNPzkVoU8TrbCAvyHo" + + +@pytest.mark.asyncio +async def test_send_delta_edits_again_after_interval(monkeypatch) -> None: + channel = MatrixChannel(_make_config(), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + channel.client = client + client.room_send_response.event_id = "$8E2XVyINbEhcuAxvxd1d9JhQosNPzkVoU8TrbCAvyHo" + + times = [100.0, 102.0, 104.0, 106.0, 108.0] + times.reverse() + monkeypatch.setattr(channel, "monotonic_time", lambda: times and times.pop()) + + await channel.send_delta("!room:matrix.org", "Hello") + await channel.send_delta("!room:matrix.org", " world") + + assert len(client.room_send_calls) == 2 + first_content = client.room_send_calls[0]["content"] + second_content = client.room_send_calls[1]["content"] + + assert "body" in first_content + assert first_content["body"] == "Hello" + assert "m.relates_to" not in first_content + + assert "body" in second_content + assert "m.relates_to" in second_content + assert second_content["body"] == "Hello world" + assert second_content["m.relates_to"] == { + "rel_type": "m.replace", + "event_id": "$8E2XVyINbEhcuAxvxd1d9JhQosNPzkVoU8TrbCAvyHo", + } + + +@pytest.mark.asyncio +async def test_send_delta_stream_end_replaces_existing_message() -> None: + channel = MatrixChannel(_make_config(), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + channel.client = client + + channel._stream_bufs["!room:matrix.org"] = matrix_module._StreamBuf( + text="Final text", + event_id="event-1", + last_edit=100.0, + ) + + await channel.send_delta("!room:matrix.org", "", {"_stream_end": True}) + + assert "!room:matrix.org" not in channel._stream_bufs + assert client.typing_calls[-1] == ("!room:matrix.org", False, TYPING_NOTICE_TIMEOUT_MS) + assert len(client.room_send_calls) == 1 + assert client.room_send_calls[0]["content"]["body"] == "Final text" + assert client.room_send_calls[0]["content"]["m.relates_to"] == { + "rel_type": "m.replace", + "event_id": "event-1", + } + + +@pytest.mark.asyncio +async def test_send_delta_stream_end_noop_when_buffer_missing() -> None: + channel = MatrixChannel(_make_config(), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + channel.client = client + + await channel.send_delta("!room:matrix.org", "", {"_stream_end": True}) + + assert client.room_send_calls == [] + assert client.typing_calls == [] + + +@pytest.mark.asyncio +async def test_send_delta_on_error_stops_typing(monkeypatch) -> None: + channel = MatrixChannel(_make_config(), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + client.raise_on_send = True + channel.client = client + + now = 100.0 + monkeypatch.setattr(channel, "monotonic_time", lambda: now) + + await channel.send_delta("!room:matrix.org", "Hello", {"room_id": "!room:matrix.org"}) + + assert "!room:matrix.org" in channel._stream_bufs + assert channel._stream_bufs["!room:matrix.org"].text == "Hello" + assert len(client.room_send_calls) == 1 + + assert len(client.typing_calls) == 1 + + +@pytest.mark.asyncio +async def test_send_delta_ignores_whitespace_only_delta(monkeypatch) -> None: + channel = MatrixChannel(_make_config(), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + channel.client = client + + now = 100.0 + monkeypatch.setattr(channel, "monotonic_time", lambda: now) + + await channel.send_delta("!room:matrix.org", " ") + + assert "!room:matrix.org" in channel._stream_bufs + assert channel._stream_bufs["!room:matrix.org"].text == " " + assert client.room_send_calls == [] \ No newline at end of file