mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-15 07:14:08 +00:00
fix(matrix): align SAS verification message flow
This commit is contained in:
parent
68712fc489
commit
2b4c984e9a
@ -28,7 +28,6 @@ try:
|
|||||||
KeyVerificationKey,
|
KeyVerificationKey,
|
||||||
KeyVerificationMac,
|
KeyVerificationMac,
|
||||||
KeyVerificationStart,
|
KeyVerificationStart,
|
||||||
LocalProtocolError,
|
|
||||||
LoginResponse,
|
LoginResponse,
|
||||||
MatrixRoom,
|
MatrixRoom,
|
||||||
RoomEncryptedMedia,
|
RoomEncryptedMedia,
|
||||||
@ -628,24 +627,12 @@ class MatrixChannel(BaseChannel):
|
|||||||
self.logger.warning("Matrix SAS accept failed for {}: {}", sender, response)
|
self.logger.warning("Matrix SAS accept failed for {}: {}", sender, response)
|
||||||
return
|
return
|
||||||
|
|
||||||
sas = getattr(self.client, "key_verifications", {}).get(transaction_id)
|
|
||||||
if sas is None:
|
|
||||||
self.logger.warning(
|
|
||||||
"Matrix SAS state missing after accept for transaction {}",
|
|
||||||
transaction_id,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
response = await self.client.to_device(sas.share_key())
|
|
||||||
if isinstance(response, ToDeviceError):
|
|
||||||
self.logger.warning(
|
|
||||||
"Matrix SAS key share failed for {}: {}",
|
|
||||||
sender,
|
|
||||||
response,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
if isinstance(event, KeyVerificationKey):
|
if isinstance(event, KeyVerificationKey):
|
||||||
|
responses = await self.client.send_to_device_messages()
|
||||||
|
if any(isinstance(response, ToDeviceError) for response in responses):
|
||||||
|
self.logger.warning("Matrix SAS key share failed for {}", sender)
|
||||||
|
return
|
||||||
|
|
||||||
response = await self.client.confirm_short_auth_string(transaction_id)
|
response = await self.client.confirm_short_auth_string(transaction_id)
|
||||||
if isinstance(response, ToDeviceError):
|
if isinstance(response, ToDeviceError):
|
||||||
self.logger.warning("Matrix SAS confirm failed for {}: {}", sender, response)
|
self.logger.warning("Matrix SAS confirm failed for {}: {}", sender, response)
|
||||||
@ -653,20 +640,7 @@ class MatrixChannel(BaseChannel):
|
|||||||
|
|
||||||
if isinstance(event, KeyVerificationMac):
|
if isinstance(event, KeyVerificationMac):
|
||||||
sas = getattr(self.client, "key_verifications", {}).get(transaction_id)
|
sas = getattr(self.client, "key_verifications", {}).get(transaction_id)
|
||||||
if sas is None:
|
if sas is not None and getattr(sas, "verified", False):
|
||||||
self.logger.warning(
|
|
||||||
"Matrix SAS state missing for MAC transaction {}",
|
|
||||||
transaction_id,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
try:
|
|
||||||
response = await self.client.to_device(sas.get_mac())
|
|
||||||
except LocalProtocolError as e:
|
|
||||||
self.logger.warning("Matrix SAS MAC failed for {}: {}", sender, e)
|
|
||||||
return
|
|
||||||
if isinstance(response, ToDeviceError):
|
|
||||||
self.logger.warning("Matrix SAS MAC send failed for {}: {}", sender, response)
|
|
||||||
else:
|
|
||||||
self.logger.info("Matrix SAS verification completed for {}", sender)
|
self.logger.info("Matrix SAS verification completed for {}", sender)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
@ -53,11 +53,14 @@ class _FakeAsyncClient:
|
|||||||
self.to_device_callbacks: list[tuple[object, object]] = []
|
self.to_device_callbacks: list[tuple[object, object]] = []
|
||||||
self.response_callbacks: list[tuple[object, object]] = []
|
self.response_callbacks: list[tuple[object, object]] = []
|
||||||
self.key_verifications: dict[str, object] = {}
|
self.key_verifications: dict[str, object] = {}
|
||||||
|
self.operation_calls: list[str] = []
|
||||||
self.accept_key_verification_calls: list[str] = []
|
self.accept_key_verification_calls: list[str] = []
|
||||||
self.confirm_short_auth_string_calls: list[str] = []
|
self.confirm_short_auth_string_calls: list[str] = []
|
||||||
|
self.send_to_device_messages_calls = 0
|
||||||
self.to_device_calls: list[object] = []
|
self.to_device_calls: list[object] = []
|
||||||
self.accept_key_verification_response: object | None = None
|
self.accept_key_verification_response: object | None = None
|
||||||
self.confirm_short_auth_string_response: object | None = None
|
self.confirm_short_auth_string_response: object | None = None
|
||||||
|
self.send_to_device_messages_response: list[object] = []
|
||||||
self.to_device_response: object | None = None
|
self.to_device_response: object | None = None
|
||||||
self.rooms: dict[str, object] = {}
|
self.rooms: dict[str, object] = {}
|
||||||
self.room_send_calls: list[dict[str, object]] = []
|
self.room_send_calls: list[dict[str, object]] = []
|
||||||
@ -94,14 +97,22 @@ class _FakeAsyncClient:
|
|||||||
self.join_calls.append(room_id)
|
self.join_calls.append(room_id)
|
||||||
|
|
||||||
async def accept_key_verification(self, transaction_id: str):
|
async def accept_key_verification(self, transaction_id: str):
|
||||||
|
self.operation_calls.append(f"accept:{transaction_id}")
|
||||||
self.accept_key_verification_calls.append(transaction_id)
|
self.accept_key_verification_calls.append(transaction_id)
|
||||||
return self.accept_key_verification_response
|
return self.accept_key_verification_response
|
||||||
|
|
||||||
async def confirm_short_auth_string(self, transaction_id: str):
|
async def confirm_short_auth_string(self, transaction_id: str):
|
||||||
|
self.operation_calls.append(f"confirm:{transaction_id}")
|
||||||
self.confirm_short_auth_string_calls.append(transaction_id)
|
self.confirm_short_auth_string_calls.append(transaction_id)
|
||||||
return self.confirm_short_auth_string_response
|
return self.confirm_short_auth_string_response
|
||||||
|
|
||||||
|
async def send_to_device_messages(self):
|
||||||
|
self.operation_calls.append("send_pending")
|
||||||
|
self.send_to_device_messages_calls += 1
|
||||||
|
return self.send_to_device_messages_response
|
||||||
|
|
||||||
async def to_device(self, message):
|
async def to_device(self, message):
|
||||||
|
self.operation_calls.append("to_device")
|
||||||
self.to_device_calls.append(message)
|
self.to_device_calls.append(message)
|
||||||
return self.to_device_response
|
return self.to_device_response
|
||||||
|
|
||||||
@ -190,9 +201,10 @@ class _FakeAsyncClient:
|
|||||||
|
|
||||||
|
|
||||||
class _FakeSas:
|
class _FakeSas:
|
||||||
def __init__(self) -> None:
|
def __init__(self, *, verified: bool = False) -> None:
|
||||||
self.share_key_called = False
|
self.share_key_called = False
|
||||||
self.get_mac_called = False
|
self.get_mac_called = False
|
||||||
|
self.verified = verified
|
||||||
|
|
||||||
def share_key(self):
|
def share_key(self):
|
||||||
self.share_key_called = True
|
self.share_key_called = True
|
||||||
@ -346,8 +358,8 @@ async def test_sas_verification_start_accepts_allowed_sender(monkeypatch) -> Non
|
|||||||
await channel._handle_key_verification_event(_FakeKeyVerificationStart())
|
await channel._handle_key_verification_event(_FakeKeyVerificationStart())
|
||||||
|
|
||||||
assert client.accept_key_verification_calls == ["tx1"]
|
assert client.accept_key_verification_calls == ["tx1"]
|
||||||
assert sas.share_key_called is True
|
assert sas.share_key_called is False
|
||||||
assert client.to_device_calls == [{"type": "share_key"}]
|
assert client.to_device_calls == []
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -398,25 +410,27 @@ async def test_sas_verification_key_confirms_allowed_sender(monkeypatch) -> None
|
|||||||
|
|
||||||
await channel._handle_key_verification_event(_FakeKeyVerificationKey())
|
await channel._handle_key_verification_event(_FakeKeyVerificationKey())
|
||||||
|
|
||||||
|
assert client.send_to_device_messages_calls == 1
|
||||||
assert client.confirm_short_auth_string_calls == ["tx1"]
|
assert client.confirm_short_auth_string_calls == ["tx1"]
|
||||||
|
assert client.operation_calls == ["send_pending", "confirm:tx1"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_sas_verification_mac_sends_mac_for_allowed_sender(monkeypatch) -> None:
|
async def test_sas_verification_mac_does_not_resend_mac(monkeypatch) -> None:
|
||||||
_patch_key_verification_events(monkeypatch)
|
_patch_key_verification_events(monkeypatch)
|
||||||
channel = MatrixChannel(
|
channel = MatrixChannel(
|
||||||
_make_config(allow_from=["@alice:matrix.org"], sas_verification=True),
|
_make_config(allow_from=["@alice:matrix.org"], sas_verification=True),
|
||||||
MessageBus(),
|
MessageBus(),
|
||||||
)
|
)
|
||||||
client = _FakeAsyncClient("", "", "", None)
|
client = _FakeAsyncClient("", "", "", None)
|
||||||
sas = _FakeSas()
|
sas = _FakeSas(verified=True)
|
||||||
client.key_verifications["tx1"] = sas
|
client.key_verifications["tx1"] = sas
|
||||||
channel.client = client
|
channel.client = client
|
||||||
|
|
||||||
await channel._handle_key_verification_event(_FakeKeyVerificationMac())
|
await channel._handle_key_verification_event(_FakeKeyVerificationMac())
|
||||||
|
|
||||||
assert sas.get_mac_called is True
|
assert sas.get_mac_called is False
|
||||||
assert client.to_device_calls == [{"type": "mac"}]
|
assert client.to_device_calls == []
|
||||||
|
|
||||||
|
|
||||||
def test_media_event_filter_does_not_match_text_events() -> None:
|
def test_media_event_filter_does_not_match_text_events() -> None:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user