mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-15 15:24:06 +00:00
fix(matrix): handle SAS device verification
This commit is contained in:
parent
0cc58a80a4
commit
68712fc489
@ -244,6 +244,7 @@ for reliable encryption, password login is recommended instead. If the
|
|||||||
"userId": "@nanobot:matrix.org",
|
"userId": "@nanobot:matrix.org",
|
||||||
"password": "mypasswordhere",
|
"password": "mypasswordhere",
|
||||||
"e2eeEnabled": true,
|
"e2eeEnabled": true,
|
||||||
|
"sasVerification": true,
|
||||||
"allowFrom": ["@your_user:matrix.org"],
|
"allowFrom": ["@your_user:matrix.org"],
|
||||||
"groupPolicy": "open",
|
"groupPolicy": "open",
|
||||||
"groupAllowFrom": [],
|
"groupAllowFrom": [],
|
||||||
@ -263,6 +264,7 @@ for reliable encryption, password login is recommended instead. If the
|
|||||||
| `groupAllowFrom` | Room allowlist (used when policy is `allowlist`). |
|
| `groupAllowFrom` | Room allowlist (used when policy is `allowlist`). |
|
||||||
| `allowRoomMentions` | Accept `@room` mentions in mention mode. |
|
| `allowRoomMentions` | Accept `@room` mentions in mention mode. |
|
||||||
| `e2eeEnabled` | E2EE support (default `true`). Set `false` for plaintext-only. |
|
| `e2eeEnabled` | E2EE support (default `true`). Set `false` for plaintext-only. |
|
||||||
|
| `sasVerification` | Auto-complete SAS device verification requests from allowed users (default `false`). Useful for Element X, which does not expose manual trust for third-party devices. |
|
||||||
| `maxMediaBytes` | Max attachment size (default `20MB`). Set `0` to block all media. |
|
| `maxMediaBytes` | Max attachment size (default `20MB`). Set `0` to block all media. |
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -23,6 +23,12 @@ try:
|
|||||||
AsyncClientConfig,
|
AsyncClientConfig,
|
||||||
InviteEvent,
|
InviteEvent,
|
||||||
JoinError,
|
JoinError,
|
||||||
|
KeyVerificationCancel,
|
||||||
|
KeyVerificationEvent,
|
||||||
|
KeyVerificationKey,
|
||||||
|
KeyVerificationMac,
|
||||||
|
KeyVerificationStart,
|
||||||
|
LocalProtocolError,
|
||||||
LoginResponse,
|
LoginResponse,
|
||||||
MatrixRoom,
|
MatrixRoom,
|
||||||
RoomEncryptedMedia,
|
RoomEncryptedMedia,
|
||||||
@ -33,6 +39,7 @@ try:
|
|||||||
RoomSendResponse,
|
RoomSendResponse,
|
||||||
RoomTypingError,
|
RoomTypingError,
|
||||||
SyncError,
|
SyncError,
|
||||||
|
ToDeviceError,
|
||||||
UploadError,
|
UploadError,
|
||||||
)
|
)
|
||||||
from nio.crypto.attachments import decrypt_attachment
|
from nio.crypto.attachments import decrypt_attachment
|
||||||
@ -194,6 +201,7 @@ class MatrixConfig(Base):
|
|||||||
access_token: str = ""
|
access_token: str = ""
|
||||||
device_id: str = ""
|
device_id: str = ""
|
||||||
e2ee_enabled: bool = Field(default=True, alias="e2eeEnabled")
|
e2ee_enabled: bool = Field(default=True, alias="e2eeEnabled")
|
||||||
|
sas_verification: bool = Field(default=False, alias="sasVerification")
|
||||||
sync_stop_grace_seconds: int = 2
|
sync_stop_grace_seconds: int = 2
|
||||||
max_media_bytes: int = 20 * 1024 * 1024
|
max_media_bytes: int = 20 * 1024 * 1024
|
||||||
max_concurrent_media_downloads: int = 2
|
max_concurrent_media_downloads: int = 2
|
||||||
@ -268,6 +276,7 @@ class MatrixChannel(BaseChannel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self._register_event_callbacks()
|
self._register_event_callbacks()
|
||||||
|
self._register_to_device_callbacks()
|
||||||
self._register_response_callbacks()
|
self._register_response_callbacks()
|
||||||
|
|
||||||
if not self.config.e2ee_enabled:
|
if not self.config.e2ee_enabled:
|
||||||
@ -572,11 +581,102 @@ class MatrixChannel(BaseChannel):
|
|||||||
self.client.add_event_callback(self._on_media_message, MATRIX_MEDIA_EVENT_FILTER)
|
self.client.add_event_callback(self._on_media_message, MATRIX_MEDIA_EVENT_FILTER)
|
||||||
self.client.add_event_callback(self._on_room_invite, InviteEvent)
|
self.client.add_event_callback(self._on_room_invite, InviteEvent)
|
||||||
|
|
||||||
|
def _register_to_device_callbacks(self) -> None:
|
||||||
|
if self.config.e2ee_enabled and self.config.sas_verification:
|
||||||
|
self.client.add_to_device_callback(
|
||||||
|
self._on_key_verification_event,
|
||||||
|
(KeyVerificationEvent,),
|
||||||
|
)
|
||||||
|
|
||||||
def _register_response_callbacks(self) -> None:
|
def _register_response_callbacks(self) -> None:
|
||||||
self.client.add_response_callback(self._on_sync_error, SyncError)
|
self.client.add_response_callback(self._on_sync_error, SyncError)
|
||||||
self.client.add_response_callback(self._on_join_error, JoinError)
|
self.client.add_response_callback(self._on_join_error, JoinError)
|
||||||
self.client.add_response_callback(self._on_send_error, RoomSendError)
|
self.client.add_response_callback(self._on_send_error, RoomSendError)
|
||||||
|
|
||||||
|
def _is_sas_sender_allowed(self, sender: str) -> bool:
|
||||||
|
return bool(sender and self.is_allowed(sender))
|
||||||
|
|
||||||
|
async def _on_key_verification_event(self, event: KeyVerificationEvent) -> None:
|
||||||
|
try:
|
||||||
|
await self._handle_key_verification_event(event)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
self.logger.exception("Matrix SAS verification handling failed")
|
||||||
|
|
||||||
|
async def _handle_key_verification_event(self, event: KeyVerificationEvent) -> None:
|
||||||
|
if not (self.config.e2ee_enabled and self.config.sas_verification):
|
||||||
|
return
|
||||||
|
if not self.client:
|
||||||
|
return
|
||||||
|
|
||||||
|
sender = str(getattr(event, "sender", "") or "")
|
||||||
|
transaction_id = str(getattr(event, "transaction_id", "") or "")
|
||||||
|
if not transaction_id or not self._is_sas_sender_allowed(sender):
|
||||||
|
return
|
||||||
|
|
||||||
|
if isinstance(event, KeyVerificationStart):
|
||||||
|
if "emoji" not in (getattr(event, "short_authentication_string", None) or []):
|
||||||
|
self.logger.info(
|
||||||
|
"Ignoring Matrix SAS verification from {} without emoji support",
|
||||||
|
sender,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
response = await self.client.accept_key_verification(transaction_id)
|
||||||
|
if isinstance(response, ToDeviceError):
|
||||||
|
self.logger.warning("Matrix SAS accept failed for {}: {}", sender, response)
|
||||||
|
return
|
||||||
|
|
||||||
|
sas = getattr(self.client, "key_verifications", {}).get(transaction_id)
|
||||||
|
if sas is None:
|
||||||
|
self.logger.warning(
|
||||||
|
"Matrix SAS state missing after accept for transaction {}",
|
||||||
|
transaction_id,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
response = await self.client.to_device(sas.share_key())
|
||||||
|
if isinstance(response, ToDeviceError):
|
||||||
|
self.logger.warning(
|
||||||
|
"Matrix SAS key share failed for {}: {}",
|
||||||
|
sender,
|
||||||
|
response,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if isinstance(event, KeyVerificationKey):
|
||||||
|
response = await self.client.confirm_short_auth_string(transaction_id)
|
||||||
|
if isinstance(response, ToDeviceError):
|
||||||
|
self.logger.warning("Matrix SAS confirm failed for {}: {}", sender, response)
|
||||||
|
return
|
||||||
|
|
||||||
|
if isinstance(event, KeyVerificationMac):
|
||||||
|
sas = getattr(self.client, "key_verifications", {}).get(transaction_id)
|
||||||
|
if sas is None:
|
||||||
|
self.logger.warning(
|
||||||
|
"Matrix SAS state missing for MAC transaction {}",
|
||||||
|
transaction_id,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
response = await self.client.to_device(sas.get_mac())
|
||||||
|
except LocalProtocolError as e:
|
||||||
|
self.logger.warning("Matrix SAS MAC failed for {}: {}", sender, e)
|
||||||
|
return
|
||||||
|
if isinstance(response, ToDeviceError):
|
||||||
|
self.logger.warning("Matrix SAS MAC send failed for {}: {}", sender, response)
|
||||||
|
else:
|
||||||
|
self.logger.info("Matrix SAS verification completed for {}", sender)
|
||||||
|
return
|
||||||
|
|
||||||
|
if isinstance(event, KeyVerificationCancel):
|
||||||
|
self.logger.info(
|
||||||
|
"Matrix SAS verification cancelled by {}: {}",
|
||||||
|
sender,
|
||||||
|
getattr(event, "reason", ""),
|
||||||
|
)
|
||||||
|
|
||||||
def _is_fatal_auth_response(self, response: Any) -> bool:
|
def _is_fatal_auth_response(self, response: Any) -> bool:
|
||||||
code = getattr(response, "status_code", None)
|
code = getattr(response, "status_code", None)
|
||||||
is_auth = code in {"M_UNKNOWN_TOKEN", "M_FORBIDDEN", "M_UNAUTHORIZED"}
|
is_auth = code in {"M_UNKNOWN_TOKEN", "M_FORBIDDEN", "M_UNAUTHORIZED"}
|
||||||
|
|||||||
@ -50,7 +50,15 @@ class _FakeAsyncClient:
|
|||||||
self.stop_sync_forever_called = False
|
self.stop_sync_forever_called = False
|
||||||
self.join_calls: list[str] = []
|
self.join_calls: list[str] = []
|
||||||
self.callbacks: list[tuple[object, object]] = []
|
self.callbacks: list[tuple[object, object]] = []
|
||||||
|
self.to_device_callbacks: list[tuple[object, object]] = []
|
||||||
self.response_callbacks: list[tuple[object, object]] = []
|
self.response_callbacks: list[tuple[object, object]] = []
|
||||||
|
self.key_verifications: dict[str, object] = {}
|
||||||
|
self.accept_key_verification_calls: list[str] = []
|
||||||
|
self.confirm_short_auth_string_calls: list[str] = []
|
||||||
|
self.to_device_calls: list[object] = []
|
||||||
|
self.accept_key_verification_response: object | None = None
|
||||||
|
self.confirm_short_auth_string_response: object | None = None
|
||||||
|
self.to_device_response: object | None = None
|
||||||
self.rooms: dict[str, object] = {}
|
self.rooms: dict[str, object] = {}
|
||||||
self.room_send_calls: list[dict[str, object]] = []
|
self.room_send_calls: list[dict[str, object]] = []
|
||||||
self.typing_calls: list[tuple[str, bool, int]] = []
|
self.typing_calls: list[tuple[str, bool, int]] = []
|
||||||
@ -70,6 +78,9 @@ class _FakeAsyncClient:
|
|||||||
def add_event_callback(self, callback, event_type) -> None:
|
def add_event_callback(self, callback, event_type) -> None:
|
||||||
self.callbacks.append((callback, event_type))
|
self.callbacks.append((callback, event_type))
|
||||||
|
|
||||||
|
def add_to_device_callback(self, callback, event_type) -> None:
|
||||||
|
self.to_device_callbacks.append((callback, event_type))
|
||||||
|
|
||||||
def add_response_callback(self, callback, response_type) -> None:
|
def add_response_callback(self, callback, response_type) -> None:
|
||||||
self.response_callbacks.append((callback, response_type))
|
self.response_callbacks.append((callback, response_type))
|
||||||
|
|
||||||
@ -82,6 +93,18 @@ class _FakeAsyncClient:
|
|||||||
async def join(self, room_id: str) -> None:
|
async def join(self, room_id: str) -> None:
|
||||||
self.join_calls.append(room_id)
|
self.join_calls.append(room_id)
|
||||||
|
|
||||||
|
async def accept_key_verification(self, transaction_id: str):
|
||||||
|
self.accept_key_verification_calls.append(transaction_id)
|
||||||
|
return self.accept_key_verification_response
|
||||||
|
|
||||||
|
async def confirm_short_auth_string(self, transaction_id: str):
|
||||||
|
self.confirm_short_auth_string_calls.append(transaction_id)
|
||||||
|
return self.confirm_short_auth_string_response
|
||||||
|
|
||||||
|
async def to_device(self, message):
|
||||||
|
self.to_device_calls.append(message)
|
||||||
|
return self.to_device_response
|
||||||
|
|
||||||
async def room_send(
|
async def room_send(
|
||||||
self,
|
self,
|
||||||
room_id: str,
|
room_id: str,
|
||||||
@ -166,6 +189,61 @@ class _FakeAsyncClient:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeSas:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.share_key_called = False
|
||||||
|
self.get_mac_called = False
|
||||||
|
|
||||||
|
def share_key(self):
|
||||||
|
self.share_key_called = True
|
||||||
|
return {"type": "share_key"}
|
||||||
|
|
||||||
|
def get_mac(self):
|
||||||
|
self.get_mac_called = True
|
||||||
|
return {"type": "mac"}
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeKeyVerificationStart:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
sender: str = "@alice:matrix.org",
|
||||||
|
transaction_id: str = "tx1",
|
||||||
|
short_authentication_string: list[str] | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.sender = sender
|
||||||
|
self.transaction_id = transaction_id
|
||||||
|
self.short_authentication_string = short_authentication_string or ["emoji"]
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeKeyVerificationKey:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
sender: str = "@alice:matrix.org",
|
||||||
|
transaction_id: str = "tx1",
|
||||||
|
) -> None:
|
||||||
|
self.sender = sender
|
||||||
|
self.transaction_id = transaction_id
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeKeyVerificationMac:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
sender: str = "@alice:matrix.org",
|
||||||
|
transaction_id: str = "tx1",
|
||||||
|
) -> None:
|
||||||
|
self.sender = sender
|
||||||
|
self.transaction_id = transaction_id
|
||||||
|
|
||||||
|
|
||||||
|
def _patch_key_verification_events(monkeypatch) -> None:
|
||||||
|
monkeypatch.setattr(matrix_module, "KeyVerificationStart", _FakeKeyVerificationStart)
|
||||||
|
monkeypatch.setattr(matrix_module, "KeyVerificationKey", _FakeKeyVerificationKey)
|
||||||
|
monkeypatch.setattr(matrix_module, "KeyVerificationMac", _FakeKeyVerificationMac)
|
||||||
|
|
||||||
|
|
||||||
def _make_config(**kwargs) -> MatrixConfig:
|
def _make_config(**kwargs) -> MatrixConfig:
|
||||||
kwargs.setdefault("allow_from", ["*"])
|
kwargs.setdefault("allow_from", ["*"])
|
||||||
return MatrixConfig(
|
return MatrixConfig(
|
||||||
@ -209,6 +287,7 @@ async def test_start_skips_load_store_when_device_id_missing(
|
|||||||
assert clients[0].config.encryption_enabled is True
|
assert clients[0].config.encryption_enabled is True
|
||||||
assert clients[0].load_store_called is False
|
assert clients[0].load_store_called is False
|
||||||
assert len(clients[0].callbacks) == 3
|
assert len(clients[0].callbacks) == 3
|
||||||
|
assert clients[0].to_device_callbacks == []
|
||||||
assert len(clients[0].response_callbacks) == 3
|
assert len(clients[0].response_callbacks) == 3
|
||||||
|
|
||||||
await channel.stop()
|
await channel.stop()
|
||||||
@ -227,6 +306,119 @@ async def test_register_event_callbacks_uses_media_base_filter() -> None:
|
|||||||
assert client.callbacks[1][1] == matrix_module.MATRIX_MEDIA_EVENT_FILTER
|
assert client.callbacks[1][1] == matrix_module.MATRIX_MEDIA_EVENT_FILTER
|
||||||
|
|
||||||
|
|
||||||
|
def test_register_to_device_callbacks_when_sas_verification_enabled() -> None:
|
||||||
|
channel = MatrixChannel(_make_config(sas_verification=True), MessageBus())
|
||||||
|
client = _FakeAsyncClient("", "", "", None)
|
||||||
|
channel.client = client
|
||||||
|
|
||||||
|
channel._register_to_device_callbacks()
|
||||||
|
|
||||||
|
assert client.to_device_callbacks == [
|
||||||
|
(channel._on_key_verification_event, (matrix_module.KeyVerificationEvent,))
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_register_to_device_callbacks_skips_when_e2ee_disabled() -> None:
|
||||||
|
channel = MatrixChannel(
|
||||||
|
_make_config(e2ee_enabled=False, sas_verification=True),
|
||||||
|
MessageBus(),
|
||||||
|
)
|
||||||
|
client = _FakeAsyncClient("", "", "", None)
|
||||||
|
channel.client = client
|
||||||
|
|
||||||
|
channel._register_to_device_callbacks()
|
||||||
|
|
||||||
|
assert client.to_device_callbacks == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sas_verification_start_accepts_allowed_sender(monkeypatch) -> None:
|
||||||
|
_patch_key_verification_events(monkeypatch)
|
||||||
|
channel = MatrixChannel(
|
||||||
|
_make_config(allow_from=["@alice:matrix.org"], sas_verification=True),
|
||||||
|
MessageBus(),
|
||||||
|
)
|
||||||
|
client = _FakeAsyncClient("", "", "", None)
|
||||||
|
sas = _FakeSas()
|
||||||
|
client.key_verifications["tx1"] = sas
|
||||||
|
channel.client = client
|
||||||
|
|
||||||
|
await channel._handle_key_verification_event(_FakeKeyVerificationStart())
|
||||||
|
|
||||||
|
assert client.accept_key_verification_calls == ["tx1"]
|
||||||
|
assert sas.share_key_called is True
|
||||||
|
assert client.to_device_calls == [{"type": "share_key"}]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sas_verification_ignores_denied_sender(monkeypatch) -> None:
|
||||||
|
_patch_key_verification_events(monkeypatch)
|
||||||
|
channel = MatrixChannel(
|
||||||
|
_make_config(allow_from=["@alice:matrix.org"], sas_verification=True),
|
||||||
|
MessageBus(),
|
||||||
|
)
|
||||||
|
client = _FakeAsyncClient("", "", "", None)
|
||||||
|
client.key_verifications["tx1"] = _FakeSas()
|
||||||
|
channel.client = client
|
||||||
|
|
||||||
|
await channel._handle_key_verification_event(
|
||||||
|
_FakeKeyVerificationStart(sender="@mallory:matrix.org")
|
||||||
|
)
|
||||||
|
|
||||||
|
assert client.accept_key_verification_calls == []
|
||||||
|
assert client.to_device_calls == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sas_verification_ignores_when_disabled(monkeypatch) -> None:
|
||||||
|
_patch_key_verification_events(monkeypatch)
|
||||||
|
channel = MatrixChannel(
|
||||||
|
_make_config(allow_from=["@alice:matrix.org"], sas_verification=False),
|
||||||
|
MessageBus(),
|
||||||
|
)
|
||||||
|
client = _FakeAsyncClient("", "", "", None)
|
||||||
|
client.key_verifications["tx1"] = _FakeSas()
|
||||||
|
channel.client = client
|
||||||
|
|
||||||
|
await channel._handle_key_verification_event(_FakeKeyVerificationStart())
|
||||||
|
|
||||||
|
assert client.accept_key_verification_calls == []
|
||||||
|
assert client.to_device_calls == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sas_verification_key_confirms_allowed_sender(monkeypatch) -> None:
|
||||||
|
_patch_key_verification_events(monkeypatch)
|
||||||
|
channel = MatrixChannel(
|
||||||
|
_make_config(allow_from=["@alice:matrix.org"], sas_verification=True),
|
||||||
|
MessageBus(),
|
||||||
|
)
|
||||||
|
client = _FakeAsyncClient("", "", "", None)
|
||||||
|
channel.client = client
|
||||||
|
|
||||||
|
await channel._handle_key_verification_event(_FakeKeyVerificationKey())
|
||||||
|
|
||||||
|
assert client.confirm_short_auth_string_calls == ["tx1"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sas_verification_mac_sends_mac_for_allowed_sender(monkeypatch) -> None:
|
||||||
|
_patch_key_verification_events(monkeypatch)
|
||||||
|
channel = MatrixChannel(
|
||||||
|
_make_config(allow_from=["@alice:matrix.org"], sas_verification=True),
|
||||||
|
MessageBus(),
|
||||||
|
)
|
||||||
|
client = _FakeAsyncClient("", "", "", None)
|
||||||
|
sas = _FakeSas()
|
||||||
|
client.key_verifications["tx1"] = sas
|
||||||
|
channel.client = client
|
||||||
|
|
||||||
|
await channel._handle_key_verification_event(_FakeKeyVerificationMac())
|
||||||
|
|
||||||
|
assert sas.get_mac_called is True
|
||||||
|
assert client.to_device_calls == [{"type": "mac"}]
|
||||||
|
|
||||||
|
|
||||||
def test_media_event_filter_does_not_match_text_events() -> None:
|
def test_media_event_filter_does_not_match_text_events() -> None:
|
||||||
assert not issubclass(matrix_module.RoomMessageText, matrix_module.MATRIX_MEDIA_EVENT_FILTER)
|
assert not issubclass(matrix_module.RoomMessageText, matrix_module.MATRIX_MEDIA_EVENT_FILTER)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user