mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-15 07:14:08 +00:00
fix(matrix): align SAS verification message flow
This commit is contained in:
parent
68712fc489
commit
2b4c984e9a
@ -28,7 +28,6 @@ try:
|
||||
KeyVerificationKey,
|
||||
KeyVerificationMac,
|
||||
KeyVerificationStart,
|
||||
LocalProtocolError,
|
||||
LoginResponse,
|
||||
MatrixRoom,
|
||||
RoomEncryptedMedia,
|
||||
@ -626,26 +625,14 @@ class MatrixChannel(BaseChannel):
|
||||
response = await self.client.accept_key_verification(transaction_id)
|
||||
if isinstance(response, ToDeviceError):
|
||||
self.logger.warning("Matrix SAS accept failed for {}: {}", sender, response)
|
||||
return
|
||||
|
||||
sas = getattr(self.client, "key_verifications", {}).get(transaction_id)
|
||||
if sas is None:
|
||||
self.logger.warning(
|
||||
"Matrix SAS state missing after accept for transaction {}",
|
||||
transaction_id,
|
||||
)
|
||||
return
|
||||
|
||||
response = await self.client.to_device(sas.share_key())
|
||||
if isinstance(response, ToDeviceError):
|
||||
self.logger.warning(
|
||||
"Matrix SAS key share failed for {}: {}",
|
||||
sender,
|
||||
response,
|
||||
)
|
||||
return
|
||||
|
||||
if isinstance(event, KeyVerificationKey):
|
||||
responses = await self.client.send_to_device_messages()
|
||||
if any(isinstance(response, ToDeviceError) for response in responses):
|
||||
self.logger.warning("Matrix SAS key share failed for {}", sender)
|
||||
return
|
||||
|
||||
response = await self.client.confirm_short_auth_string(transaction_id)
|
||||
if isinstance(response, ToDeviceError):
|
||||
self.logger.warning("Matrix SAS confirm failed for {}: {}", sender, response)
|
||||
@ -653,20 +640,7 @@ class MatrixChannel(BaseChannel):
|
||||
|
||||
if isinstance(event, KeyVerificationMac):
|
||||
sas = getattr(self.client, "key_verifications", {}).get(transaction_id)
|
||||
if sas is None:
|
||||
self.logger.warning(
|
||||
"Matrix SAS state missing for MAC transaction {}",
|
||||
transaction_id,
|
||||
)
|
||||
return
|
||||
try:
|
||||
response = await self.client.to_device(sas.get_mac())
|
||||
except LocalProtocolError as e:
|
||||
self.logger.warning("Matrix SAS MAC failed for {}: {}", sender, e)
|
||||
return
|
||||
if isinstance(response, ToDeviceError):
|
||||
self.logger.warning("Matrix SAS MAC send failed for {}: {}", sender, response)
|
||||
else:
|
||||
if sas is not None and getattr(sas, "verified", False):
|
||||
self.logger.info("Matrix SAS verification completed for {}", sender)
|
||||
return
|
||||
|
||||
|
||||
@ -53,11 +53,14 @@ class _FakeAsyncClient:
|
||||
self.to_device_callbacks: list[tuple[object, object]] = []
|
||||
self.response_callbacks: list[tuple[object, object]] = []
|
||||
self.key_verifications: dict[str, object] = {}
|
||||
self.operation_calls: list[str] = []
|
||||
self.accept_key_verification_calls: list[str] = []
|
||||
self.confirm_short_auth_string_calls: list[str] = []
|
||||
self.send_to_device_messages_calls = 0
|
||||
self.to_device_calls: list[object] = []
|
||||
self.accept_key_verification_response: object | None = None
|
||||
self.confirm_short_auth_string_response: object | None = None
|
||||
self.send_to_device_messages_response: list[object] = []
|
||||
self.to_device_response: object | None = None
|
||||
self.rooms: dict[str, object] = {}
|
||||
self.room_send_calls: list[dict[str, object]] = []
|
||||
@ -94,14 +97,22 @@ class _FakeAsyncClient:
|
||||
self.join_calls.append(room_id)
|
||||
|
||||
async def accept_key_verification(self, transaction_id: str):
|
||||
self.operation_calls.append(f"accept:{transaction_id}")
|
||||
self.accept_key_verification_calls.append(transaction_id)
|
||||
return self.accept_key_verification_response
|
||||
|
||||
async def confirm_short_auth_string(self, transaction_id: str):
|
||||
self.operation_calls.append(f"confirm:{transaction_id}")
|
||||
self.confirm_short_auth_string_calls.append(transaction_id)
|
||||
return self.confirm_short_auth_string_response
|
||||
|
||||
async def send_to_device_messages(self):
|
||||
self.operation_calls.append("send_pending")
|
||||
self.send_to_device_messages_calls += 1
|
||||
return self.send_to_device_messages_response
|
||||
|
||||
async def to_device(self, message):
|
||||
self.operation_calls.append("to_device")
|
||||
self.to_device_calls.append(message)
|
||||
return self.to_device_response
|
||||
|
||||
@ -190,9 +201,10 @@ class _FakeAsyncClient:
|
||||
|
||||
|
||||
class _FakeSas:
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, *, verified: bool = False) -> None:
|
||||
self.share_key_called = False
|
||||
self.get_mac_called = False
|
||||
self.verified = verified
|
||||
|
||||
def share_key(self):
|
||||
self.share_key_called = True
|
||||
@ -346,8 +358,8 @@ async def test_sas_verification_start_accepts_allowed_sender(monkeypatch) -> Non
|
||||
await channel._handle_key_verification_event(_FakeKeyVerificationStart())
|
||||
|
||||
assert client.accept_key_verification_calls == ["tx1"]
|
||||
assert sas.share_key_called is True
|
||||
assert client.to_device_calls == [{"type": "share_key"}]
|
||||
assert sas.share_key_called is False
|
||||
assert client.to_device_calls == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -398,25 +410,27 @@ async def test_sas_verification_key_confirms_allowed_sender(monkeypatch) -> None
|
||||
|
||||
await channel._handle_key_verification_event(_FakeKeyVerificationKey())
|
||||
|
||||
assert client.send_to_device_messages_calls == 1
|
||||
assert client.confirm_short_auth_string_calls == ["tx1"]
|
||||
assert client.operation_calls == ["send_pending", "confirm:tx1"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sas_verification_mac_sends_mac_for_allowed_sender(monkeypatch) -> None:
|
||||
async def test_sas_verification_mac_does_not_resend_mac(monkeypatch) -> None:
|
||||
_patch_key_verification_events(monkeypatch)
|
||||
channel = MatrixChannel(
|
||||
_make_config(allow_from=["@alice:matrix.org"], sas_verification=True),
|
||||
MessageBus(),
|
||||
)
|
||||
client = _FakeAsyncClient("", "", "", None)
|
||||
sas = _FakeSas()
|
||||
sas = _FakeSas(verified=True)
|
||||
client.key_verifications["tx1"] = sas
|
||||
channel.client = client
|
||||
|
||||
await channel._handle_key_verification_event(_FakeKeyVerificationMac())
|
||||
|
||||
assert sas.get_mac_called is True
|
||||
assert client.to_device_calls == [{"type": "mac"}]
|
||||
assert sas.get_mac_called is False
|
||||
assert client.to_device_calls == []
|
||||
|
||||
|
||||
def test_media_event_filter_does_not_match_text_events() -> None:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user