diff --git a/nanobot/channels/matrix.py b/nanobot/channels/matrix.py index ed429a0a3..abfa2b13a 100644 --- a/nanobot/channels/matrix.py +++ b/nanobot/channels/matrix.py @@ -28,7 +28,6 @@ try: KeyVerificationKey, KeyVerificationMac, KeyVerificationStart, - LocalProtocolError, LoginResponse, MatrixRoom, RoomEncryptedMedia, @@ -626,26 +625,14 @@ class MatrixChannel(BaseChannel): response = await self.client.accept_key_verification(transaction_id) if isinstance(response, ToDeviceError): self.logger.warning("Matrix SAS accept failed for {}: {}", sender, response) - return - - sas = getattr(self.client, "key_verifications", {}).get(transaction_id) - if sas is None: - self.logger.warning( - "Matrix SAS state missing after accept for transaction {}", - transaction_id, - ) - return - - response = await self.client.to_device(sas.share_key()) - if isinstance(response, ToDeviceError): - self.logger.warning( - "Matrix SAS key share failed for {}: {}", - sender, - response, - ) return if isinstance(event, KeyVerificationKey): + responses = await self.client.send_to_device_messages() + if any(isinstance(response, ToDeviceError) for response in responses): + self.logger.warning("Matrix SAS key share failed for {}", sender) + return + response = await self.client.confirm_short_auth_string(transaction_id) if isinstance(response, ToDeviceError): self.logger.warning("Matrix SAS confirm failed for {}: {}", sender, response) @@ -653,20 +640,7 @@ class MatrixChannel(BaseChannel): if isinstance(event, KeyVerificationMac): sas = getattr(self.client, "key_verifications", {}).get(transaction_id) - if sas is None: - self.logger.warning( - "Matrix SAS state missing for MAC transaction {}", - transaction_id, - ) - return - try: - response = await self.client.to_device(sas.get_mac()) - except LocalProtocolError as e: - self.logger.warning("Matrix SAS MAC failed for {}: {}", sender, e) - return - if isinstance(response, ToDeviceError): - self.logger.warning("Matrix SAS MAC send failed for {}: {}", sender, response) - else: + if sas is not None and getattr(sas, "verified", False): self.logger.info("Matrix SAS verification completed for {}", sender) return diff --git a/tests/channels/test_matrix_channel.py b/tests/channels/test_matrix_channel.py index 4ec0edd9e..c8fc58c48 100644 --- a/tests/channels/test_matrix_channel.py +++ b/tests/channels/test_matrix_channel.py @@ -53,11 +53,14 @@ class _FakeAsyncClient: self.to_device_callbacks: list[tuple[object, object]] = [] self.response_callbacks: list[tuple[object, object]] = [] self.key_verifications: dict[str, object] = {} + self.operation_calls: list[str] = [] self.accept_key_verification_calls: list[str] = [] self.confirm_short_auth_string_calls: list[str] = [] + self.send_to_device_messages_calls = 0 self.to_device_calls: list[object] = [] self.accept_key_verification_response: object | None = None self.confirm_short_auth_string_response: object | None = None + self.send_to_device_messages_response: list[object] = [] self.to_device_response: object | None = None self.rooms: dict[str, object] = {} self.room_send_calls: list[dict[str, object]] = [] @@ -94,14 +97,22 @@ class _FakeAsyncClient: self.join_calls.append(room_id) async def accept_key_verification(self, transaction_id: str): + self.operation_calls.append(f"accept:{transaction_id}") self.accept_key_verification_calls.append(transaction_id) return self.accept_key_verification_response async def confirm_short_auth_string(self, transaction_id: str): + self.operation_calls.append(f"confirm:{transaction_id}") self.confirm_short_auth_string_calls.append(transaction_id) return self.confirm_short_auth_string_response + async def send_to_device_messages(self): + self.operation_calls.append("send_pending") + self.send_to_device_messages_calls += 1 + return self.send_to_device_messages_response + async def to_device(self, message): + self.operation_calls.append("to_device") self.to_device_calls.append(message) return self.to_device_response @@ -190,9 +201,10 @@ class _FakeAsyncClient: class _FakeSas: - def __init__(self) -> None: + def __init__(self, *, verified: bool = False) -> None: self.share_key_called = False self.get_mac_called = False + self.verified = verified def share_key(self): self.share_key_called = True @@ -346,8 +358,8 @@ async def test_sas_verification_start_accepts_allowed_sender(monkeypatch) -> Non await channel._handle_key_verification_event(_FakeKeyVerificationStart()) assert client.accept_key_verification_calls == ["tx1"] - assert sas.share_key_called is True - assert client.to_device_calls == [{"type": "share_key"}] + assert sas.share_key_called is False + assert client.to_device_calls == [] @pytest.mark.asyncio @@ -398,25 +410,27 @@ async def test_sas_verification_key_confirms_allowed_sender(monkeypatch) -> None await channel._handle_key_verification_event(_FakeKeyVerificationKey()) + assert client.send_to_device_messages_calls == 1 assert client.confirm_short_auth_string_calls == ["tx1"] + assert client.operation_calls == ["send_pending", "confirm:tx1"] @pytest.mark.asyncio -async def test_sas_verification_mac_sends_mac_for_allowed_sender(monkeypatch) -> None: +async def test_sas_verification_mac_does_not_resend_mac(monkeypatch) -> None: _patch_key_verification_events(monkeypatch) channel = MatrixChannel( _make_config(allow_from=["@alice:matrix.org"], sas_verification=True), MessageBus(), ) client = _FakeAsyncClient("", "", "", None) - sas = _FakeSas() + sas = _FakeSas(verified=True) client.key_verifications["tx1"] = sas channel.client = client await channel._handle_key_verification_event(_FakeKeyVerificationMac()) - assert sas.get_mac_called is True - assert client.to_device_calls == [{"type": "mac"}] + assert sas.get_mac_called is False + assert client.to_device_calls == [] def test_media_event_filter_does_not_match_text_events() -> None: