fix(matrix): bound inbound media downloads

This commit is contained in:
hinotoi-agent 2026-05-30 09:48:41 +08:00 committed by Xubin Ren
parent 7c86223643
commit 4dd89f4c46
3 changed files with 174 additions and 30 deletions

View File

@ -8,23 +8,23 @@ from contextlib import suppress
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Literal, TypeAlias
from urllib.parse import quote, urlparse
from pydantic import Field
from nanobot.security.workspace_policy import is_path_within
try:
import aiohttp
import nh3
from mistune import create_markdown
from nio import (
AsyncClient,
AsyncClientConfig,
DownloadError,
InviteEvent,
JoinError,
LoginResponse,
MatrixRoom,
MemoryDownloadResponse,
RoomEncryptedMedia,
RoomMessage,
RoomMessageMedia,
@ -64,6 +64,10 @@ _MSGTYPE_MAP = {"m.image": "image", "m.audio": "audio", "m.video": "video", "m.f
MATRIX_MEDIA_EVENT_FILTER = (RoomMessageMedia, RoomEncryptedMedia)
MatrixMediaEvent: TypeAlias = RoomMessageMedia | RoomEncryptedMedia
class _MediaTooLargeError(Exception):
"""Raised when an inbound Matrix media download exceeds the configured cap."""
MATRIX_MARKDOWN = create_markdown(
escape=True,
plugins=["table", "strikethrough", "url", "superscript", "subscript"],
@ -192,6 +196,7 @@ class MatrixConfig(Base):
e2ee_enabled: bool = Field(default=True, alias="e2eeEnabled")
sync_stop_grace_seconds: int = 2
max_media_bytes: int = 20 * 1024 * 1024
max_concurrent_media_downloads: int = 2
allow_from: list[str] = Field(default_factory=list)
group_policy: Literal["open", "mention", "allowlist"] = "open"
group_allow_from: list[str] = Field(default_factory=list)
@ -233,6 +238,9 @@ class MatrixChannel(BaseChannel):
self._server_upload_limit_checked = False
self._stream_bufs: dict[str, _StreamBuf] = {}
self._started_at_ms: int = 0
self._media_download_semaphore = asyncio.Semaphore(
max(1, int(self.config.max_concurrent_media_downloads))
)
async def start(self) -> None:
@ -770,26 +778,48 @@ class MatrixChannel(BaseChannel):
event_prefix = (event_id[:24] or "evt").strip("_")
return self._media_dir() / f"{event_prefix}_{stem}{suffix}"
async def _download_media_bytes(self, mxc_url: str) -> bytes | None:
if not self.client:
async def _download_media_bytes(self, mxc_url: str, limit_bytes: int) -> bytes | None:
if not self.client or limit_bytes <= 0:
raise _MediaTooLargeError
parsed = urlparse(mxc_url)
if parsed.scheme != "mxc" or not parsed.netloc or not parsed.path.strip("/"):
return None
response = await self.client.download(mxc=mxc_url)
if isinstance(response, DownloadError):
self.logger.warning("download failed for {}: {}", mxc_url, response)
homeserver = str(getattr(self.client, "homeserver", "") or self.config.homeserver).rstrip("/")
media_url = (
f"{homeserver}/_matrix/client/v1/media/download/"
f"{quote(parsed.netloc, safe='')}/{quote(parsed.path.strip('/'), safe='')}"
)
token = getattr(self.client, "access_token", None) or self.config.access_token
headers = {"Authorization": f"Bearer {token}"} if token else None
timeout = aiohttp.ClientTimeout(total=None)
try:
async with aiohttp.ClientSession(timeout=timeout, headers=headers) as session:
async with session.get(media_url, params={"allow_remote": "true"}) as response:
if response.status >= 400:
self.logger.warning("download failed for {}: HTTP {}", mxc_url, response.status)
return None
content_length = response.headers.get("Content-Length")
if content_length is not None:
try:
if int(content_length) > limit_bytes:
raise _MediaTooLargeError
except ValueError:
pass
chunks = bytearray()
async for chunk in response.content.iter_chunked(64 * 1024):
chunks.extend(chunk)
if len(chunks) > limit_bytes:
raise _MediaTooLargeError
return bytes(chunks)
except _MediaTooLargeError:
raise
except (aiohttp.ClientError, asyncio.TimeoutError, OSError):
self.logger.warning("download failed for {}", mxc_url, exc_info=True)
return None
body = getattr(response, "body", None)
if isinstance(body, (bytes, bytearray)):
return bytes(body)
if isinstance(response, MemoryDownloadResponse):
return bytes(response.body)
if isinstance(body, (str, Path)):
path = Path(body)
if path.is_file():
try:
return path.read_bytes()
except OSError:
return None
return None
def _decrypt_media_bytes(self, event: MatrixMediaEvent, ciphertext: bytes) -> bytes | None:
key_obj, hashes, iv = getattr(event, "key", None), getattr(event, "hashes", None), getattr(event, "iv", None)
@ -818,10 +848,14 @@ class MatrixChannel(BaseChannel):
limit_bytes = await self._effective_media_limit_bytes()
declared = self._event_declared_size_bytes(event)
if declared is not None and declared > limit_bytes:
if declared is None or declared > limit_bytes:
return None, _ATTACH_TOO_LARGE.format(filename)
downloaded = await self._download_media_bytes(mxc_url)
try:
async with self._media_download_semaphore:
downloaded = await self._download_media_bytes(mxc_url, limit_bytes)
except _MediaTooLargeError:
return None, _ATTACH_TOO_LARGE.format(filename)
if downloaded is None:
return None, fail

View File

@ -82,6 +82,7 @@ msteams = [
matrix = [
"matrix-nio[e2e]>=0.25.2; sys_platform != 'win32'",
"aiohttp>=3.9.0,<4.0.0",
"mistune>=3.0.0,<4.0.0",
"nh3>=0.2.17,<1.0.0",
]

View File

@ -9,8 +9,6 @@ pytest.importorskip("nh3")
pytest.importorskip("mistune")
from nio import RoomSendResponse, SyncError
from nanobot.channels.matrix import _build_matrix_text_content
import nanobot.channels.matrix as matrix_module
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
@ -18,8 +16,9 @@ from nanobot.channels.matrix import (
MATRIX_HTML_FORMAT,
TYPING_NOTICE_TIMEOUT_MS,
MatrixChannel,
MatrixConfig,
_build_matrix_text_content,
)
from nanobot.channels.matrix import MatrixConfig
_ROOM_SEND_UNSET = object()
@ -693,6 +692,13 @@ async def test_on_media_message_downloads_attachment_and_sets_metadata(
client.download_bytes = b"image"
channel.client = client
async def _download_media_bytes(mxc_url: str, limit_bytes: int) -> bytes:
client.download_calls.append(mxc_url)
assert limit_bytes >= len(client.download_bytes)
return client.download_bytes
monkeypatch.setattr(channel, "_download_media_bytes", _download_media_bytes)
handled: list[dict[str, object]] = []
async def _fake_handle_message(**kwargs) -> None:
@ -857,9 +863,14 @@ async def test_on_media_message_handles_download_error(monkeypatch, tmp_path) ->
channel = MatrixChannel(_make_config(), MessageBus())
client = _FakeAsyncClient("", "", "", None)
client.download_response = matrix_module.DownloadError("download failed")
channel.client = client
async def _download_media_bytes(mxc_url: str, _limit_bytes: int):
client.download_calls.append(mxc_url)
return None
monkeypatch.setattr(channel, "_download_media_bytes", _download_media_bytes)
handled: list[dict[str, object]] = []
async def _fake_handle_message(**kwargs) -> None:
@ -873,7 +884,7 @@ async def test_on_media_message_handles_download_error(monkeypatch, tmp_path) ->
body="photo.png",
url="mxc://example.org/mediaid",
event_id="$event3",
source={"content": {"msgtype": "m.image"}},
source={"content": {"msgtype": "m.image", "info": {"size": 5}}},
)
await channel._on_media_message(room, event)
@ -899,6 +910,13 @@ async def test_on_media_message_decrypts_encrypted_media(monkeypatch, tmp_path)
client.download_bytes = b"cipher"
channel.client = client
async def _download_media_bytes(mxc_url: str, limit_bytes: int) -> bytes:
client.download_calls.append(mxc_url)
assert limit_bytes >= len(client.download_bytes)
return client.download_bytes
monkeypatch.setattr(channel, "_download_media_bytes", _download_media_bytes)
handled: list[dict[str, object]] = []
async def _fake_handle_message(**kwargs) -> None:
@ -942,6 +960,13 @@ async def test_on_media_message_handles_decrypt_error(monkeypatch, tmp_path) ->
client.download_bytes = b"cipher"
channel.client = client
async def _download_media_bytes(mxc_url: str, limit_bytes: int) -> bytes:
client.download_calls.append(mxc_url)
assert limit_bytes >= len(client.download_bytes)
return client.download_bytes
monkeypatch.setattr(channel, "_download_media_bytes", _download_media_bytes)
handled: list[dict[str, object]] = []
async def _fake_handle_message(**kwargs) -> None:
@ -958,7 +983,7 @@ async def test_on_media_message_handles_decrypt_error(monkeypatch, tmp_path) ->
key={"k": "key"},
hashes={"sha256": "hash"},
iv="iv",
source={"content": {"msgtype": "m.file"}},
source={"content": {"msgtype": "m.file", "info": {"size": 6}}},
)
await channel._on_media_message(room, event)
@ -1756,7 +1781,7 @@ async def test_send_delta_on_error_stops_typing(monkeypatch) -> None:
assert "!room:matrix.org" in channel._stream_bufs
assert channel._stream_bufs["!room:matrix.org"].text == "Hello"
assert len(client.room_send_calls) == 1
assert len(client.typing_calls) == 1
@ -1773,4 +1798,88 @@ async def test_send_delta_ignores_whitespace_only_delta(monkeypatch) -> None:
assert "!room:matrix.org" in channel._stream_bufs
assert channel._stream_bufs["!room:matrix.org"].text == " "
assert client.room_send_calls == []
@pytest.mark.asyncio
async def test_fetch_media_rejects_missing_declared_size(monkeypatch, tmp_path) -> None:
channel = MatrixChannel(_make_config(max_media_bytes=8), MessageBus())
client = _FakeAsyncClient("https://matrix.org", "", "", None)
channel.client = client
monkeypatch.setattr("nanobot.channels.matrix.get_media_dir", lambda _name: tmp_path)
async def _download_should_not_run(*_args, **_kwargs):
raise AssertionError("download should be rejected before fetching bytes")
monkeypatch.setattr(channel, "_download_media_bytes", _download_should_not_run)
event = SimpleNamespace(
sender="@alice:matrix.org",
event_id="$event1",
body="payload.bin",
url="mxc://example.org/media",
source={"content": {"msgtype": "m.file"}},
)
attachment, marker = await channel._fetch_media_attachment(
SimpleNamespace(room_id="!room:matrix.org"),
event,
)
assert attachment is None
assert marker == "[attachment: payload.bin - too large]"
@pytest.mark.asyncio
async def test_fetch_media_rejects_declared_oversized_before_download(monkeypatch, tmp_path) -> None:
channel = MatrixChannel(_make_config(max_media_bytes=8), MessageBus())
client = _FakeAsyncClient("https://matrix.org", "", "", None)
channel.client = client
monkeypatch.setattr("nanobot.channels.matrix.get_media_dir", lambda _name: tmp_path)
async def _download_should_not_run(*_args, **_kwargs):
raise AssertionError("download should be rejected before fetching bytes")
monkeypatch.setattr(channel, "_download_media_bytes", _download_should_not_run)
event = SimpleNamespace(
sender="@alice:matrix.org",
event_id="$event1",
body="payload.bin",
url="mxc://example.org/media",
source={"content": {"msgtype": "m.file", "info": {"size": 9}}},
)
attachment, marker = await channel._fetch_media_attachment(
SimpleNamespace(room_id="!room:matrix.org"),
event,
)
assert attachment is None
assert marker == "[attachment: payload.bin - too large]"
@pytest.mark.asyncio
async def test_fetch_media_maps_streaming_cap_to_too_large(monkeypatch, tmp_path) -> None:
channel = MatrixChannel(_make_config(max_media_bytes=8), MessageBus())
client = _FakeAsyncClient("https://matrix.org", "", "", None)
channel.client = client
monkeypatch.setattr("nanobot.channels.matrix.get_media_dir", lambda _name: tmp_path)
async def _download_too_large(_mxc_url: str, _limit_bytes: int):
raise matrix_module._MediaTooLargeError
monkeypatch.setattr(channel, "_download_media_bytes", _download_too_large)
event = SimpleNamespace(
sender="@alice:matrix.org",
event_id="$event1",
body="payload.bin",
url="mxc://example.org/media",
source={"content": {"msgtype": "m.file", "info": {"size": 8}}},
)
attachment, marker = await channel._fetch_media_attachment(
SimpleNamespace(room_id="!room:matrix.org"),
event,
)
assert attachment is None
assert marker == "[attachment: payload.bin - too large]"
assert client.room_send_calls == []