code-review fixes: fsync, entropy, is_dm propagation, tests

- Add os.fsync with Windows-compatible directory flush in pairing store
- Increase pairing code length from 6 -> 8 characters for higher entropy
- Remove SystemExit on empty allowFrom; empty list now defers to pairing
- Update is_allowed docstring to document pairing fallback semantics
- Propagate is_dm to Matrix (direct rooms) and Slack (im channels)
- Slack _is_allowed now checks pairing store for DM allowlist mode
- Fix /pairing revoke to accept optional channel argument
- Move inline import time to module top-level
- Add WebSocket comment explaining is_dm=True assumption
- Add comprehensive tests for store and BaseChannel pairing integration
- Fix existing tests that expected empty allowFrom to hard-exit

Refs #3774
This commit is contained in:
chengyongru 2026-05-14 11:03:17 +08:00 committed by Xubin Ren
parent 4c4a9ae590
commit f8e7e50759
9 changed files with 265 additions and 33 deletions

View File

@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
@ -191,6 +192,10 @@ class BaseChannel(ABC):
2. ``allowFrom`` list allow if sender_id is present. 2. ``allowFrom`` list allow if sender_id is present.
3. Pairing store approved list allow if previously approved. 3. Pairing store approved list allow if previously approved.
4. Otherwise deny. 4. Otherwise deny.
An empty ``allowFrom`` list does not cause a hard exit; instead it
defers to the pairing store so that unknown DM senders can request
access via a pairing code.
""" """
if isinstance(self.config, dict): if isinstance(self.config, dict):
if "allow_from" in self.config: if "allow_from" in self.config:
@ -296,8 +301,6 @@ class BaseChannel(ABC):
reply = "No pending pairing requests." reply = "No pending pairing requests."
else: else:
lines = ["Pending pairing requests:"] lines = ["Pending pairing requests:"]
import time
for item in pending: for item in pending:
remaining = int(item.get("expires_at", 0) - time.time()) remaining = int(item.get("expires_at", 0) - time.time())
expiry = f"{remaining}s" if remaining > 0 else "expired" expiry = f"{remaining}s" if remaining > 0 else "expired"
@ -331,12 +334,14 @@ class BaseChannel(ABC):
elif sub == "revoke": elif sub == "revoke":
if arg is None: if arg is None:
reply = "Usage: `/pairing revoke <user_id>`" reply = "Usage: `/pairing revoke <user_id>` or `/pairing revoke <channel> <user_id>`"
else: else:
if revoke(self.name, arg): target_channel = parts[3] if len(parts) > 3 else self.name
reply = f"Revoked {arg} from {self.name}" target_user = arg if len(parts) <= 3 else parts[3]
if revoke(target_channel, target_user):
reply = f"Revoked {target_user} from {target_channel}"
else: else:
reply = f"{arg} was not in the approved list for {self.name}" reply = f"{target_user} was not in the approved list for {target_channel}"
else: else:
reply = ( reply = (

View File

@ -143,9 +143,9 @@ class ChannelManager:
allow = cfg.get("allowFrom") allow = cfg.get("allowFrom")
else: else:
allow = getattr(cfg, "allow_from", None) allow = getattr(cfg, "allow_from", None)
if allow == []: if allow is None:
raise SystemExit( raise SystemExit(
f'Error: "{name}" has empty allowFrom (denies all). ' f'Error: "{name}" is missing allowFrom. '
f'Set ["*"] to allow everyone, or add specific user IDs.' f'Set ["*"] to allow everyone, or add specific user IDs.'
) )

View File

@ -28,10 +28,11 @@ try:
RoomMessageMedia, RoomMessageMedia,
RoomMessageText, RoomMessageText,
RoomSendError, RoomSendError,
RoomSendResponse,
RoomTypingError, RoomTypingError,
SyncError, SyncError,
UploadError, RoomSendResponse, UploadError,
) )
from nio.crypto.attachments import decrypt_attachment from nio.crypto.attachments import decrypt_attachment
from nio.exceptions import EncryptionError from nio.exceptions import EncryptionError
except ImportError as e: except ImportError as e:
@ -870,6 +871,7 @@ class MatrixChannel(BaseChannel):
await self._handle_message( await self._handle_message(
sender_id=event.sender, chat_id=room.room_id, sender_id=event.sender, chat_id=room.room_id,
content=event.body, metadata=self._base_metadata(room, event), content=event.body, metadata=self._base_metadata(room, event),
is_dm=self._is_direct_room(room),
) )
except Exception: except Exception:
await self._stop_typing_keepalive(room.room_id, clear_typing=True) await self._stop_typing_keepalive(room.room_id, clear_typing=True)
@ -907,6 +909,7 @@ class MatrixChannel(BaseChannel):
content="\n".join(parts), content="\n".join(parts),
media=[attachment["path"]] if attachment else [], media=[attachment["path"]] if attachment else [],
metadata=meta, metadata=meta,
is_dm=self._is_direct_room(room),
) )
except Exception: except Exception:
await self._stop_typing_keepalive(room.room_id, clear_typing=True) await self._stop_typing_keepalive(room.room_id, clear_typing=True)

View File

@ -342,6 +342,22 @@ class SlackChannel(BaseChannel):
channel_type = event.get("channel_type") or "" channel_type = event.get("channel_type") or ""
if not self._is_allowed(sender_id, chat_id, channel_type): if not self._is_allowed(sender_id, chat_id, channel_type):
if channel_type == "im" and self.config.dm.enabled:
from nanobot.pairing import generate_code
code = generate_code(self.name, sender_id)
reply = (
"This assistant requires approval before it can respond.\n"
f"Your pairing code is: `{code}`\n"
f"Ask the owner to run: `nanobot pairing approve {code}`"
)
await self.send(
OutboundMessage(
channel=self.name,
chat_id=chat_id,
content=reply,
metadata={"_pairing_code": code},
)
)
return return
if channel_type != "im" and not self._should_respond_in_channel(event_type, text, chat_id): if channel_type != "im" and not self._should_respond_in_channel(event_type, text, chat_id):
@ -608,11 +624,13 @@ class SlackChannel(BaseChannel):
self.logger.debug("done reaction failed: {}", e) self.logger.debug("done reaction failed: {}", e)
def _is_allowed(self, sender_id: str, chat_id: str, channel_type: str) -> bool: def _is_allowed(self, sender_id: str, chat_id: str, channel_type: str) -> bool:
from nanobot.pairing import is_approved
if channel_type == "im": if channel_type == "im":
if not self.config.dm.enabled: if not self.config.dm.enabled:
return False return False
if self.config.dm.policy == "allowlist": if self.config.dm.policy == "allowlist":
return sender_id in self.config.dm.allow_from return sender_id in self.config.dm.allow_from or is_approved(self.name, sender_id)
return True return True
# Group / channel messages # Group / channel messages

View File

@ -1249,6 +1249,8 @@ class WebSocketChannel(BaseChannel):
content = _parse_inbound_payload(raw) content = _parse_inbound_payload(raw)
if content is None: if content is None:
continue continue
# WebSocket connections are always treated as 1:1 (DM) because
# each connection represents a single client browser/tab.
await self._handle_message( await self._handle_message(
sender_id=client_id, sender_id=client_id,
chat_id=default_chat_id, chat_id=default_chat_id,

View File

@ -8,6 +8,7 @@ private-assistant scale: small JSON file, simple locking, no external DB.
from __future__ import annotations from __future__ import annotations
import json import json
import os
import secrets import secrets
import string import string
import threading import threading
@ -20,8 +21,9 @@ from loguru import logger
from nanobot.config.paths import get_data_dir from nanobot.config.paths import get_data_dir
_LOCK = threading.Lock() _LOCK = threading.Lock()
_ALPHABET = string.ascii_uppercase + string.digits _ALPHABET = string.ascii_uppercase + string.digits
_CODE_LENGTH = 6 # e.g. XK9-42F _CODE_LENGTH = 8 # e.g. XK9-42F-MP
_TTL_DEFAULT_S = 600 # 10 minutes _TTL_DEFAULT_S = 600 # 10 minutes
@ -48,7 +50,17 @@ def _save(data: dict[str, Any]) -> None:
with open(tmp, "w", encoding="utf-8") as f: with open(tmp, "w", encoding="utf-8") as f:
json.dump(data, f, indent=2, ensure_ascii=False) json.dump(data, f, indent=2, ensure_ascii=False)
f.flush() f.flush()
os.fsync(f.fileno())
tmp.replace(path) tmp.replace(path)
# Ensure directory entry is flushed for durability (Unix only; no-op on Windows)
try:
fd = os.open(path.parent, os.O_RDONLY)
try:
os.fsync(fd)
finally:
os.close(fd)
except (OSError, NotImplementedError):
pass
def _gc_pending(data: dict[str, Any]) -> None: def _gc_pending(data: dict[str, Any]) -> None:
@ -75,7 +87,7 @@ def generate_code(
# Ensure uniqueness # Ensure uniqueness
for _ in range(100): for _ in range(100):
raw = "".join(secrets.choice(_ALPHABET) for _ in range(_CODE_LENGTH)) raw = "".join(secrets.choice(_ALPHABET) for _ in range(_CODE_LENGTH))
code = f"{raw[:3]}-{raw[3:]}" code = f"{raw[:4]}-{raw[4:]}"
if code not in data.get("pending", {}): if code not in data.get("pending", {}):
break break
else: # pragma: no cover else: # pragma: no cover

View File

@ -1,5 +1,7 @@
from types import SimpleNamespace from types import SimpleNamespace
import pytest
from nanobot.bus.events import OutboundMessage from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus from nanobot.bus.queue import MessageBus
from nanobot.channels.base import BaseChannel from nanobot.channels.base import BaseChannel
@ -7,6 +9,11 @@ from nanobot.channels.base import BaseChannel
class _DummyChannel(BaseChannel): class _DummyChannel(BaseChannel):
name = "dummy" name = "dummy"
_sent: list[OutboundMessage]
def __init__(self, config, bus):
super().__init__(config, bus)
self._sent = []
async def start(self) -> None: async def start(self) -> None:
return None return None
@ -15,7 +22,7 @@ class _DummyChannel(BaseChannel):
return None return None
async def send(self, msg: OutboundMessage) -> None: async def send(self, msg: OutboundMessage) -> None:
return None self._sent.append(msg)
def test_is_allowed_requires_exact_match() -> None: def test_is_allowed_requires_exact_match() -> None:
@ -35,3 +42,94 @@ def test_is_allowed_denies_empty_dict_allow_from() -> None:
channel = _DummyChannel({"allow_from": []}, MessageBus()) channel = _DummyChannel({"allow_from": []}, MessageBus())
assert channel.is_allowed("alice") is False assert channel.is_allowed("alice") is False
def test_is_allowed_star_allows_all() -> None:
channel = _DummyChannel({"allowFrom": ["*"]}, MessageBus())
assert channel.is_allowed("anyone") is True
def test_is_allowed_pairing_fallback(monkeypatch) -> None:
channel = _DummyChannel({"allowFrom": []}, MessageBus())
monkeypatch.setattr(
"nanobot.channels.base.is_approved", lambda _ch, sid: sid == "paired"
)
assert channel.is_allowed("paired") is True
assert channel.is_allowed("unknown") is False
@pytest.mark.asyncio
async def test_handle_message_dm_sends_pairing_code(monkeypatch) -> None:
channel = _DummyChannel({"allowFrom": []}, MessageBus())
monkeypatch.setattr(
"nanobot.channels.base.generate_code", lambda _ch, sid: "ABCD-EFGH"
)
await channel._handle_message(
sender_id="stranger", chat_id="chat1", content="hello", is_dm=True
)
assert len(channel._sent) == 1
msg = channel._sent[0]
assert "ABCD-EFGH" in msg.content
assert msg.metadata.get("_pairing_code") == "ABCD-EFGH"
@pytest.mark.asyncio
async def test_handle_message_group_ignores_unknown() -> None:
channel = _DummyChannel({"allowFrom": []}, MessageBus())
await channel._handle_message(
sender_id="stranger", chat_id="chat1", content="hello", is_dm=False
)
assert channel._sent == []
@pytest.mark.asyncio
async def test_handle_pairing_command_list(monkeypatch) -> None:
channel = _DummyChannel({"allowFrom": ["owner"]}, MessageBus())
monkeypatch.setattr(
"nanobot.channels.base.list_pending",
lambda: [
{
"code": "ABCD-EFGH",
"channel": "dummy",
"sender_id": "123",
"expires_at": 9999999999,
}
],
)
await channel._handle_pairing_command("owner", "chat1", "/pairing list")
assert len(channel._sent) == 1
assert "ABCD-EFGH" in channel._sent[0].content
@pytest.mark.asyncio
async def test_handle_pairing_command_approve(monkeypatch) -> None:
channel = _DummyChannel({"allowFrom": ["owner"]}, MessageBus())
monkeypatch.setattr(
"nanobot.channels.base.approve_code",
lambda code: ("dummy", "123") if code == "ABCD-EFGH" else None,
)
await channel._handle_pairing_command("owner", "chat1", "/pairing approve ABCD-EFGH")
assert len(channel._sent) == 1
assert "Approved" in channel._sent[0].content
@pytest.mark.asyncio
async def test_handle_pairing_command_revoke(monkeypatch) -> None:
channel = _DummyChannel({"allowFrom": ["owner"]}, MessageBus())
monkeypatch.setattr(
"nanobot.channels.base.revoke",
lambda ch, sid: sid == "123",
)
await channel._handle_pairing_command("owner", "chat1", "/pairing revoke 123")
assert len(channel._sent) == 1
assert "Revoked" in channel._sent[0].content

View File

@ -961,8 +961,8 @@ class _StartableChannel(BaseChannel):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_validate_allow_from_raises_on_empty_list(): async def test_validate_allow_from_allows_empty_list():
"""_validate_allow_from should raise SystemExit when allow_from is empty list.""" """Empty allow_from is valid now — pairing store handles unapproved senders."""
fake_config = SimpleNamespace( fake_config = SimpleNamespace(
channels=ChannelsConfig(), channels=ChannelsConfig(),
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
@ -973,10 +973,8 @@ async def test_validate_allow_from_raises_on_empty_list():
mgr.channels = {"test": _ChannelWithAllowFrom(fake_config, None, [])} mgr.channels = {"test": _ChannelWithAllowFrom(fake_config, None, [])}
mgr._dispatch_task = None mgr._dispatch_task = None
with pytest.raises(SystemExit) as exc_info: # Should not raise — empty list defers to pairing store
mgr._validate_allow_from() mgr._validate_allow_from()
assert "empty allowFrom" in str(exc_info.value)
@pytest.mark.asyncio @pytest.mark.asyncio
@ -997,8 +995,8 @@ async def test_validate_allow_from_passes_with_asterisk():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_validate_allow_from_raises_on_empty_dict_allow_from(): async def test_validate_allow_from_allows_empty_dict_allow_from():
"""_validate_allow_from should reject empty dict-backed allow_from lists.""" """Empty dict-backed allow_from is valid — pairing store handles approval."""
fake_config = SimpleNamespace( fake_config = SimpleNamespace(
channels=ChannelsConfig(), channels=ChannelsConfig(),
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
@ -1009,10 +1007,7 @@ async def test_validate_allow_from_raises_on_empty_dict_allow_from():
mgr.channels = {"test": _ChannelWithAllowFrom({"enabled": True}, None, [])} mgr.channels = {"test": _ChannelWithAllowFrom({"enabled": True}, None, [])}
mgr._dispatch_task = None mgr._dispatch_task = None
with pytest.raises(SystemExit) as exc_info: mgr._validate_allow_from()
mgr._validate_allow_from()
assert "empty allowFrom" in str(exc_info.value)
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@ -0,0 +1,99 @@
import time
import pytest
from nanobot.pairing import store
@pytest.fixture(autouse=True)
def _tmp_store(tmp_path, monkeypatch):
path = tmp_path / "pairing.json"
monkeypatch.setattr(store, "_store_path", lambda: path)
class TestGenerateCode:
def test_format(self) -> None:
code = store.generate_code("telegram", "123")
assert len(code) == 9 # 4 + 1 + 4
assert code[4] == "-"
assert code.replace("-", "").isalnum()
assert code.replace("-", "").isupper()
def test_uniqueness(self) -> None:
codes = {store.generate_code("telegram", str(i)) for i in range(20)}
assert len(codes) == 20
def test_ttl_expiration(self) -> None:
code = store.generate_code("telegram", "123", ttl=1)
assert store.approve_code(code) is not None
code2 = store.generate_code("telegram", "456", ttl=0)
time.sleep(0.1)
assert store.approve_code(code2) is None
class TestApproveDeny:
def test_approve_moves_to_approved(self) -> None:
code = store.generate_code("telegram", "123")
assert store.is_approved("telegram", "123") is False
result = store.approve_code(code)
assert result == ("telegram", "123")
assert store.is_approved("telegram", "123") is True
assert store.get_approved("telegram") == ["123"]
def test_deny_removes_pending(self) -> None:
code = store.generate_code("telegram", "123")
assert store.deny_code(code) is True
assert store.approve_code(code) is None
def test_deny_unknown_returns_false(self) -> None:
assert store.deny_code("UNKNOWN") is False
def test_approve_expired_returns_none(self) -> None:
code = store.generate_code("telegram", "123", ttl=0)
time.sleep(0.1)
assert store.approve_code(code) is None
class TestRevoke:
def test_revoke_removes_sender(self) -> None:
code = store.generate_code("telegram", "123")
store.approve_code(code)
assert store.is_approved("telegram", "123") is True
assert store.revoke("telegram", "123") is True
assert store.is_approved("telegram", "123") is False
assert store.get_approved("telegram") == []
def test_revoke_unknown_returns_false(self) -> None:
assert store.revoke("telegram", "999") is False
class TestListPending:
def test_empty(self) -> None:
assert store.list_pending() == []
def test_shows_pending(self) -> None:
store.generate_code("telegram", "123")
store.generate_code("discord", "456")
pending = store.list_pending()
assert len(pending) == 2
channels = {p["channel"] for p in pending}
assert channels == {"telegram", "discord"}
def test_expired_not_listed(self) -> None:
store.generate_code("telegram", "123", ttl=0)
time.sleep(0.1)
assert store.list_pending() == []
class TestStoreDurability:
def test_corruption_recovery(self, tmp_path, monkeypatch) -> None:
path = tmp_path / "pairing.json"
path.write_text("not json{", encoding="utf-8")
monkeypatch.setattr(store, "_store_path", lambda: path)
# Should recover gracefully and act as empty store
assert store.list_pending() == []
assert store.is_approved("telegram", "123") is False