mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-19 16:12:30 +00:00
refactor(pairing): apply simplify review fixes
- Extract format_pairing_reply() and format_expiry() to eliminate duplication between BaseChannel and SlackChannel. - Use _write_text_atomic() from helpers.py instead of hand-rolled fsync logic in pairing store. - Convert approved lists to in-memory sets for O(1) lookup. - Remove collision retry loop (8-char entropy is sufficient). - Fix /pairing command parsing to split prefix exactly. - Remove unused import time from base.py. - Fix tests to pass subcommand_text, not full /pairing string.
This commit is contained in:
parent
cf224c9701
commit
4e7022e73c
@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import time
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@ -14,6 +13,8 @@ from nanobot.bus.queue import MessageBus
|
|||||||
from nanobot.pairing import (
|
from nanobot.pairing import (
|
||||||
approve_code,
|
approve_code,
|
||||||
deny_code,
|
deny_code,
|
||||||
|
format_expiry,
|
||||||
|
format_pairing_reply,
|
||||||
generate_code,
|
generate_code,
|
||||||
is_approved,
|
is_approved,
|
||||||
list_pending,
|
list_pending,
|
||||||
@ -222,35 +223,15 @@ class BaseChannel(ABC):
|
|||||||
session_key: str | None = None,
|
session_key: str | None = None,
|
||||||
is_dm: bool = False,
|
is_dm: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""Handle an incoming message: check permissions, issue pairing codes in DMs, or forward to bus."""
|
||||||
Handle an incoming message from the chat platform.
|
|
||||||
|
|
||||||
This method checks permissions and forwards to the bus.
|
|
||||||
For DM messages from unrecognised senders, a pairing code is
|
|
||||||
issued instead of silently dropping the message.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
sender_id: The sender's identifier.
|
|
||||||
chat_id: The chat/channel identifier.
|
|
||||||
content: Message text content.
|
|
||||||
media: Optional list of media URLs.
|
|
||||||
metadata: Optional channel-specific metadata.
|
|
||||||
session_key: Optional session key override (e.g. thread-scoped sessions).
|
|
||||||
is_dm: Whether the message is a direct / private message.
|
|
||||||
"""
|
|
||||||
if not self.is_allowed(sender_id):
|
if not self.is_allowed(sender_id):
|
||||||
if is_dm:
|
if is_dm:
|
||||||
code = generate_code(self.name, str(sender_id))
|
code = generate_code(self.name, str(sender_id))
|
||||||
reply = (
|
|
||||||
"This assistant requires approval before it can respond.\n"
|
|
||||||
f"Your pairing code is: `{code}`\n"
|
|
||||||
f"Ask the owner to run: `nanobot pairing approve {code}`"
|
|
||||||
)
|
|
||||||
await self.send(
|
await self.send(
|
||||||
OutboundMessage(
|
OutboundMessage(
|
||||||
channel=self.name,
|
channel=self.name,
|
||||||
chat_id=str(chat_id),
|
chat_id=str(chat_id),
|
||||||
content=reply,
|
content=format_pairing_reply(code),
|
||||||
metadata={"_pairing_code": code},
|
metadata={"_pairing_code": code},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -267,8 +248,9 @@ class BaseChannel(ABC):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Intercept /pairing slash commands before they reach the agent loop
|
# Intercept /pairing slash commands before they reach the agent loop
|
||||||
if content.strip().startswith("/pairing"):
|
parts = content.strip().split(None, 1)
|
||||||
await self._handle_pairing_command(sender_id, chat_id, content.strip())
|
if parts and parts[0] == "/pairing":
|
||||||
|
await self._handle_pairing_command(sender_id, chat_id, parts[1] if len(parts) > 1 else "")
|
||||||
return
|
return
|
||||||
|
|
||||||
meta = metadata or {}
|
meta = metadata or {}
|
||||||
@ -288,12 +270,12 @@ class BaseChannel(ABC):
|
|||||||
await self.bus.publish_inbound(msg)
|
await self.bus.publish_inbound(msg)
|
||||||
|
|
||||||
async def _handle_pairing_command(
|
async def _handle_pairing_command(
|
||||||
self, sender_id: str, chat_id: str, content: str
|
self, sender_id: str, chat_id: str, subcommand_text: str
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Execute a ``/pairing`` slash command and reply directly to the user."""
|
"""Execute a ``/pairing`` slash command and reply directly to the user."""
|
||||||
parts = content.split()
|
parts = subcommand_text.split()
|
||||||
sub = parts[1] if len(parts) > 1 else "list"
|
sub = parts[0] if parts else "list"
|
||||||
arg = parts[2] if len(parts) > 2 else None
|
arg = parts[1] if len(parts) > 1 else None
|
||||||
|
|
||||||
if sub in ("list",):
|
if sub in ("list",):
|
||||||
pending = list_pending()
|
pending = list_pending()
|
||||||
@ -302,8 +284,7 @@ class BaseChannel(ABC):
|
|||||||
else:
|
else:
|
||||||
lines = ["Pending pairing requests:"]
|
lines = ["Pending pairing requests:"]
|
||||||
for item in pending:
|
for item in pending:
|
||||||
remaining = int(item.get("expires_at", 0) - time.time())
|
expiry = format_expiry(item.get("expires_at", 0))
|
||||||
expiry = f"{remaining}s" if remaining > 0 else "expired"
|
|
||||||
lines.append(
|
lines.append(
|
||||||
f"- `{item['code']}` | {item['channel']} | {item['sender_id']} | {expiry}"
|
f"- `{item['code']}` | {item['channel']} | {item['sender_id']} | {expiry}"
|
||||||
)
|
)
|
||||||
@ -335,13 +316,20 @@ class BaseChannel(ABC):
|
|||||||
elif sub == "revoke":
|
elif sub == "revoke":
|
||||||
if arg is None:
|
if arg is None:
|
||||||
reply = "Usage: `/pairing revoke <user_id>` or `/pairing revoke <channel> <user_id>`"
|
reply = "Usage: `/pairing revoke <user_id>` or `/pairing revoke <channel> <user_id>`"
|
||||||
|
elif len(parts) == 2:
|
||||||
|
reply = (
|
||||||
|
f"Revoked {arg} from {self.name}"
|
||||||
|
if revoke(self.name, arg)
|
||||||
|
else f"{arg} was not in the approved list for {self.name}"
|
||||||
|
)
|
||||||
|
elif len(parts) == 3:
|
||||||
|
reply = (
|
||||||
|
f"Revoked {parts[2]} from {arg}"
|
||||||
|
if revoke(arg, parts[2])
|
||||||
|
else f"{parts[2]} was not in the approved list for {arg}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
target_channel = parts[3] if len(parts) > 3 else self.name
|
reply = "Usage: `/pairing revoke <user_id>` or `/pairing revoke <channel> <user_id>`"
|
||||||
target_user = arg if len(parts) <= 3 else parts[3]
|
|
||||||
if revoke(target_channel, target_user):
|
|
||||||
reply = f"Revoked {target_user} from {target_channel}"
|
|
||||||
else:
|
|
||||||
reply = f"{target_user} was not in the approved list for {target_channel}"
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
reply = (
|
reply = (
|
||||||
|
|||||||
@ -18,6 +18,7 @@ from nanobot.bus.queue import MessageBus
|
|||||||
from nanobot.channels.base import BaseChannel
|
from nanobot.channels.base import BaseChannel
|
||||||
from nanobot.config.paths import get_media_dir
|
from nanobot.config.paths import get_media_dir
|
||||||
from nanobot.config.schema import Base
|
from nanobot.config.schema import Base
|
||||||
|
from nanobot.pairing import format_pairing_reply, generate_code, is_approved
|
||||||
from nanobot.utils.helpers import safe_filename, split_message
|
from nanobot.utils.helpers import safe_filename, split_message
|
||||||
|
|
||||||
|
|
||||||
@ -343,13 +344,8 @@ class SlackChannel(BaseChannel):
|
|||||||
|
|
||||||
if not self._is_allowed(sender_id, chat_id, channel_type):
|
if not self._is_allowed(sender_id, chat_id, channel_type):
|
||||||
if channel_type == "im" and self.config.dm.enabled:
|
if channel_type == "im" and self.config.dm.enabled:
|
||||||
from nanobot.pairing import generate_code
|
|
||||||
code = generate_code(self.name, sender_id)
|
code = generate_code(self.name, sender_id)
|
||||||
reply = (
|
reply = format_pairing_reply(code)
|
||||||
"This assistant requires approval before it can respond.\n"
|
|
||||||
f"Your pairing code is: `{code}`\n"
|
|
||||||
f"Ask the owner to run: `nanobot pairing approve {code}`"
|
|
||||||
)
|
|
||||||
await self.send(
|
await self.send(
|
||||||
OutboundMessage(
|
OutboundMessage(
|
||||||
channel=self.name,
|
channel=self.name,
|
||||||
@ -624,8 +620,6 @@ class SlackChannel(BaseChannel):
|
|||||||
self.logger.debug("done reaction failed: {}", e)
|
self.logger.debug("done reaction failed: {}", e)
|
||||||
|
|
||||||
def _is_allowed(self, sender_id: str, chat_id: str, channel_type: str) -> bool:
|
def _is_allowed(self, sender_id: str, chat_id: str, channel_type: str) -> bool:
|
||||||
from nanobot.pairing import is_approved
|
|
||||||
|
|
||||||
if channel_type == "im":
|
if channel_type == "im":
|
||||||
if not self.config.dm.enabled:
|
if not self.config.dm.enabled:
|
||||||
return False
|
return False
|
||||||
|
|||||||
@ -3,6 +3,8 @@
|
|||||||
from nanobot.pairing.store import (
|
from nanobot.pairing.store import (
|
||||||
approve_code,
|
approve_code,
|
||||||
deny_code,
|
deny_code,
|
||||||
|
format_expiry,
|
||||||
|
format_pairing_reply,
|
||||||
generate_code,
|
generate_code,
|
||||||
get_approved,
|
get_approved,
|
||||||
is_approved,
|
is_approved,
|
||||||
@ -13,6 +15,8 @@ from nanobot.pairing.store import (
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"approve_code",
|
"approve_code",
|
||||||
"deny_code",
|
"deny_code",
|
||||||
|
"format_expiry",
|
||||||
|
"format_pairing_reply",
|
||||||
"generate_code",
|
"generate_code",
|
||||||
"get_approved",
|
"get_approved",
|
||||||
"is_approved",
|
"is_approved",
|
||||||
|
|||||||
@ -8,7 +8,6 @@ private-assistant scale: small JSON file, simple locking, no external DB.
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
|
||||||
import secrets
|
import secrets
|
||||||
import string
|
import string
|
||||||
import threading
|
import threading
|
||||||
@ -19,11 +18,11 @@ from typing import Any
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from nanobot.config.paths import get_data_dir
|
from nanobot.config.paths import get_data_dir
|
||||||
|
from nanobot.utils.helpers import _write_text_atomic
|
||||||
|
|
||||||
_LOCK = threading.Lock()
|
_LOCK = threading.Lock()
|
||||||
|
|
||||||
_ALPHABET = string.ascii_uppercase + string.digits
|
_ALPHABET = string.ascii_uppercase + string.digits
|
||||||
_CODE_LENGTH = 8 # e.g. XK9-42F-MP
|
_CODE_LENGTH = 8 # e.g. ABCD-EFGH
|
||||||
_TTL_DEFAULT_S = 600 # 10 minutes
|
_TTL_DEFAULT_S = 600 # 10 minutes
|
||||||
|
|
||||||
|
|
||||||
@ -37,30 +36,26 @@ def _load() -> dict[str, Any]:
|
|||||||
return {"approved": {}, "pending": {}}
|
return {"approved": {}, "pending": {}}
|
||||||
try:
|
try:
|
||||||
with open(path, encoding="utf-8") as f:
|
with open(path, encoding="utf-8") as f:
|
||||||
return json.load(f)
|
data = json.load(f)
|
||||||
except (json.JSONDecodeError, OSError):
|
except (json.JSONDecodeError, OSError):
|
||||||
logger.warning("Corrupted pairing store, resetting")
|
logger.warning("Corrupted pairing store, resetting")
|
||||||
return {"approved": {}, "pending": {}}
|
return {"approved": {}, "pending": {}}
|
||||||
|
|
||||||
|
# Convert approved lists to sets for O(1) lookup
|
||||||
|
for channel, users in data.get("approved", {}).items():
|
||||||
|
data["approved"][channel] = set(users)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
def _save(data: dict[str, Any]) -> None:
|
def _save(data: dict[str, Any]) -> None:
|
||||||
path = _store_path()
|
path = _store_path()
|
||||||
path.parent.mkdir(parents=True, exist_ok=True)
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
tmp = path.with_suffix(".tmp")
|
# Convert sets back to lists for JSON serialization
|
||||||
with open(tmp, "w", encoding="utf-8") as f:
|
payload = {
|
||||||
json.dump(data, f, indent=2, ensure_ascii=False)
|
"approved": {ch: sorted(list(users)) for ch, users in data.get("approved", {}).items()},
|
||||||
f.flush()
|
"pending": dict(data.get("pending", {})),
|
||||||
os.fsync(f.fileno())
|
}
|
||||||
tmp.replace(path)
|
_write_text_atomic(path, json.dumps(payload, indent=2, ensure_ascii=False))
|
||||||
# Ensure directory entry is flushed for durability (Unix only; no-op on Windows)
|
|
||||||
try:
|
|
||||||
fd = os.open(path.parent, os.O_RDONLY)
|
|
||||||
try:
|
|
||||||
os.fsync(fd)
|
|
||||||
finally:
|
|
||||||
os.close(fd)
|
|
||||||
except (OSError, NotImplementedError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def _gc_pending(data: dict[str, Any]) -> None:
|
def _gc_pending(data: dict[str, Any]) -> None:
|
||||||
@ -79,19 +74,13 @@ def generate_code(
|
|||||||
) -> str:
|
) -> str:
|
||||||
"""Create a new pairing code for *sender_id* on *channel*.
|
"""Create a new pairing code for *sender_id* on *channel*.
|
||||||
|
|
||||||
Returns the code (e.g. ``"XK9-42F"``).
|
Returns the code (e.g. ``"ABCD-EFGH"``).
|
||||||
"""
|
"""
|
||||||
with _LOCK:
|
with _LOCK:
|
||||||
data = _load()
|
data = _load()
|
||||||
_gc_pending(data)
|
_gc_pending(data)
|
||||||
# Ensure uniqueness
|
raw = "".join(secrets.choice(_ALPHABET) for _ in range(_CODE_LENGTH))
|
||||||
for _ in range(100):
|
code = f"{raw[:4]}-{raw[4:]}"
|
||||||
raw = "".join(secrets.choice(_ALPHABET) for _ in range(_CODE_LENGTH))
|
|
||||||
code = f"{raw[:4]}-{raw[4:]}"
|
|
||||||
if code not in data.get("pending", {}):
|
|
||||||
break
|
|
||||||
else: # pragma: no cover
|
|
||||||
raise RuntimeError("Failed to generate unique pairing code")
|
|
||||||
|
|
||||||
data.setdefault("pending", {})[code] = {
|
data.setdefault("pending", {})[code] = {
|
||||||
"channel": channel,
|
"channel": channel,
|
||||||
@ -119,7 +108,7 @@ def approve_code(code: str) -> tuple[str, str] | None:
|
|||||||
return None
|
return None
|
||||||
channel = info["channel"]
|
channel = info["channel"]
|
||||||
sender_id = info["sender_id"]
|
sender_id = info["sender_id"]
|
||||||
data.setdefault("approved", {}).setdefault(channel, []).append(sender_id)
|
data.setdefault("approved", {}).setdefault(channel, set()).add(sender_id)
|
||||||
_save(data)
|
_save(data)
|
||||||
logger.info("Approved pairing code {} for {}@{}", code, sender_id, channel)
|
logger.info("Approved pairing code {} for {}@{}", code, sender_id, channel)
|
||||||
return channel, sender_id
|
return channel, sender_id
|
||||||
@ -146,8 +135,8 @@ def is_approved(channel: str, sender_id: str) -> bool:
|
|||||||
"""Check whether *sender_id* has been approved on *channel*."""
|
"""Check whether *sender_id* has been approved on *channel*."""
|
||||||
with _LOCK:
|
with _LOCK:
|
||||||
data = _load()
|
data = _load()
|
||||||
approved: dict[str, list[str]] = data.get("approved", {})
|
approved: dict[str, set[str]] = data.get("approved", {})
|
||||||
return str(sender_id) in approved.get(channel, [])
|
return str(sender_id) in approved.get(channel, set())
|
||||||
|
|
||||||
|
|
||||||
def list_pending() -> list[dict[str, Any]]:
|
def list_pending() -> list[dict[str, Any]]:
|
||||||
@ -168,11 +157,11 @@ def revoke(channel: str, sender_id: str) -> bool:
|
|||||||
"""
|
"""
|
||||||
with _LOCK:
|
with _LOCK:
|
||||||
data = _load()
|
data = _load()
|
||||||
approved: dict[str, list[str]] = data.get("approved", {})
|
approved: dict[str, set[str]] = data.get("approved", {})
|
||||||
lst = approved.get(channel, [])
|
users = approved.get(channel, set())
|
||||||
if sender_id in lst:
|
if sender_id in users:
|
||||||
lst.remove(sender_id)
|
users.discard(sender_id)
|
||||||
if not lst:
|
if not users:
|
||||||
del approved[channel]
|
del approved[channel]
|
||||||
_save(data)
|
_save(data)
|
||||||
logger.info("Revoked {} from {}", sender_id, channel)
|
logger.info("Revoked {} from {}", sender_id, channel)
|
||||||
@ -184,4 +173,19 @@ def get_approved(channel: str) -> list[str]:
|
|||||||
"""Return all approved sender IDs for *channel*."""
|
"""Return all approved sender IDs for *channel*."""
|
||||||
with _LOCK:
|
with _LOCK:
|
||||||
data = _load()
|
data = _load()
|
||||||
return list(data.get("approved", {}).get(channel, []))
|
return sorted(data.get("approved", {}).get(channel, set()))
|
||||||
|
|
||||||
|
|
||||||
|
def format_pairing_reply(code: str) -> str:
|
||||||
|
"""Return the pairing-code message sent to unrecognised DM senders."""
|
||||||
|
return (
|
||||||
|
"This assistant requires approval before it can respond.\n"
|
||||||
|
f"Your pairing code is: `{code}`\n"
|
||||||
|
f"Ask the owner to run: `nanobot pairing approve {code}`"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def format_expiry(expires_at: float) -> str:
|
||||||
|
"""Return a human-readable expiry string (e.g. ``"120s"`` or ``"expired"``)."""
|
||||||
|
remaining = int(expires_at - time.time())
|
||||||
|
return f"{remaining}s" if remaining > 0 else "expired"
|
||||||
|
|||||||
@ -101,7 +101,7 @@ async def test_handle_pairing_command_list(monkeypatch) -> None:
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
await channel._handle_pairing_command("owner", "chat1", "/pairing list")
|
await channel._handle_pairing_command("owner", "chat1", "list")
|
||||||
|
|
||||||
assert len(channel._sent) == 1
|
assert len(channel._sent) == 1
|
||||||
assert "ABCD-EFGH" in channel._sent[0].content
|
assert "ABCD-EFGH" in channel._sent[0].content
|
||||||
@ -115,7 +115,7 @@ async def test_handle_pairing_command_approve(monkeypatch) -> None:
|
|||||||
lambda code: ("dummy", "123") if code == "ABCD-EFGH" else None,
|
lambda code: ("dummy", "123") if code == "ABCD-EFGH" else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
await channel._handle_pairing_command("owner", "chat1", "/pairing approve ABCD-EFGH")
|
await channel._handle_pairing_command("owner", "chat1", "approve ABCD-EFGH")
|
||||||
|
|
||||||
assert len(channel._sent) == 1
|
assert len(channel._sent) == 1
|
||||||
assert "Approved" in channel._sent[0].content
|
assert "Approved" in channel._sent[0].content
|
||||||
@ -129,7 +129,7 @@ async def test_handle_pairing_command_revoke(monkeypatch) -> None:
|
|||||||
lambda ch, sid: sid == "123",
|
lambda ch, sid: sid == "123",
|
||||||
)
|
)
|
||||||
|
|
||||||
await channel._handle_pairing_command("owner", "chat1", "/pairing revoke 123")
|
await channel._handle_pairing_command("owner", "chat1", "revoke 123")
|
||||||
|
|
||||||
assert len(channel._sent) == 1
|
assert len(channel._sent) == 1
|
||||||
assert "Revoked" in channel._sent[0].content
|
assert "Revoked" in channel._sent[0].content
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user