diff --git a/nanobot/channels/email.py b/nanobot/channels/email.py index 36cafc995..401da7bb6 100644 --- a/nanobot/channels/email.py +++ b/nanobot/channels/email.py @@ -407,6 +407,12 @@ class EmailChannel(BaseChannel): self._remember_processed_uid(uid, dedupe, cycle_uids) continue + if not self.is_allowed(sender): + self._remember_processed_uid(uid, dedupe, cycle_uids) + if mark_seen: + client.store(imap_id, "+FLAGS", "\\Seen") + continue + subject = self._decode_header_value(parsed.get("Subject", "")) date_value = parsed.get("Date", "") message_id = parsed.get("Message-ID", "").strip() diff --git a/nanobot/channels/feishu.py b/nanobot/channels/feishu.py index 7e65b6210..6fe8b9d5f 100644 --- a/nanobot/channels/feishu.py +++ b/nanobot/channels/feishu.py @@ -1644,15 +1644,7 @@ class FeishuChannel(BaseChannel): logger.debug("Feishu raw message: {}", message.content) logger.debug("Feishu mentions: {}", getattr(message, "mentions", None)) - # Deduplication check message_id = message.message_id - if message_id in self._processed_message_ids: - return - self._processed_message_ids[message_id] = None - - # Trim cache - while len(self._processed_message_ids) > 1000: - self._processed_message_ids.popitem(last=False) # Skip bot messages if sender.sender_type == "bot": @@ -1663,10 +1655,22 @@ class FeishuChannel(BaseChannel): chat_type = message.chat_type msg_type = message.message_type + if not self.is_allowed(sender_id): + return + if chat_type == "group" and not self._is_group_message_for_bot(message): logger.debug("Feishu: skipping group message (not mentioned)") return + # Deduplication check + if message_id in self._processed_message_ids: + return + self._processed_message_ids[message_id] = None + + # Trim cache + while len(self._processed_message_ids) > 1000: + self._processed_message_ids.popitem(last=False) + # Add reaction (non-blocking — tracked background task) task = asyncio.create_task( self._add_reaction(message_id, self.config.react_emoji) diff --git a/nanobot/channels/qq.py b/nanobot/channels/qq.py index 00338229a..ef70cc943 100644 --- a/nanobot/channels/qq.py +++ b/nanobot/channels/qq.py @@ -474,24 +474,28 @@ class QQChannel(BaseChannel): async def _on_message(self, data: C2CMessage | GroupMessage, is_group: bool = False) -> None: """Parse inbound message, download attachments, and publish to the bus.""" try: - if data.id in self._processed_ids: - return - self._processed_ids.append(data.id) - if is_group: chat_id = data.group_openid user_id = data.author.member_openid - self._chat_type_cache[chat_id] = "group" + chat_type = "group" else: chat_id = str( getattr(data.author, "id", None) or getattr(data.author, "user_openid", "unknown") ) user_id = chat_id - self._chat_type_cache[chat_id] = "c2c" + chat_type = "c2c" content = (data.content or "").strip() + if not self.is_allowed(user_id): + return + + if data.id in self._processed_ids: + return + self._processed_ids.append(data.id) + self._chat_type_cache[chat_id] = chat_type + # the data used by tests don't contain attachments property # so we use getattr with a default of [] to avoid AttributeError in tests attachments = getattr(data, "attachments", None) or [] diff --git a/nanobot/channels/telegram.py b/nanobot/channels/telegram.py index eecb73225..492b3ef50 100644 --- a/nanobot/channels/telegram.py +++ b/nanobot/channels/telegram.py @@ -993,6 +993,9 @@ class TelegramChannel(BaseChannel): return message = update.message user = update.effective_user + sender_id = self._sender_id(user) + if not self.is_allowed(sender_id): + return self._remember_thread_context(message) # Strip @bot_username suffix if present @@ -1004,7 +1007,7 @@ class TelegramChannel(BaseChannel): content = self._normalize_telegram_command(content) await self._handle_message( - sender_id=self._sender_id(user), + sender_id=sender_id, chat_id=str(message.chat_id), content=content, metadata=self._build_message_metadata(message, user), @@ -1264,6 +1267,8 @@ class TelegramChannel(BaseChannel): if not chat_id: logger.warning("Callback query without chat_id") return + if not self.is_allowed(sender_id): + return button_label = query.data or "" await query.answer() if query.message: diff --git a/nanobot/channels/wecom.py b/nanobot/channels/wecom.py index 69bdf3f08..ce3e7ed51 100644 --- a/nanobot/channels/wecom.py +++ b/nanobot/channels/wecom.py @@ -11,13 +11,13 @@ from pathlib import Path from typing import Any from loguru import logger +from pydantic import Field from nanobot.bus.events import OutboundMessage from nanobot.bus.queue import MessageBus from nanobot.channels.base import BaseChannel from nanobot.config.paths import get_media_dir from nanobot.config.schema import Base -from pydantic import Field WECOM_AVAILABLE = importlib.util.find_spec("wecom_aibot_sdk") is not None @@ -204,6 +204,9 @@ class WecomChannel(BaseChannel): chat_id = body.get("chatid", "") if isinstance(body, dict) else "" + if chat_id and not self.is_allowed(chat_id): + return + if chat_id and self.config.welcome_message: await self._client.reply_welcome(frame, { "msgtype": "text", @@ -233,6 +236,12 @@ class WecomChannel(BaseChannel): if not msg_id: msg_id = f"{body.get('chatid', '')}_{body.get('sendertime', '')}" + # Extract sender info from "from" field (SDK format) + from_info = body.get("from", {}) + sender_id = from_info.get("userid", "unknown") if isinstance(from_info, dict) else "unknown" + if not self.is_allowed(sender_id): + return + # Deduplication check if msg_id in self._processed_message_ids: return @@ -242,10 +251,6 @@ class WecomChannel(BaseChannel): while len(self._processed_message_ids) > 1000: self._processed_message_ids.popitem(last=False) - # Extract sender info from "from" field (SDK format) - from_info = body.get("from", {}) - sender_id = from_info.get("userid", "unknown") if isinstance(from_info, dict) else "unknown" - # For single chat, chatid is the sender's userid # For group chat, chatid is provided in body chat_type = body.get("chattype", "single") @@ -424,9 +429,9 @@ class WecomChannel(BaseChannel): # MD5 is used for file integrity only, not cryptographic security md5_hash = hashlib.md5(data).hexdigest() - CHUNK_SIZE = 512 * 1024 # 512 KB raw (before base64) + chunk_size = 512 * 1024 # 512 KB raw (before base64) mv = memoryview(data) - chunk_list = [bytes(mv[i : i + CHUNK_SIZE]) for i in range(0, file_size, CHUNK_SIZE)] + chunk_list = [bytes(mv[i : i + chunk_size]) for i in range(0, file_size, chunk_size)] n_chunks = len(chunk_list) del mv, data diff --git a/nanobot/channels/weixin.py b/nanobot/channels/weixin.py index 68fbed85d..af82984b2 100644 --- a/nanobot/channels/weixin.py +++ b/nanobot/channels/weixin.py @@ -588,20 +588,24 @@ class WeixinChannel(BaseChannel): if msg.get("message_type") == MESSAGE_TYPE_BOT: return - # Deduplication by message_id msg_id = str(msg.get("message_id", "") or msg.get("seq", "")) if not msg_id: msg_id = f"{msg.get('from_user_id', '')}_{msg.get('create_time_ms', '')}" + + from_user_id = msg.get("from_user_id", "") or "" + if not from_user_id: + return + + if not self.is_allowed(from_user_id): + return + + # Deduplication by message_id if msg_id in self._processed_ids: return self._processed_ids[msg_id] = None while len(self._processed_ids) > 1000: self._processed_ids.popitem(last=False) - from_user_id = msg.get("from_user_id", "") or "" - if not from_user_id: - return - # Cache context_token (required for all replies — inbound.ts:23-27) ctx_token = msg.get("context_token", "") if ctx_token: diff --git a/nanobot/channels/whatsapp.py b/nanobot/channels/whatsapp.py index 74d53203f..26869de18 100644 --- a/nanobot/channels/whatsapp.py +++ b/nanobot/channels/whatsapp.py @@ -8,8 +8,8 @@ import os import secrets import shutil import subprocess -from contextlib import suppress from collections import OrderedDict +from contextlib import suppress from pathlib import Path from typing import Any, Literal @@ -214,13 +214,6 @@ class WhatsAppChannel(BaseChannel): content = data.get("content", "") message_id = data.get("id", "") - if message_id: - if message_id in self._processed_message_ids: - return - self._processed_message_ids[message_id] = None - while len(self._processed_message_ids) > 1000: - self._processed_message_ids.popitem(last=False) - # Extract just the phone number or lid as chat_id is_group = data.get("isGroup", False) was_mentioned = data.get("wasMentioned", False) @@ -246,9 +239,19 @@ class WhatsAppChannel(BaseChannel): elif extracted and not phone_id: phone_id = extracted # best guess for bare values + sender_id = phone_id or self._lid_to_phone.get(lid_id, "") or lid_id or id_a or id_b + if not self.is_allowed(sender_id): + return + + if message_id: + if message_id in self._processed_message_ids: + return + self._processed_message_ids[message_id] = None + while len(self._processed_message_ids) > 1000: + self._processed_message_ids.popitem(last=False) + if phone_id and lid_id: self._lid_to_phone[lid_id] = phone_id - sender_id = phone_id or self._lid_to_phone.get(lid_id, "") or lid_id or id_a or id_b logger.info("Sender phone={} lid={} → sender_id={}", phone_id or "(empty)", lid_id or "(empty)", sender_id) diff --git a/tests/channels/test_email_channel.py b/tests/channels/test_email_channel.py index 98343522c..cb5aed45e 100644 --- a/tests/channels/test_email_channel.py +++ b/tests/channels/test_email_channel.py @@ -1,14 +1,13 @@ -from email.message import EmailMessage -from datetime import date -from pathlib import Path import imaplib +from datetime import date +from email.message import EmailMessage +from pathlib import Path import pytest from nanobot.bus.events import OutboundMessage from nanobot.bus.queue import MessageBus -from nanobot.channels.email import EmailChannel -from nanobot.channels.email import EmailConfig +from nanobot.channels.email import EmailChannel, EmailConfig def _make_config(**overrides) -> EmailConfig: @@ -24,6 +23,7 @@ def _make_config(**overrides) -> EmailConfig: smtp_username="bot@example.com", smtp_password="secret", mark_seen=True, + allow_from=["*"], # Disable auth verification by default so existing tests are unaffected verify_dkim=False, verify_spf=False, @@ -707,8 +707,8 @@ def test_email_content_tagged_with_email_context(monkeypatch) -> None: def test_check_authentication_results_method() -> None: """Unit test for the _check_authentication_results static method.""" - from email.parser import BytesParser from email import policy + from email.parser import BytesParser # No Authentication-Results header msg_no_auth = EmailMessage() @@ -788,6 +788,32 @@ def _make_raw_email_with_attachment( return msg.as_bytes() +def test_fetch_new_messages_ignores_unauthorized_sender_before_attachments(monkeypatch) -> None: + raw = _make_raw_email_with_attachment(from_addr="blocked@example.com") + fake = _make_fake_imap(raw) + monkeypatch.setattr("nanobot.channels.email.imaplib.IMAP4_SSL", lambda _h, _p: fake) + + called = {"attachments": False} + + def _extract_attachments(*_args, **_kwargs): + called["attachments"] = True + return [] + + monkeypatch.setattr(EmailChannel, "_extract_attachments", _extract_attachments) + + cfg = _make_config( + allow_from=["allowed@example.com"], + allowed_attachment_types=["application/pdf"], + verify_dkim=False, + verify_spf=False, + ) + channel = EmailChannel(cfg, MessageBus()) + + assert channel._fetch_new_messages() == [] + assert called["attachments"] is False + assert fake.store_calls == [(b"1", "+FLAGS", "\\Seen")] + + def test_extract_attachments_saves_pdf(tmp_path, monkeypatch) -> None: """PDF attachment is saved to media dir and path returned in media list.""" monkeypatch.setattr("nanobot.channels.email.get_media_dir", lambda ch: tmp_path) diff --git a/tests/channels/test_feishu_reply.py b/tests/channels/test_feishu_reply.py index 31d3a1d71..cc7e21e5f 100644 --- a/tests/channels/test_feishu_reply.py +++ b/tests/channels/test_feishu_reply.py @@ -806,3 +806,26 @@ def test_on_background_task_done_removes_from_set() -> None: loop.close() assert task not in channel._background_tasks + + +@pytest.mark.asyncio +async def test_on_message_ignores_unauthorized_sender_before_side_effects() -> None: + channel = _make_feishu_channel(group_policy="open") + channel.config.allow_from = ["ou_allowed"] + channel._add_reaction = AsyncMock() + channel._download_and_save_media = AsyncMock(return_value=("/tmp/audio.ogg", "[audio]")) + channel.transcribe_audio = AsyncMock(return_value="transcript") + channel._handle_message = AsyncMock() + + event = _make_feishu_event( + msg_type="audio", + content='{"file_key": "file_1"}', + sender_open_id="ou_blocked", + ) + + await channel._on_message(event) + + channel._add_reaction.assert_not_awaited() + channel._download_and_save_media.assert_not_awaited() + channel.transcribe_audio.assert_not_awaited() + channel._handle_message.assert_not_awaited() diff --git a/tests/channels/test_qq_media.py b/tests/channels/test_qq_media.py index 80a5ad20e..e2de72f28 100644 --- a/tests/channels/test_qq_media.py +++ b/tests/channels/test_qq_media.py @@ -1,7 +1,7 @@ """Tests for QQ channel media support: helpers, send, inbound, and upload.""" from types import SimpleNamespace -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, patch import pytest @@ -182,6 +182,35 @@ async def test_send_media_failure_falls_back_to_text() -> None: assert "bad.png" in failure_calls[0]["content"] +@pytest.mark.asyncio +async def test_on_message_ignores_unauthorized_sender_before_attachments_and_ack() -> None: + channel = QQChannel( + QQConfig( + app_id="app", + secret="secret", + allow_from=["allowed-user"], + ack_message="Processing...", + ), + MessageBus(), + ) + channel._client = _FakeClient() + channel._handle_attachments = AsyncMock(return_value=(["/tmp/a.png"], ["file"], [])) + channel._handle_message = AsyncMock() + + data = SimpleNamespace( + id="msg-blocked", + content="hello", + author=SimpleNamespace(user_openid="blocked-user"), + attachments=[SimpleNamespace(filename="a.png")], + ) + + await channel._on_message(data, is_group=False) + + channel._handle_attachments.assert_not_awaited() + channel._handle_message.assert_not_awaited() + assert channel._client.api.c2c_calls == [] + + # ── _on_message() exception handling ──────────────────────────────── diff --git a/tests/channels/test_telegram_channel.py b/tests/channels/test_telegram_channel.py index 591df84f4..2ae5cce9f 100644 --- a/tests/channels/test_telegram_channel.py +++ b/tests/channels/test_telegram_channel.py @@ -1802,3 +1802,32 @@ async def test_send_uses_native_keyboard_when_flag_on() -> None: sent = channel._app.bot.sent_messages[0] assert isinstance(sent.get("reply_markup"), InlineKeyboardMarkup) assert "[Yes]" not in sent["text"] # native keyboard owns the rendering + + +@pytest.mark.asyncio +async def test_callback_query_ignores_unauthorized_user_before_side_effects() -> None: + channel = TelegramChannel( + TelegramConfig(enabled=True, token="123:abc", allow_from=["999"], inline_keyboards=True), + MessageBus(), + ) + channel._handle_message = AsyncMock() + + query = SimpleNamespace( + id="cb_1", + data="Yes", + answer=AsyncMock(), + message=SimpleNamespace( + chat_id=123, + edit_reply_markup=AsyncMock(), + ), + ) + update = SimpleNamespace( + callback_query=query, + effective_user=SimpleNamespace(id=12345, username="alice", first_name="Alice"), + ) + + await channel._on_callback_query(update, None) + + query.answer.assert_not_awaited() + query.message.edit_reply_markup.assert_not_awaited() + channel._handle_message.assert_not_awaited() diff --git a/tests/channels/test_wecom_channel.py b/tests/channels/test_wecom_channel.py index a8ed3c0e9..7cb61ab82 100644 --- a/tests/channels/test_wecom_channel.py +++ b/tests/channels/test_wecom_channel.py @@ -3,7 +3,6 @@ import os import tempfile from pathlib import Path -from types import SimpleNamespace from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -451,6 +450,39 @@ async def test_process_text_message() -> None: assert msg.metadata["msg_type"] == "text" +@pytest.mark.asyncio +async def test_enter_chat_ignores_unauthorized_user_before_welcome() -> None: + channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["allowed"]), MessageBus()) + client = _FakeWeComClient() + channel._client = client + channel.config.welcome_message = "hello" + + await channel._on_enter_chat(_FakeFrame(body={"chatid": "blocked"})) + + client.reply_welcome.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_process_message_ignores_unauthorized_sender_before_download() -> None: + channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["allowed"]), MessageBus()) + client = _FakeWeComClient() + channel._client = client + channel._handle_message = AsyncMock() + + frame = _FakeFrame(body={ + "msgid": "msg_blocked", + "chatid": "chat1", + "from": {"userid": "blocked"}, + "image": {"url": "https://example.com/img.png", "aeskey": "key123"}, + }) + + await channel._process_message(frame, "image") + + client.download_file.assert_not_awaited() + channel._handle_message.assert_not_awaited() + assert channel.bus.inbound_size == 0 + + @pytest.mark.asyncio async def test_process_image_message() -> None: """Image message: download success → media_paths non-empty.""" diff --git a/tests/channels/test_weixin_channel.py b/tests/channels/test_weixin_channel.py index 2b455fca6..4b9b294a9 100644 --- a/tests/channels/test_weixin_channel.py +++ b/tests/channels/test_weixin_channel.py @@ -5,8 +5,8 @@ from pathlib import Path from types import SimpleNamespace from unittest.mock import AsyncMock -import pytest import httpx +import pytest import nanobot.channels.weixin as weixin_mod from nanobot.bus.queue import MessageBus @@ -15,10 +15,10 @@ from nanobot.channels.weixin import ( ITEM_TEXT, MESSAGE_TYPE_BOT, WEIXIN_CHANNEL_VERSION, - _decrypt_aes_ecb, - _encrypt_aes_ecb, WeixinChannel, WeixinConfig, + _decrypt_aes_ecb, + _encrypt_aes_ecb, ) @@ -128,6 +128,34 @@ async def test_process_message_caches_context_token_and_send_uses_it() -> None: channel._send_text.assert_awaited_once_with("wx-user", "pong", "ctx-2") +@pytest.mark.asyncio +async def test_process_message_ignores_unauthorized_sender_before_side_effects(tmp_path) -> None: + bus = MessageBus() + channel = WeixinChannel( + WeixinConfig(enabled=True, allow_from=["allowed-user"], state_dir=str(tmp_path)), + bus, + ) + channel._download_media_item = AsyncMock(return_value="/tmp/test.jpg") + channel._start_typing = AsyncMock() + + await channel._process_message( + { + "message_type": 1, + "message_id": "m-unauthorized", + "from_user_id": "blocked-user", + "context_token": "ctx-blocked", + "item_list": [ + {"type": ITEM_IMAGE, "image_item": {"media": {"encrypt_query_param": "x"}}}, + ], + } + ) + + assert channel._context_tokens == {} + channel._download_media_item.assert_not_awaited() + channel._start_typing.assert_not_awaited() + assert bus.inbound_size == 0 + + @pytest.mark.asyncio async def test_process_message_persists_context_token_to_state_file(tmp_path) -> None: bus = MessageBus() diff --git a/tests/channels/test_whatsapp_channel.py b/tests/channels/test_whatsapp_channel.py index b61033677..6229723a5 100644 --- a/tests/channels/test_whatsapp_channel.py +++ b/tests/channels/test_whatsapp_channel.py @@ -116,7 +116,7 @@ async def test_send_when_disconnected_is_noop(): @pytest.mark.asyncio async def test_group_policy_mention_skips_unmentioned_group_message(): - ch = WhatsAppChannel({"enabled": True, "groupPolicy": "mention"}, MagicMock()) + ch = WhatsAppChannel({"enabled": True, "allowFrom": ["*"], "groupPolicy": "mention"}, MagicMock()) ch._handle_message = AsyncMock() await ch._handle_bridge_message( @@ -139,7 +139,7 @@ async def test_group_policy_mention_skips_unmentioned_group_message(): @pytest.mark.asyncio async def test_group_policy_mention_accepts_mentioned_group_message(): - ch = WhatsAppChannel({"enabled": True, "groupPolicy": "mention"}, MagicMock()) + ch = WhatsAppChannel({"enabled": True, "allowFrom": ["*"], "groupPolicy": "mention"}, MagicMock()) ch._handle_message = AsyncMock() await ch._handle_bridge_message( @@ -166,7 +166,7 @@ async def test_group_policy_mention_accepts_mentioned_group_message(): @pytest.mark.asyncio async def test_sender_id_prefers_phone_jid_over_lid(): """sender_id should resolve to phone number when @s.whatsapp.net JID is present.""" - ch = WhatsAppChannel({"enabled": True}, MagicMock()) + ch = WhatsAppChannel({"enabled": True, "allowFrom": ["*"]}, MagicMock()) ch._handle_message = AsyncMock() await ch._handle_bridge_message( @@ -187,7 +187,7 @@ async def test_sender_id_prefers_phone_jid_over_lid(): @pytest.mark.asyncio async def test_lid_to_phone_cache_resolves_lid_only_messages(): """When only LID is present, a cached LID→phone mapping should be used.""" - ch = WhatsAppChannel({"enabled": True}, MagicMock()) + ch = WhatsAppChannel({"enabled": True, "allowFrom": ["*"]}, MagicMock()) ch._handle_message = AsyncMock() # First message: both phone and LID → builds cache @@ -220,7 +220,7 @@ async def test_lid_to_phone_cache_resolves_lid_only_messages(): @pytest.mark.asyncio async def test_voice_message_transcription_uses_media_path(): """Voice messages are transcribed when media path is available.""" - ch = WhatsAppChannel({"enabled": True}, MagicMock()) + ch = WhatsAppChannel({"enabled": True, "allowFrom": ["*"]}, MagicMock()) ch.transcription_provider = "openai" ch.transcription_api_key = "sk-test" ch._handle_message = AsyncMock() @@ -243,10 +243,32 @@ async def test_voice_message_transcription_uses_media_path(): assert kwargs["content"].startswith("Hello world") +@pytest.mark.asyncio +async def test_unauthorized_voice_message_does_not_transcribe() -> None: + ch = WhatsAppChannel({"enabled": True, "allowFrom": ["allowed"]}, MagicMock()) + ch._handle_message = AsyncMock() + ch.transcribe_audio = AsyncMock(return_value="Hello world") + + await ch._handle_bridge_message( + json.dumps({ + "type": "message", + "id": "v-blocked", + "sender": "blocked@s.whatsapp.net", + "pn": "", + "content": "[Voice Message]", + "timestamp": 1, + "media": ["/tmp/voice.ogg"], + }) + ) + + ch.transcribe_audio.assert_not_awaited() + ch._handle_message.assert_not_awaited() + + @pytest.mark.asyncio async def test_voice_message_no_media_shows_not_available(): """Voice messages without media produce a fallback placeholder.""" - ch = WhatsAppChannel({"enabled": True}, MagicMock()) + ch = WhatsAppChannel({"enabled": True, "allowFrom": ["*"]}, MagicMock()) ch._handle_message = AsyncMock() await ch._handle_bridge_message(