mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-19 16:12:30 +00:00
refactor(pairing): apply simplify review fixes
- Extract format_pairing_reply() and format_expiry() to eliminate duplication between BaseChannel and SlackChannel. - Use _write_text_atomic() from helpers.py instead of hand-rolled fsync logic in pairing store. - Convert approved lists to in-memory sets for O(1) lookup. - Remove collision retry loop (8-char entropy is sufficient). - Fix /pairing command parsing to split prefix exactly. - Remove unused import time from base.py. - Fix tests to pass subcommand_text, not full /pairing string.
This commit is contained in:
parent
f8e7e50759
commit
9bc86ee825
@ -2,7 +2,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
@ -14,6 +13,8 @@ from nanobot.bus.queue import MessageBus
|
||||
from nanobot.pairing import (
|
||||
approve_code,
|
||||
deny_code,
|
||||
format_expiry,
|
||||
format_pairing_reply,
|
||||
generate_code,
|
||||
is_approved,
|
||||
list_pending,
|
||||
@ -222,35 +223,15 @@ class BaseChannel(ABC):
|
||||
session_key: str | None = None,
|
||||
is_dm: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Handle an incoming message from the chat platform.
|
||||
|
||||
This method checks permissions and forwards to the bus.
|
||||
For DM messages from unrecognised senders, a pairing code is
|
||||
issued instead of silently dropping the message.
|
||||
|
||||
Args:
|
||||
sender_id: The sender's identifier.
|
||||
chat_id: The chat/channel identifier.
|
||||
content: Message text content.
|
||||
media: Optional list of media URLs.
|
||||
metadata: Optional channel-specific metadata.
|
||||
session_key: Optional session key override (e.g. thread-scoped sessions).
|
||||
is_dm: Whether the message is a direct / private message.
|
||||
"""
|
||||
"""Handle an incoming message: check permissions, issue pairing codes in DMs, or forward to bus."""
|
||||
if not self.is_allowed(sender_id):
|
||||
if is_dm:
|
||||
code = generate_code(self.name, str(sender_id))
|
||||
reply = (
|
||||
"This assistant requires approval before it can respond.\n"
|
||||
f"Your pairing code is: `{code}`\n"
|
||||
f"Ask the owner to run: `nanobot pairing approve {code}`"
|
||||
)
|
||||
await self.send(
|
||||
OutboundMessage(
|
||||
channel=self.name,
|
||||
chat_id=str(chat_id),
|
||||
content=reply,
|
||||
content=format_pairing_reply(code),
|
||||
metadata={"_pairing_code": code},
|
||||
)
|
||||
)
|
||||
@ -267,8 +248,9 @@ class BaseChannel(ABC):
|
||||
return
|
||||
|
||||
# Intercept /pairing slash commands before they reach the agent loop
|
||||
if content.strip().startswith("/pairing"):
|
||||
await self._handle_pairing_command(sender_id, chat_id, content.strip())
|
||||
parts = content.strip().split(None, 1)
|
||||
if parts and parts[0] == "/pairing":
|
||||
await self._handle_pairing_command(sender_id, chat_id, parts[1] if len(parts) > 1 else "")
|
||||
return
|
||||
|
||||
meta = metadata or {}
|
||||
@ -288,12 +270,12 @@ class BaseChannel(ABC):
|
||||
await self.bus.publish_inbound(msg)
|
||||
|
||||
async def _handle_pairing_command(
|
||||
self, sender_id: str, chat_id: str, content: str
|
||||
self, sender_id: str, chat_id: str, subcommand_text: str
|
||||
) -> None:
|
||||
"""Execute a ``/pairing`` slash command and reply directly to the user."""
|
||||
parts = content.split()
|
||||
sub = parts[1] if len(parts) > 1 else "list"
|
||||
arg = parts[2] if len(parts) > 2 else None
|
||||
parts = subcommand_text.split()
|
||||
sub = parts[0] if parts else "list"
|
||||
arg = parts[1] if len(parts) > 1 else None
|
||||
|
||||
if sub in ("list",):
|
||||
pending = list_pending()
|
||||
@ -302,8 +284,7 @@ class BaseChannel(ABC):
|
||||
else:
|
||||
lines = ["Pending pairing requests:"]
|
||||
for item in pending:
|
||||
remaining = int(item.get("expires_at", 0) - time.time())
|
||||
expiry = f"{remaining}s" if remaining > 0 else "expired"
|
||||
expiry = format_expiry(item.get("expires_at", 0))
|
||||
lines.append(
|
||||
f"- `{item['code']}` | {item['channel']} | {item['sender_id']} | {expiry}"
|
||||
)
|
||||
@ -335,13 +316,20 @@ class BaseChannel(ABC):
|
||||
elif sub == "revoke":
|
||||
if arg is None:
|
||||
reply = "Usage: `/pairing revoke <user_id>` or `/pairing revoke <channel> <user_id>`"
|
||||
elif len(parts) == 2:
|
||||
reply = (
|
||||
f"Revoked {arg} from {self.name}"
|
||||
if revoke(self.name, arg)
|
||||
else f"{arg} was not in the approved list for {self.name}"
|
||||
)
|
||||
elif len(parts) == 3:
|
||||
reply = (
|
||||
f"Revoked {parts[2]} from {arg}"
|
||||
if revoke(arg, parts[2])
|
||||
else f"{parts[2]} was not in the approved list for {arg}"
|
||||
)
|
||||
else:
|
||||
target_channel = parts[3] if len(parts) > 3 else self.name
|
||||
target_user = arg if len(parts) <= 3 else parts[3]
|
||||
if revoke(target_channel, target_user):
|
||||
reply = f"Revoked {target_user} from {target_channel}"
|
||||
else:
|
||||
reply = f"{target_user} was not in the approved list for {target_channel}"
|
||||
reply = "Usage: `/pairing revoke <user_id>` or `/pairing revoke <channel> <user_id>`"
|
||||
|
||||
else:
|
||||
reply = (
|
||||
|
||||
@ -18,6 +18,7 @@ from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.paths import get_media_dir
|
||||
from nanobot.config.schema import Base
|
||||
from nanobot.pairing import format_pairing_reply, generate_code, is_approved
|
||||
from nanobot.utils.helpers import safe_filename, split_message
|
||||
|
||||
|
||||
@ -343,13 +344,8 @@ class SlackChannel(BaseChannel):
|
||||
|
||||
if not self._is_allowed(sender_id, chat_id, channel_type):
|
||||
if channel_type == "im" and self.config.dm.enabled:
|
||||
from nanobot.pairing import generate_code
|
||||
code = generate_code(self.name, sender_id)
|
||||
reply = (
|
||||
"This assistant requires approval before it can respond.\n"
|
||||
f"Your pairing code is: `{code}`\n"
|
||||
f"Ask the owner to run: `nanobot pairing approve {code}`"
|
||||
)
|
||||
reply = format_pairing_reply(code)
|
||||
await self.send(
|
||||
OutboundMessage(
|
||||
channel=self.name,
|
||||
@ -624,8 +620,6 @@ class SlackChannel(BaseChannel):
|
||||
self.logger.debug("done reaction failed: {}", e)
|
||||
|
||||
def _is_allowed(self, sender_id: str, chat_id: str, channel_type: str) -> bool:
|
||||
from nanobot.pairing import is_approved
|
||||
|
||||
if channel_type == "im":
|
||||
if not self.config.dm.enabled:
|
||||
return False
|
||||
|
||||
@ -3,6 +3,8 @@
|
||||
from nanobot.pairing.store import (
|
||||
approve_code,
|
||||
deny_code,
|
||||
format_expiry,
|
||||
format_pairing_reply,
|
||||
generate_code,
|
||||
get_approved,
|
||||
is_approved,
|
||||
@ -13,6 +15,8 @@ from nanobot.pairing.store import (
|
||||
__all__ = [
|
||||
"approve_code",
|
||||
"deny_code",
|
||||
"format_expiry",
|
||||
"format_pairing_reply",
|
||||
"generate_code",
|
||||
"get_approved",
|
||||
"is_approved",
|
||||
|
||||
@ -8,7 +8,6 @@ private-assistant scale: small JSON file, simple locking, no external DB.
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import secrets
|
||||
import string
|
||||
import threading
|
||||
@ -19,11 +18,11 @@ from typing import Any
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.config.paths import get_data_dir
|
||||
from nanobot.utils.helpers import _write_text_atomic
|
||||
|
||||
_LOCK = threading.Lock()
|
||||
|
||||
_ALPHABET = string.ascii_uppercase + string.digits
|
||||
_CODE_LENGTH = 8 # e.g. XK9-42F-MP
|
||||
_CODE_LENGTH = 8 # e.g. ABCD-EFGH
|
||||
_TTL_DEFAULT_S = 600 # 10 minutes
|
||||
|
||||
|
||||
@ -37,30 +36,26 @@ def _load() -> dict[str, Any]:
|
||||
return {"approved": {}, "pending": {}}
|
||||
try:
|
||||
with open(path, encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
data = json.load(f)
|
||||
except (json.JSONDecodeError, OSError):
|
||||
logger.warning("Corrupted pairing store, resetting")
|
||||
return {"approved": {}, "pending": {}}
|
||||
|
||||
# Convert approved lists to sets for O(1) lookup
|
||||
for channel, users in data.get("approved", {}).items():
|
||||
data["approved"][channel] = set(users)
|
||||
return data
|
||||
|
||||
|
||||
def _save(data: dict[str, Any]) -> None:
|
||||
path = _store_path()
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp = path.with_suffix(".tmp")
|
||||
with open(tmp, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||
f.flush()
|
||||
os.fsync(f.fileno())
|
||||
tmp.replace(path)
|
||||
# Ensure directory entry is flushed for durability (Unix only; no-op on Windows)
|
||||
try:
|
||||
fd = os.open(path.parent, os.O_RDONLY)
|
||||
try:
|
||||
os.fsync(fd)
|
||||
finally:
|
||||
os.close(fd)
|
||||
except (OSError, NotImplementedError):
|
||||
pass
|
||||
# Convert sets back to lists for JSON serialization
|
||||
payload = {
|
||||
"approved": {ch: sorted(list(users)) for ch, users in data.get("approved", {}).items()},
|
||||
"pending": dict(data.get("pending", {})),
|
||||
}
|
||||
_write_text_atomic(path, json.dumps(payload, indent=2, ensure_ascii=False))
|
||||
|
||||
|
||||
def _gc_pending(data: dict[str, Any]) -> None:
|
||||
@ -79,19 +74,13 @@ def generate_code(
|
||||
) -> str:
|
||||
"""Create a new pairing code for *sender_id* on *channel*.
|
||||
|
||||
Returns the code (e.g. ``"XK9-42F"``).
|
||||
Returns the code (e.g. ``"ABCD-EFGH"``).
|
||||
"""
|
||||
with _LOCK:
|
||||
data = _load()
|
||||
_gc_pending(data)
|
||||
# Ensure uniqueness
|
||||
for _ in range(100):
|
||||
raw = "".join(secrets.choice(_ALPHABET) for _ in range(_CODE_LENGTH))
|
||||
code = f"{raw[:4]}-{raw[4:]}"
|
||||
if code not in data.get("pending", {}):
|
||||
break
|
||||
else: # pragma: no cover
|
||||
raise RuntimeError("Failed to generate unique pairing code")
|
||||
raw = "".join(secrets.choice(_ALPHABET) for _ in range(_CODE_LENGTH))
|
||||
code = f"{raw[:4]}-{raw[4:]}"
|
||||
|
||||
data.setdefault("pending", {})[code] = {
|
||||
"channel": channel,
|
||||
@ -119,7 +108,7 @@ def approve_code(code: str) -> tuple[str, str] | None:
|
||||
return None
|
||||
channel = info["channel"]
|
||||
sender_id = info["sender_id"]
|
||||
data.setdefault("approved", {}).setdefault(channel, []).append(sender_id)
|
||||
data.setdefault("approved", {}).setdefault(channel, set()).add(sender_id)
|
||||
_save(data)
|
||||
logger.info("Approved pairing code {} for {}@{}", code, sender_id, channel)
|
||||
return channel, sender_id
|
||||
@ -146,8 +135,8 @@ def is_approved(channel: str, sender_id: str) -> bool:
|
||||
"""Check whether *sender_id* has been approved on *channel*."""
|
||||
with _LOCK:
|
||||
data = _load()
|
||||
approved: dict[str, list[str]] = data.get("approved", {})
|
||||
return str(sender_id) in approved.get(channel, [])
|
||||
approved: dict[str, set[str]] = data.get("approved", {})
|
||||
return str(sender_id) in approved.get(channel, set())
|
||||
|
||||
|
||||
def list_pending() -> list[dict[str, Any]]:
|
||||
@ -168,11 +157,11 @@ def revoke(channel: str, sender_id: str) -> bool:
|
||||
"""
|
||||
with _LOCK:
|
||||
data = _load()
|
||||
approved: dict[str, list[str]] = data.get("approved", {})
|
||||
lst = approved.get(channel, [])
|
||||
if sender_id in lst:
|
||||
lst.remove(sender_id)
|
||||
if not lst:
|
||||
approved: dict[str, set[str]] = data.get("approved", {})
|
||||
users = approved.get(channel, set())
|
||||
if sender_id in users:
|
||||
users.discard(sender_id)
|
||||
if not users:
|
||||
del approved[channel]
|
||||
_save(data)
|
||||
logger.info("Revoked {} from {}", sender_id, channel)
|
||||
@ -184,4 +173,19 @@ def get_approved(channel: str) -> list[str]:
|
||||
"""Return all approved sender IDs for *channel*."""
|
||||
with _LOCK:
|
||||
data = _load()
|
||||
return list(data.get("approved", {}).get(channel, []))
|
||||
return sorted(data.get("approved", {}).get(channel, set()))
|
||||
|
||||
|
||||
def format_pairing_reply(code: str) -> str:
|
||||
"""Return the pairing-code message sent to unrecognised DM senders."""
|
||||
return (
|
||||
"This assistant requires approval before it can respond.\n"
|
||||
f"Your pairing code is: `{code}`\n"
|
||||
f"Ask the owner to run: `nanobot pairing approve {code}`"
|
||||
)
|
||||
|
||||
|
||||
def format_expiry(expires_at: float) -> str:
|
||||
"""Return a human-readable expiry string (e.g. ``"120s"`` or ``"expired"``)."""
|
||||
remaining = int(expires_at - time.time())
|
||||
return f"{remaining}s" if remaining > 0 else "expired"
|
||||
|
||||
@ -101,7 +101,7 @@ async def test_handle_pairing_command_list(monkeypatch) -> None:
|
||||
],
|
||||
)
|
||||
|
||||
await channel._handle_pairing_command("owner", "chat1", "/pairing list")
|
||||
await channel._handle_pairing_command("owner", "chat1", "list")
|
||||
|
||||
assert len(channel._sent) == 1
|
||||
assert "ABCD-EFGH" in channel._sent[0].content
|
||||
@ -115,7 +115,7 @@ async def test_handle_pairing_command_approve(monkeypatch) -> None:
|
||||
lambda code: ("dummy", "123") if code == "ABCD-EFGH" else None,
|
||||
)
|
||||
|
||||
await channel._handle_pairing_command("owner", "chat1", "/pairing approve ABCD-EFGH")
|
||||
await channel._handle_pairing_command("owner", "chat1", "approve ABCD-EFGH")
|
||||
|
||||
assert len(channel._sent) == 1
|
||||
assert "Approved" in channel._sent[0].content
|
||||
@ -129,7 +129,7 @@ async def test_handle_pairing_command_revoke(monkeypatch) -> None:
|
||||
lambda ch, sid: sid == "123",
|
||||
)
|
||||
|
||||
await channel._handle_pairing_command("owner", "chat1", "/pairing revoke 123")
|
||||
await channel._handle_pairing_command("owner", "chat1", "revoke 123")
|
||||
|
||||
assert len(channel._sent) == 1
|
||||
assert "Revoked" in channel._sent[0].content
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user