fix(matrix): align SAS verification message flow

This commit is contained in:
Xubin Ren 2026-05-31 00:54:38 +08:00
parent 68712fc489
commit 2b4c984e9a
2 changed files with 27 additions and 39 deletions

View File

@ -28,7 +28,6 @@ try:
KeyVerificationKey, KeyVerificationKey,
KeyVerificationMac, KeyVerificationMac,
KeyVerificationStart, KeyVerificationStart,
LocalProtocolError,
LoginResponse, LoginResponse,
MatrixRoom, MatrixRoom,
RoomEncryptedMedia, RoomEncryptedMedia,
@ -626,26 +625,14 @@ class MatrixChannel(BaseChannel):
response = await self.client.accept_key_verification(transaction_id) response = await self.client.accept_key_verification(transaction_id)
if isinstance(response, ToDeviceError): if isinstance(response, ToDeviceError):
self.logger.warning("Matrix SAS accept failed for {}: {}", sender, response) self.logger.warning("Matrix SAS accept failed for {}: {}", sender, response)
return
sas = getattr(self.client, "key_verifications", {}).get(transaction_id)
if sas is None:
self.logger.warning(
"Matrix SAS state missing after accept for transaction {}",
transaction_id,
)
return
response = await self.client.to_device(sas.share_key())
if isinstance(response, ToDeviceError):
self.logger.warning(
"Matrix SAS key share failed for {}: {}",
sender,
response,
)
return return
if isinstance(event, KeyVerificationKey): if isinstance(event, KeyVerificationKey):
responses = await self.client.send_to_device_messages()
if any(isinstance(response, ToDeviceError) for response in responses):
self.logger.warning("Matrix SAS key share failed for {}", sender)
return
response = await self.client.confirm_short_auth_string(transaction_id) response = await self.client.confirm_short_auth_string(transaction_id)
if isinstance(response, ToDeviceError): if isinstance(response, ToDeviceError):
self.logger.warning("Matrix SAS confirm failed for {}: {}", sender, response) self.logger.warning("Matrix SAS confirm failed for {}: {}", sender, response)
@ -653,20 +640,7 @@ class MatrixChannel(BaseChannel):
if isinstance(event, KeyVerificationMac): if isinstance(event, KeyVerificationMac):
sas = getattr(self.client, "key_verifications", {}).get(transaction_id) sas = getattr(self.client, "key_verifications", {}).get(transaction_id)
if sas is None: if sas is not None and getattr(sas, "verified", False):
self.logger.warning(
"Matrix SAS state missing for MAC transaction {}",
transaction_id,
)
return
try:
response = await self.client.to_device(sas.get_mac())
except LocalProtocolError as e:
self.logger.warning("Matrix SAS MAC failed for {}: {}", sender, e)
return
if isinstance(response, ToDeviceError):
self.logger.warning("Matrix SAS MAC send failed for {}: {}", sender, response)
else:
self.logger.info("Matrix SAS verification completed for {}", sender) self.logger.info("Matrix SAS verification completed for {}", sender)
return return

View File

@ -53,11 +53,14 @@ class _FakeAsyncClient:
self.to_device_callbacks: list[tuple[object, object]] = [] self.to_device_callbacks: list[tuple[object, object]] = []
self.response_callbacks: list[tuple[object, object]] = [] self.response_callbacks: list[tuple[object, object]] = []
self.key_verifications: dict[str, object] = {} self.key_verifications: dict[str, object] = {}
self.operation_calls: list[str] = []
self.accept_key_verification_calls: list[str] = [] self.accept_key_verification_calls: list[str] = []
self.confirm_short_auth_string_calls: list[str] = [] self.confirm_short_auth_string_calls: list[str] = []
self.send_to_device_messages_calls = 0
self.to_device_calls: list[object] = [] self.to_device_calls: list[object] = []
self.accept_key_verification_response: object | None = None self.accept_key_verification_response: object | None = None
self.confirm_short_auth_string_response: object | None = None self.confirm_short_auth_string_response: object | None = None
self.send_to_device_messages_response: list[object] = []
self.to_device_response: object | None = None self.to_device_response: object | None = None
self.rooms: dict[str, object] = {} self.rooms: dict[str, object] = {}
self.room_send_calls: list[dict[str, object]] = [] self.room_send_calls: list[dict[str, object]] = []
@ -94,14 +97,22 @@ class _FakeAsyncClient:
self.join_calls.append(room_id) self.join_calls.append(room_id)
async def accept_key_verification(self, transaction_id: str): async def accept_key_verification(self, transaction_id: str):
self.operation_calls.append(f"accept:{transaction_id}")
self.accept_key_verification_calls.append(transaction_id) self.accept_key_verification_calls.append(transaction_id)
return self.accept_key_verification_response return self.accept_key_verification_response
async def confirm_short_auth_string(self, transaction_id: str): async def confirm_short_auth_string(self, transaction_id: str):
self.operation_calls.append(f"confirm:{transaction_id}")
self.confirm_short_auth_string_calls.append(transaction_id) self.confirm_short_auth_string_calls.append(transaction_id)
return self.confirm_short_auth_string_response return self.confirm_short_auth_string_response
async def send_to_device_messages(self):
self.operation_calls.append("send_pending")
self.send_to_device_messages_calls += 1
return self.send_to_device_messages_response
async def to_device(self, message): async def to_device(self, message):
self.operation_calls.append("to_device")
self.to_device_calls.append(message) self.to_device_calls.append(message)
return self.to_device_response return self.to_device_response
@ -190,9 +201,10 @@ class _FakeAsyncClient:
class _FakeSas: class _FakeSas:
def __init__(self) -> None: def __init__(self, *, verified: bool = False) -> None:
self.share_key_called = False self.share_key_called = False
self.get_mac_called = False self.get_mac_called = False
self.verified = verified
def share_key(self): def share_key(self):
self.share_key_called = True self.share_key_called = True
@ -346,8 +358,8 @@ async def test_sas_verification_start_accepts_allowed_sender(monkeypatch) -> Non
await channel._handle_key_verification_event(_FakeKeyVerificationStart()) await channel._handle_key_verification_event(_FakeKeyVerificationStart())
assert client.accept_key_verification_calls == ["tx1"] assert client.accept_key_verification_calls == ["tx1"]
assert sas.share_key_called is True assert sas.share_key_called is False
assert client.to_device_calls == [{"type": "share_key"}] assert client.to_device_calls == []
@pytest.mark.asyncio @pytest.mark.asyncio
@ -398,25 +410,27 @@ async def test_sas_verification_key_confirms_allowed_sender(monkeypatch) -> None
await channel._handle_key_verification_event(_FakeKeyVerificationKey()) await channel._handle_key_verification_event(_FakeKeyVerificationKey())
assert client.send_to_device_messages_calls == 1
assert client.confirm_short_auth_string_calls == ["tx1"] assert client.confirm_short_auth_string_calls == ["tx1"]
assert client.operation_calls == ["send_pending", "confirm:tx1"]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_sas_verification_mac_sends_mac_for_allowed_sender(monkeypatch) -> None: async def test_sas_verification_mac_does_not_resend_mac(monkeypatch) -> None:
_patch_key_verification_events(monkeypatch) _patch_key_verification_events(monkeypatch)
channel = MatrixChannel( channel = MatrixChannel(
_make_config(allow_from=["@alice:matrix.org"], sas_verification=True), _make_config(allow_from=["@alice:matrix.org"], sas_verification=True),
MessageBus(), MessageBus(),
) )
client = _FakeAsyncClient("", "", "", None) client = _FakeAsyncClient("", "", "", None)
sas = _FakeSas() sas = _FakeSas(verified=True)
client.key_verifications["tx1"] = sas client.key_verifications["tx1"] = sas
channel.client = client channel.client = client
await channel._handle_key_verification_event(_FakeKeyVerificationMac()) await channel._handle_key_verification_event(_FakeKeyVerificationMac())
assert sas.get_mac_called is True assert sas.get_mac_called is False
assert client.to_device_calls == [{"type": "mac"}] assert client.to_device_calls == []
def test_media_event_filter_does_not_match_text_events() -> None: def test_media_event_filter_does_not_match_text_events() -> None: