From 9bc86ee82572010879422c0447a2be7d6ed33dd0 Mon Sep 17 00:00:00 2001 From: chengyongru Date: Thu, 14 May 2026 11:20:49 +0800 Subject: [PATCH] refactor(pairing): apply simplify review fixes - Extract format_pairing_reply() and format_expiry() to eliminate duplication between BaseChannel and SlackChannel. - Use _write_text_atomic() from helpers.py instead of hand-rolled fsync logic in pairing store. - Convert approved lists to in-memory sets for O(1) lookup. - Remove collision retry loop (8-char entropy is sufficient). - Fix /pairing command parsing to split prefix exactly. - Remove unused import time from base.py. - Fix tests to pass subcommand_text, not full /pairing string. --- nanobot/channels/base.py | 62 +++++++++-------------- nanobot/channels/slack.py | 10 +--- nanobot/pairing/__init__.py | 4 ++ nanobot/pairing/store.py | 78 +++++++++++++++-------------- tests/channels/test_base_channel.py | 6 +-- 5 files changed, 75 insertions(+), 85 deletions(-) diff --git a/nanobot/channels/base.py b/nanobot/channels/base.py index 48ee3cd00..c43b3904f 100644 --- a/nanobot/channels/base.py +++ b/nanobot/channels/base.py @@ -2,7 +2,6 @@ from __future__ import annotations -import time from abc import ABC, abstractmethod from pathlib import Path from typing import Any @@ -14,6 +13,8 @@ from nanobot.bus.queue import MessageBus from nanobot.pairing import ( approve_code, deny_code, + format_expiry, + format_pairing_reply, generate_code, is_approved, list_pending, @@ -222,35 +223,15 @@ class BaseChannel(ABC): session_key: str | None = None, is_dm: bool = False, ) -> None: - """ - Handle an incoming message from the chat platform. - - This method checks permissions and forwards to the bus. - For DM messages from unrecognised senders, a pairing code is - issued instead of silently dropping the message. - - Args: - sender_id: The sender's identifier. - chat_id: The chat/channel identifier. - content: Message text content. - media: Optional list of media URLs. - metadata: Optional channel-specific metadata. - session_key: Optional session key override (e.g. thread-scoped sessions). - is_dm: Whether the message is a direct / private message. - """ + """Handle an incoming message: check permissions, issue pairing codes in DMs, or forward to bus.""" if not self.is_allowed(sender_id): if is_dm: code = generate_code(self.name, str(sender_id)) - reply = ( - "This assistant requires approval before it can respond.\n" - f"Your pairing code is: `{code}`\n" - f"Ask the owner to run: `nanobot pairing approve {code}`" - ) await self.send( OutboundMessage( channel=self.name, chat_id=str(chat_id), - content=reply, + content=format_pairing_reply(code), metadata={"_pairing_code": code}, ) ) @@ -267,8 +248,9 @@ class BaseChannel(ABC): return # Intercept /pairing slash commands before they reach the agent loop - if content.strip().startswith("/pairing"): - await self._handle_pairing_command(sender_id, chat_id, content.strip()) + parts = content.strip().split(None, 1) + if parts and parts[0] == "/pairing": + await self._handle_pairing_command(sender_id, chat_id, parts[1] if len(parts) > 1 else "") return meta = metadata or {} @@ -288,12 +270,12 @@ class BaseChannel(ABC): await self.bus.publish_inbound(msg) async def _handle_pairing_command( - self, sender_id: str, chat_id: str, content: str + self, sender_id: str, chat_id: str, subcommand_text: str ) -> None: """Execute a ``/pairing`` slash command and reply directly to the user.""" - parts = content.split() - sub = parts[1] if len(parts) > 1 else "list" - arg = parts[2] if len(parts) > 2 else None + parts = subcommand_text.split() + sub = parts[0] if parts else "list" + arg = parts[1] if len(parts) > 1 else None if sub in ("list",): pending = list_pending() @@ -302,8 +284,7 @@ class BaseChannel(ABC): else: lines = ["Pending pairing requests:"] for item in pending: - remaining = int(item.get("expires_at", 0) - time.time()) - expiry = f"{remaining}s" if remaining > 0 else "expired" + expiry = format_expiry(item.get("expires_at", 0)) lines.append( f"- `{item['code']}` | {item['channel']} | {item['sender_id']} | {expiry}" ) @@ -335,13 +316,20 @@ class BaseChannel(ABC): elif sub == "revoke": if arg is None: reply = "Usage: `/pairing revoke ` or `/pairing revoke `" + elif len(parts) == 2: + reply = ( + f"Revoked {arg} from {self.name}" + if revoke(self.name, arg) + else f"{arg} was not in the approved list for {self.name}" + ) + elif len(parts) == 3: + reply = ( + f"Revoked {parts[2]} from {arg}" + if revoke(arg, parts[2]) + else f"{parts[2]} was not in the approved list for {arg}" + ) else: - target_channel = parts[3] if len(parts) > 3 else self.name - target_user = arg if len(parts) <= 3 else parts[3] - if revoke(target_channel, target_user): - reply = f"Revoked {target_user} from {target_channel}" - else: - reply = f"{target_user} was not in the approved list for {target_channel}" + reply = "Usage: `/pairing revoke ` or `/pairing revoke `" else: reply = ( diff --git a/nanobot/channels/slack.py b/nanobot/channels/slack.py index 6c37fd3b1..8f55338d6 100644 --- a/nanobot/channels/slack.py +++ b/nanobot/channels/slack.py @@ -18,6 +18,7 @@ from nanobot.bus.queue import MessageBus from nanobot.channels.base import BaseChannel from nanobot.config.paths import get_media_dir from nanobot.config.schema import Base +from nanobot.pairing import format_pairing_reply, generate_code, is_approved from nanobot.utils.helpers import safe_filename, split_message @@ -343,13 +344,8 @@ class SlackChannel(BaseChannel): if not self._is_allowed(sender_id, chat_id, channel_type): if channel_type == "im" and self.config.dm.enabled: - from nanobot.pairing import generate_code code = generate_code(self.name, sender_id) - reply = ( - "This assistant requires approval before it can respond.\n" - f"Your pairing code is: `{code}`\n" - f"Ask the owner to run: `nanobot pairing approve {code}`" - ) + reply = format_pairing_reply(code) await self.send( OutboundMessage( channel=self.name, @@ -624,8 +620,6 @@ class SlackChannel(BaseChannel): self.logger.debug("done reaction failed: {}", e) def _is_allowed(self, sender_id: str, chat_id: str, channel_type: str) -> bool: - from nanobot.pairing import is_approved - if channel_type == "im": if not self.config.dm.enabled: return False diff --git a/nanobot/pairing/__init__.py b/nanobot/pairing/__init__.py index 55f1c9f8c..0d1367c93 100644 --- a/nanobot/pairing/__init__.py +++ b/nanobot/pairing/__init__.py @@ -3,6 +3,8 @@ from nanobot.pairing.store import ( approve_code, deny_code, + format_expiry, + format_pairing_reply, generate_code, get_approved, is_approved, @@ -13,6 +15,8 @@ from nanobot.pairing.store import ( __all__ = [ "approve_code", "deny_code", + "format_expiry", + "format_pairing_reply", "generate_code", "get_approved", "is_approved", diff --git a/nanobot/pairing/store.py b/nanobot/pairing/store.py index fb531abdf..17e954602 100644 --- a/nanobot/pairing/store.py +++ b/nanobot/pairing/store.py @@ -8,7 +8,6 @@ private-assistant scale: small JSON file, simple locking, no external DB. from __future__ import annotations import json -import os import secrets import string import threading @@ -19,11 +18,11 @@ from typing import Any from loguru import logger from nanobot.config.paths import get_data_dir +from nanobot.utils.helpers import _write_text_atomic _LOCK = threading.Lock() - _ALPHABET = string.ascii_uppercase + string.digits -_CODE_LENGTH = 8 # e.g. XK9-42F-MP +_CODE_LENGTH = 8 # e.g. ABCD-EFGH _TTL_DEFAULT_S = 600 # 10 minutes @@ -37,30 +36,26 @@ def _load() -> dict[str, Any]: return {"approved": {}, "pending": {}} try: with open(path, encoding="utf-8") as f: - return json.load(f) + data = json.load(f) except (json.JSONDecodeError, OSError): logger.warning("Corrupted pairing store, resetting") return {"approved": {}, "pending": {}} + # Convert approved lists to sets for O(1) lookup + for channel, users in data.get("approved", {}).items(): + data["approved"][channel] = set(users) + return data + def _save(data: dict[str, Any]) -> None: path = _store_path() path.parent.mkdir(parents=True, exist_ok=True) - tmp = path.with_suffix(".tmp") - with open(tmp, "w", encoding="utf-8") as f: - json.dump(data, f, indent=2, ensure_ascii=False) - f.flush() - os.fsync(f.fileno()) - tmp.replace(path) - # Ensure directory entry is flushed for durability (Unix only; no-op on Windows) - try: - fd = os.open(path.parent, os.O_RDONLY) - try: - os.fsync(fd) - finally: - os.close(fd) - except (OSError, NotImplementedError): - pass + # Convert sets back to lists for JSON serialization + payload = { + "approved": {ch: sorted(list(users)) for ch, users in data.get("approved", {}).items()}, + "pending": dict(data.get("pending", {})), + } + _write_text_atomic(path, json.dumps(payload, indent=2, ensure_ascii=False)) def _gc_pending(data: dict[str, Any]) -> None: @@ -79,19 +74,13 @@ def generate_code( ) -> str: """Create a new pairing code for *sender_id* on *channel*. - Returns the code (e.g. ``"XK9-42F"``). + Returns the code (e.g. ``"ABCD-EFGH"``). """ with _LOCK: data = _load() _gc_pending(data) - # Ensure uniqueness - for _ in range(100): - raw = "".join(secrets.choice(_ALPHABET) for _ in range(_CODE_LENGTH)) - code = f"{raw[:4]}-{raw[4:]}" - if code not in data.get("pending", {}): - break - else: # pragma: no cover - raise RuntimeError("Failed to generate unique pairing code") + raw = "".join(secrets.choice(_ALPHABET) for _ in range(_CODE_LENGTH)) + code = f"{raw[:4]}-{raw[4:]}" data.setdefault("pending", {})[code] = { "channel": channel, @@ -119,7 +108,7 @@ def approve_code(code: str) -> tuple[str, str] | None: return None channel = info["channel"] sender_id = info["sender_id"] - data.setdefault("approved", {}).setdefault(channel, []).append(sender_id) + data.setdefault("approved", {}).setdefault(channel, set()).add(sender_id) _save(data) logger.info("Approved pairing code {} for {}@{}", code, sender_id, channel) return channel, sender_id @@ -146,8 +135,8 @@ def is_approved(channel: str, sender_id: str) -> bool: """Check whether *sender_id* has been approved on *channel*.""" with _LOCK: data = _load() - approved: dict[str, list[str]] = data.get("approved", {}) - return str(sender_id) in approved.get(channel, []) + approved: dict[str, set[str]] = data.get("approved", {}) + return str(sender_id) in approved.get(channel, set()) def list_pending() -> list[dict[str, Any]]: @@ -168,11 +157,11 @@ def revoke(channel: str, sender_id: str) -> bool: """ with _LOCK: data = _load() - approved: dict[str, list[str]] = data.get("approved", {}) - lst = approved.get(channel, []) - if sender_id in lst: - lst.remove(sender_id) - if not lst: + approved: dict[str, set[str]] = data.get("approved", {}) + users = approved.get(channel, set()) + if sender_id in users: + users.discard(sender_id) + if not users: del approved[channel] _save(data) logger.info("Revoked {} from {}", sender_id, channel) @@ -184,4 +173,19 @@ def get_approved(channel: str) -> list[str]: """Return all approved sender IDs for *channel*.""" with _LOCK: data = _load() - return list(data.get("approved", {}).get(channel, [])) + return sorted(data.get("approved", {}).get(channel, set())) + + +def format_pairing_reply(code: str) -> str: + """Return the pairing-code message sent to unrecognised DM senders.""" + return ( + "This assistant requires approval before it can respond.\n" + f"Your pairing code is: `{code}`\n" + f"Ask the owner to run: `nanobot pairing approve {code}`" + ) + + +def format_expiry(expires_at: float) -> str: + """Return a human-readable expiry string (e.g. ``"120s"`` or ``"expired"``).""" + remaining = int(expires_at - time.time()) + return f"{remaining}s" if remaining > 0 else "expired" diff --git a/tests/channels/test_base_channel.py b/tests/channels/test_base_channel.py index 651e3365d..ab321dde2 100644 --- a/tests/channels/test_base_channel.py +++ b/tests/channels/test_base_channel.py @@ -101,7 +101,7 @@ async def test_handle_pairing_command_list(monkeypatch) -> None: ], ) - await channel._handle_pairing_command("owner", "chat1", "/pairing list") + await channel._handle_pairing_command("owner", "chat1", "list") assert len(channel._sent) == 1 assert "ABCD-EFGH" in channel._sent[0].content @@ -115,7 +115,7 @@ async def test_handle_pairing_command_approve(monkeypatch) -> None: lambda code: ("dummy", "123") if code == "ABCD-EFGH" else None, ) - await channel._handle_pairing_command("owner", "chat1", "/pairing approve ABCD-EFGH") + await channel._handle_pairing_command("owner", "chat1", "approve ABCD-EFGH") assert len(channel._sent) == 1 assert "Approved" in channel._sent[0].content @@ -129,7 +129,7 @@ async def test_handle_pairing_command_revoke(monkeypatch) -> None: lambda ch, sid: sid == "123", ) - await channel._handle_pairing_command("owner", "chat1", "/pairing revoke 123") + await channel._handle_pairing_command("owner", "chat1", "revoke 123") assert len(channel._sent) == 1 assert "Revoked" in channel._sent[0].content