Fix pairing for Weixin and Telegram DMs

This commit is contained in:
chengyongru 2026-06-05 10:40:35 +08:00 committed by Xubin Ren
parent d435cb0b21
commit 3da68ac7fe
4 changed files with 121 additions and 17 deletions

View File

@ -869,7 +869,9 @@ class TelegramChannel(BaseChannel):
return
user = update.effective_user
if not self.is_allowed(self._sender_id(user)):
sender_id = self._sender_id(user)
if not self.is_allowed(sender_id):
await self._send_pairing_code_if_private(sender_id, update.message, user)
return
await update.message.reply_text(
f"👋 Hi {user.first_name}! I'm nanobot.\n\n"
@ -881,7 +883,10 @@ class TelegramChannel(BaseChannel):
"""Handle /help command for allowed users only."""
if not update.message or not update.effective_user:
return
if not self.is_allowed(self._sender_id(update.effective_user)):
user = update.effective_user
sender_id = self._sender_id(user)
if not self.is_allowed(sender_id):
await self._send_pairing_code_if_private(sender_id, update.message, user)
return
await update.message.reply_text(build_help_text())
@ -891,6 +896,17 @@ class TelegramChannel(BaseChannel):
sid = str(user.id)
return f"{sid}|{user.username}" if user.username else sid
async def _send_pairing_code_if_private(self, sender_id: str, message, user) -> None:
if message.chat.type != "private":
return
await self._handle_message(
sender_id=sender_id,
chat_id=str(message.chat_id),
content="",
metadata=self._build_message_metadata(message, user),
is_dm=True,
)
@staticmethod
def _derive_topic_session_key(message) -> str | None:
"""Derive topic-scoped session key for Telegram chats with threads."""
@ -1149,6 +1165,7 @@ class TelegramChannel(BaseChannel):
user = update.effective_user
sender_id = self._sender_id(user)
if not self.is_allowed(sender_id):
await self._send_pairing_code_if_private(sender_id, message, user)
return
self._remember_thread_context(message)
@ -1186,6 +1203,7 @@ class TelegramChannel(BaseChannel):
chat_id = message.chat_id
sender_id = self._sender_id(user)
if not self.is_allowed(sender_id):
await self._send_pairing_code_if_private(sender_id, message, user)
return
self._remember_thread_context(message)

View File

@ -609,9 +609,6 @@ class WeixinChannel(BaseChannel):
if not from_user_id:
return
if not self.is_allowed(from_user_id):
return
# Deduplication by message_id
if msg_id in self._processed_ids:
return
@ -619,8 +616,51 @@ class WeixinChannel(BaseChannel):
while len(self._processed_ids) > 1000:
self._processed_ids.popitem(last=False)
# Cache context_token (required for all replies — inbound.ts:23-27)
ctx_token = msg.get("context_token", "")
if not self.is_allowed(from_user_id):
if from_user_id.endswith("@chatroom"):
await self._handle_message(
sender_id=from_user_id,
chat_id=from_user_id,
content="",
metadata={"message_id": msg_id},
is_dm=False,
)
return
if not ctx_token:
self.logger.warning(
"Access denied for sender {}; cannot send WeChat pairing code without context_token",
from_user_id,
)
return
had_ctx_token = from_user_id in self._context_tokens
previous_ctx_token = self._context_tokens.get(from_user_id, "")
had_ctx_token_at = from_user_id in self._context_token_at
previous_ctx_token_at = self._context_token_at.get(from_user_id, 0.0)
self._context_tokens[from_user_id] = ctx_token
self._context_token_at[from_user_id] = time.time()
try:
await self._handle_message(
sender_id=from_user_id,
chat_id=from_user_id,
content="",
metadata={"message_id": msg_id},
is_dm=True,
)
finally:
if had_ctx_token:
self._context_tokens[from_user_id] = previous_ctx_token
else:
self._context_tokens.pop(from_user_id, None)
if had_ctx_token_at:
self._context_token_at[from_user_id] = previous_ctx_token_at
else:
self._context_token_at.pop(from_user_id, None)
return
# Cache context_token (required for all replies — inbound.ts:23-27)
if ctx_token:
self._context_tokens[from_user_id] = ctx_token
self._context_token_at[from_user_id] = time.time()

View File

@ -1359,6 +1359,23 @@ async def test_forward_command_does_not_inject_reply_context() -> None:
assert handled[0]["content"] == "/new"
@pytest.mark.asyncio
async def test_forward_command_pairs_unauthorized_private_user(monkeypatch) -> None:
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["999"], group_policy="open"),
MessageBus(),
)
channel._app = _FakeApp(lambda: None)
monkeypatch.setattr(
"nanobot.channels.base.generate_code", lambda _ch, _sid: "ABCD-EFGH"
)
await channel._forward_command(_make_telegram_update(text="/new", chat_type="private"), None)
assert len(channel._app.bot.sent_messages) == 1
assert "ABCD-EFGH" in channel._app.bot.sent_messages[0]["text"]
@pytest.mark.asyncio
async def test_forward_command_preserves_dream_log_args_and_strips_bot_suffix() -> None:
channel = TelegramChannel(
@ -1439,55 +1456,69 @@ async def test_on_help_includes_restart_command() -> None:
@pytest.mark.asyncio
async def test_on_start_ignores_unauthorized_user_silently() -> None:
async def test_on_start_sends_pairing_code_to_unauthorized_private_user(monkeypatch) -> None:
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["999"], group_policy="open"),
MessageBus(),
)
channel._app = _FakeApp(lambda: None)
update = _make_telegram_update(text="/start", chat_type="private")
update.message.reply_text = AsyncMock()
monkeypatch.setattr(
"nanobot.channels.base.generate_code", lambda _ch, _sid: "ABCD-EFGH"
)
await channel._on_start(update, None)
update.message.reply_text.assert_not_awaited()
assert len(channel._app.bot.sent_messages) == 1
assert "ABCD-EFGH" in channel._app.bot.sent_messages[0]["text"]
@pytest.mark.asyncio
async def test_on_help_ignores_unauthorized_user_silently() -> None:
async def test_on_help_sends_pairing_code_to_unauthorized_private_user(monkeypatch) -> None:
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["999"], group_policy="open"),
MessageBus(),
)
channel._app = _FakeApp(lambda: None)
update = _make_telegram_update(text="/help", chat_type="private")
update.message.reply_text = AsyncMock()
monkeypatch.setattr(
"nanobot.channels.base.generate_code", lambda _ch, _sid: "ABCD-EFGH"
)
await channel._on_help(update, None)
update.message.reply_text.assert_not_awaited()
assert len(channel._app.bot.sent_messages) == 1
assert "ABCD-EFGH" in channel._app.bot.sent_messages[0]["text"]
@pytest.mark.asyncio
async def test_on_message_ignores_unauthorized_user_before_side_effects() -> None:
async def test_on_message_pairs_unauthorized_private_user_before_side_effects(
monkeypatch,
) -> None:
channel = TelegramChannel(
TelegramConfig(enabled=True, token="123:abc", allow_from=["999"], group_policy="open"),
MessageBus(),
)
channel._app = _FakeApp(lambda: None)
started_typing: list[str] = []
handled: list[dict] = []
channel._start_typing = lambda chat_id: started_typing.append(chat_id)
channel._add_reaction = AsyncMock(return_value=None)
async def capture_handle(**kwargs) -> None:
handled.append(kwargs)
channel._handle_message = capture_handle
channel._download_message_media = AsyncMock(return_value=([], []))
monkeypatch.setattr(
"nanobot.channels.base.generate_code", lambda _ch, _sid: "ABCD-EFGH"
)
await channel._on_message(_make_telegram_update(text="hello", chat_type="private"), None)
assert started_typing == []
channel._add_reaction.assert_not_awaited()
assert handled == []
channel._download_message_media.assert_not_awaited()
assert len(channel._app.bot.sent_messages) == 1
assert "ABCD-EFGH" in channel._app.bot.sent_messages[0]["text"]
@pytest.mark.asyncio

View File

@ -130,14 +130,24 @@ async def test_process_message_caches_context_token_and_send_uses_it() -> None:
@pytest.mark.asyncio
async def test_process_message_ignores_unauthorized_sender_before_side_effects(tmp_path) -> None:
async def test_process_message_pairs_unauthorized_sender_before_media_side_effects(
monkeypatch,
tmp_path,
) -> None:
bus = MessageBus()
channel = WeixinChannel(
WeixinConfig(enabled=True, allow_from=["allowed-user"], state_dir=str(tmp_path)),
bus,
)
channel._client = object()
channel._token = "token"
channel._download_media_item = AsyncMock(return_value="/tmp/test.jpg")
channel._start_typing = AsyncMock()
channel._get_typing_ticket = AsyncMock(return_value="")
channel._send_text = AsyncMock()
monkeypatch.setattr(
"nanobot.channels.base.generate_code", lambda _ch, _sid: "ABCD-EFGH"
)
await channel._process_message(
{
@ -154,6 +164,11 @@ async def test_process_message_ignores_unauthorized_sender_before_side_effects(t
assert channel._context_tokens == {}
channel._download_media_item.assert_not_awaited()
channel._start_typing.assert_not_awaited()
channel._send_text.assert_awaited_once()
send_args = channel._send_text.await_args.args
assert send_args[0] == "blocked-user"
assert "ABCD-EFGH" in send_args[1]
assert send_args[2] == "ctx-blocked"
assert bus.inbound_size == 0