From 880097acd523a9c75c5aaa3f385d736f0ee9ff69 Mon Sep 17 00:00:00 2001 From: Kaloyan Tenchov Date: Fri, 15 May 2026 22:02:43 -0400 Subject: [PATCH] feat(signal): add Signal channel support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Integrates signal-cli daemon via HTTP JSON-RPC as a nanobot channel. Supports DMs and group chats with open/allowlist access policies, markdown→Signal text style conversion, typing indicators, attachment handling, group message context buffering, and automatic reconnect with exponential backoff. Includes unit tests for channel lifecycle, message routing, mention detection, markdown conversion, and message splitting. Originally based on https://github.com/HKUDS/nanobot/pull/601. --- nanobot/channels/signal.py | 1133 ++++++++++++++++++++++++ tests/channels/test_signal_channel.py | 1058 ++++++++++++++++++++++ tests/channels/test_signal_markdown.py | 244 +++++ 3 files changed, 2435 insertions(+) create mode 100644 nanobot/channels/signal.py create mode 100644 tests/channels/test_signal_channel.py create mode 100644 tests/channels/test_signal_markdown.py diff --git a/nanobot/channels/signal.py b/nanobot/channels/signal.py new file mode 100644 index 000000000..3e35ae676 --- /dev/null +++ b/nanobot/channels/signal.py @@ -0,0 +1,1133 @@ +"""Signal channel implementation using signal-cli daemon JSON-RPC interface.""" + +from __future__ import annotations + +import asyncio +import json +import re +import shutil +import unicodedata +from collections import deque +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +import httpx +from pydantic import Field + +from nanobot.bus.events import OutboundMessage +from nanobot.bus.queue import MessageBus +from nanobot.channels.base import BaseChannel +from nanobot.config.paths import get_media_dir +from nanobot.config.schema import Base +from nanobot.utils.helpers import safe_filename, split_message + + +@dataclass +class _Run: + text: str + styles: frozenset[str] = field(default_factory=frozenset) + opaque: bool = False # code / table content — skip further pattern processing + + +_SIG_CODE_BLOCK_RE = re.compile(r'```(?:\w+)?\n?([\s\S]*?)```') +_SIG_INLINE_CODE_RE = re.compile(r'`([^`\n]+)`') +_SIG_HEADER_RE = re.compile(r'^#{1,6}\s+(.+)$', re.MULTILINE) +_SIG_BLOCKQUOTE_RE = re.compile(r'^>\s*(.*)$', re.MULTILINE) +_SIG_BULLET_RE = re.compile(r'^[-*]\s+', re.MULTILINE) +_SIG_OLIST_RE = re.compile(r'^(\d+)\.\s+', re.MULTILINE) +_SIG_LINK_RE = re.compile(r'\[([^\]]+)\]\(([^)]+)\)') +_SIG_BOLD_RE = re.compile(r'\*\*(.+?)\*\*|__(.+?)__', re.DOTALL) +_SIG_ITALIC_RE = re.compile(r'(? str: + """Strip inline markdown from a table cell for plain-text rendering.""" + s = re.sub(r'\*\*(.+?)\*\*', r'\1', s) + s = re.sub(r'__(.+?)__', r'\1', s) + s = re.sub(r'~~(.+?)~~', r'\1', s) + s = re.sub(r'`([^`]+)`', r'\1', s) + return s.strip() + + +def _sig_render_table(table_lines: list[str]) -> str: + """Render a markdown pipe-table as fixed-width plain text.""" + + def dw(s: str) -> int: + return sum(2 if unicodedata.east_asian_width(c) in ('W', 'F') else 1 for c in s) + + rows: list[list[str]] = [] + has_sep = False + for line in table_lines: + cells = [_sig_strip_cell(c) for c in line.strip().strip('|').split('|')] + if all(re.match(r'^:?-+:?$', c) for c in cells if c): + has_sep = True + continue + rows.append(cells) + if not rows or not has_sep: + return '\n'.join(table_lines) + + ncols = max(len(r) for r in rows) + for r in rows: + r.extend([''] * (ncols - len(r))) + widths = [max(dw(r[c]) for r in rows) for c in range(ncols)] + + def dr(cells: list[str]) -> str: + return ' '.join(f'{c}{" " * (w - dw(c))}' for c, w in zip(cells, widths)) + + out = [dr(rows[0])] + out.append(' '.join('─' * w for w in widths)) + for row in rows[1:]: + out.append(dr(row)) + return '\n'.join(out) + + +def _markdown_to_signal(text: str) -> tuple[str, list[str]]: + """Convert markdown text to Signal plain text + textStyle ranges. + + Returns ``(plain_text, text_styles)`` where ``text_styles`` are + ``"start:length:STYLE"`` strings for the signal-cli ``textStyle`` parameter. + """ + if not text: + return text, [] + + # Phase 1 (text-level): extract code blocks and tables with placeholder tokens + # so they're protected from inline-style processing. + protected: list[str] = [] + + def save_code(m: re.Match) -> str: + protected.append(m.group(1)) + return f"\x00C{len(protected) - 1}\x00" + + text = _SIG_CODE_BLOCK_RE.sub(save_code, text) + + # Detect and render pipe-tables line by line. + lines = text.split('\n') + rebuilt: list[str] = [] + i = 0 + while i < len(lines): + if re.match(r'^\s*\|.+\|', lines[i]): + tbl: list[str] = [] + while i < len(lines) and re.match(r'^\s*\|.+\|', lines[i]): + tbl.append(lines[i]) + i += 1 + rendered = _sig_render_table(tbl) + if rendered != '\n'.join(tbl): + protected.append(rendered) + rebuilt.append(f"\x00C{len(protected) - 1}\x00") + else: + rebuilt.extend(tbl) + else: + rebuilt.append(lines[i]) + i += 1 + text = '\n'.join(rebuilt) + + # Phase 2 (run-based): process inline patterns. + runs: list[_Run] = [_Run(text)] + + def transform(pattern: re.Pattern, make_runs: Any) -> None: + new_runs: list[_Run] = [] + for run in runs: + if run.opaque: + new_runs.append(run) + continue + pos = 0 + for m in pattern.finditer(run.text): + if m.start() > pos: + new_runs.append(_Run(run.text[pos:m.start()], run.styles)) + new_runs.extend(make_runs(m, run.styles)) + pos = m.end() + if pos < len(run.text): + new_runs.append(_Run(run.text[pos:], run.styles)) + runs[:] = new_runs + + # Restore code/table placeholders as opaque MONOSPACE runs. + transform(_SIG_TOKEN_RE, lambda m, s: [_Run(protected[int(m.group(1))], s | {"MONOSPACE"}, opaque=True)]) + + # Inline code (opaque). + transform(_SIG_INLINE_CODE_RE, lambda m, s: [_Run(m.group(1), s | {"MONOSPACE"}, opaque=True)]) + + # Headers → bold plain text. + transform(_SIG_HEADER_RE, lambda m, s: [_Run(m.group(1), s | {"BOLD"})]) + + # Blockquotes → strip marker. + transform(_SIG_BLOCKQUOTE_RE, lambda m, s: [_Run(m.group(1), s)]) + + # Bullet lists → bullet character. + transform(_SIG_BULLET_RE, lambda m, s: [_Run("• ", s)]) + + # Numbered lists → normalize spacing. + transform(_SIG_OLIST_RE, lambda m, s: [_Run(m.group(1) + ". ", s)]) + + # Links → "text (url)" or bare url when text equals url. + def _link_runs(m: re.Match, s: frozenset) -> list[_Run]: + link_text, url = m.group(1), m.group(2) + + def _norm(u: str) -> str: + return re.sub(r'^https?://(www\.)?', '', u).rstrip('/').lower() + + if _norm(url) == _norm(link_text): + return [_Run(url, s)] + return [_Run(f"{link_text} ({url})", s)] + + transform(_SIG_LINK_RE, _link_runs) + + # Bold (before italic so ** doesn't interfere). + transform(_SIG_BOLD_RE, lambda m, s: [_Run(m.group(1) or m.group(2), s | {"BOLD"})]) + + # Italic (single * or _). + transform(_SIG_ITALIC_RE, lambda m, s: [_Run(m.group(1) or m.group(2), s | {"ITALIC"})]) + + # Strikethrough: ~~text~~ (standard) or ~text~ (single-tilde variant). + transform(_SIG_STRIKE_RE, lambda m, s: [_Run(m.group(1) or m.group(2), s | {"STRIKETHROUGH"})]) + + # Phase 3: assemble output. + plain_text = "" + text_styles: list[str] = [] + for run in runs: + if not run.text: + continue + start = len(plain_text) + plain_text += run.text + length = len(plain_text) - start + for style in sorted(run.styles): + text_styles.append(f"{start}:{length}:{style}") + + return plain_text, text_styles + + +class SignalDMConfig(Base): + """Signal DM policy configuration.""" + + enabled: bool = False + policy: str = "allowlist" # "open" or "allowlist" + allow_from: list[str] = Field(default_factory=list) # Allowed phone numbers/UUIDs + + +class SignalGroupConfig(Base): + """Signal group policy configuration.""" + + enabled: bool = False + policy: str = "allowlist" # "open" or "allowlist" - which groups to operate in + allow_from: list[str] = Field(default_factory=list) # Allowed group IDs if allowlist policy + require_mention: bool = True # Whether bot must be mentioned to respond + + +class SignalConfig(Base): + """Signal channel configuration using signal-cli daemon (HTTP mode with -a flag only).""" + + enabled: bool = False + phone_number: str = "" # Your Signal phone number (e.g., "+1234567890") + daemon_host: str = "localhost" + daemon_port: int = 8080 + group_message_buffer_size: int = 20 # Number of recent group messages to keep for context + dm: SignalDMConfig = Field(default_factory=SignalDMConfig) + group: SignalGroupConfig = Field(default_factory=SignalGroupConfig) + + @property + def allow_from(self) -> list[str]: + """Aggregate allowlist for the base-class is_allowed() check. + + Returns the union of dm.allow_from and group.allow_from so the base + channel gate sees a populated list when either sub-policy is configured. + A ``"*"`` wildcard in either sub-list propagates to allow all. + """ + return list(dict.fromkeys(self.dm.allow_from + self.group.allow_from)) + + +class SignalChannel(BaseChannel): + """ + Signal channel using signal-cli daemon via HTTP JSON-RPC interface. + + Requires signal-cli daemon in HTTP mode: + - signal-cli -a +1234567890 daemon --http localhost:8080 + + See https://github.com/AsamK/signal-cli for setup instructions. + """ + + name = "signal" + display_name = "Signal" + _TYPING_REFRESH_SECONDS = 10.0 + _MAX_MESSAGE_LEN = 64_000 # signal-cli practical limit (protocol max ~64 KB) + + @classmethod + def default_config(cls) -> dict[str, Any]: + return SignalConfig().model_dump(by_alias=True) + + def __init__(self, config: SignalConfig, bus: MessageBus): + if isinstance(config, dict): + config = SignalConfig.model_validate(config) + super().__init__(config, bus) + self.config: SignalConfig = config + self._http: httpx.AsyncClient | None = None + self._request_id = 0 + self._sse_task: asyncio.Task | None = None + self._typing_tasks: dict[str, asyncio.Task] = {} + self._typing_uuid_warnings: set[str] = set() + self._account_id_aliases: set[str] = set() + self._remember_account_id_alias(self.config.phone_number) + + # Rolling message buffer for group context (group_id -> deque of messages) + # Each message is a dict with: sender_name, sender_number, content, timestamp + self._group_buffers: dict[str, deque] = {} + + async def start(self) -> None: + """Start the Signal channel and connect to signal-cli daemon.""" + if not self.config.phone_number: + self.self.logger.error("Signal account not configured") + return + + self._running = True + await self._start_http_mode() + + async def _start_http_mode(self) -> None: + """Start Signal channel using Server-Sent Events for receiving messages.""" + base_url = f"http://{self.config.daemon_host}:{self.config.daemon_port}" + reconnect_delay_s = 1.0 + max_reconnect_delay_s = 30.0 + + while self._running: + try: + self.logger.info(f"Connecting to signal-cli daemon at {base_url}...") + + # Create HTTP client + self._http = httpx.AsyncClient(timeout=60.0, base_url=base_url) + + # Test connection + try: + response = await self._http.get("/api/v1/check") + if response.status_code == 200: + self.logger.info("Connected to signal-cli daemon") + else: + raise ConnectionRefusedError( + f"signal-cli daemon check returned status {response.status_code}" + ) + except Exception as e: + raise ConnectionRefusedError(f"signal-cli daemon not responding: {e}") + + # Reset reconnect delay after successful connection check. + reconnect_delay_s = 1.0 + + # Ensure account-level typing indicators are enabled. + await self._ensure_typing_indicators_enabled() + + # Start SSE receiver and supervise it. If it exits while we're still + # running, treat it as a disconnect and reconnect. + self._sse_task = asyncio.create_task(self._sse_receive_loop()) + await self._sse_task + if self._running: + raise ConnectionError("Signal SSE stream ended unexpectedly") + + except asyncio.CancelledError: + break + except ConnectionRefusedError as e: + self.logger.error( + f"{e}. Make sure signal-cli daemon is running: " + f"signal-cli -a {self.config.phone_number} daemon --http {self.config.daemon_host}:{self.config.daemon_port}" + ) + except Exception as e: + self.logger.error(f"Signal channel error: {e}") + finally: + if self._sse_task: + if not self._sse_task.done(): + self._sse_task.cancel() + try: + await self._sse_task + except asyncio.CancelledError: + pass + except Exception: + pass + self._sse_task = None + if self._http: + await self._http.aclose() + self._http = None + + if self._running: + self.logger.info( + f"Reconnecting to signal-cli daemon in {reconnect_delay_s:.0f} seconds..." + ) + await asyncio.sleep(reconnect_delay_s) + reconnect_delay_s = min(reconnect_delay_s * 2, max_reconnect_delay_s) + + async def stop(self) -> None: + """Stop the Signal channel.""" + self._running = False + + # Stop SSE task + if self._sse_task: + self._sse_task.cancel() + try: + await self._sse_task + except asyncio.CancelledError: + pass + + # Cancel active typing indicators + for chat_id in list(self._typing_tasks): + await self._stop_typing(chat_id) + + # Close HTTP client + if self._http: + await self._http.aclose() + self._http = None + + async def send(self, msg: OutboundMessage) -> None: + """Send a message through Signal.""" + is_progress_message = bool(msg.metadata.get("_progress")) + try: + plain_text, text_styles = _markdown_to_signal(msg.content) + if not plain_text and not msg.media: + return + recipient_params = self._recipient_params(msg.chat_id) + + chunks = split_message(plain_text, self._MAX_MESSAGE_LEN) if plain_text else [""] + for i, chunk in enumerate(chunks): + params: dict[str, Any] = {"message": chunk} + if text_styles and i == 0: + params["textStyle"] = text_styles + params.update(recipient_params) + if msg.media and i == 0: + params["attachments"] = msg.media + + response = await self._send_request("send", params) + + if "error" in response: + self.logger.error(f"Error sending Signal message: {response['error']}") + else: + self.logger.debug( + f"Signal message sent, timestamp: {response.get('result', {}).get('timestamp')}" + ) + + except Exception: + self.logger.exception("Error sending Signal message") + raise + finally: + # Keep typing active across progress updates; stop on the final reply. + if not is_progress_message: + # Avoid immediate START->STOP for fast responses, which can be invisible + # in some Signal clients. Let indicator expire naturally (~15s). + await self._stop_typing(msg.chat_id, send_stop=False) + + async def _sse_receive_loop(self) -> None: + """Receive messages via Server-Sent Events (HTTP mode).""" + if not self._http: + raise RuntimeError("HTTP client not initialized for Signal SSE stream") + + self.logger.info("Started Signal message receive loop (SSE)") + + try: + async with self._http.stream("GET", "/api/v1/events") as response: + if response.status_code != 200: + raise ConnectionError( + f"SSE connection failed with status {response.status_code}" + ) + + self.logger.info("Subscribed to Signal messages via SSE") + + # Buffer for accumulating SSE data across multiple lines + event_buffer = [] + + async for line in response.aiter_lines(): + if not self._running: + break + + # Debug: log raw SSE lines (except keepalive pings) + if line and line != ":": + self.logger.debug(f"SSE line received: {line[:200]}") + + # SSE format handling + if isinstance(line, str): + # Empty line signals end of event + if not line or line == ":": + if event_buffer: + # Try to parse the accumulated data + data_str = "" + try: + data_str = "".join(event_buffer) + data = json.loads(data_str) + self.logger.debug(f"SSE event parsed: {data}") + await self._handle_receive_notification(data) + except json.JSONDecodeError as e: + self.logger.warning( + f"Invalid JSON in SSE buffer: {e}, data: {data_str[:200]}" + ) + finally: + event_buffer = [] + + # "data:" line - accumulate it + elif line.startswith("data:"): + event_buffer.append(line[5:]) # Skip "data:" prefix + + # "event:" line - just log it (we only care about data) + elif line.startswith("event:"): + pass # Ignore event type for now + + if self._running: + raise ConnectionError("Signal SSE stream closed by remote endpoint") + + except asyncio.CancelledError: + self.logger.info("SSE receive loop cancelled") + raise + except Exception as e: + self.logger.error(f"Error in SSE receive loop: {e}") + raise + + async def _handle_receive_notification(self, params: dict[str, Any]) -> None: + """Handle incoming message notification from signal-cli.""" + self.logger.debug(f"_handle_receive_notification called with: {params}") + try: + # Extract envelope from SSE notification: {"envelope": {...}} + envelope = params.get("envelope", {}) + + self.logger.debug(f"Extracted envelope: {envelope}") + + if not envelope: + self.logger.debug("No envelope found in params") + return + + # Extract sender information + sender_parts = self._collect_sender_id_parts(envelope) + source_name = envelope.get("sourceName") + + if not sender_parts: + self.logger.debug("Received message without source, skipping") + return + + sender_number = self._primary_sender_id(sender_parts) + sender_id = "|".join(sender_parts) + + # Keep aliases of the bot account for robust mention matching. + if any(self._id_matches_account(part) for part in sender_parts): + for part in sender_parts: + self._remember_account_id_alias(part) + + # Check different message types + data_message = envelope.get("dataMessage") + sync_message = envelope.get("syncMessage") + typing_message = envelope.get("typingMessage") + receipt_message = envelope.get("receiptMessage") + + # Ignore receipt messages (delivery/read receipts) + if receipt_message: + return + + # Handle data messages (incoming messages from others) + if data_message: + await self._handle_data_message(sender_id, sender_number, data_message, source_name) + + # Handle sync messages (messages sent from another device) + elif sync_message and sync_message.get("sentMessage"): + sent_msg = sync_message["sentMessage"] + destination = sent_msg.get("destination") or sent_msg.get("destinationNumber") + if destination: + self.logger.debug( + f"Sync message sent to {destination}: {sent_msg.get('message', '')[:50]}" + ) + + # Handle typing indicators (silently ignore) + elif typing_message: + pass # Ignore typing indicators + + except Exception as e: + self.logger.error(f"Error handling receive notification: {e}") + + async def _handle_data_message( + self, + sender_id: str, + sender_number: str, + data_message: dict[str, Any], + sender_name: str | None, + ) -> None: + """Handle a data message (text, attachments, etc.).""" + message_text = data_message.get("message") or "" + attachments = data_message.get("attachments", []) + group_info = data_message.get("groupInfo") + timestamp = data_message.get("timestamp") + mentions = data_message.get("mentions", []) + reaction = data_message.get("reaction") + + # Log full data_message for debugging group detection + self.logger.info( + f"Data message from {sender_number}: " + f"groupInfo={group_info}, " + f"groupV2={data_message.get('groupV2')}, " + f"keys={list(data_message.keys())}" + ) + + # Ignore reaction messages (emoji reactions to messages) + if reaction: + self.logger.debug(f"Ignoring reaction message from {sender_number}: {reaction}") + return + + # Ignore empty messages (e.g., when bot is added to a group) + if not message_text and not attachments: + self.logger.debug(f"Ignoring empty message from {sender_number}") + return + + # Determine chat_id (group ID or sender number) + # Check both groupInfo (v1) and groupV2 (v2) fields for group detection + group_v2 = data_message.get("groupV2") + is_group_message = group_info is not None or group_v2 is not None + group_id = self._extract_group_id(group_info, group_v2) + + is_command = bool(message_text and message_text.strip().startswith("/")) + + if is_group_message: + chat_id = group_id or sender_number + + # Check if this group is allowed before doing anything else + if not self.config.group.enabled: + self.logger.info(f"Ignoring group message from {chat_id} (groups disabled)") + return + if ( + self.config.group.policy == "allowlist" + and chat_id not in self.config.group.allow_from + ): + self.logger.info( + f"Ignoring group message from {chat_id} (policy: {self.config.group.policy})" + ) + return + + # Add to group message buffer (group is allowed) + self._add_to_group_buffer( + group_id=chat_id, + sender_name=sender_name or sender_number, + sender_number=sender_number, + message_text=message_text, + timestamp=timestamp, + ) + + # Commands bypass the mention requirement; non-commands require it. + if not is_command and not self._should_respond_in_group(message_text, mentions): + self.logger.info( + f"Ignoring group message (require_mention: {self.config.group.require_mention})" + ) + return + else: + # Direct message — check policy first, then forward everything to the bus. + chat_id = sender_number + if not self.config.dm.enabled: + self.logger.debug(f"Ignoring DM from {sender_id} (DMs disabled)") + return + if self.config.dm.policy == "allowlist": + allow_list = self.config.dm.allow_from + sender_str = str(sender_id) + parts = [sender_str] + (sender_str.split("|") if "|" in sender_str else []) + if not any(p for p in parts if p in allow_list): + self.logger.debug(f"Ignoring DM from {sender_id} (policy: {self.config.dm.policy})") + return + + # Build content from text and attachments + content_parts = [] + media_paths = [] + + # For group messages, include recent message context + if is_group_message: + buffer_context = self._get_group_buffer_context(chat_id) + if buffer_context: + content_parts.append(f"[Recent group messages for context:]\n{buffer_context}\n---") + + # Prepend sender name for group messages so history shows who said what + if message_text: + # Strip bot mentions from text (for group messages) + if is_group_message: + message_text = self._strip_bot_mention(message_text, mentions) + # Prepend sender name to make it clear who is speaking + display_name = sender_name or sender_number + message_text = f"[{display_name}]: {message_text}" + content_parts.append(message_text) + + # Handle attachments + if attachments: + media_dir = get_media_dir("signal") + + for attachment in attachments: + attachment_id = attachment.get("id") + content_type = attachment.get("contentType", "") + filename = attachment.get("filename") or f"attachment_{attachment_id}" + + if not attachment_id: + continue + + try: + # signal-cli stores attachments in ~/.local/share/signal-cli/attachments/ + source_path = ( + Path.home() / ".local/share/signal-cli/attachments" / attachment_id + ) + + if source_path.exists(): + dest_path = media_dir / f"signal_{safe_filename(filename)}" + shutil.copy2(source_path, dest_path) + media_paths.append(str(dest_path)) + + # Determine media type from content type + media_type = content_type.split("/")[0] if "/" in content_type else "file" + if media_type not in ("image", "audio", "video"): + media_type = "file" + + content_parts.append(f"[{media_type}: {dest_path}]") + self.logger.debug(f"Downloaded attachment: {filename} -> {dest_path}") + else: + self.logger.warning(f"Attachment not found: {source_path}") + content_parts.append(f"[attachment: {filename} - not found]") + + except Exception as e: + self.logger.warning(f"Failed to process attachment {filename}: {e}") + content_parts.append(f"[attachment: {filename} - error]") + + content = "\n".join(content_parts) if content_parts else "[empty message]" + + self.logger.debug(f"Signal message from {sender_number}: {content[:50]}...") + + await self._start_typing(chat_id) + try: + # Forward to message bus + await self._handle_message( + sender_id=sender_id, + chat_id=chat_id, + content=content, + media=media_paths, + metadata={ + "timestamp": timestamp, + "sender_name": sender_name, + "sender_number": sender_number, + "is_group": is_group_message, + "group_id": group_id, + }, + ) + except Exception: + await self._stop_typing(chat_id) + raise + + def _add_to_group_buffer( + self, + group_id: str, + sender_name: str, + sender_number: str, + message_text: str, + timestamp: int | None, + ) -> None: + """ + Add a message to the group's rolling buffer. + + Args: + group_id: The group ID + sender_name: Display name of sender + sender_number: Phone number of sender + message_text: The message content + timestamp: Message timestamp + """ + if self.config.group_message_buffer_size <= 0: + return + + # Create buffer for this group if it doesn't exist + if group_id not in self._group_buffers: + self._group_buffers[group_id] = deque(maxlen=self.config.group_message_buffer_size) + + # Add message to buffer (deque will automatically drop oldest when full) + self._group_buffers[group_id].append( + { + "sender_name": sender_name, + "sender_number": sender_number, + "content": message_text, + "timestamp": timestamp, + } + ) + + self.logger.debug( + f"Added message to group buffer {group_id}: " + f"{len(self._group_buffers[group_id])}/{self.config.group_message_buffer_size}" + ) + + def _get_group_buffer_context(self, group_id: str) -> str: + """ + Get formatted context from the group's message buffer. + + Args: + group_id: The group ID + + Returns: + Formatted string of recent messages (excluding the current one) + """ + if group_id not in self._group_buffers: + return "" + + buffer = self._group_buffers[group_id] + if len(buffer) <= 1: # Only current message, no context + return "" + + # Format all messages except the last one (which is the current message) + # We want to show context BEFORE the mention + context_messages = list(buffer)[:-1] # Exclude the last (current) message + + lines = [] + for msg in context_messages: + sender = msg["sender_name"] + content = msg["content"][:200] # Limit to 200 chars per message + lines.append(f"{sender}: {content}") + + return "\n".join(lines) + + @staticmethod + def _normalize_signal_id(value: str) -> list[str]: + """Normalize Signal identifiers (phone/uuid/service-id) for matching.""" + raw = value.strip() + if not raw: + return [] + + normalized = [raw, raw.lower()] + if raw.startswith("+") and len(raw) > 1: + normalized.append(raw[1:]) + elif raw.isdigit(): + normalized.append(f"+{raw}") + return list(dict.fromkeys(normalized)) + + def _remember_account_id_alias(self, value: str | None) -> None: + """Remember known bot identifiers for mention matching.""" + if not value: + return + if not isinstance(value, str): + return + for candidate in self._normalize_signal_id(value): + self._account_id_aliases.add(candidate) + + def _id_matches_account(self, value: str | None) -> bool: + """Return True when an identifier refers to the bot account.""" + if not value: + return False + if not isinstance(value, str): + return False + return any( + candidate in self._account_id_aliases for candidate in self._normalize_signal_id(value) + ) + + @staticmethod + def _collect_sender_id_parts(envelope: dict[str, Any]) -> list[str]: + """Collect all known sender identifier variants from an envelope.""" + parts: list[str] = [] + for key in ( + "sourceNumber", + "source", + "sourceUuid", + "sourceServiceId", + "sourceAci", + "sourceACI", + ): + value = envelope.get(key) + if not isinstance(value, str): + continue + candidate = value.strip() + if candidate and candidate not in parts: + parts.append(candidate) + return parts + + @staticmethod + def _primary_sender_id(sender_parts: list[str]) -> str: + """Pick the best sender identifier for routing (prefer phone-like IDs).""" + for part in sender_parts: + if part.startswith("+") or part.isdigit(): + return part + return sender_parts[0] if sender_parts else "" + + @staticmethod + def _extract_group_id(group_info: Any, group_v2: Any) -> str | None: + """Extract group ID from groupInfo/groupV2 payloads across signal-cli variants.""" + for group_obj in (group_info, group_v2): + if not isinstance(group_obj, dict): + continue + for key in ("groupId", "id", "groupID"): + value = group_obj.get(key) + if isinstance(value, str) and value: + return value + return None + + @staticmethod + def _mention_id_candidates(mention: dict[str, Any]) -> list[str]: + """Extract possible identifier fields from a mention payload.""" + ids: list[str] = [] + + def _walk(value: Any, depth: int = 0) -> None: + if depth > 2: + return + if not isinstance(value, dict): + return + for key, child in value.items(): + key_lower = str(key).lower() + if isinstance(child, str) and child: + if any(token in key_lower for token in ("number", "uuid", "serviceid", "aci")): + ids.append(child) + elif isinstance(child, dict): + _walk(child, depth + 1) + + _walk(mention) + return list(dict.fromkeys(ids)) + + @staticmethod + def _mention_span(mention: dict[str, Any]) -> tuple[int, int] | None: + """Extract a safe (start, length) span from a mention.""" + try: + start = int(mention.get("start", 0)) + length = int(mention.get("length", 0)) + except (TypeError, ValueError): + return None + + if start < 0 or length <= 0: + return None + return (start, length) + + @staticmethod + def _leading_placeholder_span(text: str | None) -> tuple[int, int] | None: + """ + Detect a leading Signal mention placeholder when mention metadata is missing. + + Some clients/integrations deliver mentions as a leading placeholder character + (typically U+FFFC) but omit `mentions` metadata in the payload. + """ + if not text: + return None + + start = 0 + while start < len(text) and text[start].isspace(): + start += 1 + + if start >= len(text): + return None + + marker = text[start] + if marker not in ("\ufffc", "\ufffd", "\x1b"): + return None + + next_index = start + 1 + if next_index < len(text) and not text[next_index].isspace(): + return None + + return (start, 1) + + def _should_respond_in_group(self, message_text: str, mentions: list[dict[str, Any]]) -> bool: + """ + Determine if the bot should respond to a group message. + + Args: + message_text: The message text content + mentions: List of mentions from Signal (format: [{"number": "+1234567890", "start": 0, "length": 10}]) + + Returns: + True if bot should respond, False otherwise + """ + # Group reply behavior is controlled only by group.require_mention. + if not self.config.group.require_mention: + return True + + # If mention is required, check if bot was mentioned. + for mention in mentions: + if not isinstance(mention, dict): + continue + for mention_id in self._mention_id_candidates(mention): + if self._id_matches_account(mention_id): + return True + + # Some Signal clients emit mention spans without recipient identifiers + # (for handle-style mentions). Accept a leading identifier-less mention + # as a mention of the bot to avoid false negatives. + for mention in mentions: + if not isinstance(mention, dict): + continue + if self._mention_id_candidates(mention): + continue + span = self._mention_span(mention) + if not span: + continue + start, _ = span + if message_text is not None and not message_text[:start].strip(): + self.logger.debug("Accepting identifier-less leading mention as bot mention") + return True + + # Some payloads omit `mentions` but still include the leading mention + # placeholder character in the message body. + if not mentions and self._leading_placeholder_span(message_text): + self.logger.debug("Accepting leading placeholder mention without mention metadata") + return True + + # Fallback: check for configured phone number in plain text. + if message_text and self.config.phone_number: + for account_id in self._normalize_signal_id(self.config.phone_number): + if account_id and account_id in message_text: + return True + + return False + + def _strip_bot_mention(self, text: str, mentions: list[dict[str, Any]]) -> str: + """ + Remove bot mentions from message text. + + Signal mentions are embedded in the text, so we need to remove them based on + the mentions array which provides start position and length. + + Args: + text: Original message text + mentions: List of mention objects with start/length positions + + Returns: + Text with bot mentions removed + """ + if not text: + return text + + # Build a list of (start, length) tuples for our bot's mentions + bot_mentions = [] + for mention in mentions: + if not isinstance(mention, dict): + continue + mention_ids = self._mention_id_candidates(mention) + span = self._mention_span(mention) + if not span: + continue + + # Strip matched bot mentions by ID. + if any(self._id_matches_account(mention_id) for mention_id in mention_ids): + bot_mentions.append(span) + continue + + # Also strip identifier-less leading mention spans (handle mentions). + if not mention_ids: + start, _ = span + if not text[:start].strip(): + bot_mentions.append(span) + + if not bot_mentions: + placeholder_span = self._leading_placeholder_span(text) + if placeholder_span: + bot_mentions.append(placeholder_span) + + # Sort mentions by start position (descending) to remove from end to start + # This prevents position shifts when removing earlier mentions + bot_mentions.sort(reverse=True) + + # Remove each mention + for start, length in bot_mentions: + if start >= len(text): + continue + end = min(len(text), start + length) + text = text[:start] + text[end:] + + return text.strip() + + @staticmethod + def _is_group_chat_id(chat_id: str) -> bool: + """Return True when chat_id appears to be a Signal group ID (base64).""" + return "=" in chat_id or (len(chat_id) > 40 and "-" not in chat_id) + + def _recipient_params(self, chat_id: str) -> dict[str, Any]: + """Build recipient params for signal-cli JSON-RPC methods.""" + if self._is_group_chat_id(chat_id): + return {"groupId": chat_id} + return {"recipient": [chat_id]} + + async def _start_typing(self, chat_id: str) -> None: + """Start periodic typing indicator updates for a chat.""" + await self._stop_typing(chat_id, send_stop=False) + await self._send_typing(chat_id) + self._typing_tasks[chat_id] = asyncio.create_task(self._typing_loop(chat_id)) + + async def _stop_typing(self, chat_id: str, send_stop: bool = True) -> None: + """Stop typing indicator updates for a chat.""" + task = self._typing_tasks.pop(chat_id, None) + had_task = task is not None + if task and not task.done(): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + if send_stop and had_task: + await self._send_typing(chat_id, stop=True) + + async def _typing_loop(self, chat_id: str) -> None: + """Send typing updates periodically until cancelled.""" + try: + while self._running: + await asyncio.sleep(self._TYPING_REFRESH_SECONDS) + await self._send_typing(chat_id, quiet_success=True) + except asyncio.CancelledError: + pass + except Exception as e: + self.logger.debug(f"Typing indicator loop stopped for {chat_id}: {e}") + + async def _send_typing( + self, chat_id: str, stop: bool = False, quiet_success: bool = False + ) -> None: + """Send a typing START/STOP message via signal-cli.""" + action = "stop" if stop else "start" + if ( + not self._is_group_chat_id(chat_id) + and chat_id.startswith("+") is False + and chat_id not in self._typing_uuid_warnings + ): + self._typing_uuid_warnings.add(chat_id) + self.logger.warning( + "Signal DM recipient is UUID-only (no phone number in envelope). " + "Some Signal clients may not render typing indicators for this recipient form." + ) + candidate_params: list[dict[str, Any]] + if self._is_group_chat_id(chat_id): + candidate_params = [{"groupId": chat_id}, {"groupId": [chat_id]}] + else: + candidate_params = [{"recipient": chat_id}, {"recipient": [chat_id]}] + + last_error: Any | None = None + for params in candidate_params: + if stop: + params["stop"] = True + try: + response = await self._send_request("sendTyping", params) + except Exception as e: + last_error = str(e) + continue + + if "error" not in response: + if not quiet_success: + self.logger.info(f"Signal typing {action} sent for {chat_id}") + return + + last_error = response["error"] + + self.logger.warning(f"Failed to send Signal typing {action} for {chat_id}: {last_error}") + + async def _ensure_typing_indicators_enabled(self) -> None: + """Enable typing indicators on the bot account.""" + response = await self._send_request("updateConfiguration", {"typingIndicators": True}) + if "error" in response: + self.logger.warning(f"Failed to enable Signal typing indicators: {response['error']}") + else: + self.logger.info("Signal typing indicators enabled on account configuration") + + async def _send_request( + self, method: str, params: dict[str, Any] | None = None + ) -> dict[str, Any]: + """Send a JSON-RPC request via HTTP and wait for response.""" + # Generate request ID + self._request_id += 1 + request_id = self._request_id + + # Build JSON-RPC request + request = {"jsonrpc": "2.0", "method": method, "id": request_id} + + if params: + request["params"] = params + + return await self._send_http_request(request) + + async def _send_http_request(self, request: dict[str, Any]) -> dict[str, Any]: + """Send JSON-RPC request via HTTP.""" + if not self._http: + raise RuntimeError("Not connected to signal-cli daemon") + + try: + response = await self._http.post("/api/v1/rpc", json=request) + response.raise_for_status() + return response.json() + except Exception as e: + self.logger.error(f"HTTP request failed: {e}") + return {"error": {"message": str(e)}} diff --git a/tests/channels/test_signal_channel.py b/tests/channels/test_signal_channel.py new file mode 100644 index 000000000..b5149459b --- /dev/null +++ b/tests/channels/test_signal_channel.py @@ -0,0 +1,1058 @@ +"""Tests for the Signal channel implementation.""" + +from __future__ import annotations + +import asyncio +from unittest.mock import AsyncMock + +import pytest + +from nanobot.bus.events import OutboundMessage +from nanobot.bus.queue import MessageBus +from nanobot.channels.signal import ( + SignalChannel, + SignalConfig, + SignalDMConfig, + SignalGroupConfig, +) + +# --------------------------------------------------------------------------- +# Fake HTTP client +# --------------------------------------------------------------------------- + + +class _FakeResponse: + def __init__(self, status_code: int = 200, body: dict | None = None) -> None: + self.status_code = status_code + self._body = body or {} + + def raise_for_status(self) -> None: + if self.status_code >= 400: + raise RuntimeError(f"HTTP {self.status_code}") + + def json(self) -> dict: + return self._body + + +class _FakeHTTPClient: + """Minimal httpx.AsyncClient stand-in that records requests.""" + + def __init__(self, *, default_response: dict | None = None) -> None: + self.posts: list[dict] = [] + self.gets: list[str] = [] + self._response = _FakeResponse(body=default_response or {"result": {"timestamp": 123}}) + self.closed = False + + async def get(self, path: str) -> _FakeResponse: + self.gets.append(path) + return self._response + + async def post(self, path: str, *, json: dict) -> _FakeResponse: + self.posts.append({"path": path, "json": json}) + return self._response + + async def aclose(self) -> None: + self.closed = True + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_channel( + *, + phone_number: str = "+10000000000", + dm_enabled: bool = True, + dm_policy: str = "open", + dm_allow_from: list[str] | None = None, + group_enabled: bool = False, + group_policy: str = "open", + group_allow_from: list[str] | None = None, + require_mention: bool = True, + group_buffer_size: int = 20, +) -> SignalChannel: + config = SignalConfig( + enabled=True, + phone_number=phone_number, + dm=SignalDMConfig( + enabled=dm_enabled, + policy=dm_policy, + allow_from=dm_allow_from or [], + ), + group=SignalGroupConfig( + enabled=group_enabled, + policy=group_policy, + allow_from=group_allow_from or [], + require_mention=require_mention, + ), + group_message_buffer_size=group_buffer_size, + ) + return SignalChannel(config, MessageBus()) + + +def _dm_envelope( + *, + source_number: str = "+19995550001", + source_uuid: str | None = None, + source_name: str | None = "Alice", + message: str = "hello", + attachments: list | None = None, + reaction: dict | None = None, + timestamp: int = 1000, +) -> dict: + data_message: dict = {"message": message, "timestamp": timestamp} + if attachments is not None: + data_message["attachments"] = attachments + if reaction is not None: + data_message["reaction"] = reaction + envelope: dict = { + "sourceNumber": source_number, + "sourceName": source_name, + "dataMessage": data_message, + } + if source_uuid: + envelope["sourceUuid"] = source_uuid + return {"envelope": envelope} + + +def _group_envelope( + *, + source_number: str = "+19995550001", + source_name: str = "Bob", + group_id: str = "group123==", + message: str = "hey group", + mentions: list | None = None, + timestamp: int = 2000, + use_v2: bool = False, +) -> dict: + group_obj = {"groupId": group_id} + key = "groupV2" if use_v2 else "groupInfo" + data_message: dict = { + "message": message, + "timestamp": timestamp, + key: group_obj, + "mentions": mentions or [], + } + return { + "envelope": { + "sourceNumber": source_number, + "sourceName": source_name, + "dataMessage": data_message, + } + } + + +# --------------------------------------------------------------------------- +# Static utility tests +# --------------------------------------------------------------------------- + + +class TestNormalizeSignalId: + def test_phone_number_kept_and_stripped(self): + result = SignalChannel._normalize_signal_id("+12345678901") + assert "+12345678901" in result + assert "12345678901" in result + + def test_digits_only_gets_plus_prefix(self): + result = SignalChannel._normalize_signal_id("12345678901") + assert "+12345678901" in result + + def test_lowercase_variant_added(self): + result = SignalChannel._normalize_signal_id("SOME-UUID") + assert "some-uuid" in result + + def test_empty_string_returns_empty(self): + assert SignalChannel._normalize_signal_id("") == [] + + def test_whitespace_stripped(self): + result = SignalChannel._normalize_signal_id(" +1234 ") + assert "+1234" in result + + +class TestCollectSenderIdParts: + def test_collects_source_number(self): + env = {"sourceNumber": "+15551234567"} + parts = SignalChannel._collect_sender_id_parts(env) + assert "+15551234567" in parts + + def test_collects_multiple_keys(self): + env = {"sourceNumber": "+15551234567", "sourceUuid": "uuid-abc"} + parts = SignalChannel._collect_sender_id_parts(env) + assert "+15551234567" in parts + assert "uuid-abc" in parts + + def test_deduplicates(self): + env = {"sourceNumber": "+15551234567", "source": "+15551234567"} + parts = SignalChannel._collect_sender_id_parts(env) + assert parts.count("+15551234567") == 1 + + def test_ignores_non_string_values(self): + env = {"sourceNumber": 12345, "sourceUuid": None} + parts = SignalChannel._collect_sender_id_parts(env) + assert parts == [] + + def test_empty_envelope_returns_empty(self): + assert SignalChannel._collect_sender_id_parts({}) == [] + + +class TestPrimarySenderId: + def test_prefers_phone_number(self): + assert SignalChannel._primary_sender_id(["+1234", "uuid-abc"]) == "+1234" + + def test_accepts_digit_only(self): + assert SignalChannel._primary_sender_id(["1234567890", "uuid-abc"]) == "1234567890" + + def test_falls_back_to_first_part(self): + assert SignalChannel._primary_sender_id(["uuid-abc", "other"]) == "uuid-abc" + + def test_empty_list_returns_empty(self): + assert SignalChannel._primary_sender_id([]) == "" + + +class TestExtractGroupId: + def test_extracts_from_group_info(self): + gid = SignalChannel._extract_group_id({"groupId": "abc=="}, None) + assert gid == "abc==" + + def test_extracts_from_group_v2(self): + gid = SignalChannel._extract_group_id(None, {"id": "xyz=="}) + assert gid == "xyz==" + + def test_prefers_group_info_over_v2(self): + gid = SignalChannel._extract_group_id({"groupId": "first"}, {"groupId": "second"}) + assert gid == "first" + + def test_returns_none_when_both_none(self): + assert SignalChannel._extract_group_id(None, None) is None + + def test_returns_none_when_not_dicts(self): + assert SignalChannel._extract_group_id("bad", 123) is None + + +class TestIsGroupChatId: + def test_base64_with_equals_is_group(self): + assert SignalChannel._is_group_chat_id("abc==") is True + + def test_long_id_without_dash_is_group(self): + long_id = "a" * 41 + assert SignalChannel._is_group_chat_id(long_id) is True + + def test_phone_number_is_not_group(self): + assert SignalChannel._is_group_chat_id("+12345678901") is False + + def test_uuid_with_dashes_is_not_group(self): + assert SignalChannel._is_group_chat_id("550e8400-e29b-41d4-a716-446655440000") is False + + +class TestRecipientParams: + def test_group_chat_uses_group_id(self): + ch = _make_channel() + params = ch._recipient_params("abc==") + assert params == {"groupId": "abc=="} + + def test_dm_uses_recipient_list(self): + ch = _make_channel() + params = ch._recipient_params("+12345678901") + assert params == {"recipient": ["+12345678901"]} + + +class TestMentionHelpers: + def test_mention_id_candidates_extracts_number(self): + mention = {"number": "+1234567890"} + ids = SignalChannel._mention_id_candidates(mention) + assert "+1234567890" in ids + + def test_mention_id_candidates_extracts_uuid(self): + mention = {"uuid": "some-uuid"} + ids = SignalChannel._mention_id_candidates(mention) + assert "some-uuid" in ids + + def test_mention_span_valid(self): + assert SignalChannel._mention_span({"start": 0, "length": 5}) == (0, 5) + + def test_mention_span_negative_start(self): + assert SignalChannel._mention_span({"start": -1, "length": 5}) is None + + def test_mention_span_zero_length(self): + assert SignalChannel._mention_span({"start": 0, "length": 0}) is None + + def test_mention_span_missing_keys(self): + assert SignalChannel._mention_span({}) is None + + def test_leading_placeholder_ufffc(self): + span = SignalChannel._leading_placeholder_span(" hello") + assert span == (0, 1) + + def test_leading_placeholder_not_at_start(self): + assert SignalChannel._leading_placeholder_span("hello ") is None + + def test_leading_placeholder_empty_string(self): + assert SignalChannel._leading_placeholder_span("") is None + + def test_leading_placeholder_plain_text(self): + assert SignalChannel._leading_placeholder_span("hello") is None + + +# --------------------------------------------------------------------------- +# Account ID alias / mention matching +# --------------------------------------------------------------------------- + + +class TestAccountIdAliases: + def test_phone_number_alias_registered_on_init(self): + ch = _make_channel(phone_number="+10000000000") + assert ch._id_matches_account("+10000000000") + + def test_digit_only_variant_matches(self): + ch = _make_channel(phone_number="+10000000000") + assert ch._id_matches_account("10000000000") + + def test_remember_alias_adds_uuid(self): + ch = _make_channel() + ch._remember_account_id_alias("some-uuid-abc") + assert ch._id_matches_account("some-uuid-abc") + + def test_non_matching_id_returns_false(self): + ch = _make_channel(phone_number="+10000000000") + assert not ch._id_matches_account("+19999999999") + + def test_none_and_non_string_return_false(self): + ch = _make_channel() + assert not ch._id_matches_account(None) + + +# --------------------------------------------------------------------------- +# _should_respond_in_group +# --------------------------------------------------------------------------- + + +class TestShouldRespondInGroup: + def _make_group_channel(self, require_mention: bool = True) -> SignalChannel: + return _make_channel( + phone_number="+10000000000", + group_enabled=True, + require_mention=require_mention, + ) + + def test_no_require_mention_always_responds(self): + ch = self._make_group_channel(require_mention=False) + assert ch._should_respond_in_group("anything", []) is True + + def test_require_mention_with_no_mentions_returns_false(self): + ch = self._make_group_channel(require_mention=True) + assert ch._should_respond_in_group("hello", []) is False + + def test_require_mention_with_bot_number_mention(self): + ch = self._make_group_channel(require_mention=True) + mentions = [{"number": "+10000000000", "start": 0, "length": 12}] + assert ch._should_respond_in_group(" hello", mentions) is True + + def test_require_mention_with_uuid_mention(self): + ch = self._make_group_channel(require_mention=True) + ch._remember_account_id_alias("bot-uuid-123") + mentions = [{"uuid": "bot-uuid-123", "start": 0, "length": 8}] + assert ch._should_respond_in_group(" hello", mentions) is True + + def test_identifier_less_leading_mention_accepted(self): + ch = self._make_group_channel(require_mention=True) + # Mention with no IDs but leading span — treated as bot mention + mentions = [{"start": 0, "length": 1}] + assert ch._should_respond_in_group(" hello", mentions) is True + + def test_identifier_less_non_leading_mention_rejected(self): + ch = self._make_group_channel(require_mention=True) + mentions = [{"start": 5, "length": 1}] + assert ch._should_respond_in_group("hello ", mentions) is False + + def test_leading_placeholder_without_mentions_metadata(self): + ch = self._make_group_channel(require_mention=True) + assert ch._should_respond_in_group(" hello", []) is True + + def test_phone_number_in_text_triggers_response(self): + ch = self._make_group_channel(require_mention=True) + assert ch._should_respond_in_group("hey +10000000000 help", []) is True + + +# --------------------------------------------------------------------------- +# _strip_bot_mention +# --------------------------------------------------------------------------- + + +class TestStripBotMention: + def _make_channel_with_number(self) -> SignalChannel: + return _make_channel(phone_number="+10000000000") + + def test_strips_mention_by_phone(self): + ch = self._make_channel_with_number() + text = " hello" + mentions = [{"number": "+10000000000", "start": 0, "length": 1}] + result = ch._strip_bot_mention(text, mentions) + assert result == "hello" + + def test_strips_identifier_less_leading_mention(self): + ch = self._make_channel_with_number() + text = " hello" + mentions = [{"start": 0, "length": 1}] + result = ch._strip_bot_mention(text, mentions) + assert result == "hello" + + def test_strips_leading_placeholder_without_mention_metadata(self): + ch = self._make_channel_with_number() + text = " hello" + result = ch._strip_bot_mention(text, []) + assert result == "hello" + + def test_non_bot_mention_mid_text_not_stripped(self): + # A non-bot mention that is NOT a leading placeholder leaves the text alone. + ch = self._make_channel_with_number() + text = "hello  world" + mentions = [{"number": "+19999999999", "start": 6, "length": 1}] + result = ch._strip_bot_mention(text, mentions) + # Mid-text placeholder from a non-bot mention should be untouched + assert "" in result + + def test_empty_text_returned_unchanged(self): + ch = self._make_channel_with_number() + assert ch._strip_bot_mention("", []) == "" + + +# --------------------------------------------------------------------------- +# Group message buffer +# --------------------------------------------------------------------------- + + +class TestGroupBuffer: + def test_add_and_get_context(self): + ch = _make_channel(group_buffer_size=5) + ch._add_to_group_buffer("g1", "Alice", "+1111", "first msg", 1000) + ch._add_to_group_buffer("g1", "Bob", "+2222", "second msg", 2000) + # Only messages before the latest are returned as context + ctx = ch._get_group_buffer_context("g1") + assert "first msg" in ctx + # The last message is not included (it's the "current" one) + assert "second msg" not in ctx + + def test_empty_context_when_only_one_message(self): + ch = _make_channel(group_buffer_size=5) + ch._add_to_group_buffer("g1", "Alice", "+1111", "only msg", 1000) + assert ch._get_group_buffer_context("g1") == "" + + def test_empty_context_when_group_unknown(self): + ch = _make_channel() + assert ch._get_group_buffer_context("unknown") == "" + + def test_buffer_respects_max_size(self): + ch = _make_channel(group_buffer_size=3) + for i in range(10): + ch._add_to_group_buffer("g1", "Alice", "+1111", f"msg{i}", i) + assert len(ch._group_buffers["g1"]) == 3 + + def test_zero_buffer_size_does_not_add(self): + ch = _make_channel(group_buffer_size=0) + ch._add_to_group_buffer("g1", "Alice", "+1111", "msg", 1000) + assert "g1" not in ch._group_buffers + + def test_context_limits_message_length(self): + ch = _make_channel(group_buffer_size=5) + long_msg = "x" * 500 + ch._add_to_group_buffer("g1", "Alice", "+1111", long_msg, 1000) + ch._add_to_group_buffer("g1", "Bob", "+2222", "short", 2000) + ctx = ch._get_group_buffer_context("g1") + # Context is capped at 200 chars per message + assert len(ctx.split("Alice: ", 1)[1]) <= 200 + + +# --------------------------------------------------------------------------- +# _handle_data_message — DM routing +# --------------------------------------------------------------------------- + + +class TestHandleDataMessageDM: + def _make_dm_channel(self, policy="open", allow_from=None) -> tuple[SignalChannel, list]: + ch = _make_channel(dm_enabled=True, dm_policy=policy, dm_allow_from=allow_from or []) + handled: list[dict] = [] + + async def capture(**kwargs): + handled.append(kwargs) + + ch._handle_message = capture # type: ignore[method-assign] + + async def noop_typing(chat_id): + pass + + ch._start_typing = noop_typing # type: ignore[method-assign] + return ch, handled + + @pytest.mark.asyncio + async def test_dm_open_policy_accepted(self): + ch, handled = self._make_dm_channel(policy="open") + params = _dm_envelope(source_number="+19995550001", message="hi") + await ch._handle_receive_notification(params) + assert len(handled) == 1 + assert handled[0]["chat_id"] == "+19995550001" + assert handled[0]["content"] == "hi" + + @pytest.mark.asyncio + async def test_dm_allowlist_accepted(self): + ch, handled = self._make_dm_channel(policy="allowlist", allow_from=["+19995550001"]) + params = _dm_envelope(source_number="+19995550001") + await ch._handle_receive_notification(params) + assert len(handled) == 1 + + @pytest.mark.asyncio + async def test_dm_allowlist_rejected(self): + ch, handled = self._make_dm_channel(policy="allowlist", allow_from=["+10000000001"]) + params = _dm_envelope(source_number="+19995550002") + await ch._handle_receive_notification(params) + assert handled == [] + + @pytest.mark.asyncio + async def test_dm_disabled_rejected(self): + ch = _make_channel(dm_enabled=False) + handled: list[dict] = [] + + async def capture(**kwargs): + handled.append(kwargs) + + ch._handle_message = capture # type: ignore[method-assign] + + async def noop_typing(chat_id): + pass + + ch._start_typing = noop_typing # type: ignore[method-assign] + params = _dm_envelope(source_number="+19995550001") + await ch._handle_receive_notification(params) + assert handled == [] + + @pytest.mark.asyncio + async def test_reaction_message_ignored(self): + ch, handled = self._make_dm_channel() + params = _dm_envelope(reaction={"emoji": "👍", "targetTimestamp": 999}) + await ch._handle_receive_notification(params) + assert handled == [] + + @pytest.mark.asyncio + async def test_empty_message_ignored(self): + ch, handled = self._make_dm_channel() + params = _dm_envelope(message="") + await ch._handle_receive_notification(params) + assert handled == [] + + @pytest.mark.asyncio + async def test_receipt_message_ignored(self): + ch, handled = self._make_dm_channel() + notification = { + "envelope": { + "sourceNumber": "+19995550001", + "receiptMessage": {"when": 1234}, + } + } + await ch._handle_receive_notification(notification) + assert handled == [] + + @pytest.mark.asyncio + async def test_typing_indicator_ignored(self): + ch, handled = self._make_dm_channel() + notification = { + "envelope": { + "sourceNumber": "+19995550001", + "typingMessage": {"action": "STARTED"}, + } + } + await ch._handle_receive_notification(notification) + assert handled == [] + + @pytest.mark.asyncio + async def test_missing_envelope_ignored(self): + ch, handled = self._make_dm_channel() + await ch._handle_receive_notification({}) + assert handled == [] + + @pytest.mark.asyncio + async def test_metadata_passed_to_handle(self): + ch, handled = self._make_dm_channel() + params = _dm_envelope(source_number="+19995550001", source_name="Alice", timestamp=9999) + await ch._handle_receive_notification(params) + meta = handled[0]["metadata"] + assert meta["sender_name"] == "Alice" + assert meta["timestamp"] == 9999 + assert meta["is_group"] is False + + @pytest.mark.asyncio + async def test_sender_id_with_uuid_variant(self): + ch, handled = self._make_dm_channel() + params = _dm_envelope(source_number="+19995550001", source_uuid="uuid-abc") + await ch._handle_receive_notification(params) + assert len(handled) == 1 + # sender_id combines both parts + assert "+19995550001" in handled[0]["sender_id"] + assert "uuid-abc" in handled[0]["sender_id"] + + @pytest.mark.asyncio + async def test_stop_typing_called_on_handle_error(self): + ch = _make_channel(dm_enabled=True, dm_policy="open") + typing_stopped: list[str] = [] + + async def fail_handle(**kwargs): + raise RuntimeError("boom") + + async def noop_typing(chat_id): + pass + + async def record_stop(chat_id, **kwargs): + typing_stopped.append(chat_id) + + ch._handle_message = fail_handle # type: ignore[method-assign] + ch._start_typing = noop_typing # type: ignore[method-assign] + ch._stop_typing = record_stop # type: ignore[method-assign] + + # _handle_receive_notification swallows exceptions; the typing stop + # still fires from _handle_data_message's except clause. + params = _dm_envelope(source_number="+19995550001") + await ch._handle_receive_notification(params) + + assert "+19995550001" in typing_stopped + + +# --------------------------------------------------------------------------- +# _handle_data_message — group routing +# --------------------------------------------------------------------------- + + +class TestHandleDataMessageGroup: + def _make_group_channel( + self, + policy="open", + allow_from=None, + require_mention=True, + ) -> tuple[SignalChannel, list]: + ch = _make_channel( + group_enabled=True, + group_policy=policy, + group_allow_from=allow_from or [], + require_mention=require_mention, + ) + handled: list[dict] = [] + + async def capture(**kwargs): + handled.append(kwargs) + + ch._handle_message = capture # type: ignore[method-assign] + + async def noop_typing(chat_id): + pass + + ch._start_typing = noop_typing # type: ignore[method-assign] + return ch, handled + + @pytest.mark.asyncio + async def test_group_disabled_rejected(self): + ch = _make_channel(group_enabled=False) + handled: list[dict] = [] + ch._handle_message = lambda **kw: handled.append(kw) # type: ignore[method-assign] + params = _group_envelope(group_id="grp==", message="hi") + await ch._handle_receive_notification(params) + assert handled == [] + + @pytest.mark.asyncio + async def test_group_open_policy_no_mention_blocked_when_required(self): + ch, handled = self._make_group_channel(require_mention=True) + params = _group_envelope(group_id="grp==", message="hey everyone") + await ch._handle_receive_notification(params) + assert handled == [] + + @pytest.mark.asyncio + async def test_group_open_policy_no_mention_required(self): + ch, handled = self._make_group_channel(require_mention=False) + params = _group_envelope(group_id="grp==", message="hey everyone") + await ch._handle_receive_notification(params) + assert len(handled) == 1 + assert handled[0]["chat_id"] == "grp==" + + @pytest.mark.asyncio + async def test_group_allowlist_accepted(self): + ch, handled = self._make_group_channel( + policy="allowlist", allow_from=["grp=="], require_mention=False + ) + params = _group_envelope(group_id="grp==", message="hi") + await ch._handle_receive_notification(params) + assert len(handled) == 1 + + @pytest.mark.asyncio + async def test_group_allowlist_rejected(self): + ch, handled = self._make_group_channel(policy="allowlist", allow_from=["other=="]) + params = _group_envelope(group_id="grp==", message="hi") + await ch._handle_receive_notification(params) + assert handled == [] + + @pytest.mark.asyncio + async def test_group_mention_triggers_response(self): + ch, handled = self._make_group_channel(require_mention=True) + ch._remember_account_id_alias("+10000000000") + mentions = [{"number": "+10000000000", "start": 0, "length": 1}] + params = _group_envelope(group_id="grp==", message=" hello", mentions=mentions) + await ch._handle_receive_notification(params) + assert len(handled) == 1 + + @pytest.mark.asyncio + async def test_group_v2_id_extracted(self): + ch, handled = self._make_group_channel(require_mention=False) + params = _group_envelope(group_id="grpV2==", message="hi", use_v2=True) + await ch._handle_receive_notification(params) + assert len(handled) == 1 + assert handled[0]["chat_id"] == "grpV2==" + + @pytest.mark.asyncio + async def test_group_message_includes_sender_prefix(self): + ch, handled = self._make_group_channel(require_mention=False) + params = _group_envelope(group_id="grp==", source_name="Bob", message="hello") + await ch._handle_receive_notification(params) + assert "[Bob]:" in handled[0]["content"] + + @pytest.mark.asyncio + async def test_group_message_context_prepended(self): + ch, handled = self._make_group_channel(require_mention=False) + # First message — adds to buffer but no context yet + params1 = _group_envelope(group_id="grp==", source_name="Alice", message="msg1") + await ch._handle_receive_notification(params1) + # Second message — should include context from first + params2 = _group_envelope(group_id="grp==", source_name="Bob", message="msg2") + await ch._handle_receive_notification(params2) + assert "[Recent group messages for context:]" in handled[1]["content"] + assert "msg1" in handled[1]["content"] + + @pytest.mark.asyncio + async def test_group_metadata_marks_is_group(self): + ch, handled = self._make_group_channel(require_mention=False) + params = _group_envelope(group_id="grp==", message="hi") + await ch._handle_receive_notification(params) + assert handled[0]["metadata"]["is_group"] is True + assert handled[0]["metadata"]["group_id"] == "grp==" + + @pytest.mark.asyncio + async def test_bot_account_alias_learned_from_incoming(self): + ch, handled = self._make_group_channel(require_mention=False) + # If the bot's own UUID appears in an envelope we learn it + params = _dm_envelope(source_number="+10000000000", source_uuid="new-bot-uuid") + # DMs from self are processed (learning alias), but DM policy is open + ch._handle_message = lambda **kw: handled.append(kw) # type: ignore[method-assign] + ch._start_typing = lambda chat_id: None # type: ignore[method-assign] + await ch._handle_receive_notification(params) + assert ch._id_matches_account("new-bot-uuid") + + +# --------------------------------------------------------------------------- +# Command handling +# --------------------------------------------------------------------------- + + +class TestCommandHandling: + @pytest.mark.asyncio + async def test_dm_command_forwarded_to_bus(self): + """Slash commands in DMs are forwarded to the bus for AgentLoop to handle.""" + ch = _make_channel(dm_enabled=True, dm_policy="open") + forwarded: list[dict] = [] + + async def capture(**kw): + forwarded.append(kw) + + ch._handle_message = capture # type: ignore[method-assign] + ch._start_typing = AsyncMock() + + params = _dm_envelope(source_number="+19995550001", message="/reset") + await ch._handle_receive_notification(params) + + assert len(forwarded) == 1 + assert forwarded[0]["content"].strip() == "/reset" + + @pytest.mark.asyncio + async def test_group_command_bypasses_mention_requirement(self): + """Slash commands in groups bypass the mention requirement and reach the bus.""" + ch = _make_channel( + group_enabled=True, group_policy="open", require_mention=True + ) + forwarded: list[dict] = [] + + async def capture(**kw): + forwarded.append(kw) + + ch._handle_message = capture # type: ignore[method-assign] + ch._start_typing = AsyncMock() + + params = _group_envelope(source_number="+19995550001", group_id="grp==", message="/reset") + await ch._handle_receive_notification(params) + + assert len(forwarded) == 1 + assert "/reset" in forwarded[0]["content"] + + @pytest.mark.asyncio + async def test_command_denied_for_disallowed_dm_sender(self): + """Commands from senders not on the DM allowlist are dropped.""" + ch = _make_channel(dm_enabled=False) + forwarded: list[dict] = [] + + async def capture(**kw): + forwarded.append(kw) + + ch._handle_message = capture # type: ignore[method-assign] + + params = _dm_envelope(source_number="+19995550001", message="/reset") + await ch._handle_receive_notification(params) + + assert forwarded == [] + + +# --------------------------------------------------------------------------- +# send() — outbound messages +# --------------------------------------------------------------------------- + + +class TestSend: + def _make_send_channel(self) -> tuple[SignalChannel, _FakeHTTPClient]: + ch = _make_channel() + client = _FakeHTTPClient() + ch._http = client # type: ignore[assignment] + return ch, client + + @pytest.mark.asyncio + async def test_send_plain_text_posts_rpc(self): + ch, client = self._make_send_channel() + msg = OutboundMessage(channel="signal", chat_id="+19995550001", content="hello") + await ch.send(msg) + assert len(client.posts) == 1 + payload = client.posts[0]["json"] + assert payload["method"] == "send" + assert payload["params"]["message"] == "hello" + + @pytest.mark.asyncio + async def test_send_with_markdown_includes_text_styles(self): + ch, client = self._make_send_channel() + msg = OutboundMessage(channel="signal", chat_id="+19995550001", content="**bold**") + await ch.send(msg) + params = client.posts[0]["json"]["params"] + assert "textStyle" in params + assert any("BOLD" in s for s in params["textStyle"]) + + @pytest.mark.asyncio + async def test_send_empty_content_skips_rpc(self): + ch, client = self._make_send_channel() + msg = OutboundMessage(channel="signal", chat_id="+19995550001", content="") + await ch.send(msg) + assert client.posts == [] + + @pytest.mark.asyncio + async def test_send_to_group_uses_group_id(self): + ch, client = self._make_send_channel() + msg = OutboundMessage(channel="signal", chat_id="grp==", content="hi group") + await ch.send(msg) + params = client.posts[0]["json"]["params"] + assert "groupId" in params + assert "recipient" not in params + + @pytest.mark.asyncio + async def test_send_to_dm_uses_recipient(self): + ch, client = self._make_send_channel() + msg = OutboundMessage(channel="signal", chat_id="+19995550001", content="hi") + await ch.send(msg) + params = client.posts[0]["json"]["params"] + assert "recipient" in params + + @pytest.mark.asyncio + async def test_send_with_media_includes_attachments(self): + ch, client = self._make_send_channel() + msg = OutboundMessage( + channel="signal", + chat_id="+19995550001", + content="see attachment", + media=["/tmp/file.jpg"], + ) + await ch.send(msg) + params = client.posts[0]["json"]["params"] + assert params.get("attachments") == ["/tmp/file.jpg"] + + @pytest.mark.asyncio + async def test_send_progress_message_does_not_stop_typing(self): + ch, client = self._make_send_channel() + stopped: list[str] = [] + + async def record_stop(chat_id, **kwargs): + stopped.append(chat_id) + + ch._stop_typing = record_stop # type: ignore[method-assign] + msg = OutboundMessage( + channel="signal", + chat_id="+19995550001", + content="working...", + metadata={"_progress": True}, + ) + await ch.send(msg) + # Progress messages should NOT stop the typing indicator + assert stopped == [] + + @pytest.mark.asyncio + async def test_send_final_message_stops_typing(self): + ch, client = self._make_send_channel() + stopped: list[str] = [] + + async def record_stop(chat_id, send_stop=True): + stopped.append(chat_id) + + ch._stop_typing = record_stop # type: ignore[method-assign] + msg = OutboundMessage(channel="signal", chat_id="+19995550001", content="done") + await ch.send(msg) + assert "+19995550001" in stopped + + @pytest.mark.asyncio + async def test_send_logs_daemon_error_without_raising(self): + ch = _make_channel() + # The daemon returns {"error": {...}} in the JSON body — this is not a Python + # exception; send() logs it but does not raise (only HTTP-level exceptions raise). + ch._http = _FakeHTTPClient(default_response={"error": {"message": "fail"}}) + msg = OutboundMessage(channel="signal", chat_id="+19995550001", content="hello") + await ch.send(msg) # must not raise + + +# --------------------------------------------------------------------------- +# stop() +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_stop_cancels_sse_task() -> None: + ch = _make_channel() + cancelled = False + + async def long_running(): + nonlocal cancelled + try: + await asyncio.sleep(9999) + except asyncio.CancelledError: + cancelled = True + raise + + ch._sse_task = asyncio.create_task(long_running()) + # Yield so the task can enter its body (reach the first await) before cancel. + await asyncio.sleep(0) + ch._running = True + + await ch.stop() + + assert cancelled + assert ch._running is False + + +@pytest.mark.asyncio +async def test_stop_closes_http_client() -> None: + ch = _make_channel() + client = _FakeHTTPClient() + ch._http = client # type: ignore[assignment] + ch._running = True + + await ch.stop() + + assert client.closed + + +@pytest.mark.asyncio +async def test_stop_safe_when_no_sse_task() -> None: + ch = _make_channel() + ch._running = True + # Should not raise even with no _sse_task + await ch.stop() + assert ch._running is False + + +# --------------------------------------------------------------------------- +# _send_request / _send_http_request +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_send_request_increments_id() -> None: + ch = _make_channel() + client = _FakeHTTPClient() + ch._http = client # type: ignore[assignment] + + await ch._send_request("testMethod", {"key": "val"}) + await ch._send_request("testMethod", {"key": "val"}) + + ids = [p["json"]["id"] for p in client.posts] + assert ids == [1, 2] + + +@pytest.mark.asyncio +async def test_send_request_raises_when_not_connected() -> None: + ch = _make_channel() + # _http is None by default + with pytest.raises(RuntimeError, match="Not connected"): + await ch._send_request("testMethod") + + +# --------------------------------------------------------------------------- +# _handle_receive_notification — envelope shapes +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_handle_notification_sync_message_does_not_forward() -> None: + ch = _make_channel(dm_enabled=True, dm_policy="open") + handled: list[dict] = [] + ch._handle_message = lambda **kw: handled.append(kw) # type: ignore[method-assign] + + notification = { + "envelope": { + "sourceNumber": "+19995550001", + "syncMessage": { + "sentMessage": { + "destination": "+19990000000", + "message": "sent from other device", + } + }, + } + } + await ch._handle_receive_notification(notification) + assert handled == [] + + +@pytest.mark.asyncio +async def test_handle_notification_no_source_skipped() -> None: + ch = _make_channel(dm_enabled=True, dm_policy="open") + handled: list[dict] = [] + ch._handle_message = lambda **kw: handled.append(kw) # type: ignore[method-assign] + + notification = {"envelope": {"dataMessage": {"message": "ghost"}}} + await ch._handle_receive_notification(notification) + assert handled == [] + + +# --------------------------------------------------------------------------- +# Config: allow_from property aggregation +# --------------------------------------------------------------------------- + + +def test_config_allow_from_aggregates_dm_and_group() -> None: + config = SignalConfig( + enabled=True, + phone_number="+10000000000", + dm=SignalDMConfig(enabled=True, policy="allowlist", allow_from=["+1111", "+2222"]), + group=SignalGroupConfig( + enabled=True, policy="allowlist", allow_from=["+3333", "+1111"] + ), + ) + combined = config.allow_from + assert "+1111" in combined + assert "+2222" in combined + assert "+3333" in combined + # Duplicates removed + assert combined.count("+1111") == 1 + + +def test_config_allow_from_wildcard_propagates() -> None: + config = SignalConfig( + enabled=True, + phone_number="+10000000000", + dm=SignalDMConfig(enabled=True, policy="open", allow_from=["*"]), + group=SignalGroupConfig(enabled=True, policy="open", allow_from=[]), + ) + assert "*" in config.allow_from diff --git a/tests/channels/test_signal_markdown.py b/tests/channels/test_signal_markdown.py new file mode 100644 index 000000000..15eca70ff --- /dev/null +++ b/tests/channels/test_signal_markdown.py @@ -0,0 +1,244 @@ +"""Unit tests for the Signal markdown → plain text + textStyle converter.""" + +from nanobot.channels.signal import _markdown_to_signal + + +def styles_for(plain: str, text_styles: list[str]) -> dict[str, list[str]]: + """Return a dict mapping each styled substring to its style list.""" + result: dict[str, list[str]] = {} + for entry in text_styles: + start_s, length_s, style = entry.split(":", 2) + start, length = int(start_s), int(length_s) + span = plain[start : start + length] + result.setdefault(span, []).append(style) + return result + + +# --------------------------------------------------------------------------- +# Basic cases +# --------------------------------------------------------------------------- + + +def test_empty(): + plain, styles = _markdown_to_signal("") + assert plain == "" + assert styles == [] + + +def test_plain_text(): + plain, styles = _markdown_to_signal("hello world") + assert plain == "hello world" + assert styles == [] + + +def test_bold_stars(): + plain, styles = _markdown_to_signal("say **hello** now") + assert plain == "say hello now" + assert styles_for(plain, styles) == {"hello": ["BOLD"]} + + +def test_bold_underscores(): + plain, styles = _markdown_to_signal("say __hello__ now") + assert plain == "say hello now" + assert styles_for(plain, styles) == {"hello": ["BOLD"]} + + +def test_italic_star(): + plain, styles = _markdown_to_signal("say *hello* now") + assert plain == "say hello now" + assert styles_for(plain, styles) == {"hello": ["ITALIC"]} + + +def test_italic_underscore(): + plain, styles = _markdown_to_signal("say _hello_ now") + assert plain == "say hello now" + assert styles_for(plain, styles) == {"hello": ["ITALIC"]} + + +def test_strikethrough(): + plain, styles = _markdown_to_signal("say ~~hello~~ now") + assert plain == "say hello now" + assert styles_for(plain, styles) == {"hello": ["STRIKETHROUGH"]} + + +# --------------------------------------------------------------------------- +# Code +# --------------------------------------------------------------------------- + + +def test_inline_code(): + plain, styles = _markdown_to_signal("run `ls -la` here") + assert plain == "run ls -la here" + assert styles_for(plain, styles) == {"ls -la": ["MONOSPACE"]} + + +def test_code_block(): + plain, styles = _markdown_to_signal("```\nprint('hi')\n```") + assert "print('hi')" in plain + assert styles_for(plain, styles).get("print('hi')\n") == ["MONOSPACE"] or \ + "MONOSPACE" in str(styles_for(plain, styles)) + + +def test_code_block_with_lang(): + plain, styles = _markdown_to_signal("```python\ncode\n```") + assert "code" in plain + assert any("MONOSPACE" in s for s in styles) + + +def test_code_block_not_processed_further(): + """Markdown inside a code block must not be styled.""" + plain, styles = _markdown_to_signal("```\n**not bold**\n```") + assert "**not bold**" in plain + # Only MONOSPACE should be applied, no BOLD + for entry in styles: + assert "BOLD" not in entry + + +def test_inline_code_not_processed_further(): + """Markdown inside inline code must not be styled.""" + plain, styles = _markdown_to_signal("use `**raw**` please") + assert "**raw**" in plain + for entry in styles: + assert "BOLD" not in entry + + +# --------------------------------------------------------------------------- +# Headers +# --------------------------------------------------------------------------- + + +def test_header_becomes_bold(): + plain, styles = _markdown_to_signal("# My Title") + assert plain == "My Title" + assert styles_for(plain, styles) == {"My Title": ["BOLD"]} + + +def test_h2_becomes_bold(): + plain, styles = _markdown_to_signal("## Sub-section") + assert plain == "Sub-section" + assert styles_for(plain, styles) == {"Sub-section": ["BOLD"]} + + +# --------------------------------------------------------------------------- +# Blockquotes +# --------------------------------------------------------------------------- + + +def test_blockquote_strips_marker(): + plain, styles = _markdown_to_signal("> some quote") + assert plain == "some quote" + assert styles == [] + + +# --------------------------------------------------------------------------- +# Lists +# --------------------------------------------------------------------------- + + +def test_bullet_dash(): + plain, styles = _markdown_to_signal("- item one") + assert plain == "• item one" + + +def test_bullet_star(): + plain, styles = _markdown_to_signal("* item two") + assert plain == "• item two" + + +def test_numbered_list(): + plain, styles = _markdown_to_signal("1. first\n2. second") + assert "1. first" in plain + assert "2. second" in plain + + +# --------------------------------------------------------------------------- +# Links +# --------------------------------------------------------------------------- + + +def test_link_text_differs_from_url(): + plain, styles = _markdown_to_signal("[Click here](https://example.com)") + assert plain == "Click here (https://example.com)" + assert styles == [] + + +def test_link_text_equals_url(): + plain, styles = _markdown_to_signal("[https://example.com](https://example.com)") + assert plain == "https://example.com" + assert styles == [] + + +def test_link_text_equals_url_without_scheme(): + plain, styles = _markdown_to_signal("[example.com](https://example.com)") + assert plain == "https://example.com" + + +# --------------------------------------------------------------------------- +# Mixed / nesting +# --------------------------------------------------------------------------- + + +def test_bold_and_italic_adjacent(): + plain, styles = _markdown_to_signal("**bold** and *italic*") + assert plain == "bold and italic" + sd = styles_for(plain, styles) + assert sd.get("bold") == ["BOLD"] + assert sd.get("italic") == ["ITALIC"] + + +def test_header_with_inline_code(): + """Header becomes BOLD; code inside becomes MONOSPACE (not double-BOLD).""" + plain, styles = _markdown_to_signal("# Use `grep`") + assert plain == "Use grep" + sd = styles_for(plain, styles) + assert "BOLD" in sd.get("Use ", []) or "BOLD" in str(styles) + assert "MONOSPACE" in sd.get("grep", []) + + +def test_multiline_mixed(): + md = "**Title**\n\nSome *italic* text.\n\n- bullet\n- another" + plain, styles = _markdown_to_signal(md) + assert "Title" in plain + assert "italic" in plain + assert "• bullet" in plain + sd = styles_for(plain, styles) + assert "BOLD" in sd.get("Title", []) + assert "ITALIC" in sd.get("italic", []) + + +# --------------------------------------------------------------------------- +# Table rendering +# --------------------------------------------------------------------------- + + +def test_table_rendered_as_monospace(): + md = "| A | B |\n| - | - |\n| 1 | 2 |" + plain, styles = _markdown_to_signal(md) + assert "A" in plain and "B" in plain + assert any("MONOSPACE" in s for s in styles) + + +# --------------------------------------------------------------------------- +# Style range format +# --------------------------------------------------------------------------- + + +def test_style_range_format(): + """Each style entry must be 'start:length:STYLE'.""" + _, styles = _markdown_to_signal("**bold** text") + for entry in styles: + parts = entry.split(":") + assert len(parts) == 3 + assert parts[0].isdigit() + assert parts[1].isdigit() + assert parts[2] in {"BOLD", "ITALIC", "STRIKETHROUGH", "MONOSPACE", "SPOILER"} + + +def test_style_ranges_are_within_bounds(): + text = "hello **world** end" + plain, styles = _markdown_to_signal(text) + for entry in styles: + start_s, length_s, _ = entry.split(":", 2) + start, length = int(start_s), int(length_s) + assert start >= 0 + assert start + length <= len(plain)