diff --git a/README.md b/README.md index b97545731..1dbc82db8 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,18 @@ ![cover-v5-optimized](./images/GitHub_README.png)
+

+ English | + 简体中文 | + 繁體中文 | + Español | + Français | + Bahasa Indonesia | + 日本語 | + 한국어 | + Русский | + Tiếng Việt +

PyPI Downloads diff --git a/docs/chat-apps.md b/docs/chat-apps.md index c0c1b4ba0..88242a5f7 100644 --- a/docs/chat-apps.md +++ b/docs/chat-apps.md @@ -17,6 +17,7 @@ Connect nanobot to your favorite chat platform. Want to build your own? See the | **Wecom** | Bot ID + Bot Secret | | **Microsoft Teams** | App ID + App Password + public HTTPS endpoint | | **Mochat** | Claw token (auto-setup available) | +| **Signal** | signal-cli daemon + phone number |

Telegram (Recommended) @@ -669,3 +670,69 @@ nanobot gateway ```
+ +
+Signal + +Uses **signal-cli** daemon in HTTP mode — receive messages via SSE, send via JSON-RPC. + +**1. Install signal-cli** + +Install [signal-cli](https://github.com/AsamK/signal-cli) and register a phone number: + +```bash +signal-cli -u +1234567890 register +signal-cli -u +1234567890 verify +``` + +Start the daemon: + +```bash +signal-cli -a +1234567890 daemon --http localhost:8080 +``` + +**2. Configure** + +```json +{ + "channels": { + "signal": { + "enabled": true, + "phoneNumber": "+1234567890", + "daemonHost": "localhost", + "daemonPort": 8080, + "dm": { + "enabled": true, + "policy": "open" + }, + "group": { + "enabled": true, + "policy": "open", + "requireMention": true + } + } + } +} +``` + +> - `phoneNumber`: Your registered Signal phone number. +> - `daemonHost` / `daemonPort`: Where signal-cli daemon is listening (default `localhost:8080`). +> - `dm.policy`: `"open"` (anyone can DM) or `"allowlist"` (only listed numbers/UUIDs). When `"allowlist"`, unlisted DM senders receive a pairing code. +> - `dm.allowFrom`: List of allowed phone numbers or UUIDs (used when policy is `"allowlist"`). +> - `group.policy`: `"open"` (all groups) or `"allowlist"` (only listed group IDs). +> - `group.requireMention`: When `true` (default), the bot only responds in groups when @mentioned. +> - `group.allowFrom`: List of allowed group IDs (used when group policy is `"allowlist"`). +> - `attachmentsDir`: Override the directory where signal-cli stores inbound attachments. Defaults to `~/.local/share/signal-cli/attachments` (the Linux default). Set this if signal-cli runs with a custom `XDG_DATA_HOME` or on macOS/Windows. +> - `groupMessageBufferSize`: Number of recent group messages kept for context (default `20`, must be > 0). + +**3. Run** + +```bash +nanobot gateway +``` + +> [!TIP] +> The channel automatically reconnects to the signal-cli daemon with exponential backoff if the connection drops. +> Markdown in bot replies is automatically converted to Signal text styles (bold, italic, code, etc.). + +
diff --git a/nanobot/channels/signal.py b/nanobot/channels/signal.py new file mode 100644 index 000000000..2a38f60ac --- /dev/null +++ b/nanobot/channels/signal.py @@ -0,0 +1,1402 @@ +"""Signal channel implementation using signal-cli daemon JSON-RPC interface.""" + +from __future__ import annotations + +import asyncio +import json +import re +import shutil +import unicodedata +from collections import deque +from collections.abc import AsyncIterator, Callable +from contextlib import asynccontextmanager +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +import httpx +from pydantic import Field, computed_field, field_validator + +from nanobot.bus.events import InboundMessage, OutboundMessage +from nanobot.bus.queue import MessageBus +from nanobot.channels.base import BaseChannel +from nanobot.config.paths import get_media_dir +from nanobot.config.schema import Base +from nanobot.pairing import is_approved +from nanobot.utils.helpers import safe_filename, split_message + + +@dataclass +class _Run: + text: str + styles: frozenset[str] = field(default_factory=frozenset) + opaque: bool = False # code / table content — skip further pattern processing + + +_SIG_CODE_BLOCK_RE = re.compile(r"```(?:\w+)?\n?([\s\S]*?)```") +_SIG_INLINE_CODE_RE = re.compile(r"`([^`\n]+)`") +_SIG_HEADER_RE = re.compile(r"^#{1,6}\s+(.+)$", re.MULTILINE) +_SIG_BLOCKQUOTE_RE = re.compile(r"^>\s*(.*)$", re.MULTILINE) +_SIG_BULLET_RE = re.compile(r"^[-*]\s+", re.MULTILINE) +_SIG_OLIST_RE = re.compile(r"^(\d+)\.\s+", re.MULTILINE) +_SIG_LINK_RE = re.compile(r"\[([^\]]+)\]\(([^)]+)\)") +_SIG_BOLD_RE = re.compile(r"\*\*(.+?)\*\*|__(.+?)__", re.DOTALL) +_SIG_ITALIC_RE = re.compile( + r"(? int: + """UTF-16 code-unit length, matching Signal BodyRange semantics.""" + return len(s.encode("utf-16-le")) // 2 + + +def _sig_strip_cell(s: str) -> str: + """Strip inline markdown from a table cell for plain-text rendering.""" + for pattern, repl in _SIG_CELL_STRIP_PATTERNS: + s = pattern.sub(repl, s) + return s.strip() + + +def _sig_render_table(table_lines: list[str]) -> str: + """Render a markdown pipe-table as fixed-width plain text.""" + + def dw(s: str) -> int: + return sum(2 if unicodedata.east_asian_width(c) in ("W", "F") else 1 for c in s) + + rows: list[list[str]] = [] + has_sep = False + for line in table_lines: + cells = [_sig_strip_cell(c) for c in line.strip().strip("|").split("|")] + if all(re.match(r"^:?-+:?$", c) for c in cells if c): + has_sep = True + continue + rows.append(cells) + if not rows or not has_sep: + return "\n".join(table_lines) + + ncols = max(len(r) for r in rows) + for r in rows: + r.extend([""] * (ncols - len(r))) + widths = [max(dw(r[c]) for r in rows) for c in range(ncols)] + + def dr(cells: list[str]) -> str: + return " ".join(f"{c}{' ' * (w - dw(c))}" for c, w in zip(cells, widths)) + + out = [dr(rows[0])] + out.append(" ".join("─" * w for w in widths)) + for row in rows[1:]: + out.append(dr(row)) + return "\n".join(out) + + +def _markdown_to_signal(text: str) -> tuple[str, list[str]]: + """Convert markdown text to Signal plain text + textStyle ranges. + + Returns ``(plain_text, text_styles)`` where ``text_styles`` are + ``"start:length:STYLE"`` strings for the signal-cli ``textStyle`` parameter. + """ + if not text: + return text, [] + + # Phase 1 (text-level): extract code blocks and tables with placeholder tokens + # so they're protected from inline-style processing. + protected: list[str] = [] + + def save_code(m: re.Match) -> str: + protected.append(m.group(1)) + return f"\x00C{len(protected) - 1}\x00" + + text = _SIG_CODE_BLOCK_RE.sub(save_code, text) + + # Detect and render pipe-tables line by line. + lines = text.split("\n") + rebuilt: list[str] = [] + i = 0 + while i < len(lines): + if re.match(r"^\s*\|.+\|", lines[i]): + tbl: list[str] = [] + while i < len(lines) and re.match(r"^\s*\|.+\|", lines[i]): + tbl.append(lines[i]) + i += 1 + rendered = _sig_render_table(tbl) + if rendered != "\n".join(tbl): + protected.append(rendered) + rebuilt.append(f"\x00C{len(protected) - 1}\x00") + else: + rebuilt.extend(tbl) + else: + rebuilt.append(lines[i]) + i += 1 + text = "\n".join(rebuilt) + + # Phase 2 (run-based): process inline patterns. + runs: list[_Run] = [_Run(text)] + + def transform( + pattern: re.Pattern, + make_runs: Callable[[re.Match, frozenset[str]], list[_Run]], + ) -> None: + new_runs: list[_Run] = [] + for run in runs: + if run.opaque: + new_runs.append(run) + continue + pos = 0 + for m in pattern.finditer(run.text): + if m.start() > pos: + new_runs.append(_Run(run.text[pos : m.start()], run.styles)) + new_runs.extend(make_runs(m, run.styles)) + pos = m.end() + if pos < len(run.text): + new_runs.append(_Run(run.text[pos:], run.styles)) + runs[:] = new_runs + + # Restore code/table placeholders as opaque MONOSPACE runs. + transform( + _SIG_TOKEN_RE, + lambda m, s: [_Run(protected[int(m.group(1))], s | {"MONOSPACE"}, opaque=True)], + ) + + # Inline code (opaque). + transform(_SIG_INLINE_CODE_RE, lambda m, s: [_Run(m.group(1), s | {"MONOSPACE"}, opaque=True)]) + + # Headers → bold plain text. + transform(_SIG_HEADER_RE, lambda m, s: [_Run(m.group(1), s | {"BOLD"})]) + + # Blockquotes → strip marker. + transform(_SIG_BLOCKQUOTE_RE, lambda m, s: [_Run(m.group(1), s)]) + + # Bullet lists → bullet character. + transform(_SIG_BULLET_RE, lambda m, s: [_Run("• ", s)]) + + # Numbered lists → normalize spacing. + transform(_SIG_OLIST_RE, lambda m, s: [_Run(m.group(1) + ". ", s)]) + + # Links → "text (url)" or bare url when text equals url. + def _link_runs(m: re.Match, s: frozenset) -> list[_Run]: + link_text, url = m.group(1), m.group(2) + + def _norm(u: str) -> str: + return re.sub(r"^https?://(www\.)?", "", u).rstrip("/").lower() + + if _norm(url) == _norm(link_text): + return [_Run(url, s)] + return [_Run(f"{link_text} ({url})", s)] + + transform(_SIG_LINK_RE, _link_runs) + + # Bold (before italic so ** doesn't interfere). + transform(_SIG_BOLD_RE, lambda m, s: [_Run(m.group(1) or m.group(2), s | {"BOLD"})]) + + # Italic (single * or _). + transform(_SIG_ITALIC_RE, lambda m, s: [_Run(m.group(1) or m.group(2), s | {"ITALIC"})]) + + # Strikethrough: ~~text~~ (standard) or ~text~ (single-tilde variant). + transform(_SIG_STRIKE_RE, lambda m, s: [_Run(m.group(1) or m.group(2), s | {"STRIKETHROUGH"})]) + + # Phase 3: assemble output. Offsets and lengths are emitted in UTF-16 code + # units because Signal's BodyRange (via signal-cli's textStyle) interprets + # them as such; Python's len() counts code points, which would shift ranges + # left by 1 unit per non-BMP character preceding them. + plain_text = "" + text_styles: list[str] = [] + utf16_offset = 0 + for run in runs: + if not run.text: + continue + plain_text += run.text + start = utf16_offset + length = _utf16_len(run.text) + utf16_offset += length + for style in sorted(run.styles): + text_styles.append(f"{start}:{length}:{style}") + + return plain_text, text_styles + + +def _partition_styles( + plain_text: str, chunks: list[str], text_styles: list[str] +) -> list[list[str]]: + """Partition Signal textStyle ranges across message chunks. + + ``split_message`` slices ``plain_text`` into pieces (optionally trimming + whitespace at the boundaries), but the style ranges produced by + ``_markdown_to_signal`` are expressed in UTF-16 offsets relative to the + full ``plain_text``. This redistributes them per chunk with offsets + rebased to each chunk's start. Ranges that span a boundary are split + across the chunks they touch; ranges that fall entirely in trimmed + whitespace are dropped. + """ + if not chunks: + return [] + if not text_styles: + return [[] for _ in chunks] + + # Locate each chunk's UTF-16 start in plain_text. split_message lstrips at + # boundaries (but not before the first chunk), so we skip whitespace + # between chunks to mirror that. + chunk_ranges: list[tuple[int, int]] = [] + cursor = 0 # Python codepoint cursor in plain_text + for i, chunk in enumerate(chunks): + if i > 0: + while cursor < len(plain_text) and plain_text[cursor].isspace(): + cursor += 1 + utf16_start = _utf16_len(plain_text[:cursor]) + utf16_end = utf16_start + _utf16_len(chunk) + chunk_ranges.append((utf16_start, utf16_end)) + cursor += len(chunk) + + result: list[list[str]] = [[] for _ in chunks] + for entry in text_styles: + s, ln, style = entry.split(":", 2) + r_start = int(s) + r_end = r_start + int(ln) + for i, (c_start, c_end) in enumerate(chunk_ranges): + if r_end <= c_start or r_start >= c_end: + continue + new_start = max(r_start, c_start) - c_start + new_end = min(r_end, c_end) - c_start + new_length = new_end - new_start + if new_length > 0: + result[i].append(f"{new_start}:{new_length}:{style}") + return result + + +class SignalDMConfig(Base): + """Signal DM policy configuration.""" + + enabled: bool = False + policy: str = "allowlist" # "open" or "allowlist" + allow_from: list[str] = Field(default_factory=list) # Allowed phone numbers/UUIDs + + +class SignalGroupConfig(Base): + """Signal group policy configuration.""" + + enabled: bool = False + policy: str = "allowlist" # "open" or "allowlist" - which groups to operate in + allow_from: list[str] = Field(default_factory=list) # Allowed group IDs if allowlist policy + require_mention: bool = True # Whether bot must be mentioned to respond + + +class SignalConfig(Base): + """Signal channel configuration using signal-cli daemon (HTTP mode with -a flag only).""" + + enabled: bool = False + phone_number: str = "" # Your Signal phone number (e.g., "+1234567890") + daemon_host: str = "localhost" + daemon_port: int = 8080 + group_message_buffer_size: int = 20 # Number of recent group messages to keep for context + # Override the directory signal-cli writes inbound attachments to. When + # None, defaults to ~/.local/share/signal-cli/attachments (the daemon's + # platform default on Linux). Set this if the daemon is running with a + # custom XDG_DATA_HOME or on macOS/Windows where the default path differs. + attachments_dir: str | None = None + dm: SignalDMConfig = Field(default_factory=SignalDMConfig) + group: SignalGroupConfig = Field(default_factory=SignalGroupConfig) + + @field_validator("group_message_buffer_size") + @classmethod + def _validate_buffer_size(cls, v: int) -> int: + if v <= 0: + raise ValueError("group_message_buffer_size must be > 0") + return v + + @computed_field # type: ignore[prop-decorator] + @property + def allow_from(self) -> list[str]: + """Aggregate allowlist for the base-class is_allowed() check. + + Returns the union of dm.allow_from and group.allow_from so the base + channel gate sees a populated list when either sub-policy is configured. + A ``"*"`` wildcard in either sub-list propagates to allow all. + """ + return list(dict.fromkeys(self.dm.allow_from + self.group.allow_from)) + + +class SignalChannel(BaseChannel): + """ + Signal channel using signal-cli daemon via HTTP JSON-RPC interface. + + Requires signal-cli daemon in HTTP mode: + - signal-cli -a +1234567890 daemon --http localhost:8080 + + See https://github.com/AsamK/signal-cli for setup instructions. + """ + + name = "signal" + display_name = "Signal" + _TYPING_REFRESH_SECONDS = 10.0 + _MAX_MESSAGE_LEN = 64_000 # signal-cli practical limit (protocol max ~64 KB) + _HTTP_TIMEOUT_SECONDS = 60.0 + + @classmethod + def default_config(cls) -> dict[str, Any]: + return SignalConfig().model_dump(by_alias=True) + + def __init__(self, config: SignalConfig, bus: MessageBus): + if isinstance(config, dict): + config = SignalConfig.model_validate(config) + super().__init__(config, bus) + self.config: SignalConfig = config + self._http: httpx.AsyncClient | None = None + self._request_id = 0 + self._sse_task: asyncio.Task | None = None + self._typing_tasks: dict[str, asyncio.Task] = {} + self._typing_uuid_warnings: set[str] = set() + self._account_id_aliases: set[str] = set() + self._remember_account_id_alias(self.config.phone_number) + + # Rolling message buffer for group context (group_id -> deque of messages) + # Each message is a dict with: sender_name, sender_number, content, timestamp + self._group_buffers: dict[str, deque] = {} + + def is_allowed(self, sender_id: str) -> bool: + """Override base check to normalize and split pipe-joined identifiers. + + ``sender_id`` from Signal is the pipe-joined composite produced by + ``_collect_sender_id_parts``; allow_from entries may be single + identifiers or composites and may use the ``+`` prefix variant or + not. Delegates to ``_sender_matches_allowlist`` so the base gate + matches the per-policy DM gate. + """ + allow_list = self.config.allow_from + if "*" in allow_list: + return True + if self._sender_matches_allowlist(sender_id, allow_list): + return True + if self._sender_approved_via_pairing(sender_id): + return True + if not allow_list: + self.logger.warning("allow_from is empty — all access denied") + return False + + def _sender_approved_via_pairing(self, sender_id: str) -> bool: + """Return True if any normalized variant of sender_id is in the pairing store. + + Pairing approval may be recorded under any of the identifier forms + signal exposes (phone with/without ``+``, UUID, ACI), so we check + each part of the pipe-joined composite against ``is_approved``. + """ + for part in str(sender_id).split("|"): + for variant in self._normalize_signal_id(part): + if is_approved(self.name, variant): + return True + return False + + async def _handle_message( + self, + sender_id: str, + chat_id: str, + content: str, + media: list[str] | None = None, + metadata: dict[str, Any] | None = None, + session_key: str | None = None, + is_dm: bool = False, + ) -> None: + """Handle an inbound message whose policy has already been checked. + + ``_check_inbound_policy`` is the authoritative gate for DM/group + access, so we skip the base-class ``is_allowed()`` check and publish + directly to the bus. The denied-DM pairing path calls + ``super()._handle_message`` instead, which goes through + ``is_allowed`` and issues a pairing code. + """ + meta = metadata or {} + if self.supports_streaming: + meta = {**meta, "_wants_stream": True} + await self.bus.publish_inbound( + InboundMessage( + channel=self.name, + sender_id=str(sender_id), + chat_id=str(chat_id), + content=content, + media=media or [], + metadata=meta, + session_key_override=session_key, + ) + ) + + async def start(self) -> None: + """Start the Signal channel and connect to signal-cli daemon.""" + if not self.config.phone_number: + self.logger.error("Signal account not configured") + return + + self._running = True + await self._start_http_mode() + + async def _start_http_mode(self) -> None: + """Start Signal channel using Server-Sent Events for receiving messages.""" + base_url = f"http://{self.config.daemon_host}:{self.config.daemon_port}" + reconnect_delay_s = 1.0 + max_reconnect_delay_s = 30.0 + + while self._running: + try: + self.logger.info("Connecting to signal-cli daemon at {}...", base_url) + + # Create HTTP client + self._http = httpx.AsyncClient( + timeout=self._HTTP_TIMEOUT_SECONDS, base_url=base_url + ) + + # Test connection + try: + response = await self._http.get("/api/v1/check") + if response.status_code == 200: + self.logger.info("Connected to signal-cli daemon") + else: + raise ConnectionRefusedError( + f"signal-cli daemon check returned status {response.status_code}" + ) + except Exception as e: + raise ConnectionRefusedError(f"signal-cli daemon not responding: {e}") + + # Reset reconnect delay after successful connection check. + reconnect_delay_s = 1.0 + + # Ensure account-level typing indicators are enabled. + await self._ensure_typing_indicators_enabled() + + # Start SSE receiver and supervise it. If it exits while we're still + # running, treat it as a disconnect and reconnect. + self._sse_task = asyncio.create_task(self._sse_receive_loop()) + await self._sse_task + if self._running: + raise ConnectionError("Signal SSE stream ended unexpectedly") + + except asyncio.CancelledError: + break + except ConnectionRefusedError as e: + self.logger.error( + "{}. Make sure signal-cli daemon is running: " + "signal-cli -a {} daemon --http {}:{}", + e, + self.config.phone_number, + self.config.daemon_host, + self.config.daemon_port, + ) + except Exception as e: + self.logger.error("Signal channel error: {}", e) + finally: + if self._sse_task: + if not self._sse_task.done(): + self._sse_task.cancel() + try: + await self._sse_task + except asyncio.CancelledError: + pass + except Exception: + pass + self._sse_task = None + if self._http: + await self._http.aclose() + self._http = None + + if self._running: + self.logger.info( + "Reconnecting to signal-cli daemon in {:.0f} seconds...", reconnect_delay_s + ) + await asyncio.sleep(reconnect_delay_s) + reconnect_delay_s = min(reconnect_delay_s * 2, max_reconnect_delay_s) + + async def stop(self) -> None: + """Stop the Signal channel.""" + self._running = False + + # Stop SSE task + if self._sse_task: + self._sse_task.cancel() + try: + await self._sse_task + except asyncio.CancelledError: + pass + + # Cancel active typing indicators + for chat_id in list(self._typing_tasks): + await self._stop_typing(chat_id) + + # Close HTTP client + if self._http: + await self._http.aclose() + self._http = None + + async def send(self, msg: OutboundMessage) -> None: + """Send a message through Signal.""" + is_progress_message = bool(msg.metadata.get("_progress")) + try: + plain_text, text_styles = _markdown_to_signal(msg.content) + if not plain_text and not msg.media: + return + recipient_params = self._recipient_params(msg.chat_id) + + chunks = split_message(plain_text, self._MAX_MESSAGE_LEN) if plain_text else [""] + chunk_styles = _partition_styles(plain_text, chunks, text_styles) + for i, chunk in enumerate(chunks): + params: dict[str, Any] = {"message": chunk} + if chunk_styles[i]: + params["textStyle"] = chunk_styles[i] + params.update(recipient_params) + if msg.media and i == 0: + params["attachments"] = msg.media + + response = await self._send_request("send", params) + + if "error" in response: + self.logger.error("Error sending Signal message: {}", response['error']) + raise RuntimeError(f"signal-cli send failed: {response['error']}") + else: + self.logger.debug( + f"Signal message sent, timestamp: {response.get('result', {}).get('timestamp')}" + ) + + except Exception: + self.logger.exception("Error sending Signal message") + raise + finally: + # Keep typing active across progress updates; stop on the final reply. + if not is_progress_message: + # Avoid immediate START->STOP for fast responses, which can be invisible + # in some Signal clients. Let indicator expire naturally (~15s). + await self._stop_typing(msg.chat_id, send_stop=False) + + async def _sse_receive_loop(self) -> None: + """Receive messages via Server-Sent Events (HTTP mode).""" + if not self._http: + raise RuntimeError("HTTP client not initialized for Signal SSE stream") + + self.logger.info("Started Signal message receive loop (SSE)") + + try: + async with self._http.stream("GET", "/api/v1/events") as response: + if response.status_code != 200: + raise ConnectionError( + f"SSE connection failed with status {response.status_code}" + ) + + self.logger.info("Subscribed to Signal messages via SSE") + + # Buffer for accumulating SSE data across multiple lines + event_buffer = [] + + async for line in response.aiter_lines(): + if not self._running: + break + + # Debug: log raw SSE lines (except keepalive pings) + if line and line != ":": + self.logger.debug("SSE line received: {}", line[:200]) + + # SSE format handling + if isinstance(line, str): + # Empty line signals end of event + if not line or line == ":": + if event_buffer: + # Try to parse the accumulated data + data_str = "" + try: + data_str = "\n".join(event_buffer) + data = json.loads(data_str) + self.logger.debug("SSE event parsed: {}", data) + await self._handle_receive_notification(data) + except json.JSONDecodeError as e: + self.logger.warning( + "Invalid JSON in SSE buffer: {}, data: {}", + e, + data_str[:200], + ) + finally: + event_buffer = [] + + # "data:" line - accumulate it + elif line.startswith("data:"): + # SSE spec: strip one optional leading space after "data:". + event_buffer.append(line[6:] if line[5:6] == " " else line[5:]) + + # "event:" line - just log it (we only care about data) + elif line.startswith("event:"): + pass # Ignore event type for now + + if self._running: + raise ConnectionError("Signal SSE stream closed by remote endpoint") + + except asyncio.CancelledError: + self.logger.info("SSE receive loop cancelled") + raise + except Exception as e: + self.logger.error("Error in SSE receive loop: {}", e) + raise + + @asynccontextmanager + async def _safe_handle(self, action: str, payload: Any = None) -> AsyncIterator[None]: + """Swallow and log any exception from a top-level handler block. + + Logs `self.logger.error` with the action name, the exception, and a + bounded ``repr`` of the offending payload so the offending input is + recoverable from logs without having to correlate by timestamp. + """ + try: + yield + except Exception as e: + snippet = repr(payload)[:200] if payload is not None else "" + text = f"Error in {action}: {e}" + if snippet: + text += f" | payload={snippet}" + self.logger.opt(exception=True).error(text) + + async def _handle_receive_notification(self, params: dict[str, Any]) -> None: + """Handle incoming message notification from signal-cli.""" + self.logger.debug("_handle_receive_notification called with: {}", params) + async with self._safe_handle("receive notification", params): + # Extract envelope from SSE notification: {"envelope": {...}} + envelope = params.get("envelope", {}) + + self.logger.debug("Extracted envelope: {}", envelope) + + if not envelope: + self.logger.debug("No envelope found in params") + return + + # Extract sender information + sender_parts = self._collect_sender_id_parts(envelope) + source_name = envelope.get("sourceName") + + if not sender_parts: + self.logger.debug("Received message without source, skipping") + return + + sender_number = self._primary_sender_id(sender_parts) + sender_id = "|".join(sender_parts) + + # Keep aliases of the bot account for robust mention matching. + if any(self._id_matches_account(part) for part in sender_parts): + for part in sender_parts: + self._remember_account_id_alias(part) + + # Check different message types + data_message = envelope.get("dataMessage") + sync_message = envelope.get("syncMessage") + typing_message = envelope.get("typingMessage") + receipt_message = envelope.get("receiptMessage") + + # Ignore receipt messages (delivery/read receipts) + if receipt_message: + return + + # Handle data messages (incoming messages from others) + if data_message: + await self._handle_data_message(sender_id, sender_number, data_message, source_name) + + # Handle sync messages (messages sent from another device) + elif sync_message and sync_message.get("sentMessage"): + sent_msg = sync_message["sentMessage"] + destination = sent_msg.get("destination") or sent_msg.get("destinationNumber") + if destination: + self.logger.debug( + "Sync message sent to {}: {}", destination, sent_msg.get("message", "")[:50] + ) + + # Handle typing indicators (silently ignore) + elif typing_message: + pass # Ignore typing indicators + + async def _handle_data_message( + self, + sender_id: str, + sender_number: str, + data_message: dict[str, Any], + sender_name: str | None, + ) -> None: + """Handle a data message (text, attachments, etc.).""" + message_text = data_message.get("message") or "" + attachments = data_message.get("attachments", []) + mentions = data_message.get("mentions", []) + timestamp = data_message.get("timestamp") + + self.logger.info( + "Data message from {}: groupInfo={}, groupV2={}, keys={}", + sender_number, + data_message.get("groupInfo"), + data_message.get("groupV2"), + list(data_message.keys()), + ) + + if data_message.get("reaction"): + self.logger.debug( + "Ignoring reaction message from {}: {}", sender_number, data_message["reaction"] + ) + return + if not message_text and not attachments: + self.logger.debug("Ignoring empty message from {}", sender_number) + return + + group_info = data_message.get("groupInfo") + group_v2 = data_message.get("groupV2") + is_group_message = group_info is not None or group_v2 is not None + group_id = self._extract_group_id(group_info, group_v2) + + allowed, chat_id = self._check_inbound_policy( + sender_id=sender_id, + sender_number=sender_number, + group_id=group_id, + is_group_message=is_group_message, + message_text=message_text, + mentions=mentions, + sender_name=sender_name, + timestamp=timestamp, + ) + if not allowed: + # Mirror Slack: let denied DMs reach the base-class + # _handle_message so it can reply with a pairing code. + # Group denials stay dropped. + if not is_group_message and self.config.dm.enabled: + await super()._handle_message( + sender_id=sender_id, + chat_id=chat_id, + content="", + is_dm=True, + ) + return + + content, media_paths = self._assemble_inbound_content( + sender_name=sender_name, + sender_number=sender_number, + message_text=message_text, + attachments=attachments, + mentions=mentions, + is_group_message=is_group_message, + chat_id=chat_id, + ) + + self.logger.debug("Signal message from {}: {}...", sender_number, content[:50]) + + await self._start_typing(chat_id) + try: + await self._handle_message( + sender_id=sender_id, + chat_id=chat_id, + content=content, + media=media_paths, + metadata={ + "timestamp": timestamp, + "sender_name": sender_name, + "sender_number": sender_number, + "is_group": is_group_message, + "group_id": group_id, + }, + is_dm=not is_group_message, + ) + except Exception: + await self._stop_typing(chat_id) + raise + + def _check_inbound_policy( + self, + *, + sender_id: str, + sender_number: str, + group_id: str | None, + is_group_message: bool, + message_text: str, + mentions: list, + sender_name: str | None, + timestamp: int | None, + ) -> tuple[bool, str]: + """Decide whether to route an inbound message past DM/group policy. + + Returns ``(allow, chat_id)``. Has one side effect: when a group + message passes the enabled+allowlist gates, it is appended to the + group's rolling context buffer before the mention check. + """ + if is_group_message: + chat_id = group_id or sender_number + if not self.config.group.enabled: + self.logger.info("Ignoring group message from {} (groups disabled)", chat_id) + return False, chat_id + if ( + self.config.group.policy == "allowlist" + and chat_id not in self.config.group.allow_from + ): + self.logger.info( + "Ignoring group message from {} (policy: {})", + chat_id, + self.config.group.policy, + ) + return False, chat_id + + self._add_to_group_buffer( + group_id=chat_id, + sender_name=sender_name or sender_number, + sender_number=sender_number, + message_text=message_text, + timestamp=timestamp, + ) + + is_command = bool(message_text and message_text.strip().startswith("/")) + if not is_command and not self._should_respond_in_group(message_text, mentions): + self.logger.info( + "Ignoring group message (require_mention: {})", + self.config.group.require_mention, + ) + return False, chat_id + return True, chat_id + + # Direct message + chat_id = sender_number + if not self.config.dm.enabled: + self.logger.debug("Ignoring DM from {} (DMs disabled)", sender_id) + return False, chat_id + if self.config.dm.policy == "allowlist": + if not self._sender_matches_allowlist(sender_id, self.config.dm.allow_from): + self.logger.debug( + "Ignoring DM from {} (policy: {})", sender_id, self.config.dm.policy + ) + return False, chat_id + return True, chat_id + + def _assemble_inbound_content( + self, + *, + sender_name: str | None, + sender_number: str, + message_text: str, + attachments: list, + mentions: list, + is_group_message: bool, + chat_id: str, + ) -> tuple[str, list[str]]: + """Build ``(content, media_paths)`` for an inbound message. + + Pulls in group context, strips bot mentions, prefixes the sender's + display name on group messages, and copies any attachments from + signal-cli's storage into the channel media dir. + """ + content_parts: list[str] = [] + media_paths: list[str] = [] + + if is_group_message: + buffer_context = self._get_group_buffer_context(chat_id) + if buffer_context: + content_parts.append(f"[Recent group messages for context:]\n{buffer_context}\n---") + + if message_text: + if is_group_message: + message_text = self._strip_bot_mention(message_text, mentions) + display_name = sender_name or sender_number + message_text = f"[{display_name}]: {message_text}" + content_parts.append(message_text) + + if attachments: + media_dir = get_media_dir("signal") + for attachment in attachments: + attachment_id = attachment.get("id") + content_type = attachment.get("contentType", "") + filename = attachment.get("filename") or f"attachment_{attachment_id}" + if not attachment_id: + continue + try: + source_path = self._signal_attachments_dir() / attachment_id + if source_path.exists(): + dest_path = media_dir / f"signal_{safe_filename(filename)}" + shutil.copy2(source_path, dest_path) + media_paths.append(str(dest_path)) + media_type = content_type.split("/")[0] if "/" in content_type else "file" + if media_type not in ("image", "audio", "video"): + media_type = "file" + content_parts.append(f"[{media_type}: {dest_path}]") + self.logger.debug("Downloaded attachment: {} -> {}", filename, dest_path) + else: + self.logger.warning("Attachment not found: {}", source_path) + content_parts.append(f"[attachment: {filename} - not found]") + except Exception as e: + self.logger.warning("Failed to process attachment {}: {}", filename, e) + content_parts.append(f"[attachment: {filename} - error]") + + content = "\n".join(content_parts) if content_parts else "[empty message]" + return content, media_paths + + def _add_to_group_buffer( + self, + group_id: str, + sender_name: str, + sender_number: str, + message_text: str, + timestamp: int | None, + ) -> None: + """ + Add a message to the group's rolling buffer. + + Args: + group_id: The group ID + sender_name: Display name of sender + sender_number: Phone number of sender + message_text: The message content + timestamp: Message timestamp + """ + # Create buffer for this group if it doesn't exist + if group_id not in self._group_buffers: + self._group_buffers[group_id] = deque(maxlen=self.config.group_message_buffer_size) + + # Add message to buffer (deque will automatically drop oldest when full) + self._group_buffers[group_id].append( + { + "sender_name": sender_name, + "sender_number": sender_number, + "content": message_text, + "timestamp": timestamp, + } + ) + + self.logger.debug( + "Added message to group buffer {}: {}/{}", + group_id, + len(self._group_buffers[group_id]), + self.config.group_message_buffer_size, + ) + + def _get_group_buffer_context(self, group_id: str) -> str: + """ + Get formatted context from the group's message buffer. + + Args: + group_id: The group ID + + Returns: + Formatted string of recent messages (excluding the current one) + """ + if group_id not in self._group_buffers: + return "" + + buffer = self._group_buffers[group_id] + if len(buffer) <= 1: # Only current message, no context + return "" + + # Format all messages except the last one (which is the current message) + # We want to show context BEFORE the mention + context_messages = list(buffer)[:-1] # Exclude the last (current) message + + lines = [] + for msg in context_messages: + sender = msg["sender_name"] + content = msg["content"][:200] # Limit to 200 chars per message + lines.append(f"{sender}: {content}") + + return "\n".join(lines) + + def _signal_attachments_dir(self) -> Path: + """Return the directory signal-cli writes inbound attachments to. + + Defaults to ``~/.local/share/signal-cli/attachments`` (the daemon's + platform default on Linux) when ``config.attachments_dir`` is unset. + """ + configured = self.config.attachments_dir + if configured: + return Path(configured).expanduser() + return Path.home() / ".local/share/signal-cli/attachments" + + @staticmethod + def _normalize_signal_id(value: str) -> list[str]: + """Normalize Signal identifiers (phone/uuid/service-id) for matching.""" + raw = value.strip() + if not raw: + return [] + + normalized = [raw, raw.lower()] + if raw.startswith("+") and len(raw) > 1: + normalized.append(raw[1:]) + elif raw.isdigit(): + normalized.append(f"+{raw}") + return list(dict.fromkeys(normalized)) + + @classmethod + def _sender_matches_allowlist(cls, sender_id: str, allow_list: list[str]) -> bool: + """Return True if any normalized variant of sender_id is on allow_list. + + Both ``sender_id`` and each allow_list entry can be a single + identifier or a pipe-joined composite of several (e.g. + ``"+1234567890|uuid-abc"``); both sides are split on ``|`` and each + part is run through ``_normalize_signal_id`` so an allowlist entry + like ``1234567890`` matches a sender ``+1234567890`` (and vice + versa), and case-only differences in UUIDs/ACIs match too. + """ + if not allow_list: + return False + sender_variants: set[str] = set() + for part in str(sender_id).split("|"): + sender_variants.update(cls._normalize_signal_id(part)) + if not sender_variants: + return False + allow_variants: set[str] = set() + for entry in allow_list: + for part in str(entry).split("|"): + allow_variants.update(cls._normalize_signal_id(part)) + return bool(sender_variants & allow_variants) + + def _remember_account_id_alias(self, value: str | None) -> None: + """Remember known bot identifiers for mention matching.""" + if not value: + return + if not isinstance(value, str): + return + for candidate in self._normalize_signal_id(value): + self._account_id_aliases.add(candidate) + + def _id_matches_account(self, value: str | None) -> bool: + """Return True when an identifier refers to the bot account.""" + if not value: + return False + if not isinstance(value, str): + return False + return any( + candidate in self._account_id_aliases for candidate in self._normalize_signal_id(value) + ) + + @staticmethod + def _collect_sender_id_parts(envelope: dict[str, Any]) -> list[str]: + """Collect all known sender identifier variants from an envelope.""" + parts: list[str] = [] + for key in ( + "sourceNumber", + "source", + "sourceUuid", + "sourceServiceId", + "sourceAci", + "sourceACI", + ): + value = envelope.get(key) + if not isinstance(value, str): + continue + candidate = value.strip() + if candidate and candidate not in parts: + parts.append(candidate) + return parts + + @staticmethod + def _primary_sender_id(sender_parts: list[str]) -> str: + """Pick the best sender identifier for routing (prefer phone-like IDs).""" + for part in sender_parts: + if part.startswith("+") or part.isdigit(): + return part + return sender_parts[0] if sender_parts else "" + + @staticmethod + def _extract_group_id(group_info: Any, group_v2: Any) -> str | None: + """Extract group ID from groupInfo/groupV2 payloads across signal-cli variants.""" + for group_obj in (group_info, group_v2): + if not isinstance(group_obj, dict): + continue + for key in ("groupId", "id", "groupID"): + value = group_obj.get(key) + if isinstance(value, str) and value: + return value + return None + + @staticmethod + def _mention_id_candidates(mention: dict[str, Any]) -> list[str]: + """Extract possible identifier fields from a mention payload.""" + ids: list[str] = [] + + def _walk(value: dict[str, Any] | Any, depth: int = 0) -> None: + if depth > 2: + return + if not isinstance(value, dict): + return + for key, child in value.items(): + key_lower = str(key).lower() + if isinstance(child, str) and child: + if any(token in key_lower for token in ("number", "uuid", "serviceid", "aci")): + ids.append(child) + elif isinstance(child, dict): + _walk(child, depth + 1) + + _walk(mention) + return list(dict.fromkeys(ids)) + + @staticmethod + def _mention_span(mention: dict[str, Any]) -> tuple[int, int] | None: + """Extract a safe (start, length) span from a mention.""" + try: + start = int(mention.get("start", 0)) + length = int(mention.get("length", 0)) + except (TypeError, ValueError): + return None + + if start < 0 or length <= 0: + return None + return (start, length) + + @staticmethod + def _leading_placeholder_span(text: str | None) -> tuple[int, int] | None: + """ + Detect a leading Signal mention placeholder when mention metadata is missing. + + Some clients/integrations deliver mentions as a leading placeholder character + (typically U+FFFC) but omit `mentions` metadata in the payload. + """ + if not text: + return None + + start = 0 + while start < len(text) and text[start].isspace(): + start += 1 + + if start >= len(text): + return None + + marker = text[start] + if marker not in ("\ufffc", "\ufffd", "\x1b"): + return None + + next_index = start + 1 + if next_index < len(text) and not text[next_index].isspace(): + return None + + return (start, 1) + + def _should_respond_in_group(self, message_text: str, mentions: list[dict[str, Any]]) -> bool: + """ + Determine if the bot should respond to a group message. + + Args: + message_text: The message text content + mentions: List of mentions from Signal (format: [{"number": "+1234567890", "start": 0, "length": 10}]) + + Returns: + True if bot should respond, False otherwise + """ + # Group reply behavior is controlled only by group.require_mention. + if not self.config.group.require_mention: + return True + + # If mention is required, check if bot was mentioned. + for mention in mentions: + if not isinstance(mention, dict): + continue + for mention_id in self._mention_id_candidates(mention): + if self._id_matches_account(mention_id): + return True + + # Some Signal clients emit mention spans without recipient identifiers + # (for handle-style mentions). Accept a leading identifier-less mention + # as a mention of the bot to avoid false negatives. + for mention in mentions: + if not isinstance(mention, dict): + continue + if self._mention_id_candidates(mention): + continue + span = self._mention_span(mention) + if not span: + continue + start, _ = span + if message_text is not None and not message_text[:start].strip(): + self.logger.debug("Accepting identifier-less leading mention as bot mention") + return True + + # Some payloads omit `mentions` but still include the leading mention + # placeholder character in the message body. + if not mentions and self._leading_placeholder_span(message_text): + self.logger.debug("Accepting leading placeholder mention without mention metadata") + return True + + # Fallback: check for configured phone number in plain text. + if message_text and self.config.phone_number: + for account_id in self._normalize_signal_id(self.config.phone_number): + if account_id and account_id in message_text: + return True + + return False + + def _strip_bot_mention(self, text: str, mentions: list[dict[str, Any]]) -> str: + """ + Remove bot mentions from message text. + + Signal mentions are embedded in the text, so we need to remove them based on + the mentions array which provides start position and length. + + Args: + text: Original message text + mentions: List of mention objects with start/length positions + + Returns: + Text with bot mentions removed + """ + if not text: + return text + + # Build a list of (start, length) tuples for our bot's mentions + bot_mentions = [] + for mention in mentions: + if not isinstance(mention, dict): + continue + mention_ids = self._mention_id_candidates(mention) + span = self._mention_span(mention) + if not span: + continue + + # Strip matched bot mentions by ID. + if any(self._id_matches_account(mention_id) for mention_id in mention_ids): + bot_mentions.append(span) + continue + + # Also strip identifier-less leading mention spans (handle mentions). + if not mention_ids: + start, _ = span + if not text[:start].strip(): + bot_mentions.append(span) + + if not bot_mentions: + placeholder_span = self._leading_placeholder_span(text) + if placeholder_span: + bot_mentions.append(placeholder_span) + + # Sort mentions by start position (descending) to remove from end to start + # This prevents position shifts when removing earlier mentions + bot_mentions.sort(reverse=True) + + # Remove each mention + for start, length in bot_mentions: + if start >= len(text): + continue + end = min(len(text), start + length) + text = text[:start] + text[end:] + + return text.strip() + + @staticmethod + def _is_group_chat_id(chat_id: str) -> bool: + """Return True when chat_id appears to be a Signal group ID (base64).""" + return "=" in chat_id or (len(chat_id) > 40 and "-" not in chat_id) + + def _recipient_params(self, chat_id: str) -> dict[str, Any]: + """Build recipient params for signal-cli JSON-RPC methods.""" + if self._is_group_chat_id(chat_id): + return {"groupId": chat_id} + return {"recipient": [chat_id]} + + async def _start_typing(self, chat_id: str) -> None: + """Start periodic typing indicator updates for a chat.""" + await self._stop_typing(chat_id, send_stop=False) + await self._send_typing(chat_id) + self._typing_tasks[chat_id] = asyncio.create_task(self._typing_loop(chat_id)) + + async def _stop_typing(self, chat_id: str, send_stop: bool = True) -> None: + """Stop typing indicator updates for a chat.""" + task = self._typing_tasks.pop(chat_id, None) + had_task = task is not None + if task and not task.done(): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + if send_stop and had_task: + await self._send_typing(chat_id, stop=True) + + async def _typing_loop(self, chat_id: str) -> None: + """Send typing updates periodically until cancelled.""" + try: + while self._running: + await asyncio.sleep(self._TYPING_REFRESH_SECONDS) + await self._send_typing(chat_id, quiet_success=True) + except asyncio.CancelledError: + pass + except Exception as e: + self.logger.debug("Typing indicator loop stopped for {}: {}", chat_id, e) + + async def _send_typing( + self, chat_id: str, stop: bool = False, quiet_success: bool = False + ) -> None: + """Send a typing START/STOP message via signal-cli.""" + action = "stop" if stop else "start" + if ( + not self._is_group_chat_id(chat_id) + and chat_id.startswith("+") is False + and chat_id not in self._typing_uuid_warnings + ): + self._typing_uuid_warnings.add(chat_id) + self.logger.warning( + "Signal DM recipient is UUID-only (no phone number in envelope). " + "Some Signal clients may not render typing indicators for this recipient form." + ) + candidate_params: list[dict[str, Any]] + if self._is_group_chat_id(chat_id): + candidate_params = [{"groupId": chat_id}, {"groupId": [chat_id]}] + else: + candidate_params = [{"recipient": chat_id}, {"recipient": [chat_id]}] + + last_error: Any | None = None + for params in candidate_params: + if stop: + params["stop"] = True + try: + response = await self._send_request("sendTyping", params) + except Exception as e: + last_error = str(e) + continue + + if "error" not in response: + if not quiet_success: + self.logger.info("Signal typing {} sent for {}", action, chat_id) + return + + last_error = response["error"] + + self.logger.warning( + "Failed to send Signal typing {} for {}: {}", action, chat_id, last_error + ) + + async def _ensure_typing_indicators_enabled(self) -> None: + """Enable typing indicators on the bot account.""" + response = await self._send_request("updateConfiguration", {"typingIndicators": True}) + if "error" in response: + self.logger.warning( + "Failed to enable Signal typing indicators: {}", response["error"] + ) + else: + self.logger.info("Signal typing indicators enabled on account configuration") + + async def _send_request( + self, method: str, params: dict[str, Any] | None = None + ) -> dict[str, Any]: + """Send a JSON-RPC request via HTTP and wait for response.""" + # Generate request ID + self._request_id += 1 + request_id = self._request_id + + # Build JSON-RPC request + request = {"jsonrpc": "2.0", "method": method, "id": request_id} + + if params: + request["params"] = params + + return await self._send_http_request(request) + + async def _send_http_request(self, request: dict[str, Any]) -> dict[str, Any]: + """Send JSON-RPC request via HTTP.""" + if not self._http: + raise RuntimeError("Not connected to signal-cli daemon") + + try: + response = await self._http.post("/api/v1/rpc", json=request) + response.raise_for_status() + return response.json() + except Exception as e: + self.logger.error("HTTP request failed: {}", e) + return {"error": {"message": str(e)}} diff --git a/nanobot/providers/openai_compat_provider.py b/nanobot/providers/openai_compat_provider.py index b8112b529..03ab35a0e 100644 --- a/nanobot/providers/openai_compat_provider.py +++ b/nanobot/providers/openai_compat_provider.py @@ -1097,6 +1097,15 @@ class OpenAICompatProvider(LLMProvider): if delta: _accum_legacy_function_call(getattr(delta, "function_call", None)) + # Some providers (e.g. Zhipu/GLM) reuse the same tool_call id for + # parallel tool calls in streaming mode. Deduplicate before building + # the response so downstream tool messages don't collide. + _seen_tc_ids: set[str] = set() + for b in tc_bufs.values(): + if not b["id"] or b["id"] in _seen_tc_ids: + b["id"] = _short_tool_id() + _seen_tc_ids.add(b["id"]) + return LLMResponse( content="".join(content_parts) or None, tool_calls=[ diff --git a/nanobot/utils/file_edit_events.py b/nanobot/utils/file_edit_events.py index acef725b0..f02022c13 100644 --- a/nanobot/utils/file_edit_events.py +++ b/nanobot/utils/file_edit_events.py @@ -9,7 +9,6 @@ from dataclasses import dataclass, field from pathlib import Path from typing import Any, Awaitable, Callable - TRACKED_FILE_EDIT_TOOLS = frozenset({"write_file", "edit_file", "apply_patch"}) _MAX_SNAPSHOT_BYTES = 2 * 1024 * 1024 _LIVE_EMIT_INTERVAL_S = 0.18 @@ -457,12 +456,14 @@ class StreamingFileEditTracker: def apply_final_call_ids(self, final_tool_calls: list[Any]) -> None: """Keep final start/end events keyed to any earlier streamed placeholder.""" + used_canonicals: set[str] = set() for tool_call in final_tool_calls: canonical = self.canonical_call_id_for(tool_call) - if canonical: + if canonical and canonical not in used_canonicals: try: tool_call.id = canonical - except Exception: + used_canonicals.add(canonical) + except (AttributeError, TypeError): pass def canonical_call_id_for(self, tool_call: Any) -> str | None: diff --git a/tests/channels/test_signal_channel.py b/tests/channels/test_signal_channel.py new file mode 100644 index 000000000..277c85b83 --- /dev/null +++ b/tests/channels/test_signal_channel.py @@ -0,0 +1,1514 @@ +"""Tests for the Signal channel implementation.""" + +from __future__ import annotations + +import asyncio +from contextlib import asynccontextmanager +from pathlib import Path +from unittest.mock import MagicMock + +import pytest + +from nanobot.bus.events import InboundMessage, OutboundMessage +from nanobot.bus.queue import MessageBus +from nanobot.channels.signal import ( + SignalChannel, + SignalConfig, + SignalDMConfig, + SignalGroupConfig, +) + +# --------------------------------------------------------------------------- +# Fake HTTP client +# --------------------------------------------------------------------------- + + +class _FakeResponse: + def __init__(self, status_code: int = 200, body: dict | None = None) -> None: + self.status_code = status_code + self._body = body or {} + + def raise_for_status(self) -> None: + if self.status_code >= 400: + raise RuntimeError(f"HTTP {self.status_code}") + + def json(self) -> dict: + return self._body + + +class _FakeHTTPClient: + """Minimal httpx.AsyncClient stand-in that records requests.""" + + def __init__(self, *, default_response: dict | None = None) -> None: + self.posts: list[dict] = [] + self.gets: list[str] = [] + self._response = _FakeResponse(body=default_response or {"result": {"timestamp": 123}}) + self.closed = False + + async def get(self, path: str) -> _FakeResponse: + self.gets.append(path) + return self._response + + async def post(self, path: str, *, json: dict) -> _FakeResponse: + self.posts.append({"path": path, "json": json}) + return self._response + + async def aclose(self) -> None: + self.closed = True + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_channel_with_capture(**overrides) -> tuple[SignalChannel, list[dict]]: + """Build a SignalChannel with _handle_message captured into a list and a + no-op _start_typing, used by every receive-flow test class. + """ + ch = _make_channel(**overrides) + handled: list[dict] = [] + + async def capture(**kwargs): + handled.append(kwargs) + + async def noop_typing(chat_id): + pass + + ch._handle_message = capture # type: ignore[method-assign] + ch._start_typing = noop_typing # type: ignore[method-assign] + return ch, handled + + +def _make_channel( + *, + phone_number: str = "+10000000000", + dm_enabled: bool = True, + dm_policy: str = "open", + dm_allow_from: list[str] | None = None, + group_enabled: bool = False, + group_policy: str = "open", + group_allow_from: list[str] | None = None, + require_mention: bool = True, + group_buffer_size: int = 20, + attachments_dir: str | None = None, +) -> SignalChannel: + config = SignalConfig( + enabled=True, + phone_number=phone_number, + dm=SignalDMConfig( + enabled=dm_enabled, + policy=dm_policy, + allow_from=dm_allow_from or [], + ), + group=SignalGroupConfig( + enabled=group_enabled, + policy=group_policy, + allow_from=group_allow_from or [], + require_mention=require_mention, + ), + group_message_buffer_size=group_buffer_size, + attachments_dir=attachments_dir, + ) + return SignalChannel(config, MessageBus()) + + +def _dm_envelope( + *, + source_number: str = "+19995550001", + source_uuid: str | None = None, + source_name: str | None = "Alice", + message: str = "hello", + attachments: list | None = None, + reaction: dict | None = None, + timestamp: int = 1000, +) -> dict: + data_message: dict = {"message": message, "timestamp": timestamp} + if attachments is not None: + data_message["attachments"] = attachments + if reaction is not None: + data_message["reaction"] = reaction + envelope: dict = { + "sourceNumber": source_number, + "sourceName": source_name, + "dataMessage": data_message, + } + if source_uuid: + envelope["sourceUuid"] = source_uuid + return {"envelope": envelope} + + +def _group_envelope( + *, + source_number: str = "+19995550001", + source_name: str = "Bob", + group_id: str = "group123==", + message: str = "hey group", + mentions: list | None = None, + timestamp: int = 2000, + use_v2: bool = False, +) -> dict: + group_obj = {"groupId": group_id} + key = "groupV2" if use_v2 else "groupInfo" + data_message: dict = { + "message": message, + "timestamp": timestamp, + key: group_obj, + "mentions": mentions or [], + } + return { + "envelope": { + "sourceNumber": source_number, + "sourceName": source_name, + "dataMessage": data_message, + } + } + + +# --------------------------------------------------------------------------- +# Static utility tests +# --------------------------------------------------------------------------- + + +class TestNormalizeSignalId: + def test_phone_number_kept_and_stripped(self): + result = SignalChannel._normalize_signal_id("+12345678901") + assert "+12345678901" in result + assert "12345678901" in result + + def test_digits_only_gets_plus_prefix(self): + result = SignalChannel._normalize_signal_id("12345678901") + assert "+12345678901" in result + + def test_lowercase_variant_added(self): + result = SignalChannel._normalize_signal_id("SOME-UUID") + assert "some-uuid" in result + + def test_empty_string_returns_empty(self): + assert SignalChannel._normalize_signal_id("") == [] + + def test_whitespace_stripped(self): + result = SignalChannel._normalize_signal_id(" +1234 ") + assert "+1234" in result + + +class TestCollectSenderIdParts: + def test_collects_source_number(self): + env = {"sourceNumber": "+15551234567"} + parts = SignalChannel._collect_sender_id_parts(env) + assert "+15551234567" in parts + + def test_collects_multiple_keys(self): + env = {"sourceNumber": "+15551234567", "sourceUuid": "uuid-abc"} + parts = SignalChannel._collect_sender_id_parts(env) + assert "+15551234567" in parts + assert "uuid-abc" in parts + + def test_deduplicates(self): + env = {"sourceNumber": "+15551234567", "source": "+15551234567"} + parts = SignalChannel._collect_sender_id_parts(env) + assert parts.count("+15551234567") == 1 + + def test_ignores_non_string_values(self): + env = {"sourceNumber": 12345, "sourceUuid": None} + parts = SignalChannel._collect_sender_id_parts(env) + assert parts == [] + + def test_empty_envelope_returns_empty(self): + assert SignalChannel._collect_sender_id_parts({}) == [] + + +class TestPrimarySenderId: + def test_prefers_phone_number(self): + assert SignalChannel._primary_sender_id(["+1234", "uuid-abc"]) == "+1234" + + def test_accepts_digit_only(self): + assert SignalChannel._primary_sender_id(["1234567890", "uuid-abc"]) == "1234567890" + + def test_falls_back_to_first_part(self): + assert SignalChannel._primary_sender_id(["uuid-abc", "other"]) == "uuid-abc" + + def test_empty_list_returns_empty(self): + assert SignalChannel._primary_sender_id([]) == "" + + +class TestExtractGroupId: + def test_extracts_from_group_info(self): + gid = SignalChannel._extract_group_id({"groupId": "abc=="}, None) + assert gid == "abc==" + + def test_extracts_from_group_v2(self): + gid = SignalChannel._extract_group_id(None, {"id": "xyz=="}) + assert gid == "xyz==" + + def test_prefers_group_info_over_v2(self): + gid = SignalChannel._extract_group_id({"groupId": "first"}, {"groupId": "second"}) + assert gid == "first" + + def test_returns_none_when_both_none(self): + assert SignalChannel._extract_group_id(None, None) is None + + def test_returns_none_when_not_dicts(self): + assert SignalChannel._extract_group_id("bad", 123) is None + + +class TestIsGroupChatId: + def test_base64_with_equals_is_group(self): + assert SignalChannel._is_group_chat_id("abc==") is True + + def test_long_id_without_dash_is_group(self): + long_id = "a" * 41 + assert SignalChannel._is_group_chat_id(long_id) is True + + def test_phone_number_is_not_group(self): + assert SignalChannel._is_group_chat_id("+12345678901") is False + + def test_uuid_with_dashes_is_not_group(self): + assert SignalChannel._is_group_chat_id("550e8400-e29b-41d4-a716-446655440000") is False + + +class TestRecipientParams: + def test_group_chat_uses_group_id(self): + ch = _make_channel() + params = ch._recipient_params("abc==") + assert params == {"groupId": "abc=="} + + def test_dm_uses_recipient_list(self): + ch = _make_channel() + params = ch._recipient_params("+12345678901") + assert params == {"recipient": ["+12345678901"]} + + +class TestMentionHelpers: + def test_mention_id_candidates_extracts_number(self): + mention = {"number": "+1234567890"} + ids = SignalChannel._mention_id_candidates(mention) + assert "+1234567890" in ids + + def test_mention_id_candidates_extracts_uuid(self): + mention = {"uuid": "some-uuid"} + ids = SignalChannel._mention_id_candidates(mention) + assert "some-uuid" in ids + + def test_mention_span_valid(self): + assert SignalChannel._mention_span({"start": 0, "length": 5}) == (0, 5) + + def test_mention_span_negative_start(self): + assert SignalChannel._mention_span({"start": -1, "length": 5}) is None + + def test_mention_span_zero_length(self): + assert SignalChannel._mention_span({"start": 0, "length": 0}) is None + + def test_mention_span_missing_keys(self): + assert SignalChannel._mention_span({}) is None + + def test_leading_placeholder_ufffc(self): + span = SignalChannel._leading_placeholder_span(" hello") + assert span == (0, 1) + + def test_leading_placeholder_not_at_start(self): + assert SignalChannel._leading_placeholder_span("hello ") is None + + def test_leading_placeholder_empty_string(self): + assert SignalChannel._leading_placeholder_span("") is None + + def test_leading_placeholder_plain_text(self): + assert SignalChannel._leading_placeholder_span("hello") is None + + +# --------------------------------------------------------------------------- +# Account ID alias / mention matching +# --------------------------------------------------------------------------- + + +class TestAccountIdAliases: + def test_phone_number_alias_registered_on_init(self): + ch = _make_channel(phone_number="+10000000000") + assert ch._id_matches_account("+10000000000") + + def test_digit_only_variant_matches(self): + ch = _make_channel(phone_number="+10000000000") + assert ch._id_matches_account("10000000000") + + def test_remember_alias_adds_uuid(self): + ch = _make_channel() + ch._remember_account_id_alias("some-uuid-abc") + assert ch._id_matches_account("some-uuid-abc") + + def test_non_matching_id_returns_false(self): + ch = _make_channel(phone_number="+10000000000") + assert not ch._id_matches_account("+19999999999") + + def test_none_and_non_string_return_false(self): + ch = _make_channel() + assert not ch._id_matches_account(None) + + +# --------------------------------------------------------------------------- +# _should_respond_in_group +# --------------------------------------------------------------------------- + + +class TestShouldRespondInGroup: + def _make_group_channel(self, require_mention: bool = True) -> SignalChannel: + return _make_channel( + phone_number="+10000000000", + group_enabled=True, + require_mention=require_mention, + ) + + def test_no_require_mention_always_responds(self): + ch = self._make_group_channel(require_mention=False) + assert ch._should_respond_in_group("anything", []) is True + + def test_require_mention_with_no_mentions_returns_false(self): + ch = self._make_group_channel(require_mention=True) + assert ch._should_respond_in_group("hello", []) is False + + def test_require_mention_with_bot_number_mention(self): + ch = self._make_group_channel(require_mention=True) + mentions = [{"number": "+10000000000", "start": 0, "length": 12}] + assert ch._should_respond_in_group(" hello", mentions) is True + + def test_require_mention_with_uuid_mention(self): + ch = self._make_group_channel(require_mention=True) + ch._remember_account_id_alias("bot-uuid-123") + mentions = [{"uuid": "bot-uuid-123", "start": 0, "length": 8}] + assert ch._should_respond_in_group(" hello", mentions) is True + + def test_identifier_less_leading_mention_accepted(self): + ch = self._make_group_channel(require_mention=True) + # Mention with no IDs but leading span — treated as bot mention + mentions = [{"start": 0, "length": 1}] + assert ch._should_respond_in_group(" hello", mentions) is True + + def test_identifier_less_non_leading_mention_rejected(self): + ch = self._make_group_channel(require_mention=True) + mentions = [{"start": 5, "length": 1}] + assert ch._should_respond_in_group("hello ", mentions) is False + + def test_leading_placeholder_without_mentions_metadata(self): + ch = self._make_group_channel(require_mention=True) + assert ch._should_respond_in_group(" hello", []) is True + + def test_phone_number_in_text_triggers_response(self): + ch = self._make_group_channel(require_mention=True) + assert ch._should_respond_in_group("hey +10000000000 help", []) is True + + +# --------------------------------------------------------------------------- +# _strip_bot_mention +# --------------------------------------------------------------------------- + + +class TestStripBotMention: + def _make_channel_with_number(self) -> SignalChannel: + return _make_channel(phone_number="+10000000000") + + def test_strips_mention_by_phone(self): + ch = self._make_channel_with_number() + text = " hello" + mentions = [{"number": "+10000000000", "start": 0, "length": 1}] + result = ch._strip_bot_mention(text, mentions) + assert result == "hello" + + def test_strips_identifier_less_leading_mention(self): + ch = self._make_channel_with_number() + text = " hello" + mentions = [{"start": 0, "length": 1}] + result = ch._strip_bot_mention(text, mentions) + assert result == "hello" + + def test_strips_leading_placeholder_without_mention_metadata(self): + ch = self._make_channel_with_number() + text = " hello" + result = ch._strip_bot_mention(text, []) + assert result == "hello" + + def test_non_bot_mention_mid_text_not_stripped(self): + # A non-bot mention that is NOT a leading placeholder leaves the text alone. + ch = self._make_channel_with_number() + text = "hello  world" + mentions = [{"number": "+19999999999", "start": 6, "length": 1}] + result = ch._strip_bot_mention(text, mentions) + # Mid-text placeholder from a non-bot mention should be untouched + assert "" in result + + def test_empty_text_returned_unchanged(self): + ch = self._make_channel_with_number() + assert ch._strip_bot_mention("", []) == "" + + +# --------------------------------------------------------------------------- +# Group message buffer +# --------------------------------------------------------------------------- + + +class TestGroupBuffer: + def test_add_and_get_context(self): + ch = _make_channel(group_buffer_size=5) + ch._add_to_group_buffer("g1", "Alice", "+1111", "first msg", 1000) + ch._add_to_group_buffer("g1", "Bob", "+2222", "second msg", 2000) + # Only messages before the latest are returned as context + ctx = ch._get_group_buffer_context("g1") + assert "first msg" in ctx + # The last message is not included (it's the "current" one) + assert "second msg" not in ctx + + def test_empty_context_when_only_one_message(self): + ch = _make_channel(group_buffer_size=5) + ch._add_to_group_buffer("g1", "Alice", "+1111", "only msg", 1000) + assert ch._get_group_buffer_context("g1") == "" + + def test_empty_context_when_group_unknown(self): + ch = _make_channel() + assert ch._get_group_buffer_context("unknown") == "" + + def test_buffer_respects_max_size(self): + ch = _make_channel(group_buffer_size=3) + for i in range(10): + ch._add_to_group_buffer("g1", "Alice", "+1111", f"msg{i}", i) + assert len(ch._group_buffers["g1"]) == 3 + + def test_zero_buffer_size_rejected_by_validator(self): + with pytest.raises(ValueError, match="group_message_buffer_size"): + _make_channel(group_buffer_size=0) + + def test_negative_buffer_size_rejected_by_validator(self): + with pytest.raises(ValueError, match="group_message_buffer_size"): + _make_channel(group_buffer_size=-1) + + def test_context_limits_message_length(self): + ch = _make_channel(group_buffer_size=5) + long_msg = "x" * 500 + ch._add_to_group_buffer("g1", "Alice", "+1111", long_msg, 1000) + ch._add_to_group_buffer("g1", "Bob", "+2222", "short", 2000) + ctx = ch._get_group_buffer_context("g1") + # Context is capped at 200 chars per message + assert len(ctx.split("Alice: ", 1)[1]) <= 200 + + +# --------------------------------------------------------------------------- +# _handle_data_message — DM routing +# --------------------------------------------------------------------------- + + +class TestIsAllowed: + """The base-channel allowlist gate is overridden to understand Signal's + pipe-joined composite sender_ids and the +/no-+ phone variants. + """ + + def test_denies_when_allowlist_empty(self): + ch = _make_channel(dm_enabled=True, dm_policy="allowlist") + assert ch.is_allowed("+19995550001") is False + + def test_denies_when_no_policy_allows(self): + """When both dm and group are disabled, is_allowed denies.""" + ch = _make_channel(dm_enabled=False, group_enabled=False) + assert ch.is_allowed("+19995550001") is False + + def test_allows_wildcard(self): + ch = _make_channel(dm_policy="allowlist", dm_allow_from=["*"]) + assert ch.is_allowed("+19995550001|some-uuid") is True + + def test_allows_composite_sender_against_split_allowlist(self): + """Composite sender_id, single-id allow_from — must match either part.""" + ch = _make_channel( + dm_policy="allowlist", + dm_allow_from=["+19995550001"], + ) + assert ch.is_allowed("+19995550001|1872ba20-uuid") is True + + def test_allows_composite_sender_against_composite_allowlist_entry(self): + """Backward compat: pipe-joined composite allowlist entries still match.""" + composite = "+19995550001|1872ba20-uuid" + ch = _make_channel(dm_policy="allowlist", dm_allow_from=[composite]) + assert ch.is_allowed(composite) is True + + def test_allows_when_only_uuid_part_is_listed(self): + ch = _make_channel(dm_policy="allowlist", dm_allow_from=["1872ba20-uuid"]) + assert ch.is_allowed("+19995550001|1872ba20-uuid") is True + + def test_denies_when_no_part_matches(self): + ch = _make_channel(dm_policy="allowlist", dm_allow_from=["+12223334444"]) + assert ch.is_allowed("+19995550001|1872ba20-uuid") is False + + def test_allowlist_union_includes_group_ids(self): + """allow_from is the union of dm.allow_from and group.allow_from.""" + ch = _make_channel( + group_enabled=True, + group_policy="allowlist", + group_allow_from=["group-id-base64=="], + ) + assert "group-id-base64==" in ch.config.allow_from + + +class TestEndToEndDMRouting: + """End-to-end tests that keep the real _handle_message chain (no mock), + verifying that _check_inbound_policy + _handle_message work together + correctly for DM routing. The override of _handle_message publishes + directly to bus (policy already checked); denied DMs call + super()._handle_message which issues a pairing code. + """ + + @pytest.mark.asyncio + async def test_open_dm_policy_publishes_to_bus(self): + """Open DM: _check_inbound_policy passes → _handle_message publishes.""" + ch = _make_channel(dm_enabled=True, dm_policy="open") + + async def noop_typing(chat_id): + pass + + ch._start_typing = noop_typing # type: ignore[method-assign] + published: list[InboundMessage] = [] + + async def capture_publish(msg: InboundMessage): + published.append(msg) + + ch.bus.publish_inbound = capture_publish # type: ignore[method-assign] + + params = _dm_envelope(source_number="+19995550001", message="hello") + await ch._handle_receive_notification(params) + + assert len(published) == 1 + assert published[0].content == "hello" + assert published[0].sender_id == "+19995550001" + + @pytest.mark.asyncio + async def test_allowlist_dm_denied_triggers_pairing(self): + """Allowlist DM: denied sender triggers pairing code via send().""" + ch = _make_channel(dm_enabled=True, dm_policy="allowlist", dm_allow_from=[]) + ch._http = _FakeHTTPClient() # type: ignore[assignment] + + async def noop_typing(chat_id): + pass + + ch._start_typing = noop_typing # type: ignore[method-assign] + published: list[InboundMessage] = [] + + async def capture_publish(msg: InboundMessage): + published.append(msg) + + ch.bus.publish_inbound = capture_publish # type: ignore[method-assign] + + params = _dm_envelope(source_number="+19995550002", message="hello") + await ch._handle_receive_notification(params) + + # Should NOT publish to bus — sender is not on allowlist. + assert published == [] + # Should have sent a pairing code via send (captured in HTTP posts). + assert len(ch._http.posts) == 1 # type: ignore[attr-defined] + sent_text = ch._http.posts[0]["json"]["params"]["message"] # type: ignore[attr-defined] + assert "pairing" in sent_text.lower() or "pair" in sent_text.lower() + + @pytest.mark.asyncio + async def test_allowlist_dm_denied_with_group_open_still_pairs(self): + """dm.policy="allowlist" + group.policy="open": denied DM sender + must still get a pairing code, not be leaked by the group open check.""" + ch = _make_channel( + dm_enabled=True, + dm_policy="allowlist", + dm_allow_from=[], + group_enabled=True, + group_policy="open", + ) + ch._http = _FakeHTTPClient() # type: ignore[assignment] + + async def noop_typing(chat_id): + pass + + ch._start_typing = noop_typing # type: ignore[method-assign] + published: list[InboundMessage] = [] + + async def capture_publish(msg: InboundMessage): + published.append(msg) + + ch.bus.publish_inbound = capture_publish # type: ignore[method-assign] + + params = _dm_envelope(source_number="+19995550002", message="hello") + await ch._handle_receive_notification(params) + + assert published == [] + assert len(ch._http.posts) == 1 # type: ignore[attr-defined] + + @pytest.mark.asyncio + async def test_open_group_policy_publishes_to_bus(self): + """Open group: group message from unknown sender publishes to bus.""" + ch = _make_channel( + group_enabled=True, + group_policy="open", + require_mention=False, + ) + + async def noop_typing(chat_id): + pass + + ch._start_typing = noop_typing # type: ignore[method-assign] + published: list[InboundMessage] = [] + + async def capture_publish(msg: InboundMessage): + published.append(msg) + + ch.bus.publish_inbound = capture_publish # type: ignore[method-assign] + + params = _group_envelope(group_id="grp==", message="hello group") + await ch._handle_receive_notification(params) + + assert len(published) == 1 + assert "hello group" in published[0].content + + +class TestCheckInboundPolicy: + """Direct tests for the policy gate that _handle_data_message now delegates to.""" + + def _call( + self, + ch: SignalChannel, + *, + sender_id: str = "+19995550001", + sender_number: str = "+19995550001", + group_id: str | None = None, + is_group_message: bool = False, + message_text: str = "hi", + mentions: list | None = None, + sender_name: str | None = "Alice", + timestamp: int | None = 1000, + ) -> tuple[bool, str]: + return ch._check_inbound_policy( + sender_id=sender_id, + sender_number=sender_number, + group_id=group_id, + is_group_message=is_group_message, + message_text=message_text, + mentions=mentions or [], + sender_name=sender_name, + timestamp=timestamp, + ) + + def test_dm_open_allows(self): + ch = _make_channel(dm_enabled=True, dm_policy="open") + allowed, chat_id = self._call(ch) + assert allowed is True + assert chat_id == "+19995550001" + + def test_dm_disabled_blocks(self): + ch = _make_channel(dm_enabled=False) + allowed, _ = self._call(ch) + assert allowed is False + + def test_dm_allowlist_blocks_unknown_sender(self): + ch = _make_channel(dm_policy="allowlist", dm_allow_from=["+12223334444"]) + allowed, _ = self._call(ch, sender_id="+19995550001") + assert allowed is False + + def test_dm_allowlist_allows_known_sender(self): + ch = _make_channel(dm_policy="allowlist", dm_allow_from=["+19995550001"]) + allowed, _ = self._call(ch, sender_id="+19995550001") + assert allowed is True + + def test_group_disabled_blocks(self): + ch = _make_channel(group_enabled=False) + allowed, _ = self._call(ch, is_group_message=True, group_id="g1") + assert allowed is False + + def test_group_open_with_mention_allows(self): + ch = _make_channel( + group_enabled=True, + group_policy="open", + phone_number="+10000000000", + require_mention=True, + ) + allowed, chat_id = self._call( + ch, + is_group_message=True, + group_id="g1", + message_text="hello @bot", + mentions=[{"number": "+10000000000", "start": 6, "length": 4}], + ) + assert allowed is True + assert chat_id == "g1" + + def test_group_open_without_mention_blocks(self): + ch = _make_channel(group_enabled=True, group_policy="open", require_mention=True) + allowed, _ = self._call(ch, is_group_message=True, group_id="g1", message_text="plain talk") + assert allowed is False + + def test_group_command_bypasses_mention_requirement(self): + ch = _make_channel(group_enabled=True, group_policy="open", require_mention=True) + allowed, _ = self._call(ch, is_group_message=True, group_id="g1", message_text="/help") + assert allowed is True + + def test_allowed_group_appends_to_buffer(self): + """Side effect: when a group message is allowed, it lands in the buffer.""" + ch = _make_channel(group_enabled=True, group_policy="open", require_mention=False) + self._call(ch, is_group_message=True, group_id="g1", message_text="first") + self._call(ch, is_group_message=True, group_id="g1", message_text="second") + assert len(ch._group_buffers["g1"]) == 2 + + def test_blocked_group_does_not_append_to_buffer(self): + """Side effect: when a group is disabled, the buffer must not change.""" + ch = _make_channel(group_enabled=False) + self._call(ch, is_group_message=True, group_id="g1", message_text="hi") + assert "g1" not in ch._group_buffers + + +class TestAttachmentsDir: + def test_default_attachments_dir(self): + ch = _make_channel() + expected = Path.home() / ".local/share/signal-cli/attachments" + assert ch._signal_attachments_dir() == expected + + def test_configured_attachments_dir(self, tmp_path): + ch = _make_channel(attachments_dir=str(tmp_path / "custom")) + assert ch._signal_attachments_dir() == tmp_path / "custom" + + def test_attachments_dir_expands_user(self): + ch = _make_channel(attachments_dir="~/signal-attachments") + assert ch._signal_attachments_dir() == Path.home() / "signal-attachments" + + +class TestHandleDataMessageDM: + def _make_dm_channel(self, policy="open", allow_from=None) -> tuple[SignalChannel, list]: + return _make_channel_with_capture( + dm_enabled=True, dm_policy=policy, dm_allow_from=allow_from or [] + ) + + @pytest.mark.asyncio + async def test_dm_open_policy_accepted(self): + ch, handled = self._make_dm_channel(policy="open") + params = _dm_envelope(source_number="+19995550001", message="hi") + await ch._handle_receive_notification(params) + assert len(handled) == 1 + assert handled[0]["chat_id"] == "+19995550001" + assert handled[0]["content"] == "hi" + + @pytest.mark.asyncio + async def test_dm_allowlist_accepted(self): + ch, handled = self._make_dm_channel(policy="allowlist", allow_from=["+19995550001"]) + params = _dm_envelope(source_number="+19995550001") + await ch._handle_receive_notification(params) + assert len(handled) == 1 + + @pytest.mark.asyncio + async def test_dm_allowlist_rejected_triggers_pairing(self): + # Denied DM senders go through super()._handle_message which checks + # is_allowed → sends pairing code via self.send(). + ch, handled = self._make_dm_channel(policy="allowlist", allow_from=["+10000000001"]) + ch._http = _FakeHTTPClient() # type: ignore[attr-defined] + params = _dm_envelope(source_number="+19995550002") + await ch._handle_receive_notification(params) + # The denied DM path calls super()._handle_message, not self._handle_message, + # so the capture list stays empty. Verify pairing code was sent via HTTP. + assert handled == [] + assert len(ch._http.posts) == 1 # type: ignore[attr-defined] + sent_text = ch._http.posts[0]["json"]["params"]["message"] # type: ignore[attr-defined] + assert "pairing" in sent_text.lower() or "pair" in sent_text.lower() + + @pytest.mark.asyncio + async def test_dm_paired_sender_allowed_without_allowlist_entry(self, monkeypatch): + # Once a sender completes pairing they should pass is_allowed on every + # subsequent message — otherwise the pairing reply loops forever. + approved = {"+19995550002"} + monkeypatch.setattr( + "nanobot.channels.signal.is_approved", + lambda channel, sender_id: sender_id in approved, + ) + ch = _make_channel(dm_enabled=True, dm_policy="allowlist", dm_allow_from=[]) + assert ch.is_allowed("+19995550002") is True + # Variant forms (with/without "+") must still match a stored approval. + assert ch.is_allowed("19995550002") is True + # Unpaired sender stays denied. + assert ch.is_allowed("+19995559999") is False + + @pytest.mark.asyncio + async def test_dm_allowlist_matches_without_plus_prefix(self): + """An allowlist entry without '+' must match a sender that carries '+'.""" + ch, handled = self._make_dm_channel(policy="allowlist", allow_from=["19995550001"]) + params = _dm_envelope(source_number="+19995550001") + await ch._handle_receive_notification(params) + assert len(handled) == 1 + + @pytest.mark.asyncio + async def test_dm_allowlist_matches_with_plus_prefix(self): + """An allowlist entry with '+' must match a sender without '+'.""" + ch, handled = self._make_dm_channel(policy="allowlist", allow_from=["+19995550001"]) + params = _dm_envelope(source_number="+19995550001", source_uuid=None) + # Replace envelope's sourceNumber with the non-prefixed form by editing + # the constructed dict directly so _collect_sender_id_parts sees it. + params["envelope"]["sourceNumber"] = "19995550001" + await ch._handle_receive_notification(params) + assert len(handled) == 1 + + @pytest.mark.asyncio + async def test_dm_allowlist_matches_uuid_case_insensitive(self): + """UUID matching must be case-insensitive.""" + uuid = "ABCDEF12-3456-7890-ABCD-EF1234567890" + ch, handled = self._make_dm_channel(policy="allowlist", allow_from=[uuid.lower()]) + params = _dm_envelope(source_number="+19995550001", source_uuid=uuid) + await ch._handle_receive_notification(params) + assert len(handled) == 1 + + @pytest.mark.asyncio + async def test_dm_allowlist_matches_pipe_joined_composite_entry(self): + """Allowlist entries written as ``phone|uuid`` composites still work. + + Some configs pre-date the per-part splitting and store the full + sender_id composite as a single allow_from entry. Keep matching it. + """ + composite = "+19995550001|1872ba20-f52a-4bad-b434-bf7f808c8b22" + ch, handled = self._make_dm_channel(policy="allowlist", allow_from=[composite]) + params = _dm_envelope( + source_number="+19995550001", + source_uuid="1872ba20-f52a-4bad-b434-bf7f808c8b22", + ) + await ch._handle_receive_notification(params) + assert len(handled) == 1 + + @pytest.mark.asyncio + async def test_dm_disabled_rejected(self): + ch = _make_channel(dm_enabled=False) + handled: list[dict] = [] + + async def capture(**kwargs): + handled.append(kwargs) + + ch._handle_message = capture # type: ignore[method-assign] + + async def noop_typing(chat_id): + pass + + ch._start_typing = noop_typing # type: ignore[method-assign] + params = _dm_envelope(source_number="+19995550001") + await ch._handle_receive_notification(params) + assert handled == [] + + @pytest.mark.asyncio + async def test_reaction_message_ignored(self): + ch, handled = self._make_dm_channel() + params = _dm_envelope(reaction={"emoji": "👍", "targetTimestamp": 999}) + await ch._handle_receive_notification(params) + assert handled == [] + + @pytest.mark.asyncio + async def test_empty_message_ignored(self): + ch, handled = self._make_dm_channel() + params = _dm_envelope(message="") + await ch._handle_receive_notification(params) + assert handled == [] + + @pytest.mark.asyncio + async def test_receipt_message_ignored(self): + ch, handled = self._make_dm_channel() + notification = { + "envelope": { + "sourceNumber": "+19995550001", + "receiptMessage": {"when": 1234}, + } + } + await ch._handle_receive_notification(notification) + assert handled == [] + + @pytest.mark.asyncio + async def test_typing_indicator_ignored(self): + ch, handled = self._make_dm_channel() + notification = { + "envelope": { + "sourceNumber": "+19995550001", + "typingMessage": {"action": "STARTED"}, + } + } + await ch._handle_receive_notification(notification) + assert handled == [] + + @pytest.mark.asyncio + async def test_missing_envelope_ignored(self): + ch, handled = self._make_dm_channel() + await ch._handle_receive_notification({}) + assert handled == [] + + @pytest.mark.asyncio + async def test_metadata_passed_to_handle(self): + ch, handled = self._make_dm_channel() + params = _dm_envelope(source_number="+19995550001", source_name="Alice", timestamp=9999) + await ch._handle_receive_notification(params) + meta = handled[0]["metadata"] + assert meta["sender_name"] == "Alice" + assert meta["timestamp"] == 9999 + assert meta["is_group"] is False + + @pytest.mark.asyncio + async def test_sender_id_with_uuid_variant(self): + ch, handled = self._make_dm_channel() + params = _dm_envelope(source_number="+19995550001", source_uuid="uuid-abc") + await ch._handle_receive_notification(params) + assert len(handled) == 1 + # sender_id combines both parts + assert "+19995550001" in handled[0]["sender_id"] + assert "uuid-abc" in handled[0]["sender_id"] + + @pytest.mark.asyncio + async def test_stop_typing_called_on_handle_error(self): + ch = _make_channel(dm_enabled=True, dm_policy="open") + typing_stopped: list[str] = [] + + async def fail_handle(**kwargs): + raise RuntimeError("boom") + + async def noop_typing(chat_id): + pass + + async def record_stop(chat_id, **kwargs): + typing_stopped.append(chat_id) + + ch._handle_message = fail_handle # type: ignore[method-assign] + ch._start_typing = noop_typing # type: ignore[method-assign] + ch._stop_typing = record_stop # type: ignore[method-assign] + + # _handle_receive_notification swallows exceptions; the typing stop + # still fires from _handle_data_message's except clause. + params = _dm_envelope(source_number="+19995550001") + await ch._handle_receive_notification(params) + + assert "+19995550001" in typing_stopped + + +# --------------------------------------------------------------------------- +# _handle_data_message — group routing +# --------------------------------------------------------------------------- + + +class TestHandleDataMessageGroup: + def _make_group_channel( + self, + policy="open", + allow_from=None, + require_mention=True, + ) -> tuple[SignalChannel, list]: + return _make_channel_with_capture( + group_enabled=True, + group_policy=policy, + group_allow_from=allow_from or [], + require_mention=require_mention, + ) + + @pytest.mark.asyncio + async def test_group_disabled_rejected(self): + ch = _make_channel(group_enabled=False) + handled: list[dict] = [] + ch._handle_message = lambda **kw: handled.append(kw) # type: ignore[method-assign] + params = _group_envelope(group_id="grp==", message="hi") + await ch._handle_receive_notification(params) + assert handled == [] + + @pytest.mark.asyncio + async def test_group_open_policy_no_mention_blocked_when_required(self): + ch, handled = self._make_group_channel(require_mention=True) + params = _group_envelope(group_id="grp==", message="hey everyone") + await ch._handle_receive_notification(params) + assert handled == [] + + @pytest.mark.asyncio + async def test_group_open_policy_no_mention_required(self): + ch, handled = self._make_group_channel(require_mention=False) + params = _group_envelope(group_id="grp==", message="hey everyone") + await ch._handle_receive_notification(params) + assert len(handled) == 1 + assert handled[0]["chat_id"] == "grp==" + + @pytest.mark.asyncio + async def test_group_allowlist_accepted(self): + ch, handled = self._make_group_channel( + policy="allowlist", allow_from=["grp=="], require_mention=False + ) + params = _group_envelope(group_id="grp==", message="hi") + await ch._handle_receive_notification(params) + assert len(handled) == 1 + + @pytest.mark.asyncio + async def test_group_allowlist_rejected(self): + ch, handled = self._make_group_channel(policy="allowlist", allow_from=["other=="]) + params = _group_envelope(group_id="grp==", message="hi") + await ch._handle_receive_notification(params) + assert handled == [] + + @pytest.mark.asyncio + async def test_group_mention_triggers_response(self): + ch, handled = self._make_group_channel(require_mention=True) + ch._remember_account_id_alias("+10000000000") + mentions = [{"number": "+10000000000", "start": 0, "length": 1}] + params = _group_envelope(group_id="grp==", message=" hello", mentions=mentions) + await ch._handle_receive_notification(params) + assert len(handled) == 1 + + @pytest.mark.asyncio + async def test_group_v2_id_extracted(self): + ch, handled = self._make_group_channel(require_mention=False) + params = _group_envelope(group_id="grpV2==", message="hi", use_v2=True) + await ch._handle_receive_notification(params) + assert len(handled) == 1 + assert handled[0]["chat_id"] == "grpV2==" + + @pytest.mark.asyncio + async def test_group_message_includes_sender_prefix(self): + ch, handled = self._make_group_channel(require_mention=False) + params = _group_envelope(group_id="grp==", source_name="Bob", message="hello") + await ch._handle_receive_notification(params) + assert "[Bob]:" in handled[0]["content"] + + @pytest.mark.asyncio + async def test_group_message_context_prepended(self): + ch, handled = self._make_group_channel(require_mention=False) + # First message — adds to buffer but no context yet + params1 = _group_envelope(group_id="grp==", source_name="Alice", message="msg1") + await ch._handle_receive_notification(params1) + # Second message — should include context from first + params2 = _group_envelope(group_id="grp==", source_name="Bob", message="msg2") + await ch._handle_receive_notification(params2) + assert "[Recent group messages for context:]" in handled[1]["content"] + assert "msg1" in handled[1]["content"] + + @pytest.mark.asyncio + async def test_group_metadata_marks_is_group(self): + ch, handled = self._make_group_channel(require_mention=False) + params = _group_envelope(group_id="grp==", message="hi") + await ch._handle_receive_notification(params) + assert handled[0]["metadata"]["is_group"] is True + assert handled[0]["metadata"]["group_id"] == "grp==" + + @pytest.mark.asyncio + async def test_bot_account_alias_learned_from_incoming(self): + ch, handled = self._make_group_channel(require_mention=False) + # If the bot's own UUID appears in an envelope we learn it + params = _dm_envelope(source_number="+10000000000", source_uuid="new-bot-uuid") + # DMs from self are processed (learning alias), but DM policy is open + ch._handle_message = lambda **kw: handled.append(kw) # type: ignore[method-assign] + ch._start_typing = lambda chat_id: None # type: ignore[method-assign] + await ch._handle_receive_notification(params) + assert ch._id_matches_account("new-bot-uuid") + + +# --------------------------------------------------------------------------- +# Lifecycle / SSE +# --------------------------------------------------------------------------- + + +class _FakeSSEResponse: + """Minimal stand-in for httpx Response under stream().""" + + def __init__(self, lines: list[str], status_code: int = 200) -> None: + self.status_code = status_code + self._lines = lines + + async def aiter_lines(self): + for line in self._lines: + yield line + + +def _fake_streaming_client(lines: list[str], *, status_code: int = 200) -> MagicMock: + """Return an httpx.AsyncClient stand-in whose .stream() yields a FakeSSEResponse.""" + response = _FakeSSEResponse(lines, status_code=status_code) + + @asynccontextmanager + async def _ctx(*_args, **_kwargs): + yield response + + http = MagicMock() + http.stream = lambda *a, **kw: _ctx(*a, **kw) + return http + + +class TestLifecycle: + @pytest.mark.asyncio + async def test_start_returns_early_when_phone_missing(self): + """start() with an empty phone number must not enter the HTTP loop.""" + ch = _make_channel(phone_number="") + await ch.start() + assert ch._running is False + assert ch._http is None + assert ch._sse_task is None + + +class TestSSEReceiveLoop: + @pytest.mark.asyncio + async def test_dispatches_valid_envelope(self): + ch = _make_channel() + ch._running = True + + captured: list[dict] = [] + + async def capture(params): + captured.append(params) + + ch._handle_receive_notification = capture # type: ignore[method-assign] + ch._http = _fake_streaming_client( + ['data: {"envelope":{"sourceNumber":"+19995550001"}}', ""] + ) + + # Loop ends when lines exhaust; the surrounding _start_http_mode would + # treat that as a disconnect, but the loop itself raises ConnectionError + # when the stream closes while still running. + with pytest.raises(ConnectionError): + await ch._sse_receive_loop() + assert captured == [{"envelope": {"sourceNumber": "+19995550001"}}] + + @pytest.mark.asyncio + async def test_handles_invalid_json_frame(self): + """An unparseable SSE frame is logged and skipped without crashing.""" + ch = _make_channel() + ch._running = True + + captured: list[dict] = [] + + async def capture(params): + captured.append(params) + + ch._handle_receive_notification = capture # type: ignore[method-assign] + ch._http = _fake_streaming_client( + [ + "data: this-is-not-json", + "", # event boundary triggers parse attempt + 'data: {"envelope":{"sourceNumber":"+1"}}', + "", + ] + ) + + with pytest.raises(ConnectionError): + await ch._sse_receive_loop() + # Bad frame skipped; good frame still dispatched. + assert captured == [{"envelope": {"sourceNumber": "+1"}}] + + @pytest.mark.asyncio + async def test_non_200_status_raises(self): + ch = _make_channel() + ch._running = True + ch._http = _fake_streaming_client([], status_code=503) + with pytest.raises(ConnectionError, match="status 503"): + await ch._sse_receive_loop() + + @pytest.mark.asyncio + async def test_no_http_client_raises(self): + ch = _make_channel() + ch._http = None + with pytest.raises(RuntimeError, match="HTTP client not initialized"): + await ch._sse_receive_loop() + + +# --------------------------------------------------------------------------- +# Command handling +# --------------------------------------------------------------------------- + + +class TestCommandHandling: + @pytest.mark.asyncio + async def test_dm_command_forwarded_to_bus(self): + """Slash commands in DMs are forwarded to the bus for AgentLoop to handle.""" + ch, forwarded = _make_channel_with_capture(dm_enabled=True, dm_policy="open") + params = _dm_envelope(source_number="+19995550001", message="/reset") + await ch._handle_receive_notification(params) + assert len(forwarded) == 1 + assert forwarded[0]["content"].strip() == "/reset" + + @pytest.mark.asyncio + async def test_group_command_bypasses_mention_requirement(self): + """Slash commands in groups bypass the mention requirement and reach the bus.""" + ch, forwarded = _make_channel_with_capture( + group_enabled=True, group_policy="open", require_mention=True + ) + params = _group_envelope(source_number="+19995550001", group_id="grp==", message="/reset") + await ch._handle_receive_notification(params) + assert len(forwarded) == 1 + assert "/reset" in forwarded[0]["content"] + + @pytest.mark.asyncio + async def test_command_denied_for_disallowed_dm_sender(self): + """Commands from senders not on the DM allowlist are dropped.""" + ch, forwarded = _make_channel_with_capture(dm_enabled=False) + params = _dm_envelope(source_number="+19995550001", message="/reset") + await ch._handle_receive_notification(params) + assert forwarded == [] + + +# --------------------------------------------------------------------------- +# send() — outbound messages +# --------------------------------------------------------------------------- + + +class TestSend: + def _make_send_channel(self) -> tuple[SignalChannel, _FakeHTTPClient]: + ch = _make_channel() + client = _FakeHTTPClient() + ch._http = client # type: ignore[assignment] + return ch, client + + @pytest.mark.asyncio + async def test_send_plain_text_posts_rpc(self): + ch, client = self._make_send_channel() + msg = OutboundMessage(channel="signal", chat_id="+19995550001", content="hello") + await ch.send(msg) + assert len(client.posts) == 1 + payload = client.posts[0]["json"] + assert payload["method"] == "send" + assert payload["params"]["message"] == "hello" + + @pytest.mark.asyncio + async def test_send_with_markdown_includes_text_styles(self): + ch, client = self._make_send_channel() + msg = OutboundMessage(channel="signal", chat_id="+19995550001", content="**bold**") + await ch.send(msg) + params = client.posts[0]["json"]["params"] + assert "textStyle" in params + assert any("BOLD" in s for s in params["textStyle"]) + + @pytest.mark.asyncio + async def test_send_split_message_redistributes_text_styles(self): + """Long message split across chunks: each chunk gets its own textStyle + with offsets rebased to that chunk.""" + ch, client = self._make_send_channel() + ch._MAX_MESSAGE_LEN = 12 # type: ignore[attr-defined] + msg = OutboundMessage( + channel="signal", + chat_id="+19995550001", + content="**head** middle and **tail**", + ) + await ch.send(msg) + assert len(client.posts) >= 2 + # Chunk 0 has BOLD for "head"; chunk 1+ must also carry BOLD for "tail". + bold_chunks = [ + p["json"]["params"] + for p in client.posts + if any("BOLD" in s for s in p["json"]["params"].get("textStyle", [])) + ] + assert len(bold_chunks) >= 2, ( + "expected BOLD ranges in more than one chunk; got " + f"{[p['json']['params'] for p in client.posts]}" + ) + # Each emitted range must point inside its own chunk's text. + for params in bold_chunks: + chunk_text = params["message"] + for entry in params["textStyle"]: + s, ln, _ = entry.split(":", 2) + start, length = int(s), int(ln) + end_units = start + length + assert end_units <= len(chunk_text.encode("utf-16-le")) // 2 + + @pytest.mark.asyncio + async def test_send_empty_content_skips_rpc(self): + ch, client = self._make_send_channel() + msg = OutboundMessage(channel="signal", chat_id="+19995550001", content="") + await ch.send(msg) + assert client.posts == [] + + @pytest.mark.asyncio + async def test_send_to_group_uses_group_id(self): + ch, client = self._make_send_channel() + msg = OutboundMessage(channel="signal", chat_id="grp==", content="hi group") + await ch.send(msg) + params = client.posts[0]["json"]["params"] + assert "groupId" in params + assert "recipient" not in params + + @pytest.mark.asyncio + async def test_send_to_dm_uses_recipient(self): + ch, client = self._make_send_channel() + msg = OutboundMessage(channel="signal", chat_id="+19995550001", content="hi") + await ch.send(msg) + params = client.posts[0]["json"]["params"] + assert "recipient" in params + + @pytest.mark.asyncio + async def test_send_with_media_includes_attachments(self): + ch, client = self._make_send_channel() + msg = OutboundMessage( + channel="signal", + chat_id="+19995550001", + content="see attachment", + media=["/tmp/file.jpg"], + ) + await ch.send(msg) + params = client.posts[0]["json"]["params"] + assert params.get("attachments") == ["/tmp/file.jpg"] + + @pytest.mark.asyncio + async def test_send_progress_message_does_not_stop_typing(self): + ch, client = self._make_send_channel() + stopped: list[str] = [] + + async def record_stop(chat_id, **kwargs): + stopped.append(chat_id) + + ch._stop_typing = record_stop # type: ignore[method-assign] + msg = OutboundMessage( + channel="signal", + chat_id="+19995550001", + content="working...", + metadata={"_progress": True}, + ) + await ch.send(msg) + # Progress messages should NOT stop the typing indicator + assert stopped == [] + + @pytest.mark.asyncio + async def test_send_final_message_stops_typing(self): + ch, client = self._make_send_channel() + stopped: list[str] = [] + + async def record_stop(chat_id, send_stop=True): + stopped.append(chat_id) + + ch._stop_typing = record_stop # type: ignore[method-assign] + msg = OutboundMessage(channel="signal", chat_id="+19995550001", content="done") + await ch.send(msg) + assert "+19995550001" in stopped + + @pytest.mark.asyncio + async def test_send_raises_on_daemon_error(self): + # _send_http_request turns every exception into {"error": ...}, so this branch + # is the only place ChannelManager retry can be triggered — must raise. + ch = _make_channel() + ch._http = _FakeHTTPClient(default_response={"error": {"message": "fail"}}) + msg = OutboundMessage(channel="signal", chat_id="+19995550001", content="hello") + with pytest.raises(RuntimeError, match="signal-cli send failed"): + await ch.send(msg) + + +# --------------------------------------------------------------------------- +# stop() +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_stop_cancels_sse_task() -> None: + ch = _make_channel() + cancelled = False + + async def long_running(): + nonlocal cancelled + try: + await asyncio.sleep(9999) + except asyncio.CancelledError: + cancelled = True + raise + + ch._sse_task = asyncio.create_task(long_running()) + # Yield so the task can enter its body (reach the first await) before cancel. + await asyncio.sleep(0) + ch._running = True + + await ch.stop() + + assert cancelled + assert ch._running is False + + +@pytest.mark.asyncio +async def test_stop_closes_http_client() -> None: + ch = _make_channel() + client = _FakeHTTPClient() + ch._http = client # type: ignore[assignment] + ch._running = True + + await ch.stop() + + assert client.closed + + +@pytest.mark.asyncio +async def test_stop_safe_when_no_sse_task() -> None: + ch = _make_channel() + ch._running = True + # Should not raise even with no _sse_task + await ch.stop() + assert ch._running is False + + +# --------------------------------------------------------------------------- +# _send_request / _send_http_request +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_send_request_increments_id() -> None: + ch = _make_channel() + client = _FakeHTTPClient() + ch._http = client # type: ignore[assignment] + + await ch._send_request("testMethod", {"key": "val"}) + await ch._send_request("testMethod", {"key": "val"}) + + ids = [p["json"]["id"] for p in client.posts] + assert ids == [1, 2] + + +@pytest.mark.asyncio +async def test_send_request_raises_when_not_connected() -> None: + ch = _make_channel() + # _http is None by default + with pytest.raises(RuntimeError, match="Not connected"): + await ch._send_request("testMethod") + + +# --------------------------------------------------------------------------- +# _handle_receive_notification — envelope shapes +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_handle_notification_sync_message_does_not_forward() -> None: + ch = _make_channel(dm_enabled=True, dm_policy="open") + handled: list[dict] = [] + ch._handle_message = lambda **kw: handled.append(kw) # type: ignore[method-assign] + + notification = { + "envelope": { + "sourceNumber": "+19995550001", + "syncMessage": { + "sentMessage": { + "destination": "+19990000000", + "message": "sent from other device", + } + }, + } + } + await ch._handle_receive_notification(notification) + assert handled == [] + + +@pytest.mark.asyncio +async def test_handle_notification_no_source_skipped() -> None: + ch = _make_channel(dm_enabled=True, dm_policy="open") + handled: list[dict] = [] + ch._handle_message = lambda **kw: handled.append(kw) # type: ignore[method-assign] + + notification = {"envelope": {"dataMessage": {"message": "ghost"}}} + await ch._handle_receive_notification(notification) + assert handled == [] + + +# --------------------------------------------------------------------------- +# Config: allow_from property aggregation +# --------------------------------------------------------------------------- + + +def test_config_allow_from_aggregates_dm_and_group() -> None: + config = SignalConfig( + enabled=True, + phone_number="+10000000000", + dm=SignalDMConfig(enabled=True, policy="allowlist", allow_from=["+1111", "+2222"]), + group=SignalGroupConfig(enabled=True, policy="allowlist", allow_from=["+3333", "+1111"]), + ) + combined = config.allow_from + assert "+1111" in combined + assert "+2222" in combined + assert "+3333" in combined + # Duplicates removed + assert combined.count("+1111") == 1 + + +def test_config_allow_from_wildcard_propagates() -> None: + config = SignalConfig( + enabled=True, + phone_number="+10000000000", + dm=SignalDMConfig(enabled=True, policy="open", allow_from=["*"]), + group=SignalGroupConfig(enabled=True, policy="open", allow_from=[]), + ) + assert "*" in config.allow_from diff --git a/tests/channels/test_signal_markdown.py b/tests/channels/test_signal_markdown.py new file mode 100644 index 000000000..37a21c6d8 --- /dev/null +++ b/tests/channels/test_signal_markdown.py @@ -0,0 +1,525 @@ +"""Unit tests for the Signal markdown → plain text + textStyle converter.""" + +from nanobot.channels.signal import _markdown_to_signal, _partition_styles +from nanobot.utils.helpers import split_message + + +def _utf16_len(s: str) -> int: + return len(s.encode("utf-16-le")) // 2 + + +def styles_for(plain: str, text_styles: list[str]) -> dict[str, list[str]]: + """Return a dict mapping each styled substring to its style list.""" + result: dict[str, list[str]] = {} + for entry in text_styles: + start_s, length_s, style = entry.split(":", 2) + start, length = int(start_s), int(length_s) + span = plain[start : start + length] + result.setdefault(span, []).append(style) + return result + + +def utf16_styles_for(plain: str, text_styles: list[str]) -> dict[str, list[str]]: + """Like styles_for, but slices `plain` using UTF-16 offsets (Signal's units).""" + encoded = plain.encode("utf-16-le") + result: dict[str, list[str]] = {} + for entry in text_styles: + start_s, length_s, style = entry.split(":", 2) + start, length = int(start_s), int(length_s) + span = encoded[start * 2 : (start + length) * 2].decode("utf-16-le") + result.setdefault(span, []).append(style) + return result + + +# --------------------------------------------------------------------------- +# Basic cases +# --------------------------------------------------------------------------- + + +def test_empty(): + plain, styles = _markdown_to_signal("") + assert plain == "" + assert styles == [] + + +def test_plain_text(): + plain, styles = _markdown_to_signal("hello world") + assert plain == "hello world" + assert styles == [] + + +def test_bold_stars(): + plain, styles = _markdown_to_signal("say **hello** now") + assert plain == "say hello now" + assert styles_for(plain, styles) == {"hello": ["BOLD"]} + + +def test_bold_underscores(): + plain, styles = _markdown_to_signal("say __hello__ now") + assert plain == "say hello now" + assert styles_for(plain, styles) == {"hello": ["BOLD"]} + + +def test_italic_star(): + plain, styles = _markdown_to_signal("say *hello* now") + assert plain == "say hello now" + assert styles_for(plain, styles) == {"hello": ["ITALIC"]} + + +def test_italic_underscore(): + plain, styles = _markdown_to_signal("say _hello_ now") + assert plain == "say hello now" + assert styles_for(plain, styles) == {"hello": ["ITALIC"]} + + +def test_strikethrough(): + plain, styles = _markdown_to_signal("say ~~hello~~ now") + assert plain == "say hello now" + assert styles_for(plain, styles) == {"hello": ["STRIKETHROUGH"]} + + +# --------------------------------------------------------------------------- +# Code +# --------------------------------------------------------------------------- + + +def test_inline_code(): + plain, styles = _markdown_to_signal("run `ls -la` here") + assert plain == "run ls -la here" + assert styles_for(plain, styles) == {"ls -la": ["MONOSPACE"]} + + +def test_code_block(): + plain, styles = _markdown_to_signal("```\nprint('hi')\n```") + assert "print('hi')" in plain + assert styles_for(plain, styles).get("print('hi')\n") == ["MONOSPACE"] or "MONOSPACE" in str( + styles_for(plain, styles) + ) + + +def test_code_block_with_lang(): + plain, styles = _markdown_to_signal("```python\ncode\n```") + assert "code" in plain + assert any("MONOSPACE" in s for s in styles) + + +def test_code_block_not_processed_further(): + """Markdown inside a code block must not be styled.""" + plain, styles = _markdown_to_signal("```\n**not bold**\n```") + assert "**not bold**" in plain + # Only MONOSPACE should be applied, no BOLD + for entry in styles: + assert "BOLD" not in entry + + +def test_inline_code_not_processed_further(): + """Markdown inside inline code must not be styled.""" + plain, styles = _markdown_to_signal("use `**raw**` please") + assert "**raw**" in plain + for entry in styles: + assert "BOLD" not in entry + + +# --------------------------------------------------------------------------- +# Headers +# --------------------------------------------------------------------------- + + +def test_header_becomes_bold(): + plain, styles = _markdown_to_signal("# My Title") + assert plain == "My Title" + assert styles_for(plain, styles) == {"My Title": ["BOLD"]} + + +def test_h2_becomes_bold(): + plain, styles = _markdown_to_signal("## Sub-section") + assert plain == "Sub-section" + assert styles_for(plain, styles) == {"Sub-section": ["BOLD"]} + + +# --------------------------------------------------------------------------- +# Blockquotes +# --------------------------------------------------------------------------- + + +def test_blockquote_strips_marker(): + plain, styles = _markdown_to_signal("> some quote") + assert plain == "some quote" + assert styles == [] + + +# --------------------------------------------------------------------------- +# Lists +# --------------------------------------------------------------------------- + + +def test_bullet_dash(): + plain, styles = _markdown_to_signal("- item one") + assert plain == "• item one" + + +def test_bullet_star(): + plain, styles = _markdown_to_signal("* item two") + assert plain == "• item two" + + +def test_numbered_list(): + plain, styles = _markdown_to_signal("1. first\n2. second") + assert "1. first" in plain + assert "2. second" in plain + + +# --------------------------------------------------------------------------- +# Links +# --------------------------------------------------------------------------- + + +def test_link_text_differs_from_url(): + plain, styles = _markdown_to_signal("[Click here](https://example.com)") + assert plain == "Click here (https://example.com)" + assert styles == [] + + +def test_link_text_equals_url(): + plain, styles = _markdown_to_signal("[https://example.com](https://example.com)") + assert plain == "https://example.com" + assert styles == [] + + +def test_link_text_equals_url_without_scheme(): + plain, styles = _markdown_to_signal("[example.com](https://example.com)") + assert plain == "https://example.com" + + +# --------------------------------------------------------------------------- +# Mixed / nesting +# --------------------------------------------------------------------------- + + +def test_bold_and_italic_adjacent(): + plain, styles = _markdown_to_signal("**bold** and *italic*") + assert plain == "bold and italic" + sd = styles_for(plain, styles) + assert sd.get("bold") == ["BOLD"] + assert sd.get("italic") == ["ITALIC"] + + +def test_header_with_inline_code(): + """Header becomes BOLD; code inside becomes MONOSPACE (not double-BOLD).""" + plain, styles = _markdown_to_signal("# Use `grep`") + assert plain == "Use grep" + sd = styles_for(plain, styles) + assert "BOLD" in sd.get("Use ", []) or "BOLD" in str(styles) + assert "MONOSPACE" in sd.get("grep", []) + + +def test_multiline_mixed(): + md = "**Title**\n\nSome *italic* text.\n\n- bullet\n- another" + plain, styles = _markdown_to_signal(md) + assert "Title" in plain + assert "italic" in plain + assert "• bullet" in plain + sd = styles_for(plain, styles) + assert "BOLD" in sd.get("Title", []) + assert "ITALIC" in sd.get("italic", []) + + +# --------------------------------------------------------------------------- +# Table rendering +# --------------------------------------------------------------------------- + + +def test_table_rendered_as_monospace(): + md = "| A | B |\n| - | - |\n| 1 | 2 |" + plain, styles = _markdown_to_signal(md) + assert "A" in plain and "B" in plain + assert any("MONOSPACE" in s for s in styles) + + +# --------------------------------------------------------------------------- +# Style range format +# --------------------------------------------------------------------------- + + +def test_style_range_format(): + """Each style entry must be 'start:length:STYLE'.""" + _, styles = _markdown_to_signal("**bold** text") + for entry in styles: + parts = entry.split(":") + assert len(parts) == 3 + assert parts[0].isdigit() + assert parts[1].isdigit() + assert parts[2] in {"BOLD", "ITALIC", "STRIKETHROUGH", "MONOSPACE", "SPOILER"} + + +def test_style_ranges_are_within_bounds(): + text = "hello **world** end" + plain, styles = _markdown_to_signal(text) + for entry in styles: + start_s, length_s, _ = entry.split(":", 2) + start, length = int(start_s), int(length_s) + assert start >= 0 + assert start + length <= len(plain) + + +# --------------------------------------------------------------------------- +# Non-BMP / UTF-16 offsets +# +# Signal's BodyRange (and signal-cli's textStyle) interprets start/length in +# UTF-16 code units. Python's len() counts code points, so characters outside +# the BMP (emojis, supplementary CJK) shift offsets by +1 per occurrence. +# --------------------------------------------------------------------------- + + +def assert_within_utf16_bounds(plain: str, styles: list[str]) -> None: + limit = _utf16_len(plain) + for entry in styles: + start_s, length_s, _ = entry.split(":", 2) + start, length = int(start_s), int(length_s) + assert start >= 0 + assert start + length <= limit, f"range {entry} exceeds utf-16 length {limit} of {plain!r}" + + +def test_bold_with_emoji_inside(): + plain, styles = _markdown_to_signal("**hi 🎉 bye**") + assert plain == "hi 🎉 bye" + assert utf16_styles_for(plain, styles) == {"hi 🎉 bye": ["BOLD"]} + assert_within_utf16_bounds(plain, styles) + + +def test_italic_with_trailing_emoji(): + plain, styles = _markdown_to_signal("*bye 🎉*") + assert plain == "bye 🎉" + assert utf16_styles_for(plain, styles) == {"bye 🎉": ["ITALIC"]} + assert_within_utf16_bounds(plain, styles) + + +def test_bold_after_emoji_prefix(): + plain, styles = _markdown_to_signal("🎉 **bold**") + assert plain == "🎉 bold" + assert utf16_styles_for(plain, styles) == {"bold": ["BOLD"]} + assert_within_utf16_bounds(plain, styles) + + +def test_bold_after_and_inside_emoji(): + plain, styles = _markdown_to_signal("🎉 **a 🎊 b**") + assert plain == "🎉 a 🎊 b" + assert utf16_styles_for(plain, styles) == {"a 🎊 b": ["BOLD"]} + assert_within_utf16_bounds(plain, styles) + + +def test_supplementary_cjk_in_bold(): + """Non-BMP CJK (U+20BB7) proves the bug is UTF-16, not emoji-specific.""" + plain, styles = _markdown_to_signal("**𠮷野家**") + assert plain == "𠮷野家" + assert utf16_styles_for(plain, styles) == {"𠮷野家": ["BOLD"]} + assert_within_utf16_bounds(plain, styles) + + +def test_zwj_emoji_in_bold(): + """ZWJ family sequence = multiple surrogate pairs + BMP ZWJs.""" + plain, styles = _markdown_to_signal("**hi 👨‍👩‍👧 bye**") + assert plain == "hi 👨‍👩‍👧 bye" + assert utf16_styles_for(plain, styles) == {"hi 👨‍👩‍👧 bye": ["BOLD"]} + assert_within_utf16_bounds(plain, styles) + + +def test_ascii_offsets_unchanged(): + """ASCII-only path must produce the same offsets as before the UTF-16 fix.""" + plain, styles = _markdown_to_signal("**bold** plain *it*") + assert plain == "bold plain it" + assert sorted(styles) == sorted(["0:4:BOLD", "11:2:ITALIC"]) + + +def test_reported_daily_brief_pattern(): + """Regression for the reported bug: a single non-BMP emoji shifts every + subsequent styled span left by 1 UTF-16 unit, lopping off the last letter. + """ + md = ( + "**Weather**\n" + "- Conditions: 🌩️ Thunderstorms\n\n" + "**News**\n" + "*World*\n" + "*Local*\n\n" + "**Quote of the Day**" + ) + plain, styles = _markdown_to_signal(md) + sd = utf16_styles_for(plain, styles) + assert sd.get("Weather") == ["BOLD"] + assert sd.get("News") == ["BOLD"] + assert sd.get("World") == ["ITALIC"] + assert sd.get("Local") == ["ITALIC"] + assert sd.get("Quote of the Day") == ["BOLD"] + assert_within_utf16_bounds(plain, styles) + + +# --------------------------------------------------------------------------- +# Chunk redistribution +# +# split_message can break a long Signal payload into multiple chunks. The +# style ranges from _markdown_to_signal are anchored to the full text, so +# they must be redistributed per-chunk with rebased offsets — otherwise +# styles for chunks 1..N are silently lost. +# --------------------------------------------------------------------------- + + +def _resolve_chunk_styles(text: str, max_len: int) -> tuple[list[str], list[list[str]]]: + """Helper: full markdown → signal pipeline, including chunking.""" + plain, styles = _markdown_to_signal(text) + chunks = split_message(plain, max_len) if plain else [""] + return chunks, _partition_styles(plain, chunks, styles) + + +def test_partition_styles_single_chunk_passthrough(): + plain, styles = _markdown_to_signal("**bold** plain *it*") + parts = _partition_styles(plain, [plain], styles) + assert parts == [styles] + + +def test_partition_styles_no_styles(): + plain = "hello world" + assert _partition_styles(plain, [plain], []) == [[]] + assert _partition_styles(plain, ["hello", "world"], []) == [[], []] + + +def test_partition_styles_drops_styles_outside_chunks(): + """Whitespace trimmed by split_message must not carry a style range.""" + plain = "a b" + # Fake a style spanning the trimmed whitespace only. + chunks = ["a", "b"] + parts = _partition_styles(plain, chunks, ["1:3:BOLD"]) + assert parts == [[], []] + + +def test_partition_styles_long_message_preserves_chunk_one_styles(): + """A bold span deep in the message must follow the message into chunk 1.""" + # Two ~30-char paragraphs separated by a blank line, then **tail**. + line_a = "alpha " * 5 # 30 chars, ends with space + line_b = "beta " * 5 + md = f"{line_a.strip()}\n\n{line_b.strip()}\n\n**tail**" + plain, styles = _markdown_to_signal(md) + # Force a split between the paragraphs. + max_len = len(line_a.strip()) + 2 # fits paragraph A + the "\n\n" + chunks = split_message(plain, max_len) + assert len(chunks) >= 2, "test setup must produce a split" + parts = _partition_styles(plain, chunks, styles) + # The bold "tail" should land in the last chunk, with chunk-relative offset. + final_chunk = chunks[-1] + final_styles = parts[-1] + assert any("BOLD" in s for s in final_styles) + for entry in final_styles: + s, ln, _ = entry.split(":", 2) + start, length = int(s), int(ln) + slice_ = final_chunk.encode("utf-16-le")[start * 2 : (start + length) * 2].decode( + "utf-16-le" + ) + assert slice_ == "tail" + + +def test_partition_styles_chunk_zero_styles_unchanged(): + """Styles entirely in chunk 0 keep their original offsets.""" + md = "**head** middle and **tail**" + plain, styles = _markdown_to_signal(md) + # Split so chunk 0 contains "head" and part of the rest, chunk 1 contains "tail". + chunks = split_message(plain, 12) + assert len(chunks) >= 2 + parts = _partition_styles(plain, chunks, styles) + # "head" lives in chunk 0; assert its offset is unchanged (chunk 0 starts at 0). + head_entries = [s for s in parts[0] if "BOLD" in s] + assert any(s.startswith("0:4:") for s in head_entries) + + +def test_partition_styles_with_non_bmp_chunk_offset(): + """Chunk-start offsets must be expressed in UTF-16 code units.""" + # Emoji in chunk 0, bold in chunk 1. + md = "🎉 alpha beta gamma\n\n**tail**" + plain, styles = _markdown_to_signal(md) + chunks = split_message(plain, 18) + assert len(chunks) >= 2 + parts = _partition_styles(plain, chunks, styles) + final_styles = parts[-1] + assert any("BOLD" in s for s in final_styles) + final_chunk = chunks[-1] + for entry in final_styles: + s, ln, _ = entry.split(":", 2) + start, length = int(s), int(ln) + slice_ = final_chunk.encode("utf-16-le")[start * 2 : (start + length) * 2].decode( + "utf-16-le" + ) + assert slice_ == "tail" + + +def test_partition_styles_range_spanning_chunks_is_split(): + """A style range that straddles a chunk boundary gets sliced into both chunks.""" + # Construct manually: plain = "abc def", style covers "abc def" (whole thing). + plain = "abc def" + chunks = split_message(plain, 4) # "abc" / "def" + assert chunks == ["abc", "def"] + parts = _partition_styles(plain, chunks, ["0:7:BOLD"]) + # Chunk 0 holds 0:3:BOLD, chunk 1 holds 0:3:BOLD (length=3 each, "def" only + # since the space was trimmed by lstrip). + assert parts[0] == ["0:3:BOLD"] + assert parts[1] == ["0:3:BOLD"] + + +# --------------------------------------------------------------------------- +# Adjacency, nesting, and malformed input +# --------------------------------------------------------------------------- + + +def test_bold_italic_combo_outer_bold_inner_italic(): + """`**_combo_**` carries both BOLD and ITALIC over the same span.""" + plain, styles = _markdown_to_signal("**_combo_**") + assert plain == "combo" + sd = styles_for(plain, styles) + assert set(sd.get("combo", [])) == {"BOLD", "ITALIC"} + + +def test_bold_and_italic_adjacent_no_separator(): + """`**bold***italic*` produces BOLD on `bold` and ITALIC on `italic`.""" + plain, styles = _markdown_to_signal("**bold***italic*") + assert plain == "bolditalic" + sd = styles_for(plain, styles) + assert sd.get("bold") == ["BOLD"] + assert sd.get("italic") == ["ITALIC"] + + +def test_unclosed_bold_falls_through_as_plain(): + """An unmatched `**` opener round-trips as literal text with no style.""" + plain, styles = _markdown_to_signal("**bold") + assert plain == "**bold" + assert styles == [] + + +def test_unclosed_inline_code_falls_through_as_plain(): + """An unmatched backtick round-trips as literal text with no style.""" + plain, styles = _markdown_to_signal("use `grep") + assert plain == "use `grep" + assert styles == [] + + +def test_inline_code_inside_blockquote(): + """Blockquote prefix is stripped; inline code becomes MONOSPACE.""" + plain, styles = _markdown_to_signal("> use `grep`") + assert plain == "use grep" + sd = styles_for(plain, styles) + assert sd.get("grep") == ["MONOSPACE"] + + +def test_header_with_inner_bold_produces_contiguous_bold_ranges(): + """`# **wrap** me` — header forces BOLD over the whole line; the inner `**` + splits the run, yielding two contiguous BOLD ranges that together cover + "wrap me". This is intentional — Signal renders adjacent same-style ranges + as a single visual span. + """ + plain, styles = _markdown_to_signal("# **wrap** me") + assert plain == "wrap me" + # Both ranges are BOLD; collectively they cover the whole "wrap me". + bold_ranges = [s for s in styles if s.endswith(":BOLD")] + assert len(bold_ranges) == 2 + covered = set() + for entry in bold_ranges: + start, length, _ = entry.split(":", 2) + for i in range(int(start), int(start) + int(length)): + covered.add(i) + assert covered == set(range(len(plain))) diff --git a/tests/providers/test_custom_provider.py b/tests/providers/test_custom_provider.py index 85314dc79..ee1f9a090 100644 --- a/tests/providers/test_custom_provider.py +++ b/tests/providers/test_custom_provider.py @@ -56,6 +56,35 @@ def test_custom_provider_parse_chunks_accepts_plain_text_chunks() -> None: assert result.content == "hello world" +def test_custom_provider_parse_chunks_deduplicates_parallel_tool_call_ids() -> None: + chunks = [{ + "choices": [{ + "finish_reason": "tool_calls", + "delta": { + "tool_calls": [ + { + "index": 0, + "id": "call_dup", + "function": {"name": "read_file", "arguments": '{"path":"a.txt"}'}, + }, + { + "index": 1, + "id": "call_dup", + "function": {"name": "read_file", "arguments": '{"path":"b.txt"}'}, + }, + ], + }, + }], + }] + + result = OpenAICompatProvider._parse_chunks(chunks) + ids = [tool_call.id for tool_call in result.tool_calls or []] + + assert ids[0] == "call_dup" + assert len(ids) == 2 + assert len(set(ids)) == 2 + + def test_local_provider_502_error_includes_reachability_hint() -> None: spec = find_by_name("ollama") with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): diff --git a/tests/utils/test_file_edit_events.py b/tests/utils/test_file_edit_events.py index 3ac4dc929..ec1046061 100644 --- a/tests/utils/test_file_edit_events.py +++ b/tests/utils/test_file_edit_events.py @@ -5,13 +5,13 @@ from pathlib import Path from types import SimpleNamespace from nanobot.utils.file_edit_events import ( + StreamingFileEditTracker, build_file_edit_end_event, build_file_edit_start_event, line_diff_stats, prepare_file_edit_tracker, prepare_file_edit_trackers, read_file_snapshot, - StreamingFileEditTracker, ) @@ -374,6 +374,43 @@ def test_streaming_tracker_applies_canonical_call_id_to_final_tool(tmp_path: Pat asyncio.run(run()) +def test_streaming_tracker_does_not_restore_duplicate_canonical_ids(tmp_path: Path) -> None: + events: list[dict] = [] + + async def emit(batch: list[dict]) -> None: + events.extend(batch) + + async def run() -> None: + tracker = StreamingFileEditTracker(workspace=tmp_path, tools={}, emit=emit) + await tracker.update({ + "index": 0, + "call_id": "call_dup", + "name": "write_file", + "arguments_delta": '{"path":"a.md","content":"one\\n"}', + }) + await tracker.update({ + "index": 1, + "call_id": "call_dup", + "name": "write_file", + "arguments_delta": '{"path":"b.md","content":"two\\n"}', + }) + final_a = SimpleNamespace( + id="call_dup", + name="write_file", + arguments={"path": "a.md", "content": "one\n"}, + ) + final_b = SimpleNamespace( + id="call_unique", + name="write_file", + arguments={"path": "b.md", "content": "two\n"}, + ) + tracker.apply_final_call_ids([final_a, final_b]) + assert final_a.id == "call_dup" + assert final_b.id == "call_unique" + + asyncio.run(run()) + + def test_streaming_edit_file_tracker_flushes_small_pending_count(tmp_path: Path) -> None: target = tmp_path / "small.py" target.write_text("old\n", encoding="utf-8") diff --git a/webui/src/hooks/useSessions.ts b/webui/src/hooks/useSessions.ts index c22751c65..7b468fc89 100644 --- a/webui/src/hooks/useSessions.ts +++ b/webui/src/hooks/useSessions.ts @@ -27,13 +27,25 @@ export function useSessions(): { const [loading, setLoading] = useState(true); const [error, setError] = useState(null); const tokenRef = useRef(token); + const optimisticKeysRef = useRef>(new Set()); tokenRef.current = token; const refresh = useCallback(async () => { try { setLoading(true); const rows = await listSessions(tokenRef.current); - setSessions(rows); + const serverKeys = new Set(rows.map((row) => row.key)); + setSessions((prev) => [ + ...rows, + ...prev.filter( + (session) => + optimisticKeysRef.current.has(session.key) && + !serverKeys.has(session.key), + ), + ]); + for (const key of Array.from(optimisticKeysRef.current)) { + if (serverKeys.has(key)) optimisticKeysRef.current.delete(key); + } setError(null); } catch (e) { const msg = @@ -57,6 +69,7 @@ export function useSessions(): { const createChat = useCallback(async (): Promise => { const chatId = await client.newChat(); const key = `websocket:${chatId}`; + optimisticKeysRef.current.add(key); // Optimistic insert; a subsequent refresh will replace it with the // authoritative row once the server persists the session. setSessions((prev) => [ @@ -77,6 +90,7 @@ export function useSessions(): { const deleteChat = useCallback( async (key: string) => { await apiDeleteSession(tokenRef.current, key); + optimisticKeysRef.current.delete(key); setSessions((prev) => prev.filter((s) => s.key !== key)); }, [], diff --git a/webui/src/tests/useSessions.test.tsx b/webui/src/tests/useSessions.test.tsx index 72df813e0..8e76e697e 100644 --- a/webui/src/tests/useSessions.test.tsx +++ b/webui/src/tests/useSessions.test.tsx @@ -157,6 +157,53 @@ describe("useSessions", () => { expect(api.listSessions).toHaveBeenCalledTimes(2); }); + it("keeps a newly created chat visible until the server session list catches up", async () => { + vi.mocked(api.listSessions) + .mockResolvedValueOnce([]) + .mockResolvedValueOnce([]) + .mockResolvedValueOnce([ + { + key: "websocket:chat-new", + channel: "websocket", + chatId: "chat-new", + createdAt: "2026-05-20T10:00:00Z", + updatedAt: "2026-05-20T10:01:00Z", + title: "Generated title", + preview: "First message", + }, + ]); + const client = fakeClient(); + client.newChat.mockResolvedValue("chat-new"); + + const { result } = renderHook(() => useSessions(), { + wrapper: wrap(client), + }); + + await waitFor(() => expect(result.current.loading).toBe(false)); + expect(result.current.sessions).toEqual([]); + + await act(async () => { + await result.current.createChat(); + }); + + expect(result.current.sessions.map((s) => s.key)).toEqual(["websocket:chat-new"]); + + await act(async () => { + await result.current.refresh(); + }); + + expect(result.current.sessions.map((s) => s.key)).toEqual(["websocket:chat-new"]); + expect(result.current.sessions[0]?.preview).toBe(""); + + await act(async () => { + await result.current.refresh(); + }); + + expect(result.current.sessions.map((s) => s.key)).toEqual(["websocket:chat-new"]); + expect(result.current.sessions[0]?.preview).toBe("First message"); + expect(result.current.sessions[0]?.title).toBe("Generated title"); + }); + it("passes through WebUI transcript user media as images and media", async () => { vi.mocked(api.fetchWebuiThread).mockResolvedValue({ schemaVersion: 3,