diff --git a/nanobot/channels/signal.py b/nanobot/channels/signal.py new file mode 100644 index 000000000..3e35ae676 --- /dev/null +++ b/nanobot/channels/signal.py @@ -0,0 +1,1133 @@ +"""Signal channel implementation using signal-cli daemon JSON-RPC interface.""" + +from __future__ import annotations + +import asyncio +import json +import re +import shutil +import unicodedata +from collections import deque +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +import httpx +from pydantic import Field + +from nanobot.bus.events import OutboundMessage +from nanobot.bus.queue import MessageBus +from nanobot.channels.base import BaseChannel +from nanobot.config.paths import get_media_dir +from nanobot.config.schema import Base +from nanobot.utils.helpers import safe_filename, split_message + + +@dataclass +class _Run: + text: str + styles: frozenset[str] = field(default_factory=frozenset) + opaque: bool = False # code / table content — skip further pattern processing + + +_SIG_CODE_BLOCK_RE = re.compile(r'```(?:\w+)?\n?([\s\S]*?)```') +_SIG_INLINE_CODE_RE = re.compile(r'`([^`\n]+)`') +_SIG_HEADER_RE = re.compile(r'^#{1,6}\s+(.+)$', re.MULTILINE) +_SIG_BLOCKQUOTE_RE = re.compile(r'^>\s*(.*)$', re.MULTILINE) +_SIG_BULLET_RE = re.compile(r'^[-*]\s+', re.MULTILINE) +_SIG_OLIST_RE = re.compile(r'^(\d+)\.\s+', re.MULTILINE) +_SIG_LINK_RE = re.compile(r'\[([^\]]+)\]\(([^)]+)\)') +_SIG_BOLD_RE = re.compile(r'\*\*(.+?)\*\*|__(.+?)__', re.DOTALL) +_SIG_ITALIC_RE = re.compile(r'(? str: + """Strip inline markdown from a table cell for plain-text rendering.""" + s = re.sub(r'\*\*(.+?)\*\*', r'\1', s) + s = re.sub(r'__(.+?)__', r'\1', s) + s = re.sub(r'~~(.+?)~~', r'\1', s) + s = re.sub(r'`([^`]+)`', r'\1', s) + return s.strip() + + +def _sig_render_table(table_lines: list[str]) -> str: + """Render a markdown pipe-table as fixed-width plain text.""" + + def dw(s: str) -> int: + return sum(2 if unicodedata.east_asian_width(c) in ('W', 'F') else 1 for c in s) + + rows: list[list[str]] = [] + has_sep = False + for line in table_lines: + cells = [_sig_strip_cell(c) for c in line.strip().strip('|').split('|')] + if all(re.match(r'^:?-+:?$', c) for c in cells if c): + has_sep = True + continue + rows.append(cells) + if not rows or not has_sep: + return '\n'.join(table_lines) + + ncols = max(len(r) for r in rows) + for r in rows: + r.extend([''] * (ncols - len(r))) + widths = [max(dw(r[c]) for r in rows) for c in range(ncols)] + + def dr(cells: list[str]) -> str: + return ' '.join(f'{c}{" " * (w - dw(c))}' for c, w in zip(cells, widths)) + + out = [dr(rows[0])] + out.append(' '.join('─' * w for w in widths)) + for row in rows[1:]: + out.append(dr(row)) + return '\n'.join(out) + + +def _markdown_to_signal(text: str) -> tuple[str, list[str]]: + """Convert markdown text to Signal plain text + textStyle ranges. + + Returns ``(plain_text, text_styles)`` where ``text_styles`` are + ``"start:length:STYLE"`` strings for the signal-cli ``textStyle`` parameter. + """ + if not text: + return text, [] + + # Phase 1 (text-level): extract code blocks and tables with placeholder tokens + # so they're protected from inline-style processing. + protected: list[str] = [] + + def save_code(m: re.Match) -> str: + protected.append(m.group(1)) + return f"\x00C{len(protected) - 1}\x00" + + text = _SIG_CODE_BLOCK_RE.sub(save_code, text) + + # Detect and render pipe-tables line by line. + lines = text.split('\n') + rebuilt: list[str] = [] + i = 0 + while i < len(lines): + if re.match(r'^\s*\|.+\|', lines[i]): + tbl: list[str] = [] + while i < len(lines) and re.match(r'^\s*\|.+\|', lines[i]): + tbl.append(lines[i]) + i += 1 + rendered = _sig_render_table(tbl) + if rendered != '\n'.join(tbl): + protected.append(rendered) + rebuilt.append(f"\x00C{len(protected) - 1}\x00") + else: + rebuilt.extend(tbl) + else: + rebuilt.append(lines[i]) + i += 1 + text = '\n'.join(rebuilt) + + # Phase 2 (run-based): process inline patterns. + runs: list[_Run] = [_Run(text)] + + def transform(pattern: re.Pattern, make_runs: Any) -> None: + new_runs: list[_Run] = [] + for run in runs: + if run.opaque: + new_runs.append(run) + continue + pos = 0 + for m in pattern.finditer(run.text): + if m.start() > pos: + new_runs.append(_Run(run.text[pos:m.start()], run.styles)) + new_runs.extend(make_runs(m, run.styles)) + pos = m.end() + if pos < len(run.text): + new_runs.append(_Run(run.text[pos:], run.styles)) + runs[:] = new_runs + + # Restore code/table placeholders as opaque MONOSPACE runs. + transform(_SIG_TOKEN_RE, lambda m, s: [_Run(protected[int(m.group(1))], s | {"MONOSPACE"}, opaque=True)]) + + # Inline code (opaque). + transform(_SIG_INLINE_CODE_RE, lambda m, s: [_Run(m.group(1), s | {"MONOSPACE"}, opaque=True)]) + + # Headers → bold plain text. + transform(_SIG_HEADER_RE, lambda m, s: [_Run(m.group(1), s | {"BOLD"})]) + + # Blockquotes → strip marker. + transform(_SIG_BLOCKQUOTE_RE, lambda m, s: [_Run(m.group(1), s)]) + + # Bullet lists → bullet character. + transform(_SIG_BULLET_RE, lambda m, s: [_Run("• ", s)]) + + # Numbered lists → normalize spacing. + transform(_SIG_OLIST_RE, lambda m, s: [_Run(m.group(1) + ". ", s)]) + + # Links → "text (url)" or bare url when text equals url. + def _link_runs(m: re.Match, s: frozenset) -> list[_Run]: + link_text, url = m.group(1), m.group(2) + + def _norm(u: str) -> str: + return re.sub(r'^https?://(www\.)?', '', u).rstrip('/').lower() + + if _norm(url) == _norm(link_text): + return [_Run(url, s)] + return [_Run(f"{link_text} ({url})", s)] + + transform(_SIG_LINK_RE, _link_runs) + + # Bold (before italic so ** doesn't interfere). + transform(_SIG_BOLD_RE, lambda m, s: [_Run(m.group(1) or m.group(2), s | {"BOLD"})]) + + # Italic (single * or _). + transform(_SIG_ITALIC_RE, lambda m, s: [_Run(m.group(1) or m.group(2), s | {"ITALIC"})]) + + # Strikethrough: ~~text~~ (standard) or ~text~ (single-tilde variant). + transform(_SIG_STRIKE_RE, lambda m, s: [_Run(m.group(1) or m.group(2), s | {"STRIKETHROUGH"})]) + + # Phase 3: assemble output. + plain_text = "" + text_styles: list[str] = [] + for run in runs: + if not run.text: + continue + start = len(plain_text) + plain_text += run.text + length = len(plain_text) - start + for style in sorted(run.styles): + text_styles.append(f"{start}:{length}:{style}") + + return plain_text, text_styles + + +class SignalDMConfig(Base): + """Signal DM policy configuration.""" + + enabled: bool = False + policy: str = "allowlist" # "open" or "allowlist" + allow_from: list[str] = Field(default_factory=list) # Allowed phone numbers/UUIDs + + +class SignalGroupConfig(Base): + """Signal group policy configuration.""" + + enabled: bool = False + policy: str = "allowlist" # "open" or "allowlist" - which groups to operate in + allow_from: list[str] = Field(default_factory=list) # Allowed group IDs if allowlist policy + require_mention: bool = True # Whether bot must be mentioned to respond + + +class SignalConfig(Base): + """Signal channel configuration using signal-cli daemon (HTTP mode with -a flag only).""" + + enabled: bool = False + phone_number: str = "" # Your Signal phone number (e.g., "+1234567890") + daemon_host: str = "localhost" + daemon_port: int = 8080 + group_message_buffer_size: int = 20 # Number of recent group messages to keep for context + dm: SignalDMConfig = Field(default_factory=SignalDMConfig) + group: SignalGroupConfig = Field(default_factory=SignalGroupConfig) + + @property + def allow_from(self) -> list[str]: + """Aggregate allowlist for the base-class is_allowed() check. + + Returns the union of dm.allow_from and group.allow_from so the base + channel gate sees a populated list when either sub-policy is configured. + A ``"*"`` wildcard in either sub-list propagates to allow all. + """ + return list(dict.fromkeys(self.dm.allow_from + self.group.allow_from)) + + +class SignalChannel(BaseChannel): + """ + Signal channel using signal-cli daemon via HTTP JSON-RPC interface. + + Requires signal-cli daemon in HTTP mode: + - signal-cli -a +1234567890 daemon --http localhost:8080 + + See https://github.com/AsamK/signal-cli for setup instructions. + """ + + name = "signal" + display_name = "Signal" + _TYPING_REFRESH_SECONDS = 10.0 + _MAX_MESSAGE_LEN = 64_000 # signal-cli practical limit (protocol max ~64 KB) + + @classmethod + def default_config(cls) -> dict[str, Any]: + return SignalConfig().model_dump(by_alias=True) + + def __init__(self, config: SignalConfig, bus: MessageBus): + if isinstance(config, dict): + config = SignalConfig.model_validate(config) + super().__init__(config, bus) + self.config: SignalConfig = config + self._http: httpx.AsyncClient | None = None + self._request_id = 0 + self._sse_task: asyncio.Task | None = None + self._typing_tasks: dict[str, asyncio.Task] = {} + self._typing_uuid_warnings: set[str] = set() + self._account_id_aliases: set[str] = set() + self._remember_account_id_alias(self.config.phone_number) + + # Rolling message buffer for group context (group_id -> deque of messages) + # Each message is a dict with: sender_name, sender_number, content, timestamp + self._group_buffers: dict[str, deque] = {} + + async def start(self) -> None: + """Start the Signal channel and connect to signal-cli daemon.""" + if not self.config.phone_number: + self.self.logger.error("Signal account not configured") + return + + self._running = True + await self._start_http_mode() + + async def _start_http_mode(self) -> None: + """Start Signal channel using Server-Sent Events for receiving messages.""" + base_url = f"http://{self.config.daemon_host}:{self.config.daemon_port}" + reconnect_delay_s = 1.0 + max_reconnect_delay_s = 30.0 + + while self._running: + try: + self.logger.info(f"Connecting to signal-cli daemon at {base_url}...") + + # Create HTTP client + self._http = httpx.AsyncClient(timeout=60.0, base_url=base_url) + + # Test connection + try: + response = await self._http.get("/api/v1/check") + if response.status_code == 200: + self.logger.info("Connected to signal-cli daemon") + else: + raise ConnectionRefusedError( + f"signal-cli daemon check returned status {response.status_code}" + ) + except Exception as e: + raise ConnectionRefusedError(f"signal-cli daemon not responding: {e}") + + # Reset reconnect delay after successful connection check. + reconnect_delay_s = 1.0 + + # Ensure account-level typing indicators are enabled. + await self._ensure_typing_indicators_enabled() + + # Start SSE receiver and supervise it. If it exits while we're still + # running, treat it as a disconnect and reconnect. + self._sse_task = asyncio.create_task(self._sse_receive_loop()) + await self._sse_task + if self._running: + raise ConnectionError("Signal SSE stream ended unexpectedly") + + except asyncio.CancelledError: + break + except ConnectionRefusedError as e: + self.logger.error( + f"{e}. Make sure signal-cli daemon is running: " + f"signal-cli -a {self.config.phone_number} daemon --http {self.config.daemon_host}:{self.config.daemon_port}" + ) + except Exception as e: + self.logger.error(f"Signal channel error: {e}") + finally: + if self._sse_task: + if not self._sse_task.done(): + self._sse_task.cancel() + try: + await self._sse_task + except asyncio.CancelledError: + pass + except Exception: + pass + self._sse_task = None + if self._http: + await self._http.aclose() + self._http = None + + if self._running: + self.logger.info( + f"Reconnecting to signal-cli daemon in {reconnect_delay_s:.0f} seconds..." + ) + await asyncio.sleep(reconnect_delay_s) + reconnect_delay_s = min(reconnect_delay_s * 2, max_reconnect_delay_s) + + async def stop(self) -> None: + """Stop the Signal channel.""" + self._running = False + + # Stop SSE task + if self._sse_task: + self._sse_task.cancel() + try: + await self._sse_task + except asyncio.CancelledError: + pass + + # Cancel active typing indicators + for chat_id in list(self._typing_tasks): + await self._stop_typing(chat_id) + + # Close HTTP client + if self._http: + await self._http.aclose() + self._http = None + + async def send(self, msg: OutboundMessage) -> None: + """Send a message through Signal.""" + is_progress_message = bool(msg.metadata.get("_progress")) + try: + plain_text, text_styles = _markdown_to_signal(msg.content) + if not plain_text and not msg.media: + return + recipient_params = self._recipient_params(msg.chat_id) + + chunks = split_message(plain_text, self._MAX_MESSAGE_LEN) if plain_text else [""] + for i, chunk in enumerate(chunks): + params: dict[str, Any] = {"message": chunk} + if text_styles and i == 0: + params["textStyle"] = text_styles + params.update(recipient_params) + if msg.media and i == 0: + params["attachments"] = msg.media + + response = await self._send_request("send", params) + + if "error" in response: + self.logger.error(f"Error sending Signal message: {response['error']}") + else: + self.logger.debug( + f"Signal message sent, timestamp: {response.get('result', {}).get('timestamp')}" + ) + + except Exception: + self.logger.exception("Error sending Signal message") + raise + finally: + # Keep typing active across progress updates; stop on the final reply. + if not is_progress_message: + # Avoid immediate START->STOP for fast responses, which can be invisible + # in some Signal clients. Let indicator expire naturally (~15s). + await self._stop_typing(msg.chat_id, send_stop=False) + + async def _sse_receive_loop(self) -> None: + """Receive messages via Server-Sent Events (HTTP mode).""" + if not self._http: + raise RuntimeError("HTTP client not initialized for Signal SSE stream") + + self.logger.info("Started Signal message receive loop (SSE)") + + try: + async with self._http.stream("GET", "/api/v1/events") as response: + if response.status_code != 200: + raise ConnectionError( + f"SSE connection failed with status {response.status_code}" + ) + + self.logger.info("Subscribed to Signal messages via SSE") + + # Buffer for accumulating SSE data across multiple lines + event_buffer = [] + + async for line in response.aiter_lines(): + if not self._running: + break + + # Debug: log raw SSE lines (except keepalive pings) + if line and line != ":": + self.logger.debug(f"SSE line received: {line[:200]}") + + # SSE format handling + if isinstance(line, str): + # Empty line signals end of event + if not line or line == ":": + if event_buffer: + # Try to parse the accumulated data + data_str = "" + try: + data_str = "".join(event_buffer) + data = json.loads(data_str) + self.logger.debug(f"SSE event parsed: {data}") + await self._handle_receive_notification(data) + except json.JSONDecodeError as e: + self.logger.warning( + f"Invalid JSON in SSE buffer: {e}, data: {data_str[:200]}" + ) + finally: + event_buffer = [] + + # "data:" line - accumulate it + elif line.startswith("data:"): + event_buffer.append(line[5:]) # Skip "data:" prefix + + # "event:" line - just log it (we only care about data) + elif line.startswith("event:"): + pass # Ignore event type for now + + if self._running: + raise ConnectionError("Signal SSE stream closed by remote endpoint") + + except asyncio.CancelledError: + self.logger.info("SSE receive loop cancelled") + raise + except Exception as e: + self.logger.error(f"Error in SSE receive loop: {e}") + raise + + async def _handle_receive_notification(self, params: dict[str, Any]) -> None: + """Handle incoming message notification from signal-cli.""" + self.logger.debug(f"_handle_receive_notification called with: {params}") + try: + # Extract envelope from SSE notification: {"envelope": {...}} + envelope = params.get("envelope", {}) + + self.logger.debug(f"Extracted envelope: {envelope}") + + if not envelope: + self.logger.debug("No envelope found in params") + return + + # Extract sender information + sender_parts = self._collect_sender_id_parts(envelope) + source_name = envelope.get("sourceName") + + if not sender_parts: + self.logger.debug("Received message without source, skipping") + return + + sender_number = self._primary_sender_id(sender_parts) + sender_id = "|".join(sender_parts) + + # Keep aliases of the bot account for robust mention matching. + if any(self._id_matches_account(part) for part in sender_parts): + for part in sender_parts: + self._remember_account_id_alias(part) + + # Check different message types + data_message = envelope.get("dataMessage") + sync_message = envelope.get("syncMessage") + typing_message = envelope.get("typingMessage") + receipt_message = envelope.get("receiptMessage") + + # Ignore receipt messages (delivery/read receipts) + if receipt_message: + return + + # Handle data messages (incoming messages from others) + if data_message: + await self._handle_data_message(sender_id, sender_number, data_message, source_name) + + # Handle sync messages (messages sent from another device) + elif sync_message and sync_message.get("sentMessage"): + sent_msg = sync_message["sentMessage"] + destination = sent_msg.get("destination") or sent_msg.get("destinationNumber") + if destination: + self.logger.debug( + f"Sync message sent to {destination}: {sent_msg.get('message', '')[:50]}" + ) + + # Handle typing indicators (silently ignore) + elif typing_message: + pass # Ignore typing indicators + + except Exception as e: + self.logger.error(f"Error handling receive notification: {e}") + + async def _handle_data_message( + self, + sender_id: str, + sender_number: str, + data_message: dict[str, Any], + sender_name: str | None, + ) -> None: + """Handle a data message (text, attachments, etc.).""" + message_text = data_message.get("message") or "" + attachments = data_message.get("attachments", []) + group_info = data_message.get("groupInfo") + timestamp = data_message.get("timestamp") + mentions = data_message.get("mentions", []) + reaction = data_message.get("reaction") + + # Log full data_message for debugging group detection + self.logger.info( + f"Data message from {sender_number}: " + f"groupInfo={group_info}, " + f"groupV2={data_message.get('groupV2')}, " + f"keys={list(data_message.keys())}" + ) + + # Ignore reaction messages (emoji reactions to messages) + if reaction: + self.logger.debug(f"Ignoring reaction message from {sender_number}: {reaction}") + return + + # Ignore empty messages (e.g., when bot is added to a group) + if not message_text and not attachments: + self.logger.debug(f"Ignoring empty message from {sender_number}") + return + + # Determine chat_id (group ID or sender number) + # Check both groupInfo (v1) and groupV2 (v2) fields for group detection + group_v2 = data_message.get("groupV2") + is_group_message = group_info is not None or group_v2 is not None + group_id = self._extract_group_id(group_info, group_v2) + + is_command = bool(message_text and message_text.strip().startswith("/")) + + if is_group_message: + chat_id = group_id or sender_number + + # Check if this group is allowed before doing anything else + if not self.config.group.enabled: + self.logger.info(f"Ignoring group message from {chat_id} (groups disabled)") + return + if ( + self.config.group.policy == "allowlist" + and chat_id not in self.config.group.allow_from + ): + self.logger.info( + f"Ignoring group message from {chat_id} (policy: {self.config.group.policy})" + ) + return + + # Add to group message buffer (group is allowed) + self._add_to_group_buffer( + group_id=chat_id, + sender_name=sender_name or sender_number, + sender_number=sender_number, + message_text=message_text, + timestamp=timestamp, + ) + + # Commands bypass the mention requirement; non-commands require it. + if not is_command and not self._should_respond_in_group(message_text, mentions): + self.logger.info( + f"Ignoring group message (require_mention: {self.config.group.require_mention})" + ) + return + else: + # Direct message — check policy first, then forward everything to the bus. + chat_id = sender_number + if not self.config.dm.enabled: + self.logger.debug(f"Ignoring DM from {sender_id} (DMs disabled)") + return + if self.config.dm.policy == "allowlist": + allow_list = self.config.dm.allow_from + sender_str = str(sender_id) + parts = [sender_str] + (sender_str.split("|") if "|" in sender_str else []) + if not any(p for p in parts if p in allow_list): + self.logger.debug(f"Ignoring DM from {sender_id} (policy: {self.config.dm.policy})") + return + + # Build content from text and attachments + content_parts = [] + media_paths = [] + + # For group messages, include recent message context + if is_group_message: + buffer_context = self._get_group_buffer_context(chat_id) + if buffer_context: + content_parts.append(f"[Recent group messages for context:]\n{buffer_context}\n---") + + # Prepend sender name for group messages so history shows who said what + if message_text: + # Strip bot mentions from text (for group messages) + if is_group_message: + message_text = self._strip_bot_mention(message_text, mentions) + # Prepend sender name to make it clear who is speaking + display_name = sender_name or sender_number + message_text = f"[{display_name}]: {message_text}" + content_parts.append(message_text) + + # Handle attachments + if attachments: + media_dir = get_media_dir("signal") + + for attachment in attachments: + attachment_id = attachment.get("id") + content_type = attachment.get("contentType", "") + filename = attachment.get("filename") or f"attachment_{attachment_id}" + + if not attachment_id: + continue + + try: + # signal-cli stores attachments in ~/.local/share/signal-cli/attachments/ + source_path = ( + Path.home() / ".local/share/signal-cli/attachments" / attachment_id + ) + + if source_path.exists(): + dest_path = media_dir / f"signal_{safe_filename(filename)}" + shutil.copy2(source_path, dest_path) + media_paths.append(str(dest_path)) + + # Determine media type from content type + media_type = content_type.split("/")[0] if "/" in content_type else "file" + if media_type not in ("image", "audio", "video"): + media_type = "file" + + content_parts.append(f"[{media_type}: {dest_path}]") + self.logger.debug(f"Downloaded attachment: {filename} -> {dest_path}") + else: + self.logger.warning(f"Attachment not found: {source_path}") + content_parts.append(f"[attachment: {filename} - not found]") + + except Exception as e: + self.logger.warning(f"Failed to process attachment {filename}: {e}") + content_parts.append(f"[attachment: {filename} - error]") + + content = "\n".join(content_parts) if content_parts else "[empty message]" + + self.logger.debug(f"Signal message from {sender_number}: {content[:50]}...") + + await self._start_typing(chat_id) + try: + # Forward to message bus + await self._handle_message( + sender_id=sender_id, + chat_id=chat_id, + content=content, + media=media_paths, + metadata={ + "timestamp": timestamp, + "sender_name": sender_name, + "sender_number": sender_number, + "is_group": is_group_message, + "group_id": group_id, + }, + ) + except Exception: + await self._stop_typing(chat_id) + raise + + def _add_to_group_buffer( + self, + group_id: str, + sender_name: str, + sender_number: str, + message_text: str, + timestamp: int | None, + ) -> None: + """ + Add a message to the group's rolling buffer. + + Args: + group_id: The group ID + sender_name: Display name of sender + sender_number: Phone number of sender + message_text: The message content + timestamp: Message timestamp + """ + if self.config.group_message_buffer_size <= 0: + return + + # Create buffer for this group if it doesn't exist + if group_id not in self._group_buffers: + self._group_buffers[group_id] = deque(maxlen=self.config.group_message_buffer_size) + + # Add message to buffer (deque will automatically drop oldest when full) + self._group_buffers[group_id].append( + { + "sender_name": sender_name, + "sender_number": sender_number, + "content": message_text, + "timestamp": timestamp, + } + ) + + self.logger.debug( + f"Added message to group buffer {group_id}: " + f"{len(self._group_buffers[group_id])}/{self.config.group_message_buffer_size}" + ) + + def _get_group_buffer_context(self, group_id: str) -> str: + """ + Get formatted context from the group's message buffer. + + Args: + group_id: The group ID + + Returns: + Formatted string of recent messages (excluding the current one) + """ + if group_id not in self._group_buffers: + return "" + + buffer = self._group_buffers[group_id] + if len(buffer) <= 1: # Only current message, no context + return "" + + # Format all messages except the last one (which is the current message) + # We want to show context BEFORE the mention + context_messages = list(buffer)[:-1] # Exclude the last (current) message + + lines = [] + for msg in context_messages: + sender = msg["sender_name"] + content = msg["content"][:200] # Limit to 200 chars per message + lines.append(f"{sender}: {content}") + + return "\n".join(lines) + + @staticmethod + def _normalize_signal_id(value: str) -> list[str]: + """Normalize Signal identifiers (phone/uuid/service-id) for matching.""" + raw = value.strip() + if not raw: + return [] + + normalized = [raw, raw.lower()] + if raw.startswith("+") and len(raw) > 1: + normalized.append(raw[1:]) + elif raw.isdigit(): + normalized.append(f"+{raw}") + return list(dict.fromkeys(normalized)) + + def _remember_account_id_alias(self, value: str | None) -> None: + """Remember known bot identifiers for mention matching.""" + if not value: + return + if not isinstance(value, str): + return + for candidate in self._normalize_signal_id(value): + self._account_id_aliases.add(candidate) + + def _id_matches_account(self, value: str | None) -> bool: + """Return True when an identifier refers to the bot account.""" + if not value: + return False + if not isinstance(value, str): + return False + return any( + candidate in self._account_id_aliases for candidate in self._normalize_signal_id(value) + ) + + @staticmethod + def _collect_sender_id_parts(envelope: dict[str, Any]) -> list[str]: + """Collect all known sender identifier variants from an envelope.""" + parts: list[str] = [] + for key in ( + "sourceNumber", + "source", + "sourceUuid", + "sourceServiceId", + "sourceAci", + "sourceACI", + ): + value = envelope.get(key) + if not isinstance(value, str): + continue + candidate = value.strip() + if candidate and candidate not in parts: + parts.append(candidate) + return parts + + @staticmethod + def _primary_sender_id(sender_parts: list[str]) -> str: + """Pick the best sender identifier for routing (prefer phone-like IDs).""" + for part in sender_parts: + if part.startswith("+") or part.isdigit(): + return part + return sender_parts[0] if sender_parts else "" + + @staticmethod + def _extract_group_id(group_info: Any, group_v2: Any) -> str | None: + """Extract group ID from groupInfo/groupV2 payloads across signal-cli variants.""" + for group_obj in (group_info, group_v2): + if not isinstance(group_obj, dict): + continue + for key in ("groupId", "id", "groupID"): + value = group_obj.get(key) + if isinstance(value, str) and value: + return value + return None + + @staticmethod + def _mention_id_candidates(mention: dict[str, Any]) -> list[str]: + """Extract possible identifier fields from a mention payload.""" + ids: list[str] = [] + + def _walk(value: Any, depth: int = 0) -> None: + if depth > 2: + return + if not isinstance(value, dict): + return + for key, child in value.items(): + key_lower = str(key).lower() + if isinstance(child, str) and child: + if any(token in key_lower for token in ("number", "uuid", "serviceid", "aci")): + ids.append(child) + elif isinstance(child, dict): + _walk(child, depth + 1) + + _walk(mention) + return list(dict.fromkeys(ids)) + + @staticmethod + def _mention_span(mention: dict[str, Any]) -> tuple[int, int] | None: + """Extract a safe (start, length) span from a mention.""" + try: + start = int(mention.get("start", 0)) + length = int(mention.get("length", 0)) + except (TypeError, ValueError): + return None + + if start < 0 or length <= 0: + return None + return (start, length) + + @staticmethod + def _leading_placeholder_span(text: str | None) -> tuple[int, int] | None: + """ + Detect a leading Signal mention placeholder when mention metadata is missing. + + Some clients/integrations deliver mentions as a leading placeholder character + (typically U+FFFC) but omit `mentions` metadata in the payload. + """ + if not text: + return None + + start = 0 + while start < len(text) and text[start].isspace(): + start += 1 + + if start >= len(text): + return None + + marker = text[start] + if marker not in ("\ufffc", "\ufffd", "\x1b"): + return None + + next_index = start + 1 + if next_index < len(text) and not text[next_index].isspace(): + return None + + return (start, 1) + + def _should_respond_in_group(self, message_text: str, mentions: list[dict[str, Any]]) -> bool: + """ + Determine if the bot should respond to a group message. + + Args: + message_text: The message text content + mentions: List of mentions from Signal (format: [{"number": "+1234567890", "start": 0, "length": 10}]) + + Returns: + True if bot should respond, False otherwise + """ + # Group reply behavior is controlled only by group.require_mention. + if not self.config.group.require_mention: + return True + + # If mention is required, check if bot was mentioned. + for mention in mentions: + if not isinstance(mention, dict): + continue + for mention_id in self._mention_id_candidates(mention): + if self._id_matches_account(mention_id): + return True + + # Some Signal clients emit mention spans without recipient identifiers + # (for handle-style mentions). Accept a leading identifier-less mention + # as a mention of the bot to avoid false negatives. + for mention in mentions: + if not isinstance(mention, dict): + continue + if self._mention_id_candidates(mention): + continue + span = self._mention_span(mention) + if not span: + continue + start, _ = span + if message_text is not None and not message_text[:start].strip(): + self.logger.debug("Accepting identifier-less leading mention as bot mention") + return True + + # Some payloads omit `mentions` but still include the leading mention + # placeholder character in the message body. + if not mentions and self._leading_placeholder_span(message_text): + self.logger.debug("Accepting leading placeholder mention without mention metadata") + return True + + # Fallback: check for configured phone number in plain text. + if message_text and self.config.phone_number: + for account_id in self._normalize_signal_id(self.config.phone_number): + if account_id and account_id in message_text: + return True + + return False + + def _strip_bot_mention(self, text: str, mentions: list[dict[str, Any]]) -> str: + """ + Remove bot mentions from message text. + + Signal mentions are embedded in the text, so we need to remove them based on + the mentions array which provides start position and length. + + Args: + text: Original message text + mentions: List of mention objects with start/length positions + + Returns: + Text with bot mentions removed + """ + if not text: + return text + + # Build a list of (start, length) tuples for our bot's mentions + bot_mentions = [] + for mention in mentions: + if not isinstance(mention, dict): + continue + mention_ids = self._mention_id_candidates(mention) + span = self._mention_span(mention) + if not span: + continue + + # Strip matched bot mentions by ID. + if any(self._id_matches_account(mention_id) for mention_id in mention_ids): + bot_mentions.append(span) + continue + + # Also strip identifier-less leading mention spans (handle mentions). + if not mention_ids: + start, _ = span + if not text[:start].strip(): + bot_mentions.append(span) + + if not bot_mentions: + placeholder_span = self._leading_placeholder_span(text) + if placeholder_span: + bot_mentions.append(placeholder_span) + + # Sort mentions by start position (descending) to remove from end to start + # This prevents position shifts when removing earlier mentions + bot_mentions.sort(reverse=True) + + # Remove each mention + for start, length in bot_mentions: + if start >= len(text): + continue + end = min(len(text), start + length) + text = text[:start] + text[end:] + + return text.strip() + + @staticmethod + def _is_group_chat_id(chat_id: str) -> bool: + """Return True when chat_id appears to be a Signal group ID (base64).""" + return "=" in chat_id or (len(chat_id) > 40 and "-" not in chat_id) + + def _recipient_params(self, chat_id: str) -> dict[str, Any]: + """Build recipient params for signal-cli JSON-RPC methods.""" + if self._is_group_chat_id(chat_id): + return {"groupId": chat_id} + return {"recipient": [chat_id]} + + async def _start_typing(self, chat_id: str) -> None: + """Start periodic typing indicator updates for a chat.""" + await self._stop_typing(chat_id, send_stop=False) + await self._send_typing(chat_id) + self._typing_tasks[chat_id] = asyncio.create_task(self._typing_loop(chat_id)) + + async def _stop_typing(self, chat_id: str, send_stop: bool = True) -> None: + """Stop typing indicator updates for a chat.""" + task = self._typing_tasks.pop(chat_id, None) + had_task = task is not None + if task and not task.done(): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + if send_stop and had_task: + await self._send_typing(chat_id, stop=True) + + async def _typing_loop(self, chat_id: str) -> None: + """Send typing updates periodically until cancelled.""" + try: + while self._running: + await asyncio.sleep(self._TYPING_REFRESH_SECONDS) + await self._send_typing(chat_id, quiet_success=True) + except asyncio.CancelledError: + pass + except Exception as e: + self.logger.debug(f"Typing indicator loop stopped for {chat_id}: {e}") + + async def _send_typing( + self, chat_id: str, stop: bool = False, quiet_success: bool = False + ) -> None: + """Send a typing START/STOP message via signal-cli.""" + action = "stop" if stop else "start" + if ( + not self._is_group_chat_id(chat_id) + and chat_id.startswith("+") is False + and chat_id not in self._typing_uuid_warnings + ): + self._typing_uuid_warnings.add(chat_id) + self.logger.warning( + "Signal DM recipient is UUID-only (no phone number in envelope). " + "Some Signal clients may not render typing indicators for this recipient form." + ) + candidate_params: list[dict[str, Any]] + if self._is_group_chat_id(chat_id): + candidate_params = [{"groupId": chat_id}, {"groupId": [chat_id]}] + else: + candidate_params = [{"recipient": chat_id}, {"recipient": [chat_id]}] + + last_error: Any | None = None + for params in candidate_params: + if stop: + params["stop"] = True + try: + response = await self._send_request("sendTyping", params) + except Exception as e: + last_error = str(e) + continue + + if "error" not in response: + if not quiet_success: + self.logger.info(f"Signal typing {action} sent for {chat_id}") + return + + last_error = response["error"] + + self.logger.warning(f"Failed to send Signal typing {action} for {chat_id}: {last_error}") + + async def _ensure_typing_indicators_enabled(self) -> None: + """Enable typing indicators on the bot account.""" + response = await self._send_request("updateConfiguration", {"typingIndicators": True}) + if "error" in response: + self.logger.warning(f"Failed to enable Signal typing indicators: {response['error']}") + else: + self.logger.info("Signal typing indicators enabled on account configuration") + + async def _send_request( + self, method: str, params: dict[str, Any] | None = None + ) -> dict[str, Any]: + """Send a JSON-RPC request via HTTP and wait for response.""" + # Generate request ID + self._request_id += 1 + request_id = self._request_id + + # Build JSON-RPC request + request = {"jsonrpc": "2.0", "method": method, "id": request_id} + + if params: + request["params"] = params + + return await self._send_http_request(request) + + async def _send_http_request(self, request: dict[str, Any]) -> dict[str, Any]: + """Send JSON-RPC request via HTTP.""" + if not self._http: + raise RuntimeError("Not connected to signal-cli daemon") + + try: + response = await self._http.post("/api/v1/rpc", json=request) + response.raise_for_status() + return response.json() + except Exception as e: + self.logger.error(f"HTTP request failed: {e}") + return {"error": {"message": str(e)}} diff --git a/tests/channels/test_signal_channel.py b/tests/channels/test_signal_channel.py new file mode 100644 index 000000000..b5149459b --- /dev/null +++ b/tests/channels/test_signal_channel.py @@ -0,0 +1,1058 @@ +"""Tests for the Signal channel implementation.""" + +from __future__ import annotations + +import asyncio +from unittest.mock import AsyncMock + +import pytest + +from nanobot.bus.events import OutboundMessage +from nanobot.bus.queue import MessageBus +from nanobot.channels.signal import ( + SignalChannel, + SignalConfig, + SignalDMConfig, + SignalGroupConfig, +) + +# --------------------------------------------------------------------------- +# Fake HTTP client +# --------------------------------------------------------------------------- + + +class _FakeResponse: + def __init__(self, status_code: int = 200, body: dict | None = None) -> None: + self.status_code = status_code + self._body = body or {} + + def raise_for_status(self) -> None: + if self.status_code >= 400: + raise RuntimeError(f"HTTP {self.status_code}") + + def json(self) -> dict: + return self._body + + +class _FakeHTTPClient: + """Minimal httpx.AsyncClient stand-in that records requests.""" + + def __init__(self, *, default_response: dict | None = None) -> None: + self.posts: list[dict] = [] + self.gets: list[str] = [] + self._response = _FakeResponse(body=default_response or {"result": {"timestamp": 123}}) + self.closed = False + + async def get(self, path: str) -> _FakeResponse: + self.gets.append(path) + return self._response + + async def post(self, path: str, *, json: dict) -> _FakeResponse: + self.posts.append({"path": path, "json": json}) + return self._response + + async def aclose(self) -> None: + self.closed = True + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_channel( + *, + phone_number: str = "+10000000000", + dm_enabled: bool = True, + dm_policy: str = "open", + dm_allow_from: list[str] | None = None, + group_enabled: bool = False, + group_policy: str = "open", + group_allow_from: list[str] | None = None, + require_mention: bool = True, + group_buffer_size: int = 20, +) -> SignalChannel: + config = SignalConfig( + enabled=True, + phone_number=phone_number, + dm=SignalDMConfig( + enabled=dm_enabled, + policy=dm_policy, + allow_from=dm_allow_from or [], + ), + group=SignalGroupConfig( + enabled=group_enabled, + policy=group_policy, + allow_from=group_allow_from or [], + require_mention=require_mention, + ), + group_message_buffer_size=group_buffer_size, + ) + return SignalChannel(config, MessageBus()) + + +def _dm_envelope( + *, + source_number: str = "+19995550001", + source_uuid: str | None = None, + source_name: str | None = "Alice", + message: str = "hello", + attachments: list | None = None, + reaction: dict | None = None, + timestamp: int = 1000, +) -> dict: + data_message: dict = {"message": message, "timestamp": timestamp} + if attachments is not None: + data_message["attachments"] = attachments + if reaction is not None: + data_message["reaction"] = reaction + envelope: dict = { + "sourceNumber": source_number, + "sourceName": source_name, + "dataMessage": data_message, + } + if source_uuid: + envelope["sourceUuid"] = source_uuid + return {"envelope": envelope} + + +def _group_envelope( + *, + source_number: str = "+19995550001", + source_name: str = "Bob", + group_id: str = "group123==", + message: str = "hey group", + mentions: list | None = None, + timestamp: int = 2000, + use_v2: bool = False, +) -> dict: + group_obj = {"groupId": group_id} + key = "groupV2" if use_v2 else "groupInfo" + data_message: dict = { + "message": message, + "timestamp": timestamp, + key: group_obj, + "mentions": mentions or [], + } + return { + "envelope": { + "sourceNumber": source_number, + "sourceName": source_name, + "dataMessage": data_message, + } + } + + +# --------------------------------------------------------------------------- +# Static utility tests +# --------------------------------------------------------------------------- + + +class TestNormalizeSignalId: + def test_phone_number_kept_and_stripped(self): + result = SignalChannel._normalize_signal_id("+12345678901") + assert "+12345678901" in result + assert "12345678901" in result + + def test_digits_only_gets_plus_prefix(self): + result = SignalChannel._normalize_signal_id("12345678901") + assert "+12345678901" in result + + def test_lowercase_variant_added(self): + result = SignalChannel._normalize_signal_id("SOME-UUID") + assert "some-uuid" in result + + def test_empty_string_returns_empty(self): + assert SignalChannel._normalize_signal_id("") == [] + + def test_whitespace_stripped(self): + result = SignalChannel._normalize_signal_id(" +1234 ") + assert "+1234" in result + + +class TestCollectSenderIdParts: + def test_collects_source_number(self): + env = {"sourceNumber": "+15551234567"} + parts = SignalChannel._collect_sender_id_parts(env) + assert "+15551234567" in parts + + def test_collects_multiple_keys(self): + env = {"sourceNumber": "+15551234567", "sourceUuid": "uuid-abc"} + parts = SignalChannel._collect_sender_id_parts(env) + assert "+15551234567" in parts + assert "uuid-abc" in parts + + def test_deduplicates(self): + env = {"sourceNumber": "+15551234567", "source": "+15551234567"} + parts = SignalChannel._collect_sender_id_parts(env) + assert parts.count("+15551234567") == 1 + + def test_ignores_non_string_values(self): + env = {"sourceNumber": 12345, "sourceUuid": None} + parts = SignalChannel._collect_sender_id_parts(env) + assert parts == [] + + def test_empty_envelope_returns_empty(self): + assert SignalChannel._collect_sender_id_parts({}) == [] + + +class TestPrimarySenderId: + def test_prefers_phone_number(self): + assert SignalChannel._primary_sender_id(["+1234", "uuid-abc"]) == "+1234" + + def test_accepts_digit_only(self): + assert SignalChannel._primary_sender_id(["1234567890", "uuid-abc"]) == "1234567890" + + def test_falls_back_to_first_part(self): + assert SignalChannel._primary_sender_id(["uuid-abc", "other"]) == "uuid-abc" + + def test_empty_list_returns_empty(self): + assert SignalChannel._primary_sender_id([]) == "" + + +class TestExtractGroupId: + def test_extracts_from_group_info(self): + gid = SignalChannel._extract_group_id({"groupId": "abc=="}, None) + assert gid == "abc==" + + def test_extracts_from_group_v2(self): + gid = SignalChannel._extract_group_id(None, {"id": "xyz=="}) + assert gid == "xyz==" + + def test_prefers_group_info_over_v2(self): + gid = SignalChannel._extract_group_id({"groupId": "first"}, {"groupId": "second"}) + assert gid == "first" + + def test_returns_none_when_both_none(self): + assert SignalChannel._extract_group_id(None, None) is None + + def test_returns_none_when_not_dicts(self): + assert SignalChannel._extract_group_id("bad", 123) is None + + +class TestIsGroupChatId: + def test_base64_with_equals_is_group(self): + assert SignalChannel._is_group_chat_id("abc==") is True + + def test_long_id_without_dash_is_group(self): + long_id = "a" * 41 + assert SignalChannel._is_group_chat_id(long_id) is True + + def test_phone_number_is_not_group(self): + assert SignalChannel._is_group_chat_id("+12345678901") is False + + def test_uuid_with_dashes_is_not_group(self): + assert SignalChannel._is_group_chat_id("550e8400-e29b-41d4-a716-446655440000") is False + + +class TestRecipientParams: + def test_group_chat_uses_group_id(self): + ch = _make_channel() + params = ch._recipient_params("abc==") + assert params == {"groupId": "abc=="} + + def test_dm_uses_recipient_list(self): + ch = _make_channel() + params = ch._recipient_params("+12345678901") + assert params == {"recipient": ["+12345678901"]} + + +class TestMentionHelpers: + def test_mention_id_candidates_extracts_number(self): + mention = {"number": "+1234567890"} + ids = SignalChannel._mention_id_candidates(mention) + assert "+1234567890" in ids + + def test_mention_id_candidates_extracts_uuid(self): + mention = {"uuid": "some-uuid"} + ids = SignalChannel._mention_id_candidates(mention) + assert "some-uuid" in ids + + def test_mention_span_valid(self): + assert SignalChannel._mention_span({"start": 0, "length": 5}) == (0, 5) + + def test_mention_span_negative_start(self): + assert SignalChannel._mention_span({"start": -1, "length": 5}) is None + + def test_mention_span_zero_length(self): + assert SignalChannel._mention_span({"start": 0, "length": 0}) is None + + def test_mention_span_missing_keys(self): + assert SignalChannel._mention_span({}) is None + + def test_leading_placeholder_ufffc(self): + span = SignalChannel._leading_placeholder_span(" hello") + assert span == (0, 1) + + def test_leading_placeholder_not_at_start(self): + assert SignalChannel._leading_placeholder_span("hello ") is None + + def test_leading_placeholder_empty_string(self): + assert SignalChannel._leading_placeholder_span("") is None + + def test_leading_placeholder_plain_text(self): + assert SignalChannel._leading_placeholder_span("hello") is None + + +# --------------------------------------------------------------------------- +# Account ID alias / mention matching +# --------------------------------------------------------------------------- + + +class TestAccountIdAliases: + def test_phone_number_alias_registered_on_init(self): + ch = _make_channel(phone_number="+10000000000") + assert ch._id_matches_account("+10000000000") + + def test_digit_only_variant_matches(self): + ch = _make_channel(phone_number="+10000000000") + assert ch._id_matches_account("10000000000") + + def test_remember_alias_adds_uuid(self): + ch = _make_channel() + ch._remember_account_id_alias("some-uuid-abc") + assert ch._id_matches_account("some-uuid-abc") + + def test_non_matching_id_returns_false(self): + ch = _make_channel(phone_number="+10000000000") + assert not ch._id_matches_account("+19999999999") + + def test_none_and_non_string_return_false(self): + ch = _make_channel() + assert not ch._id_matches_account(None) + + +# --------------------------------------------------------------------------- +# _should_respond_in_group +# --------------------------------------------------------------------------- + + +class TestShouldRespondInGroup: + def _make_group_channel(self, require_mention: bool = True) -> SignalChannel: + return _make_channel( + phone_number="+10000000000", + group_enabled=True, + require_mention=require_mention, + ) + + def test_no_require_mention_always_responds(self): + ch = self._make_group_channel(require_mention=False) + assert ch._should_respond_in_group("anything", []) is True + + def test_require_mention_with_no_mentions_returns_false(self): + ch = self._make_group_channel(require_mention=True) + assert ch._should_respond_in_group("hello", []) is False + + def test_require_mention_with_bot_number_mention(self): + ch = self._make_group_channel(require_mention=True) + mentions = [{"number": "+10000000000", "start": 0, "length": 12}] + assert ch._should_respond_in_group(" hello", mentions) is True + + def test_require_mention_with_uuid_mention(self): + ch = self._make_group_channel(require_mention=True) + ch._remember_account_id_alias("bot-uuid-123") + mentions = [{"uuid": "bot-uuid-123", "start": 0, "length": 8}] + assert ch._should_respond_in_group(" hello", mentions) is True + + def test_identifier_less_leading_mention_accepted(self): + ch = self._make_group_channel(require_mention=True) + # Mention with no IDs but leading span — treated as bot mention + mentions = [{"start": 0, "length": 1}] + assert ch._should_respond_in_group(" hello", mentions) is True + + def test_identifier_less_non_leading_mention_rejected(self): + ch = self._make_group_channel(require_mention=True) + mentions = [{"start": 5, "length": 1}] + assert ch._should_respond_in_group("hello ", mentions) is False + + def test_leading_placeholder_without_mentions_metadata(self): + ch = self._make_group_channel(require_mention=True) + assert ch._should_respond_in_group(" hello", []) is True + + def test_phone_number_in_text_triggers_response(self): + ch = self._make_group_channel(require_mention=True) + assert ch._should_respond_in_group("hey +10000000000 help", []) is True + + +# --------------------------------------------------------------------------- +# _strip_bot_mention +# --------------------------------------------------------------------------- + + +class TestStripBotMention: + def _make_channel_with_number(self) -> SignalChannel: + return _make_channel(phone_number="+10000000000") + + def test_strips_mention_by_phone(self): + ch = self._make_channel_with_number() + text = " hello" + mentions = [{"number": "+10000000000", "start": 0, "length": 1}] + result = ch._strip_bot_mention(text, mentions) + assert result == "hello" + + def test_strips_identifier_less_leading_mention(self): + ch = self._make_channel_with_number() + text = " hello" + mentions = [{"start": 0, "length": 1}] + result = ch._strip_bot_mention(text, mentions) + assert result == "hello" + + def test_strips_leading_placeholder_without_mention_metadata(self): + ch = self._make_channel_with_number() + text = " hello" + result = ch._strip_bot_mention(text, []) + assert result == "hello" + + def test_non_bot_mention_mid_text_not_stripped(self): + # A non-bot mention that is NOT a leading placeholder leaves the text alone. + ch = self._make_channel_with_number() + text = "hello  world" + mentions = [{"number": "+19999999999", "start": 6, "length": 1}] + result = ch._strip_bot_mention(text, mentions) + # Mid-text placeholder from a non-bot mention should be untouched + assert "" in result + + def test_empty_text_returned_unchanged(self): + ch = self._make_channel_with_number() + assert ch._strip_bot_mention("", []) == "" + + +# --------------------------------------------------------------------------- +# Group message buffer +# --------------------------------------------------------------------------- + + +class TestGroupBuffer: + def test_add_and_get_context(self): + ch = _make_channel(group_buffer_size=5) + ch._add_to_group_buffer("g1", "Alice", "+1111", "first msg", 1000) + ch._add_to_group_buffer("g1", "Bob", "+2222", "second msg", 2000) + # Only messages before the latest are returned as context + ctx = ch._get_group_buffer_context("g1") + assert "first msg" in ctx + # The last message is not included (it's the "current" one) + assert "second msg" not in ctx + + def test_empty_context_when_only_one_message(self): + ch = _make_channel(group_buffer_size=5) + ch._add_to_group_buffer("g1", "Alice", "+1111", "only msg", 1000) + assert ch._get_group_buffer_context("g1") == "" + + def test_empty_context_when_group_unknown(self): + ch = _make_channel() + assert ch._get_group_buffer_context("unknown") == "" + + def test_buffer_respects_max_size(self): + ch = _make_channel(group_buffer_size=3) + for i in range(10): + ch._add_to_group_buffer("g1", "Alice", "+1111", f"msg{i}", i) + assert len(ch._group_buffers["g1"]) == 3 + + def test_zero_buffer_size_does_not_add(self): + ch = _make_channel(group_buffer_size=0) + ch._add_to_group_buffer("g1", "Alice", "+1111", "msg", 1000) + assert "g1" not in ch._group_buffers + + def test_context_limits_message_length(self): + ch = _make_channel(group_buffer_size=5) + long_msg = "x" * 500 + ch._add_to_group_buffer("g1", "Alice", "+1111", long_msg, 1000) + ch._add_to_group_buffer("g1", "Bob", "+2222", "short", 2000) + ctx = ch._get_group_buffer_context("g1") + # Context is capped at 200 chars per message + assert len(ctx.split("Alice: ", 1)[1]) <= 200 + + +# --------------------------------------------------------------------------- +# _handle_data_message — DM routing +# --------------------------------------------------------------------------- + + +class TestHandleDataMessageDM: + def _make_dm_channel(self, policy="open", allow_from=None) -> tuple[SignalChannel, list]: + ch = _make_channel(dm_enabled=True, dm_policy=policy, dm_allow_from=allow_from or []) + handled: list[dict] = [] + + async def capture(**kwargs): + handled.append(kwargs) + + ch._handle_message = capture # type: ignore[method-assign] + + async def noop_typing(chat_id): + pass + + ch._start_typing = noop_typing # type: ignore[method-assign] + return ch, handled + + @pytest.mark.asyncio + async def test_dm_open_policy_accepted(self): + ch, handled = self._make_dm_channel(policy="open") + params = _dm_envelope(source_number="+19995550001", message="hi") + await ch._handle_receive_notification(params) + assert len(handled) == 1 + assert handled[0]["chat_id"] == "+19995550001" + assert handled[0]["content"] == "hi" + + @pytest.mark.asyncio + async def test_dm_allowlist_accepted(self): + ch, handled = self._make_dm_channel(policy="allowlist", allow_from=["+19995550001"]) + params = _dm_envelope(source_number="+19995550001") + await ch._handle_receive_notification(params) + assert len(handled) == 1 + + @pytest.mark.asyncio + async def test_dm_allowlist_rejected(self): + ch, handled = self._make_dm_channel(policy="allowlist", allow_from=["+10000000001"]) + params = _dm_envelope(source_number="+19995550002") + await ch._handle_receive_notification(params) + assert handled == [] + + @pytest.mark.asyncio + async def test_dm_disabled_rejected(self): + ch = _make_channel(dm_enabled=False) + handled: list[dict] = [] + + async def capture(**kwargs): + handled.append(kwargs) + + ch._handle_message = capture # type: ignore[method-assign] + + async def noop_typing(chat_id): + pass + + ch._start_typing = noop_typing # type: ignore[method-assign] + params = _dm_envelope(source_number="+19995550001") + await ch._handle_receive_notification(params) + assert handled == [] + + @pytest.mark.asyncio + async def test_reaction_message_ignored(self): + ch, handled = self._make_dm_channel() + params = _dm_envelope(reaction={"emoji": "👍", "targetTimestamp": 999}) + await ch._handle_receive_notification(params) + assert handled == [] + + @pytest.mark.asyncio + async def test_empty_message_ignored(self): + ch, handled = self._make_dm_channel() + params = _dm_envelope(message="") + await ch._handle_receive_notification(params) + assert handled == [] + + @pytest.mark.asyncio + async def test_receipt_message_ignored(self): + ch, handled = self._make_dm_channel() + notification = { + "envelope": { + "sourceNumber": "+19995550001", + "receiptMessage": {"when": 1234}, + } + } + await ch._handle_receive_notification(notification) + assert handled == [] + + @pytest.mark.asyncio + async def test_typing_indicator_ignored(self): + ch, handled = self._make_dm_channel() + notification = { + "envelope": { + "sourceNumber": "+19995550001", + "typingMessage": {"action": "STARTED"}, + } + } + await ch._handle_receive_notification(notification) + assert handled == [] + + @pytest.mark.asyncio + async def test_missing_envelope_ignored(self): + ch, handled = self._make_dm_channel() + await ch._handle_receive_notification({}) + assert handled == [] + + @pytest.mark.asyncio + async def test_metadata_passed_to_handle(self): + ch, handled = self._make_dm_channel() + params = _dm_envelope(source_number="+19995550001", source_name="Alice", timestamp=9999) + await ch._handle_receive_notification(params) + meta = handled[0]["metadata"] + assert meta["sender_name"] == "Alice" + assert meta["timestamp"] == 9999 + assert meta["is_group"] is False + + @pytest.mark.asyncio + async def test_sender_id_with_uuid_variant(self): + ch, handled = self._make_dm_channel() + params = _dm_envelope(source_number="+19995550001", source_uuid="uuid-abc") + await ch._handle_receive_notification(params) + assert len(handled) == 1 + # sender_id combines both parts + assert "+19995550001" in handled[0]["sender_id"] + assert "uuid-abc" in handled[0]["sender_id"] + + @pytest.mark.asyncio + async def test_stop_typing_called_on_handle_error(self): + ch = _make_channel(dm_enabled=True, dm_policy="open") + typing_stopped: list[str] = [] + + async def fail_handle(**kwargs): + raise RuntimeError("boom") + + async def noop_typing(chat_id): + pass + + async def record_stop(chat_id, **kwargs): + typing_stopped.append(chat_id) + + ch._handle_message = fail_handle # type: ignore[method-assign] + ch._start_typing = noop_typing # type: ignore[method-assign] + ch._stop_typing = record_stop # type: ignore[method-assign] + + # _handle_receive_notification swallows exceptions; the typing stop + # still fires from _handle_data_message's except clause. + params = _dm_envelope(source_number="+19995550001") + await ch._handle_receive_notification(params) + + assert "+19995550001" in typing_stopped + + +# --------------------------------------------------------------------------- +# _handle_data_message — group routing +# --------------------------------------------------------------------------- + + +class TestHandleDataMessageGroup: + def _make_group_channel( + self, + policy="open", + allow_from=None, + require_mention=True, + ) -> tuple[SignalChannel, list]: + ch = _make_channel( + group_enabled=True, + group_policy=policy, + group_allow_from=allow_from or [], + require_mention=require_mention, + ) + handled: list[dict] = [] + + async def capture(**kwargs): + handled.append(kwargs) + + ch._handle_message = capture # type: ignore[method-assign] + + async def noop_typing(chat_id): + pass + + ch._start_typing = noop_typing # type: ignore[method-assign] + return ch, handled + + @pytest.mark.asyncio + async def test_group_disabled_rejected(self): + ch = _make_channel(group_enabled=False) + handled: list[dict] = [] + ch._handle_message = lambda **kw: handled.append(kw) # type: ignore[method-assign] + params = _group_envelope(group_id="grp==", message="hi") + await ch._handle_receive_notification(params) + assert handled == [] + + @pytest.mark.asyncio + async def test_group_open_policy_no_mention_blocked_when_required(self): + ch, handled = self._make_group_channel(require_mention=True) + params = _group_envelope(group_id="grp==", message="hey everyone") + await ch._handle_receive_notification(params) + assert handled == [] + + @pytest.mark.asyncio + async def test_group_open_policy_no_mention_required(self): + ch, handled = self._make_group_channel(require_mention=False) + params = _group_envelope(group_id="grp==", message="hey everyone") + await ch._handle_receive_notification(params) + assert len(handled) == 1 + assert handled[0]["chat_id"] == "grp==" + + @pytest.mark.asyncio + async def test_group_allowlist_accepted(self): + ch, handled = self._make_group_channel( + policy="allowlist", allow_from=["grp=="], require_mention=False + ) + params = _group_envelope(group_id="grp==", message="hi") + await ch._handle_receive_notification(params) + assert len(handled) == 1 + + @pytest.mark.asyncio + async def test_group_allowlist_rejected(self): + ch, handled = self._make_group_channel(policy="allowlist", allow_from=["other=="]) + params = _group_envelope(group_id="grp==", message="hi") + await ch._handle_receive_notification(params) + assert handled == [] + + @pytest.mark.asyncio + async def test_group_mention_triggers_response(self): + ch, handled = self._make_group_channel(require_mention=True) + ch._remember_account_id_alias("+10000000000") + mentions = [{"number": "+10000000000", "start": 0, "length": 1}] + params = _group_envelope(group_id="grp==", message=" hello", mentions=mentions) + await ch._handle_receive_notification(params) + assert len(handled) == 1 + + @pytest.mark.asyncio + async def test_group_v2_id_extracted(self): + ch, handled = self._make_group_channel(require_mention=False) + params = _group_envelope(group_id="grpV2==", message="hi", use_v2=True) + await ch._handle_receive_notification(params) + assert len(handled) == 1 + assert handled[0]["chat_id"] == "grpV2==" + + @pytest.mark.asyncio + async def test_group_message_includes_sender_prefix(self): + ch, handled = self._make_group_channel(require_mention=False) + params = _group_envelope(group_id="grp==", source_name="Bob", message="hello") + await ch._handle_receive_notification(params) + assert "[Bob]:" in handled[0]["content"] + + @pytest.mark.asyncio + async def test_group_message_context_prepended(self): + ch, handled = self._make_group_channel(require_mention=False) + # First message — adds to buffer but no context yet + params1 = _group_envelope(group_id="grp==", source_name="Alice", message="msg1") + await ch._handle_receive_notification(params1) + # Second message — should include context from first + params2 = _group_envelope(group_id="grp==", source_name="Bob", message="msg2") + await ch._handle_receive_notification(params2) + assert "[Recent group messages for context:]" in handled[1]["content"] + assert "msg1" in handled[1]["content"] + + @pytest.mark.asyncio + async def test_group_metadata_marks_is_group(self): + ch, handled = self._make_group_channel(require_mention=False) + params = _group_envelope(group_id="grp==", message="hi") + await ch._handle_receive_notification(params) + assert handled[0]["metadata"]["is_group"] is True + assert handled[0]["metadata"]["group_id"] == "grp==" + + @pytest.mark.asyncio + async def test_bot_account_alias_learned_from_incoming(self): + ch, handled = self._make_group_channel(require_mention=False) + # If the bot's own UUID appears in an envelope we learn it + params = _dm_envelope(source_number="+10000000000", source_uuid="new-bot-uuid") + # DMs from self are processed (learning alias), but DM policy is open + ch._handle_message = lambda **kw: handled.append(kw) # type: ignore[method-assign] + ch._start_typing = lambda chat_id: None # type: ignore[method-assign] + await ch._handle_receive_notification(params) + assert ch._id_matches_account("new-bot-uuid") + + +# --------------------------------------------------------------------------- +# Command handling +# --------------------------------------------------------------------------- + + +class TestCommandHandling: + @pytest.mark.asyncio + async def test_dm_command_forwarded_to_bus(self): + """Slash commands in DMs are forwarded to the bus for AgentLoop to handle.""" + ch = _make_channel(dm_enabled=True, dm_policy="open") + forwarded: list[dict] = [] + + async def capture(**kw): + forwarded.append(kw) + + ch._handle_message = capture # type: ignore[method-assign] + ch._start_typing = AsyncMock() + + params = _dm_envelope(source_number="+19995550001", message="/reset") + await ch._handle_receive_notification(params) + + assert len(forwarded) == 1 + assert forwarded[0]["content"].strip() == "/reset" + + @pytest.mark.asyncio + async def test_group_command_bypasses_mention_requirement(self): + """Slash commands in groups bypass the mention requirement and reach the bus.""" + ch = _make_channel( + group_enabled=True, group_policy="open", require_mention=True + ) + forwarded: list[dict] = [] + + async def capture(**kw): + forwarded.append(kw) + + ch._handle_message = capture # type: ignore[method-assign] + ch._start_typing = AsyncMock() + + params = _group_envelope(source_number="+19995550001", group_id="grp==", message="/reset") + await ch._handle_receive_notification(params) + + assert len(forwarded) == 1 + assert "/reset" in forwarded[0]["content"] + + @pytest.mark.asyncio + async def test_command_denied_for_disallowed_dm_sender(self): + """Commands from senders not on the DM allowlist are dropped.""" + ch = _make_channel(dm_enabled=False) + forwarded: list[dict] = [] + + async def capture(**kw): + forwarded.append(kw) + + ch._handle_message = capture # type: ignore[method-assign] + + params = _dm_envelope(source_number="+19995550001", message="/reset") + await ch._handle_receive_notification(params) + + assert forwarded == [] + + +# --------------------------------------------------------------------------- +# send() — outbound messages +# --------------------------------------------------------------------------- + + +class TestSend: + def _make_send_channel(self) -> tuple[SignalChannel, _FakeHTTPClient]: + ch = _make_channel() + client = _FakeHTTPClient() + ch._http = client # type: ignore[assignment] + return ch, client + + @pytest.mark.asyncio + async def test_send_plain_text_posts_rpc(self): + ch, client = self._make_send_channel() + msg = OutboundMessage(channel="signal", chat_id="+19995550001", content="hello") + await ch.send(msg) + assert len(client.posts) == 1 + payload = client.posts[0]["json"] + assert payload["method"] == "send" + assert payload["params"]["message"] == "hello" + + @pytest.mark.asyncio + async def test_send_with_markdown_includes_text_styles(self): + ch, client = self._make_send_channel() + msg = OutboundMessage(channel="signal", chat_id="+19995550001", content="**bold**") + await ch.send(msg) + params = client.posts[0]["json"]["params"] + assert "textStyle" in params + assert any("BOLD" in s for s in params["textStyle"]) + + @pytest.mark.asyncio + async def test_send_empty_content_skips_rpc(self): + ch, client = self._make_send_channel() + msg = OutboundMessage(channel="signal", chat_id="+19995550001", content="") + await ch.send(msg) + assert client.posts == [] + + @pytest.mark.asyncio + async def test_send_to_group_uses_group_id(self): + ch, client = self._make_send_channel() + msg = OutboundMessage(channel="signal", chat_id="grp==", content="hi group") + await ch.send(msg) + params = client.posts[0]["json"]["params"] + assert "groupId" in params + assert "recipient" not in params + + @pytest.mark.asyncio + async def test_send_to_dm_uses_recipient(self): + ch, client = self._make_send_channel() + msg = OutboundMessage(channel="signal", chat_id="+19995550001", content="hi") + await ch.send(msg) + params = client.posts[0]["json"]["params"] + assert "recipient" in params + + @pytest.mark.asyncio + async def test_send_with_media_includes_attachments(self): + ch, client = self._make_send_channel() + msg = OutboundMessage( + channel="signal", + chat_id="+19995550001", + content="see attachment", + media=["/tmp/file.jpg"], + ) + await ch.send(msg) + params = client.posts[0]["json"]["params"] + assert params.get("attachments") == ["/tmp/file.jpg"] + + @pytest.mark.asyncio + async def test_send_progress_message_does_not_stop_typing(self): + ch, client = self._make_send_channel() + stopped: list[str] = [] + + async def record_stop(chat_id, **kwargs): + stopped.append(chat_id) + + ch._stop_typing = record_stop # type: ignore[method-assign] + msg = OutboundMessage( + channel="signal", + chat_id="+19995550001", + content="working...", + metadata={"_progress": True}, + ) + await ch.send(msg) + # Progress messages should NOT stop the typing indicator + assert stopped == [] + + @pytest.mark.asyncio + async def test_send_final_message_stops_typing(self): + ch, client = self._make_send_channel() + stopped: list[str] = [] + + async def record_stop(chat_id, send_stop=True): + stopped.append(chat_id) + + ch._stop_typing = record_stop # type: ignore[method-assign] + msg = OutboundMessage(channel="signal", chat_id="+19995550001", content="done") + await ch.send(msg) + assert "+19995550001" in stopped + + @pytest.mark.asyncio + async def test_send_logs_daemon_error_without_raising(self): + ch = _make_channel() + # The daemon returns {"error": {...}} in the JSON body — this is not a Python + # exception; send() logs it but does not raise (only HTTP-level exceptions raise). + ch._http = _FakeHTTPClient(default_response={"error": {"message": "fail"}}) + msg = OutboundMessage(channel="signal", chat_id="+19995550001", content="hello") + await ch.send(msg) # must not raise + + +# --------------------------------------------------------------------------- +# stop() +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_stop_cancels_sse_task() -> None: + ch = _make_channel() + cancelled = False + + async def long_running(): + nonlocal cancelled + try: + await asyncio.sleep(9999) + except asyncio.CancelledError: + cancelled = True + raise + + ch._sse_task = asyncio.create_task(long_running()) + # Yield so the task can enter its body (reach the first await) before cancel. + await asyncio.sleep(0) + ch._running = True + + await ch.stop() + + assert cancelled + assert ch._running is False + + +@pytest.mark.asyncio +async def test_stop_closes_http_client() -> None: + ch = _make_channel() + client = _FakeHTTPClient() + ch._http = client # type: ignore[assignment] + ch._running = True + + await ch.stop() + + assert client.closed + + +@pytest.mark.asyncio +async def test_stop_safe_when_no_sse_task() -> None: + ch = _make_channel() + ch._running = True + # Should not raise even with no _sse_task + await ch.stop() + assert ch._running is False + + +# --------------------------------------------------------------------------- +# _send_request / _send_http_request +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_send_request_increments_id() -> None: + ch = _make_channel() + client = _FakeHTTPClient() + ch._http = client # type: ignore[assignment] + + await ch._send_request("testMethod", {"key": "val"}) + await ch._send_request("testMethod", {"key": "val"}) + + ids = [p["json"]["id"] for p in client.posts] + assert ids == [1, 2] + + +@pytest.mark.asyncio +async def test_send_request_raises_when_not_connected() -> None: + ch = _make_channel() + # _http is None by default + with pytest.raises(RuntimeError, match="Not connected"): + await ch._send_request("testMethod") + + +# --------------------------------------------------------------------------- +# _handle_receive_notification — envelope shapes +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_handle_notification_sync_message_does_not_forward() -> None: + ch = _make_channel(dm_enabled=True, dm_policy="open") + handled: list[dict] = [] + ch._handle_message = lambda **kw: handled.append(kw) # type: ignore[method-assign] + + notification = { + "envelope": { + "sourceNumber": "+19995550001", + "syncMessage": { + "sentMessage": { + "destination": "+19990000000", + "message": "sent from other device", + } + }, + } + } + await ch._handle_receive_notification(notification) + assert handled == [] + + +@pytest.mark.asyncio +async def test_handle_notification_no_source_skipped() -> None: + ch = _make_channel(dm_enabled=True, dm_policy="open") + handled: list[dict] = [] + ch._handle_message = lambda **kw: handled.append(kw) # type: ignore[method-assign] + + notification = {"envelope": {"dataMessage": {"message": "ghost"}}} + await ch._handle_receive_notification(notification) + assert handled == [] + + +# --------------------------------------------------------------------------- +# Config: allow_from property aggregation +# --------------------------------------------------------------------------- + + +def test_config_allow_from_aggregates_dm_and_group() -> None: + config = SignalConfig( + enabled=True, + phone_number="+10000000000", + dm=SignalDMConfig(enabled=True, policy="allowlist", allow_from=["+1111", "+2222"]), + group=SignalGroupConfig( + enabled=True, policy="allowlist", allow_from=["+3333", "+1111"] + ), + ) + combined = config.allow_from + assert "+1111" in combined + assert "+2222" in combined + assert "+3333" in combined + # Duplicates removed + assert combined.count("+1111") == 1 + + +def test_config_allow_from_wildcard_propagates() -> None: + config = SignalConfig( + enabled=True, + phone_number="+10000000000", + dm=SignalDMConfig(enabled=True, policy="open", allow_from=["*"]), + group=SignalGroupConfig(enabled=True, policy="open", allow_from=[]), + ) + assert "*" in config.allow_from diff --git a/tests/channels/test_signal_markdown.py b/tests/channels/test_signal_markdown.py new file mode 100644 index 000000000..15eca70ff --- /dev/null +++ b/tests/channels/test_signal_markdown.py @@ -0,0 +1,244 @@ +"""Unit tests for the Signal markdown → plain text + textStyle converter.""" + +from nanobot.channels.signal import _markdown_to_signal + + +def styles_for(plain: str, text_styles: list[str]) -> dict[str, list[str]]: + """Return a dict mapping each styled substring to its style list.""" + result: dict[str, list[str]] = {} + for entry in text_styles: + start_s, length_s, style = entry.split(":", 2) + start, length = int(start_s), int(length_s) + span = plain[start : start + length] + result.setdefault(span, []).append(style) + return result + + +# --------------------------------------------------------------------------- +# Basic cases +# --------------------------------------------------------------------------- + + +def test_empty(): + plain, styles = _markdown_to_signal("") + assert plain == "" + assert styles == [] + + +def test_plain_text(): + plain, styles = _markdown_to_signal("hello world") + assert plain == "hello world" + assert styles == [] + + +def test_bold_stars(): + plain, styles = _markdown_to_signal("say **hello** now") + assert plain == "say hello now" + assert styles_for(plain, styles) == {"hello": ["BOLD"]} + + +def test_bold_underscores(): + plain, styles = _markdown_to_signal("say __hello__ now") + assert plain == "say hello now" + assert styles_for(plain, styles) == {"hello": ["BOLD"]} + + +def test_italic_star(): + plain, styles = _markdown_to_signal("say *hello* now") + assert plain == "say hello now" + assert styles_for(plain, styles) == {"hello": ["ITALIC"]} + + +def test_italic_underscore(): + plain, styles = _markdown_to_signal("say _hello_ now") + assert plain == "say hello now" + assert styles_for(plain, styles) == {"hello": ["ITALIC"]} + + +def test_strikethrough(): + plain, styles = _markdown_to_signal("say ~~hello~~ now") + assert plain == "say hello now" + assert styles_for(plain, styles) == {"hello": ["STRIKETHROUGH"]} + + +# --------------------------------------------------------------------------- +# Code +# --------------------------------------------------------------------------- + + +def test_inline_code(): + plain, styles = _markdown_to_signal("run `ls -la` here") + assert plain == "run ls -la here" + assert styles_for(plain, styles) == {"ls -la": ["MONOSPACE"]} + + +def test_code_block(): + plain, styles = _markdown_to_signal("```\nprint('hi')\n```") + assert "print('hi')" in plain + assert styles_for(plain, styles).get("print('hi')\n") == ["MONOSPACE"] or \ + "MONOSPACE" in str(styles_for(plain, styles)) + + +def test_code_block_with_lang(): + plain, styles = _markdown_to_signal("```python\ncode\n```") + assert "code" in plain + assert any("MONOSPACE" in s for s in styles) + + +def test_code_block_not_processed_further(): + """Markdown inside a code block must not be styled.""" + plain, styles = _markdown_to_signal("```\n**not bold**\n```") + assert "**not bold**" in plain + # Only MONOSPACE should be applied, no BOLD + for entry in styles: + assert "BOLD" not in entry + + +def test_inline_code_not_processed_further(): + """Markdown inside inline code must not be styled.""" + plain, styles = _markdown_to_signal("use `**raw**` please") + assert "**raw**" in plain + for entry in styles: + assert "BOLD" not in entry + + +# --------------------------------------------------------------------------- +# Headers +# --------------------------------------------------------------------------- + + +def test_header_becomes_bold(): + plain, styles = _markdown_to_signal("# My Title") + assert plain == "My Title" + assert styles_for(plain, styles) == {"My Title": ["BOLD"]} + + +def test_h2_becomes_bold(): + plain, styles = _markdown_to_signal("## Sub-section") + assert plain == "Sub-section" + assert styles_for(plain, styles) == {"Sub-section": ["BOLD"]} + + +# --------------------------------------------------------------------------- +# Blockquotes +# --------------------------------------------------------------------------- + + +def test_blockquote_strips_marker(): + plain, styles = _markdown_to_signal("> some quote") + assert plain == "some quote" + assert styles == [] + + +# --------------------------------------------------------------------------- +# Lists +# --------------------------------------------------------------------------- + + +def test_bullet_dash(): + plain, styles = _markdown_to_signal("- item one") + assert plain == "• item one" + + +def test_bullet_star(): + plain, styles = _markdown_to_signal("* item two") + assert plain == "• item two" + + +def test_numbered_list(): + plain, styles = _markdown_to_signal("1. first\n2. second") + assert "1. first" in plain + assert "2. second" in plain + + +# --------------------------------------------------------------------------- +# Links +# --------------------------------------------------------------------------- + + +def test_link_text_differs_from_url(): + plain, styles = _markdown_to_signal("[Click here](https://example.com)") + assert plain == "Click here (https://example.com)" + assert styles == [] + + +def test_link_text_equals_url(): + plain, styles = _markdown_to_signal("[https://example.com](https://example.com)") + assert plain == "https://example.com" + assert styles == [] + + +def test_link_text_equals_url_without_scheme(): + plain, styles = _markdown_to_signal("[example.com](https://example.com)") + assert plain == "https://example.com" + + +# --------------------------------------------------------------------------- +# Mixed / nesting +# --------------------------------------------------------------------------- + + +def test_bold_and_italic_adjacent(): + plain, styles = _markdown_to_signal("**bold** and *italic*") + assert plain == "bold and italic" + sd = styles_for(plain, styles) + assert sd.get("bold") == ["BOLD"] + assert sd.get("italic") == ["ITALIC"] + + +def test_header_with_inline_code(): + """Header becomes BOLD; code inside becomes MONOSPACE (not double-BOLD).""" + plain, styles = _markdown_to_signal("# Use `grep`") + assert plain == "Use grep" + sd = styles_for(plain, styles) + assert "BOLD" in sd.get("Use ", []) or "BOLD" in str(styles) + assert "MONOSPACE" in sd.get("grep", []) + + +def test_multiline_mixed(): + md = "**Title**\n\nSome *italic* text.\n\n- bullet\n- another" + plain, styles = _markdown_to_signal(md) + assert "Title" in plain + assert "italic" in plain + assert "• bullet" in plain + sd = styles_for(plain, styles) + assert "BOLD" in sd.get("Title", []) + assert "ITALIC" in sd.get("italic", []) + + +# --------------------------------------------------------------------------- +# Table rendering +# --------------------------------------------------------------------------- + + +def test_table_rendered_as_monospace(): + md = "| A | B |\n| - | - |\n| 1 | 2 |" + plain, styles = _markdown_to_signal(md) + assert "A" in plain and "B" in plain + assert any("MONOSPACE" in s for s in styles) + + +# --------------------------------------------------------------------------- +# Style range format +# --------------------------------------------------------------------------- + + +def test_style_range_format(): + """Each style entry must be 'start:length:STYLE'.""" + _, styles = _markdown_to_signal("**bold** text") + for entry in styles: + parts = entry.split(":") + assert len(parts) == 3 + assert parts[0].isdigit() + assert parts[1].isdigit() + assert parts[2] in {"BOLD", "ITALIC", "STRIKETHROUGH", "MONOSPACE", "SPOILER"} + + +def test_style_ranges_are_within_bounds(): + text = "hello **world** end" + plain, styles = _markdown_to_signal(text) + for entry in styles: + start_s, length_s, _ = entry.split(":", 2) + start, length = int(start_s), int(length_s) + assert start >= 0 + assert start + length <= len(plain)