mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-19 16:12:30 +00:00
code-review fixes: fsync, entropy, is_dm propagation, tests
- Add os.fsync with Windows-compatible directory flush in pairing store - Increase pairing code length from 6 -> 8 characters for higher entropy - Remove SystemExit on empty allowFrom; empty list now defers to pairing - Update is_allowed docstring to document pairing fallback semantics - Propagate is_dm to Matrix (direct rooms) and Slack (im channels) - Slack _is_allowed now checks pairing store for DM allowlist mode - Fix /pairing revoke to accept optional channel argument - Move inline import time to module top-level - Add WebSocket comment explaining is_dm=True assumption - Add comprehensive tests for store and BaseChannel pairing integration - Fix existing tests that expected empty allowFrom to hard-exit Refs #3774
This commit is contained in:
parent
4c4a9ae590
commit
f8e7e50759
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@ -191,6 +192,10 @@ class BaseChannel(ABC):
|
|||||||
2. ``allowFrom`` list → allow if sender_id is present.
|
2. ``allowFrom`` list → allow if sender_id is present.
|
||||||
3. Pairing store approved list → allow if previously approved.
|
3. Pairing store approved list → allow if previously approved.
|
||||||
4. Otherwise deny.
|
4. Otherwise deny.
|
||||||
|
|
||||||
|
An empty ``allowFrom`` list does not cause a hard exit; instead it
|
||||||
|
defers to the pairing store so that unknown DM senders can request
|
||||||
|
access via a pairing code.
|
||||||
"""
|
"""
|
||||||
if isinstance(self.config, dict):
|
if isinstance(self.config, dict):
|
||||||
if "allow_from" in self.config:
|
if "allow_from" in self.config:
|
||||||
@ -296,8 +301,6 @@ class BaseChannel(ABC):
|
|||||||
reply = "No pending pairing requests."
|
reply = "No pending pairing requests."
|
||||||
else:
|
else:
|
||||||
lines = ["Pending pairing requests:"]
|
lines = ["Pending pairing requests:"]
|
||||||
import time
|
|
||||||
|
|
||||||
for item in pending:
|
for item in pending:
|
||||||
remaining = int(item.get("expires_at", 0) - time.time())
|
remaining = int(item.get("expires_at", 0) - time.time())
|
||||||
expiry = f"{remaining}s" if remaining > 0 else "expired"
|
expiry = f"{remaining}s" if remaining > 0 else "expired"
|
||||||
@ -331,12 +334,14 @@ class BaseChannel(ABC):
|
|||||||
|
|
||||||
elif sub == "revoke":
|
elif sub == "revoke":
|
||||||
if arg is None:
|
if arg is None:
|
||||||
reply = "Usage: `/pairing revoke <user_id>`"
|
reply = "Usage: `/pairing revoke <user_id>` or `/pairing revoke <channel> <user_id>`"
|
||||||
else:
|
else:
|
||||||
if revoke(self.name, arg):
|
target_channel = parts[3] if len(parts) > 3 else self.name
|
||||||
reply = f"Revoked {arg} from {self.name}"
|
target_user = arg if len(parts) <= 3 else parts[3]
|
||||||
|
if revoke(target_channel, target_user):
|
||||||
|
reply = f"Revoked {target_user} from {target_channel}"
|
||||||
else:
|
else:
|
||||||
reply = f"{arg} was not in the approved list for {self.name}"
|
reply = f"{target_user} was not in the approved list for {target_channel}"
|
||||||
|
|
||||||
else:
|
else:
|
||||||
reply = (
|
reply = (
|
||||||
|
|||||||
@ -143,9 +143,9 @@ class ChannelManager:
|
|||||||
allow = cfg.get("allowFrom")
|
allow = cfg.get("allowFrom")
|
||||||
else:
|
else:
|
||||||
allow = getattr(cfg, "allow_from", None)
|
allow = getattr(cfg, "allow_from", None)
|
||||||
if allow == []:
|
if allow is None:
|
||||||
raise SystemExit(
|
raise SystemExit(
|
||||||
f'Error: "{name}" has empty allowFrom (denies all). '
|
f'Error: "{name}" is missing allowFrom. '
|
||||||
f'Set ["*"] to allow everyone, or add specific user IDs.'
|
f'Set ["*"] to allow everyone, or add specific user IDs.'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -28,10 +28,11 @@ try:
|
|||||||
RoomMessageMedia,
|
RoomMessageMedia,
|
||||||
RoomMessageText,
|
RoomMessageText,
|
||||||
RoomSendError,
|
RoomSendError,
|
||||||
|
RoomSendResponse,
|
||||||
RoomTypingError,
|
RoomTypingError,
|
||||||
SyncError,
|
SyncError,
|
||||||
UploadError, RoomSendResponse,
|
UploadError,
|
||||||
)
|
)
|
||||||
from nio.crypto.attachments import decrypt_attachment
|
from nio.crypto.attachments import decrypt_attachment
|
||||||
from nio.exceptions import EncryptionError
|
from nio.exceptions import EncryptionError
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
@ -107,7 +108,7 @@ class _StreamBuf:
|
|||||||
|
|
||||||
:ivar text: Stores the text content of the buffer.
|
:ivar text: Stores the text content of the buffer.
|
||||||
:type text: str
|
:type text: str
|
||||||
:ivar event_id: Identifier for the associated event. None indicates no
|
:ivar event_id: Identifier for the associated event. None indicates no
|
||||||
specific event association.
|
specific event association.
|
||||||
:type event_id: str | None
|
:type event_id: str | None
|
||||||
:ivar last_edit: Timestamp of the most recent edit to the buffer.
|
:ivar last_edit: Timestamp of the most recent edit to the buffer.
|
||||||
@ -140,19 +141,19 @@ def _build_matrix_text_content(
|
|||||||
) -> dict[str, object]:
|
) -> dict[str, object]:
|
||||||
"""
|
"""
|
||||||
Constructs and returns a dictionary representing the matrix text content with optional
|
Constructs and returns a dictionary representing the matrix text content with optional
|
||||||
HTML formatting and reference to an existing event for replacement. This function is
|
HTML formatting and reference to an existing event for replacement. This function is
|
||||||
primarily used to create content payloads compatible with the Matrix messaging protocol.
|
primarily used to create content payloads compatible with the Matrix messaging protocol.
|
||||||
|
|
||||||
:param text: The plain text content to include in the message.
|
:param text: The plain text content to include in the message.
|
||||||
:type text: str
|
:type text: str
|
||||||
:param event_id: Optional ID of the event to replace. If provided, the function will
|
:param event_id: Optional ID of the event to replace. If provided, the function will
|
||||||
include information indicating that the message is a replacement of the specified
|
include information indicating that the message is a replacement of the specified
|
||||||
event.
|
event.
|
||||||
:type event_id: str | None
|
:type event_id: str | None
|
||||||
:param thread_relates_to: Optional Matrix thread relation metadata. For edits this is
|
:param thread_relates_to: Optional Matrix thread relation metadata. For edits this is
|
||||||
stored in ``m.new_content`` so the replacement remains in the same thread.
|
stored in ``m.new_content`` so the replacement remains in the same thread.
|
||||||
:type thread_relates_to: dict[str, object] | None
|
:type thread_relates_to: dict[str, object] | None
|
||||||
:return: A dictionary containing the matrix text content, potentially enriched with
|
:return: A dictionary containing the matrix text content, potentially enriched with
|
||||||
HTML formatting and replacement metadata if applicable.
|
HTML formatting and replacement metadata if applicable.
|
||||||
:rtype: dict[str, object]
|
:rtype: dict[str, object]
|
||||||
"""
|
"""
|
||||||
@ -523,7 +524,7 @@ class MatrixChannel(BaseChannel):
|
|||||||
return
|
return
|
||||||
|
|
||||||
await self._stop_typing_keepalive(chat_id, clear_typing=True)
|
await self._stop_typing_keepalive(chat_id, clear_typing=True)
|
||||||
|
|
||||||
content = _build_matrix_text_content(
|
content = _build_matrix_text_content(
|
||||||
buf.text,
|
buf.text,
|
||||||
buf.event_id,
|
buf.event_id,
|
||||||
@ -537,7 +538,7 @@ class MatrixChannel(BaseChannel):
|
|||||||
buf = _StreamBuf()
|
buf = _StreamBuf()
|
||||||
self._stream_bufs[chat_id] = buf
|
self._stream_bufs[chat_id] = buf
|
||||||
buf.text += delta
|
buf.text += delta
|
||||||
|
|
||||||
if not buf.text.strip():
|
if not buf.text.strip():
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -870,6 +871,7 @@ class MatrixChannel(BaseChannel):
|
|||||||
await self._handle_message(
|
await self._handle_message(
|
||||||
sender_id=event.sender, chat_id=room.room_id,
|
sender_id=event.sender, chat_id=room.room_id,
|
||||||
content=event.body, metadata=self._base_metadata(room, event),
|
content=event.body, metadata=self._base_metadata(room, event),
|
||||||
|
is_dm=self._is_direct_room(room),
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
await self._stop_typing_keepalive(room.room_id, clear_typing=True)
|
await self._stop_typing_keepalive(room.room_id, clear_typing=True)
|
||||||
@ -907,6 +909,7 @@ class MatrixChannel(BaseChannel):
|
|||||||
content="\n".join(parts),
|
content="\n".join(parts),
|
||||||
media=[attachment["path"]] if attachment else [],
|
media=[attachment["path"]] if attachment else [],
|
||||||
metadata=meta,
|
metadata=meta,
|
||||||
|
is_dm=self._is_direct_room(room),
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
await self._stop_typing_keepalive(room.room_id, clear_typing=True)
|
await self._stop_typing_keepalive(room.room_id, clear_typing=True)
|
||||||
|
|||||||
@ -342,6 +342,22 @@ class SlackChannel(BaseChannel):
|
|||||||
channel_type = event.get("channel_type") or ""
|
channel_type = event.get("channel_type") or ""
|
||||||
|
|
||||||
if not self._is_allowed(sender_id, chat_id, channel_type):
|
if not self._is_allowed(sender_id, chat_id, channel_type):
|
||||||
|
if channel_type == "im" and self.config.dm.enabled:
|
||||||
|
from nanobot.pairing import generate_code
|
||||||
|
code = generate_code(self.name, sender_id)
|
||||||
|
reply = (
|
||||||
|
"This assistant requires approval before it can respond.\n"
|
||||||
|
f"Your pairing code is: `{code}`\n"
|
||||||
|
f"Ask the owner to run: `nanobot pairing approve {code}`"
|
||||||
|
)
|
||||||
|
await self.send(
|
||||||
|
OutboundMessage(
|
||||||
|
channel=self.name,
|
||||||
|
chat_id=chat_id,
|
||||||
|
content=reply,
|
||||||
|
metadata={"_pairing_code": code},
|
||||||
|
)
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
if channel_type != "im" and not self._should_respond_in_channel(event_type, text, chat_id):
|
if channel_type != "im" and not self._should_respond_in_channel(event_type, text, chat_id):
|
||||||
@ -608,11 +624,13 @@ class SlackChannel(BaseChannel):
|
|||||||
self.logger.debug("done reaction failed: {}", e)
|
self.logger.debug("done reaction failed: {}", e)
|
||||||
|
|
||||||
def _is_allowed(self, sender_id: str, chat_id: str, channel_type: str) -> bool:
|
def _is_allowed(self, sender_id: str, chat_id: str, channel_type: str) -> bool:
|
||||||
|
from nanobot.pairing import is_approved
|
||||||
|
|
||||||
if channel_type == "im":
|
if channel_type == "im":
|
||||||
if not self.config.dm.enabled:
|
if not self.config.dm.enabled:
|
||||||
return False
|
return False
|
||||||
if self.config.dm.policy == "allowlist":
|
if self.config.dm.policy == "allowlist":
|
||||||
return sender_id in self.config.dm.allow_from
|
return sender_id in self.config.dm.allow_from or is_approved(self.name, sender_id)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Group / channel messages
|
# Group / channel messages
|
||||||
|
|||||||
@ -1249,6 +1249,8 @@ class WebSocketChannel(BaseChannel):
|
|||||||
content = _parse_inbound_payload(raw)
|
content = _parse_inbound_payload(raw)
|
||||||
if content is None:
|
if content is None:
|
||||||
continue
|
continue
|
||||||
|
# WebSocket connections are always treated as 1:1 (DM) because
|
||||||
|
# each connection represents a single client browser/tab.
|
||||||
await self._handle_message(
|
await self._handle_message(
|
||||||
sender_id=client_id,
|
sender_id=client_id,
|
||||||
chat_id=default_chat_id,
|
chat_id=default_chat_id,
|
||||||
|
|||||||
@ -8,6 +8,7 @@ private-assistant scale: small JSON file, simple locking, no external DB.
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
import secrets
|
import secrets
|
||||||
import string
|
import string
|
||||||
import threading
|
import threading
|
||||||
@ -20,8 +21,9 @@ from loguru import logger
|
|||||||
from nanobot.config.paths import get_data_dir
|
from nanobot.config.paths import get_data_dir
|
||||||
|
|
||||||
_LOCK = threading.Lock()
|
_LOCK = threading.Lock()
|
||||||
|
|
||||||
_ALPHABET = string.ascii_uppercase + string.digits
|
_ALPHABET = string.ascii_uppercase + string.digits
|
||||||
_CODE_LENGTH = 6 # e.g. XK9-42F
|
_CODE_LENGTH = 8 # e.g. XK9-42F-MP
|
||||||
_TTL_DEFAULT_S = 600 # 10 minutes
|
_TTL_DEFAULT_S = 600 # 10 minutes
|
||||||
|
|
||||||
|
|
||||||
@ -48,7 +50,17 @@ def _save(data: dict[str, Any]) -> None:
|
|||||||
with open(tmp, "w", encoding="utf-8") as f:
|
with open(tmp, "w", encoding="utf-8") as f:
|
||||||
json.dump(data, f, indent=2, ensure_ascii=False)
|
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||||
f.flush()
|
f.flush()
|
||||||
|
os.fsync(f.fileno())
|
||||||
tmp.replace(path)
|
tmp.replace(path)
|
||||||
|
# Ensure directory entry is flushed for durability (Unix only; no-op on Windows)
|
||||||
|
try:
|
||||||
|
fd = os.open(path.parent, os.O_RDONLY)
|
||||||
|
try:
|
||||||
|
os.fsync(fd)
|
||||||
|
finally:
|
||||||
|
os.close(fd)
|
||||||
|
except (OSError, NotImplementedError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _gc_pending(data: dict[str, Any]) -> None:
|
def _gc_pending(data: dict[str, Any]) -> None:
|
||||||
@ -75,7 +87,7 @@ def generate_code(
|
|||||||
# Ensure uniqueness
|
# Ensure uniqueness
|
||||||
for _ in range(100):
|
for _ in range(100):
|
||||||
raw = "".join(secrets.choice(_ALPHABET) for _ in range(_CODE_LENGTH))
|
raw = "".join(secrets.choice(_ALPHABET) for _ in range(_CODE_LENGTH))
|
||||||
code = f"{raw[:3]}-{raw[3:]}"
|
code = f"{raw[:4]}-{raw[4:]}"
|
||||||
if code not in data.get("pending", {}):
|
if code not in data.get("pending", {}):
|
||||||
break
|
break
|
||||||
else: # pragma: no cover
|
else: # pragma: no cover
|
||||||
|
|||||||
@ -1,5 +1,7 @@
|
|||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from nanobot.bus.events import OutboundMessage
|
from nanobot.bus.events import OutboundMessage
|
||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
from nanobot.channels.base import BaseChannel
|
from nanobot.channels.base import BaseChannel
|
||||||
@ -7,6 +9,11 @@ from nanobot.channels.base import BaseChannel
|
|||||||
|
|
||||||
class _DummyChannel(BaseChannel):
|
class _DummyChannel(BaseChannel):
|
||||||
name = "dummy"
|
name = "dummy"
|
||||||
|
_sent: list[OutboundMessage]
|
||||||
|
|
||||||
|
def __init__(self, config, bus):
|
||||||
|
super().__init__(config, bus)
|
||||||
|
self._sent = []
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
return None
|
return None
|
||||||
@ -15,7 +22,7 @@ class _DummyChannel(BaseChannel):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
async def send(self, msg: OutboundMessage) -> None:
|
async def send(self, msg: OutboundMessage) -> None:
|
||||||
return None
|
self._sent.append(msg)
|
||||||
|
|
||||||
|
|
||||||
def test_is_allowed_requires_exact_match() -> None:
|
def test_is_allowed_requires_exact_match() -> None:
|
||||||
@ -35,3 +42,94 @@ def test_is_allowed_denies_empty_dict_allow_from() -> None:
|
|||||||
channel = _DummyChannel({"allow_from": []}, MessageBus())
|
channel = _DummyChannel({"allow_from": []}, MessageBus())
|
||||||
|
|
||||||
assert channel.is_allowed("alice") is False
|
assert channel.is_allowed("alice") is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_allowed_star_allows_all() -> None:
|
||||||
|
channel = _DummyChannel({"allowFrom": ["*"]}, MessageBus())
|
||||||
|
assert channel.is_allowed("anyone") is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_allowed_pairing_fallback(monkeypatch) -> None:
|
||||||
|
channel = _DummyChannel({"allowFrom": []}, MessageBus())
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"nanobot.channels.base.is_approved", lambda _ch, sid: sid == "paired"
|
||||||
|
)
|
||||||
|
assert channel.is_allowed("paired") is True
|
||||||
|
assert channel.is_allowed("unknown") is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_message_dm_sends_pairing_code(monkeypatch) -> None:
|
||||||
|
channel = _DummyChannel({"allowFrom": []}, MessageBus())
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"nanobot.channels.base.generate_code", lambda _ch, sid: "ABCD-EFGH"
|
||||||
|
)
|
||||||
|
|
||||||
|
await channel._handle_message(
|
||||||
|
sender_id="stranger", chat_id="chat1", content="hello", is_dm=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(channel._sent) == 1
|
||||||
|
msg = channel._sent[0]
|
||||||
|
assert "ABCD-EFGH" in msg.content
|
||||||
|
assert msg.metadata.get("_pairing_code") == "ABCD-EFGH"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_message_group_ignores_unknown() -> None:
|
||||||
|
channel = _DummyChannel({"allowFrom": []}, MessageBus())
|
||||||
|
|
||||||
|
await channel._handle_message(
|
||||||
|
sender_id="stranger", chat_id="chat1", content="hello", is_dm=False
|
||||||
|
)
|
||||||
|
|
||||||
|
assert channel._sent == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_pairing_command_list(monkeypatch) -> None:
|
||||||
|
channel = _DummyChannel({"allowFrom": ["owner"]}, MessageBus())
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"nanobot.channels.base.list_pending",
|
||||||
|
lambda: [
|
||||||
|
{
|
||||||
|
"code": "ABCD-EFGH",
|
||||||
|
"channel": "dummy",
|
||||||
|
"sender_id": "123",
|
||||||
|
"expires_at": 9999999999,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
await channel._handle_pairing_command("owner", "chat1", "/pairing list")
|
||||||
|
|
||||||
|
assert len(channel._sent) == 1
|
||||||
|
assert "ABCD-EFGH" in channel._sent[0].content
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_pairing_command_approve(monkeypatch) -> None:
|
||||||
|
channel = _DummyChannel({"allowFrom": ["owner"]}, MessageBus())
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"nanobot.channels.base.approve_code",
|
||||||
|
lambda code: ("dummy", "123") if code == "ABCD-EFGH" else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
await channel._handle_pairing_command("owner", "chat1", "/pairing approve ABCD-EFGH")
|
||||||
|
|
||||||
|
assert len(channel._sent) == 1
|
||||||
|
assert "Approved" in channel._sent[0].content
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_pairing_command_revoke(monkeypatch) -> None:
|
||||||
|
channel = _DummyChannel({"allowFrom": ["owner"]}, MessageBus())
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"nanobot.channels.base.revoke",
|
||||||
|
lambda ch, sid: sid == "123",
|
||||||
|
)
|
||||||
|
|
||||||
|
await channel._handle_pairing_command("owner", "chat1", "/pairing revoke 123")
|
||||||
|
|
||||||
|
assert len(channel._sent) == 1
|
||||||
|
assert "Revoked" in channel._sent[0].content
|
||||||
|
|||||||
@ -961,8 +961,8 @@ class _StartableChannel(BaseChannel):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_validate_allow_from_raises_on_empty_list():
|
async def test_validate_allow_from_allows_empty_list():
|
||||||
"""_validate_allow_from should raise SystemExit when allow_from is empty list."""
|
"""Empty allow_from is valid now — pairing store handles unapproved senders."""
|
||||||
fake_config = SimpleNamespace(
|
fake_config = SimpleNamespace(
|
||||||
channels=ChannelsConfig(),
|
channels=ChannelsConfig(),
|
||||||
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
|
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
|
||||||
@ -973,10 +973,8 @@ async def test_validate_allow_from_raises_on_empty_list():
|
|||||||
mgr.channels = {"test": _ChannelWithAllowFrom(fake_config, None, [])}
|
mgr.channels = {"test": _ChannelWithAllowFrom(fake_config, None, [])}
|
||||||
mgr._dispatch_task = None
|
mgr._dispatch_task = None
|
||||||
|
|
||||||
with pytest.raises(SystemExit) as exc_info:
|
# Should not raise — empty list defers to pairing store
|
||||||
mgr._validate_allow_from()
|
mgr._validate_allow_from()
|
||||||
|
|
||||||
assert "empty allowFrom" in str(exc_info.value)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -997,8 +995,8 @@ async def test_validate_allow_from_passes_with_asterisk():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_validate_allow_from_raises_on_empty_dict_allow_from():
|
async def test_validate_allow_from_allows_empty_dict_allow_from():
|
||||||
"""_validate_allow_from should reject empty dict-backed allow_from lists."""
|
"""Empty dict-backed allow_from is valid — pairing store handles approval."""
|
||||||
fake_config = SimpleNamespace(
|
fake_config = SimpleNamespace(
|
||||||
channels=ChannelsConfig(),
|
channels=ChannelsConfig(),
|
||||||
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
|
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
|
||||||
@ -1009,10 +1007,7 @@ async def test_validate_allow_from_raises_on_empty_dict_allow_from():
|
|||||||
mgr.channels = {"test": _ChannelWithAllowFrom({"enabled": True}, None, [])}
|
mgr.channels = {"test": _ChannelWithAllowFrom({"enabled": True}, None, [])}
|
||||||
mgr._dispatch_task = None
|
mgr._dispatch_task = None
|
||||||
|
|
||||||
with pytest.raises(SystemExit) as exc_info:
|
mgr._validate_allow_from()
|
||||||
mgr._validate_allow_from()
|
|
||||||
|
|
||||||
assert "empty allowFrom" in str(exc_info.value)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
99
tests/pairing/test_store.py
Normal file
99
tests/pairing/test_store.py
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
import time
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.pairing import store
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _tmp_store(tmp_path, monkeypatch):
|
||||||
|
path = tmp_path / "pairing.json"
|
||||||
|
monkeypatch.setattr(store, "_store_path", lambda: path)
|
||||||
|
|
||||||
|
|
||||||
|
class TestGenerateCode:
|
||||||
|
def test_format(self) -> None:
|
||||||
|
code = store.generate_code("telegram", "123")
|
||||||
|
assert len(code) == 9 # 4 + 1 + 4
|
||||||
|
assert code[4] == "-"
|
||||||
|
assert code.replace("-", "").isalnum()
|
||||||
|
assert code.replace("-", "").isupper()
|
||||||
|
|
||||||
|
def test_uniqueness(self) -> None:
|
||||||
|
codes = {store.generate_code("telegram", str(i)) for i in range(20)}
|
||||||
|
assert len(codes) == 20
|
||||||
|
|
||||||
|
def test_ttl_expiration(self) -> None:
|
||||||
|
code = store.generate_code("telegram", "123", ttl=1)
|
||||||
|
assert store.approve_code(code) is not None
|
||||||
|
|
||||||
|
code2 = store.generate_code("telegram", "456", ttl=0)
|
||||||
|
time.sleep(0.1)
|
||||||
|
assert store.approve_code(code2) is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestApproveDeny:
|
||||||
|
def test_approve_moves_to_approved(self) -> None:
|
||||||
|
code = store.generate_code("telegram", "123")
|
||||||
|
assert store.is_approved("telegram", "123") is False
|
||||||
|
|
||||||
|
result = store.approve_code(code)
|
||||||
|
assert result == ("telegram", "123")
|
||||||
|
assert store.is_approved("telegram", "123") is True
|
||||||
|
assert store.get_approved("telegram") == ["123"]
|
||||||
|
|
||||||
|
def test_deny_removes_pending(self) -> None:
|
||||||
|
code = store.generate_code("telegram", "123")
|
||||||
|
assert store.deny_code(code) is True
|
||||||
|
assert store.approve_code(code) is None
|
||||||
|
|
||||||
|
def test_deny_unknown_returns_false(self) -> None:
|
||||||
|
assert store.deny_code("UNKNOWN") is False
|
||||||
|
|
||||||
|
def test_approve_expired_returns_none(self) -> None:
|
||||||
|
code = store.generate_code("telegram", "123", ttl=0)
|
||||||
|
time.sleep(0.1)
|
||||||
|
assert store.approve_code(code) is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestRevoke:
|
||||||
|
def test_revoke_removes_sender(self) -> None:
|
||||||
|
code = store.generate_code("telegram", "123")
|
||||||
|
store.approve_code(code)
|
||||||
|
assert store.is_approved("telegram", "123") is True
|
||||||
|
|
||||||
|
assert store.revoke("telegram", "123") is True
|
||||||
|
assert store.is_approved("telegram", "123") is False
|
||||||
|
assert store.get_approved("telegram") == []
|
||||||
|
|
||||||
|
def test_revoke_unknown_returns_false(self) -> None:
|
||||||
|
assert store.revoke("telegram", "999") is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestListPending:
|
||||||
|
def test_empty(self) -> None:
|
||||||
|
assert store.list_pending() == []
|
||||||
|
|
||||||
|
def test_shows_pending(self) -> None:
|
||||||
|
store.generate_code("telegram", "123")
|
||||||
|
store.generate_code("discord", "456")
|
||||||
|
pending = store.list_pending()
|
||||||
|
assert len(pending) == 2
|
||||||
|
channels = {p["channel"] for p in pending}
|
||||||
|
assert channels == {"telegram", "discord"}
|
||||||
|
|
||||||
|
def test_expired_not_listed(self) -> None:
|
||||||
|
store.generate_code("telegram", "123", ttl=0)
|
||||||
|
time.sleep(0.1)
|
||||||
|
assert store.list_pending() == []
|
||||||
|
|
||||||
|
|
||||||
|
class TestStoreDurability:
|
||||||
|
def test_corruption_recovery(self, tmp_path, monkeypatch) -> None:
|
||||||
|
path = tmp_path / "pairing.json"
|
||||||
|
path.write_text("not json{", encoding="utf-8")
|
||||||
|
monkeypatch.setattr(store, "_store_path", lambda: path)
|
||||||
|
|
||||||
|
# Should recover gracefully and act as empty store
|
||||||
|
assert store.list_pending() == []
|
||||||
|
assert store.is_approved("telegram", "123") is False
|
||||||
Loading…
x
Reference in New Issue
Block a user