refactor(signal): hygiene cleanups around constants, typing, and config

- Hoist the cell-strip patterns to module level so they match the rest of
  the module's regex style and aren't reparsed on every call.
- Type the markdown transform callback and the mention id walker so the
  inline Callable signature is no longer an untyped Any.
- Add _HTTP_TIMEOUT_SECONDS alongside the other class-level tunables.
- Reject group_message_buffer_size <= 0 in a Pydantic field_validator
  rather than silently disabling the buffer at write time.
- Mark SignalConfig.allow_from as a computed_field so it shows up in
  model_dump() instead of being invisible to serialization.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
Kaloyan Tenchov 2026-05-16 11:23:12 -04:00 committed by chengyongru
parent 6ec6c9bb83
commit d56bafa6d0
2 changed files with 40 additions and 15 deletions

View File

@ -8,12 +8,13 @@ import re
import shutil import shutil
import unicodedata import unicodedata
from collections import deque from collections import deque
from collections.abc import Callable
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
import httpx import httpx
from pydantic import Field from pydantic import Field, computed_field, field_validator
from nanobot.bus.events import OutboundMessage from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus from nanobot.bus.queue import MessageBus
@ -42,6 +43,18 @@ _SIG_ITALIC_RE = re.compile(r'(?<!\*)\*([^*\n]+)\*(?!\*)|(?<![a-zA-Z0-9_])_([^_\
_SIG_STRIKE_RE = re.compile(r'~~(.+?)~~|(?<![~\w])~([^~\n]+)~(?![~\w])', re.DOTALL) _SIG_STRIKE_RE = re.compile(r'~~(.+?)~~|(?<![~\w])~([^~\n]+)~(?![~\w])', re.DOTALL)
_SIG_TOKEN_RE = re.compile(r'\x00C(\d+)\x00') _SIG_TOKEN_RE = re.compile(r'\x00C(\d+)\x00')
# Patterns used to strip inline markdown when rendering table cells as plain
# text. Defined separately from the styling regexes above because the cell
# stripper needs a fixed, narrow subset (no single-asterisk italic, no
# single-tilde strikethrough) and benefits from each pattern's group 1 being
# the content directly.
_SIG_CELL_STRIP_PATTERNS: tuple[tuple[re.Pattern, str], ...] = (
(re.compile(r'\*\*(.+?)\*\*'), r'\1'),
(re.compile(r'__(.+?)__'), r'\1'),
(re.compile(r'~~(.+?)~~'), r'\1'),
(re.compile(r'`([^`]+)`'), r'\1'),
)
def _utf16_len(s: str) -> int: def _utf16_len(s: str) -> int:
"""UTF-16 code-unit length, matching Signal BodyRange semantics.""" """UTF-16 code-unit length, matching Signal BodyRange semantics."""
@ -50,10 +63,8 @@ def _utf16_len(s: str) -> int:
def _sig_strip_cell(s: str) -> str: def _sig_strip_cell(s: str) -> str:
"""Strip inline markdown from a table cell for plain-text rendering.""" """Strip inline markdown from a table cell for plain-text rendering."""
s = re.sub(r'\*\*(.+?)\*\*', r'\1', s) for pattern, repl in _SIG_CELL_STRIP_PATTERNS:
s = re.sub(r'__(.+?)__', r'\1', s) s = pattern.sub(repl, s)
s = re.sub(r'~~(.+?)~~', r'\1', s)
s = re.sub(r'`([^`]+)`', r'\1', s)
return s.strip() return s.strip()
@ -132,7 +143,10 @@ def _markdown_to_signal(text: str) -> tuple[str, list[str]]:
# Phase 2 (run-based): process inline patterns. # Phase 2 (run-based): process inline patterns.
runs: list[_Run] = [_Run(text)] runs: list[_Run] = [_Run(text)]
def transform(pattern: re.Pattern, make_runs: Any) -> None: def transform(
pattern: re.Pattern,
make_runs: Callable[[re.Match, frozenset[str]], list[_Run]],
) -> None:
new_runs: list[_Run] = [] new_runs: list[_Run] = []
for run in runs: for run in runs:
if run.opaque: if run.opaque:
@ -284,6 +298,14 @@ class SignalConfig(Base):
dm: SignalDMConfig = Field(default_factory=SignalDMConfig) dm: SignalDMConfig = Field(default_factory=SignalDMConfig)
group: SignalGroupConfig = Field(default_factory=SignalGroupConfig) group: SignalGroupConfig = Field(default_factory=SignalGroupConfig)
@field_validator("group_message_buffer_size")
@classmethod
def _validate_buffer_size(cls, v: int) -> int:
if v <= 0:
raise ValueError("group_message_buffer_size must be > 0")
return v
@computed_field # type: ignore[prop-decorator]
@property @property
def allow_from(self) -> list[str]: def allow_from(self) -> list[str]:
"""Aggregate allowlist for the base-class is_allowed() check. """Aggregate allowlist for the base-class is_allowed() check.
@ -309,6 +331,7 @@ class SignalChannel(BaseChannel):
display_name = "Signal" display_name = "Signal"
_TYPING_REFRESH_SECONDS = 10.0 _TYPING_REFRESH_SECONDS = 10.0
_MAX_MESSAGE_LEN = 64_000 # signal-cli practical limit (protocol max ~64 KB) _MAX_MESSAGE_LEN = 64_000 # signal-cli practical limit (protocol max ~64 KB)
_HTTP_TIMEOUT_SECONDS = 60.0
@classmethod @classmethod
def default_config(cls) -> dict[str, Any]: def default_config(cls) -> dict[str, Any]:
@ -351,7 +374,9 @@ class SignalChannel(BaseChannel):
self.logger.info(f"Connecting to signal-cli daemon at {base_url}...") self.logger.info(f"Connecting to signal-cli daemon at {base_url}...")
# Create HTTP client # Create HTTP client
self._http = httpx.AsyncClient(timeout=60.0, base_url=base_url) self._http = httpx.AsyncClient(
timeout=self._HTTP_TIMEOUT_SECONDS, base_url=base_url
)
# Test connection # Test connection
try: try:
@ -777,9 +802,6 @@ class SignalChannel(BaseChannel):
message_text: The message content message_text: The message content
timestamp: Message timestamp timestamp: Message timestamp
""" """
if self.config.group_message_buffer_size <= 0:
return
# Create buffer for this group if it doesn't exist # Create buffer for this group if it doesn't exist
if group_id not in self._group_buffers: if group_id not in self._group_buffers:
self._group_buffers[group_id] = deque(maxlen=self.config.group_message_buffer_size) self._group_buffers[group_id] = deque(maxlen=self.config.group_message_buffer_size)
@ -906,7 +928,7 @@ class SignalChannel(BaseChannel):
"""Extract possible identifier fields from a mention payload.""" """Extract possible identifier fields from a mention payload."""
ids: list[str] = [] ids: list[str] = []
def _walk(value: Any, depth: int = 0) -> None: def _walk(value: dict[str, Any] | Any, depth: int = 0) -> None:
if depth > 2: if depth > 2:
return return
if not isinstance(value, dict): if not isinstance(value, dict):

View File

@ -448,10 +448,13 @@ class TestGroupBuffer:
ch._add_to_group_buffer("g1", "Alice", "+1111", f"msg{i}", i) ch._add_to_group_buffer("g1", "Alice", "+1111", f"msg{i}", i)
assert len(ch._group_buffers["g1"]) == 3 assert len(ch._group_buffers["g1"]) == 3
def test_zero_buffer_size_does_not_add(self): def test_zero_buffer_size_rejected_by_validator(self):
ch = _make_channel(group_buffer_size=0) with pytest.raises(ValueError, match="group_message_buffer_size"):
ch._add_to_group_buffer("g1", "Alice", "+1111", "msg", 1000) _make_channel(group_buffer_size=0)
assert "g1" not in ch._group_buffers
def test_negative_buffer_size_rejected_by_validator(self):
with pytest.raises(ValueError, match="group_message_buffer_size"):
_make_channel(group_buffer_size=-1)
def test_context_limits_message_length(self): def test_context_limits_message_length(self):
ch = _make_channel(group_buffer_size=5) ch = _make_channel(group_buffer_size=5)