refactor(pairing): apply simplify review fixes

- Extract format_pairing_reply() and format_expiry() to eliminate
duplication between BaseChannel and SlackChannel.
- Use _write_text_atomic() from helpers.py instead of hand-rolled
fsync logic in pairing store.
- Convert approved lists to in-memory sets for O(1) lookup.
- Remove collision retry loop (8-char entropy is sufficient).
- Fix /pairing command parsing to split prefix exactly.
- Remove unused import time from base.py.
- Fix tests to pass subcommand_text, not full /pairing string.
This commit is contained in:
chengyongru 2026-05-14 11:20:49 +08:00
parent cf224c9701
commit 4e7022e73c
5 changed files with 75 additions and 85 deletions

View File

@ -2,7 +2,6 @@
from __future__ import annotations
import time
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any
@ -14,6 +13,8 @@ from nanobot.bus.queue import MessageBus
from nanobot.pairing import (
approve_code,
deny_code,
format_expiry,
format_pairing_reply,
generate_code,
is_approved,
list_pending,
@ -222,35 +223,15 @@ class BaseChannel(ABC):
session_key: str | None = None,
is_dm: bool = False,
) -> None:
"""
Handle an incoming message from the chat platform.
This method checks permissions and forwards to the bus.
For DM messages from unrecognised senders, a pairing code is
issued instead of silently dropping the message.
Args:
sender_id: The sender's identifier.
chat_id: The chat/channel identifier.
content: Message text content.
media: Optional list of media URLs.
metadata: Optional channel-specific metadata.
session_key: Optional session key override (e.g. thread-scoped sessions).
is_dm: Whether the message is a direct / private message.
"""
"""Handle an incoming message: check permissions, issue pairing codes in DMs, or forward to bus."""
if not self.is_allowed(sender_id):
if is_dm:
code = generate_code(self.name, str(sender_id))
reply = (
"This assistant requires approval before it can respond.\n"
f"Your pairing code is: `{code}`\n"
f"Ask the owner to run: `nanobot pairing approve {code}`"
)
await self.send(
OutboundMessage(
channel=self.name,
chat_id=str(chat_id),
content=reply,
content=format_pairing_reply(code),
metadata={"_pairing_code": code},
)
)
@ -267,8 +248,9 @@ class BaseChannel(ABC):
return
# Intercept /pairing slash commands before they reach the agent loop
if content.strip().startswith("/pairing"):
await self._handle_pairing_command(sender_id, chat_id, content.strip())
parts = content.strip().split(None, 1)
if parts and parts[0] == "/pairing":
await self._handle_pairing_command(sender_id, chat_id, parts[1] if len(parts) > 1 else "")
return
meta = metadata or {}
@ -288,12 +270,12 @@ class BaseChannel(ABC):
await self.bus.publish_inbound(msg)
async def _handle_pairing_command(
self, sender_id: str, chat_id: str, content: str
self, sender_id: str, chat_id: str, subcommand_text: str
) -> None:
"""Execute a ``/pairing`` slash command and reply directly to the user."""
parts = content.split()
sub = parts[1] if len(parts) > 1 else "list"
arg = parts[2] if len(parts) > 2 else None
parts = subcommand_text.split()
sub = parts[0] if parts else "list"
arg = parts[1] if len(parts) > 1 else None
if sub in ("list",):
pending = list_pending()
@ -302,8 +284,7 @@ class BaseChannel(ABC):
else:
lines = ["Pending pairing requests:"]
for item in pending:
remaining = int(item.get("expires_at", 0) - time.time())
expiry = f"{remaining}s" if remaining > 0 else "expired"
expiry = format_expiry(item.get("expires_at", 0))
lines.append(
f"- `{item['code']}` | {item['channel']} | {item['sender_id']} | {expiry}"
)
@ -335,13 +316,20 @@ class BaseChannel(ABC):
elif sub == "revoke":
if arg is None:
reply = "Usage: `/pairing revoke <user_id>` or `/pairing revoke <channel> <user_id>`"
elif len(parts) == 2:
reply = (
f"Revoked {arg} from {self.name}"
if revoke(self.name, arg)
else f"{arg} was not in the approved list for {self.name}"
)
elif len(parts) == 3:
reply = (
f"Revoked {parts[2]} from {arg}"
if revoke(arg, parts[2])
else f"{parts[2]} was not in the approved list for {arg}"
)
else:
target_channel = parts[3] if len(parts) > 3 else self.name
target_user = arg if len(parts) <= 3 else parts[3]
if revoke(target_channel, target_user):
reply = f"Revoked {target_user} from {target_channel}"
else:
reply = f"{target_user} was not in the approved list for {target_channel}"
reply = "Usage: `/pairing revoke <user_id>` or `/pairing revoke <channel> <user_id>`"
else:
reply = (

View File

@ -18,6 +18,7 @@ from nanobot.bus.queue import MessageBus
from nanobot.channels.base import BaseChannel
from nanobot.config.paths import get_media_dir
from nanobot.config.schema import Base
from nanobot.pairing import format_pairing_reply, generate_code, is_approved
from nanobot.utils.helpers import safe_filename, split_message
@ -343,13 +344,8 @@ class SlackChannel(BaseChannel):
if not self._is_allowed(sender_id, chat_id, channel_type):
if channel_type == "im" and self.config.dm.enabled:
from nanobot.pairing import generate_code
code = generate_code(self.name, sender_id)
reply = (
"This assistant requires approval before it can respond.\n"
f"Your pairing code is: `{code}`\n"
f"Ask the owner to run: `nanobot pairing approve {code}`"
)
reply = format_pairing_reply(code)
await self.send(
OutboundMessage(
channel=self.name,
@ -624,8 +620,6 @@ class SlackChannel(BaseChannel):
self.logger.debug("done reaction failed: {}", e)
def _is_allowed(self, sender_id: str, chat_id: str, channel_type: str) -> bool:
from nanobot.pairing import is_approved
if channel_type == "im":
if not self.config.dm.enabled:
return False

View File

@ -3,6 +3,8 @@
from nanobot.pairing.store import (
approve_code,
deny_code,
format_expiry,
format_pairing_reply,
generate_code,
get_approved,
is_approved,
@ -13,6 +15,8 @@ from nanobot.pairing.store import (
__all__ = [
"approve_code",
"deny_code",
"format_expiry",
"format_pairing_reply",
"generate_code",
"get_approved",
"is_approved",

View File

@ -8,7 +8,6 @@ private-assistant scale: small JSON file, simple locking, no external DB.
from __future__ import annotations
import json
import os
import secrets
import string
import threading
@ -19,11 +18,11 @@ from typing import Any
from loguru import logger
from nanobot.config.paths import get_data_dir
from nanobot.utils.helpers import _write_text_atomic
_LOCK = threading.Lock()
_ALPHABET = string.ascii_uppercase + string.digits
_CODE_LENGTH = 8 # e.g. XK9-42F-MP
_CODE_LENGTH = 8 # e.g. ABCD-EFGH
_TTL_DEFAULT_S = 600 # 10 minutes
@ -37,30 +36,26 @@ def _load() -> dict[str, Any]:
return {"approved": {}, "pending": {}}
try:
with open(path, encoding="utf-8") as f:
return json.load(f)
data = json.load(f)
except (json.JSONDecodeError, OSError):
logger.warning("Corrupted pairing store, resetting")
return {"approved": {}, "pending": {}}
# Convert approved lists to sets for O(1) lookup
for channel, users in data.get("approved", {}).items():
data["approved"][channel] = set(users)
return data
def _save(data: dict[str, Any]) -> None:
path = _store_path()
path.parent.mkdir(parents=True, exist_ok=True)
tmp = path.with_suffix(".tmp")
with open(tmp, "w", encoding="utf-8") as f:
json.dump(data, f, indent=2, ensure_ascii=False)
f.flush()
os.fsync(f.fileno())
tmp.replace(path)
# Ensure directory entry is flushed for durability (Unix only; no-op on Windows)
try:
fd = os.open(path.parent, os.O_RDONLY)
try:
os.fsync(fd)
finally:
os.close(fd)
except (OSError, NotImplementedError):
pass
# Convert sets back to lists for JSON serialization
payload = {
"approved": {ch: sorted(list(users)) for ch, users in data.get("approved", {}).items()},
"pending": dict(data.get("pending", {})),
}
_write_text_atomic(path, json.dumps(payload, indent=2, ensure_ascii=False))
def _gc_pending(data: dict[str, Any]) -> None:
@ -79,19 +74,13 @@ def generate_code(
) -> str:
"""Create a new pairing code for *sender_id* on *channel*.
Returns the code (e.g. ``"XK9-42F"``).
Returns the code (e.g. ``"ABCD-EFGH"``).
"""
with _LOCK:
data = _load()
_gc_pending(data)
# Ensure uniqueness
for _ in range(100):
raw = "".join(secrets.choice(_ALPHABET) for _ in range(_CODE_LENGTH))
code = f"{raw[:4]}-{raw[4:]}"
if code not in data.get("pending", {}):
break
else: # pragma: no cover
raise RuntimeError("Failed to generate unique pairing code")
raw = "".join(secrets.choice(_ALPHABET) for _ in range(_CODE_LENGTH))
code = f"{raw[:4]}-{raw[4:]}"
data.setdefault("pending", {})[code] = {
"channel": channel,
@ -119,7 +108,7 @@ def approve_code(code: str) -> tuple[str, str] | None:
return None
channel = info["channel"]
sender_id = info["sender_id"]
data.setdefault("approved", {}).setdefault(channel, []).append(sender_id)
data.setdefault("approved", {}).setdefault(channel, set()).add(sender_id)
_save(data)
logger.info("Approved pairing code {} for {}@{}", code, sender_id, channel)
return channel, sender_id
@ -146,8 +135,8 @@ def is_approved(channel: str, sender_id: str) -> bool:
"""Check whether *sender_id* has been approved on *channel*."""
with _LOCK:
data = _load()
approved: dict[str, list[str]] = data.get("approved", {})
return str(sender_id) in approved.get(channel, [])
approved: dict[str, set[str]] = data.get("approved", {})
return str(sender_id) in approved.get(channel, set())
def list_pending() -> list[dict[str, Any]]:
@ -168,11 +157,11 @@ def revoke(channel: str, sender_id: str) -> bool:
"""
with _LOCK:
data = _load()
approved: dict[str, list[str]] = data.get("approved", {})
lst = approved.get(channel, [])
if sender_id in lst:
lst.remove(sender_id)
if not lst:
approved: dict[str, set[str]] = data.get("approved", {})
users = approved.get(channel, set())
if sender_id in users:
users.discard(sender_id)
if not users:
del approved[channel]
_save(data)
logger.info("Revoked {} from {}", sender_id, channel)
@ -184,4 +173,19 @@ def get_approved(channel: str) -> list[str]:
"""Return all approved sender IDs for *channel*."""
with _LOCK:
data = _load()
return list(data.get("approved", {}).get(channel, []))
return sorted(data.get("approved", {}).get(channel, set()))
def format_pairing_reply(code: str) -> str:
"""Return the pairing-code message sent to unrecognised DM senders."""
return (
"This assistant requires approval before it can respond.\n"
f"Your pairing code is: `{code}`\n"
f"Ask the owner to run: `nanobot pairing approve {code}`"
)
def format_expiry(expires_at: float) -> str:
"""Return a human-readable expiry string (e.g. ``"120s"`` or ``"expired"``)."""
remaining = int(expires_at - time.time())
return f"{remaining}s" if remaining > 0 else "expired"

View File

@ -101,7 +101,7 @@ async def test_handle_pairing_command_list(monkeypatch) -> None:
],
)
await channel._handle_pairing_command("owner", "chat1", "/pairing list")
await channel._handle_pairing_command("owner", "chat1", "list")
assert len(channel._sent) == 1
assert "ABCD-EFGH" in channel._sent[0].content
@ -115,7 +115,7 @@ async def test_handle_pairing_command_approve(monkeypatch) -> None:
lambda code: ("dummy", "123") if code == "ABCD-EFGH" else None,
)
await channel._handle_pairing_command("owner", "chat1", "/pairing approve ABCD-EFGH")
await channel._handle_pairing_command("owner", "chat1", "approve ABCD-EFGH")
assert len(channel._sent) == 1
assert "Approved" in channel._sent[0].content
@ -129,7 +129,7 @@ async def test_handle_pairing_command_revoke(monkeypatch) -> None:
lambda ch, sid: sid == "123",
)
await channel._handle_pairing_command("owner", "chat1", "/pairing revoke 123")
await channel._handle_pairing_command("owner", "chat1", "revoke 123")
assert len(channel._sent) == 1
assert "Revoked" in channel._sent[0].content