refactor(pairing): apply simplify review fixes

- Extract format_pairing_reply() and format_expiry() to eliminate
duplication between BaseChannel and SlackChannel.
- Use _write_text_atomic() from helpers.py instead of hand-rolled
fsync logic in pairing store.
- Convert approved lists to in-memory sets for O(1) lookup.
- Remove collision retry loop (8-char entropy is sufficient).
- Fix /pairing command parsing to split prefix exactly.
- Remove unused import time from base.py.
- Fix tests to pass subcommand_text, not full /pairing string.
This commit is contained in:
chengyongru 2026-05-14 11:20:49 +08:00
parent cf224c9701
commit 4e7022e73c
5 changed files with 75 additions and 85 deletions

View File

@ -2,7 +2,6 @@
from __future__ import annotations from __future__ import annotations
import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
@ -14,6 +13,8 @@ from nanobot.bus.queue import MessageBus
from nanobot.pairing import ( from nanobot.pairing import (
approve_code, approve_code,
deny_code, deny_code,
format_expiry,
format_pairing_reply,
generate_code, generate_code,
is_approved, is_approved,
list_pending, list_pending,
@ -222,35 +223,15 @@ class BaseChannel(ABC):
session_key: str | None = None, session_key: str | None = None,
is_dm: bool = False, is_dm: bool = False,
) -> None: ) -> None:
""" """Handle an incoming message: check permissions, issue pairing codes in DMs, or forward to bus."""
Handle an incoming message from the chat platform.
This method checks permissions and forwards to the bus.
For DM messages from unrecognised senders, a pairing code is
issued instead of silently dropping the message.
Args:
sender_id: The sender's identifier.
chat_id: The chat/channel identifier.
content: Message text content.
media: Optional list of media URLs.
metadata: Optional channel-specific metadata.
session_key: Optional session key override (e.g. thread-scoped sessions).
is_dm: Whether the message is a direct / private message.
"""
if not self.is_allowed(sender_id): if not self.is_allowed(sender_id):
if is_dm: if is_dm:
code = generate_code(self.name, str(sender_id)) code = generate_code(self.name, str(sender_id))
reply = (
"This assistant requires approval before it can respond.\n"
f"Your pairing code is: `{code}`\n"
f"Ask the owner to run: `nanobot pairing approve {code}`"
)
await self.send( await self.send(
OutboundMessage( OutboundMessage(
channel=self.name, channel=self.name,
chat_id=str(chat_id), chat_id=str(chat_id),
content=reply, content=format_pairing_reply(code),
metadata={"_pairing_code": code}, metadata={"_pairing_code": code},
) )
) )
@ -267,8 +248,9 @@ class BaseChannel(ABC):
return return
# Intercept /pairing slash commands before they reach the agent loop # Intercept /pairing slash commands before they reach the agent loop
if content.strip().startswith("/pairing"): parts = content.strip().split(None, 1)
await self._handle_pairing_command(sender_id, chat_id, content.strip()) if parts and parts[0] == "/pairing":
await self._handle_pairing_command(sender_id, chat_id, parts[1] if len(parts) > 1 else "")
return return
meta = metadata or {} meta = metadata or {}
@ -288,12 +270,12 @@ class BaseChannel(ABC):
await self.bus.publish_inbound(msg) await self.bus.publish_inbound(msg)
async def _handle_pairing_command( async def _handle_pairing_command(
self, sender_id: str, chat_id: str, content: str self, sender_id: str, chat_id: str, subcommand_text: str
) -> None: ) -> None:
"""Execute a ``/pairing`` slash command and reply directly to the user.""" """Execute a ``/pairing`` slash command and reply directly to the user."""
parts = content.split() parts = subcommand_text.split()
sub = parts[1] if len(parts) > 1 else "list" sub = parts[0] if parts else "list"
arg = parts[2] if len(parts) > 2 else None arg = parts[1] if len(parts) > 1 else None
if sub in ("list",): if sub in ("list",):
pending = list_pending() pending = list_pending()
@ -302,8 +284,7 @@ class BaseChannel(ABC):
else: else:
lines = ["Pending pairing requests:"] lines = ["Pending pairing requests:"]
for item in pending: for item in pending:
remaining = int(item.get("expires_at", 0) - time.time()) expiry = format_expiry(item.get("expires_at", 0))
expiry = f"{remaining}s" if remaining > 0 else "expired"
lines.append( lines.append(
f"- `{item['code']}` | {item['channel']} | {item['sender_id']} | {expiry}" f"- `{item['code']}` | {item['channel']} | {item['sender_id']} | {expiry}"
) )
@ -335,13 +316,20 @@ class BaseChannel(ABC):
elif sub == "revoke": elif sub == "revoke":
if arg is None: if arg is None:
reply = "Usage: `/pairing revoke <user_id>` or `/pairing revoke <channel> <user_id>`" reply = "Usage: `/pairing revoke <user_id>` or `/pairing revoke <channel> <user_id>`"
elif len(parts) == 2:
reply = (
f"Revoked {arg} from {self.name}"
if revoke(self.name, arg)
else f"{arg} was not in the approved list for {self.name}"
)
elif len(parts) == 3:
reply = (
f"Revoked {parts[2]} from {arg}"
if revoke(arg, parts[2])
else f"{parts[2]} was not in the approved list for {arg}"
)
else: else:
target_channel = parts[3] if len(parts) > 3 else self.name reply = "Usage: `/pairing revoke <user_id>` or `/pairing revoke <channel> <user_id>`"
target_user = arg if len(parts) <= 3 else parts[3]
if revoke(target_channel, target_user):
reply = f"Revoked {target_user} from {target_channel}"
else:
reply = f"{target_user} was not in the approved list for {target_channel}"
else: else:
reply = ( reply = (

View File

@ -18,6 +18,7 @@ from nanobot.bus.queue import MessageBus
from nanobot.channels.base import BaseChannel from nanobot.channels.base import BaseChannel
from nanobot.config.paths import get_media_dir from nanobot.config.paths import get_media_dir
from nanobot.config.schema import Base from nanobot.config.schema import Base
from nanobot.pairing import format_pairing_reply, generate_code, is_approved
from nanobot.utils.helpers import safe_filename, split_message from nanobot.utils.helpers import safe_filename, split_message
@ -343,13 +344,8 @@ class SlackChannel(BaseChannel):
if not self._is_allowed(sender_id, chat_id, channel_type): if not self._is_allowed(sender_id, chat_id, channel_type):
if channel_type == "im" and self.config.dm.enabled: if channel_type == "im" and self.config.dm.enabled:
from nanobot.pairing import generate_code
code = generate_code(self.name, sender_id) code = generate_code(self.name, sender_id)
reply = ( reply = format_pairing_reply(code)
"This assistant requires approval before it can respond.\n"
f"Your pairing code is: `{code}`\n"
f"Ask the owner to run: `nanobot pairing approve {code}`"
)
await self.send( await self.send(
OutboundMessage( OutboundMessage(
channel=self.name, channel=self.name,
@ -624,8 +620,6 @@ class SlackChannel(BaseChannel):
self.logger.debug("done reaction failed: {}", e) self.logger.debug("done reaction failed: {}", e)
def _is_allowed(self, sender_id: str, chat_id: str, channel_type: str) -> bool: def _is_allowed(self, sender_id: str, chat_id: str, channel_type: str) -> bool:
from nanobot.pairing import is_approved
if channel_type == "im": if channel_type == "im":
if not self.config.dm.enabled: if not self.config.dm.enabled:
return False return False

View File

@ -3,6 +3,8 @@
from nanobot.pairing.store import ( from nanobot.pairing.store import (
approve_code, approve_code,
deny_code, deny_code,
format_expiry,
format_pairing_reply,
generate_code, generate_code,
get_approved, get_approved,
is_approved, is_approved,
@ -13,6 +15,8 @@ from nanobot.pairing.store import (
__all__ = [ __all__ = [
"approve_code", "approve_code",
"deny_code", "deny_code",
"format_expiry",
"format_pairing_reply",
"generate_code", "generate_code",
"get_approved", "get_approved",
"is_approved", "is_approved",

View File

@ -8,7 +8,6 @@ private-assistant scale: small JSON file, simple locking, no external DB.
from __future__ import annotations from __future__ import annotations
import json import json
import os
import secrets import secrets
import string import string
import threading import threading
@ -19,11 +18,11 @@ from typing import Any
from loguru import logger from loguru import logger
from nanobot.config.paths import get_data_dir from nanobot.config.paths import get_data_dir
from nanobot.utils.helpers import _write_text_atomic
_LOCK = threading.Lock() _LOCK = threading.Lock()
_ALPHABET = string.ascii_uppercase + string.digits _ALPHABET = string.ascii_uppercase + string.digits
_CODE_LENGTH = 8 # e.g. XK9-42F-MP _CODE_LENGTH = 8 # e.g. ABCD-EFGH
_TTL_DEFAULT_S = 600 # 10 minutes _TTL_DEFAULT_S = 600 # 10 minutes
@ -37,30 +36,26 @@ def _load() -> dict[str, Any]:
return {"approved": {}, "pending": {}} return {"approved": {}, "pending": {}}
try: try:
with open(path, encoding="utf-8") as f: with open(path, encoding="utf-8") as f:
return json.load(f) data = json.load(f)
except (json.JSONDecodeError, OSError): except (json.JSONDecodeError, OSError):
logger.warning("Corrupted pairing store, resetting") logger.warning("Corrupted pairing store, resetting")
return {"approved": {}, "pending": {}} return {"approved": {}, "pending": {}}
# Convert approved lists to sets for O(1) lookup
for channel, users in data.get("approved", {}).items():
data["approved"][channel] = set(users)
return data
def _save(data: dict[str, Any]) -> None: def _save(data: dict[str, Any]) -> None:
path = _store_path() path = _store_path()
path.parent.mkdir(parents=True, exist_ok=True) path.parent.mkdir(parents=True, exist_ok=True)
tmp = path.with_suffix(".tmp") # Convert sets back to lists for JSON serialization
with open(tmp, "w", encoding="utf-8") as f: payload = {
json.dump(data, f, indent=2, ensure_ascii=False) "approved": {ch: sorted(list(users)) for ch, users in data.get("approved", {}).items()},
f.flush() "pending": dict(data.get("pending", {})),
os.fsync(f.fileno()) }
tmp.replace(path) _write_text_atomic(path, json.dumps(payload, indent=2, ensure_ascii=False))
# Ensure directory entry is flushed for durability (Unix only; no-op on Windows)
try:
fd = os.open(path.parent, os.O_RDONLY)
try:
os.fsync(fd)
finally:
os.close(fd)
except (OSError, NotImplementedError):
pass
def _gc_pending(data: dict[str, Any]) -> None: def _gc_pending(data: dict[str, Any]) -> None:
@ -79,19 +74,13 @@ def generate_code(
) -> str: ) -> str:
"""Create a new pairing code for *sender_id* on *channel*. """Create a new pairing code for *sender_id* on *channel*.
Returns the code (e.g. ``"XK9-42F"``). Returns the code (e.g. ``"ABCD-EFGH"``).
""" """
with _LOCK: with _LOCK:
data = _load() data = _load()
_gc_pending(data) _gc_pending(data)
# Ensure uniqueness raw = "".join(secrets.choice(_ALPHABET) for _ in range(_CODE_LENGTH))
for _ in range(100): code = f"{raw[:4]}-{raw[4:]}"
raw = "".join(secrets.choice(_ALPHABET) for _ in range(_CODE_LENGTH))
code = f"{raw[:4]}-{raw[4:]}"
if code not in data.get("pending", {}):
break
else: # pragma: no cover
raise RuntimeError("Failed to generate unique pairing code")
data.setdefault("pending", {})[code] = { data.setdefault("pending", {})[code] = {
"channel": channel, "channel": channel,
@ -119,7 +108,7 @@ def approve_code(code: str) -> tuple[str, str] | None:
return None return None
channel = info["channel"] channel = info["channel"]
sender_id = info["sender_id"] sender_id = info["sender_id"]
data.setdefault("approved", {}).setdefault(channel, []).append(sender_id) data.setdefault("approved", {}).setdefault(channel, set()).add(sender_id)
_save(data) _save(data)
logger.info("Approved pairing code {} for {}@{}", code, sender_id, channel) logger.info("Approved pairing code {} for {}@{}", code, sender_id, channel)
return channel, sender_id return channel, sender_id
@ -146,8 +135,8 @@ def is_approved(channel: str, sender_id: str) -> bool:
"""Check whether *sender_id* has been approved on *channel*.""" """Check whether *sender_id* has been approved on *channel*."""
with _LOCK: with _LOCK:
data = _load() data = _load()
approved: dict[str, list[str]] = data.get("approved", {}) approved: dict[str, set[str]] = data.get("approved", {})
return str(sender_id) in approved.get(channel, []) return str(sender_id) in approved.get(channel, set())
def list_pending() -> list[dict[str, Any]]: def list_pending() -> list[dict[str, Any]]:
@ -168,11 +157,11 @@ def revoke(channel: str, sender_id: str) -> bool:
""" """
with _LOCK: with _LOCK:
data = _load() data = _load()
approved: dict[str, list[str]] = data.get("approved", {}) approved: dict[str, set[str]] = data.get("approved", {})
lst = approved.get(channel, []) users = approved.get(channel, set())
if sender_id in lst: if sender_id in users:
lst.remove(sender_id) users.discard(sender_id)
if not lst: if not users:
del approved[channel] del approved[channel]
_save(data) _save(data)
logger.info("Revoked {} from {}", sender_id, channel) logger.info("Revoked {} from {}", sender_id, channel)
@ -184,4 +173,19 @@ def get_approved(channel: str) -> list[str]:
"""Return all approved sender IDs for *channel*.""" """Return all approved sender IDs for *channel*."""
with _LOCK: with _LOCK:
data = _load() data = _load()
return list(data.get("approved", {}).get(channel, [])) return sorted(data.get("approved", {}).get(channel, set()))
def format_pairing_reply(code: str) -> str:
"""Return the pairing-code message sent to unrecognised DM senders."""
return (
"This assistant requires approval before it can respond.\n"
f"Your pairing code is: `{code}`\n"
f"Ask the owner to run: `nanobot pairing approve {code}`"
)
def format_expiry(expires_at: float) -> str:
"""Return a human-readable expiry string (e.g. ``"120s"`` or ``"expired"``)."""
remaining = int(expires_at - time.time())
return f"{remaining}s" if remaining > 0 else "expired"

View File

@ -101,7 +101,7 @@ async def test_handle_pairing_command_list(monkeypatch) -> None:
], ],
) )
await channel._handle_pairing_command("owner", "chat1", "/pairing list") await channel._handle_pairing_command("owner", "chat1", "list")
assert len(channel._sent) == 1 assert len(channel._sent) == 1
assert "ABCD-EFGH" in channel._sent[0].content assert "ABCD-EFGH" in channel._sent[0].content
@ -115,7 +115,7 @@ async def test_handle_pairing_command_approve(monkeypatch) -> None:
lambda code: ("dummy", "123") if code == "ABCD-EFGH" else None, lambda code: ("dummy", "123") if code == "ABCD-EFGH" else None,
) )
await channel._handle_pairing_command("owner", "chat1", "/pairing approve ABCD-EFGH") await channel._handle_pairing_command("owner", "chat1", "approve ABCD-EFGH")
assert len(channel._sent) == 1 assert len(channel._sent) == 1
assert "Approved" in channel._sent[0].content assert "Approved" in channel._sent[0].content
@ -129,7 +129,7 @@ async def test_handle_pairing_command_revoke(monkeypatch) -> None:
lambda ch, sid: sid == "123", lambda ch, sid: sid == "123",
) )
await channel._handle_pairing_command("owner", "chat1", "/pairing revoke 123") await channel._handle_pairing_command("owner", "chat1", "revoke 123")
assert len(channel._sent) == 1 assert len(channel._sent) == 1
assert "Revoked" in channel._sent[0].content assert "Revoked" in channel._sent[0].content