mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-19 16:12:30 +00:00
code-review fixes: fsync, entropy, is_dm propagation, tests
- Add os.fsync with Windows-compatible directory flush in pairing store - Increase pairing code length from 6 -> 8 characters for higher entropy - Remove SystemExit on empty allowFrom; empty list now defers to pairing - Update is_allowed docstring to document pairing fallback semantics - Propagate is_dm to Matrix (direct rooms) and Slack (im channels) - Slack _is_allowed now checks pairing store for DM allowlist mode - Fix /pairing revoke to accept optional channel argument - Move inline import time to module top-level - Add WebSocket comment explaining is_dm=True assumption - Add comprehensive tests for store and BaseChannel pairing integration - Fix existing tests that expected empty allowFrom to hard-exit Refs #3774
This commit is contained in:
parent
4c4a9ae590
commit
f8e7e50759
@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
@ -191,6 +192,10 @@ class BaseChannel(ABC):
|
||||
2. ``allowFrom`` list → allow if sender_id is present.
|
||||
3. Pairing store approved list → allow if previously approved.
|
||||
4. Otherwise deny.
|
||||
|
||||
An empty ``allowFrom`` list does not cause a hard exit; instead it
|
||||
defers to the pairing store so that unknown DM senders can request
|
||||
access via a pairing code.
|
||||
"""
|
||||
if isinstance(self.config, dict):
|
||||
if "allow_from" in self.config:
|
||||
@ -296,8 +301,6 @@ class BaseChannel(ABC):
|
||||
reply = "No pending pairing requests."
|
||||
else:
|
||||
lines = ["Pending pairing requests:"]
|
||||
import time
|
||||
|
||||
for item in pending:
|
||||
remaining = int(item.get("expires_at", 0) - time.time())
|
||||
expiry = f"{remaining}s" if remaining > 0 else "expired"
|
||||
@ -331,12 +334,14 @@ class BaseChannel(ABC):
|
||||
|
||||
elif sub == "revoke":
|
||||
if arg is None:
|
||||
reply = "Usage: `/pairing revoke <user_id>`"
|
||||
reply = "Usage: `/pairing revoke <user_id>` or `/pairing revoke <channel> <user_id>`"
|
||||
else:
|
||||
if revoke(self.name, arg):
|
||||
reply = f"Revoked {arg} from {self.name}"
|
||||
target_channel = parts[3] if len(parts) > 3 else self.name
|
||||
target_user = arg if len(parts) <= 3 else parts[3]
|
||||
if revoke(target_channel, target_user):
|
||||
reply = f"Revoked {target_user} from {target_channel}"
|
||||
else:
|
||||
reply = f"{arg} was not in the approved list for {self.name}"
|
||||
reply = f"{target_user} was not in the approved list for {target_channel}"
|
||||
|
||||
else:
|
||||
reply = (
|
||||
|
||||
@ -143,9 +143,9 @@ class ChannelManager:
|
||||
allow = cfg.get("allowFrom")
|
||||
else:
|
||||
allow = getattr(cfg, "allow_from", None)
|
||||
if allow == []:
|
||||
if allow is None:
|
||||
raise SystemExit(
|
||||
f'Error: "{name}" has empty allowFrom (denies all). '
|
||||
f'Error: "{name}" is missing allowFrom. '
|
||||
f'Set ["*"] to allow everyone, or add specific user IDs.'
|
||||
)
|
||||
|
||||
|
||||
@ -28,10 +28,11 @@ try:
|
||||
RoomMessageMedia,
|
||||
RoomMessageText,
|
||||
RoomSendError,
|
||||
RoomSendResponse,
|
||||
RoomTypingError,
|
||||
SyncError,
|
||||
UploadError, RoomSendResponse,
|
||||
)
|
||||
UploadError,
|
||||
)
|
||||
from nio.crypto.attachments import decrypt_attachment
|
||||
from nio.exceptions import EncryptionError
|
||||
except ImportError as e:
|
||||
@ -107,7 +108,7 @@ class _StreamBuf:
|
||||
|
||||
:ivar text: Stores the text content of the buffer.
|
||||
:type text: str
|
||||
:ivar event_id: Identifier for the associated event. None indicates no
|
||||
:ivar event_id: Identifier for the associated event. None indicates no
|
||||
specific event association.
|
||||
:type event_id: str | None
|
||||
:ivar last_edit: Timestamp of the most recent edit to the buffer.
|
||||
@ -140,19 +141,19 @@ def _build_matrix_text_content(
|
||||
) -> dict[str, object]:
|
||||
"""
|
||||
Constructs and returns a dictionary representing the matrix text content with optional
|
||||
HTML formatting and reference to an existing event for replacement. This function is
|
||||
HTML formatting and reference to an existing event for replacement. This function is
|
||||
primarily used to create content payloads compatible with the Matrix messaging protocol.
|
||||
|
||||
:param text: The plain text content to include in the message.
|
||||
:type text: str
|
||||
:param event_id: Optional ID of the event to replace. If provided, the function will
|
||||
include information indicating that the message is a replacement of the specified
|
||||
:param event_id: Optional ID of the event to replace. If provided, the function will
|
||||
include information indicating that the message is a replacement of the specified
|
||||
event.
|
||||
:type event_id: str | None
|
||||
:param thread_relates_to: Optional Matrix thread relation metadata. For edits this is
|
||||
stored in ``m.new_content`` so the replacement remains in the same thread.
|
||||
:type thread_relates_to: dict[str, object] | None
|
||||
:return: A dictionary containing the matrix text content, potentially enriched with
|
||||
:return: A dictionary containing the matrix text content, potentially enriched with
|
||||
HTML formatting and replacement metadata if applicable.
|
||||
:rtype: dict[str, object]
|
||||
"""
|
||||
@ -523,7 +524,7 @@ class MatrixChannel(BaseChannel):
|
||||
return
|
||||
|
||||
await self._stop_typing_keepalive(chat_id, clear_typing=True)
|
||||
|
||||
|
||||
content = _build_matrix_text_content(
|
||||
buf.text,
|
||||
buf.event_id,
|
||||
@ -537,7 +538,7 @@ class MatrixChannel(BaseChannel):
|
||||
buf = _StreamBuf()
|
||||
self._stream_bufs[chat_id] = buf
|
||||
buf.text += delta
|
||||
|
||||
|
||||
if not buf.text.strip():
|
||||
return
|
||||
|
||||
@ -870,6 +871,7 @@ class MatrixChannel(BaseChannel):
|
||||
await self._handle_message(
|
||||
sender_id=event.sender, chat_id=room.room_id,
|
||||
content=event.body, metadata=self._base_metadata(room, event),
|
||||
is_dm=self._is_direct_room(room),
|
||||
)
|
||||
except Exception:
|
||||
await self._stop_typing_keepalive(room.room_id, clear_typing=True)
|
||||
@ -907,6 +909,7 @@ class MatrixChannel(BaseChannel):
|
||||
content="\n".join(parts),
|
||||
media=[attachment["path"]] if attachment else [],
|
||||
metadata=meta,
|
||||
is_dm=self._is_direct_room(room),
|
||||
)
|
||||
except Exception:
|
||||
await self._stop_typing_keepalive(room.room_id, clear_typing=True)
|
||||
|
||||
@ -342,6 +342,22 @@ class SlackChannel(BaseChannel):
|
||||
channel_type = event.get("channel_type") or ""
|
||||
|
||||
if not self._is_allowed(sender_id, chat_id, channel_type):
|
||||
if channel_type == "im" and self.config.dm.enabled:
|
||||
from nanobot.pairing import generate_code
|
||||
code = generate_code(self.name, sender_id)
|
||||
reply = (
|
||||
"This assistant requires approval before it can respond.\n"
|
||||
f"Your pairing code is: `{code}`\n"
|
||||
f"Ask the owner to run: `nanobot pairing approve {code}`"
|
||||
)
|
||||
await self.send(
|
||||
OutboundMessage(
|
||||
channel=self.name,
|
||||
chat_id=chat_id,
|
||||
content=reply,
|
||||
metadata={"_pairing_code": code},
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
if channel_type != "im" and not self._should_respond_in_channel(event_type, text, chat_id):
|
||||
@ -608,11 +624,13 @@ class SlackChannel(BaseChannel):
|
||||
self.logger.debug("done reaction failed: {}", e)
|
||||
|
||||
def _is_allowed(self, sender_id: str, chat_id: str, channel_type: str) -> bool:
|
||||
from nanobot.pairing import is_approved
|
||||
|
||||
if channel_type == "im":
|
||||
if not self.config.dm.enabled:
|
||||
return False
|
||||
if self.config.dm.policy == "allowlist":
|
||||
return sender_id in self.config.dm.allow_from
|
||||
return sender_id in self.config.dm.allow_from or is_approved(self.name, sender_id)
|
||||
return True
|
||||
|
||||
# Group / channel messages
|
||||
|
||||
@ -1249,6 +1249,8 @@ class WebSocketChannel(BaseChannel):
|
||||
content = _parse_inbound_payload(raw)
|
||||
if content is None:
|
||||
continue
|
||||
# WebSocket connections are always treated as 1:1 (DM) because
|
||||
# each connection represents a single client browser/tab.
|
||||
await self._handle_message(
|
||||
sender_id=client_id,
|
||||
chat_id=default_chat_id,
|
||||
|
||||
@ -8,6 +8,7 @@ private-assistant scale: small JSON file, simple locking, no external DB.
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import secrets
|
||||
import string
|
||||
import threading
|
||||
@ -20,8 +21,9 @@ from loguru import logger
|
||||
from nanobot.config.paths import get_data_dir
|
||||
|
||||
_LOCK = threading.Lock()
|
||||
|
||||
_ALPHABET = string.ascii_uppercase + string.digits
|
||||
_CODE_LENGTH = 6 # e.g. XK9-42F
|
||||
_CODE_LENGTH = 8 # e.g. XK9-42F-MP
|
||||
_TTL_DEFAULT_S = 600 # 10 minutes
|
||||
|
||||
|
||||
@ -48,7 +50,17 @@ def _save(data: dict[str, Any]) -> None:
|
||||
with open(tmp, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||
f.flush()
|
||||
os.fsync(f.fileno())
|
||||
tmp.replace(path)
|
||||
# Ensure directory entry is flushed for durability (Unix only; no-op on Windows)
|
||||
try:
|
||||
fd = os.open(path.parent, os.O_RDONLY)
|
||||
try:
|
||||
os.fsync(fd)
|
||||
finally:
|
||||
os.close(fd)
|
||||
except (OSError, NotImplementedError):
|
||||
pass
|
||||
|
||||
|
||||
def _gc_pending(data: dict[str, Any]) -> None:
|
||||
@ -75,7 +87,7 @@ def generate_code(
|
||||
# Ensure uniqueness
|
||||
for _ in range(100):
|
||||
raw = "".join(secrets.choice(_ALPHABET) for _ in range(_CODE_LENGTH))
|
||||
code = f"{raw[:3]}-{raw[3:]}"
|
||||
code = f"{raw[:4]}-{raw[4:]}"
|
||||
if code not in data.get("pending", {}):
|
||||
break
|
||||
else: # pragma: no cover
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
@ -7,6 +9,11 @@ from nanobot.channels.base import BaseChannel
|
||||
|
||||
class _DummyChannel(BaseChannel):
|
||||
name = "dummy"
|
||||
_sent: list[OutboundMessage]
|
||||
|
||||
def __init__(self, config, bus):
|
||||
super().__init__(config, bus)
|
||||
self._sent = []
|
||||
|
||||
async def start(self) -> None:
|
||||
return None
|
||||
@ -15,7 +22,7 @@ class _DummyChannel(BaseChannel):
|
||||
return None
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
return None
|
||||
self._sent.append(msg)
|
||||
|
||||
|
||||
def test_is_allowed_requires_exact_match() -> None:
|
||||
@ -35,3 +42,94 @@ def test_is_allowed_denies_empty_dict_allow_from() -> None:
|
||||
channel = _DummyChannel({"allow_from": []}, MessageBus())
|
||||
|
||||
assert channel.is_allowed("alice") is False
|
||||
|
||||
|
||||
def test_is_allowed_star_allows_all() -> None:
|
||||
channel = _DummyChannel({"allowFrom": ["*"]}, MessageBus())
|
||||
assert channel.is_allowed("anyone") is True
|
||||
|
||||
|
||||
def test_is_allowed_pairing_fallback(monkeypatch) -> None:
|
||||
channel = _DummyChannel({"allowFrom": []}, MessageBus())
|
||||
monkeypatch.setattr(
|
||||
"nanobot.channels.base.is_approved", lambda _ch, sid: sid == "paired"
|
||||
)
|
||||
assert channel.is_allowed("paired") is True
|
||||
assert channel.is_allowed("unknown") is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_dm_sends_pairing_code(monkeypatch) -> None:
|
||||
channel = _DummyChannel({"allowFrom": []}, MessageBus())
|
||||
monkeypatch.setattr(
|
||||
"nanobot.channels.base.generate_code", lambda _ch, sid: "ABCD-EFGH"
|
||||
)
|
||||
|
||||
await channel._handle_message(
|
||||
sender_id="stranger", chat_id="chat1", content="hello", is_dm=True
|
||||
)
|
||||
|
||||
assert len(channel._sent) == 1
|
||||
msg = channel._sent[0]
|
||||
assert "ABCD-EFGH" in msg.content
|
||||
assert msg.metadata.get("_pairing_code") == "ABCD-EFGH"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_group_ignores_unknown() -> None:
|
||||
channel = _DummyChannel({"allowFrom": []}, MessageBus())
|
||||
|
||||
await channel._handle_message(
|
||||
sender_id="stranger", chat_id="chat1", content="hello", is_dm=False
|
||||
)
|
||||
|
||||
assert channel._sent == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_pairing_command_list(monkeypatch) -> None:
|
||||
channel = _DummyChannel({"allowFrom": ["owner"]}, MessageBus())
|
||||
monkeypatch.setattr(
|
||||
"nanobot.channels.base.list_pending",
|
||||
lambda: [
|
||||
{
|
||||
"code": "ABCD-EFGH",
|
||||
"channel": "dummy",
|
||||
"sender_id": "123",
|
||||
"expires_at": 9999999999,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
await channel._handle_pairing_command("owner", "chat1", "/pairing list")
|
||||
|
||||
assert len(channel._sent) == 1
|
||||
assert "ABCD-EFGH" in channel._sent[0].content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_pairing_command_approve(monkeypatch) -> None:
|
||||
channel = _DummyChannel({"allowFrom": ["owner"]}, MessageBus())
|
||||
monkeypatch.setattr(
|
||||
"nanobot.channels.base.approve_code",
|
||||
lambda code: ("dummy", "123") if code == "ABCD-EFGH" else None,
|
||||
)
|
||||
|
||||
await channel._handle_pairing_command("owner", "chat1", "/pairing approve ABCD-EFGH")
|
||||
|
||||
assert len(channel._sent) == 1
|
||||
assert "Approved" in channel._sent[0].content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_pairing_command_revoke(monkeypatch) -> None:
|
||||
channel = _DummyChannel({"allowFrom": ["owner"]}, MessageBus())
|
||||
monkeypatch.setattr(
|
||||
"nanobot.channels.base.revoke",
|
||||
lambda ch, sid: sid == "123",
|
||||
)
|
||||
|
||||
await channel._handle_pairing_command("owner", "chat1", "/pairing revoke 123")
|
||||
|
||||
assert len(channel._sent) == 1
|
||||
assert "Revoked" in channel._sent[0].content
|
||||
|
||||
@ -961,8 +961,8 @@ class _StartableChannel(BaseChannel):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_allow_from_raises_on_empty_list():
|
||||
"""_validate_allow_from should raise SystemExit when allow_from is empty list."""
|
||||
async def test_validate_allow_from_allows_empty_list():
|
||||
"""Empty allow_from is valid now — pairing store handles unapproved senders."""
|
||||
fake_config = SimpleNamespace(
|
||||
channels=ChannelsConfig(),
|
||||
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
|
||||
@ -973,10 +973,8 @@ async def test_validate_allow_from_raises_on_empty_list():
|
||||
mgr.channels = {"test": _ChannelWithAllowFrom(fake_config, None, [])}
|
||||
mgr._dispatch_task = None
|
||||
|
||||
with pytest.raises(SystemExit) as exc_info:
|
||||
mgr._validate_allow_from()
|
||||
|
||||
assert "empty allowFrom" in str(exc_info.value)
|
||||
# Should not raise — empty list defers to pairing store
|
||||
mgr._validate_allow_from()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -997,8 +995,8 @@ async def test_validate_allow_from_passes_with_asterisk():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_allow_from_raises_on_empty_dict_allow_from():
|
||||
"""_validate_allow_from should reject empty dict-backed allow_from lists."""
|
||||
async def test_validate_allow_from_allows_empty_dict_allow_from():
|
||||
"""Empty dict-backed allow_from is valid — pairing store handles approval."""
|
||||
fake_config = SimpleNamespace(
|
||||
channels=ChannelsConfig(),
|
||||
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
|
||||
@ -1009,10 +1007,7 @@ async def test_validate_allow_from_raises_on_empty_dict_allow_from():
|
||||
mgr.channels = {"test": _ChannelWithAllowFrom({"enabled": True}, None, [])}
|
||||
mgr._dispatch_task = None
|
||||
|
||||
with pytest.raises(SystemExit) as exc_info:
|
||||
mgr._validate_allow_from()
|
||||
|
||||
assert "empty allowFrom" in str(exc_info.value)
|
||||
mgr._validate_allow_from()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
99
tests/pairing/test_store.py
Normal file
99
tests/pairing/test_store.py
Normal file
@ -0,0 +1,99 @@
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.pairing import store
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _tmp_store(tmp_path, monkeypatch):
|
||||
path = tmp_path / "pairing.json"
|
||||
monkeypatch.setattr(store, "_store_path", lambda: path)
|
||||
|
||||
|
||||
class TestGenerateCode:
|
||||
def test_format(self) -> None:
|
||||
code = store.generate_code("telegram", "123")
|
||||
assert len(code) == 9 # 4 + 1 + 4
|
||||
assert code[4] == "-"
|
||||
assert code.replace("-", "").isalnum()
|
||||
assert code.replace("-", "").isupper()
|
||||
|
||||
def test_uniqueness(self) -> None:
|
||||
codes = {store.generate_code("telegram", str(i)) for i in range(20)}
|
||||
assert len(codes) == 20
|
||||
|
||||
def test_ttl_expiration(self) -> None:
|
||||
code = store.generate_code("telegram", "123", ttl=1)
|
||||
assert store.approve_code(code) is not None
|
||||
|
||||
code2 = store.generate_code("telegram", "456", ttl=0)
|
||||
time.sleep(0.1)
|
||||
assert store.approve_code(code2) is None
|
||||
|
||||
|
||||
class TestApproveDeny:
|
||||
def test_approve_moves_to_approved(self) -> None:
|
||||
code = store.generate_code("telegram", "123")
|
||||
assert store.is_approved("telegram", "123") is False
|
||||
|
||||
result = store.approve_code(code)
|
||||
assert result == ("telegram", "123")
|
||||
assert store.is_approved("telegram", "123") is True
|
||||
assert store.get_approved("telegram") == ["123"]
|
||||
|
||||
def test_deny_removes_pending(self) -> None:
|
||||
code = store.generate_code("telegram", "123")
|
||||
assert store.deny_code(code) is True
|
||||
assert store.approve_code(code) is None
|
||||
|
||||
def test_deny_unknown_returns_false(self) -> None:
|
||||
assert store.deny_code("UNKNOWN") is False
|
||||
|
||||
def test_approve_expired_returns_none(self) -> None:
|
||||
code = store.generate_code("telegram", "123", ttl=0)
|
||||
time.sleep(0.1)
|
||||
assert store.approve_code(code) is None
|
||||
|
||||
|
||||
class TestRevoke:
|
||||
def test_revoke_removes_sender(self) -> None:
|
||||
code = store.generate_code("telegram", "123")
|
||||
store.approve_code(code)
|
||||
assert store.is_approved("telegram", "123") is True
|
||||
|
||||
assert store.revoke("telegram", "123") is True
|
||||
assert store.is_approved("telegram", "123") is False
|
||||
assert store.get_approved("telegram") == []
|
||||
|
||||
def test_revoke_unknown_returns_false(self) -> None:
|
||||
assert store.revoke("telegram", "999") is False
|
||||
|
||||
|
||||
class TestListPending:
|
||||
def test_empty(self) -> None:
|
||||
assert store.list_pending() == []
|
||||
|
||||
def test_shows_pending(self) -> None:
|
||||
store.generate_code("telegram", "123")
|
||||
store.generate_code("discord", "456")
|
||||
pending = store.list_pending()
|
||||
assert len(pending) == 2
|
||||
channels = {p["channel"] for p in pending}
|
||||
assert channels == {"telegram", "discord"}
|
||||
|
||||
def test_expired_not_listed(self) -> None:
|
||||
store.generate_code("telegram", "123", ttl=0)
|
||||
time.sleep(0.1)
|
||||
assert store.list_pending() == []
|
||||
|
||||
|
||||
class TestStoreDurability:
|
||||
def test_corruption_recovery(self, tmp_path, monkeypatch) -> None:
|
||||
path = tmp_path / "pairing.json"
|
||||
path.write_text("not json{", encoding="utf-8")
|
||||
monkeypatch.setattr(store, "_store_path", lambda: path)
|
||||
|
||||
# Should recover gracefully and act as empty store
|
||||
assert store.list_pending() == []
|
||||
assert store.is_approved("telegram", "123") is False
|
||||
Loading…
x
Reference in New Issue
Block a user