fix(matrix): handle SAS device verification

This commit is contained in:
mytechdream 2026-05-30 20:11:05 +08:00 committed by Xubin Ren
parent 0cc58a80a4
commit 68712fc489
3 changed files with 294 additions and 0 deletions

View File

@ -244,6 +244,7 @@ for reliable encryption, password login is recommended instead. If the
"userId": "@nanobot:matrix.org", "userId": "@nanobot:matrix.org",
"password": "mypasswordhere", "password": "mypasswordhere",
"e2eeEnabled": true, "e2eeEnabled": true,
"sasVerification": true,
"allowFrom": ["@your_user:matrix.org"], "allowFrom": ["@your_user:matrix.org"],
"groupPolicy": "open", "groupPolicy": "open",
"groupAllowFrom": [], "groupAllowFrom": [],
@ -263,6 +264,7 @@ for reliable encryption, password login is recommended instead. If the
| `groupAllowFrom` | Room allowlist (used when policy is `allowlist`). | | `groupAllowFrom` | Room allowlist (used when policy is `allowlist`). |
| `allowRoomMentions` | Accept `@room` mentions in mention mode. | | `allowRoomMentions` | Accept `@room` mentions in mention mode. |
| `e2eeEnabled` | E2EE support (default `true`). Set `false` for plaintext-only. | | `e2eeEnabled` | E2EE support (default `true`). Set `false` for plaintext-only. |
| `sasVerification` | Auto-complete SAS device verification requests from allowed users (default `false`). Useful for Element X, which does not expose manual trust for third-party devices. |
| `maxMediaBytes` | Max attachment size (default `20MB`). Set `0` to block all media. | | `maxMediaBytes` | Max attachment size (default `20MB`). Set `0` to block all media. |

View File

@ -23,6 +23,12 @@ try:
AsyncClientConfig, AsyncClientConfig,
InviteEvent, InviteEvent,
JoinError, JoinError,
KeyVerificationCancel,
KeyVerificationEvent,
KeyVerificationKey,
KeyVerificationMac,
KeyVerificationStart,
LocalProtocolError,
LoginResponse, LoginResponse,
MatrixRoom, MatrixRoom,
RoomEncryptedMedia, RoomEncryptedMedia,
@ -33,6 +39,7 @@ try:
RoomSendResponse, RoomSendResponse,
RoomTypingError, RoomTypingError,
SyncError, SyncError,
ToDeviceError,
UploadError, UploadError,
) )
from nio.crypto.attachments import decrypt_attachment from nio.crypto.attachments import decrypt_attachment
@ -194,6 +201,7 @@ class MatrixConfig(Base):
access_token: str = "" access_token: str = ""
device_id: str = "" device_id: str = ""
e2ee_enabled: bool = Field(default=True, alias="e2eeEnabled") e2ee_enabled: bool = Field(default=True, alias="e2eeEnabled")
sas_verification: bool = Field(default=False, alias="sasVerification")
sync_stop_grace_seconds: int = 2 sync_stop_grace_seconds: int = 2
max_media_bytes: int = 20 * 1024 * 1024 max_media_bytes: int = 20 * 1024 * 1024
max_concurrent_media_downloads: int = 2 max_concurrent_media_downloads: int = 2
@ -268,6 +276,7 @@ class MatrixChannel(BaseChannel):
) )
self._register_event_callbacks() self._register_event_callbacks()
self._register_to_device_callbacks()
self._register_response_callbacks() self._register_response_callbacks()
if not self.config.e2ee_enabled: if not self.config.e2ee_enabled:
@ -572,11 +581,102 @@ class MatrixChannel(BaseChannel):
self.client.add_event_callback(self._on_media_message, MATRIX_MEDIA_EVENT_FILTER) self.client.add_event_callback(self._on_media_message, MATRIX_MEDIA_EVENT_FILTER)
self.client.add_event_callback(self._on_room_invite, InviteEvent) self.client.add_event_callback(self._on_room_invite, InviteEvent)
def _register_to_device_callbacks(self) -> None:
if self.config.e2ee_enabled and self.config.sas_verification:
self.client.add_to_device_callback(
self._on_key_verification_event,
(KeyVerificationEvent,),
)
def _register_response_callbacks(self) -> None: def _register_response_callbacks(self) -> None:
self.client.add_response_callback(self._on_sync_error, SyncError) self.client.add_response_callback(self._on_sync_error, SyncError)
self.client.add_response_callback(self._on_join_error, JoinError) self.client.add_response_callback(self._on_join_error, JoinError)
self.client.add_response_callback(self._on_send_error, RoomSendError) self.client.add_response_callback(self._on_send_error, RoomSendError)
def _is_sas_sender_allowed(self, sender: str) -> bool:
return bool(sender and self.is_allowed(sender))
async def _on_key_verification_event(self, event: KeyVerificationEvent) -> None:
try:
await self._handle_key_verification_event(event)
except asyncio.CancelledError:
raise
except Exception:
self.logger.exception("Matrix SAS verification handling failed")
async def _handle_key_verification_event(self, event: KeyVerificationEvent) -> None:
if not (self.config.e2ee_enabled and self.config.sas_verification):
return
if not self.client:
return
sender = str(getattr(event, "sender", "") or "")
transaction_id = str(getattr(event, "transaction_id", "") or "")
if not transaction_id or not self._is_sas_sender_allowed(sender):
return
if isinstance(event, KeyVerificationStart):
if "emoji" not in (getattr(event, "short_authentication_string", None) or []):
self.logger.info(
"Ignoring Matrix SAS verification from {} without emoji support",
sender,
)
return
response = await self.client.accept_key_verification(transaction_id)
if isinstance(response, ToDeviceError):
self.logger.warning("Matrix SAS accept failed for {}: {}", sender, response)
return
sas = getattr(self.client, "key_verifications", {}).get(transaction_id)
if sas is None:
self.logger.warning(
"Matrix SAS state missing after accept for transaction {}",
transaction_id,
)
return
response = await self.client.to_device(sas.share_key())
if isinstance(response, ToDeviceError):
self.logger.warning(
"Matrix SAS key share failed for {}: {}",
sender,
response,
)
return
if isinstance(event, KeyVerificationKey):
response = await self.client.confirm_short_auth_string(transaction_id)
if isinstance(response, ToDeviceError):
self.logger.warning("Matrix SAS confirm failed for {}: {}", sender, response)
return
if isinstance(event, KeyVerificationMac):
sas = getattr(self.client, "key_verifications", {}).get(transaction_id)
if sas is None:
self.logger.warning(
"Matrix SAS state missing for MAC transaction {}",
transaction_id,
)
return
try:
response = await self.client.to_device(sas.get_mac())
except LocalProtocolError as e:
self.logger.warning("Matrix SAS MAC failed for {}: {}", sender, e)
return
if isinstance(response, ToDeviceError):
self.logger.warning("Matrix SAS MAC send failed for {}: {}", sender, response)
else:
self.logger.info("Matrix SAS verification completed for {}", sender)
return
if isinstance(event, KeyVerificationCancel):
self.logger.info(
"Matrix SAS verification cancelled by {}: {}",
sender,
getattr(event, "reason", ""),
)
def _is_fatal_auth_response(self, response: Any) -> bool: def _is_fatal_auth_response(self, response: Any) -> bool:
code = getattr(response, "status_code", None) code = getattr(response, "status_code", None)
is_auth = code in {"M_UNKNOWN_TOKEN", "M_FORBIDDEN", "M_UNAUTHORIZED"} is_auth = code in {"M_UNKNOWN_TOKEN", "M_FORBIDDEN", "M_UNAUTHORIZED"}

View File

@ -50,7 +50,15 @@ class _FakeAsyncClient:
self.stop_sync_forever_called = False self.stop_sync_forever_called = False
self.join_calls: list[str] = [] self.join_calls: list[str] = []
self.callbacks: list[tuple[object, object]] = [] self.callbacks: list[tuple[object, object]] = []
self.to_device_callbacks: list[tuple[object, object]] = []
self.response_callbacks: list[tuple[object, object]] = [] self.response_callbacks: list[tuple[object, object]] = []
self.key_verifications: dict[str, object] = {}
self.accept_key_verification_calls: list[str] = []
self.confirm_short_auth_string_calls: list[str] = []
self.to_device_calls: list[object] = []
self.accept_key_verification_response: object | None = None
self.confirm_short_auth_string_response: object | None = None
self.to_device_response: object | None = None
self.rooms: dict[str, object] = {} self.rooms: dict[str, object] = {}
self.room_send_calls: list[dict[str, object]] = [] self.room_send_calls: list[dict[str, object]] = []
self.typing_calls: list[tuple[str, bool, int]] = [] self.typing_calls: list[tuple[str, bool, int]] = []
@ -70,6 +78,9 @@ class _FakeAsyncClient:
def add_event_callback(self, callback, event_type) -> None: def add_event_callback(self, callback, event_type) -> None:
self.callbacks.append((callback, event_type)) self.callbacks.append((callback, event_type))
def add_to_device_callback(self, callback, event_type) -> None:
self.to_device_callbacks.append((callback, event_type))
def add_response_callback(self, callback, response_type) -> None: def add_response_callback(self, callback, response_type) -> None:
self.response_callbacks.append((callback, response_type)) self.response_callbacks.append((callback, response_type))
@ -82,6 +93,18 @@ class _FakeAsyncClient:
async def join(self, room_id: str) -> None: async def join(self, room_id: str) -> None:
self.join_calls.append(room_id) self.join_calls.append(room_id)
async def accept_key_verification(self, transaction_id: str):
self.accept_key_verification_calls.append(transaction_id)
return self.accept_key_verification_response
async def confirm_short_auth_string(self, transaction_id: str):
self.confirm_short_auth_string_calls.append(transaction_id)
return self.confirm_short_auth_string_response
async def to_device(self, message):
self.to_device_calls.append(message)
return self.to_device_response
async def room_send( async def room_send(
self, self,
room_id: str, room_id: str,
@ -166,6 +189,61 @@ class _FakeAsyncClient:
return None return None
class _FakeSas:
def __init__(self) -> None:
self.share_key_called = False
self.get_mac_called = False
def share_key(self):
self.share_key_called = True
return {"type": "share_key"}
def get_mac(self):
self.get_mac_called = True
return {"type": "mac"}
class _FakeKeyVerificationStart:
def __init__(
self,
*,
sender: str = "@alice:matrix.org",
transaction_id: str = "tx1",
short_authentication_string: list[str] | None = None,
) -> None:
self.sender = sender
self.transaction_id = transaction_id
self.short_authentication_string = short_authentication_string or ["emoji"]
class _FakeKeyVerificationKey:
def __init__(
self,
*,
sender: str = "@alice:matrix.org",
transaction_id: str = "tx1",
) -> None:
self.sender = sender
self.transaction_id = transaction_id
class _FakeKeyVerificationMac:
def __init__(
self,
*,
sender: str = "@alice:matrix.org",
transaction_id: str = "tx1",
) -> None:
self.sender = sender
self.transaction_id = transaction_id
def _patch_key_verification_events(monkeypatch) -> None:
monkeypatch.setattr(matrix_module, "KeyVerificationStart", _FakeKeyVerificationStart)
monkeypatch.setattr(matrix_module, "KeyVerificationKey", _FakeKeyVerificationKey)
monkeypatch.setattr(matrix_module, "KeyVerificationMac", _FakeKeyVerificationMac)
def _make_config(**kwargs) -> MatrixConfig: def _make_config(**kwargs) -> MatrixConfig:
kwargs.setdefault("allow_from", ["*"]) kwargs.setdefault("allow_from", ["*"])
return MatrixConfig( return MatrixConfig(
@ -209,6 +287,7 @@ async def test_start_skips_load_store_when_device_id_missing(
assert clients[0].config.encryption_enabled is True assert clients[0].config.encryption_enabled is True
assert clients[0].load_store_called is False assert clients[0].load_store_called is False
assert len(clients[0].callbacks) == 3 assert len(clients[0].callbacks) == 3
assert clients[0].to_device_callbacks == []
assert len(clients[0].response_callbacks) == 3 assert len(clients[0].response_callbacks) == 3
await channel.stop() await channel.stop()
@ -227,6 +306,119 @@ async def test_register_event_callbacks_uses_media_base_filter() -> None:
assert client.callbacks[1][1] == matrix_module.MATRIX_MEDIA_EVENT_FILTER assert client.callbacks[1][1] == matrix_module.MATRIX_MEDIA_EVENT_FILTER
def test_register_to_device_callbacks_when_sas_verification_enabled() -> None:
channel = MatrixChannel(_make_config(sas_verification=True), MessageBus())
client = _FakeAsyncClient("", "", "", None)
channel.client = client
channel._register_to_device_callbacks()
assert client.to_device_callbacks == [
(channel._on_key_verification_event, (matrix_module.KeyVerificationEvent,))
]
def test_register_to_device_callbacks_skips_when_e2ee_disabled() -> None:
channel = MatrixChannel(
_make_config(e2ee_enabled=False, sas_verification=True),
MessageBus(),
)
client = _FakeAsyncClient("", "", "", None)
channel.client = client
channel._register_to_device_callbacks()
assert client.to_device_callbacks == []
@pytest.mark.asyncio
async def test_sas_verification_start_accepts_allowed_sender(monkeypatch) -> None:
_patch_key_verification_events(monkeypatch)
channel = MatrixChannel(
_make_config(allow_from=["@alice:matrix.org"], sas_verification=True),
MessageBus(),
)
client = _FakeAsyncClient("", "", "", None)
sas = _FakeSas()
client.key_verifications["tx1"] = sas
channel.client = client
await channel._handle_key_verification_event(_FakeKeyVerificationStart())
assert client.accept_key_verification_calls == ["tx1"]
assert sas.share_key_called is True
assert client.to_device_calls == [{"type": "share_key"}]
@pytest.mark.asyncio
async def test_sas_verification_ignores_denied_sender(monkeypatch) -> None:
_patch_key_verification_events(monkeypatch)
channel = MatrixChannel(
_make_config(allow_from=["@alice:matrix.org"], sas_verification=True),
MessageBus(),
)
client = _FakeAsyncClient("", "", "", None)
client.key_verifications["tx1"] = _FakeSas()
channel.client = client
await channel._handle_key_verification_event(
_FakeKeyVerificationStart(sender="@mallory:matrix.org")
)
assert client.accept_key_verification_calls == []
assert client.to_device_calls == []
@pytest.mark.asyncio
async def test_sas_verification_ignores_when_disabled(monkeypatch) -> None:
_patch_key_verification_events(monkeypatch)
channel = MatrixChannel(
_make_config(allow_from=["@alice:matrix.org"], sas_verification=False),
MessageBus(),
)
client = _FakeAsyncClient("", "", "", None)
client.key_verifications["tx1"] = _FakeSas()
channel.client = client
await channel._handle_key_verification_event(_FakeKeyVerificationStart())
assert client.accept_key_verification_calls == []
assert client.to_device_calls == []
@pytest.mark.asyncio
async def test_sas_verification_key_confirms_allowed_sender(monkeypatch) -> None:
_patch_key_verification_events(monkeypatch)
channel = MatrixChannel(
_make_config(allow_from=["@alice:matrix.org"], sas_verification=True),
MessageBus(),
)
client = _FakeAsyncClient("", "", "", None)
channel.client = client
await channel._handle_key_verification_event(_FakeKeyVerificationKey())
assert client.confirm_short_auth_string_calls == ["tx1"]
@pytest.mark.asyncio
async def test_sas_verification_mac_sends_mac_for_allowed_sender(monkeypatch) -> None:
_patch_key_verification_events(monkeypatch)
channel = MatrixChannel(
_make_config(allow_from=["@alice:matrix.org"], sas_verification=True),
MessageBus(),
)
client = _FakeAsyncClient("", "", "", None)
sas = _FakeSas()
client.key_verifications["tx1"] = sas
channel.client = client
await channel._handle_key_verification_event(_FakeKeyVerificationMac())
assert sas.get_mac_called is True
assert client.to_device_calls == [{"type": "mac"}]
def test_media_event_filter_does_not_match_text_events() -> None: def test_media_event_filter_does_not_match_text_events() -> None:
assert not issubclass(matrix_module.RoomMessageText, matrix_module.MATRIX_MEDIA_EVENT_FILTER) assert not issubclass(matrix_module.RoomMessageText, matrix_module.MATRIX_MEDIA_EVENT_FILTER)