fix(matrix): bound inbound media downloads

This commit is contained in:
hinotoi-agent 2026-05-30 09:48:41 +08:00 committed by Xubin Ren
parent 7c86223643
commit 4dd89f4c46
3 changed files with 174 additions and 30 deletions

View File

@ -8,23 +8,23 @@ from contextlib import suppress
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Any, Literal, TypeAlias from typing import Any, Literal, TypeAlias
from urllib.parse import quote, urlparse
from pydantic import Field from pydantic import Field
from nanobot.security.workspace_policy import is_path_within from nanobot.security.workspace_policy import is_path_within
try: try:
import aiohttp
import nh3 import nh3
from mistune import create_markdown from mistune import create_markdown
from nio import ( from nio import (
AsyncClient, AsyncClient,
AsyncClientConfig, AsyncClientConfig,
DownloadError,
InviteEvent, InviteEvent,
JoinError, JoinError,
LoginResponse, LoginResponse,
MatrixRoom, MatrixRoom,
MemoryDownloadResponse,
RoomEncryptedMedia, RoomEncryptedMedia,
RoomMessage, RoomMessage,
RoomMessageMedia, RoomMessageMedia,
@ -64,6 +64,10 @@ _MSGTYPE_MAP = {"m.image": "image", "m.audio": "audio", "m.video": "video", "m.f
MATRIX_MEDIA_EVENT_FILTER = (RoomMessageMedia, RoomEncryptedMedia) MATRIX_MEDIA_EVENT_FILTER = (RoomMessageMedia, RoomEncryptedMedia)
MatrixMediaEvent: TypeAlias = RoomMessageMedia | RoomEncryptedMedia MatrixMediaEvent: TypeAlias = RoomMessageMedia | RoomEncryptedMedia
class _MediaTooLargeError(Exception):
"""Raised when an inbound Matrix media download exceeds the configured cap."""
MATRIX_MARKDOWN = create_markdown( MATRIX_MARKDOWN = create_markdown(
escape=True, escape=True,
plugins=["table", "strikethrough", "url", "superscript", "subscript"], plugins=["table", "strikethrough", "url", "superscript", "subscript"],
@ -192,6 +196,7 @@ class MatrixConfig(Base):
e2ee_enabled: bool = Field(default=True, alias="e2eeEnabled") e2ee_enabled: bool = Field(default=True, alias="e2eeEnabled")
sync_stop_grace_seconds: int = 2 sync_stop_grace_seconds: int = 2
max_media_bytes: int = 20 * 1024 * 1024 max_media_bytes: int = 20 * 1024 * 1024
max_concurrent_media_downloads: int = 2
allow_from: list[str] = Field(default_factory=list) allow_from: list[str] = Field(default_factory=list)
group_policy: Literal["open", "mention", "allowlist"] = "open" group_policy: Literal["open", "mention", "allowlist"] = "open"
group_allow_from: list[str] = Field(default_factory=list) group_allow_from: list[str] = Field(default_factory=list)
@ -233,6 +238,9 @@ class MatrixChannel(BaseChannel):
self._server_upload_limit_checked = False self._server_upload_limit_checked = False
self._stream_bufs: dict[str, _StreamBuf] = {} self._stream_bufs: dict[str, _StreamBuf] = {}
self._started_at_ms: int = 0 self._started_at_ms: int = 0
self._media_download_semaphore = asyncio.Semaphore(
max(1, int(self.config.max_concurrent_media_downloads))
)
async def start(self) -> None: async def start(self) -> None:
@ -770,26 +778,48 @@ class MatrixChannel(BaseChannel):
event_prefix = (event_id[:24] or "evt").strip("_") event_prefix = (event_id[:24] or "evt").strip("_")
return self._media_dir() / f"{event_prefix}_{stem}{suffix}" return self._media_dir() / f"{event_prefix}_{stem}{suffix}"
async def _download_media_bytes(self, mxc_url: str) -> bytes | None: async def _download_media_bytes(self, mxc_url: str, limit_bytes: int) -> bytes | None:
if not self.client: if not self.client or limit_bytes <= 0:
raise _MediaTooLargeError
parsed = urlparse(mxc_url)
if parsed.scheme != "mxc" or not parsed.netloc or not parsed.path.strip("/"):
return None return None
response = await self.client.download(mxc=mxc_url)
if isinstance(response, DownloadError): homeserver = str(getattr(self.client, "homeserver", "") or self.config.homeserver).rstrip("/")
self.logger.warning("download failed for {}: {}", mxc_url, response) media_url = (
f"{homeserver}/_matrix/client/v1/media/download/"
f"{quote(parsed.netloc, safe='')}/{quote(parsed.path.strip('/'), safe='')}"
)
token = getattr(self.client, "access_token", None) or self.config.access_token
headers = {"Authorization": f"Bearer {token}"} if token else None
timeout = aiohttp.ClientTimeout(total=None)
try:
async with aiohttp.ClientSession(timeout=timeout, headers=headers) as session:
async with session.get(media_url, params={"allow_remote": "true"}) as response:
if response.status >= 400:
self.logger.warning("download failed for {}: HTTP {}", mxc_url, response.status)
return None
content_length = response.headers.get("Content-Length")
if content_length is not None:
try:
if int(content_length) > limit_bytes:
raise _MediaTooLargeError
except ValueError:
pass
chunks = bytearray()
async for chunk in response.content.iter_chunked(64 * 1024):
chunks.extend(chunk)
if len(chunks) > limit_bytes:
raise _MediaTooLargeError
return bytes(chunks)
except _MediaTooLargeError:
raise
except (aiohttp.ClientError, asyncio.TimeoutError, OSError):
self.logger.warning("download failed for {}", mxc_url, exc_info=True)
return None return None
body = getattr(response, "body", None)
if isinstance(body, (bytes, bytearray)):
return bytes(body)
if isinstance(response, MemoryDownloadResponse):
return bytes(response.body)
if isinstance(body, (str, Path)):
path = Path(body)
if path.is_file():
try:
return path.read_bytes()
except OSError:
return None
return None
def _decrypt_media_bytes(self, event: MatrixMediaEvent, ciphertext: bytes) -> bytes | None: def _decrypt_media_bytes(self, event: MatrixMediaEvent, ciphertext: bytes) -> bytes | None:
key_obj, hashes, iv = getattr(event, "key", None), getattr(event, "hashes", None), getattr(event, "iv", None) key_obj, hashes, iv = getattr(event, "key", None), getattr(event, "hashes", None), getattr(event, "iv", None)
@ -818,10 +848,14 @@ class MatrixChannel(BaseChannel):
limit_bytes = await self._effective_media_limit_bytes() limit_bytes = await self._effective_media_limit_bytes()
declared = self._event_declared_size_bytes(event) declared = self._event_declared_size_bytes(event)
if declared is not None and declared > limit_bytes: if declared is None or declared > limit_bytes:
return None, _ATTACH_TOO_LARGE.format(filename) return None, _ATTACH_TOO_LARGE.format(filename)
downloaded = await self._download_media_bytes(mxc_url) try:
async with self._media_download_semaphore:
downloaded = await self._download_media_bytes(mxc_url, limit_bytes)
except _MediaTooLargeError:
return None, _ATTACH_TOO_LARGE.format(filename)
if downloaded is None: if downloaded is None:
return None, fail return None, fail

View File

@ -82,6 +82,7 @@ msteams = [
matrix = [ matrix = [
"matrix-nio[e2e]>=0.25.2; sys_platform != 'win32'", "matrix-nio[e2e]>=0.25.2; sys_platform != 'win32'",
"aiohttp>=3.9.0,<4.0.0",
"mistune>=3.0.0,<4.0.0", "mistune>=3.0.0,<4.0.0",
"nh3>=0.2.17,<1.0.0", "nh3>=0.2.17,<1.0.0",
] ]

View File

@ -9,8 +9,6 @@ pytest.importorskip("nh3")
pytest.importorskip("mistune") pytest.importorskip("mistune")
from nio import RoomSendResponse, SyncError from nio import RoomSendResponse, SyncError
from nanobot.channels.matrix import _build_matrix_text_content
import nanobot.channels.matrix as matrix_module import nanobot.channels.matrix as matrix_module
from nanobot.bus.events import OutboundMessage from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus from nanobot.bus.queue import MessageBus
@ -18,8 +16,9 @@ from nanobot.channels.matrix import (
MATRIX_HTML_FORMAT, MATRIX_HTML_FORMAT,
TYPING_NOTICE_TIMEOUT_MS, TYPING_NOTICE_TIMEOUT_MS,
MatrixChannel, MatrixChannel,
MatrixConfig,
_build_matrix_text_content,
) )
from nanobot.channels.matrix import MatrixConfig
_ROOM_SEND_UNSET = object() _ROOM_SEND_UNSET = object()
@ -693,6 +692,13 @@ async def test_on_media_message_downloads_attachment_and_sets_metadata(
client.download_bytes = b"image" client.download_bytes = b"image"
channel.client = client channel.client = client
async def _download_media_bytes(mxc_url: str, limit_bytes: int) -> bytes:
client.download_calls.append(mxc_url)
assert limit_bytes >= len(client.download_bytes)
return client.download_bytes
monkeypatch.setattr(channel, "_download_media_bytes", _download_media_bytes)
handled: list[dict[str, object]] = [] handled: list[dict[str, object]] = []
async def _fake_handle_message(**kwargs) -> None: async def _fake_handle_message(**kwargs) -> None:
@ -857,9 +863,14 @@ async def test_on_media_message_handles_download_error(monkeypatch, tmp_path) ->
channel = MatrixChannel(_make_config(), MessageBus()) channel = MatrixChannel(_make_config(), MessageBus())
client = _FakeAsyncClient("", "", "", None) client = _FakeAsyncClient("", "", "", None)
client.download_response = matrix_module.DownloadError("download failed")
channel.client = client channel.client = client
async def _download_media_bytes(mxc_url: str, _limit_bytes: int):
client.download_calls.append(mxc_url)
return None
monkeypatch.setattr(channel, "_download_media_bytes", _download_media_bytes)
handled: list[dict[str, object]] = [] handled: list[dict[str, object]] = []
async def _fake_handle_message(**kwargs) -> None: async def _fake_handle_message(**kwargs) -> None:
@ -873,7 +884,7 @@ async def test_on_media_message_handles_download_error(monkeypatch, tmp_path) ->
body="photo.png", body="photo.png",
url="mxc://example.org/mediaid", url="mxc://example.org/mediaid",
event_id="$event3", event_id="$event3",
source={"content": {"msgtype": "m.image"}}, source={"content": {"msgtype": "m.image", "info": {"size": 5}}},
) )
await channel._on_media_message(room, event) await channel._on_media_message(room, event)
@ -899,6 +910,13 @@ async def test_on_media_message_decrypts_encrypted_media(monkeypatch, tmp_path)
client.download_bytes = b"cipher" client.download_bytes = b"cipher"
channel.client = client channel.client = client
async def _download_media_bytes(mxc_url: str, limit_bytes: int) -> bytes:
client.download_calls.append(mxc_url)
assert limit_bytes >= len(client.download_bytes)
return client.download_bytes
monkeypatch.setattr(channel, "_download_media_bytes", _download_media_bytes)
handled: list[dict[str, object]] = [] handled: list[dict[str, object]] = []
async def _fake_handle_message(**kwargs) -> None: async def _fake_handle_message(**kwargs) -> None:
@ -942,6 +960,13 @@ async def test_on_media_message_handles_decrypt_error(monkeypatch, tmp_path) ->
client.download_bytes = b"cipher" client.download_bytes = b"cipher"
channel.client = client channel.client = client
async def _download_media_bytes(mxc_url: str, limit_bytes: int) -> bytes:
client.download_calls.append(mxc_url)
assert limit_bytes >= len(client.download_bytes)
return client.download_bytes
monkeypatch.setattr(channel, "_download_media_bytes", _download_media_bytes)
handled: list[dict[str, object]] = [] handled: list[dict[str, object]] = []
async def _fake_handle_message(**kwargs) -> None: async def _fake_handle_message(**kwargs) -> None:
@ -958,7 +983,7 @@ async def test_on_media_message_handles_decrypt_error(monkeypatch, tmp_path) ->
key={"k": "key"}, key={"k": "key"},
hashes={"sha256": "hash"}, hashes={"sha256": "hash"},
iv="iv", iv="iv",
source={"content": {"msgtype": "m.file"}}, source={"content": {"msgtype": "m.file", "info": {"size": 6}}},
) )
await channel._on_media_message(room, event) await channel._on_media_message(room, event)
@ -1756,7 +1781,7 @@ async def test_send_delta_on_error_stops_typing(monkeypatch) -> None:
assert "!room:matrix.org" in channel._stream_bufs assert "!room:matrix.org" in channel._stream_bufs
assert channel._stream_bufs["!room:matrix.org"].text == "Hello" assert channel._stream_bufs["!room:matrix.org"].text == "Hello"
assert len(client.room_send_calls) == 1 assert len(client.room_send_calls) == 1
assert len(client.typing_calls) == 1 assert len(client.typing_calls) == 1
@ -1773,4 +1798,88 @@ async def test_send_delta_ignores_whitespace_only_delta(monkeypatch) -> None:
assert "!room:matrix.org" in channel._stream_bufs assert "!room:matrix.org" in channel._stream_bufs
assert channel._stream_bufs["!room:matrix.org"].text == " " assert channel._stream_bufs["!room:matrix.org"].text == " "
assert client.room_send_calls == []
@pytest.mark.asyncio
async def test_fetch_media_rejects_missing_declared_size(monkeypatch, tmp_path) -> None:
channel = MatrixChannel(_make_config(max_media_bytes=8), MessageBus())
client = _FakeAsyncClient("https://matrix.org", "", "", None)
channel.client = client
monkeypatch.setattr("nanobot.channels.matrix.get_media_dir", lambda _name: tmp_path)
async def _download_should_not_run(*_args, **_kwargs):
raise AssertionError("download should be rejected before fetching bytes")
monkeypatch.setattr(channel, "_download_media_bytes", _download_should_not_run)
event = SimpleNamespace(
sender="@alice:matrix.org",
event_id="$event1",
body="payload.bin",
url="mxc://example.org/media",
source={"content": {"msgtype": "m.file"}},
)
attachment, marker = await channel._fetch_media_attachment(
SimpleNamespace(room_id="!room:matrix.org"),
event,
)
assert attachment is None
assert marker == "[attachment: payload.bin - too large]"
@pytest.mark.asyncio
async def test_fetch_media_rejects_declared_oversized_before_download(monkeypatch, tmp_path) -> None:
channel = MatrixChannel(_make_config(max_media_bytes=8), MessageBus())
client = _FakeAsyncClient("https://matrix.org", "", "", None)
channel.client = client
monkeypatch.setattr("nanobot.channels.matrix.get_media_dir", lambda _name: tmp_path)
async def _download_should_not_run(*_args, **_kwargs):
raise AssertionError("download should be rejected before fetching bytes")
monkeypatch.setattr(channel, "_download_media_bytes", _download_should_not_run)
event = SimpleNamespace(
sender="@alice:matrix.org",
event_id="$event1",
body="payload.bin",
url="mxc://example.org/media",
source={"content": {"msgtype": "m.file", "info": {"size": 9}}},
)
attachment, marker = await channel._fetch_media_attachment(
SimpleNamespace(room_id="!room:matrix.org"),
event,
)
assert attachment is None
assert marker == "[attachment: payload.bin - too large]"
@pytest.mark.asyncio
async def test_fetch_media_maps_streaming_cap_to_too_large(monkeypatch, tmp_path) -> None:
channel = MatrixChannel(_make_config(max_media_bytes=8), MessageBus())
client = _FakeAsyncClient("https://matrix.org", "", "", None)
channel.client = client
monkeypatch.setattr("nanobot.channels.matrix.get_media_dir", lambda _name: tmp_path)
async def _download_too_large(_mxc_url: str, _limit_bytes: int):
raise matrix_module._MediaTooLargeError
monkeypatch.setattr(channel, "_download_media_bytes", _download_too_large)
event = SimpleNamespace(
sender="@alice:matrix.org",
event_id="$event1",
body="payload.bin",
url="mxc://example.org/media",
source={"content": {"msgtype": "m.file", "info": {"size": 8}}},
)
attachment, marker = await channel._fetch_media_attachment(
SimpleNamespace(room_id="!room:matrix.org"),
event,
)
assert attachment is None
assert marker == "[attachment: payload.bin - too large]"
assert client.room_send_calls == []