mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-21 09:02:32 +00:00
refactor(signal): hygiene cleanups around constants, typing, and config
- Hoist the cell-strip patterns to module level so they match the rest of the module's regex style and aren't reparsed on every call. - Type the markdown transform callback and the mention id walker so the inline Callable signature is no longer an untyped Any. - Add _HTTP_TIMEOUT_SECONDS alongside the other class-level tunables. - Reject group_message_buffer_size <= 0 in a Pydantic field_validator rather than silently disabling the buffer at write time. - Mark SignalConfig.allow_from as a computed_field so it shows up in model_dump() instead of being invisible to serialization. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
parent
6ec6c9bb83
commit
d56bafa6d0
@ -8,12 +8,13 @@ import re
|
||||
import shutil
|
||||
import unicodedata
|
||||
from collections import deque
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from pydantic import Field
|
||||
from pydantic import Field, computed_field, field_validator
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
@ -42,6 +43,18 @@ _SIG_ITALIC_RE = re.compile(r'(?<!\*)\*([^*\n]+)\*(?!\*)|(?<![a-zA-Z0-9_])_([^_\
|
||||
_SIG_STRIKE_RE = re.compile(r'~~(.+?)~~|(?<![~\w])~([^~\n]+)~(?![~\w])', re.DOTALL)
|
||||
_SIG_TOKEN_RE = re.compile(r'\x00C(\d+)\x00')
|
||||
|
||||
# Patterns used to strip inline markdown when rendering table cells as plain
|
||||
# text. Defined separately from the styling regexes above because the cell
|
||||
# stripper needs a fixed, narrow subset (no single-asterisk italic, no
|
||||
# single-tilde strikethrough) and benefits from each pattern's group 1 being
|
||||
# the content directly.
|
||||
_SIG_CELL_STRIP_PATTERNS: tuple[tuple[re.Pattern, str], ...] = (
|
||||
(re.compile(r'\*\*(.+?)\*\*'), r'\1'),
|
||||
(re.compile(r'__(.+?)__'), r'\1'),
|
||||
(re.compile(r'~~(.+?)~~'), r'\1'),
|
||||
(re.compile(r'`([^`]+)`'), r'\1'),
|
||||
)
|
||||
|
||||
|
||||
def _utf16_len(s: str) -> int:
|
||||
"""UTF-16 code-unit length, matching Signal BodyRange semantics."""
|
||||
@ -50,10 +63,8 @@ def _utf16_len(s: str) -> int:
|
||||
|
||||
def _sig_strip_cell(s: str) -> str:
|
||||
"""Strip inline markdown from a table cell for plain-text rendering."""
|
||||
s = re.sub(r'\*\*(.+?)\*\*', r'\1', s)
|
||||
s = re.sub(r'__(.+?)__', r'\1', s)
|
||||
s = re.sub(r'~~(.+?)~~', r'\1', s)
|
||||
s = re.sub(r'`([^`]+)`', r'\1', s)
|
||||
for pattern, repl in _SIG_CELL_STRIP_PATTERNS:
|
||||
s = pattern.sub(repl, s)
|
||||
return s.strip()
|
||||
|
||||
|
||||
@ -132,7 +143,10 @@ def _markdown_to_signal(text: str) -> tuple[str, list[str]]:
|
||||
# Phase 2 (run-based): process inline patterns.
|
||||
runs: list[_Run] = [_Run(text)]
|
||||
|
||||
def transform(pattern: re.Pattern, make_runs: Any) -> None:
|
||||
def transform(
|
||||
pattern: re.Pattern,
|
||||
make_runs: Callable[[re.Match, frozenset[str]], list[_Run]],
|
||||
) -> None:
|
||||
new_runs: list[_Run] = []
|
||||
for run in runs:
|
||||
if run.opaque:
|
||||
@ -284,6 +298,14 @@ class SignalConfig(Base):
|
||||
dm: SignalDMConfig = Field(default_factory=SignalDMConfig)
|
||||
group: SignalGroupConfig = Field(default_factory=SignalGroupConfig)
|
||||
|
||||
@field_validator("group_message_buffer_size")
|
||||
@classmethod
|
||||
def _validate_buffer_size(cls, v: int) -> int:
|
||||
if v <= 0:
|
||||
raise ValueError("group_message_buffer_size must be > 0")
|
||||
return v
|
||||
|
||||
@computed_field # type: ignore[prop-decorator]
|
||||
@property
|
||||
def allow_from(self) -> list[str]:
|
||||
"""Aggregate allowlist for the base-class is_allowed() check.
|
||||
@ -309,6 +331,7 @@ class SignalChannel(BaseChannel):
|
||||
display_name = "Signal"
|
||||
_TYPING_REFRESH_SECONDS = 10.0
|
||||
_MAX_MESSAGE_LEN = 64_000 # signal-cli practical limit (protocol max ~64 KB)
|
||||
_HTTP_TIMEOUT_SECONDS = 60.0
|
||||
|
||||
@classmethod
|
||||
def default_config(cls) -> dict[str, Any]:
|
||||
@ -351,7 +374,9 @@ class SignalChannel(BaseChannel):
|
||||
self.logger.info(f"Connecting to signal-cli daemon at {base_url}...")
|
||||
|
||||
# Create HTTP client
|
||||
self._http = httpx.AsyncClient(timeout=60.0, base_url=base_url)
|
||||
self._http = httpx.AsyncClient(
|
||||
timeout=self._HTTP_TIMEOUT_SECONDS, base_url=base_url
|
||||
)
|
||||
|
||||
# Test connection
|
||||
try:
|
||||
@ -777,9 +802,6 @@ class SignalChannel(BaseChannel):
|
||||
message_text: The message content
|
||||
timestamp: Message timestamp
|
||||
"""
|
||||
if self.config.group_message_buffer_size <= 0:
|
||||
return
|
||||
|
||||
# Create buffer for this group if it doesn't exist
|
||||
if group_id not in self._group_buffers:
|
||||
self._group_buffers[group_id] = deque(maxlen=self.config.group_message_buffer_size)
|
||||
@ -906,7 +928,7 @@ class SignalChannel(BaseChannel):
|
||||
"""Extract possible identifier fields from a mention payload."""
|
||||
ids: list[str] = []
|
||||
|
||||
def _walk(value: Any, depth: int = 0) -> None:
|
||||
def _walk(value: dict[str, Any] | Any, depth: int = 0) -> None:
|
||||
if depth > 2:
|
||||
return
|
||||
if not isinstance(value, dict):
|
||||
|
||||
@ -448,10 +448,13 @@ class TestGroupBuffer:
|
||||
ch._add_to_group_buffer("g1", "Alice", "+1111", f"msg{i}", i)
|
||||
assert len(ch._group_buffers["g1"]) == 3
|
||||
|
||||
def test_zero_buffer_size_does_not_add(self):
|
||||
ch = _make_channel(group_buffer_size=0)
|
||||
ch._add_to_group_buffer("g1", "Alice", "+1111", "msg", 1000)
|
||||
assert "g1" not in ch._group_buffers
|
||||
def test_zero_buffer_size_rejected_by_validator(self):
|
||||
with pytest.raises(ValueError, match="group_message_buffer_size"):
|
||||
_make_channel(group_buffer_size=0)
|
||||
|
||||
def test_negative_buffer_size_rejected_by_validator(self):
|
||||
with pytest.raises(ValueError, match="group_message_buffer_size"):
|
||||
_make_channel(group_buffer_size=-1)
|
||||
|
||||
def test_context_limits_message_length(self):
|
||||
ch = _make_channel(group_buffer_size=5)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user