mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-21 17:12:32 +00:00
Cleanup
This commit is contained in:
parent
3874b3acf4
commit
aed6b6967c
@ -32,17 +32,19 @@ class _Run:
|
||||
opaque: bool = False # code / table content — skip further pattern processing
|
||||
|
||||
|
||||
_SIG_CODE_BLOCK_RE = re.compile(r'```(?:\w+)?\n?([\s\S]*?)```')
|
||||
_SIG_INLINE_CODE_RE = re.compile(r'`([^`\n]+)`')
|
||||
_SIG_HEADER_RE = re.compile(r'^#{1,6}\s+(.+)$', re.MULTILINE)
|
||||
_SIG_BLOCKQUOTE_RE = re.compile(r'^>\s*(.*)$', re.MULTILINE)
|
||||
_SIG_BULLET_RE = re.compile(r'^[-*]\s+', re.MULTILINE)
|
||||
_SIG_OLIST_RE = re.compile(r'^(\d+)\.\s+', re.MULTILINE)
|
||||
_SIG_LINK_RE = re.compile(r'\[([^\]]+)\]\(([^)]+)\)')
|
||||
_SIG_BOLD_RE = re.compile(r'\*\*(.+?)\*\*|__(.+?)__', re.DOTALL)
|
||||
_SIG_ITALIC_RE = re.compile(r'(?<!\*)\*([^*\n]+)\*(?!\*)|(?<![a-zA-Z0-9_])_([^_\n]+)_(?![a-zA-Z0-9_])')
|
||||
_SIG_STRIKE_RE = re.compile(r'~~(.+?)~~|(?<![~\w])~([^~\n]+)~(?![~\w])', re.DOTALL)
|
||||
_SIG_TOKEN_RE = re.compile(r'\x00C(\d+)\x00')
|
||||
_SIG_CODE_BLOCK_RE = re.compile(r"```(?:\w+)?\n?([\s\S]*?)```")
|
||||
_SIG_INLINE_CODE_RE = re.compile(r"`([^`\n]+)`")
|
||||
_SIG_HEADER_RE = re.compile(r"^#{1,6}\s+(.+)$", re.MULTILINE)
|
||||
_SIG_BLOCKQUOTE_RE = re.compile(r"^>\s*(.*)$", re.MULTILINE)
|
||||
_SIG_BULLET_RE = re.compile(r"^[-*]\s+", re.MULTILINE)
|
||||
_SIG_OLIST_RE = re.compile(r"^(\d+)\.\s+", re.MULTILINE)
|
||||
_SIG_LINK_RE = re.compile(r"\[([^\]]+)\]\(([^)]+)\)")
|
||||
_SIG_BOLD_RE = re.compile(r"\*\*(.+?)\*\*|__(.+?)__", re.DOTALL)
|
||||
_SIG_ITALIC_RE = re.compile(
|
||||
r"(?<!\*)\*([^*\n]+)\*(?!\*)|(?<![a-zA-Z0-9_])_([^_\n]+)_(?![a-zA-Z0-9_])"
|
||||
)
|
||||
_SIG_STRIKE_RE = re.compile(r"~~(.+?)~~|(?<![~\w])~([^~\n]+)~(?![~\w])", re.DOTALL)
|
||||
_SIG_TOKEN_RE = re.compile(r"\x00C(\d+)\x00")
|
||||
|
||||
# Patterns used to strip inline markdown when rendering table cells as plain
|
||||
# text. Defined separately from the styling regexes above because the cell
|
||||
@ -50,10 +52,10 @@ _SIG_TOKEN_RE = re.compile(r'\x00C(\d+)\x00')
|
||||
# single-tilde strikethrough) and benefits from each pattern's group 1 being
|
||||
# the content directly.
|
||||
_SIG_CELL_STRIP_PATTERNS: tuple[tuple[re.Pattern, str], ...] = (
|
||||
(re.compile(r'\*\*(.+?)\*\*'), r'\1'),
|
||||
(re.compile(r'__(.+?)__'), r'\1'),
|
||||
(re.compile(r'~~(.+?)~~'), r'\1'),
|
||||
(re.compile(r'`([^`]+)`'), r'\1'),
|
||||
(re.compile(r"\*\*(.+?)\*\*"), r"\1"),
|
||||
(re.compile(r"__(.+?)__"), r"\1"),
|
||||
(re.compile(r"~~(.+?)~~"), r"\1"),
|
||||
(re.compile(r"`([^`]+)`"), r"\1"),
|
||||
)
|
||||
|
||||
|
||||
@ -73,32 +75,32 @@ def _sig_render_table(table_lines: list[str]) -> str:
|
||||
"""Render a markdown pipe-table as fixed-width plain text."""
|
||||
|
||||
def dw(s: str) -> int:
|
||||
return sum(2 if unicodedata.east_asian_width(c) in ('W', 'F') else 1 for c in s)
|
||||
return sum(2 if unicodedata.east_asian_width(c) in ("W", "F") else 1 for c in s)
|
||||
|
||||
rows: list[list[str]] = []
|
||||
has_sep = False
|
||||
for line in table_lines:
|
||||
cells = [_sig_strip_cell(c) for c in line.strip().strip('|').split('|')]
|
||||
if all(re.match(r'^:?-+:?$', c) for c in cells if c):
|
||||
cells = [_sig_strip_cell(c) for c in line.strip().strip("|").split("|")]
|
||||
if all(re.match(r"^:?-+:?$", c) for c in cells if c):
|
||||
has_sep = True
|
||||
continue
|
||||
rows.append(cells)
|
||||
if not rows or not has_sep:
|
||||
return '\n'.join(table_lines)
|
||||
return "\n".join(table_lines)
|
||||
|
||||
ncols = max(len(r) for r in rows)
|
||||
for r in rows:
|
||||
r.extend([''] * (ncols - len(r)))
|
||||
r.extend([""] * (ncols - len(r)))
|
||||
widths = [max(dw(r[c]) for r in rows) for c in range(ncols)]
|
||||
|
||||
def dr(cells: list[str]) -> str:
|
||||
return ' '.join(f'{c}{" " * (w - dw(c))}' for c, w in zip(cells, widths))
|
||||
return " ".join(f"{c}{' ' * (w - dw(c))}" for c, w in zip(cells, widths))
|
||||
|
||||
out = [dr(rows[0])]
|
||||
out.append(' '.join('─' * w for w in widths))
|
||||
out.append(" ".join("─" * w for w in widths))
|
||||
for row in rows[1:]:
|
||||
out.append(dr(row))
|
||||
return '\n'.join(out)
|
||||
return "\n".join(out)
|
||||
|
||||
|
||||
def _markdown_to_signal(text: str) -> tuple[str, list[str]]:
|
||||
@ -121,17 +123,17 @@ def _markdown_to_signal(text: str) -> tuple[str, list[str]]:
|
||||
text = _SIG_CODE_BLOCK_RE.sub(save_code, text)
|
||||
|
||||
# Detect and render pipe-tables line by line.
|
||||
lines = text.split('\n')
|
||||
lines = text.split("\n")
|
||||
rebuilt: list[str] = []
|
||||
i = 0
|
||||
while i < len(lines):
|
||||
if re.match(r'^\s*\|.+\|', lines[i]):
|
||||
if re.match(r"^\s*\|.+\|", lines[i]):
|
||||
tbl: list[str] = []
|
||||
while i < len(lines) and re.match(r'^\s*\|.+\|', lines[i]):
|
||||
while i < len(lines) and re.match(r"^\s*\|.+\|", lines[i]):
|
||||
tbl.append(lines[i])
|
||||
i += 1
|
||||
rendered = _sig_render_table(tbl)
|
||||
if rendered != '\n'.join(tbl):
|
||||
if rendered != "\n".join(tbl):
|
||||
protected.append(rendered)
|
||||
rebuilt.append(f"\x00C{len(protected) - 1}\x00")
|
||||
else:
|
||||
@ -139,7 +141,7 @@ def _markdown_to_signal(text: str) -> tuple[str, list[str]]:
|
||||
else:
|
||||
rebuilt.append(lines[i])
|
||||
i += 1
|
||||
text = '\n'.join(rebuilt)
|
||||
text = "\n".join(rebuilt)
|
||||
|
||||
# Phase 2 (run-based): process inline patterns.
|
||||
runs: list[_Run] = [_Run(text)]
|
||||
@ -164,7 +166,10 @@ def _markdown_to_signal(text: str) -> tuple[str, list[str]]:
|
||||
runs[:] = new_runs
|
||||
|
||||
# Restore code/table placeholders as opaque MONOSPACE runs.
|
||||
transform(_SIG_TOKEN_RE, lambda m, s: [_Run(protected[int(m.group(1))], s | {"MONOSPACE"}, opaque=True)])
|
||||
transform(
|
||||
_SIG_TOKEN_RE,
|
||||
lambda m, s: [_Run(protected[int(m.group(1))], s | {"MONOSPACE"}, opaque=True)],
|
||||
)
|
||||
|
||||
# Inline code (opaque).
|
||||
transform(_SIG_INLINE_CODE_RE, lambda m, s: [_Run(m.group(1), s | {"MONOSPACE"}, opaque=True)])
|
||||
@ -186,7 +191,7 @@ def _markdown_to_signal(text: str) -> tuple[str, list[str]]:
|
||||
link_text, url = m.group(1), m.group(2)
|
||||
|
||||
def _norm(u: str) -> str:
|
||||
return re.sub(r'^https?://(www\.)?', '', u).rstrip('/').lower()
|
||||
return re.sub(r"^https?://(www\.)?", "", u).rstrip("/").lower()
|
||||
|
||||
if _norm(url) == _norm(link_text):
|
||||
return [_Run(url, s)]
|
||||
@ -581,9 +586,7 @@ class SignalChannel(BaseChannel):
|
||||
raise
|
||||
|
||||
@asynccontextmanager
|
||||
async def _safe_handle(
|
||||
self, action: str, payload: Any = None
|
||||
) -> AsyncIterator[None]:
|
||||
async def _safe_handle(self, action: str, payload: Any = None) -> AsyncIterator[None]:
|
||||
"""Swallow and log any exception from a top-level handler block.
|
||||
|
||||
Logs `self.logger.error` with the action name, the exception, and a
|
||||
@ -788,9 +791,7 @@ class SignalChannel(BaseChannel):
|
||||
return False, chat_id
|
||||
if self.config.dm.policy == "allowlist":
|
||||
if not self._sender_matches_allowlist(sender_id, self.config.dm.allow_from):
|
||||
self.logger.debug(
|
||||
f"Ignoring DM from {sender_id} (policy: {self.config.dm.policy})"
|
||||
)
|
||||
self.logger.debug(f"Ignoring DM from {sender_id} (policy: {self.config.dm.policy})")
|
||||
return False, chat_id
|
||||
return True, chat_id
|
||||
|
||||
@ -817,9 +818,7 @@ class SignalChannel(BaseChannel):
|
||||
if is_group_message:
|
||||
buffer_context = self._get_group_buffer_context(chat_id)
|
||||
if buffer_context:
|
||||
content_parts.append(
|
||||
f"[Recent group messages for context:]\n{buffer_context}\n---"
|
||||
)
|
||||
content_parts.append(f"[Recent group messages for context:]\n{buffer_context}\n---")
|
||||
|
||||
if message_text:
|
||||
if is_group_message:
|
||||
@ -842,9 +841,7 @@ class SignalChannel(BaseChannel):
|
||||
dest_path = media_dir / f"signal_{safe_filename(filename)}"
|
||||
shutil.copy2(source_path, dest_path)
|
||||
media_paths.append(str(dest_path))
|
||||
media_type = (
|
||||
content_type.split("/")[0] if "/" in content_type else "file"
|
||||
)
|
||||
media_type = content_type.split("/")[0] if "/" in content_type else "file"
|
||||
if media_type not in ("image", "audio", "video"):
|
||||
media_type = "file"
|
||||
content_parts.append(f"[{media_type}: {dest_path}]")
|
||||
|
||||
@ -610,16 +610,12 @@ class TestCheckInboundPolicy:
|
||||
|
||||
def test_group_open_without_mention_blocks(self):
|
||||
ch = _make_channel(group_enabled=True, group_policy="open", require_mention=True)
|
||||
allowed, _ = self._call(
|
||||
ch, is_group_message=True, group_id="g1", message_text="plain talk"
|
||||
)
|
||||
allowed, _ = self._call(ch, is_group_message=True, group_id="g1", message_text="plain talk")
|
||||
assert allowed is False
|
||||
|
||||
def test_group_command_bypasses_mention_requirement(self):
|
||||
ch = _make_channel(group_enabled=True, group_policy="open", require_mention=True)
|
||||
allowed, _ = self._call(
|
||||
ch, is_group_message=True, group_id="g1", message_text="/help"
|
||||
)
|
||||
allowed, _ = self._call(ch, is_group_message=True, group_id="g1", message_text="/help")
|
||||
assert allowed is True
|
||||
|
||||
def test_allowed_group_appends_to_buffer(self):
|
||||
@ -703,9 +699,7 @@ class TestHandleDataMessageDM:
|
||||
async def test_dm_allowlist_matches_uuid_case_insensitive(self):
|
||||
"""UUID matching must be case-insensitive."""
|
||||
uuid = "ABCDEF12-3456-7890-ABCD-EF1234567890"
|
||||
ch, handled = self._make_dm_channel(
|
||||
policy="allowlist", allow_from=[uuid.lower()]
|
||||
)
|
||||
ch, handled = self._make_dm_channel(policy="allowlist", allow_from=[uuid.lower()])
|
||||
params = _dm_envelope(source_number="+19995550001", source_uuid=uuid)
|
||||
await ch._handle_receive_notification(params)
|
||||
assert len(handled) == 1
|
||||
@ -1076,9 +1070,7 @@ class TestCommandHandling:
|
||||
ch, forwarded = _make_channel_with_capture(
|
||||
group_enabled=True, group_policy="open", require_mention=True
|
||||
)
|
||||
params = _group_envelope(
|
||||
source_number="+19995550001", group_id="grp==", message="/reset"
|
||||
)
|
||||
params = _group_envelope(source_number="+19995550001", group_id="grp==", message="/reset")
|
||||
await ch._handle_receive_notification(params)
|
||||
assert len(forwarded) == 1
|
||||
assert "/reset" in forwarded[0]["content"]
|
||||
@ -1357,9 +1349,7 @@ def test_config_allow_from_aggregates_dm_and_group() -> None:
|
||||
enabled=True,
|
||||
phone_number="+10000000000",
|
||||
dm=SignalDMConfig(enabled=True, policy="allowlist", allow_from=["+1111", "+2222"]),
|
||||
group=SignalGroupConfig(
|
||||
enabled=True, policy="allowlist", allow_from=["+3333", "+1111"]
|
||||
),
|
||||
group=SignalGroupConfig(enabled=True, policy="allowlist", allow_from=["+3333", "+1111"]),
|
||||
)
|
||||
combined = config.allow_from
|
||||
assert "+1111" in combined
|
||||
|
||||
@ -1,7 +1,5 @@
|
||||
"""Unit tests for the Signal markdown → plain text + textStyle converter."""
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.channels.signal import _markdown_to_signal, _partition_styles
|
||||
from nanobot.utils.helpers import split_message
|
||||
|
||||
@ -94,8 +92,9 @@ def test_inline_code():
|
||||
def test_code_block():
|
||||
plain, styles = _markdown_to_signal("```\nprint('hi')\n```")
|
||||
assert "print('hi')" in plain
|
||||
assert styles_for(plain, styles).get("print('hi')\n") == ["MONOSPACE"] or \
|
||||
"MONOSPACE" in str(styles_for(plain, styles))
|
||||
assert styles_for(plain, styles).get("print('hi')\n") == ["MONOSPACE"] or "MONOSPACE" in str(
|
||||
styles_for(plain, styles)
|
||||
)
|
||||
|
||||
|
||||
def test_code_block_with_lang():
|
||||
@ -278,9 +277,7 @@ def assert_within_utf16_bounds(plain: str, styles: list[str]) -> None:
|
||||
start_s, length_s, _ = entry.split(":", 2)
|
||||
start, length = int(start_s), int(length_s)
|
||||
assert start >= 0
|
||||
assert start + length <= limit, (
|
||||
f"range {entry} exceeds utf-16 length {limit} of {plain!r}"
|
||||
)
|
||||
assert start + length <= limit, f"range {entry} exceeds utf-16 length {limit} of {plain!r}"
|
||||
|
||||
|
||||
def test_bold_with_emoji_inside():
|
||||
@ -413,7 +410,9 @@ def test_partition_styles_long_message_preserves_chunk_one_styles():
|
||||
for entry in final_styles:
|
||||
s, ln, _ = entry.split(":", 2)
|
||||
start, length = int(s), int(ln)
|
||||
slice_ = final_chunk.encode("utf-16-le")[start * 2 : (start + length) * 2].decode("utf-16-le")
|
||||
slice_ = final_chunk.encode("utf-16-le")[start * 2 : (start + length) * 2].decode(
|
||||
"utf-16-le"
|
||||
)
|
||||
assert slice_ == "tail"
|
||||
|
||||
|
||||
@ -444,7 +443,9 @@ def test_partition_styles_with_non_bmp_chunk_offset():
|
||||
for entry in final_styles:
|
||||
s, ln, _ = entry.split(":", 2)
|
||||
start, length = int(s), int(ln)
|
||||
slice_ = final_chunk.encode("utf-16-le")[start * 2 : (start + length) * 2].decode("utf-16-le")
|
||||
slice_ = final_chunk.encode("utf-16-le")[start * 2 : (start + length) * 2].decode(
|
||||
"utf-16-le"
|
||||
)
|
||||
assert slice_ == "tail"
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user