From 4dd89f4c468d4cb052aed596fbe42267d7d17103 Mon Sep 17 00:00:00 2001 From: hinotoi-agent Date: Sat, 30 May 2026 09:48:41 +0800 Subject: [PATCH] fix(matrix): bound inbound media downloads --- nanobot/channels/matrix.py | 78 +++++++++++----- pyproject.toml | 1 + tests/channels/test_matrix_channel.py | 125 ++++++++++++++++++++++++-- 3 files changed, 174 insertions(+), 30 deletions(-) diff --git a/nanobot/channels/matrix.py b/nanobot/channels/matrix.py index 41f7c8e4e..9bb665684 100644 --- a/nanobot/channels/matrix.py +++ b/nanobot/channels/matrix.py @@ -8,23 +8,23 @@ from contextlib import suppress from dataclasses import dataclass from pathlib import Path from typing import Any, Literal, TypeAlias +from urllib.parse import quote, urlparse from pydantic import Field from nanobot.security.workspace_policy import is_path_within try: + import aiohttp import nh3 from mistune import create_markdown from nio import ( AsyncClient, AsyncClientConfig, - DownloadError, InviteEvent, JoinError, LoginResponse, MatrixRoom, - MemoryDownloadResponse, RoomEncryptedMedia, RoomMessage, RoomMessageMedia, @@ -64,6 +64,10 @@ _MSGTYPE_MAP = {"m.image": "image", "m.audio": "audio", "m.video": "video", "m.f MATRIX_MEDIA_EVENT_FILTER = (RoomMessageMedia, RoomEncryptedMedia) MatrixMediaEvent: TypeAlias = RoomMessageMedia | RoomEncryptedMedia + +class _MediaTooLargeError(Exception): + """Raised when an inbound Matrix media download exceeds the configured cap.""" + MATRIX_MARKDOWN = create_markdown( escape=True, plugins=["table", "strikethrough", "url", "superscript", "subscript"], @@ -192,6 +196,7 @@ class MatrixConfig(Base): e2ee_enabled: bool = Field(default=True, alias="e2eeEnabled") sync_stop_grace_seconds: int = 2 max_media_bytes: int = 20 * 1024 * 1024 + max_concurrent_media_downloads: int = 2 allow_from: list[str] = Field(default_factory=list) group_policy: Literal["open", "mention", "allowlist"] = "open" group_allow_from: list[str] = Field(default_factory=list) @@ -233,6 +238,9 @@ class MatrixChannel(BaseChannel): self._server_upload_limit_checked = False self._stream_bufs: dict[str, _StreamBuf] = {} self._started_at_ms: int = 0 + self._media_download_semaphore = asyncio.Semaphore( + max(1, int(self.config.max_concurrent_media_downloads)) + ) async def start(self) -> None: @@ -770,26 +778,48 @@ class MatrixChannel(BaseChannel): event_prefix = (event_id[:24] or "evt").strip("_") return self._media_dir() / f"{event_prefix}_{stem}{suffix}" - async def _download_media_bytes(self, mxc_url: str) -> bytes | None: - if not self.client: + async def _download_media_bytes(self, mxc_url: str, limit_bytes: int) -> bytes | None: + if not self.client or limit_bytes <= 0: + raise _MediaTooLargeError + + parsed = urlparse(mxc_url) + if parsed.scheme != "mxc" or not parsed.netloc or not parsed.path.strip("/"): return None - response = await self.client.download(mxc=mxc_url) - if isinstance(response, DownloadError): - self.logger.warning("download failed for {}: {}", mxc_url, response) + + homeserver = str(getattr(self.client, "homeserver", "") or self.config.homeserver).rstrip("/") + media_url = ( + f"{homeserver}/_matrix/client/v1/media/download/" + f"{quote(parsed.netloc, safe='')}/{quote(parsed.path.strip('/'), safe='')}" + ) + token = getattr(self.client, "access_token", None) or self.config.access_token + headers = {"Authorization": f"Bearer {token}"} if token else None + timeout = aiohttp.ClientTimeout(total=None) + + try: + async with aiohttp.ClientSession(timeout=timeout, headers=headers) as session: + async with session.get(media_url, params={"allow_remote": "true"}) as response: + if response.status >= 400: + self.logger.warning("download failed for {}: HTTP {}", mxc_url, response.status) + return None + content_length = response.headers.get("Content-Length") + if content_length is not None: + try: + if int(content_length) > limit_bytes: + raise _MediaTooLargeError + except ValueError: + pass + + chunks = bytearray() + async for chunk in response.content.iter_chunked(64 * 1024): + chunks.extend(chunk) + if len(chunks) > limit_bytes: + raise _MediaTooLargeError + return bytes(chunks) + except _MediaTooLargeError: + raise + except (aiohttp.ClientError, asyncio.TimeoutError, OSError): + self.logger.warning("download failed for {}", mxc_url, exc_info=True) return None - body = getattr(response, "body", None) - if isinstance(body, (bytes, bytearray)): - return bytes(body) - if isinstance(response, MemoryDownloadResponse): - return bytes(response.body) - if isinstance(body, (str, Path)): - path = Path(body) - if path.is_file(): - try: - return path.read_bytes() - except OSError: - return None - return None def _decrypt_media_bytes(self, event: MatrixMediaEvent, ciphertext: bytes) -> bytes | None: key_obj, hashes, iv = getattr(event, "key", None), getattr(event, "hashes", None), getattr(event, "iv", None) @@ -818,10 +848,14 @@ class MatrixChannel(BaseChannel): limit_bytes = await self._effective_media_limit_bytes() declared = self._event_declared_size_bytes(event) - if declared is not None and declared > limit_bytes: + if declared is None or declared > limit_bytes: return None, _ATTACH_TOO_LARGE.format(filename) - downloaded = await self._download_media_bytes(mxc_url) + try: + async with self._media_download_semaphore: + downloaded = await self._download_media_bytes(mxc_url, limit_bytes) + except _MediaTooLargeError: + return None, _ATTACH_TOO_LARGE.format(filename) if downloaded is None: return None, fail diff --git a/pyproject.toml b/pyproject.toml index ee27548c7..058ab0d01 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,6 +82,7 @@ msteams = [ matrix = [ "matrix-nio[e2e]>=0.25.2; sys_platform != 'win32'", + "aiohttp>=3.9.0,<4.0.0", "mistune>=3.0.0,<4.0.0", "nh3>=0.2.17,<1.0.0", ] diff --git a/tests/channels/test_matrix_channel.py b/tests/channels/test_matrix_channel.py index 8bd9f8154..7eb9f73a1 100644 --- a/tests/channels/test_matrix_channel.py +++ b/tests/channels/test_matrix_channel.py @@ -9,8 +9,6 @@ pytest.importorskip("nh3") pytest.importorskip("mistune") from nio import RoomSendResponse, SyncError -from nanobot.channels.matrix import _build_matrix_text_content - import nanobot.channels.matrix as matrix_module from nanobot.bus.events import OutboundMessage from nanobot.bus.queue import MessageBus @@ -18,8 +16,9 @@ from nanobot.channels.matrix import ( MATRIX_HTML_FORMAT, TYPING_NOTICE_TIMEOUT_MS, MatrixChannel, + MatrixConfig, + _build_matrix_text_content, ) -from nanobot.channels.matrix import MatrixConfig _ROOM_SEND_UNSET = object() @@ -693,6 +692,13 @@ async def test_on_media_message_downloads_attachment_and_sets_metadata( client.download_bytes = b"image" channel.client = client + async def _download_media_bytes(mxc_url: str, limit_bytes: int) -> bytes: + client.download_calls.append(mxc_url) + assert limit_bytes >= len(client.download_bytes) + return client.download_bytes + + monkeypatch.setattr(channel, "_download_media_bytes", _download_media_bytes) + handled: list[dict[str, object]] = [] async def _fake_handle_message(**kwargs) -> None: @@ -857,9 +863,14 @@ async def test_on_media_message_handles_download_error(monkeypatch, tmp_path) -> channel = MatrixChannel(_make_config(), MessageBus()) client = _FakeAsyncClient("", "", "", None) - client.download_response = matrix_module.DownloadError("download failed") channel.client = client + async def _download_media_bytes(mxc_url: str, _limit_bytes: int): + client.download_calls.append(mxc_url) + return None + + monkeypatch.setattr(channel, "_download_media_bytes", _download_media_bytes) + handled: list[dict[str, object]] = [] async def _fake_handle_message(**kwargs) -> None: @@ -873,7 +884,7 @@ async def test_on_media_message_handles_download_error(monkeypatch, tmp_path) -> body="photo.png", url="mxc://example.org/mediaid", event_id="$event3", - source={"content": {"msgtype": "m.image"}}, + source={"content": {"msgtype": "m.image", "info": {"size": 5}}}, ) await channel._on_media_message(room, event) @@ -899,6 +910,13 @@ async def test_on_media_message_decrypts_encrypted_media(monkeypatch, tmp_path) client.download_bytes = b"cipher" channel.client = client + async def _download_media_bytes(mxc_url: str, limit_bytes: int) -> bytes: + client.download_calls.append(mxc_url) + assert limit_bytes >= len(client.download_bytes) + return client.download_bytes + + monkeypatch.setattr(channel, "_download_media_bytes", _download_media_bytes) + handled: list[dict[str, object]] = [] async def _fake_handle_message(**kwargs) -> None: @@ -942,6 +960,13 @@ async def test_on_media_message_handles_decrypt_error(monkeypatch, tmp_path) -> client.download_bytes = b"cipher" channel.client = client + async def _download_media_bytes(mxc_url: str, limit_bytes: int) -> bytes: + client.download_calls.append(mxc_url) + assert limit_bytes >= len(client.download_bytes) + return client.download_bytes + + monkeypatch.setattr(channel, "_download_media_bytes", _download_media_bytes) + handled: list[dict[str, object]] = [] async def _fake_handle_message(**kwargs) -> None: @@ -958,7 +983,7 @@ async def test_on_media_message_handles_decrypt_error(monkeypatch, tmp_path) -> key={"k": "key"}, hashes={"sha256": "hash"}, iv="iv", - source={"content": {"msgtype": "m.file"}}, + source={"content": {"msgtype": "m.file", "info": {"size": 6}}}, ) await channel._on_media_message(room, event) @@ -1756,7 +1781,7 @@ async def test_send_delta_on_error_stops_typing(monkeypatch) -> None: assert "!room:matrix.org" in channel._stream_bufs assert channel._stream_bufs["!room:matrix.org"].text == "Hello" assert len(client.room_send_calls) == 1 - + assert len(client.typing_calls) == 1 @@ -1773,4 +1798,88 @@ async def test_send_delta_ignores_whitespace_only_delta(monkeypatch) -> None: assert "!room:matrix.org" in channel._stream_bufs assert channel._stream_bufs["!room:matrix.org"].text == " " - assert client.room_send_calls == [] \ No newline at end of file + + +@pytest.mark.asyncio +async def test_fetch_media_rejects_missing_declared_size(monkeypatch, tmp_path) -> None: + channel = MatrixChannel(_make_config(max_media_bytes=8), MessageBus()) + client = _FakeAsyncClient("https://matrix.org", "", "", None) + channel.client = client + monkeypatch.setattr("nanobot.channels.matrix.get_media_dir", lambda _name: tmp_path) + + async def _download_should_not_run(*_args, **_kwargs): + raise AssertionError("download should be rejected before fetching bytes") + + monkeypatch.setattr(channel, "_download_media_bytes", _download_should_not_run) + event = SimpleNamespace( + sender="@alice:matrix.org", + event_id="$event1", + body="payload.bin", + url="mxc://example.org/media", + source={"content": {"msgtype": "m.file"}}, + ) + + attachment, marker = await channel._fetch_media_attachment( + SimpleNamespace(room_id="!room:matrix.org"), + event, + ) + + assert attachment is None + assert marker == "[attachment: payload.bin - too large]" + + +@pytest.mark.asyncio +async def test_fetch_media_rejects_declared_oversized_before_download(monkeypatch, tmp_path) -> None: + channel = MatrixChannel(_make_config(max_media_bytes=8), MessageBus()) + client = _FakeAsyncClient("https://matrix.org", "", "", None) + channel.client = client + monkeypatch.setattr("nanobot.channels.matrix.get_media_dir", lambda _name: tmp_path) + + async def _download_should_not_run(*_args, **_kwargs): + raise AssertionError("download should be rejected before fetching bytes") + + monkeypatch.setattr(channel, "_download_media_bytes", _download_should_not_run) + event = SimpleNamespace( + sender="@alice:matrix.org", + event_id="$event1", + body="payload.bin", + url="mxc://example.org/media", + source={"content": {"msgtype": "m.file", "info": {"size": 9}}}, + ) + + attachment, marker = await channel._fetch_media_attachment( + SimpleNamespace(room_id="!room:matrix.org"), + event, + ) + + assert attachment is None + assert marker == "[attachment: payload.bin - too large]" + + +@pytest.mark.asyncio +async def test_fetch_media_maps_streaming_cap_to_too_large(monkeypatch, tmp_path) -> None: + channel = MatrixChannel(_make_config(max_media_bytes=8), MessageBus()) + client = _FakeAsyncClient("https://matrix.org", "", "", None) + channel.client = client + monkeypatch.setattr("nanobot.channels.matrix.get_media_dir", lambda _name: tmp_path) + + async def _download_too_large(_mxc_url: str, _limit_bytes: int): + raise matrix_module._MediaTooLargeError + + monkeypatch.setattr(channel, "_download_media_bytes", _download_too_large) + event = SimpleNamespace( + sender="@alice:matrix.org", + event_id="$event1", + body="payload.bin", + url="mxc://example.org/media", + source={"content": {"msgtype": "m.file", "info": {"size": 8}}}, + ) + + attachment, marker = await channel._fetch_media_attachment( + SimpleNamespace(room_id="!room:matrix.org"), + event, + ) + + assert attachment is None + assert marker == "[attachment: payload.bin - too large]" + assert client.room_send_calls == []