mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-21 09:02:32 +00:00
refactor(signal): hygiene cleanups around constants, typing, and config
- Hoist the cell-strip patterns to module level so they match the rest of the module's regex style and aren't reparsed on every call. - Type the markdown transform callback and the mention id walker so the inline Callable signature is no longer an untyped Any. - Add _HTTP_TIMEOUT_SECONDS alongside the other class-level tunables. - Reject group_message_buffer_size <= 0 in a Pydantic field_validator rather than silently disabling the buffer at write time. - Mark SignalConfig.allow_from as a computed_field so it shows up in model_dump() instead of being invisible to serialization. Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
parent
6ec6c9bb83
commit
d56bafa6d0
@ -8,12 +8,13 @@ import re
|
|||||||
import shutil
|
import shutil
|
||||||
import unicodedata
|
import unicodedata
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
from collections.abc import Callable
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from pydantic import Field
|
from pydantic import Field, computed_field, field_validator
|
||||||
|
|
||||||
from nanobot.bus.events import OutboundMessage
|
from nanobot.bus.events import OutboundMessage
|
||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
@ -42,6 +43,18 @@ _SIG_ITALIC_RE = re.compile(r'(?<!\*)\*([^*\n]+)\*(?!\*)|(?<![a-zA-Z0-9_])_([^_\
|
|||||||
_SIG_STRIKE_RE = re.compile(r'~~(.+?)~~|(?<![~\w])~([^~\n]+)~(?![~\w])', re.DOTALL)
|
_SIG_STRIKE_RE = re.compile(r'~~(.+?)~~|(?<![~\w])~([^~\n]+)~(?![~\w])', re.DOTALL)
|
||||||
_SIG_TOKEN_RE = re.compile(r'\x00C(\d+)\x00')
|
_SIG_TOKEN_RE = re.compile(r'\x00C(\d+)\x00')
|
||||||
|
|
||||||
|
# Patterns used to strip inline markdown when rendering table cells as plain
|
||||||
|
# text. Defined separately from the styling regexes above because the cell
|
||||||
|
# stripper needs a fixed, narrow subset (no single-asterisk italic, no
|
||||||
|
# single-tilde strikethrough) and benefits from each pattern's group 1 being
|
||||||
|
# the content directly.
|
||||||
|
_SIG_CELL_STRIP_PATTERNS: tuple[tuple[re.Pattern, str], ...] = (
|
||||||
|
(re.compile(r'\*\*(.+?)\*\*'), r'\1'),
|
||||||
|
(re.compile(r'__(.+?)__'), r'\1'),
|
||||||
|
(re.compile(r'~~(.+?)~~'), r'\1'),
|
||||||
|
(re.compile(r'`([^`]+)`'), r'\1'),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _utf16_len(s: str) -> int:
|
def _utf16_len(s: str) -> int:
|
||||||
"""UTF-16 code-unit length, matching Signal BodyRange semantics."""
|
"""UTF-16 code-unit length, matching Signal BodyRange semantics."""
|
||||||
@ -50,10 +63,8 @@ def _utf16_len(s: str) -> int:
|
|||||||
|
|
||||||
def _sig_strip_cell(s: str) -> str:
|
def _sig_strip_cell(s: str) -> str:
|
||||||
"""Strip inline markdown from a table cell for plain-text rendering."""
|
"""Strip inline markdown from a table cell for plain-text rendering."""
|
||||||
s = re.sub(r'\*\*(.+?)\*\*', r'\1', s)
|
for pattern, repl in _SIG_CELL_STRIP_PATTERNS:
|
||||||
s = re.sub(r'__(.+?)__', r'\1', s)
|
s = pattern.sub(repl, s)
|
||||||
s = re.sub(r'~~(.+?)~~', r'\1', s)
|
|
||||||
s = re.sub(r'`([^`]+)`', r'\1', s)
|
|
||||||
return s.strip()
|
return s.strip()
|
||||||
|
|
||||||
|
|
||||||
@ -132,7 +143,10 @@ def _markdown_to_signal(text: str) -> tuple[str, list[str]]:
|
|||||||
# Phase 2 (run-based): process inline patterns.
|
# Phase 2 (run-based): process inline patterns.
|
||||||
runs: list[_Run] = [_Run(text)]
|
runs: list[_Run] = [_Run(text)]
|
||||||
|
|
||||||
def transform(pattern: re.Pattern, make_runs: Any) -> None:
|
def transform(
|
||||||
|
pattern: re.Pattern,
|
||||||
|
make_runs: Callable[[re.Match, frozenset[str]], list[_Run]],
|
||||||
|
) -> None:
|
||||||
new_runs: list[_Run] = []
|
new_runs: list[_Run] = []
|
||||||
for run in runs:
|
for run in runs:
|
||||||
if run.opaque:
|
if run.opaque:
|
||||||
@ -284,6 +298,14 @@ class SignalConfig(Base):
|
|||||||
dm: SignalDMConfig = Field(default_factory=SignalDMConfig)
|
dm: SignalDMConfig = Field(default_factory=SignalDMConfig)
|
||||||
group: SignalGroupConfig = Field(default_factory=SignalGroupConfig)
|
group: SignalGroupConfig = Field(default_factory=SignalGroupConfig)
|
||||||
|
|
||||||
|
@field_validator("group_message_buffer_size")
|
||||||
|
@classmethod
|
||||||
|
def _validate_buffer_size(cls, v: int) -> int:
|
||||||
|
if v <= 0:
|
||||||
|
raise ValueError("group_message_buffer_size must be > 0")
|
||||||
|
return v
|
||||||
|
|
||||||
|
@computed_field # type: ignore[prop-decorator]
|
||||||
@property
|
@property
|
||||||
def allow_from(self) -> list[str]:
|
def allow_from(self) -> list[str]:
|
||||||
"""Aggregate allowlist for the base-class is_allowed() check.
|
"""Aggregate allowlist for the base-class is_allowed() check.
|
||||||
@ -309,6 +331,7 @@ class SignalChannel(BaseChannel):
|
|||||||
display_name = "Signal"
|
display_name = "Signal"
|
||||||
_TYPING_REFRESH_SECONDS = 10.0
|
_TYPING_REFRESH_SECONDS = 10.0
|
||||||
_MAX_MESSAGE_LEN = 64_000 # signal-cli practical limit (protocol max ~64 KB)
|
_MAX_MESSAGE_LEN = 64_000 # signal-cli practical limit (protocol max ~64 KB)
|
||||||
|
_HTTP_TIMEOUT_SECONDS = 60.0
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def default_config(cls) -> dict[str, Any]:
|
def default_config(cls) -> dict[str, Any]:
|
||||||
@ -351,7 +374,9 @@ class SignalChannel(BaseChannel):
|
|||||||
self.logger.info(f"Connecting to signal-cli daemon at {base_url}...")
|
self.logger.info(f"Connecting to signal-cli daemon at {base_url}...")
|
||||||
|
|
||||||
# Create HTTP client
|
# Create HTTP client
|
||||||
self._http = httpx.AsyncClient(timeout=60.0, base_url=base_url)
|
self._http = httpx.AsyncClient(
|
||||||
|
timeout=self._HTTP_TIMEOUT_SECONDS, base_url=base_url
|
||||||
|
)
|
||||||
|
|
||||||
# Test connection
|
# Test connection
|
||||||
try:
|
try:
|
||||||
@ -777,9 +802,6 @@ class SignalChannel(BaseChannel):
|
|||||||
message_text: The message content
|
message_text: The message content
|
||||||
timestamp: Message timestamp
|
timestamp: Message timestamp
|
||||||
"""
|
"""
|
||||||
if self.config.group_message_buffer_size <= 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Create buffer for this group if it doesn't exist
|
# Create buffer for this group if it doesn't exist
|
||||||
if group_id not in self._group_buffers:
|
if group_id not in self._group_buffers:
|
||||||
self._group_buffers[group_id] = deque(maxlen=self.config.group_message_buffer_size)
|
self._group_buffers[group_id] = deque(maxlen=self.config.group_message_buffer_size)
|
||||||
@ -906,7 +928,7 @@ class SignalChannel(BaseChannel):
|
|||||||
"""Extract possible identifier fields from a mention payload."""
|
"""Extract possible identifier fields from a mention payload."""
|
||||||
ids: list[str] = []
|
ids: list[str] = []
|
||||||
|
|
||||||
def _walk(value: Any, depth: int = 0) -> None:
|
def _walk(value: dict[str, Any] | Any, depth: int = 0) -> None:
|
||||||
if depth > 2:
|
if depth > 2:
|
||||||
return
|
return
|
||||||
if not isinstance(value, dict):
|
if not isinstance(value, dict):
|
||||||
|
|||||||
@ -448,10 +448,13 @@ class TestGroupBuffer:
|
|||||||
ch._add_to_group_buffer("g1", "Alice", "+1111", f"msg{i}", i)
|
ch._add_to_group_buffer("g1", "Alice", "+1111", f"msg{i}", i)
|
||||||
assert len(ch._group_buffers["g1"]) == 3
|
assert len(ch._group_buffers["g1"]) == 3
|
||||||
|
|
||||||
def test_zero_buffer_size_does_not_add(self):
|
def test_zero_buffer_size_rejected_by_validator(self):
|
||||||
ch = _make_channel(group_buffer_size=0)
|
with pytest.raises(ValueError, match="group_message_buffer_size"):
|
||||||
ch._add_to_group_buffer("g1", "Alice", "+1111", "msg", 1000)
|
_make_channel(group_buffer_size=0)
|
||||||
assert "g1" not in ch._group_buffers
|
|
||||||
|
def test_negative_buffer_size_rejected_by_validator(self):
|
||||||
|
with pytest.raises(ValueError, match="group_message_buffer_size"):
|
||||||
|
_make_channel(group_buffer_size=-1)
|
||||||
|
|
||||||
def test_context_limits_message_length(self):
|
def test_context_limits_message_length(self):
|
||||||
ch = _make_channel(group_buffer_size=5)
|
ch = _make_channel(group_buffer_size=5)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user