mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-07 02:05:51 +00:00
fix(channels): reject unauthorized inbound before side effects
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
parent
1813fc5021
commit
4db50f2e32
@ -407,6 +407,12 @@ class EmailChannel(BaseChannel):
|
|||||||
self._remember_processed_uid(uid, dedupe, cycle_uids)
|
self._remember_processed_uid(uid, dedupe, cycle_uids)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if not self.is_allowed(sender):
|
||||||
|
self._remember_processed_uid(uid, dedupe, cycle_uids)
|
||||||
|
if mark_seen:
|
||||||
|
client.store(imap_id, "+FLAGS", "\\Seen")
|
||||||
|
continue
|
||||||
|
|
||||||
subject = self._decode_header_value(parsed.get("Subject", ""))
|
subject = self._decode_header_value(parsed.get("Subject", ""))
|
||||||
date_value = parsed.get("Date", "")
|
date_value = parsed.get("Date", "")
|
||||||
message_id = parsed.get("Message-ID", "").strip()
|
message_id = parsed.get("Message-ID", "").strip()
|
||||||
|
|||||||
@ -1644,15 +1644,7 @@ class FeishuChannel(BaseChannel):
|
|||||||
logger.debug("Feishu raw message: {}", message.content)
|
logger.debug("Feishu raw message: {}", message.content)
|
||||||
logger.debug("Feishu mentions: {}", getattr(message, "mentions", None))
|
logger.debug("Feishu mentions: {}", getattr(message, "mentions", None))
|
||||||
|
|
||||||
# Deduplication check
|
|
||||||
message_id = message.message_id
|
message_id = message.message_id
|
||||||
if message_id in self._processed_message_ids:
|
|
||||||
return
|
|
||||||
self._processed_message_ids[message_id] = None
|
|
||||||
|
|
||||||
# Trim cache
|
|
||||||
while len(self._processed_message_ids) > 1000:
|
|
||||||
self._processed_message_ids.popitem(last=False)
|
|
||||||
|
|
||||||
# Skip bot messages
|
# Skip bot messages
|
||||||
if sender.sender_type == "bot":
|
if sender.sender_type == "bot":
|
||||||
@ -1663,10 +1655,22 @@ class FeishuChannel(BaseChannel):
|
|||||||
chat_type = message.chat_type
|
chat_type = message.chat_type
|
||||||
msg_type = message.message_type
|
msg_type = message.message_type
|
||||||
|
|
||||||
|
if not self.is_allowed(sender_id):
|
||||||
|
return
|
||||||
|
|
||||||
if chat_type == "group" and not self._is_group_message_for_bot(message):
|
if chat_type == "group" and not self._is_group_message_for_bot(message):
|
||||||
logger.debug("Feishu: skipping group message (not mentioned)")
|
logger.debug("Feishu: skipping group message (not mentioned)")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Deduplication check
|
||||||
|
if message_id in self._processed_message_ids:
|
||||||
|
return
|
||||||
|
self._processed_message_ids[message_id] = None
|
||||||
|
|
||||||
|
# Trim cache
|
||||||
|
while len(self._processed_message_ids) > 1000:
|
||||||
|
self._processed_message_ids.popitem(last=False)
|
||||||
|
|
||||||
# Add reaction (non-blocking — tracked background task)
|
# Add reaction (non-blocking — tracked background task)
|
||||||
task = asyncio.create_task(
|
task = asyncio.create_task(
|
||||||
self._add_reaction(message_id, self.config.react_emoji)
|
self._add_reaction(message_id, self.config.react_emoji)
|
||||||
|
|||||||
@ -474,24 +474,28 @@ class QQChannel(BaseChannel):
|
|||||||
async def _on_message(self, data: C2CMessage | GroupMessage, is_group: bool = False) -> None:
|
async def _on_message(self, data: C2CMessage | GroupMessage, is_group: bool = False) -> None:
|
||||||
"""Parse inbound message, download attachments, and publish to the bus."""
|
"""Parse inbound message, download attachments, and publish to the bus."""
|
||||||
try:
|
try:
|
||||||
if data.id in self._processed_ids:
|
|
||||||
return
|
|
||||||
self._processed_ids.append(data.id)
|
|
||||||
|
|
||||||
if is_group:
|
if is_group:
|
||||||
chat_id = data.group_openid
|
chat_id = data.group_openid
|
||||||
user_id = data.author.member_openid
|
user_id = data.author.member_openid
|
||||||
self._chat_type_cache[chat_id] = "group"
|
chat_type = "group"
|
||||||
else:
|
else:
|
||||||
chat_id = str(
|
chat_id = str(
|
||||||
getattr(data.author, "id", None)
|
getattr(data.author, "id", None)
|
||||||
or getattr(data.author, "user_openid", "unknown")
|
or getattr(data.author, "user_openid", "unknown")
|
||||||
)
|
)
|
||||||
user_id = chat_id
|
user_id = chat_id
|
||||||
self._chat_type_cache[chat_id] = "c2c"
|
chat_type = "c2c"
|
||||||
|
|
||||||
content = (data.content or "").strip()
|
content = (data.content or "").strip()
|
||||||
|
|
||||||
|
if not self.is_allowed(user_id):
|
||||||
|
return
|
||||||
|
|
||||||
|
if data.id in self._processed_ids:
|
||||||
|
return
|
||||||
|
self._processed_ids.append(data.id)
|
||||||
|
self._chat_type_cache[chat_id] = chat_type
|
||||||
|
|
||||||
# the data used by tests don't contain attachments property
|
# the data used by tests don't contain attachments property
|
||||||
# so we use getattr with a default of [] to avoid AttributeError in tests
|
# so we use getattr with a default of [] to avoid AttributeError in tests
|
||||||
attachments = getattr(data, "attachments", None) or []
|
attachments = getattr(data, "attachments", None) or []
|
||||||
|
|||||||
@ -993,6 +993,9 @@ class TelegramChannel(BaseChannel):
|
|||||||
return
|
return
|
||||||
message = update.message
|
message = update.message
|
||||||
user = update.effective_user
|
user = update.effective_user
|
||||||
|
sender_id = self._sender_id(user)
|
||||||
|
if not self.is_allowed(sender_id):
|
||||||
|
return
|
||||||
self._remember_thread_context(message)
|
self._remember_thread_context(message)
|
||||||
|
|
||||||
# Strip @bot_username suffix if present
|
# Strip @bot_username suffix if present
|
||||||
@ -1004,7 +1007,7 @@ class TelegramChannel(BaseChannel):
|
|||||||
content = self._normalize_telegram_command(content)
|
content = self._normalize_telegram_command(content)
|
||||||
|
|
||||||
await self._handle_message(
|
await self._handle_message(
|
||||||
sender_id=self._sender_id(user),
|
sender_id=sender_id,
|
||||||
chat_id=str(message.chat_id),
|
chat_id=str(message.chat_id),
|
||||||
content=content,
|
content=content,
|
||||||
metadata=self._build_message_metadata(message, user),
|
metadata=self._build_message_metadata(message, user),
|
||||||
@ -1264,6 +1267,8 @@ class TelegramChannel(BaseChannel):
|
|||||||
if not chat_id:
|
if not chat_id:
|
||||||
logger.warning("Callback query without chat_id")
|
logger.warning("Callback query without chat_id")
|
||||||
return
|
return
|
||||||
|
if not self.is_allowed(sender_id):
|
||||||
|
return
|
||||||
button_label = query.data or ""
|
button_label = query.data or ""
|
||||||
await query.answer()
|
await query.answer()
|
||||||
if query.message:
|
if query.message:
|
||||||
|
|||||||
@ -11,13 +11,13 @@ from pathlib import Path
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
from nanobot.bus.events import OutboundMessage
|
from nanobot.bus.events import OutboundMessage
|
||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
from nanobot.channels.base import BaseChannel
|
from nanobot.channels.base import BaseChannel
|
||||||
from nanobot.config.paths import get_media_dir
|
from nanobot.config.paths import get_media_dir
|
||||||
from nanobot.config.schema import Base
|
from nanobot.config.schema import Base
|
||||||
from pydantic import Field
|
|
||||||
|
|
||||||
WECOM_AVAILABLE = importlib.util.find_spec("wecom_aibot_sdk") is not None
|
WECOM_AVAILABLE = importlib.util.find_spec("wecom_aibot_sdk") is not None
|
||||||
|
|
||||||
@ -204,6 +204,9 @@ class WecomChannel(BaseChannel):
|
|||||||
|
|
||||||
chat_id = body.get("chatid", "") if isinstance(body, dict) else ""
|
chat_id = body.get("chatid", "") if isinstance(body, dict) else ""
|
||||||
|
|
||||||
|
if chat_id and not self.is_allowed(chat_id):
|
||||||
|
return
|
||||||
|
|
||||||
if chat_id and self.config.welcome_message:
|
if chat_id and self.config.welcome_message:
|
||||||
await self._client.reply_welcome(frame, {
|
await self._client.reply_welcome(frame, {
|
||||||
"msgtype": "text",
|
"msgtype": "text",
|
||||||
@ -233,6 +236,12 @@ class WecomChannel(BaseChannel):
|
|||||||
if not msg_id:
|
if not msg_id:
|
||||||
msg_id = f"{body.get('chatid', '')}_{body.get('sendertime', '')}"
|
msg_id = f"{body.get('chatid', '')}_{body.get('sendertime', '')}"
|
||||||
|
|
||||||
|
# Extract sender info from "from" field (SDK format)
|
||||||
|
from_info = body.get("from", {})
|
||||||
|
sender_id = from_info.get("userid", "unknown") if isinstance(from_info, dict) else "unknown"
|
||||||
|
if not self.is_allowed(sender_id):
|
||||||
|
return
|
||||||
|
|
||||||
# Deduplication check
|
# Deduplication check
|
||||||
if msg_id in self._processed_message_ids:
|
if msg_id in self._processed_message_ids:
|
||||||
return
|
return
|
||||||
@ -242,10 +251,6 @@ class WecomChannel(BaseChannel):
|
|||||||
while len(self._processed_message_ids) > 1000:
|
while len(self._processed_message_ids) > 1000:
|
||||||
self._processed_message_ids.popitem(last=False)
|
self._processed_message_ids.popitem(last=False)
|
||||||
|
|
||||||
# Extract sender info from "from" field (SDK format)
|
|
||||||
from_info = body.get("from", {})
|
|
||||||
sender_id = from_info.get("userid", "unknown") if isinstance(from_info, dict) else "unknown"
|
|
||||||
|
|
||||||
# For single chat, chatid is the sender's userid
|
# For single chat, chatid is the sender's userid
|
||||||
# For group chat, chatid is provided in body
|
# For group chat, chatid is provided in body
|
||||||
chat_type = body.get("chattype", "single")
|
chat_type = body.get("chattype", "single")
|
||||||
@ -424,9 +429,9 @@ class WecomChannel(BaseChannel):
|
|||||||
# MD5 is used for file integrity only, not cryptographic security
|
# MD5 is used for file integrity only, not cryptographic security
|
||||||
md5_hash = hashlib.md5(data).hexdigest()
|
md5_hash = hashlib.md5(data).hexdigest()
|
||||||
|
|
||||||
CHUNK_SIZE = 512 * 1024 # 512 KB raw (before base64)
|
chunk_size = 512 * 1024 # 512 KB raw (before base64)
|
||||||
mv = memoryview(data)
|
mv = memoryview(data)
|
||||||
chunk_list = [bytes(mv[i : i + CHUNK_SIZE]) for i in range(0, file_size, CHUNK_SIZE)]
|
chunk_list = [bytes(mv[i : i + chunk_size]) for i in range(0, file_size, chunk_size)]
|
||||||
n_chunks = len(chunk_list)
|
n_chunks = len(chunk_list)
|
||||||
del mv, data
|
del mv, data
|
||||||
|
|
||||||
|
|||||||
@ -588,20 +588,24 @@ class WeixinChannel(BaseChannel):
|
|||||||
if msg.get("message_type") == MESSAGE_TYPE_BOT:
|
if msg.get("message_type") == MESSAGE_TYPE_BOT:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Deduplication by message_id
|
|
||||||
msg_id = str(msg.get("message_id", "") or msg.get("seq", ""))
|
msg_id = str(msg.get("message_id", "") or msg.get("seq", ""))
|
||||||
if not msg_id:
|
if not msg_id:
|
||||||
msg_id = f"{msg.get('from_user_id', '')}_{msg.get('create_time_ms', '')}"
|
msg_id = f"{msg.get('from_user_id', '')}_{msg.get('create_time_ms', '')}"
|
||||||
|
|
||||||
|
from_user_id = msg.get("from_user_id", "") or ""
|
||||||
|
if not from_user_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not self.is_allowed(from_user_id):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Deduplication by message_id
|
||||||
if msg_id in self._processed_ids:
|
if msg_id in self._processed_ids:
|
||||||
return
|
return
|
||||||
self._processed_ids[msg_id] = None
|
self._processed_ids[msg_id] = None
|
||||||
while len(self._processed_ids) > 1000:
|
while len(self._processed_ids) > 1000:
|
||||||
self._processed_ids.popitem(last=False)
|
self._processed_ids.popitem(last=False)
|
||||||
|
|
||||||
from_user_id = msg.get("from_user_id", "") or ""
|
|
||||||
if not from_user_id:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Cache context_token (required for all replies — inbound.ts:23-27)
|
# Cache context_token (required for all replies — inbound.ts:23-27)
|
||||||
ctx_token = msg.get("context_token", "")
|
ctx_token = msg.get("context_token", "")
|
||||||
if ctx_token:
|
if ctx_token:
|
||||||
|
|||||||
@ -8,8 +8,8 @@ import os
|
|||||||
import secrets
|
import secrets
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
from contextlib import suppress
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
from contextlib import suppress
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
@ -214,13 +214,6 @@ class WhatsAppChannel(BaseChannel):
|
|||||||
content = data.get("content", "")
|
content = data.get("content", "")
|
||||||
message_id = data.get("id", "")
|
message_id = data.get("id", "")
|
||||||
|
|
||||||
if message_id:
|
|
||||||
if message_id in self._processed_message_ids:
|
|
||||||
return
|
|
||||||
self._processed_message_ids[message_id] = None
|
|
||||||
while len(self._processed_message_ids) > 1000:
|
|
||||||
self._processed_message_ids.popitem(last=False)
|
|
||||||
|
|
||||||
# Extract just the phone number or lid as chat_id
|
# Extract just the phone number or lid as chat_id
|
||||||
is_group = data.get("isGroup", False)
|
is_group = data.get("isGroup", False)
|
||||||
was_mentioned = data.get("wasMentioned", False)
|
was_mentioned = data.get("wasMentioned", False)
|
||||||
@ -246,9 +239,19 @@ class WhatsAppChannel(BaseChannel):
|
|||||||
elif extracted and not phone_id:
|
elif extracted and not phone_id:
|
||||||
phone_id = extracted # best guess for bare values
|
phone_id = extracted # best guess for bare values
|
||||||
|
|
||||||
|
sender_id = phone_id or self._lid_to_phone.get(lid_id, "") or lid_id or id_a or id_b
|
||||||
|
if not self.is_allowed(sender_id):
|
||||||
|
return
|
||||||
|
|
||||||
|
if message_id:
|
||||||
|
if message_id in self._processed_message_ids:
|
||||||
|
return
|
||||||
|
self._processed_message_ids[message_id] = None
|
||||||
|
while len(self._processed_message_ids) > 1000:
|
||||||
|
self._processed_message_ids.popitem(last=False)
|
||||||
|
|
||||||
if phone_id and lid_id:
|
if phone_id and lid_id:
|
||||||
self._lid_to_phone[lid_id] = phone_id
|
self._lid_to_phone[lid_id] = phone_id
|
||||||
sender_id = phone_id or self._lid_to_phone.get(lid_id, "") or lid_id or id_a or id_b
|
|
||||||
|
|
||||||
logger.info("Sender phone={} lid={} → sender_id={}", phone_id or "(empty)", lid_id or "(empty)", sender_id)
|
logger.info("Sender phone={} lid={} → sender_id={}", phone_id or "(empty)", lid_id or "(empty)", sender_id)
|
||||||
|
|
||||||
|
|||||||
@ -1,14 +1,13 @@
|
|||||||
from email.message import EmailMessage
|
|
||||||
from datetime import date
|
|
||||||
from pathlib import Path
|
|
||||||
import imaplib
|
import imaplib
|
||||||
|
from datetime import date
|
||||||
|
from email.message import EmailMessage
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from nanobot.bus.events import OutboundMessage
|
from nanobot.bus.events import OutboundMessage
|
||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
from nanobot.channels.email import EmailChannel
|
from nanobot.channels.email import EmailChannel, EmailConfig
|
||||||
from nanobot.channels.email import EmailConfig
|
|
||||||
|
|
||||||
|
|
||||||
def _make_config(**overrides) -> EmailConfig:
|
def _make_config(**overrides) -> EmailConfig:
|
||||||
@ -24,6 +23,7 @@ def _make_config(**overrides) -> EmailConfig:
|
|||||||
smtp_username="bot@example.com",
|
smtp_username="bot@example.com",
|
||||||
smtp_password="secret",
|
smtp_password="secret",
|
||||||
mark_seen=True,
|
mark_seen=True,
|
||||||
|
allow_from=["*"],
|
||||||
# Disable auth verification by default so existing tests are unaffected
|
# Disable auth verification by default so existing tests are unaffected
|
||||||
verify_dkim=False,
|
verify_dkim=False,
|
||||||
verify_spf=False,
|
verify_spf=False,
|
||||||
@ -707,8 +707,8 @@ def test_email_content_tagged_with_email_context(monkeypatch) -> None:
|
|||||||
|
|
||||||
def test_check_authentication_results_method() -> None:
|
def test_check_authentication_results_method() -> None:
|
||||||
"""Unit test for the _check_authentication_results static method."""
|
"""Unit test for the _check_authentication_results static method."""
|
||||||
from email.parser import BytesParser
|
|
||||||
from email import policy
|
from email import policy
|
||||||
|
from email.parser import BytesParser
|
||||||
|
|
||||||
# No Authentication-Results header
|
# No Authentication-Results header
|
||||||
msg_no_auth = EmailMessage()
|
msg_no_auth = EmailMessage()
|
||||||
@ -788,6 +788,32 @@ def _make_raw_email_with_attachment(
|
|||||||
return msg.as_bytes()
|
return msg.as_bytes()
|
||||||
|
|
||||||
|
|
||||||
|
def test_fetch_new_messages_ignores_unauthorized_sender_before_attachments(monkeypatch) -> None:
|
||||||
|
raw = _make_raw_email_with_attachment(from_addr="blocked@example.com")
|
||||||
|
fake = _make_fake_imap(raw)
|
||||||
|
monkeypatch.setattr("nanobot.channels.email.imaplib.IMAP4_SSL", lambda _h, _p: fake)
|
||||||
|
|
||||||
|
called = {"attachments": False}
|
||||||
|
|
||||||
|
def _extract_attachments(*_args, **_kwargs):
|
||||||
|
called["attachments"] = True
|
||||||
|
return []
|
||||||
|
|
||||||
|
monkeypatch.setattr(EmailChannel, "_extract_attachments", _extract_attachments)
|
||||||
|
|
||||||
|
cfg = _make_config(
|
||||||
|
allow_from=["allowed@example.com"],
|
||||||
|
allowed_attachment_types=["application/pdf"],
|
||||||
|
verify_dkim=False,
|
||||||
|
verify_spf=False,
|
||||||
|
)
|
||||||
|
channel = EmailChannel(cfg, MessageBus())
|
||||||
|
|
||||||
|
assert channel._fetch_new_messages() == []
|
||||||
|
assert called["attachments"] is False
|
||||||
|
assert fake.store_calls == [(b"1", "+FLAGS", "\\Seen")]
|
||||||
|
|
||||||
|
|
||||||
def test_extract_attachments_saves_pdf(tmp_path, monkeypatch) -> None:
|
def test_extract_attachments_saves_pdf(tmp_path, monkeypatch) -> None:
|
||||||
"""PDF attachment is saved to media dir and path returned in media list."""
|
"""PDF attachment is saved to media dir and path returned in media list."""
|
||||||
monkeypatch.setattr("nanobot.channels.email.get_media_dir", lambda ch: tmp_path)
|
monkeypatch.setattr("nanobot.channels.email.get_media_dir", lambda ch: tmp_path)
|
||||||
|
|||||||
@ -806,3 +806,26 @@ def test_on_background_task_done_removes_from_set() -> None:
|
|||||||
loop.close()
|
loop.close()
|
||||||
|
|
||||||
assert task not in channel._background_tasks
|
assert task not in channel._background_tasks
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_on_message_ignores_unauthorized_sender_before_side_effects() -> None:
|
||||||
|
channel = _make_feishu_channel(group_policy="open")
|
||||||
|
channel.config.allow_from = ["ou_allowed"]
|
||||||
|
channel._add_reaction = AsyncMock()
|
||||||
|
channel._download_and_save_media = AsyncMock(return_value=("/tmp/audio.ogg", "[audio]"))
|
||||||
|
channel.transcribe_audio = AsyncMock(return_value="transcript")
|
||||||
|
channel._handle_message = AsyncMock()
|
||||||
|
|
||||||
|
event = _make_feishu_event(
|
||||||
|
msg_type="audio",
|
||||||
|
content='{"file_key": "file_1"}',
|
||||||
|
sender_open_id="ou_blocked",
|
||||||
|
)
|
||||||
|
|
||||||
|
await channel._on_message(event)
|
||||||
|
|
||||||
|
channel._add_reaction.assert_not_awaited()
|
||||||
|
channel._download_and_save_media.assert_not_awaited()
|
||||||
|
channel.transcribe_audio.assert_not_awaited()
|
||||||
|
channel._handle_message.assert_not_awaited()
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
"""Tests for QQ channel media support: helpers, send, inbound, and upload."""
|
"""Tests for QQ channel media support: helpers, send, inbound, and upload."""
|
||||||
|
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -182,6 +182,35 @@ async def test_send_media_failure_falls_back_to_text() -> None:
|
|||||||
assert "bad.png" in failure_calls[0]["content"]
|
assert "bad.png" in failure_calls[0]["content"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_on_message_ignores_unauthorized_sender_before_attachments_and_ack() -> None:
|
||||||
|
channel = QQChannel(
|
||||||
|
QQConfig(
|
||||||
|
app_id="app",
|
||||||
|
secret="secret",
|
||||||
|
allow_from=["allowed-user"],
|
||||||
|
ack_message="Processing...",
|
||||||
|
),
|
||||||
|
MessageBus(),
|
||||||
|
)
|
||||||
|
channel._client = _FakeClient()
|
||||||
|
channel._handle_attachments = AsyncMock(return_value=(["/tmp/a.png"], ["file"], []))
|
||||||
|
channel._handle_message = AsyncMock()
|
||||||
|
|
||||||
|
data = SimpleNamespace(
|
||||||
|
id="msg-blocked",
|
||||||
|
content="hello",
|
||||||
|
author=SimpleNamespace(user_openid="blocked-user"),
|
||||||
|
attachments=[SimpleNamespace(filename="a.png")],
|
||||||
|
)
|
||||||
|
|
||||||
|
await channel._on_message(data, is_group=False)
|
||||||
|
|
||||||
|
channel._handle_attachments.assert_not_awaited()
|
||||||
|
channel._handle_message.assert_not_awaited()
|
||||||
|
assert channel._client.api.c2c_calls == []
|
||||||
|
|
||||||
|
|
||||||
# ── _on_message() exception handling ────────────────────────────────
|
# ── _on_message() exception handling ────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1802,3 +1802,32 @@ async def test_send_uses_native_keyboard_when_flag_on() -> None:
|
|||||||
sent = channel._app.bot.sent_messages[0]
|
sent = channel._app.bot.sent_messages[0]
|
||||||
assert isinstance(sent.get("reply_markup"), InlineKeyboardMarkup)
|
assert isinstance(sent.get("reply_markup"), InlineKeyboardMarkup)
|
||||||
assert "[Yes]" not in sent["text"] # native keyboard owns the rendering
|
assert "[Yes]" not in sent["text"] # native keyboard owns the rendering
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_callback_query_ignores_unauthorized_user_before_side_effects() -> None:
|
||||||
|
channel = TelegramChannel(
|
||||||
|
TelegramConfig(enabled=True, token="123:abc", allow_from=["999"], inline_keyboards=True),
|
||||||
|
MessageBus(),
|
||||||
|
)
|
||||||
|
channel._handle_message = AsyncMock()
|
||||||
|
|
||||||
|
query = SimpleNamespace(
|
||||||
|
id="cb_1",
|
||||||
|
data="Yes",
|
||||||
|
answer=AsyncMock(),
|
||||||
|
message=SimpleNamespace(
|
||||||
|
chat_id=123,
|
||||||
|
edit_reply_markup=AsyncMock(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
update = SimpleNamespace(
|
||||||
|
callback_query=query,
|
||||||
|
effective_user=SimpleNamespace(id=12345, username="alice", first_name="Alice"),
|
||||||
|
)
|
||||||
|
|
||||||
|
await channel._on_callback_query(update, None)
|
||||||
|
|
||||||
|
query.answer.assert_not_awaited()
|
||||||
|
query.message.edit_reply_markup.assert_not_awaited()
|
||||||
|
channel._handle_message.assert_not_awaited()
|
||||||
|
|||||||
@ -3,7 +3,6 @@
|
|||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from types import SimpleNamespace
|
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -451,6 +450,39 @@ async def test_process_text_message() -> None:
|
|||||||
assert msg.metadata["msg_type"] == "text"
|
assert msg.metadata["msg_type"] == "text"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_enter_chat_ignores_unauthorized_user_before_welcome() -> None:
|
||||||
|
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["allowed"]), MessageBus())
|
||||||
|
client = _FakeWeComClient()
|
||||||
|
channel._client = client
|
||||||
|
channel.config.welcome_message = "hello"
|
||||||
|
|
||||||
|
await channel._on_enter_chat(_FakeFrame(body={"chatid": "blocked"}))
|
||||||
|
|
||||||
|
client.reply_welcome.assert_not_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_process_message_ignores_unauthorized_sender_before_download() -> None:
|
||||||
|
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["allowed"]), MessageBus())
|
||||||
|
client = _FakeWeComClient()
|
||||||
|
channel._client = client
|
||||||
|
channel._handle_message = AsyncMock()
|
||||||
|
|
||||||
|
frame = _FakeFrame(body={
|
||||||
|
"msgid": "msg_blocked",
|
||||||
|
"chatid": "chat1",
|
||||||
|
"from": {"userid": "blocked"},
|
||||||
|
"image": {"url": "https://example.com/img.png", "aeskey": "key123"},
|
||||||
|
})
|
||||||
|
|
||||||
|
await channel._process_message(frame, "image")
|
||||||
|
|
||||||
|
client.download_file.assert_not_awaited()
|
||||||
|
channel._handle_message.assert_not_awaited()
|
||||||
|
assert channel.bus.inbound_size == 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_process_image_message() -> None:
|
async def test_process_image_message() -> None:
|
||||||
"""Image message: download success → media_paths non-empty."""
|
"""Image message: download success → media_paths non-empty."""
|
||||||
|
|||||||
@ -5,8 +5,8 @@ from pathlib import Path
|
|||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from unittest.mock import AsyncMock
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
import pytest
|
|
||||||
import httpx
|
import httpx
|
||||||
|
import pytest
|
||||||
|
|
||||||
import nanobot.channels.weixin as weixin_mod
|
import nanobot.channels.weixin as weixin_mod
|
||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
@ -15,10 +15,10 @@ from nanobot.channels.weixin import (
|
|||||||
ITEM_TEXT,
|
ITEM_TEXT,
|
||||||
MESSAGE_TYPE_BOT,
|
MESSAGE_TYPE_BOT,
|
||||||
WEIXIN_CHANNEL_VERSION,
|
WEIXIN_CHANNEL_VERSION,
|
||||||
_decrypt_aes_ecb,
|
|
||||||
_encrypt_aes_ecb,
|
|
||||||
WeixinChannel,
|
WeixinChannel,
|
||||||
WeixinConfig,
|
WeixinConfig,
|
||||||
|
_decrypt_aes_ecb,
|
||||||
|
_encrypt_aes_ecb,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -128,6 +128,34 @@ async def test_process_message_caches_context_token_and_send_uses_it() -> None:
|
|||||||
channel._send_text.assert_awaited_once_with("wx-user", "pong", "ctx-2")
|
channel._send_text.assert_awaited_once_with("wx-user", "pong", "ctx-2")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_process_message_ignores_unauthorized_sender_before_side_effects(tmp_path) -> None:
|
||||||
|
bus = MessageBus()
|
||||||
|
channel = WeixinChannel(
|
||||||
|
WeixinConfig(enabled=True, allow_from=["allowed-user"], state_dir=str(tmp_path)),
|
||||||
|
bus,
|
||||||
|
)
|
||||||
|
channel._download_media_item = AsyncMock(return_value="/tmp/test.jpg")
|
||||||
|
channel._start_typing = AsyncMock()
|
||||||
|
|
||||||
|
await channel._process_message(
|
||||||
|
{
|
||||||
|
"message_type": 1,
|
||||||
|
"message_id": "m-unauthorized",
|
||||||
|
"from_user_id": "blocked-user",
|
||||||
|
"context_token": "ctx-blocked",
|
||||||
|
"item_list": [
|
||||||
|
{"type": ITEM_IMAGE, "image_item": {"media": {"encrypt_query_param": "x"}}},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert channel._context_tokens == {}
|
||||||
|
channel._download_media_item.assert_not_awaited()
|
||||||
|
channel._start_typing.assert_not_awaited()
|
||||||
|
assert bus.inbound_size == 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_process_message_persists_context_token_to_state_file(tmp_path) -> None:
|
async def test_process_message_persists_context_token_to_state_file(tmp_path) -> None:
|
||||||
bus = MessageBus()
|
bus = MessageBus()
|
||||||
|
|||||||
@ -116,7 +116,7 @@ async def test_send_when_disconnected_is_noop():
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_group_policy_mention_skips_unmentioned_group_message():
|
async def test_group_policy_mention_skips_unmentioned_group_message():
|
||||||
ch = WhatsAppChannel({"enabled": True, "groupPolicy": "mention"}, MagicMock())
|
ch = WhatsAppChannel({"enabled": True, "allowFrom": ["*"], "groupPolicy": "mention"}, MagicMock())
|
||||||
ch._handle_message = AsyncMock()
|
ch._handle_message = AsyncMock()
|
||||||
|
|
||||||
await ch._handle_bridge_message(
|
await ch._handle_bridge_message(
|
||||||
@ -139,7 +139,7 @@ async def test_group_policy_mention_skips_unmentioned_group_message():
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_group_policy_mention_accepts_mentioned_group_message():
|
async def test_group_policy_mention_accepts_mentioned_group_message():
|
||||||
ch = WhatsAppChannel({"enabled": True, "groupPolicy": "mention"}, MagicMock())
|
ch = WhatsAppChannel({"enabled": True, "allowFrom": ["*"], "groupPolicy": "mention"}, MagicMock())
|
||||||
ch._handle_message = AsyncMock()
|
ch._handle_message = AsyncMock()
|
||||||
|
|
||||||
await ch._handle_bridge_message(
|
await ch._handle_bridge_message(
|
||||||
@ -166,7 +166,7 @@ async def test_group_policy_mention_accepts_mentioned_group_message():
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_sender_id_prefers_phone_jid_over_lid():
|
async def test_sender_id_prefers_phone_jid_over_lid():
|
||||||
"""sender_id should resolve to phone number when @s.whatsapp.net JID is present."""
|
"""sender_id should resolve to phone number when @s.whatsapp.net JID is present."""
|
||||||
ch = WhatsAppChannel({"enabled": True}, MagicMock())
|
ch = WhatsAppChannel({"enabled": True, "allowFrom": ["*"]}, MagicMock())
|
||||||
ch._handle_message = AsyncMock()
|
ch._handle_message = AsyncMock()
|
||||||
|
|
||||||
await ch._handle_bridge_message(
|
await ch._handle_bridge_message(
|
||||||
@ -187,7 +187,7 @@ async def test_sender_id_prefers_phone_jid_over_lid():
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_lid_to_phone_cache_resolves_lid_only_messages():
|
async def test_lid_to_phone_cache_resolves_lid_only_messages():
|
||||||
"""When only LID is present, a cached LID→phone mapping should be used."""
|
"""When only LID is present, a cached LID→phone mapping should be used."""
|
||||||
ch = WhatsAppChannel({"enabled": True}, MagicMock())
|
ch = WhatsAppChannel({"enabled": True, "allowFrom": ["*"]}, MagicMock())
|
||||||
ch._handle_message = AsyncMock()
|
ch._handle_message = AsyncMock()
|
||||||
|
|
||||||
# First message: both phone and LID → builds cache
|
# First message: both phone and LID → builds cache
|
||||||
@ -220,7 +220,7 @@ async def test_lid_to_phone_cache_resolves_lid_only_messages():
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_voice_message_transcription_uses_media_path():
|
async def test_voice_message_transcription_uses_media_path():
|
||||||
"""Voice messages are transcribed when media path is available."""
|
"""Voice messages are transcribed when media path is available."""
|
||||||
ch = WhatsAppChannel({"enabled": True}, MagicMock())
|
ch = WhatsAppChannel({"enabled": True, "allowFrom": ["*"]}, MagicMock())
|
||||||
ch.transcription_provider = "openai"
|
ch.transcription_provider = "openai"
|
||||||
ch.transcription_api_key = "sk-test"
|
ch.transcription_api_key = "sk-test"
|
||||||
ch._handle_message = AsyncMock()
|
ch._handle_message = AsyncMock()
|
||||||
@ -243,10 +243,32 @@ async def test_voice_message_transcription_uses_media_path():
|
|||||||
assert kwargs["content"].startswith("Hello world")
|
assert kwargs["content"].startswith("Hello world")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unauthorized_voice_message_does_not_transcribe() -> None:
|
||||||
|
ch = WhatsAppChannel({"enabled": True, "allowFrom": ["allowed"]}, MagicMock())
|
||||||
|
ch._handle_message = AsyncMock()
|
||||||
|
ch.transcribe_audio = AsyncMock(return_value="Hello world")
|
||||||
|
|
||||||
|
await ch._handle_bridge_message(
|
||||||
|
json.dumps({
|
||||||
|
"type": "message",
|
||||||
|
"id": "v-blocked",
|
||||||
|
"sender": "blocked@s.whatsapp.net",
|
||||||
|
"pn": "",
|
||||||
|
"content": "[Voice Message]",
|
||||||
|
"timestamp": 1,
|
||||||
|
"media": ["/tmp/voice.ogg"],
|
||||||
|
})
|
||||||
|
)
|
||||||
|
|
||||||
|
ch.transcribe_audio.assert_not_awaited()
|
||||||
|
ch._handle_message.assert_not_awaited()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_voice_message_no_media_shows_not_available():
|
async def test_voice_message_no_media_shows_not_available():
|
||||||
"""Voice messages without media produce a fallback placeholder."""
|
"""Voice messages without media produce a fallback placeholder."""
|
||||||
ch = WhatsAppChannel({"enabled": True}, MagicMock())
|
ch = WhatsAppChannel({"enabled": True, "allowFrom": ["*"]}, MagicMock())
|
||||||
ch._handle_message = AsyncMock()
|
ch._handle_message = AsyncMock()
|
||||||
|
|
||||||
await ch._handle_bridge_message(
|
await ch._handle_bridge_message(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user