mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-15 07:14:08 +00:00
fix(matrix): bound inbound media downloads
This commit is contained in:
parent
7c86223643
commit
4dd89f4c46
@ -8,23 +8,23 @@ from contextlib import suppress
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Literal, TypeAlias
|
from typing import Any, Literal, TypeAlias
|
||||||
|
from urllib.parse import quote, urlparse
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from nanobot.security.workspace_policy import is_path_within
|
from nanobot.security.workspace_policy import is_path_within
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
import aiohttp
|
||||||
import nh3
|
import nh3
|
||||||
from mistune import create_markdown
|
from mistune import create_markdown
|
||||||
from nio import (
|
from nio import (
|
||||||
AsyncClient,
|
AsyncClient,
|
||||||
AsyncClientConfig,
|
AsyncClientConfig,
|
||||||
DownloadError,
|
|
||||||
InviteEvent,
|
InviteEvent,
|
||||||
JoinError,
|
JoinError,
|
||||||
LoginResponse,
|
LoginResponse,
|
||||||
MatrixRoom,
|
MatrixRoom,
|
||||||
MemoryDownloadResponse,
|
|
||||||
RoomEncryptedMedia,
|
RoomEncryptedMedia,
|
||||||
RoomMessage,
|
RoomMessage,
|
||||||
RoomMessageMedia,
|
RoomMessageMedia,
|
||||||
@ -64,6 +64,10 @@ _MSGTYPE_MAP = {"m.image": "image", "m.audio": "audio", "m.video": "video", "m.f
|
|||||||
MATRIX_MEDIA_EVENT_FILTER = (RoomMessageMedia, RoomEncryptedMedia)
|
MATRIX_MEDIA_EVENT_FILTER = (RoomMessageMedia, RoomEncryptedMedia)
|
||||||
MatrixMediaEvent: TypeAlias = RoomMessageMedia | RoomEncryptedMedia
|
MatrixMediaEvent: TypeAlias = RoomMessageMedia | RoomEncryptedMedia
|
||||||
|
|
||||||
|
|
||||||
|
class _MediaTooLargeError(Exception):
|
||||||
|
"""Raised when an inbound Matrix media download exceeds the configured cap."""
|
||||||
|
|
||||||
MATRIX_MARKDOWN = create_markdown(
|
MATRIX_MARKDOWN = create_markdown(
|
||||||
escape=True,
|
escape=True,
|
||||||
plugins=["table", "strikethrough", "url", "superscript", "subscript"],
|
plugins=["table", "strikethrough", "url", "superscript", "subscript"],
|
||||||
@ -192,6 +196,7 @@ class MatrixConfig(Base):
|
|||||||
e2ee_enabled: bool = Field(default=True, alias="e2eeEnabled")
|
e2ee_enabled: bool = Field(default=True, alias="e2eeEnabled")
|
||||||
sync_stop_grace_seconds: int = 2
|
sync_stop_grace_seconds: int = 2
|
||||||
max_media_bytes: int = 20 * 1024 * 1024
|
max_media_bytes: int = 20 * 1024 * 1024
|
||||||
|
max_concurrent_media_downloads: int = 2
|
||||||
allow_from: list[str] = Field(default_factory=list)
|
allow_from: list[str] = Field(default_factory=list)
|
||||||
group_policy: Literal["open", "mention", "allowlist"] = "open"
|
group_policy: Literal["open", "mention", "allowlist"] = "open"
|
||||||
group_allow_from: list[str] = Field(default_factory=list)
|
group_allow_from: list[str] = Field(default_factory=list)
|
||||||
@ -233,6 +238,9 @@ class MatrixChannel(BaseChannel):
|
|||||||
self._server_upload_limit_checked = False
|
self._server_upload_limit_checked = False
|
||||||
self._stream_bufs: dict[str, _StreamBuf] = {}
|
self._stream_bufs: dict[str, _StreamBuf] = {}
|
||||||
self._started_at_ms: int = 0
|
self._started_at_ms: int = 0
|
||||||
|
self._media_download_semaphore = asyncio.Semaphore(
|
||||||
|
max(1, int(self.config.max_concurrent_media_downloads))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
@ -770,26 +778,48 @@ class MatrixChannel(BaseChannel):
|
|||||||
event_prefix = (event_id[:24] or "evt").strip("_")
|
event_prefix = (event_id[:24] or "evt").strip("_")
|
||||||
return self._media_dir() / f"{event_prefix}_{stem}{suffix}"
|
return self._media_dir() / f"{event_prefix}_{stem}{suffix}"
|
||||||
|
|
||||||
async def _download_media_bytes(self, mxc_url: str) -> bytes | None:
|
async def _download_media_bytes(self, mxc_url: str, limit_bytes: int) -> bytes | None:
|
||||||
if not self.client:
|
if not self.client or limit_bytes <= 0:
|
||||||
|
raise _MediaTooLargeError
|
||||||
|
|
||||||
|
parsed = urlparse(mxc_url)
|
||||||
|
if parsed.scheme != "mxc" or not parsed.netloc or not parsed.path.strip("/"):
|
||||||
return None
|
return None
|
||||||
response = await self.client.download(mxc=mxc_url)
|
|
||||||
if isinstance(response, DownloadError):
|
homeserver = str(getattr(self.client, "homeserver", "") or self.config.homeserver).rstrip("/")
|
||||||
self.logger.warning("download failed for {}: {}", mxc_url, response)
|
media_url = (
|
||||||
|
f"{homeserver}/_matrix/client/v1/media/download/"
|
||||||
|
f"{quote(parsed.netloc, safe='')}/{quote(parsed.path.strip('/'), safe='')}"
|
||||||
|
)
|
||||||
|
token = getattr(self.client, "access_token", None) or self.config.access_token
|
||||||
|
headers = {"Authorization": f"Bearer {token}"} if token else None
|
||||||
|
timeout = aiohttp.ClientTimeout(total=None)
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession(timeout=timeout, headers=headers) as session:
|
||||||
|
async with session.get(media_url, params={"allow_remote": "true"}) as response:
|
||||||
|
if response.status >= 400:
|
||||||
|
self.logger.warning("download failed for {}: HTTP {}", mxc_url, response.status)
|
||||||
|
return None
|
||||||
|
content_length = response.headers.get("Content-Length")
|
||||||
|
if content_length is not None:
|
||||||
|
try:
|
||||||
|
if int(content_length) > limit_bytes:
|
||||||
|
raise _MediaTooLargeError
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
chunks = bytearray()
|
||||||
|
async for chunk in response.content.iter_chunked(64 * 1024):
|
||||||
|
chunks.extend(chunk)
|
||||||
|
if len(chunks) > limit_bytes:
|
||||||
|
raise _MediaTooLargeError
|
||||||
|
return bytes(chunks)
|
||||||
|
except _MediaTooLargeError:
|
||||||
|
raise
|
||||||
|
except (aiohttp.ClientError, asyncio.TimeoutError, OSError):
|
||||||
|
self.logger.warning("download failed for {}", mxc_url, exc_info=True)
|
||||||
return None
|
return None
|
||||||
body = getattr(response, "body", None)
|
|
||||||
if isinstance(body, (bytes, bytearray)):
|
|
||||||
return bytes(body)
|
|
||||||
if isinstance(response, MemoryDownloadResponse):
|
|
||||||
return bytes(response.body)
|
|
||||||
if isinstance(body, (str, Path)):
|
|
||||||
path = Path(body)
|
|
||||||
if path.is_file():
|
|
||||||
try:
|
|
||||||
return path.read_bytes()
|
|
||||||
except OSError:
|
|
||||||
return None
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _decrypt_media_bytes(self, event: MatrixMediaEvent, ciphertext: bytes) -> bytes | None:
|
def _decrypt_media_bytes(self, event: MatrixMediaEvent, ciphertext: bytes) -> bytes | None:
|
||||||
key_obj, hashes, iv = getattr(event, "key", None), getattr(event, "hashes", None), getattr(event, "iv", None)
|
key_obj, hashes, iv = getattr(event, "key", None), getattr(event, "hashes", None), getattr(event, "iv", None)
|
||||||
@ -818,10 +848,14 @@ class MatrixChannel(BaseChannel):
|
|||||||
|
|
||||||
limit_bytes = await self._effective_media_limit_bytes()
|
limit_bytes = await self._effective_media_limit_bytes()
|
||||||
declared = self._event_declared_size_bytes(event)
|
declared = self._event_declared_size_bytes(event)
|
||||||
if declared is not None and declared > limit_bytes:
|
if declared is None or declared > limit_bytes:
|
||||||
return None, _ATTACH_TOO_LARGE.format(filename)
|
return None, _ATTACH_TOO_LARGE.format(filename)
|
||||||
|
|
||||||
downloaded = await self._download_media_bytes(mxc_url)
|
try:
|
||||||
|
async with self._media_download_semaphore:
|
||||||
|
downloaded = await self._download_media_bytes(mxc_url, limit_bytes)
|
||||||
|
except _MediaTooLargeError:
|
||||||
|
return None, _ATTACH_TOO_LARGE.format(filename)
|
||||||
if downloaded is None:
|
if downloaded is None:
|
||||||
return None, fail
|
return None, fail
|
||||||
|
|
||||||
|
|||||||
@ -82,6 +82,7 @@ msteams = [
|
|||||||
|
|
||||||
matrix = [
|
matrix = [
|
||||||
"matrix-nio[e2e]>=0.25.2; sys_platform != 'win32'",
|
"matrix-nio[e2e]>=0.25.2; sys_platform != 'win32'",
|
||||||
|
"aiohttp>=3.9.0,<4.0.0",
|
||||||
"mistune>=3.0.0,<4.0.0",
|
"mistune>=3.0.0,<4.0.0",
|
||||||
"nh3>=0.2.17,<1.0.0",
|
"nh3>=0.2.17,<1.0.0",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -9,8 +9,6 @@ pytest.importorskip("nh3")
|
|||||||
pytest.importorskip("mistune")
|
pytest.importorskip("mistune")
|
||||||
from nio import RoomSendResponse, SyncError
|
from nio import RoomSendResponse, SyncError
|
||||||
|
|
||||||
from nanobot.channels.matrix import _build_matrix_text_content
|
|
||||||
|
|
||||||
import nanobot.channels.matrix as matrix_module
|
import nanobot.channels.matrix as matrix_module
|
||||||
from nanobot.bus.events import OutboundMessage
|
from nanobot.bus.events import OutboundMessage
|
||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
@ -18,8 +16,9 @@ from nanobot.channels.matrix import (
|
|||||||
MATRIX_HTML_FORMAT,
|
MATRIX_HTML_FORMAT,
|
||||||
TYPING_NOTICE_TIMEOUT_MS,
|
TYPING_NOTICE_TIMEOUT_MS,
|
||||||
MatrixChannel,
|
MatrixChannel,
|
||||||
|
MatrixConfig,
|
||||||
|
_build_matrix_text_content,
|
||||||
)
|
)
|
||||||
from nanobot.channels.matrix import MatrixConfig
|
|
||||||
|
|
||||||
_ROOM_SEND_UNSET = object()
|
_ROOM_SEND_UNSET = object()
|
||||||
|
|
||||||
@ -693,6 +692,13 @@ async def test_on_media_message_downloads_attachment_and_sets_metadata(
|
|||||||
client.download_bytes = b"image"
|
client.download_bytes = b"image"
|
||||||
channel.client = client
|
channel.client = client
|
||||||
|
|
||||||
|
async def _download_media_bytes(mxc_url: str, limit_bytes: int) -> bytes:
|
||||||
|
client.download_calls.append(mxc_url)
|
||||||
|
assert limit_bytes >= len(client.download_bytes)
|
||||||
|
return client.download_bytes
|
||||||
|
|
||||||
|
monkeypatch.setattr(channel, "_download_media_bytes", _download_media_bytes)
|
||||||
|
|
||||||
handled: list[dict[str, object]] = []
|
handled: list[dict[str, object]] = []
|
||||||
|
|
||||||
async def _fake_handle_message(**kwargs) -> None:
|
async def _fake_handle_message(**kwargs) -> None:
|
||||||
@ -857,9 +863,14 @@ async def test_on_media_message_handles_download_error(monkeypatch, tmp_path) ->
|
|||||||
|
|
||||||
channel = MatrixChannel(_make_config(), MessageBus())
|
channel = MatrixChannel(_make_config(), MessageBus())
|
||||||
client = _FakeAsyncClient("", "", "", None)
|
client = _FakeAsyncClient("", "", "", None)
|
||||||
client.download_response = matrix_module.DownloadError("download failed")
|
|
||||||
channel.client = client
|
channel.client = client
|
||||||
|
|
||||||
|
async def _download_media_bytes(mxc_url: str, _limit_bytes: int):
|
||||||
|
client.download_calls.append(mxc_url)
|
||||||
|
return None
|
||||||
|
|
||||||
|
monkeypatch.setattr(channel, "_download_media_bytes", _download_media_bytes)
|
||||||
|
|
||||||
handled: list[dict[str, object]] = []
|
handled: list[dict[str, object]] = []
|
||||||
|
|
||||||
async def _fake_handle_message(**kwargs) -> None:
|
async def _fake_handle_message(**kwargs) -> None:
|
||||||
@ -873,7 +884,7 @@ async def test_on_media_message_handles_download_error(monkeypatch, tmp_path) ->
|
|||||||
body="photo.png",
|
body="photo.png",
|
||||||
url="mxc://example.org/mediaid",
|
url="mxc://example.org/mediaid",
|
||||||
event_id="$event3",
|
event_id="$event3",
|
||||||
source={"content": {"msgtype": "m.image"}},
|
source={"content": {"msgtype": "m.image", "info": {"size": 5}}},
|
||||||
)
|
)
|
||||||
|
|
||||||
await channel._on_media_message(room, event)
|
await channel._on_media_message(room, event)
|
||||||
@ -899,6 +910,13 @@ async def test_on_media_message_decrypts_encrypted_media(monkeypatch, tmp_path)
|
|||||||
client.download_bytes = b"cipher"
|
client.download_bytes = b"cipher"
|
||||||
channel.client = client
|
channel.client = client
|
||||||
|
|
||||||
|
async def _download_media_bytes(mxc_url: str, limit_bytes: int) -> bytes:
|
||||||
|
client.download_calls.append(mxc_url)
|
||||||
|
assert limit_bytes >= len(client.download_bytes)
|
||||||
|
return client.download_bytes
|
||||||
|
|
||||||
|
monkeypatch.setattr(channel, "_download_media_bytes", _download_media_bytes)
|
||||||
|
|
||||||
handled: list[dict[str, object]] = []
|
handled: list[dict[str, object]] = []
|
||||||
|
|
||||||
async def _fake_handle_message(**kwargs) -> None:
|
async def _fake_handle_message(**kwargs) -> None:
|
||||||
@ -942,6 +960,13 @@ async def test_on_media_message_handles_decrypt_error(monkeypatch, tmp_path) ->
|
|||||||
client.download_bytes = b"cipher"
|
client.download_bytes = b"cipher"
|
||||||
channel.client = client
|
channel.client = client
|
||||||
|
|
||||||
|
async def _download_media_bytes(mxc_url: str, limit_bytes: int) -> bytes:
|
||||||
|
client.download_calls.append(mxc_url)
|
||||||
|
assert limit_bytes >= len(client.download_bytes)
|
||||||
|
return client.download_bytes
|
||||||
|
|
||||||
|
monkeypatch.setattr(channel, "_download_media_bytes", _download_media_bytes)
|
||||||
|
|
||||||
handled: list[dict[str, object]] = []
|
handled: list[dict[str, object]] = []
|
||||||
|
|
||||||
async def _fake_handle_message(**kwargs) -> None:
|
async def _fake_handle_message(**kwargs) -> None:
|
||||||
@ -958,7 +983,7 @@ async def test_on_media_message_handles_decrypt_error(monkeypatch, tmp_path) ->
|
|||||||
key={"k": "key"},
|
key={"k": "key"},
|
||||||
hashes={"sha256": "hash"},
|
hashes={"sha256": "hash"},
|
||||||
iv="iv",
|
iv="iv",
|
||||||
source={"content": {"msgtype": "m.file"}},
|
source={"content": {"msgtype": "m.file", "info": {"size": 6}}},
|
||||||
)
|
)
|
||||||
|
|
||||||
await channel._on_media_message(room, event)
|
await channel._on_media_message(room, event)
|
||||||
@ -1756,7 +1781,7 @@ async def test_send_delta_on_error_stops_typing(monkeypatch) -> None:
|
|||||||
assert "!room:matrix.org" in channel._stream_bufs
|
assert "!room:matrix.org" in channel._stream_bufs
|
||||||
assert channel._stream_bufs["!room:matrix.org"].text == "Hello"
|
assert channel._stream_bufs["!room:matrix.org"].text == "Hello"
|
||||||
assert len(client.room_send_calls) == 1
|
assert len(client.room_send_calls) == 1
|
||||||
|
|
||||||
assert len(client.typing_calls) == 1
|
assert len(client.typing_calls) == 1
|
||||||
|
|
||||||
|
|
||||||
@ -1773,4 +1798,88 @@ async def test_send_delta_ignores_whitespace_only_delta(monkeypatch) -> None:
|
|||||||
|
|
||||||
assert "!room:matrix.org" in channel._stream_bufs
|
assert "!room:matrix.org" in channel._stream_bufs
|
||||||
assert channel._stream_bufs["!room:matrix.org"].text == " "
|
assert channel._stream_bufs["!room:matrix.org"].text == " "
|
||||||
assert client.room_send_calls == []
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fetch_media_rejects_missing_declared_size(monkeypatch, tmp_path) -> None:
|
||||||
|
channel = MatrixChannel(_make_config(max_media_bytes=8), MessageBus())
|
||||||
|
client = _FakeAsyncClient("https://matrix.org", "", "", None)
|
||||||
|
channel.client = client
|
||||||
|
monkeypatch.setattr("nanobot.channels.matrix.get_media_dir", lambda _name: tmp_path)
|
||||||
|
|
||||||
|
async def _download_should_not_run(*_args, **_kwargs):
|
||||||
|
raise AssertionError("download should be rejected before fetching bytes")
|
||||||
|
|
||||||
|
monkeypatch.setattr(channel, "_download_media_bytes", _download_should_not_run)
|
||||||
|
event = SimpleNamespace(
|
||||||
|
sender="@alice:matrix.org",
|
||||||
|
event_id="$event1",
|
||||||
|
body="payload.bin",
|
||||||
|
url="mxc://example.org/media",
|
||||||
|
source={"content": {"msgtype": "m.file"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
attachment, marker = await channel._fetch_media_attachment(
|
||||||
|
SimpleNamespace(room_id="!room:matrix.org"),
|
||||||
|
event,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert attachment is None
|
||||||
|
assert marker == "[attachment: payload.bin - too large]"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fetch_media_rejects_declared_oversized_before_download(monkeypatch, tmp_path) -> None:
|
||||||
|
channel = MatrixChannel(_make_config(max_media_bytes=8), MessageBus())
|
||||||
|
client = _FakeAsyncClient("https://matrix.org", "", "", None)
|
||||||
|
channel.client = client
|
||||||
|
monkeypatch.setattr("nanobot.channels.matrix.get_media_dir", lambda _name: tmp_path)
|
||||||
|
|
||||||
|
async def _download_should_not_run(*_args, **_kwargs):
|
||||||
|
raise AssertionError("download should be rejected before fetching bytes")
|
||||||
|
|
||||||
|
monkeypatch.setattr(channel, "_download_media_bytes", _download_should_not_run)
|
||||||
|
event = SimpleNamespace(
|
||||||
|
sender="@alice:matrix.org",
|
||||||
|
event_id="$event1",
|
||||||
|
body="payload.bin",
|
||||||
|
url="mxc://example.org/media",
|
||||||
|
source={"content": {"msgtype": "m.file", "info": {"size": 9}}},
|
||||||
|
)
|
||||||
|
|
||||||
|
attachment, marker = await channel._fetch_media_attachment(
|
||||||
|
SimpleNamespace(room_id="!room:matrix.org"),
|
||||||
|
event,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert attachment is None
|
||||||
|
assert marker == "[attachment: payload.bin - too large]"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fetch_media_maps_streaming_cap_to_too_large(monkeypatch, tmp_path) -> None:
|
||||||
|
channel = MatrixChannel(_make_config(max_media_bytes=8), MessageBus())
|
||||||
|
client = _FakeAsyncClient("https://matrix.org", "", "", None)
|
||||||
|
channel.client = client
|
||||||
|
monkeypatch.setattr("nanobot.channels.matrix.get_media_dir", lambda _name: tmp_path)
|
||||||
|
|
||||||
|
async def _download_too_large(_mxc_url: str, _limit_bytes: int):
|
||||||
|
raise matrix_module._MediaTooLargeError
|
||||||
|
|
||||||
|
monkeypatch.setattr(channel, "_download_media_bytes", _download_too_large)
|
||||||
|
event = SimpleNamespace(
|
||||||
|
sender="@alice:matrix.org",
|
||||||
|
event_id="$event1",
|
||||||
|
body="payload.bin",
|
||||||
|
url="mxc://example.org/media",
|
||||||
|
source={"content": {"msgtype": "m.file", "info": {"size": 8}}},
|
||||||
|
)
|
||||||
|
|
||||||
|
attachment, marker = await channel._fetch_media_attachment(
|
||||||
|
SimpleNamespace(room_id="!room:matrix.org"),
|
||||||
|
event,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert attachment is None
|
||||||
|
assert marker == "[attachment: payload.bin - too large]"
|
||||||
|
assert client.room_send_calls == []
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user