This commit is contained in:
Kaloyan Tenchov 2026-05-16 12:52:33 -04:00 committed by chengyongru
parent 3874b3acf4
commit aed6b6967c
3 changed files with 55 additions and 67 deletions

View File

@ -32,17 +32,19 @@ class _Run:
opaque: bool = False # code / table content — skip further pattern processing
_SIG_CODE_BLOCK_RE = re.compile(r'```(?:\w+)?\n?([\s\S]*?)```')
_SIG_INLINE_CODE_RE = re.compile(r'`([^`\n]+)`')
_SIG_HEADER_RE = re.compile(r'^#{1,6}\s+(.+)$', re.MULTILINE)
_SIG_BLOCKQUOTE_RE = re.compile(r'^>\s*(.*)$', re.MULTILINE)
_SIG_BULLET_RE = re.compile(r'^[-*]\s+', re.MULTILINE)
_SIG_OLIST_RE = re.compile(r'^(\d+)\.\s+', re.MULTILINE)
_SIG_LINK_RE = re.compile(r'\[([^\]]+)\]\(([^)]+)\)')
_SIG_BOLD_RE = re.compile(r'\*\*(.+?)\*\*|__(.+?)__', re.DOTALL)
_SIG_ITALIC_RE = re.compile(r'(?<!\*)\*([^*\n]+)\*(?!\*)|(?<![a-zA-Z0-9_])_([^_\n]+)_(?![a-zA-Z0-9_])')
_SIG_STRIKE_RE = re.compile(r'~~(.+?)~~|(?<![~\w])~([^~\n]+)~(?![~\w])', re.DOTALL)
_SIG_TOKEN_RE = re.compile(r'\x00C(\d+)\x00')
_SIG_CODE_BLOCK_RE = re.compile(r"```(?:\w+)?\n?([\s\S]*?)```")
_SIG_INLINE_CODE_RE = re.compile(r"`([^`\n]+)`")
_SIG_HEADER_RE = re.compile(r"^#{1,6}\s+(.+)$", re.MULTILINE)
_SIG_BLOCKQUOTE_RE = re.compile(r"^>\s*(.*)$", re.MULTILINE)
_SIG_BULLET_RE = re.compile(r"^[-*]\s+", re.MULTILINE)
_SIG_OLIST_RE = re.compile(r"^(\d+)\.\s+", re.MULTILINE)
_SIG_LINK_RE = re.compile(r"\[([^\]]+)\]\(([^)]+)\)")
_SIG_BOLD_RE = re.compile(r"\*\*(.+?)\*\*|__(.+?)__", re.DOTALL)
_SIG_ITALIC_RE = re.compile(
r"(?<!\*)\*([^*\n]+)\*(?!\*)|(?<![a-zA-Z0-9_])_([^_\n]+)_(?![a-zA-Z0-9_])"
)
_SIG_STRIKE_RE = re.compile(r"~~(.+?)~~|(?<![~\w])~([^~\n]+)~(?![~\w])", re.DOTALL)
_SIG_TOKEN_RE = re.compile(r"\x00C(\d+)\x00")
# Patterns used to strip inline markdown when rendering table cells as plain
# text. Defined separately from the styling regexes above because the cell
@ -50,10 +52,10 @@ _SIG_TOKEN_RE = re.compile(r'\x00C(\d+)\x00')
# single-tilde strikethrough) and benefits from each pattern's group 1 being
# the content directly.
_SIG_CELL_STRIP_PATTERNS: tuple[tuple[re.Pattern, str], ...] = (
(re.compile(r'\*\*(.+?)\*\*'), r'\1'),
(re.compile(r'__(.+?)__'), r'\1'),
(re.compile(r'~~(.+?)~~'), r'\1'),
(re.compile(r'`([^`]+)`'), r'\1'),
(re.compile(r"\*\*(.+?)\*\*"), r"\1"),
(re.compile(r"__(.+?)__"), r"\1"),
(re.compile(r"~~(.+?)~~"), r"\1"),
(re.compile(r"`([^`]+)`"), r"\1"),
)
@ -73,32 +75,32 @@ def _sig_render_table(table_lines: list[str]) -> str:
"""Render a markdown pipe-table as fixed-width plain text."""
def dw(s: str) -> int:
return sum(2 if unicodedata.east_asian_width(c) in ('W', 'F') else 1 for c in s)
return sum(2 if unicodedata.east_asian_width(c) in ("W", "F") else 1 for c in s)
rows: list[list[str]] = []
has_sep = False
for line in table_lines:
cells = [_sig_strip_cell(c) for c in line.strip().strip('|').split('|')]
if all(re.match(r'^:?-+:?$', c) for c in cells if c):
cells = [_sig_strip_cell(c) for c in line.strip().strip("|").split("|")]
if all(re.match(r"^:?-+:?$", c) for c in cells if c):
has_sep = True
continue
rows.append(cells)
if not rows or not has_sep:
return '\n'.join(table_lines)
return "\n".join(table_lines)
ncols = max(len(r) for r in rows)
for r in rows:
r.extend([''] * (ncols - len(r)))
r.extend([""] * (ncols - len(r)))
widths = [max(dw(r[c]) for r in rows) for c in range(ncols)]
def dr(cells: list[str]) -> str:
return ' '.join(f'{c}{" " * (w - dw(c))}' for c, w in zip(cells, widths))
return " ".join(f"{c}{' ' * (w - dw(c))}" for c, w in zip(cells, widths))
out = [dr(rows[0])]
out.append(' '.join('' * w for w in widths))
out.append(" ".join("" * w for w in widths))
for row in rows[1:]:
out.append(dr(row))
return '\n'.join(out)
return "\n".join(out)
def _markdown_to_signal(text: str) -> tuple[str, list[str]]:
@ -121,17 +123,17 @@ def _markdown_to_signal(text: str) -> tuple[str, list[str]]:
text = _SIG_CODE_BLOCK_RE.sub(save_code, text)
# Detect and render pipe-tables line by line.
lines = text.split('\n')
lines = text.split("\n")
rebuilt: list[str] = []
i = 0
while i < len(lines):
if re.match(r'^\s*\|.+\|', lines[i]):
if re.match(r"^\s*\|.+\|", lines[i]):
tbl: list[str] = []
while i < len(lines) and re.match(r'^\s*\|.+\|', lines[i]):
while i < len(lines) and re.match(r"^\s*\|.+\|", lines[i]):
tbl.append(lines[i])
i += 1
rendered = _sig_render_table(tbl)
if rendered != '\n'.join(tbl):
if rendered != "\n".join(tbl):
protected.append(rendered)
rebuilt.append(f"\x00C{len(protected) - 1}\x00")
else:
@ -139,7 +141,7 @@ def _markdown_to_signal(text: str) -> tuple[str, list[str]]:
else:
rebuilt.append(lines[i])
i += 1
text = '\n'.join(rebuilt)
text = "\n".join(rebuilt)
# Phase 2 (run-based): process inline patterns.
runs: list[_Run] = [_Run(text)]
@ -164,7 +166,10 @@ def _markdown_to_signal(text: str) -> tuple[str, list[str]]:
runs[:] = new_runs
# Restore code/table placeholders as opaque MONOSPACE runs.
transform(_SIG_TOKEN_RE, lambda m, s: [_Run(protected[int(m.group(1))], s | {"MONOSPACE"}, opaque=True)])
transform(
_SIG_TOKEN_RE,
lambda m, s: [_Run(protected[int(m.group(1))], s | {"MONOSPACE"}, opaque=True)],
)
# Inline code (opaque).
transform(_SIG_INLINE_CODE_RE, lambda m, s: [_Run(m.group(1), s | {"MONOSPACE"}, opaque=True)])
@ -186,7 +191,7 @@ def _markdown_to_signal(text: str) -> tuple[str, list[str]]:
link_text, url = m.group(1), m.group(2)
def _norm(u: str) -> str:
return re.sub(r'^https?://(www\.)?', '', u).rstrip('/').lower()
return re.sub(r"^https?://(www\.)?", "", u).rstrip("/").lower()
if _norm(url) == _norm(link_text):
return [_Run(url, s)]
@ -581,9 +586,7 @@ class SignalChannel(BaseChannel):
raise
@asynccontextmanager
async def _safe_handle(
self, action: str, payload: Any = None
) -> AsyncIterator[None]:
async def _safe_handle(self, action: str, payload: Any = None) -> AsyncIterator[None]:
"""Swallow and log any exception from a top-level handler block.
Logs `self.logger.error` with the action name, the exception, and a
@ -788,9 +791,7 @@ class SignalChannel(BaseChannel):
return False, chat_id
if self.config.dm.policy == "allowlist":
if not self._sender_matches_allowlist(sender_id, self.config.dm.allow_from):
self.logger.debug(
f"Ignoring DM from {sender_id} (policy: {self.config.dm.policy})"
)
self.logger.debug(f"Ignoring DM from {sender_id} (policy: {self.config.dm.policy})")
return False, chat_id
return True, chat_id
@ -817,9 +818,7 @@ class SignalChannel(BaseChannel):
if is_group_message:
buffer_context = self._get_group_buffer_context(chat_id)
if buffer_context:
content_parts.append(
f"[Recent group messages for context:]\n{buffer_context}\n---"
)
content_parts.append(f"[Recent group messages for context:]\n{buffer_context}\n---")
if message_text:
if is_group_message:
@ -842,9 +841,7 @@ class SignalChannel(BaseChannel):
dest_path = media_dir / f"signal_{safe_filename(filename)}"
shutil.copy2(source_path, dest_path)
media_paths.append(str(dest_path))
media_type = (
content_type.split("/")[0] if "/" in content_type else "file"
)
media_type = content_type.split("/")[0] if "/" in content_type else "file"
if media_type not in ("image", "audio", "video"):
media_type = "file"
content_parts.append(f"[{media_type}: {dest_path}]")

View File

@ -610,16 +610,12 @@ class TestCheckInboundPolicy:
def test_group_open_without_mention_blocks(self):
ch = _make_channel(group_enabled=True, group_policy="open", require_mention=True)
allowed, _ = self._call(
ch, is_group_message=True, group_id="g1", message_text="plain talk"
)
allowed, _ = self._call(ch, is_group_message=True, group_id="g1", message_text="plain talk")
assert allowed is False
def test_group_command_bypasses_mention_requirement(self):
ch = _make_channel(group_enabled=True, group_policy="open", require_mention=True)
allowed, _ = self._call(
ch, is_group_message=True, group_id="g1", message_text="/help"
)
allowed, _ = self._call(ch, is_group_message=True, group_id="g1", message_text="/help")
assert allowed is True
def test_allowed_group_appends_to_buffer(self):
@ -703,9 +699,7 @@ class TestHandleDataMessageDM:
async def test_dm_allowlist_matches_uuid_case_insensitive(self):
"""UUID matching must be case-insensitive."""
uuid = "ABCDEF12-3456-7890-ABCD-EF1234567890"
ch, handled = self._make_dm_channel(
policy="allowlist", allow_from=[uuid.lower()]
)
ch, handled = self._make_dm_channel(policy="allowlist", allow_from=[uuid.lower()])
params = _dm_envelope(source_number="+19995550001", source_uuid=uuid)
await ch._handle_receive_notification(params)
assert len(handled) == 1
@ -1076,9 +1070,7 @@ class TestCommandHandling:
ch, forwarded = _make_channel_with_capture(
group_enabled=True, group_policy="open", require_mention=True
)
params = _group_envelope(
source_number="+19995550001", group_id="grp==", message="/reset"
)
params = _group_envelope(source_number="+19995550001", group_id="grp==", message="/reset")
await ch._handle_receive_notification(params)
assert len(forwarded) == 1
assert "/reset" in forwarded[0]["content"]
@ -1357,9 +1349,7 @@ def test_config_allow_from_aggregates_dm_and_group() -> None:
enabled=True,
phone_number="+10000000000",
dm=SignalDMConfig(enabled=True, policy="allowlist", allow_from=["+1111", "+2222"]),
group=SignalGroupConfig(
enabled=True, policy="allowlist", allow_from=["+3333", "+1111"]
),
group=SignalGroupConfig(enabled=True, policy="allowlist", allow_from=["+3333", "+1111"]),
)
combined = config.allow_from
assert "+1111" in combined

View File

@ -1,7 +1,5 @@
"""Unit tests for the Signal markdown → plain text + textStyle converter."""
import pytest
from nanobot.channels.signal import _markdown_to_signal, _partition_styles
from nanobot.utils.helpers import split_message
@ -94,8 +92,9 @@ def test_inline_code():
def test_code_block():
plain, styles = _markdown_to_signal("```\nprint('hi')\n```")
assert "print('hi')" in plain
assert styles_for(plain, styles).get("print('hi')\n") == ["MONOSPACE"] or \
"MONOSPACE" in str(styles_for(plain, styles))
assert styles_for(plain, styles).get("print('hi')\n") == ["MONOSPACE"] or "MONOSPACE" in str(
styles_for(plain, styles)
)
def test_code_block_with_lang():
@ -278,9 +277,7 @@ def assert_within_utf16_bounds(plain: str, styles: list[str]) -> None:
start_s, length_s, _ = entry.split(":", 2)
start, length = int(start_s), int(length_s)
assert start >= 0
assert start + length <= limit, (
f"range {entry} exceeds utf-16 length {limit} of {plain!r}"
)
assert start + length <= limit, f"range {entry} exceeds utf-16 length {limit} of {plain!r}"
def test_bold_with_emoji_inside():
@ -413,7 +410,9 @@ def test_partition_styles_long_message_preserves_chunk_one_styles():
for entry in final_styles:
s, ln, _ = entry.split(":", 2)
start, length = int(s), int(ln)
slice_ = final_chunk.encode("utf-16-le")[start * 2 : (start + length) * 2].decode("utf-16-le")
slice_ = final_chunk.encode("utf-16-le")[start * 2 : (start + length) * 2].decode(
"utf-16-le"
)
assert slice_ == "tail"
@ -444,7 +443,9 @@ def test_partition_styles_with_non_bmp_chunk_offset():
for entry in final_styles:
s, ln, _ = entry.split(":", 2)
start, length = int(s), int(ln)
slice_ = final_chunk.encode("utf-16-le")[start * 2 : (start + length) * 2].decode("utf-16-le")
slice_ = final_chunk.encode("utf-16-le")[start * 2 : (start + length) * 2].decode(
"utf-16-le"
)
assert slice_ == "tail"