From 68712fc489bd2441ab7f889c79d6048c969beefb Mon Sep 17 00:00:00 2001 From: mytechdream <1838492264@qq.com> Date: Sat, 30 May 2026 20:11:05 +0800 Subject: [PATCH] fix(matrix): handle SAS device verification --- docs/chat-apps.md | 2 + nanobot/channels/matrix.py | 100 ++++++++++++++ tests/channels/test_matrix_channel.py | 192 ++++++++++++++++++++++++++ 3 files changed, 294 insertions(+) diff --git a/docs/chat-apps.md b/docs/chat-apps.md index 58429c16b..da28ef21c 100644 --- a/docs/chat-apps.md +++ b/docs/chat-apps.md @@ -244,6 +244,7 @@ for reliable encryption, password login is recommended instead. If the "userId": "@nanobot:matrix.org", "password": "mypasswordhere", "e2eeEnabled": true, + "sasVerification": true, "allowFrom": ["@your_user:matrix.org"], "groupPolicy": "open", "groupAllowFrom": [], @@ -263,6 +264,7 @@ for reliable encryption, password login is recommended instead. If the | `groupAllowFrom` | Room allowlist (used when policy is `allowlist`). | | `allowRoomMentions` | Accept `@room` mentions in mention mode. | | `e2eeEnabled` | E2EE support (default `true`). Set `false` for plaintext-only. | +| `sasVerification` | Auto-complete SAS device verification requests from allowed users (default `false`). Useful for Element X, which does not expose manual trust for third-party devices. | | `maxMediaBytes` | Max attachment size (default `20MB`). Set `0` to block all media. | diff --git a/nanobot/channels/matrix.py b/nanobot/channels/matrix.py index d2a1d95d3..ed429a0a3 100644 --- a/nanobot/channels/matrix.py +++ b/nanobot/channels/matrix.py @@ -23,6 +23,12 @@ try: AsyncClientConfig, InviteEvent, JoinError, + KeyVerificationCancel, + KeyVerificationEvent, + KeyVerificationKey, + KeyVerificationMac, + KeyVerificationStart, + LocalProtocolError, LoginResponse, MatrixRoom, RoomEncryptedMedia, @@ -33,6 +39,7 @@ try: RoomSendResponse, RoomTypingError, SyncError, + ToDeviceError, UploadError, ) from nio.crypto.attachments import decrypt_attachment @@ -194,6 +201,7 @@ class MatrixConfig(Base): access_token: str = "" device_id: str = "" e2ee_enabled: bool = Field(default=True, alias="e2eeEnabled") + sas_verification: bool = Field(default=False, alias="sasVerification") sync_stop_grace_seconds: int = 2 max_media_bytes: int = 20 * 1024 * 1024 max_concurrent_media_downloads: int = 2 @@ -268,6 +276,7 @@ class MatrixChannel(BaseChannel): ) self._register_event_callbacks() + self._register_to_device_callbacks() self._register_response_callbacks() if not self.config.e2ee_enabled: @@ -572,11 +581,102 @@ class MatrixChannel(BaseChannel): self.client.add_event_callback(self._on_media_message, MATRIX_MEDIA_EVENT_FILTER) self.client.add_event_callback(self._on_room_invite, InviteEvent) + def _register_to_device_callbacks(self) -> None: + if self.config.e2ee_enabled and self.config.sas_verification: + self.client.add_to_device_callback( + self._on_key_verification_event, + (KeyVerificationEvent,), + ) + def _register_response_callbacks(self) -> None: self.client.add_response_callback(self._on_sync_error, SyncError) self.client.add_response_callback(self._on_join_error, JoinError) self.client.add_response_callback(self._on_send_error, RoomSendError) + def _is_sas_sender_allowed(self, sender: str) -> bool: + return bool(sender and self.is_allowed(sender)) + + async def _on_key_verification_event(self, event: KeyVerificationEvent) -> None: + try: + await self._handle_key_verification_event(event) + except asyncio.CancelledError: + raise + except Exception: + self.logger.exception("Matrix SAS verification handling failed") + + async def _handle_key_verification_event(self, event: KeyVerificationEvent) -> None: + if not (self.config.e2ee_enabled and self.config.sas_verification): + return + if not self.client: + return + + sender = str(getattr(event, "sender", "") or "") + transaction_id = str(getattr(event, "transaction_id", "") or "") + if not transaction_id or not self._is_sas_sender_allowed(sender): + return + + if isinstance(event, KeyVerificationStart): + if "emoji" not in (getattr(event, "short_authentication_string", None) or []): + self.logger.info( + "Ignoring Matrix SAS verification from {} without emoji support", + sender, + ) + return + + response = await self.client.accept_key_verification(transaction_id) + if isinstance(response, ToDeviceError): + self.logger.warning("Matrix SAS accept failed for {}: {}", sender, response) + return + + sas = getattr(self.client, "key_verifications", {}).get(transaction_id) + if sas is None: + self.logger.warning( + "Matrix SAS state missing after accept for transaction {}", + transaction_id, + ) + return + + response = await self.client.to_device(sas.share_key()) + if isinstance(response, ToDeviceError): + self.logger.warning( + "Matrix SAS key share failed for {}: {}", + sender, + response, + ) + return + + if isinstance(event, KeyVerificationKey): + response = await self.client.confirm_short_auth_string(transaction_id) + if isinstance(response, ToDeviceError): + self.logger.warning("Matrix SAS confirm failed for {}: {}", sender, response) + return + + if isinstance(event, KeyVerificationMac): + sas = getattr(self.client, "key_verifications", {}).get(transaction_id) + if sas is None: + self.logger.warning( + "Matrix SAS state missing for MAC transaction {}", + transaction_id, + ) + return + try: + response = await self.client.to_device(sas.get_mac()) + except LocalProtocolError as e: + self.logger.warning("Matrix SAS MAC failed for {}: {}", sender, e) + return + if isinstance(response, ToDeviceError): + self.logger.warning("Matrix SAS MAC send failed for {}: {}", sender, response) + else: + self.logger.info("Matrix SAS verification completed for {}", sender) + return + + if isinstance(event, KeyVerificationCancel): + self.logger.info( + "Matrix SAS verification cancelled by {}: {}", + sender, + getattr(event, "reason", ""), + ) + def _is_fatal_auth_response(self, response: Any) -> bool: code = getattr(response, "status_code", None) is_auth = code in {"M_UNKNOWN_TOKEN", "M_FORBIDDEN", "M_UNAUTHORIZED"} diff --git a/tests/channels/test_matrix_channel.py b/tests/channels/test_matrix_channel.py index 2bf6a28cd..4ec0edd9e 100644 --- a/tests/channels/test_matrix_channel.py +++ b/tests/channels/test_matrix_channel.py @@ -50,7 +50,15 @@ class _FakeAsyncClient: self.stop_sync_forever_called = False self.join_calls: list[str] = [] self.callbacks: list[tuple[object, object]] = [] + self.to_device_callbacks: list[tuple[object, object]] = [] self.response_callbacks: list[tuple[object, object]] = [] + self.key_verifications: dict[str, object] = {} + self.accept_key_verification_calls: list[str] = [] + self.confirm_short_auth_string_calls: list[str] = [] + self.to_device_calls: list[object] = [] + self.accept_key_verification_response: object | None = None + self.confirm_short_auth_string_response: object | None = None + self.to_device_response: object | None = None self.rooms: dict[str, object] = {} self.room_send_calls: list[dict[str, object]] = [] self.typing_calls: list[tuple[str, bool, int]] = [] @@ -70,6 +78,9 @@ class _FakeAsyncClient: def add_event_callback(self, callback, event_type) -> None: self.callbacks.append((callback, event_type)) + def add_to_device_callback(self, callback, event_type) -> None: + self.to_device_callbacks.append((callback, event_type)) + def add_response_callback(self, callback, response_type) -> None: self.response_callbacks.append((callback, response_type)) @@ -82,6 +93,18 @@ class _FakeAsyncClient: async def join(self, room_id: str) -> None: self.join_calls.append(room_id) + async def accept_key_verification(self, transaction_id: str): + self.accept_key_verification_calls.append(transaction_id) + return self.accept_key_verification_response + + async def confirm_short_auth_string(self, transaction_id: str): + self.confirm_short_auth_string_calls.append(transaction_id) + return self.confirm_short_auth_string_response + + async def to_device(self, message): + self.to_device_calls.append(message) + return self.to_device_response + async def room_send( self, room_id: str, @@ -166,6 +189,61 @@ class _FakeAsyncClient: return None +class _FakeSas: + def __init__(self) -> None: + self.share_key_called = False + self.get_mac_called = False + + def share_key(self): + self.share_key_called = True + return {"type": "share_key"} + + def get_mac(self): + self.get_mac_called = True + return {"type": "mac"} + + +class _FakeKeyVerificationStart: + def __init__( + self, + *, + sender: str = "@alice:matrix.org", + transaction_id: str = "tx1", + short_authentication_string: list[str] | None = None, + ) -> None: + self.sender = sender + self.transaction_id = transaction_id + self.short_authentication_string = short_authentication_string or ["emoji"] + + +class _FakeKeyVerificationKey: + def __init__( + self, + *, + sender: str = "@alice:matrix.org", + transaction_id: str = "tx1", + ) -> None: + self.sender = sender + self.transaction_id = transaction_id + + +class _FakeKeyVerificationMac: + def __init__( + self, + *, + sender: str = "@alice:matrix.org", + transaction_id: str = "tx1", + ) -> None: + self.sender = sender + self.transaction_id = transaction_id + + +def _patch_key_verification_events(monkeypatch) -> None: + monkeypatch.setattr(matrix_module, "KeyVerificationStart", _FakeKeyVerificationStart) + monkeypatch.setattr(matrix_module, "KeyVerificationKey", _FakeKeyVerificationKey) + monkeypatch.setattr(matrix_module, "KeyVerificationMac", _FakeKeyVerificationMac) + + def _make_config(**kwargs) -> MatrixConfig: kwargs.setdefault("allow_from", ["*"]) return MatrixConfig( @@ -209,6 +287,7 @@ async def test_start_skips_load_store_when_device_id_missing( assert clients[0].config.encryption_enabled is True assert clients[0].load_store_called is False assert len(clients[0].callbacks) == 3 + assert clients[0].to_device_callbacks == [] assert len(clients[0].response_callbacks) == 3 await channel.stop() @@ -227,6 +306,119 @@ async def test_register_event_callbacks_uses_media_base_filter() -> None: assert client.callbacks[1][1] == matrix_module.MATRIX_MEDIA_EVENT_FILTER +def test_register_to_device_callbacks_when_sas_verification_enabled() -> None: + channel = MatrixChannel(_make_config(sas_verification=True), MessageBus()) + client = _FakeAsyncClient("", "", "", None) + channel.client = client + + channel._register_to_device_callbacks() + + assert client.to_device_callbacks == [ + (channel._on_key_verification_event, (matrix_module.KeyVerificationEvent,)) + ] + + +def test_register_to_device_callbacks_skips_when_e2ee_disabled() -> None: + channel = MatrixChannel( + _make_config(e2ee_enabled=False, sas_verification=True), + MessageBus(), + ) + client = _FakeAsyncClient("", "", "", None) + channel.client = client + + channel._register_to_device_callbacks() + + assert client.to_device_callbacks == [] + + +@pytest.mark.asyncio +async def test_sas_verification_start_accepts_allowed_sender(monkeypatch) -> None: + _patch_key_verification_events(monkeypatch) + channel = MatrixChannel( + _make_config(allow_from=["@alice:matrix.org"], sas_verification=True), + MessageBus(), + ) + client = _FakeAsyncClient("", "", "", None) + sas = _FakeSas() + client.key_verifications["tx1"] = sas + channel.client = client + + await channel._handle_key_verification_event(_FakeKeyVerificationStart()) + + assert client.accept_key_verification_calls == ["tx1"] + assert sas.share_key_called is True + assert client.to_device_calls == [{"type": "share_key"}] + + +@pytest.mark.asyncio +async def test_sas_verification_ignores_denied_sender(monkeypatch) -> None: + _patch_key_verification_events(monkeypatch) + channel = MatrixChannel( + _make_config(allow_from=["@alice:matrix.org"], sas_verification=True), + MessageBus(), + ) + client = _FakeAsyncClient("", "", "", None) + client.key_verifications["tx1"] = _FakeSas() + channel.client = client + + await channel._handle_key_verification_event( + _FakeKeyVerificationStart(sender="@mallory:matrix.org") + ) + + assert client.accept_key_verification_calls == [] + assert client.to_device_calls == [] + + +@pytest.mark.asyncio +async def test_sas_verification_ignores_when_disabled(monkeypatch) -> None: + _patch_key_verification_events(monkeypatch) + channel = MatrixChannel( + _make_config(allow_from=["@alice:matrix.org"], sas_verification=False), + MessageBus(), + ) + client = _FakeAsyncClient("", "", "", None) + client.key_verifications["tx1"] = _FakeSas() + channel.client = client + + await channel._handle_key_verification_event(_FakeKeyVerificationStart()) + + assert client.accept_key_verification_calls == [] + assert client.to_device_calls == [] + + +@pytest.mark.asyncio +async def test_sas_verification_key_confirms_allowed_sender(monkeypatch) -> None: + _patch_key_verification_events(monkeypatch) + channel = MatrixChannel( + _make_config(allow_from=["@alice:matrix.org"], sas_verification=True), + MessageBus(), + ) + client = _FakeAsyncClient("", "", "", None) + channel.client = client + + await channel._handle_key_verification_event(_FakeKeyVerificationKey()) + + assert client.confirm_short_auth_string_calls == ["tx1"] + + +@pytest.mark.asyncio +async def test_sas_verification_mac_sends_mac_for_allowed_sender(monkeypatch) -> None: + _patch_key_verification_events(monkeypatch) + channel = MatrixChannel( + _make_config(allow_from=["@alice:matrix.org"], sas_verification=True), + MessageBus(), + ) + client = _FakeAsyncClient("", "", "", None) + sas = _FakeSas() + client.key_verifications["tx1"] = sas + channel.client = client + + await channel._handle_key_verification_event(_FakeKeyVerificationMac()) + + assert sas.get_mac_called is True + assert client.to_device_calls == [{"type": "mac"}] + + def test_media_event_filter_does_not_match_text_events() -> None: assert not issubclass(matrix_module.RoomMessageText, matrix_module.MATRIX_MEDIA_EVENT_FILTER)