mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-15 07:14:08 +00:00
fix(matrix): handle SAS device verification
This commit is contained in:
parent
0cc58a80a4
commit
68712fc489
@ -244,6 +244,7 @@ for reliable encryption, password login is recommended instead. If the
|
||||
"userId": "@nanobot:matrix.org",
|
||||
"password": "mypasswordhere",
|
||||
"e2eeEnabled": true,
|
||||
"sasVerification": true,
|
||||
"allowFrom": ["@your_user:matrix.org"],
|
||||
"groupPolicy": "open",
|
||||
"groupAllowFrom": [],
|
||||
@ -263,6 +264,7 @@ for reliable encryption, password login is recommended instead. If the
|
||||
| `groupAllowFrom` | Room allowlist (used when policy is `allowlist`). |
|
||||
| `allowRoomMentions` | Accept `@room` mentions in mention mode. |
|
||||
| `e2eeEnabled` | E2EE support (default `true`). Set `false` for plaintext-only. |
|
||||
| `sasVerification` | Auto-complete SAS device verification requests from allowed users (default `false`). Useful for Element X, which does not expose manual trust for third-party devices. |
|
||||
| `maxMediaBytes` | Max attachment size (default `20MB`). Set `0` to block all media. |
|
||||
|
||||
|
||||
|
||||
@ -23,6 +23,12 @@ try:
|
||||
AsyncClientConfig,
|
||||
InviteEvent,
|
||||
JoinError,
|
||||
KeyVerificationCancel,
|
||||
KeyVerificationEvent,
|
||||
KeyVerificationKey,
|
||||
KeyVerificationMac,
|
||||
KeyVerificationStart,
|
||||
LocalProtocolError,
|
||||
LoginResponse,
|
||||
MatrixRoom,
|
||||
RoomEncryptedMedia,
|
||||
@ -33,6 +39,7 @@ try:
|
||||
RoomSendResponse,
|
||||
RoomTypingError,
|
||||
SyncError,
|
||||
ToDeviceError,
|
||||
UploadError,
|
||||
)
|
||||
from nio.crypto.attachments import decrypt_attachment
|
||||
@ -194,6 +201,7 @@ class MatrixConfig(Base):
|
||||
access_token: str = ""
|
||||
device_id: str = ""
|
||||
e2ee_enabled: bool = Field(default=True, alias="e2eeEnabled")
|
||||
sas_verification: bool = Field(default=False, alias="sasVerification")
|
||||
sync_stop_grace_seconds: int = 2
|
||||
max_media_bytes: int = 20 * 1024 * 1024
|
||||
max_concurrent_media_downloads: int = 2
|
||||
@ -268,6 +276,7 @@ class MatrixChannel(BaseChannel):
|
||||
)
|
||||
|
||||
self._register_event_callbacks()
|
||||
self._register_to_device_callbacks()
|
||||
self._register_response_callbacks()
|
||||
|
||||
if not self.config.e2ee_enabled:
|
||||
@ -572,11 +581,102 @@ class MatrixChannel(BaseChannel):
|
||||
self.client.add_event_callback(self._on_media_message, MATRIX_MEDIA_EVENT_FILTER)
|
||||
self.client.add_event_callback(self._on_room_invite, InviteEvent)
|
||||
|
||||
def _register_to_device_callbacks(self) -> None:
|
||||
if self.config.e2ee_enabled and self.config.sas_verification:
|
||||
self.client.add_to_device_callback(
|
||||
self._on_key_verification_event,
|
||||
(KeyVerificationEvent,),
|
||||
)
|
||||
|
||||
def _register_response_callbacks(self) -> None:
|
||||
self.client.add_response_callback(self._on_sync_error, SyncError)
|
||||
self.client.add_response_callback(self._on_join_error, JoinError)
|
||||
self.client.add_response_callback(self._on_send_error, RoomSendError)
|
||||
|
||||
def _is_sas_sender_allowed(self, sender: str) -> bool:
|
||||
return bool(sender and self.is_allowed(sender))
|
||||
|
||||
async def _on_key_verification_event(self, event: KeyVerificationEvent) -> None:
|
||||
try:
|
||||
await self._handle_key_verification_event(event)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception:
|
||||
self.logger.exception("Matrix SAS verification handling failed")
|
||||
|
||||
async def _handle_key_verification_event(self, event: KeyVerificationEvent) -> None:
|
||||
if not (self.config.e2ee_enabled and self.config.sas_verification):
|
||||
return
|
||||
if not self.client:
|
||||
return
|
||||
|
||||
sender = str(getattr(event, "sender", "") or "")
|
||||
transaction_id = str(getattr(event, "transaction_id", "") or "")
|
||||
if not transaction_id or not self._is_sas_sender_allowed(sender):
|
||||
return
|
||||
|
||||
if isinstance(event, KeyVerificationStart):
|
||||
if "emoji" not in (getattr(event, "short_authentication_string", None) or []):
|
||||
self.logger.info(
|
||||
"Ignoring Matrix SAS verification from {} without emoji support",
|
||||
sender,
|
||||
)
|
||||
return
|
||||
|
||||
response = await self.client.accept_key_verification(transaction_id)
|
||||
if isinstance(response, ToDeviceError):
|
||||
self.logger.warning("Matrix SAS accept failed for {}: {}", sender, response)
|
||||
return
|
||||
|
||||
sas = getattr(self.client, "key_verifications", {}).get(transaction_id)
|
||||
if sas is None:
|
||||
self.logger.warning(
|
||||
"Matrix SAS state missing after accept for transaction {}",
|
||||
transaction_id,
|
||||
)
|
||||
return
|
||||
|
||||
response = await self.client.to_device(sas.share_key())
|
||||
if isinstance(response, ToDeviceError):
|
||||
self.logger.warning(
|
||||
"Matrix SAS key share failed for {}: {}",
|
||||
sender,
|
||||
response,
|
||||
)
|
||||
return
|
||||
|
||||
if isinstance(event, KeyVerificationKey):
|
||||
response = await self.client.confirm_short_auth_string(transaction_id)
|
||||
if isinstance(response, ToDeviceError):
|
||||
self.logger.warning("Matrix SAS confirm failed for {}: {}", sender, response)
|
||||
return
|
||||
|
||||
if isinstance(event, KeyVerificationMac):
|
||||
sas = getattr(self.client, "key_verifications", {}).get(transaction_id)
|
||||
if sas is None:
|
||||
self.logger.warning(
|
||||
"Matrix SAS state missing for MAC transaction {}",
|
||||
transaction_id,
|
||||
)
|
||||
return
|
||||
try:
|
||||
response = await self.client.to_device(sas.get_mac())
|
||||
except LocalProtocolError as e:
|
||||
self.logger.warning("Matrix SAS MAC failed for {}: {}", sender, e)
|
||||
return
|
||||
if isinstance(response, ToDeviceError):
|
||||
self.logger.warning("Matrix SAS MAC send failed for {}: {}", sender, response)
|
||||
else:
|
||||
self.logger.info("Matrix SAS verification completed for {}", sender)
|
||||
return
|
||||
|
||||
if isinstance(event, KeyVerificationCancel):
|
||||
self.logger.info(
|
||||
"Matrix SAS verification cancelled by {}: {}",
|
||||
sender,
|
||||
getattr(event, "reason", ""),
|
||||
)
|
||||
|
||||
def _is_fatal_auth_response(self, response: Any) -> bool:
|
||||
code = getattr(response, "status_code", None)
|
||||
is_auth = code in {"M_UNKNOWN_TOKEN", "M_FORBIDDEN", "M_UNAUTHORIZED"}
|
||||
|
||||
@ -50,7 +50,15 @@ class _FakeAsyncClient:
|
||||
self.stop_sync_forever_called = False
|
||||
self.join_calls: list[str] = []
|
||||
self.callbacks: list[tuple[object, object]] = []
|
||||
self.to_device_callbacks: list[tuple[object, object]] = []
|
||||
self.response_callbacks: list[tuple[object, object]] = []
|
||||
self.key_verifications: dict[str, object] = {}
|
||||
self.accept_key_verification_calls: list[str] = []
|
||||
self.confirm_short_auth_string_calls: list[str] = []
|
||||
self.to_device_calls: list[object] = []
|
||||
self.accept_key_verification_response: object | None = None
|
||||
self.confirm_short_auth_string_response: object | None = None
|
||||
self.to_device_response: object | None = None
|
||||
self.rooms: dict[str, object] = {}
|
||||
self.room_send_calls: list[dict[str, object]] = []
|
||||
self.typing_calls: list[tuple[str, bool, int]] = []
|
||||
@ -70,6 +78,9 @@ class _FakeAsyncClient:
|
||||
def add_event_callback(self, callback, event_type) -> None:
|
||||
self.callbacks.append((callback, event_type))
|
||||
|
||||
def add_to_device_callback(self, callback, event_type) -> None:
|
||||
self.to_device_callbacks.append((callback, event_type))
|
||||
|
||||
def add_response_callback(self, callback, response_type) -> None:
|
||||
self.response_callbacks.append((callback, response_type))
|
||||
|
||||
@ -82,6 +93,18 @@ class _FakeAsyncClient:
|
||||
async def join(self, room_id: str) -> None:
|
||||
self.join_calls.append(room_id)
|
||||
|
||||
async def accept_key_verification(self, transaction_id: str):
|
||||
self.accept_key_verification_calls.append(transaction_id)
|
||||
return self.accept_key_verification_response
|
||||
|
||||
async def confirm_short_auth_string(self, transaction_id: str):
|
||||
self.confirm_short_auth_string_calls.append(transaction_id)
|
||||
return self.confirm_short_auth_string_response
|
||||
|
||||
async def to_device(self, message):
|
||||
self.to_device_calls.append(message)
|
||||
return self.to_device_response
|
||||
|
||||
async def room_send(
|
||||
self,
|
||||
room_id: str,
|
||||
@ -166,6 +189,61 @@ class _FakeAsyncClient:
|
||||
return None
|
||||
|
||||
|
||||
class _FakeSas:
|
||||
def __init__(self) -> None:
|
||||
self.share_key_called = False
|
||||
self.get_mac_called = False
|
||||
|
||||
def share_key(self):
|
||||
self.share_key_called = True
|
||||
return {"type": "share_key"}
|
||||
|
||||
def get_mac(self):
|
||||
self.get_mac_called = True
|
||||
return {"type": "mac"}
|
||||
|
||||
|
||||
class _FakeKeyVerificationStart:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
sender: str = "@alice:matrix.org",
|
||||
transaction_id: str = "tx1",
|
||||
short_authentication_string: list[str] | None = None,
|
||||
) -> None:
|
||||
self.sender = sender
|
||||
self.transaction_id = transaction_id
|
||||
self.short_authentication_string = short_authentication_string or ["emoji"]
|
||||
|
||||
|
||||
class _FakeKeyVerificationKey:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
sender: str = "@alice:matrix.org",
|
||||
transaction_id: str = "tx1",
|
||||
) -> None:
|
||||
self.sender = sender
|
||||
self.transaction_id = transaction_id
|
||||
|
||||
|
||||
class _FakeKeyVerificationMac:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
sender: str = "@alice:matrix.org",
|
||||
transaction_id: str = "tx1",
|
||||
) -> None:
|
||||
self.sender = sender
|
||||
self.transaction_id = transaction_id
|
||||
|
||||
|
||||
def _patch_key_verification_events(monkeypatch) -> None:
|
||||
monkeypatch.setattr(matrix_module, "KeyVerificationStart", _FakeKeyVerificationStart)
|
||||
monkeypatch.setattr(matrix_module, "KeyVerificationKey", _FakeKeyVerificationKey)
|
||||
monkeypatch.setattr(matrix_module, "KeyVerificationMac", _FakeKeyVerificationMac)
|
||||
|
||||
|
||||
def _make_config(**kwargs) -> MatrixConfig:
|
||||
kwargs.setdefault("allow_from", ["*"])
|
||||
return MatrixConfig(
|
||||
@ -209,6 +287,7 @@ async def test_start_skips_load_store_when_device_id_missing(
|
||||
assert clients[0].config.encryption_enabled is True
|
||||
assert clients[0].load_store_called is False
|
||||
assert len(clients[0].callbacks) == 3
|
||||
assert clients[0].to_device_callbacks == []
|
||||
assert len(clients[0].response_callbacks) == 3
|
||||
|
||||
await channel.stop()
|
||||
@ -227,6 +306,119 @@ async def test_register_event_callbacks_uses_media_base_filter() -> None:
|
||||
assert client.callbacks[1][1] == matrix_module.MATRIX_MEDIA_EVENT_FILTER
|
||||
|
||||
|
||||
def test_register_to_device_callbacks_when_sas_verification_enabled() -> None:
|
||||
channel = MatrixChannel(_make_config(sas_verification=True), MessageBus())
|
||||
client = _FakeAsyncClient("", "", "", None)
|
||||
channel.client = client
|
||||
|
||||
channel._register_to_device_callbacks()
|
||||
|
||||
assert client.to_device_callbacks == [
|
||||
(channel._on_key_verification_event, (matrix_module.KeyVerificationEvent,))
|
||||
]
|
||||
|
||||
|
||||
def test_register_to_device_callbacks_skips_when_e2ee_disabled() -> None:
|
||||
channel = MatrixChannel(
|
||||
_make_config(e2ee_enabled=False, sas_verification=True),
|
||||
MessageBus(),
|
||||
)
|
||||
client = _FakeAsyncClient("", "", "", None)
|
||||
channel.client = client
|
||||
|
||||
channel._register_to_device_callbacks()
|
||||
|
||||
assert client.to_device_callbacks == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sas_verification_start_accepts_allowed_sender(monkeypatch) -> None:
|
||||
_patch_key_verification_events(monkeypatch)
|
||||
channel = MatrixChannel(
|
||||
_make_config(allow_from=["@alice:matrix.org"], sas_verification=True),
|
||||
MessageBus(),
|
||||
)
|
||||
client = _FakeAsyncClient("", "", "", None)
|
||||
sas = _FakeSas()
|
||||
client.key_verifications["tx1"] = sas
|
||||
channel.client = client
|
||||
|
||||
await channel._handle_key_verification_event(_FakeKeyVerificationStart())
|
||||
|
||||
assert client.accept_key_verification_calls == ["tx1"]
|
||||
assert sas.share_key_called is True
|
||||
assert client.to_device_calls == [{"type": "share_key"}]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sas_verification_ignores_denied_sender(monkeypatch) -> None:
|
||||
_patch_key_verification_events(monkeypatch)
|
||||
channel = MatrixChannel(
|
||||
_make_config(allow_from=["@alice:matrix.org"], sas_verification=True),
|
||||
MessageBus(),
|
||||
)
|
||||
client = _FakeAsyncClient("", "", "", None)
|
||||
client.key_verifications["tx1"] = _FakeSas()
|
||||
channel.client = client
|
||||
|
||||
await channel._handle_key_verification_event(
|
||||
_FakeKeyVerificationStart(sender="@mallory:matrix.org")
|
||||
)
|
||||
|
||||
assert client.accept_key_verification_calls == []
|
||||
assert client.to_device_calls == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sas_verification_ignores_when_disabled(monkeypatch) -> None:
|
||||
_patch_key_verification_events(monkeypatch)
|
||||
channel = MatrixChannel(
|
||||
_make_config(allow_from=["@alice:matrix.org"], sas_verification=False),
|
||||
MessageBus(),
|
||||
)
|
||||
client = _FakeAsyncClient("", "", "", None)
|
||||
client.key_verifications["tx1"] = _FakeSas()
|
||||
channel.client = client
|
||||
|
||||
await channel._handle_key_verification_event(_FakeKeyVerificationStart())
|
||||
|
||||
assert client.accept_key_verification_calls == []
|
||||
assert client.to_device_calls == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sas_verification_key_confirms_allowed_sender(monkeypatch) -> None:
|
||||
_patch_key_verification_events(monkeypatch)
|
||||
channel = MatrixChannel(
|
||||
_make_config(allow_from=["@alice:matrix.org"], sas_verification=True),
|
||||
MessageBus(),
|
||||
)
|
||||
client = _FakeAsyncClient("", "", "", None)
|
||||
channel.client = client
|
||||
|
||||
await channel._handle_key_verification_event(_FakeKeyVerificationKey())
|
||||
|
||||
assert client.confirm_short_auth_string_calls == ["tx1"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sas_verification_mac_sends_mac_for_allowed_sender(monkeypatch) -> None:
|
||||
_patch_key_verification_events(monkeypatch)
|
||||
channel = MatrixChannel(
|
||||
_make_config(allow_from=["@alice:matrix.org"], sas_verification=True),
|
||||
MessageBus(),
|
||||
)
|
||||
client = _FakeAsyncClient("", "", "", None)
|
||||
sas = _FakeSas()
|
||||
client.key_verifications["tx1"] = sas
|
||||
channel.client = client
|
||||
|
||||
await channel._handle_key_verification_event(_FakeKeyVerificationMac())
|
||||
|
||||
assert sas.get_mac_called is True
|
||||
assert client.to_device_calls == [{"type": "mac"}]
|
||||
|
||||
|
||||
def test_media_event_filter_does_not_match_text_events() -> None:
|
||||
assert not issubclass(matrix_module.RoomMessageText, matrix_module.MATRIX_MEDIA_EVENT_FILTER)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user