This commit is contained in:
Kaloyan Tenchov 2026-05-16 12:52:33 -04:00 committed by chengyongru
parent 3874b3acf4
commit aed6b6967c
3 changed files with 55 additions and 67 deletions

View File

@ -32,17 +32,19 @@ class _Run:
opaque: bool = False # code / table content — skip further pattern processing opaque: bool = False # code / table content — skip further pattern processing
_SIG_CODE_BLOCK_RE = re.compile(r'```(?:\w+)?\n?([\s\S]*?)```') _SIG_CODE_BLOCK_RE = re.compile(r"```(?:\w+)?\n?([\s\S]*?)```")
_SIG_INLINE_CODE_RE = re.compile(r'`([^`\n]+)`') _SIG_INLINE_CODE_RE = re.compile(r"`([^`\n]+)`")
_SIG_HEADER_RE = re.compile(r'^#{1,6}\s+(.+)$', re.MULTILINE) _SIG_HEADER_RE = re.compile(r"^#{1,6}\s+(.+)$", re.MULTILINE)
_SIG_BLOCKQUOTE_RE = re.compile(r'^>\s*(.*)$', re.MULTILINE) _SIG_BLOCKQUOTE_RE = re.compile(r"^>\s*(.*)$", re.MULTILINE)
_SIG_BULLET_RE = re.compile(r'^[-*]\s+', re.MULTILINE) _SIG_BULLET_RE = re.compile(r"^[-*]\s+", re.MULTILINE)
_SIG_OLIST_RE = re.compile(r'^(\d+)\.\s+', re.MULTILINE) _SIG_OLIST_RE = re.compile(r"^(\d+)\.\s+", re.MULTILINE)
_SIG_LINK_RE = re.compile(r'\[([^\]]+)\]\(([^)]+)\)') _SIG_LINK_RE = re.compile(r"\[([^\]]+)\]\(([^)]+)\)")
_SIG_BOLD_RE = re.compile(r'\*\*(.+?)\*\*|__(.+?)__', re.DOTALL) _SIG_BOLD_RE = re.compile(r"\*\*(.+?)\*\*|__(.+?)__", re.DOTALL)
_SIG_ITALIC_RE = re.compile(r'(?<!\*)\*([^*\n]+)\*(?!\*)|(?<![a-zA-Z0-9_])_([^_\n]+)_(?![a-zA-Z0-9_])') _SIG_ITALIC_RE = re.compile(
_SIG_STRIKE_RE = re.compile(r'~~(.+?)~~|(?<![~\w])~([^~\n]+)~(?![~\w])', re.DOTALL) r"(?<!\*)\*([^*\n]+)\*(?!\*)|(?<![a-zA-Z0-9_])_([^_\n]+)_(?![a-zA-Z0-9_])"
_SIG_TOKEN_RE = re.compile(r'\x00C(\d+)\x00') )
_SIG_STRIKE_RE = re.compile(r"~~(.+?)~~|(?<![~\w])~([^~\n]+)~(?![~\w])", re.DOTALL)
_SIG_TOKEN_RE = re.compile(r"\x00C(\d+)\x00")
# Patterns used to strip inline markdown when rendering table cells as plain # Patterns used to strip inline markdown when rendering table cells as plain
# text. Defined separately from the styling regexes above because the cell # text. Defined separately from the styling regexes above because the cell
@ -50,10 +52,10 @@ _SIG_TOKEN_RE = re.compile(r'\x00C(\d+)\x00')
# single-tilde strikethrough) and benefits from each pattern's group 1 being # single-tilde strikethrough) and benefits from each pattern's group 1 being
# the content directly. # the content directly.
_SIG_CELL_STRIP_PATTERNS: tuple[tuple[re.Pattern, str], ...] = ( _SIG_CELL_STRIP_PATTERNS: tuple[tuple[re.Pattern, str], ...] = (
(re.compile(r'\*\*(.+?)\*\*'), r'\1'), (re.compile(r"\*\*(.+?)\*\*"), r"\1"),
(re.compile(r'__(.+?)__'), r'\1'), (re.compile(r"__(.+?)__"), r"\1"),
(re.compile(r'~~(.+?)~~'), r'\1'), (re.compile(r"~~(.+?)~~"), r"\1"),
(re.compile(r'`([^`]+)`'), r'\1'), (re.compile(r"`([^`]+)`"), r"\1"),
) )
@ -73,32 +75,32 @@ def _sig_render_table(table_lines: list[str]) -> str:
"""Render a markdown pipe-table as fixed-width plain text.""" """Render a markdown pipe-table as fixed-width plain text."""
def dw(s: str) -> int: def dw(s: str) -> int:
return sum(2 if unicodedata.east_asian_width(c) in ('W', 'F') else 1 for c in s) return sum(2 if unicodedata.east_asian_width(c) in ("W", "F") else 1 for c in s)
rows: list[list[str]] = [] rows: list[list[str]] = []
has_sep = False has_sep = False
for line in table_lines: for line in table_lines:
cells = [_sig_strip_cell(c) for c in line.strip().strip('|').split('|')] cells = [_sig_strip_cell(c) for c in line.strip().strip("|").split("|")]
if all(re.match(r'^:?-+:?$', c) for c in cells if c): if all(re.match(r"^:?-+:?$", c) for c in cells if c):
has_sep = True has_sep = True
continue continue
rows.append(cells) rows.append(cells)
if not rows or not has_sep: if not rows or not has_sep:
return '\n'.join(table_lines) return "\n".join(table_lines)
ncols = max(len(r) for r in rows) ncols = max(len(r) for r in rows)
for r in rows: for r in rows:
r.extend([''] * (ncols - len(r))) r.extend([""] * (ncols - len(r)))
widths = [max(dw(r[c]) for r in rows) for c in range(ncols)] widths = [max(dw(r[c]) for r in rows) for c in range(ncols)]
def dr(cells: list[str]) -> str: def dr(cells: list[str]) -> str:
return ' '.join(f'{c}{" " * (w - dw(c))}' for c, w in zip(cells, widths)) return " ".join(f"{c}{' ' * (w - dw(c))}" for c, w in zip(cells, widths))
out = [dr(rows[0])] out = [dr(rows[0])]
out.append(' '.join('' * w for w in widths)) out.append(" ".join("" * w for w in widths))
for row in rows[1:]: for row in rows[1:]:
out.append(dr(row)) out.append(dr(row))
return '\n'.join(out) return "\n".join(out)
def _markdown_to_signal(text: str) -> tuple[str, list[str]]: def _markdown_to_signal(text: str) -> tuple[str, list[str]]:
@ -121,17 +123,17 @@ def _markdown_to_signal(text: str) -> tuple[str, list[str]]:
text = _SIG_CODE_BLOCK_RE.sub(save_code, text) text = _SIG_CODE_BLOCK_RE.sub(save_code, text)
# Detect and render pipe-tables line by line. # Detect and render pipe-tables line by line.
lines = text.split('\n') lines = text.split("\n")
rebuilt: list[str] = [] rebuilt: list[str] = []
i = 0 i = 0
while i < len(lines): while i < len(lines):
if re.match(r'^\s*\|.+\|', lines[i]): if re.match(r"^\s*\|.+\|", lines[i]):
tbl: list[str] = [] tbl: list[str] = []
while i < len(lines) and re.match(r'^\s*\|.+\|', lines[i]): while i < len(lines) and re.match(r"^\s*\|.+\|", lines[i]):
tbl.append(lines[i]) tbl.append(lines[i])
i += 1 i += 1
rendered = _sig_render_table(tbl) rendered = _sig_render_table(tbl)
if rendered != '\n'.join(tbl): if rendered != "\n".join(tbl):
protected.append(rendered) protected.append(rendered)
rebuilt.append(f"\x00C{len(protected) - 1}\x00") rebuilt.append(f"\x00C{len(protected) - 1}\x00")
else: else:
@ -139,7 +141,7 @@ def _markdown_to_signal(text: str) -> tuple[str, list[str]]:
else: else:
rebuilt.append(lines[i]) rebuilt.append(lines[i])
i += 1 i += 1
text = '\n'.join(rebuilt) text = "\n".join(rebuilt)
# Phase 2 (run-based): process inline patterns. # Phase 2 (run-based): process inline patterns.
runs: list[_Run] = [_Run(text)] runs: list[_Run] = [_Run(text)]
@ -156,7 +158,7 @@ def _markdown_to_signal(text: str) -> tuple[str, list[str]]:
pos = 0 pos = 0
for m in pattern.finditer(run.text): for m in pattern.finditer(run.text):
if m.start() > pos: if m.start() > pos:
new_runs.append(_Run(run.text[pos:m.start()], run.styles)) new_runs.append(_Run(run.text[pos : m.start()], run.styles))
new_runs.extend(make_runs(m, run.styles)) new_runs.extend(make_runs(m, run.styles))
pos = m.end() pos = m.end()
if pos < len(run.text): if pos < len(run.text):
@ -164,7 +166,10 @@ def _markdown_to_signal(text: str) -> tuple[str, list[str]]:
runs[:] = new_runs runs[:] = new_runs
# Restore code/table placeholders as opaque MONOSPACE runs. # Restore code/table placeholders as opaque MONOSPACE runs.
transform(_SIG_TOKEN_RE, lambda m, s: [_Run(protected[int(m.group(1))], s | {"MONOSPACE"}, opaque=True)]) transform(
_SIG_TOKEN_RE,
lambda m, s: [_Run(protected[int(m.group(1))], s | {"MONOSPACE"}, opaque=True)],
)
# Inline code (opaque). # Inline code (opaque).
transform(_SIG_INLINE_CODE_RE, lambda m, s: [_Run(m.group(1), s | {"MONOSPACE"}, opaque=True)]) transform(_SIG_INLINE_CODE_RE, lambda m, s: [_Run(m.group(1), s | {"MONOSPACE"}, opaque=True)])
@ -186,7 +191,7 @@ def _markdown_to_signal(text: str) -> tuple[str, list[str]]:
link_text, url = m.group(1), m.group(2) link_text, url = m.group(1), m.group(2)
def _norm(u: str) -> str: def _norm(u: str) -> str:
return re.sub(r'^https?://(www\.)?', '', u).rstrip('/').lower() return re.sub(r"^https?://(www\.)?", "", u).rstrip("/").lower()
if _norm(url) == _norm(link_text): if _norm(url) == _norm(link_text):
return [_Run(url, s)] return [_Run(url, s)]
@ -581,9 +586,7 @@ class SignalChannel(BaseChannel):
raise raise
@asynccontextmanager @asynccontextmanager
async def _safe_handle( async def _safe_handle(self, action: str, payload: Any = None) -> AsyncIterator[None]:
self, action: str, payload: Any = None
) -> AsyncIterator[None]:
"""Swallow and log any exception from a top-level handler block. """Swallow and log any exception from a top-level handler block.
Logs `self.logger.error` with the action name, the exception, and a Logs `self.logger.error` with the action name, the exception, and a
@ -788,9 +791,7 @@ class SignalChannel(BaseChannel):
return False, chat_id return False, chat_id
if self.config.dm.policy == "allowlist": if self.config.dm.policy == "allowlist":
if not self._sender_matches_allowlist(sender_id, self.config.dm.allow_from): if not self._sender_matches_allowlist(sender_id, self.config.dm.allow_from):
self.logger.debug( self.logger.debug(f"Ignoring DM from {sender_id} (policy: {self.config.dm.policy})")
f"Ignoring DM from {sender_id} (policy: {self.config.dm.policy})"
)
return False, chat_id return False, chat_id
return True, chat_id return True, chat_id
@ -817,9 +818,7 @@ class SignalChannel(BaseChannel):
if is_group_message: if is_group_message:
buffer_context = self._get_group_buffer_context(chat_id) buffer_context = self._get_group_buffer_context(chat_id)
if buffer_context: if buffer_context:
content_parts.append( content_parts.append(f"[Recent group messages for context:]\n{buffer_context}\n---")
f"[Recent group messages for context:]\n{buffer_context}\n---"
)
if message_text: if message_text:
if is_group_message: if is_group_message:
@ -842,9 +841,7 @@ class SignalChannel(BaseChannel):
dest_path = media_dir / f"signal_{safe_filename(filename)}" dest_path = media_dir / f"signal_{safe_filename(filename)}"
shutil.copy2(source_path, dest_path) shutil.copy2(source_path, dest_path)
media_paths.append(str(dest_path)) media_paths.append(str(dest_path))
media_type = ( media_type = content_type.split("/")[0] if "/" in content_type else "file"
content_type.split("/")[0] if "/" in content_type else "file"
)
if media_type not in ("image", "audio", "video"): if media_type not in ("image", "audio", "video"):
media_type = "file" media_type = "file"
content_parts.append(f"[{media_type}: {dest_path}]") content_parts.append(f"[{media_type}: {dest_path}]")

View File

@ -610,16 +610,12 @@ class TestCheckInboundPolicy:
def test_group_open_without_mention_blocks(self): def test_group_open_without_mention_blocks(self):
ch = _make_channel(group_enabled=True, group_policy="open", require_mention=True) ch = _make_channel(group_enabled=True, group_policy="open", require_mention=True)
allowed, _ = self._call( allowed, _ = self._call(ch, is_group_message=True, group_id="g1", message_text="plain talk")
ch, is_group_message=True, group_id="g1", message_text="plain talk"
)
assert allowed is False assert allowed is False
def test_group_command_bypasses_mention_requirement(self): def test_group_command_bypasses_mention_requirement(self):
ch = _make_channel(group_enabled=True, group_policy="open", require_mention=True) ch = _make_channel(group_enabled=True, group_policy="open", require_mention=True)
allowed, _ = self._call( allowed, _ = self._call(ch, is_group_message=True, group_id="g1", message_text="/help")
ch, is_group_message=True, group_id="g1", message_text="/help"
)
assert allowed is True assert allowed is True
def test_allowed_group_appends_to_buffer(self): def test_allowed_group_appends_to_buffer(self):
@ -703,9 +699,7 @@ class TestHandleDataMessageDM:
async def test_dm_allowlist_matches_uuid_case_insensitive(self): async def test_dm_allowlist_matches_uuid_case_insensitive(self):
"""UUID matching must be case-insensitive.""" """UUID matching must be case-insensitive."""
uuid = "ABCDEF12-3456-7890-ABCD-EF1234567890" uuid = "ABCDEF12-3456-7890-ABCD-EF1234567890"
ch, handled = self._make_dm_channel( ch, handled = self._make_dm_channel(policy="allowlist", allow_from=[uuid.lower()])
policy="allowlist", allow_from=[uuid.lower()]
)
params = _dm_envelope(source_number="+19995550001", source_uuid=uuid) params = _dm_envelope(source_number="+19995550001", source_uuid=uuid)
await ch._handle_receive_notification(params) await ch._handle_receive_notification(params)
assert len(handled) == 1 assert len(handled) == 1
@ -1076,9 +1070,7 @@ class TestCommandHandling:
ch, forwarded = _make_channel_with_capture( ch, forwarded = _make_channel_with_capture(
group_enabled=True, group_policy="open", require_mention=True group_enabled=True, group_policy="open", require_mention=True
) )
params = _group_envelope( params = _group_envelope(source_number="+19995550001", group_id="grp==", message="/reset")
source_number="+19995550001", group_id="grp==", message="/reset"
)
await ch._handle_receive_notification(params) await ch._handle_receive_notification(params)
assert len(forwarded) == 1 assert len(forwarded) == 1
assert "/reset" in forwarded[0]["content"] assert "/reset" in forwarded[0]["content"]
@ -1357,9 +1349,7 @@ def test_config_allow_from_aggregates_dm_and_group() -> None:
enabled=True, enabled=True,
phone_number="+10000000000", phone_number="+10000000000",
dm=SignalDMConfig(enabled=True, policy="allowlist", allow_from=["+1111", "+2222"]), dm=SignalDMConfig(enabled=True, policy="allowlist", allow_from=["+1111", "+2222"]),
group=SignalGroupConfig( group=SignalGroupConfig(enabled=True, policy="allowlist", allow_from=["+3333", "+1111"]),
enabled=True, policy="allowlist", allow_from=["+3333", "+1111"]
),
) )
combined = config.allow_from combined = config.allow_from
assert "+1111" in combined assert "+1111" in combined

View File

@ -1,7 +1,5 @@
"""Unit tests for the Signal markdown → plain text + textStyle converter.""" """Unit tests for the Signal markdown → plain text + textStyle converter."""
import pytest
from nanobot.channels.signal import _markdown_to_signal, _partition_styles from nanobot.channels.signal import _markdown_to_signal, _partition_styles
from nanobot.utils.helpers import split_message from nanobot.utils.helpers import split_message
@ -94,8 +92,9 @@ def test_inline_code():
def test_code_block(): def test_code_block():
plain, styles = _markdown_to_signal("```\nprint('hi')\n```") plain, styles = _markdown_to_signal("```\nprint('hi')\n```")
assert "print('hi')" in plain assert "print('hi')" in plain
assert styles_for(plain, styles).get("print('hi')\n") == ["MONOSPACE"] or \ assert styles_for(plain, styles).get("print('hi')\n") == ["MONOSPACE"] or "MONOSPACE" in str(
"MONOSPACE" in str(styles_for(plain, styles)) styles_for(plain, styles)
)
def test_code_block_with_lang(): def test_code_block_with_lang():
@ -278,9 +277,7 @@ def assert_within_utf16_bounds(plain: str, styles: list[str]) -> None:
start_s, length_s, _ = entry.split(":", 2) start_s, length_s, _ = entry.split(":", 2)
start, length = int(start_s), int(length_s) start, length = int(start_s), int(length_s)
assert start >= 0 assert start >= 0
assert start + length <= limit, ( assert start + length <= limit, f"range {entry} exceeds utf-16 length {limit} of {plain!r}"
f"range {entry} exceeds utf-16 length {limit} of {plain!r}"
)
def test_bold_with_emoji_inside(): def test_bold_with_emoji_inside():
@ -413,7 +410,9 @@ def test_partition_styles_long_message_preserves_chunk_one_styles():
for entry in final_styles: for entry in final_styles:
s, ln, _ = entry.split(":", 2) s, ln, _ = entry.split(":", 2)
start, length = int(s), int(ln) start, length = int(s), int(ln)
slice_ = final_chunk.encode("utf-16-le")[start * 2 : (start + length) * 2].decode("utf-16-le") slice_ = final_chunk.encode("utf-16-le")[start * 2 : (start + length) * 2].decode(
"utf-16-le"
)
assert slice_ == "tail" assert slice_ == "tail"
@ -444,7 +443,9 @@ def test_partition_styles_with_non_bmp_chunk_offset():
for entry in final_styles: for entry in final_styles:
s, ln, _ = entry.split(":", 2) s, ln, _ = entry.split(":", 2)
start, length = int(s), int(ln) start, length = int(s), int(ln)
slice_ = final_chunk.encode("utf-16-le")[start * 2 : (start + length) * 2].decode("utf-16-le") slice_ = final_chunk.encode("utf-16-le")[start * 2 : (start + length) * 2].decode(
"utf-16-le"
)
assert slice_ == "tail" assert slice_ == "tail"