diff --git a/nanobot/channels/base.py b/nanobot/channels/base.py index 63c822f1d..48ee3cd00 100644 --- a/nanobot/channels/base.py +++ b/nanobot/channels/base.py @@ -2,6 +2,7 @@ from __future__ import annotations +import time from abc import ABC, abstractmethod from pathlib import Path from typing import Any @@ -191,6 +192,10 @@ class BaseChannel(ABC): 2. ``allowFrom`` list → allow if sender_id is present. 3. Pairing store approved list → allow if previously approved. 4. Otherwise deny. + + An empty ``allowFrom`` list does not cause a hard exit; instead it + defers to the pairing store so that unknown DM senders can request + access via a pairing code. """ if isinstance(self.config, dict): if "allow_from" in self.config: @@ -296,8 +301,6 @@ class BaseChannel(ABC): reply = "No pending pairing requests." else: lines = ["Pending pairing requests:"] - import time - for item in pending: remaining = int(item.get("expires_at", 0) - time.time()) expiry = f"{remaining}s" if remaining > 0 else "expired" @@ -331,12 +334,14 @@ class BaseChannel(ABC): elif sub == "revoke": if arg is None: - reply = "Usage: `/pairing revoke `" + reply = "Usage: `/pairing revoke ` or `/pairing revoke `" else: - if revoke(self.name, arg): - reply = f"Revoked {arg} from {self.name}" + target_channel = parts[3] if len(parts) > 3 else self.name + target_user = arg if len(parts) <= 3 else parts[3] + if revoke(target_channel, target_user): + reply = f"Revoked {target_user} from {target_channel}" else: - reply = f"{arg} was not in the approved list for {self.name}" + reply = f"{target_user} was not in the approved list for {target_channel}" else: reply = ( diff --git a/nanobot/channels/manager.py b/nanobot/channels/manager.py index 3a6b6e50f..de0ed0c01 100644 --- a/nanobot/channels/manager.py +++ b/nanobot/channels/manager.py @@ -143,9 +143,9 @@ class ChannelManager: allow = cfg.get("allowFrom") else: allow = getattr(cfg, "allow_from", None) - if allow == []: + if allow is None: raise SystemExit( - f'Error: "{name}" has empty allowFrom (denies all). ' + f'Error: "{name}" is missing allowFrom. ' f'Set ["*"] to allow everyone, or add specific user IDs.' ) diff --git a/nanobot/channels/matrix.py b/nanobot/channels/matrix.py index 3d9e33c9d..a11be1e1c 100644 --- a/nanobot/channels/matrix.py +++ b/nanobot/channels/matrix.py @@ -28,10 +28,11 @@ try: RoomMessageMedia, RoomMessageText, RoomSendError, + RoomSendResponse, RoomTypingError, SyncError, - UploadError, RoomSendResponse, -) + UploadError, + ) from nio.crypto.attachments import decrypt_attachment from nio.exceptions import EncryptionError except ImportError as e: @@ -107,7 +108,7 @@ class _StreamBuf: :ivar text: Stores the text content of the buffer. :type text: str - :ivar event_id: Identifier for the associated event. None indicates no + :ivar event_id: Identifier for the associated event. None indicates no specific event association. :type event_id: str | None :ivar last_edit: Timestamp of the most recent edit to the buffer. @@ -140,19 +141,19 @@ def _build_matrix_text_content( ) -> dict[str, object]: """ Constructs and returns a dictionary representing the matrix text content with optional - HTML formatting and reference to an existing event for replacement. This function is + HTML formatting and reference to an existing event for replacement. This function is primarily used to create content payloads compatible with the Matrix messaging protocol. :param text: The plain text content to include in the message. :type text: str - :param event_id: Optional ID of the event to replace. If provided, the function will - include information indicating that the message is a replacement of the specified + :param event_id: Optional ID of the event to replace. If provided, the function will + include information indicating that the message is a replacement of the specified event. :type event_id: str | None :param thread_relates_to: Optional Matrix thread relation metadata. For edits this is stored in ``m.new_content`` so the replacement remains in the same thread. :type thread_relates_to: dict[str, object] | None - :return: A dictionary containing the matrix text content, potentially enriched with + :return: A dictionary containing the matrix text content, potentially enriched with HTML formatting and replacement metadata if applicable. :rtype: dict[str, object] """ @@ -523,7 +524,7 @@ class MatrixChannel(BaseChannel): return await self._stop_typing_keepalive(chat_id, clear_typing=True) - + content = _build_matrix_text_content( buf.text, buf.event_id, @@ -537,7 +538,7 @@ class MatrixChannel(BaseChannel): buf = _StreamBuf() self._stream_bufs[chat_id] = buf buf.text += delta - + if not buf.text.strip(): return @@ -870,6 +871,7 @@ class MatrixChannel(BaseChannel): await self._handle_message( sender_id=event.sender, chat_id=room.room_id, content=event.body, metadata=self._base_metadata(room, event), + is_dm=self._is_direct_room(room), ) except Exception: await self._stop_typing_keepalive(room.room_id, clear_typing=True) @@ -907,6 +909,7 @@ class MatrixChannel(BaseChannel): content="\n".join(parts), media=[attachment["path"]] if attachment else [], metadata=meta, + is_dm=self._is_direct_room(room), ) except Exception: await self._stop_typing_keepalive(room.room_id, clear_typing=True) diff --git a/nanobot/channels/slack.py b/nanobot/channels/slack.py index be3172bff..6c37fd3b1 100644 --- a/nanobot/channels/slack.py +++ b/nanobot/channels/slack.py @@ -342,6 +342,22 @@ class SlackChannel(BaseChannel): channel_type = event.get("channel_type") or "" if not self._is_allowed(sender_id, chat_id, channel_type): + if channel_type == "im" and self.config.dm.enabled: + from nanobot.pairing import generate_code + code = generate_code(self.name, sender_id) + reply = ( + "This assistant requires approval before it can respond.\n" + f"Your pairing code is: `{code}`\n" + f"Ask the owner to run: `nanobot pairing approve {code}`" + ) + await self.send( + OutboundMessage( + channel=self.name, + chat_id=chat_id, + content=reply, + metadata={"_pairing_code": code}, + ) + ) return if channel_type != "im" and not self._should_respond_in_channel(event_type, text, chat_id): @@ -608,11 +624,13 @@ class SlackChannel(BaseChannel): self.logger.debug("done reaction failed: {}", e) def _is_allowed(self, sender_id: str, chat_id: str, channel_type: str) -> bool: + from nanobot.pairing import is_approved + if channel_type == "im": if not self.config.dm.enabled: return False if self.config.dm.policy == "allowlist": - return sender_id in self.config.dm.allow_from + return sender_id in self.config.dm.allow_from or is_approved(self.name, sender_id) return True # Group / channel messages diff --git a/nanobot/channels/websocket.py b/nanobot/channels/websocket.py index 0a521e747..0db169512 100644 --- a/nanobot/channels/websocket.py +++ b/nanobot/channels/websocket.py @@ -1249,6 +1249,8 @@ class WebSocketChannel(BaseChannel): content = _parse_inbound_payload(raw) if content is None: continue + # WebSocket connections are always treated as 1:1 (DM) because + # each connection represents a single client browser/tab. await self._handle_message( sender_id=client_id, chat_id=default_chat_id, diff --git a/nanobot/pairing/store.py b/nanobot/pairing/store.py index d44ff61f1..fb531abdf 100644 --- a/nanobot/pairing/store.py +++ b/nanobot/pairing/store.py @@ -8,6 +8,7 @@ private-assistant scale: small JSON file, simple locking, no external DB. from __future__ import annotations import json +import os import secrets import string import threading @@ -20,8 +21,9 @@ from loguru import logger from nanobot.config.paths import get_data_dir _LOCK = threading.Lock() + _ALPHABET = string.ascii_uppercase + string.digits -_CODE_LENGTH = 6 # e.g. XK9-42F +_CODE_LENGTH = 8 # e.g. XK9-42F-MP _TTL_DEFAULT_S = 600 # 10 minutes @@ -48,7 +50,17 @@ def _save(data: dict[str, Any]) -> None: with open(tmp, "w", encoding="utf-8") as f: json.dump(data, f, indent=2, ensure_ascii=False) f.flush() + os.fsync(f.fileno()) tmp.replace(path) + # Ensure directory entry is flushed for durability (Unix only; no-op on Windows) + try: + fd = os.open(path.parent, os.O_RDONLY) + try: + os.fsync(fd) + finally: + os.close(fd) + except (OSError, NotImplementedError): + pass def _gc_pending(data: dict[str, Any]) -> None: @@ -75,7 +87,7 @@ def generate_code( # Ensure uniqueness for _ in range(100): raw = "".join(secrets.choice(_ALPHABET) for _ in range(_CODE_LENGTH)) - code = f"{raw[:3]}-{raw[3:]}" + code = f"{raw[:4]}-{raw[4:]}" if code not in data.get("pending", {}): break else: # pragma: no cover diff --git a/tests/channels/test_base_channel.py b/tests/channels/test_base_channel.py index 660aff60e..651e3365d 100644 --- a/tests/channels/test_base_channel.py +++ b/tests/channels/test_base_channel.py @@ -1,5 +1,7 @@ from types import SimpleNamespace +import pytest + from nanobot.bus.events import OutboundMessage from nanobot.bus.queue import MessageBus from nanobot.channels.base import BaseChannel @@ -7,6 +9,11 @@ from nanobot.channels.base import BaseChannel class _DummyChannel(BaseChannel): name = "dummy" + _sent: list[OutboundMessage] + + def __init__(self, config, bus): + super().__init__(config, bus) + self._sent = [] async def start(self) -> None: return None @@ -15,7 +22,7 @@ class _DummyChannel(BaseChannel): return None async def send(self, msg: OutboundMessage) -> None: - return None + self._sent.append(msg) def test_is_allowed_requires_exact_match() -> None: @@ -35,3 +42,94 @@ def test_is_allowed_denies_empty_dict_allow_from() -> None: channel = _DummyChannel({"allow_from": []}, MessageBus()) assert channel.is_allowed("alice") is False + + +def test_is_allowed_star_allows_all() -> None: + channel = _DummyChannel({"allowFrom": ["*"]}, MessageBus()) + assert channel.is_allowed("anyone") is True + + +def test_is_allowed_pairing_fallback(monkeypatch) -> None: + channel = _DummyChannel({"allowFrom": []}, MessageBus()) + monkeypatch.setattr( + "nanobot.channels.base.is_approved", lambda _ch, sid: sid == "paired" + ) + assert channel.is_allowed("paired") is True + assert channel.is_allowed("unknown") is False + + +@pytest.mark.asyncio +async def test_handle_message_dm_sends_pairing_code(monkeypatch) -> None: + channel = _DummyChannel({"allowFrom": []}, MessageBus()) + monkeypatch.setattr( + "nanobot.channels.base.generate_code", lambda _ch, sid: "ABCD-EFGH" + ) + + await channel._handle_message( + sender_id="stranger", chat_id="chat1", content="hello", is_dm=True + ) + + assert len(channel._sent) == 1 + msg = channel._sent[0] + assert "ABCD-EFGH" in msg.content + assert msg.metadata.get("_pairing_code") == "ABCD-EFGH" + + +@pytest.mark.asyncio +async def test_handle_message_group_ignores_unknown() -> None: + channel = _DummyChannel({"allowFrom": []}, MessageBus()) + + await channel._handle_message( + sender_id="stranger", chat_id="chat1", content="hello", is_dm=False + ) + + assert channel._sent == [] + + +@pytest.mark.asyncio +async def test_handle_pairing_command_list(monkeypatch) -> None: + channel = _DummyChannel({"allowFrom": ["owner"]}, MessageBus()) + monkeypatch.setattr( + "nanobot.channels.base.list_pending", + lambda: [ + { + "code": "ABCD-EFGH", + "channel": "dummy", + "sender_id": "123", + "expires_at": 9999999999, + } + ], + ) + + await channel._handle_pairing_command("owner", "chat1", "/pairing list") + + assert len(channel._sent) == 1 + assert "ABCD-EFGH" in channel._sent[0].content + + +@pytest.mark.asyncio +async def test_handle_pairing_command_approve(monkeypatch) -> None: + channel = _DummyChannel({"allowFrom": ["owner"]}, MessageBus()) + monkeypatch.setattr( + "nanobot.channels.base.approve_code", + lambda code: ("dummy", "123") if code == "ABCD-EFGH" else None, + ) + + await channel._handle_pairing_command("owner", "chat1", "/pairing approve ABCD-EFGH") + + assert len(channel._sent) == 1 + assert "Approved" in channel._sent[0].content + + +@pytest.mark.asyncio +async def test_handle_pairing_command_revoke(monkeypatch) -> None: + channel = _DummyChannel({"allowFrom": ["owner"]}, MessageBus()) + monkeypatch.setattr( + "nanobot.channels.base.revoke", + lambda ch, sid: sid == "123", + ) + + await channel._handle_pairing_command("owner", "chat1", "/pairing revoke 123") + + assert len(channel._sent) == 1 + assert "Revoked" in channel._sent[0].content diff --git a/tests/channels/test_channel_plugins.py b/tests/channels/test_channel_plugins.py index a32d96e1a..9b6e79783 100644 --- a/tests/channels/test_channel_plugins.py +++ b/tests/channels/test_channel_plugins.py @@ -961,8 +961,8 @@ class _StartableChannel(BaseChannel): @pytest.mark.asyncio -async def test_validate_allow_from_raises_on_empty_list(): - """_validate_allow_from should raise SystemExit when allow_from is empty list.""" +async def test_validate_allow_from_allows_empty_list(): + """Empty allow_from is valid now — pairing store handles unapproved senders.""" fake_config = SimpleNamespace( channels=ChannelsConfig(), providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), @@ -973,10 +973,8 @@ async def test_validate_allow_from_raises_on_empty_list(): mgr.channels = {"test": _ChannelWithAllowFrom(fake_config, None, [])} mgr._dispatch_task = None - with pytest.raises(SystemExit) as exc_info: - mgr._validate_allow_from() - - assert "empty allowFrom" in str(exc_info.value) + # Should not raise — empty list defers to pairing store + mgr._validate_allow_from() @pytest.mark.asyncio @@ -997,8 +995,8 @@ async def test_validate_allow_from_passes_with_asterisk(): @pytest.mark.asyncio -async def test_validate_allow_from_raises_on_empty_dict_allow_from(): - """_validate_allow_from should reject empty dict-backed allow_from lists.""" +async def test_validate_allow_from_allows_empty_dict_allow_from(): + """Empty dict-backed allow_from is valid — pairing store handles approval.""" fake_config = SimpleNamespace( channels=ChannelsConfig(), providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), @@ -1009,10 +1007,7 @@ async def test_validate_allow_from_raises_on_empty_dict_allow_from(): mgr.channels = {"test": _ChannelWithAllowFrom({"enabled": True}, None, [])} mgr._dispatch_task = None - with pytest.raises(SystemExit) as exc_info: - mgr._validate_allow_from() - - assert "empty allowFrom" in str(exc_info.value) + mgr._validate_allow_from() @pytest.mark.asyncio diff --git a/tests/pairing/test_store.py b/tests/pairing/test_store.py new file mode 100644 index 000000000..a3bbf7b39 --- /dev/null +++ b/tests/pairing/test_store.py @@ -0,0 +1,99 @@ +import time + +import pytest + +from nanobot.pairing import store + + +@pytest.fixture(autouse=True) +def _tmp_store(tmp_path, monkeypatch): + path = tmp_path / "pairing.json" + monkeypatch.setattr(store, "_store_path", lambda: path) + + +class TestGenerateCode: + def test_format(self) -> None: + code = store.generate_code("telegram", "123") + assert len(code) == 9 # 4 + 1 + 4 + assert code[4] == "-" + assert code.replace("-", "").isalnum() + assert code.replace("-", "").isupper() + + def test_uniqueness(self) -> None: + codes = {store.generate_code("telegram", str(i)) for i in range(20)} + assert len(codes) == 20 + + def test_ttl_expiration(self) -> None: + code = store.generate_code("telegram", "123", ttl=1) + assert store.approve_code(code) is not None + + code2 = store.generate_code("telegram", "456", ttl=0) + time.sleep(0.1) + assert store.approve_code(code2) is None + + +class TestApproveDeny: + def test_approve_moves_to_approved(self) -> None: + code = store.generate_code("telegram", "123") + assert store.is_approved("telegram", "123") is False + + result = store.approve_code(code) + assert result == ("telegram", "123") + assert store.is_approved("telegram", "123") is True + assert store.get_approved("telegram") == ["123"] + + def test_deny_removes_pending(self) -> None: + code = store.generate_code("telegram", "123") + assert store.deny_code(code) is True + assert store.approve_code(code) is None + + def test_deny_unknown_returns_false(self) -> None: + assert store.deny_code("UNKNOWN") is False + + def test_approve_expired_returns_none(self) -> None: + code = store.generate_code("telegram", "123", ttl=0) + time.sleep(0.1) + assert store.approve_code(code) is None + + +class TestRevoke: + def test_revoke_removes_sender(self) -> None: + code = store.generate_code("telegram", "123") + store.approve_code(code) + assert store.is_approved("telegram", "123") is True + + assert store.revoke("telegram", "123") is True + assert store.is_approved("telegram", "123") is False + assert store.get_approved("telegram") == [] + + def test_revoke_unknown_returns_false(self) -> None: + assert store.revoke("telegram", "999") is False + + +class TestListPending: + def test_empty(self) -> None: + assert store.list_pending() == [] + + def test_shows_pending(self) -> None: + store.generate_code("telegram", "123") + store.generate_code("discord", "456") + pending = store.list_pending() + assert len(pending) == 2 + channels = {p["channel"] for p in pending} + assert channels == {"telegram", "discord"} + + def test_expired_not_listed(self) -> None: + store.generate_code("telegram", "123", ttl=0) + time.sleep(0.1) + assert store.list_pending() == [] + + +class TestStoreDurability: + def test_corruption_recovery(self, tmp_path, monkeypatch) -> None: + path = tmp_path / "pairing.json" + path.write_text("not json{", encoding="utf-8") + monkeypatch.setattr(store, "_store_path", lambda: path) + + # Should recover gracefully and act as empty store + assert store.list_pending() == [] + assert store.is_approved("telegram", "123") is False