mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-19 16:12:30 +00:00
feat(pairing): chat-native DM sender approval
Replace the file-editing onboarding workflow with a chat-native pairing flow: - New pairing store (nanobot/pairing/store.py) persists approved senders and pending codes in ~/.nanobot/pairing.json. - DM messages from unknown senders receive a short pairing code instead of silent denial. Group chats remain silently ignored. - Existing allowFrom semantics are fully preserved; approved pairing users are merged at runtime so no config migration is needed. - nanobot pairing list/approve/deny/revoke CLI commands for bootstrap and emergency management. - /pairing slash commands intercepted in-channel so owners can approve senders without leaving the chat. - is_dm flag added to BaseChannel._handle_message; Telegram, Discord and WebSocket updated to pass it. Closes #3768
This commit is contained in:
parent
c10ec6094e
commit
4c4a9ae590
@ -10,6 +10,14 @@ from loguru import logger
|
||||
|
||||
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.pairing import (
|
||||
approve_code,
|
||||
deny_code,
|
||||
generate_code,
|
||||
is_approved,
|
||||
list_pending,
|
||||
revoke,
|
||||
)
|
||||
|
||||
|
||||
class BaseChannel(ABC):
|
||||
@ -176,7 +184,14 @@ class BaseChannel(ABC):
|
||||
return bool(streaming) and type(self).send_delta is not BaseChannel.send_delta
|
||||
|
||||
def is_allowed(self, sender_id: str) -> bool:
|
||||
"""Check if *sender_id* is permitted. Empty list → deny all; ``"*"`` → allow all."""
|
||||
"""Check if *sender_id* is permitted.
|
||||
|
||||
Priority:
|
||||
1. ``allowFrom: ["*"]`` → allow all.
|
||||
2. ``allowFrom`` list → allow if sender_id is present.
|
||||
3. Pairing store approved list → allow if previously approved.
|
||||
4. Otherwise deny.
|
||||
"""
|
||||
if isinstance(self.config, dict):
|
||||
if "allow_from" in self.config:
|
||||
allow_list = self.config.get("allow_from")
|
||||
@ -184,12 +199,13 @@ class BaseChannel(ABC):
|
||||
allow_list = self.config.get("allowFrom", [])
|
||||
else:
|
||||
allow_list = getattr(self.config, "allow_from", [])
|
||||
if not allow_list:
|
||||
self.logger.warning("allow_from is empty — all access denied")
|
||||
return False
|
||||
if "*" in allow_list:
|
||||
return True
|
||||
return str(sender_id) in allow_list
|
||||
if str(sender_id) in allow_list:
|
||||
return True
|
||||
if is_approved(self.name, str(sender_id)):
|
||||
return True
|
||||
return False
|
||||
|
||||
async def _handle_message(
|
||||
self,
|
||||
@ -199,11 +215,14 @@ class BaseChannel(ABC):
|
||||
media: list[str] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
session_key: str | None = None,
|
||||
is_dm: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Handle an incoming message from the chat platform.
|
||||
|
||||
This method checks permissions and forwards to the bus.
|
||||
For DM messages from unrecognised senders, a pairing code is
|
||||
issued instead of silently dropping the message.
|
||||
|
||||
Args:
|
||||
sender_id: The sender's identifier.
|
||||
@ -212,13 +231,39 @@ class BaseChannel(ABC):
|
||||
media: Optional list of media URLs.
|
||||
metadata: Optional channel-specific metadata.
|
||||
session_key: Optional session key override (e.g. thread-scoped sessions).
|
||||
is_dm: Whether the message is a direct / private message.
|
||||
"""
|
||||
if not self.is_allowed(sender_id):
|
||||
self.logger.warning(
|
||||
"Access denied for sender {}. "
|
||||
"Add them to allowFrom list in config to grant access.",
|
||||
sender_id,
|
||||
)
|
||||
if is_dm:
|
||||
code = generate_code(self.name, str(sender_id))
|
||||
reply = (
|
||||
"This assistant requires approval before it can respond.\n"
|
||||
f"Your pairing code is: `{code}`\n"
|
||||
f"Ask the owner to run: `nanobot pairing approve {code}`"
|
||||
)
|
||||
await self.send(
|
||||
OutboundMessage(
|
||||
channel=self.name,
|
||||
chat_id=str(chat_id),
|
||||
content=reply,
|
||||
metadata={"_pairing_code": code},
|
||||
)
|
||||
)
|
||||
self.logger.info(
|
||||
"Sent pairing code {} to sender {} in chat {}",
|
||||
code, sender_id, chat_id,
|
||||
)
|
||||
else:
|
||||
self.logger.warning(
|
||||
"Access denied for sender {}. "
|
||||
"Add them to allowFrom list in config to grant access.",
|
||||
sender_id,
|
||||
)
|
||||
return
|
||||
|
||||
# Intercept /pairing slash commands before they reach the agent loop
|
||||
if content.strip().startswith("/pairing"):
|
||||
await self._handle_pairing_command(sender_id, chat_id, content.strip())
|
||||
return
|
||||
|
||||
meta = metadata or {}
|
||||
@ -237,6 +282,77 @@ class BaseChannel(ABC):
|
||||
|
||||
await self.bus.publish_inbound(msg)
|
||||
|
||||
async def _handle_pairing_command(
|
||||
self, sender_id: str, chat_id: str, content: str
|
||||
) -> None:
|
||||
"""Execute a ``/pairing`` slash command and reply directly to the user."""
|
||||
parts = content.split()
|
||||
sub = parts[1] if len(parts) > 1 else "list"
|
||||
arg = parts[2] if len(parts) > 2 else None
|
||||
|
||||
if sub in ("list",):
|
||||
pending = list_pending()
|
||||
if not pending:
|
||||
reply = "No pending pairing requests."
|
||||
else:
|
||||
lines = ["Pending pairing requests:"]
|
||||
import time
|
||||
|
||||
for item in pending:
|
||||
remaining = int(item.get("expires_at", 0) - time.time())
|
||||
expiry = f"{remaining}s" if remaining > 0 else "expired"
|
||||
lines.append(
|
||||
f"- `{item['code']}` | {item['channel']} | {item['sender_id']} | {expiry}"
|
||||
)
|
||||
reply = "\n".join(lines)
|
||||
|
||||
elif sub == "approve":
|
||||
if arg is None:
|
||||
reply = "Usage: `/pairing approve <code>`"
|
||||
else:
|
||||
result = approve_code(arg)
|
||||
if result is None:
|
||||
reply = f"Invalid or expired pairing code: `{arg}`"
|
||||
else:
|
||||
channel, sid = result
|
||||
reply = (
|
||||
f"Approved pairing code `{arg}` — "
|
||||
f"{sid} can now access {channel}"
|
||||
)
|
||||
|
||||
elif sub == "deny":
|
||||
if arg is None:
|
||||
reply = "Usage: `/pairing deny <code>`"
|
||||
else:
|
||||
if deny_code(arg):
|
||||
reply = f"Denied pairing code `{arg}`"
|
||||
else:
|
||||
reply = f"Pairing code `{arg}` not found or already expired"
|
||||
|
||||
elif sub == "revoke":
|
||||
if arg is None:
|
||||
reply = "Usage: `/pairing revoke <user_id>`"
|
||||
else:
|
||||
if revoke(self.name, arg):
|
||||
reply = f"Revoked {arg} from {self.name}"
|
||||
else:
|
||||
reply = f"{arg} was not in the approved list for {self.name}"
|
||||
|
||||
else:
|
||||
reply = (
|
||||
"Unknown pairing command.\n"
|
||||
"Usage: `/pairing [list|approve <code>|deny <code>|revoke <user_id>]`"
|
||||
)
|
||||
|
||||
await self.send(
|
||||
OutboundMessage(
|
||||
channel=self.name,
|
||||
chat_id=str(chat_id),
|
||||
content=reply,
|
||||
metadata={"_pairing_command": True},
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def default_config(cls) -> dict[str, Any]:
|
||||
"""Return default config for onboard. Override in plugins to auto-populate config.json."""
|
||||
|
||||
@ -577,6 +577,7 @@ class DiscordChannel(BaseChannel):
|
||||
media=media_paths,
|
||||
metadata=metadata,
|
||||
session_key=session_key,
|
||||
is_dm=message.guild is None,
|
||||
)
|
||||
except Exception:
|
||||
await self._clear_reactions(channel_id)
|
||||
|
||||
@ -1011,6 +1011,7 @@ class TelegramChannel(BaseChannel):
|
||||
content=content,
|
||||
metadata=self._build_message_metadata(message, user),
|
||||
session_key=self._derive_topic_session_key(message),
|
||||
is_dm=message.chat.type == "private",
|
||||
)
|
||||
|
||||
async def _on_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
|
||||
|
||||
@ -1254,6 +1254,7 @@ class WebSocketChannel(BaseChannel):
|
||||
chat_id=default_chat_id,
|
||||
content=content,
|
||||
metadata={"remote": getattr(connection, "remote_address", None)},
|
||||
is_dm=True,
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.debug("connection ended: {}", e)
|
||||
@ -1399,6 +1400,7 @@ class WebSocketChannel(BaseChannel):
|
||||
content=content,
|
||||
media=media_paths or None,
|
||||
metadata=metadata,
|
||||
is_dm=True,
|
||||
)
|
||||
return
|
||||
await self._send_event(connection, "error", detail=f"unknown type: {t!r}")
|
||||
|
||||
@ -1620,5 +1620,94 @@ def _login_github_copilot() -> None:
|
||||
raise typer.Exit(1)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Pairing Commands
|
||||
# ============================================================================
|
||||
|
||||
pairing_app = typer.Typer(help="Manage DM pairing approvals")
|
||||
app.add_typer(pairing_app, name="pairing")
|
||||
|
||||
|
||||
@pairing_app.command("list")
|
||||
def pairing_list():
|
||||
"""Show pending pairing requests."""
|
||||
from nanobot.pairing import list_pending
|
||||
|
||||
pending = list_pending()
|
||||
if not pending:
|
||||
console.print("[dim]No pending pairing requests.[/dim]")
|
||||
return
|
||||
|
||||
table = Table(title="Pending Pairing Requests")
|
||||
table.add_column("Code", style="cyan")
|
||||
table.add_column("Channel", style="magenta")
|
||||
table.add_column("Sender ID", style="yellow")
|
||||
table.add_column("Expires", style="green")
|
||||
|
||||
import time
|
||||
|
||||
for item in pending:
|
||||
remaining = int(item.get("expires_at", 0) - time.time())
|
||||
expiry = f"{remaining}s" if remaining > 0 else "expired"
|
||||
table.add_row(
|
||||
item["code"],
|
||||
item["channel"],
|
||||
item["sender_id"],
|
||||
expiry,
|
||||
)
|
||||
|
||||
console.print(table)
|
||||
|
||||
|
||||
@pairing_app.command("approve")
|
||||
def pairing_approve(
|
||||
code: str = typer.Argument(..., help="Pairing code to approve"),
|
||||
):
|
||||
"""Approve a pending pairing code."""
|
||||
from nanobot.pairing import approve_code
|
||||
|
||||
result = approve_code(code)
|
||||
if result is None:
|
||||
console.print(f"[red]✗[/red] Invalid or expired pairing code: {code}")
|
||||
raise typer.Exit(1)
|
||||
|
||||
channel, sender_id = result
|
||||
console.print(
|
||||
f"[green]✓[/green] Approved pairing code {code} — "
|
||||
f"{sender_id} can now access {channel}"
|
||||
)
|
||||
|
||||
|
||||
@pairing_app.command("deny")
|
||||
def pairing_deny(
|
||||
code: str = typer.Argument(..., help="Pairing code to deny"),
|
||||
):
|
||||
"""Deny and discard a pending pairing code."""
|
||||
from nanobot.pairing import deny_code
|
||||
|
||||
if deny_code(code):
|
||||
console.print(f"[green]✓[/green] Denied pairing code {code}")
|
||||
else:
|
||||
console.print(f"[yellow]! Pairing code {code} not found or already expired[/yellow]")
|
||||
|
||||
|
||||
@pairing_app.command("revoke")
|
||||
def pairing_revoke(
|
||||
channel: str = typer.Argument(..., help="Channel name (e.g. telegram)"),
|
||||
user_id: str = typer.Argument(..., help="User ID to revoke"),
|
||||
):
|
||||
"""Revoke an approved sender from a channel."""
|
||||
from nanobot.pairing import revoke
|
||||
|
||||
if revoke(channel, user_id):
|
||||
console.print(
|
||||
f"[green]✓[/green] Revoked {user_id} from {channel}"
|
||||
)
|
||||
else:
|
||||
console.print(
|
||||
f"[yellow]! {user_id} was not in the approved list for {channel}[/yellow]"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
|
||||
21
nanobot/pairing/__init__.py
Normal file
21
nanobot/pairing/__init__.py
Normal file
@ -0,0 +1,21 @@
|
||||
"""Pairing module for DM sender approval."""
|
||||
|
||||
from nanobot.pairing.store import (
|
||||
approve_code,
|
||||
deny_code,
|
||||
generate_code,
|
||||
get_approved,
|
||||
is_approved,
|
||||
list_pending,
|
||||
revoke,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"approve_code",
|
||||
"deny_code",
|
||||
"generate_code",
|
||||
"get_approved",
|
||||
"is_approved",
|
||||
"list_pending",
|
||||
"revoke",
|
||||
]
|
||||
175
nanobot/pairing/store.py
Normal file
175
nanobot/pairing/store.py
Normal file
@ -0,0 +1,175 @@
|
||||
"""Pairing store for DM sender approval.
|
||||
|
||||
Persistent storage at ``~/.nanobot/pairing.json`` keeps approved senders
|
||||
and pending pairing codes per channel. The store is designed for
|
||||
private-assistant scale: small JSON file, simple locking, no external DB.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import secrets
|
||||
import string
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.config.paths import get_data_dir
|
||||
|
||||
_LOCK = threading.Lock()
|
||||
_ALPHABET = string.ascii_uppercase + string.digits
|
||||
_CODE_LENGTH = 6 # e.g. XK9-42F
|
||||
_TTL_DEFAULT_S = 600 # 10 minutes
|
||||
|
||||
|
||||
def _store_path() -> Path:
|
||||
return get_data_dir() / "pairing.json"
|
||||
|
||||
|
||||
def _load() -> dict[str, Any]:
|
||||
path = _store_path()
|
||||
if not path.exists():
|
||||
return {"approved": {}, "pending": {}}
|
||||
try:
|
||||
with open(path, encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except (json.JSONDecodeError, OSError):
|
||||
logger.warning("Corrupted pairing store, resetting")
|
||||
return {"approved": {}, "pending": {}}
|
||||
|
||||
|
||||
def _save(data: dict[str, Any]) -> None:
|
||||
path = _store_path()
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp = path.with_suffix(".tmp")
|
||||
with open(tmp, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||
f.flush()
|
||||
tmp.replace(path)
|
||||
|
||||
|
||||
def _gc_pending(data: dict[str, Any]) -> None:
|
||||
"""Remove expired pending entries in-place."""
|
||||
now = time.time()
|
||||
pending: dict[str, Any] = data.get("pending", {})
|
||||
expired = [code for code, info in pending.items() if info.get("expires_at", 0) < now]
|
||||
for code in expired:
|
||||
del pending[code]
|
||||
|
||||
|
||||
def generate_code(
|
||||
channel: str,
|
||||
sender_id: str,
|
||||
ttl: int = _TTL_DEFAULT_S,
|
||||
) -> str:
|
||||
"""Create a new pairing code for *sender_id* on *channel*.
|
||||
|
||||
Returns the code (e.g. ``"XK9-42F"``).
|
||||
"""
|
||||
with _LOCK:
|
||||
data = _load()
|
||||
_gc_pending(data)
|
||||
# Ensure uniqueness
|
||||
for _ in range(100):
|
||||
raw = "".join(secrets.choice(_ALPHABET) for _ in range(_CODE_LENGTH))
|
||||
code = f"{raw[:3]}-{raw[3:]}"
|
||||
if code not in data.get("pending", {}):
|
||||
break
|
||||
else: # pragma: no cover
|
||||
raise RuntimeError("Failed to generate unique pairing code")
|
||||
|
||||
data.setdefault("pending", {})[code] = {
|
||||
"channel": channel,
|
||||
"sender_id": sender_id,
|
||||
"created_at": time.time(),
|
||||
"expires_at": time.time() + ttl,
|
||||
}
|
||||
_save(data)
|
||||
logger.info("Generated pairing code {} for {}@{}", code, sender_id, channel)
|
||||
return code
|
||||
|
||||
|
||||
def approve_code(code: str) -> tuple[str, str] | None:
|
||||
"""Approve a pending pairing code.
|
||||
|
||||
Returns ``(channel, sender_id)`` on success, or ``None`` if the code
|
||||
does not exist or has expired.
|
||||
"""
|
||||
with _LOCK:
|
||||
data = _load()
|
||||
_gc_pending(data)
|
||||
pending: dict[str, Any] = data.get("pending", {})
|
||||
info = pending.pop(code, None)
|
||||
if info is None:
|
||||
return None
|
||||
channel = info["channel"]
|
||||
sender_id = info["sender_id"]
|
||||
data.setdefault("approved", {}).setdefault(channel, []).append(sender_id)
|
||||
_save(data)
|
||||
logger.info("Approved pairing code {} for {}@{}", code, sender_id, channel)
|
||||
return channel, sender_id
|
||||
|
||||
|
||||
def deny_code(code: str) -> bool:
|
||||
"""Reject and discard a pending pairing code.
|
||||
|
||||
Returns ``True`` if the code existed and was removed.
|
||||
"""
|
||||
with _LOCK:
|
||||
data = _load()
|
||||
_gc_pending(data)
|
||||
pending: dict[str, Any] = data.get("pending", {})
|
||||
if code in pending:
|
||||
del pending[code]
|
||||
_save(data)
|
||||
logger.info("Denied pairing code {}", code)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_approved(channel: str, sender_id: str) -> bool:
|
||||
"""Check whether *sender_id* has been approved on *channel*."""
|
||||
with _LOCK:
|
||||
data = _load()
|
||||
approved: dict[str, list[str]] = data.get("approved", {})
|
||||
return str(sender_id) in approved.get(channel, [])
|
||||
|
||||
|
||||
def list_pending() -> list[dict[str, Any]]:
|
||||
"""Return all non-expired pending pairing requests."""
|
||||
with _LOCK:
|
||||
data = _load()
|
||||
_gc_pending(data)
|
||||
return [
|
||||
{"code": code, **info}
|
||||
for code, info in data.get("pending", {}).items()
|
||||
]
|
||||
|
||||
|
||||
def revoke(channel: str, sender_id: str) -> bool:
|
||||
"""Remove an approved sender from *channel*.
|
||||
|
||||
Returns ``True`` if the sender was present and removed.
|
||||
"""
|
||||
with _LOCK:
|
||||
data = _load()
|
||||
approved: dict[str, list[str]] = data.get("approved", {})
|
||||
lst = approved.get(channel, [])
|
||||
if sender_id in lst:
|
||||
lst.remove(sender_id)
|
||||
if not lst:
|
||||
del approved[channel]
|
||||
_save(data)
|
||||
logger.info("Revoked {} from {}", sender_id, channel)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_approved(channel: str) -> list[str]:
|
||||
"""Return all approved sender IDs for *channel*."""
|
||||
with _LOCK:
|
||||
data = _load()
|
||||
return list(data.get("approved", {}).get(channel, []))
|
||||
Loading…
x
Reference in New Issue
Block a user