code-review fixes: fsync, entropy, is_dm propagation, tests

- Add os.fsync with Windows-compatible directory flush in pairing store
- Increase pairing code length from 6 -> 8 characters for higher entropy
- Remove SystemExit on empty allowFrom; empty list now defers to pairing
- Update is_allowed docstring to document pairing fallback semantics
- Propagate is_dm to Matrix (direct rooms) and Slack (im channels)
- Slack _is_allowed now checks pairing store for DM allowlist mode
- Fix /pairing revoke to accept optional channel argument
- Move inline import time to module top-level
- Add WebSocket comment explaining is_dm=True assumption
- Add comprehensive tests for store and BaseChannel pairing integration
- Fix existing tests that expected empty allowFrom to hard-exit

Refs #3774
This commit is contained in:
chengyongru 2026-05-14 11:03:17 +08:00
parent 01e93a895f
commit cf224c9701
9 changed files with 265 additions and 33 deletions

View File

@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
@ -191,6 +192,10 @@ class BaseChannel(ABC):
2. ``allowFrom`` list allow if sender_id is present. 2. ``allowFrom`` list allow if sender_id is present.
3. Pairing store approved list allow if previously approved. 3. Pairing store approved list allow if previously approved.
4. Otherwise deny. 4. Otherwise deny.
An empty ``allowFrom`` list does not cause a hard exit; instead it
defers to the pairing store so that unknown DM senders can request
access via a pairing code.
""" """
if isinstance(self.config, dict): if isinstance(self.config, dict):
if "allow_from" in self.config: if "allow_from" in self.config:
@ -296,8 +301,6 @@ class BaseChannel(ABC):
reply = "No pending pairing requests." reply = "No pending pairing requests."
else: else:
lines = ["Pending pairing requests:"] lines = ["Pending pairing requests:"]
import time
for item in pending: for item in pending:
remaining = int(item.get("expires_at", 0) - time.time()) remaining = int(item.get("expires_at", 0) - time.time())
expiry = f"{remaining}s" if remaining > 0 else "expired" expiry = f"{remaining}s" if remaining > 0 else "expired"
@ -331,12 +334,14 @@ class BaseChannel(ABC):
elif sub == "revoke": elif sub == "revoke":
if arg is None: if arg is None:
reply = "Usage: `/pairing revoke <user_id>`" reply = "Usage: `/pairing revoke <user_id>` or `/pairing revoke <channel> <user_id>`"
else: else:
if revoke(self.name, arg): target_channel = parts[3] if len(parts) > 3 else self.name
reply = f"Revoked {arg} from {self.name}" target_user = arg if len(parts) <= 3 else parts[3]
if revoke(target_channel, target_user):
reply = f"Revoked {target_user} from {target_channel}"
else: else:
reply = f"{arg} was not in the approved list for {self.name}" reply = f"{target_user} was not in the approved list for {target_channel}"
else: else:
reply = ( reply = (

View File

@ -143,9 +143,9 @@ class ChannelManager:
allow = cfg.get("allowFrom") allow = cfg.get("allowFrom")
else: else:
allow = getattr(cfg, "allow_from", None) allow = getattr(cfg, "allow_from", None)
if allow == []: if allow is None:
raise SystemExit( raise SystemExit(
f'Error: "{name}" has empty allowFrom (denies all). ' f'Error: "{name}" is missing allowFrom. '
f'Set ["*"] to allow everyone, or add specific user IDs.' f'Set ["*"] to allow everyone, or add specific user IDs.'
) )

View File

@ -28,10 +28,11 @@ try:
RoomMessageMedia, RoomMessageMedia,
RoomMessageText, RoomMessageText,
RoomSendError, RoomSendError,
RoomSendResponse,
RoomTypingError, RoomTypingError,
SyncError, SyncError,
UploadError, RoomSendResponse, UploadError,
) )
from nio.crypto.attachments import decrypt_attachment from nio.crypto.attachments import decrypt_attachment
from nio.exceptions import EncryptionError from nio.exceptions import EncryptionError
except ImportError as e: except ImportError as e:
@ -107,7 +108,7 @@ class _StreamBuf:
:ivar text: Stores the text content of the buffer. :ivar text: Stores the text content of the buffer.
:type text: str :type text: str
:ivar event_id: Identifier for the associated event. None indicates no :ivar event_id: Identifier for the associated event. None indicates no
specific event association. specific event association.
:type event_id: str | None :type event_id: str | None
:ivar last_edit: Timestamp of the most recent edit to the buffer. :ivar last_edit: Timestamp of the most recent edit to the buffer.
@ -140,19 +141,19 @@ def _build_matrix_text_content(
) -> dict[str, object]: ) -> dict[str, object]:
""" """
Constructs and returns a dictionary representing the matrix text content with optional Constructs and returns a dictionary representing the matrix text content with optional
HTML formatting and reference to an existing event for replacement. This function is HTML formatting and reference to an existing event for replacement. This function is
primarily used to create content payloads compatible with the Matrix messaging protocol. primarily used to create content payloads compatible with the Matrix messaging protocol.
:param text: The plain text content to include in the message. :param text: The plain text content to include in the message.
:type text: str :type text: str
:param event_id: Optional ID of the event to replace. If provided, the function will :param event_id: Optional ID of the event to replace. If provided, the function will
include information indicating that the message is a replacement of the specified include information indicating that the message is a replacement of the specified
event. event.
:type event_id: str | None :type event_id: str | None
:param thread_relates_to: Optional Matrix thread relation metadata. For edits this is :param thread_relates_to: Optional Matrix thread relation metadata. For edits this is
stored in ``m.new_content`` so the replacement remains in the same thread. stored in ``m.new_content`` so the replacement remains in the same thread.
:type thread_relates_to: dict[str, object] | None :type thread_relates_to: dict[str, object] | None
:return: A dictionary containing the matrix text content, potentially enriched with :return: A dictionary containing the matrix text content, potentially enriched with
HTML formatting and replacement metadata if applicable. HTML formatting and replacement metadata if applicable.
:rtype: dict[str, object] :rtype: dict[str, object]
""" """
@ -523,7 +524,7 @@ class MatrixChannel(BaseChannel):
return return
await self._stop_typing_keepalive(chat_id, clear_typing=True) await self._stop_typing_keepalive(chat_id, clear_typing=True)
content = _build_matrix_text_content( content = _build_matrix_text_content(
buf.text, buf.text,
buf.event_id, buf.event_id,
@ -537,7 +538,7 @@ class MatrixChannel(BaseChannel):
buf = _StreamBuf() buf = _StreamBuf()
self._stream_bufs[chat_id] = buf self._stream_bufs[chat_id] = buf
buf.text += delta buf.text += delta
if not buf.text.strip(): if not buf.text.strip():
return return
@ -870,6 +871,7 @@ class MatrixChannel(BaseChannel):
await self._handle_message( await self._handle_message(
sender_id=event.sender, chat_id=room.room_id, sender_id=event.sender, chat_id=room.room_id,
content=event.body, metadata=self._base_metadata(room, event), content=event.body, metadata=self._base_metadata(room, event),
is_dm=self._is_direct_room(room),
) )
except Exception: except Exception:
await self._stop_typing_keepalive(room.room_id, clear_typing=True) await self._stop_typing_keepalive(room.room_id, clear_typing=True)
@ -907,6 +909,7 @@ class MatrixChannel(BaseChannel):
content="\n".join(parts), content="\n".join(parts),
media=[attachment["path"]] if attachment else [], media=[attachment["path"]] if attachment else [],
metadata=meta, metadata=meta,
is_dm=self._is_direct_room(room),
) )
except Exception: except Exception:
await self._stop_typing_keepalive(room.room_id, clear_typing=True) await self._stop_typing_keepalive(room.room_id, clear_typing=True)

View File

@ -342,6 +342,22 @@ class SlackChannel(BaseChannel):
channel_type = event.get("channel_type") or "" channel_type = event.get("channel_type") or ""
if not self._is_allowed(sender_id, chat_id, channel_type): if not self._is_allowed(sender_id, chat_id, channel_type):
if channel_type == "im" and self.config.dm.enabled:
from nanobot.pairing import generate_code
code = generate_code(self.name, sender_id)
reply = (
"This assistant requires approval before it can respond.\n"
f"Your pairing code is: `{code}`\n"
f"Ask the owner to run: `nanobot pairing approve {code}`"
)
await self.send(
OutboundMessage(
channel=self.name,
chat_id=chat_id,
content=reply,
metadata={"_pairing_code": code},
)
)
return return
if channel_type != "im" and not self._should_respond_in_channel(event_type, text, chat_id): if channel_type != "im" and not self._should_respond_in_channel(event_type, text, chat_id):
@ -608,11 +624,13 @@ class SlackChannel(BaseChannel):
self.logger.debug("done reaction failed: {}", e) self.logger.debug("done reaction failed: {}", e)
def _is_allowed(self, sender_id: str, chat_id: str, channel_type: str) -> bool: def _is_allowed(self, sender_id: str, chat_id: str, channel_type: str) -> bool:
from nanobot.pairing import is_approved
if channel_type == "im": if channel_type == "im":
if not self.config.dm.enabled: if not self.config.dm.enabled:
return False return False
if self.config.dm.policy == "allowlist": if self.config.dm.policy == "allowlist":
return sender_id in self.config.dm.allow_from return sender_id in self.config.dm.allow_from or is_approved(self.name, sender_id)
return True return True
# Group / channel messages # Group / channel messages

View File

@ -1249,6 +1249,8 @@ class WebSocketChannel(BaseChannel):
content = _parse_inbound_payload(raw) content = _parse_inbound_payload(raw)
if content is None: if content is None:
continue continue
# WebSocket connections are always treated as 1:1 (DM) because
# each connection represents a single client browser/tab.
await self._handle_message( await self._handle_message(
sender_id=client_id, sender_id=client_id,
chat_id=default_chat_id, chat_id=default_chat_id,

View File

@ -8,6 +8,7 @@ private-assistant scale: small JSON file, simple locking, no external DB.
from __future__ import annotations from __future__ import annotations
import json import json
import os
import secrets import secrets
import string import string
import threading import threading
@ -20,8 +21,9 @@ from loguru import logger
from nanobot.config.paths import get_data_dir from nanobot.config.paths import get_data_dir
_LOCK = threading.Lock() _LOCK = threading.Lock()
_ALPHABET = string.ascii_uppercase + string.digits _ALPHABET = string.ascii_uppercase + string.digits
_CODE_LENGTH = 6 # e.g. XK9-42F _CODE_LENGTH = 8 # e.g. XK9-42F-MP
_TTL_DEFAULT_S = 600 # 10 minutes _TTL_DEFAULT_S = 600 # 10 minutes
@ -48,7 +50,17 @@ def _save(data: dict[str, Any]) -> None:
with open(tmp, "w", encoding="utf-8") as f: with open(tmp, "w", encoding="utf-8") as f:
json.dump(data, f, indent=2, ensure_ascii=False) json.dump(data, f, indent=2, ensure_ascii=False)
f.flush() f.flush()
os.fsync(f.fileno())
tmp.replace(path) tmp.replace(path)
# Ensure directory entry is flushed for durability (Unix only; no-op on Windows)
try:
fd = os.open(path.parent, os.O_RDONLY)
try:
os.fsync(fd)
finally:
os.close(fd)
except (OSError, NotImplementedError):
pass
def _gc_pending(data: dict[str, Any]) -> None: def _gc_pending(data: dict[str, Any]) -> None:
@ -75,7 +87,7 @@ def generate_code(
# Ensure uniqueness # Ensure uniqueness
for _ in range(100): for _ in range(100):
raw = "".join(secrets.choice(_ALPHABET) for _ in range(_CODE_LENGTH)) raw = "".join(secrets.choice(_ALPHABET) for _ in range(_CODE_LENGTH))
code = f"{raw[:3]}-{raw[3:]}" code = f"{raw[:4]}-{raw[4:]}"
if code not in data.get("pending", {}): if code not in data.get("pending", {}):
break break
else: # pragma: no cover else: # pragma: no cover

View File

@ -1,5 +1,7 @@
from types import SimpleNamespace from types import SimpleNamespace
import pytest
from nanobot.bus.events import OutboundMessage from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus from nanobot.bus.queue import MessageBus
from nanobot.channels.base import BaseChannel from nanobot.channels.base import BaseChannel
@ -7,6 +9,11 @@ from nanobot.channels.base import BaseChannel
class _DummyChannel(BaseChannel): class _DummyChannel(BaseChannel):
name = "dummy" name = "dummy"
_sent: list[OutboundMessage]
def __init__(self, config, bus):
super().__init__(config, bus)
self._sent = []
async def start(self) -> None: async def start(self) -> None:
return None return None
@ -15,7 +22,7 @@ class _DummyChannel(BaseChannel):
return None return None
async def send(self, msg: OutboundMessage) -> None: async def send(self, msg: OutboundMessage) -> None:
return None self._sent.append(msg)
def test_is_allowed_requires_exact_match() -> None: def test_is_allowed_requires_exact_match() -> None:
@ -35,3 +42,94 @@ def test_is_allowed_denies_empty_dict_allow_from() -> None:
channel = _DummyChannel({"allow_from": []}, MessageBus()) channel = _DummyChannel({"allow_from": []}, MessageBus())
assert channel.is_allowed("alice") is False assert channel.is_allowed("alice") is False
def test_is_allowed_star_allows_all() -> None:
channel = _DummyChannel({"allowFrom": ["*"]}, MessageBus())
assert channel.is_allowed("anyone") is True
def test_is_allowed_pairing_fallback(monkeypatch) -> None:
channel = _DummyChannel({"allowFrom": []}, MessageBus())
monkeypatch.setattr(
"nanobot.channels.base.is_approved", lambda _ch, sid: sid == "paired"
)
assert channel.is_allowed("paired") is True
assert channel.is_allowed("unknown") is False
@pytest.mark.asyncio
async def test_handle_message_dm_sends_pairing_code(monkeypatch) -> None:
channel = _DummyChannel({"allowFrom": []}, MessageBus())
monkeypatch.setattr(
"nanobot.channels.base.generate_code", lambda _ch, sid: "ABCD-EFGH"
)
await channel._handle_message(
sender_id="stranger", chat_id="chat1", content="hello", is_dm=True
)
assert len(channel._sent) == 1
msg = channel._sent[0]
assert "ABCD-EFGH" in msg.content
assert msg.metadata.get("_pairing_code") == "ABCD-EFGH"
@pytest.mark.asyncio
async def test_handle_message_group_ignores_unknown() -> None:
channel = _DummyChannel({"allowFrom": []}, MessageBus())
await channel._handle_message(
sender_id="stranger", chat_id="chat1", content="hello", is_dm=False
)
assert channel._sent == []
@pytest.mark.asyncio
async def test_handle_pairing_command_list(monkeypatch) -> None:
channel = _DummyChannel({"allowFrom": ["owner"]}, MessageBus())
monkeypatch.setattr(
"nanobot.channels.base.list_pending",
lambda: [
{
"code": "ABCD-EFGH",
"channel": "dummy",
"sender_id": "123",
"expires_at": 9999999999,
}
],
)
await channel._handle_pairing_command("owner", "chat1", "/pairing list")
assert len(channel._sent) == 1
assert "ABCD-EFGH" in channel._sent[0].content
@pytest.mark.asyncio
async def test_handle_pairing_command_approve(monkeypatch) -> None:
channel = _DummyChannel({"allowFrom": ["owner"]}, MessageBus())
monkeypatch.setattr(
"nanobot.channels.base.approve_code",
lambda code: ("dummy", "123") if code == "ABCD-EFGH" else None,
)
await channel._handle_pairing_command("owner", "chat1", "/pairing approve ABCD-EFGH")
assert len(channel._sent) == 1
assert "Approved" in channel._sent[0].content
@pytest.mark.asyncio
async def test_handle_pairing_command_revoke(monkeypatch) -> None:
channel = _DummyChannel({"allowFrom": ["owner"]}, MessageBus())
monkeypatch.setattr(
"nanobot.channels.base.revoke",
lambda ch, sid: sid == "123",
)
await channel._handle_pairing_command("owner", "chat1", "/pairing revoke 123")
assert len(channel._sent) == 1
assert "Revoked" in channel._sent[0].content

View File

@ -961,8 +961,8 @@ class _StartableChannel(BaseChannel):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_validate_allow_from_raises_on_empty_list(): async def test_validate_allow_from_allows_empty_list():
"""_validate_allow_from should raise SystemExit when allow_from is empty list.""" """Empty allow_from is valid now — pairing store handles unapproved senders."""
fake_config = SimpleNamespace( fake_config = SimpleNamespace(
channels=ChannelsConfig(), channels=ChannelsConfig(),
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
@ -973,10 +973,8 @@ async def test_validate_allow_from_raises_on_empty_list():
mgr.channels = {"test": _ChannelWithAllowFrom(fake_config, None, [])} mgr.channels = {"test": _ChannelWithAllowFrom(fake_config, None, [])}
mgr._dispatch_task = None mgr._dispatch_task = None
with pytest.raises(SystemExit) as exc_info: # Should not raise — empty list defers to pairing store
mgr._validate_allow_from() mgr._validate_allow_from()
assert "empty allowFrom" in str(exc_info.value)
@pytest.mark.asyncio @pytest.mark.asyncio
@ -997,8 +995,8 @@ async def test_validate_allow_from_passes_with_asterisk():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_validate_allow_from_raises_on_empty_dict_allow_from(): async def test_validate_allow_from_allows_empty_dict_allow_from():
"""_validate_allow_from should reject empty dict-backed allow_from lists.""" """Empty dict-backed allow_from is valid — pairing store handles approval."""
fake_config = SimpleNamespace( fake_config = SimpleNamespace(
channels=ChannelsConfig(), channels=ChannelsConfig(),
providers=SimpleNamespace(groq=SimpleNamespace(api_key="")), providers=SimpleNamespace(groq=SimpleNamespace(api_key="")),
@ -1009,10 +1007,7 @@ async def test_validate_allow_from_raises_on_empty_dict_allow_from():
mgr.channels = {"test": _ChannelWithAllowFrom({"enabled": True}, None, [])} mgr.channels = {"test": _ChannelWithAllowFrom({"enabled": True}, None, [])}
mgr._dispatch_task = None mgr._dispatch_task = None
with pytest.raises(SystemExit) as exc_info: mgr._validate_allow_from()
mgr._validate_allow_from()
assert "empty allowFrom" in str(exc_info.value)
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@ -0,0 +1,99 @@
import time
import pytest
from nanobot.pairing import store
@pytest.fixture(autouse=True)
def _tmp_store(tmp_path, monkeypatch):
path = tmp_path / "pairing.json"
monkeypatch.setattr(store, "_store_path", lambda: path)
class TestGenerateCode:
def test_format(self) -> None:
code = store.generate_code("telegram", "123")
assert len(code) == 9 # 4 + 1 + 4
assert code[4] == "-"
assert code.replace("-", "").isalnum()
assert code.replace("-", "").isupper()
def test_uniqueness(self) -> None:
codes = {store.generate_code("telegram", str(i)) for i in range(20)}
assert len(codes) == 20
def test_ttl_expiration(self) -> None:
code = store.generate_code("telegram", "123", ttl=1)
assert store.approve_code(code) is not None
code2 = store.generate_code("telegram", "456", ttl=0)
time.sleep(0.1)
assert store.approve_code(code2) is None
class TestApproveDeny:
def test_approve_moves_to_approved(self) -> None:
code = store.generate_code("telegram", "123")
assert store.is_approved("telegram", "123") is False
result = store.approve_code(code)
assert result == ("telegram", "123")
assert store.is_approved("telegram", "123") is True
assert store.get_approved("telegram") == ["123"]
def test_deny_removes_pending(self) -> None:
code = store.generate_code("telegram", "123")
assert store.deny_code(code) is True
assert store.approve_code(code) is None
def test_deny_unknown_returns_false(self) -> None:
assert store.deny_code("UNKNOWN") is False
def test_approve_expired_returns_none(self) -> None:
code = store.generate_code("telegram", "123", ttl=0)
time.sleep(0.1)
assert store.approve_code(code) is None
class TestRevoke:
def test_revoke_removes_sender(self) -> None:
code = store.generate_code("telegram", "123")
store.approve_code(code)
assert store.is_approved("telegram", "123") is True
assert store.revoke("telegram", "123") is True
assert store.is_approved("telegram", "123") is False
assert store.get_approved("telegram") == []
def test_revoke_unknown_returns_false(self) -> None:
assert store.revoke("telegram", "999") is False
class TestListPending:
def test_empty(self) -> None:
assert store.list_pending() == []
def test_shows_pending(self) -> None:
store.generate_code("telegram", "123")
store.generate_code("discord", "456")
pending = store.list_pending()
assert len(pending) == 2
channels = {p["channel"] for p in pending}
assert channels == {"telegram", "discord"}
def test_expired_not_listed(self) -> None:
store.generate_code("telegram", "123", ttl=0)
time.sleep(0.1)
assert store.list_pending() == []
class TestStoreDurability:
def test_corruption_recovery(self, tmp_path, monkeypatch) -> None:
path = tmp_path / "pairing.json"
path.write_text("not json{", encoding="utf-8")
monkeypatch.setattr(store, "_store_path", lambda: path)
# Should recover gracefully and act as empty store
assert store.list_pending() == []
assert store.is_approved("telegram", "123") is False