mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-21 09:02:32 +00:00
Cleanup
This commit is contained in:
parent
3874b3acf4
commit
aed6b6967c
@ -32,17 +32,19 @@ class _Run:
|
|||||||
opaque: bool = False # code / table content — skip further pattern processing
|
opaque: bool = False # code / table content — skip further pattern processing
|
||||||
|
|
||||||
|
|
||||||
_SIG_CODE_BLOCK_RE = re.compile(r'```(?:\w+)?\n?([\s\S]*?)```')
|
_SIG_CODE_BLOCK_RE = re.compile(r"```(?:\w+)?\n?([\s\S]*?)```")
|
||||||
_SIG_INLINE_CODE_RE = re.compile(r'`([^`\n]+)`')
|
_SIG_INLINE_CODE_RE = re.compile(r"`([^`\n]+)`")
|
||||||
_SIG_HEADER_RE = re.compile(r'^#{1,6}\s+(.+)$', re.MULTILINE)
|
_SIG_HEADER_RE = re.compile(r"^#{1,6}\s+(.+)$", re.MULTILINE)
|
||||||
_SIG_BLOCKQUOTE_RE = re.compile(r'^>\s*(.*)$', re.MULTILINE)
|
_SIG_BLOCKQUOTE_RE = re.compile(r"^>\s*(.*)$", re.MULTILINE)
|
||||||
_SIG_BULLET_RE = re.compile(r'^[-*]\s+', re.MULTILINE)
|
_SIG_BULLET_RE = re.compile(r"^[-*]\s+", re.MULTILINE)
|
||||||
_SIG_OLIST_RE = re.compile(r'^(\d+)\.\s+', re.MULTILINE)
|
_SIG_OLIST_RE = re.compile(r"^(\d+)\.\s+", re.MULTILINE)
|
||||||
_SIG_LINK_RE = re.compile(r'\[([^\]]+)\]\(([^)]+)\)')
|
_SIG_LINK_RE = re.compile(r"\[([^\]]+)\]\(([^)]+)\)")
|
||||||
_SIG_BOLD_RE = re.compile(r'\*\*(.+?)\*\*|__(.+?)__', re.DOTALL)
|
_SIG_BOLD_RE = re.compile(r"\*\*(.+?)\*\*|__(.+?)__", re.DOTALL)
|
||||||
_SIG_ITALIC_RE = re.compile(r'(?<!\*)\*([^*\n]+)\*(?!\*)|(?<![a-zA-Z0-9_])_([^_\n]+)_(?![a-zA-Z0-9_])')
|
_SIG_ITALIC_RE = re.compile(
|
||||||
_SIG_STRIKE_RE = re.compile(r'~~(.+?)~~|(?<![~\w])~([^~\n]+)~(?![~\w])', re.DOTALL)
|
r"(?<!\*)\*([^*\n]+)\*(?!\*)|(?<![a-zA-Z0-9_])_([^_\n]+)_(?![a-zA-Z0-9_])"
|
||||||
_SIG_TOKEN_RE = re.compile(r'\x00C(\d+)\x00')
|
)
|
||||||
|
_SIG_STRIKE_RE = re.compile(r"~~(.+?)~~|(?<![~\w])~([^~\n]+)~(?![~\w])", re.DOTALL)
|
||||||
|
_SIG_TOKEN_RE = re.compile(r"\x00C(\d+)\x00")
|
||||||
|
|
||||||
# Patterns used to strip inline markdown when rendering table cells as plain
|
# Patterns used to strip inline markdown when rendering table cells as plain
|
||||||
# text. Defined separately from the styling regexes above because the cell
|
# text. Defined separately from the styling regexes above because the cell
|
||||||
@ -50,10 +52,10 @@ _SIG_TOKEN_RE = re.compile(r'\x00C(\d+)\x00')
|
|||||||
# single-tilde strikethrough) and benefits from each pattern's group 1 being
|
# single-tilde strikethrough) and benefits from each pattern's group 1 being
|
||||||
# the content directly.
|
# the content directly.
|
||||||
_SIG_CELL_STRIP_PATTERNS: tuple[tuple[re.Pattern, str], ...] = (
|
_SIG_CELL_STRIP_PATTERNS: tuple[tuple[re.Pattern, str], ...] = (
|
||||||
(re.compile(r'\*\*(.+?)\*\*'), r'\1'),
|
(re.compile(r"\*\*(.+?)\*\*"), r"\1"),
|
||||||
(re.compile(r'__(.+?)__'), r'\1'),
|
(re.compile(r"__(.+?)__"), r"\1"),
|
||||||
(re.compile(r'~~(.+?)~~'), r'\1'),
|
(re.compile(r"~~(.+?)~~"), r"\1"),
|
||||||
(re.compile(r'`([^`]+)`'), r'\1'),
|
(re.compile(r"`([^`]+)`"), r"\1"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -73,32 +75,32 @@ def _sig_render_table(table_lines: list[str]) -> str:
|
|||||||
"""Render a markdown pipe-table as fixed-width plain text."""
|
"""Render a markdown pipe-table as fixed-width plain text."""
|
||||||
|
|
||||||
def dw(s: str) -> int:
|
def dw(s: str) -> int:
|
||||||
return sum(2 if unicodedata.east_asian_width(c) in ('W', 'F') else 1 for c in s)
|
return sum(2 if unicodedata.east_asian_width(c) in ("W", "F") else 1 for c in s)
|
||||||
|
|
||||||
rows: list[list[str]] = []
|
rows: list[list[str]] = []
|
||||||
has_sep = False
|
has_sep = False
|
||||||
for line in table_lines:
|
for line in table_lines:
|
||||||
cells = [_sig_strip_cell(c) for c in line.strip().strip('|').split('|')]
|
cells = [_sig_strip_cell(c) for c in line.strip().strip("|").split("|")]
|
||||||
if all(re.match(r'^:?-+:?$', c) for c in cells if c):
|
if all(re.match(r"^:?-+:?$", c) for c in cells if c):
|
||||||
has_sep = True
|
has_sep = True
|
||||||
continue
|
continue
|
||||||
rows.append(cells)
|
rows.append(cells)
|
||||||
if not rows or not has_sep:
|
if not rows or not has_sep:
|
||||||
return '\n'.join(table_lines)
|
return "\n".join(table_lines)
|
||||||
|
|
||||||
ncols = max(len(r) for r in rows)
|
ncols = max(len(r) for r in rows)
|
||||||
for r in rows:
|
for r in rows:
|
||||||
r.extend([''] * (ncols - len(r)))
|
r.extend([""] * (ncols - len(r)))
|
||||||
widths = [max(dw(r[c]) for r in rows) for c in range(ncols)]
|
widths = [max(dw(r[c]) for r in rows) for c in range(ncols)]
|
||||||
|
|
||||||
def dr(cells: list[str]) -> str:
|
def dr(cells: list[str]) -> str:
|
||||||
return ' '.join(f'{c}{" " * (w - dw(c))}' for c, w in zip(cells, widths))
|
return " ".join(f"{c}{' ' * (w - dw(c))}" for c, w in zip(cells, widths))
|
||||||
|
|
||||||
out = [dr(rows[0])]
|
out = [dr(rows[0])]
|
||||||
out.append(' '.join('─' * w for w in widths))
|
out.append(" ".join("─" * w for w in widths))
|
||||||
for row in rows[1:]:
|
for row in rows[1:]:
|
||||||
out.append(dr(row))
|
out.append(dr(row))
|
||||||
return '\n'.join(out)
|
return "\n".join(out)
|
||||||
|
|
||||||
|
|
||||||
def _markdown_to_signal(text: str) -> tuple[str, list[str]]:
|
def _markdown_to_signal(text: str) -> tuple[str, list[str]]:
|
||||||
@ -121,17 +123,17 @@ def _markdown_to_signal(text: str) -> tuple[str, list[str]]:
|
|||||||
text = _SIG_CODE_BLOCK_RE.sub(save_code, text)
|
text = _SIG_CODE_BLOCK_RE.sub(save_code, text)
|
||||||
|
|
||||||
# Detect and render pipe-tables line by line.
|
# Detect and render pipe-tables line by line.
|
||||||
lines = text.split('\n')
|
lines = text.split("\n")
|
||||||
rebuilt: list[str] = []
|
rebuilt: list[str] = []
|
||||||
i = 0
|
i = 0
|
||||||
while i < len(lines):
|
while i < len(lines):
|
||||||
if re.match(r'^\s*\|.+\|', lines[i]):
|
if re.match(r"^\s*\|.+\|", lines[i]):
|
||||||
tbl: list[str] = []
|
tbl: list[str] = []
|
||||||
while i < len(lines) and re.match(r'^\s*\|.+\|', lines[i]):
|
while i < len(lines) and re.match(r"^\s*\|.+\|", lines[i]):
|
||||||
tbl.append(lines[i])
|
tbl.append(lines[i])
|
||||||
i += 1
|
i += 1
|
||||||
rendered = _sig_render_table(tbl)
|
rendered = _sig_render_table(tbl)
|
||||||
if rendered != '\n'.join(tbl):
|
if rendered != "\n".join(tbl):
|
||||||
protected.append(rendered)
|
protected.append(rendered)
|
||||||
rebuilt.append(f"\x00C{len(protected) - 1}\x00")
|
rebuilt.append(f"\x00C{len(protected) - 1}\x00")
|
||||||
else:
|
else:
|
||||||
@ -139,7 +141,7 @@ def _markdown_to_signal(text: str) -> tuple[str, list[str]]:
|
|||||||
else:
|
else:
|
||||||
rebuilt.append(lines[i])
|
rebuilt.append(lines[i])
|
||||||
i += 1
|
i += 1
|
||||||
text = '\n'.join(rebuilt)
|
text = "\n".join(rebuilt)
|
||||||
|
|
||||||
# Phase 2 (run-based): process inline patterns.
|
# Phase 2 (run-based): process inline patterns.
|
||||||
runs: list[_Run] = [_Run(text)]
|
runs: list[_Run] = [_Run(text)]
|
||||||
@ -156,7 +158,7 @@ def _markdown_to_signal(text: str) -> tuple[str, list[str]]:
|
|||||||
pos = 0
|
pos = 0
|
||||||
for m in pattern.finditer(run.text):
|
for m in pattern.finditer(run.text):
|
||||||
if m.start() > pos:
|
if m.start() > pos:
|
||||||
new_runs.append(_Run(run.text[pos:m.start()], run.styles))
|
new_runs.append(_Run(run.text[pos : m.start()], run.styles))
|
||||||
new_runs.extend(make_runs(m, run.styles))
|
new_runs.extend(make_runs(m, run.styles))
|
||||||
pos = m.end()
|
pos = m.end()
|
||||||
if pos < len(run.text):
|
if pos < len(run.text):
|
||||||
@ -164,7 +166,10 @@ def _markdown_to_signal(text: str) -> tuple[str, list[str]]:
|
|||||||
runs[:] = new_runs
|
runs[:] = new_runs
|
||||||
|
|
||||||
# Restore code/table placeholders as opaque MONOSPACE runs.
|
# Restore code/table placeholders as opaque MONOSPACE runs.
|
||||||
transform(_SIG_TOKEN_RE, lambda m, s: [_Run(protected[int(m.group(1))], s | {"MONOSPACE"}, opaque=True)])
|
transform(
|
||||||
|
_SIG_TOKEN_RE,
|
||||||
|
lambda m, s: [_Run(protected[int(m.group(1))], s | {"MONOSPACE"}, opaque=True)],
|
||||||
|
)
|
||||||
|
|
||||||
# Inline code (opaque).
|
# Inline code (opaque).
|
||||||
transform(_SIG_INLINE_CODE_RE, lambda m, s: [_Run(m.group(1), s | {"MONOSPACE"}, opaque=True)])
|
transform(_SIG_INLINE_CODE_RE, lambda m, s: [_Run(m.group(1), s | {"MONOSPACE"}, opaque=True)])
|
||||||
@ -186,7 +191,7 @@ def _markdown_to_signal(text: str) -> tuple[str, list[str]]:
|
|||||||
link_text, url = m.group(1), m.group(2)
|
link_text, url = m.group(1), m.group(2)
|
||||||
|
|
||||||
def _norm(u: str) -> str:
|
def _norm(u: str) -> str:
|
||||||
return re.sub(r'^https?://(www\.)?', '', u).rstrip('/').lower()
|
return re.sub(r"^https?://(www\.)?", "", u).rstrip("/").lower()
|
||||||
|
|
||||||
if _norm(url) == _norm(link_text):
|
if _norm(url) == _norm(link_text):
|
||||||
return [_Run(url, s)]
|
return [_Run(url, s)]
|
||||||
@ -581,9 +586,7 @@ class SignalChannel(BaseChannel):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def _safe_handle(
|
async def _safe_handle(self, action: str, payload: Any = None) -> AsyncIterator[None]:
|
||||||
self, action: str, payload: Any = None
|
|
||||||
) -> AsyncIterator[None]:
|
|
||||||
"""Swallow and log any exception from a top-level handler block.
|
"""Swallow and log any exception from a top-level handler block.
|
||||||
|
|
||||||
Logs `self.logger.error` with the action name, the exception, and a
|
Logs `self.logger.error` with the action name, the exception, and a
|
||||||
@ -788,9 +791,7 @@ class SignalChannel(BaseChannel):
|
|||||||
return False, chat_id
|
return False, chat_id
|
||||||
if self.config.dm.policy == "allowlist":
|
if self.config.dm.policy == "allowlist":
|
||||||
if not self._sender_matches_allowlist(sender_id, self.config.dm.allow_from):
|
if not self._sender_matches_allowlist(sender_id, self.config.dm.allow_from):
|
||||||
self.logger.debug(
|
self.logger.debug(f"Ignoring DM from {sender_id} (policy: {self.config.dm.policy})")
|
||||||
f"Ignoring DM from {sender_id} (policy: {self.config.dm.policy})"
|
|
||||||
)
|
|
||||||
return False, chat_id
|
return False, chat_id
|
||||||
return True, chat_id
|
return True, chat_id
|
||||||
|
|
||||||
@ -817,9 +818,7 @@ class SignalChannel(BaseChannel):
|
|||||||
if is_group_message:
|
if is_group_message:
|
||||||
buffer_context = self._get_group_buffer_context(chat_id)
|
buffer_context = self._get_group_buffer_context(chat_id)
|
||||||
if buffer_context:
|
if buffer_context:
|
||||||
content_parts.append(
|
content_parts.append(f"[Recent group messages for context:]\n{buffer_context}\n---")
|
||||||
f"[Recent group messages for context:]\n{buffer_context}\n---"
|
|
||||||
)
|
|
||||||
|
|
||||||
if message_text:
|
if message_text:
|
||||||
if is_group_message:
|
if is_group_message:
|
||||||
@ -842,9 +841,7 @@ class SignalChannel(BaseChannel):
|
|||||||
dest_path = media_dir / f"signal_{safe_filename(filename)}"
|
dest_path = media_dir / f"signal_{safe_filename(filename)}"
|
||||||
shutil.copy2(source_path, dest_path)
|
shutil.copy2(source_path, dest_path)
|
||||||
media_paths.append(str(dest_path))
|
media_paths.append(str(dest_path))
|
||||||
media_type = (
|
media_type = content_type.split("/")[0] if "/" in content_type else "file"
|
||||||
content_type.split("/")[0] if "/" in content_type else "file"
|
|
||||||
)
|
|
||||||
if media_type not in ("image", "audio", "video"):
|
if media_type not in ("image", "audio", "video"):
|
||||||
media_type = "file"
|
media_type = "file"
|
||||||
content_parts.append(f"[{media_type}: {dest_path}]")
|
content_parts.append(f"[{media_type}: {dest_path}]")
|
||||||
|
|||||||
@ -610,16 +610,12 @@ class TestCheckInboundPolicy:
|
|||||||
|
|
||||||
def test_group_open_without_mention_blocks(self):
|
def test_group_open_without_mention_blocks(self):
|
||||||
ch = _make_channel(group_enabled=True, group_policy="open", require_mention=True)
|
ch = _make_channel(group_enabled=True, group_policy="open", require_mention=True)
|
||||||
allowed, _ = self._call(
|
allowed, _ = self._call(ch, is_group_message=True, group_id="g1", message_text="plain talk")
|
||||||
ch, is_group_message=True, group_id="g1", message_text="plain talk"
|
|
||||||
)
|
|
||||||
assert allowed is False
|
assert allowed is False
|
||||||
|
|
||||||
def test_group_command_bypasses_mention_requirement(self):
|
def test_group_command_bypasses_mention_requirement(self):
|
||||||
ch = _make_channel(group_enabled=True, group_policy="open", require_mention=True)
|
ch = _make_channel(group_enabled=True, group_policy="open", require_mention=True)
|
||||||
allowed, _ = self._call(
|
allowed, _ = self._call(ch, is_group_message=True, group_id="g1", message_text="/help")
|
||||||
ch, is_group_message=True, group_id="g1", message_text="/help"
|
|
||||||
)
|
|
||||||
assert allowed is True
|
assert allowed is True
|
||||||
|
|
||||||
def test_allowed_group_appends_to_buffer(self):
|
def test_allowed_group_appends_to_buffer(self):
|
||||||
@ -703,9 +699,7 @@ class TestHandleDataMessageDM:
|
|||||||
async def test_dm_allowlist_matches_uuid_case_insensitive(self):
|
async def test_dm_allowlist_matches_uuid_case_insensitive(self):
|
||||||
"""UUID matching must be case-insensitive."""
|
"""UUID matching must be case-insensitive."""
|
||||||
uuid = "ABCDEF12-3456-7890-ABCD-EF1234567890"
|
uuid = "ABCDEF12-3456-7890-ABCD-EF1234567890"
|
||||||
ch, handled = self._make_dm_channel(
|
ch, handled = self._make_dm_channel(policy="allowlist", allow_from=[uuid.lower()])
|
||||||
policy="allowlist", allow_from=[uuid.lower()]
|
|
||||||
)
|
|
||||||
params = _dm_envelope(source_number="+19995550001", source_uuid=uuid)
|
params = _dm_envelope(source_number="+19995550001", source_uuid=uuid)
|
||||||
await ch._handle_receive_notification(params)
|
await ch._handle_receive_notification(params)
|
||||||
assert len(handled) == 1
|
assert len(handled) == 1
|
||||||
@ -1076,9 +1070,7 @@ class TestCommandHandling:
|
|||||||
ch, forwarded = _make_channel_with_capture(
|
ch, forwarded = _make_channel_with_capture(
|
||||||
group_enabled=True, group_policy="open", require_mention=True
|
group_enabled=True, group_policy="open", require_mention=True
|
||||||
)
|
)
|
||||||
params = _group_envelope(
|
params = _group_envelope(source_number="+19995550001", group_id="grp==", message="/reset")
|
||||||
source_number="+19995550001", group_id="grp==", message="/reset"
|
|
||||||
)
|
|
||||||
await ch._handle_receive_notification(params)
|
await ch._handle_receive_notification(params)
|
||||||
assert len(forwarded) == 1
|
assert len(forwarded) == 1
|
||||||
assert "/reset" in forwarded[0]["content"]
|
assert "/reset" in forwarded[0]["content"]
|
||||||
@ -1357,9 +1349,7 @@ def test_config_allow_from_aggregates_dm_and_group() -> None:
|
|||||||
enabled=True,
|
enabled=True,
|
||||||
phone_number="+10000000000",
|
phone_number="+10000000000",
|
||||||
dm=SignalDMConfig(enabled=True, policy="allowlist", allow_from=["+1111", "+2222"]),
|
dm=SignalDMConfig(enabled=True, policy="allowlist", allow_from=["+1111", "+2222"]),
|
||||||
group=SignalGroupConfig(
|
group=SignalGroupConfig(enabled=True, policy="allowlist", allow_from=["+3333", "+1111"]),
|
||||||
enabled=True, policy="allowlist", allow_from=["+3333", "+1111"]
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
combined = config.allow_from
|
combined = config.allow_from
|
||||||
assert "+1111" in combined
|
assert "+1111" in combined
|
||||||
|
|||||||
@ -1,7 +1,5 @@
|
|||||||
"""Unit tests for the Signal markdown → plain text + textStyle converter."""
|
"""Unit tests for the Signal markdown → plain text + textStyle converter."""
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from nanobot.channels.signal import _markdown_to_signal, _partition_styles
|
from nanobot.channels.signal import _markdown_to_signal, _partition_styles
|
||||||
from nanobot.utils.helpers import split_message
|
from nanobot.utils.helpers import split_message
|
||||||
|
|
||||||
@ -94,8 +92,9 @@ def test_inline_code():
|
|||||||
def test_code_block():
|
def test_code_block():
|
||||||
plain, styles = _markdown_to_signal("```\nprint('hi')\n```")
|
plain, styles = _markdown_to_signal("```\nprint('hi')\n```")
|
||||||
assert "print('hi')" in plain
|
assert "print('hi')" in plain
|
||||||
assert styles_for(plain, styles).get("print('hi')\n") == ["MONOSPACE"] or \
|
assert styles_for(plain, styles).get("print('hi')\n") == ["MONOSPACE"] or "MONOSPACE" in str(
|
||||||
"MONOSPACE" in str(styles_for(plain, styles))
|
styles_for(plain, styles)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_code_block_with_lang():
|
def test_code_block_with_lang():
|
||||||
@ -278,9 +277,7 @@ def assert_within_utf16_bounds(plain: str, styles: list[str]) -> None:
|
|||||||
start_s, length_s, _ = entry.split(":", 2)
|
start_s, length_s, _ = entry.split(":", 2)
|
||||||
start, length = int(start_s), int(length_s)
|
start, length = int(start_s), int(length_s)
|
||||||
assert start >= 0
|
assert start >= 0
|
||||||
assert start + length <= limit, (
|
assert start + length <= limit, f"range {entry} exceeds utf-16 length {limit} of {plain!r}"
|
||||||
f"range {entry} exceeds utf-16 length {limit} of {plain!r}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_bold_with_emoji_inside():
|
def test_bold_with_emoji_inside():
|
||||||
@ -413,7 +410,9 @@ def test_partition_styles_long_message_preserves_chunk_one_styles():
|
|||||||
for entry in final_styles:
|
for entry in final_styles:
|
||||||
s, ln, _ = entry.split(":", 2)
|
s, ln, _ = entry.split(":", 2)
|
||||||
start, length = int(s), int(ln)
|
start, length = int(s), int(ln)
|
||||||
slice_ = final_chunk.encode("utf-16-le")[start * 2 : (start + length) * 2].decode("utf-16-le")
|
slice_ = final_chunk.encode("utf-16-le")[start * 2 : (start + length) * 2].decode(
|
||||||
|
"utf-16-le"
|
||||||
|
)
|
||||||
assert slice_ == "tail"
|
assert slice_ == "tail"
|
||||||
|
|
||||||
|
|
||||||
@ -444,7 +443,9 @@ def test_partition_styles_with_non_bmp_chunk_offset():
|
|||||||
for entry in final_styles:
|
for entry in final_styles:
|
||||||
s, ln, _ = entry.split(":", 2)
|
s, ln, _ = entry.split(":", 2)
|
||||||
start, length = int(s), int(ln)
|
start, length = int(s), int(ln)
|
||||||
slice_ = final_chunk.encode("utf-16-le")[start * 2 : (start + length) * 2].decode("utf-16-le")
|
slice_ = final_chunk.encode("utf-16-le")[start * 2 : (start + length) * 2].decode(
|
||||||
|
"utf-16-le"
|
||||||
|
)
|
||||||
assert slice_ == "tail"
|
assert slice_ == "tail"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user