mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-13 14:23:58 +00:00
Fix pairing for Weixin and Telegram DMs
This commit is contained in:
parent
d435cb0b21
commit
3da68ac7fe
@ -869,7 +869,9 @@ class TelegramChannel(BaseChannel):
|
||||
return
|
||||
|
||||
user = update.effective_user
|
||||
if not self.is_allowed(self._sender_id(user)):
|
||||
sender_id = self._sender_id(user)
|
||||
if not self.is_allowed(sender_id):
|
||||
await self._send_pairing_code_if_private(sender_id, update.message, user)
|
||||
return
|
||||
await update.message.reply_text(
|
||||
f"👋 Hi {user.first_name}! I'm nanobot.\n\n"
|
||||
@ -881,7 +883,10 @@ class TelegramChannel(BaseChannel):
|
||||
"""Handle /help command for allowed users only."""
|
||||
if not update.message or not update.effective_user:
|
||||
return
|
||||
if not self.is_allowed(self._sender_id(update.effective_user)):
|
||||
user = update.effective_user
|
||||
sender_id = self._sender_id(user)
|
||||
if not self.is_allowed(sender_id):
|
||||
await self._send_pairing_code_if_private(sender_id, update.message, user)
|
||||
return
|
||||
await update.message.reply_text(build_help_text())
|
||||
|
||||
@ -891,6 +896,17 @@ class TelegramChannel(BaseChannel):
|
||||
sid = str(user.id)
|
||||
return f"{sid}|{user.username}" if user.username else sid
|
||||
|
||||
async def _send_pairing_code_if_private(self, sender_id: str, message, user) -> None:
|
||||
if message.chat.type != "private":
|
||||
return
|
||||
await self._handle_message(
|
||||
sender_id=sender_id,
|
||||
chat_id=str(message.chat_id),
|
||||
content="",
|
||||
metadata=self._build_message_metadata(message, user),
|
||||
is_dm=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _derive_topic_session_key(message) -> str | None:
|
||||
"""Derive topic-scoped session key for Telegram chats with threads."""
|
||||
@ -1149,6 +1165,7 @@ class TelegramChannel(BaseChannel):
|
||||
user = update.effective_user
|
||||
sender_id = self._sender_id(user)
|
||||
if not self.is_allowed(sender_id):
|
||||
await self._send_pairing_code_if_private(sender_id, message, user)
|
||||
return
|
||||
self._remember_thread_context(message)
|
||||
|
||||
@ -1186,6 +1203,7 @@ class TelegramChannel(BaseChannel):
|
||||
chat_id = message.chat_id
|
||||
sender_id = self._sender_id(user)
|
||||
if not self.is_allowed(sender_id):
|
||||
await self._send_pairing_code_if_private(sender_id, message, user)
|
||||
return
|
||||
self._remember_thread_context(message)
|
||||
|
||||
|
||||
@ -609,9 +609,6 @@ class WeixinChannel(BaseChannel):
|
||||
if not from_user_id:
|
||||
return
|
||||
|
||||
if not self.is_allowed(from_user_id):
|
||||
return
|
||||
|
||||
# Deduplication by message_id
|
||||
if msg_id in self._processed_ids:
|
||||
return
|
||||
@ -619,8 +616,51 @@ class WeixinChannel(BaseChannel):
|
||||
while len(self._processed_ids) > 1000:
|
||||
self._processed_ids.popitem(last=False)
|
||||
|
||||
# Cache context_token (required for all replies — inbound.ts:23-27)
|
||||
ctx_token = msg.get("context_token", "")
|
||||
if not self.is_allowed(from_user_id):
|
||||
if from_user_id.endswith("@chatroom"):
|
||||
await self._handle_message(
|
||||
sender_id=from_user_id,
|
||||
chat_id=from_user_id,
|
||||
content="",
|
||||
metadata={"message_id": msg_id},
|
||||
is_dm=False,
|
||||
)
|
||||
return
|
||||
|
||||
if not ctx_token:
|
||||
self.logger.warning(
|
||||
"Access denied for sender {}; cannot send WeChat pairing code without context_token",
|
||||
from_user_id,
|
||||
)
|
||||
return
|
||||
|
||||
had_ctx_token = from_user_id in self._context_tokens
|
||||
previous_ctx_token = self._context_tokens.get(from_user_id, "")
|
||||
had_ctx_token_at = from_user_id in self._context_token_at
|
||||
previous_ctx_token_at = self._context_token_at.get(from_user_id, 0.0)
|
||||
self._context_tokens[from_user_id] = ctx_token
|
||||
self._context_token_at[from_user_id] = time.time()
|
||||
try:
|
||||
await self._handle_message(
|
||||
sender_id=from_user_id,
|
||||
chat_id=from_user_id,
|
||||
content="",
|
||||
metadata={"message_id": msg_id},
|
||||
is_dm=True,
|
||||
)
|
||||
finally:
|
||||
if had_ctx_token:
|
||||
self._context_tokens[from_user_id] = previous_ctx_token
|
||||
else:
|
||||
self._context_tokens.pop(from_user_id, None)
|
||||
if had_ctx_token_at:
|
||||
self._context_token_at[from_user_id] = previous_ctx_token_at
|
||||
else:
|
||||
self._context_token_at.pop(from_user_id, None)
|
||||
return
|
||||
|
||||
# Cache context_token (required for all replies — inbound.ts:23-27)
|
||||
if ctx_token:
|
||||
self._context_tokens[from_user_id] = ctx_token
|
||||
self._context_token_at[from_user_id] = time.time()
|
||||
|
||||
@ -1359,6 +1359,23 @@ async def test_forward_command_does_not_inject_reply_context() -> None:
|
||||
assert handled[0]["content"] == "/new"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_forward_command_pairs_unauthorized_private_user(monkeypatch) -> None:
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["999"], group_policy="open"),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.channels.base.generate_code", lambda _ch, _sid: "ABCD-EFGH"
|
||||
)
|
||||
|
||||
await channel._forward_command(_make_telegram_update(text="/new", chat_type="private"), None)
|
||||
|
||||
assert len(channel._app.bot.sent_messages) == 1
|
||||
assert "ABCD-EFGH" in channel._app.bot.sent_messages[0]["text"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_forward_command_preserves_dream_log_args_and_strips_bot_suffix() -> None:
|
||||
channel = TelegramChannel(
|
||||
@ -1439,55 +1456,69 @@ async def test_on_help_includes_restart_command() -> None:
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_start_ignores_unauthorized_user_silently() -> None:
|
||||
async def test_on_start_sends_pairing_code_to_unauthorized_private_user(monkeypatch) -> None:
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["999"], group_policy="open"),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
update = _make_telegram_update(text="/start", chat_type="private")
|
||||
update.message.reply_text = AsyncMock()
|
||||
monkeypatch.setattr(
|
||||
"nanobot.channels.base.generate_code", lambda _ch, _sid: "ABCD-EFGH"
|
||||
)
|
||||
|
||||
await channel._on_start(update, None)
|
||||
|
||||
update.message.reply_text.assert_not_awaited()
|
||||
assert len(channel._app.bot.sent_messages) == 1
|
||||
assert "ABCD-EFGH" in channel._app.bot.sent_messages[0]["text"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_help_ignores_unauthorized_user_silently() -> None:
|
||||
async def test_on_help_sends_pairing_code_to_unauthorized_private_user(monkeypatch) -> None:
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["999"], group_policy="open"),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
update = _make_telegram_update(text="/help", chat_type="private")
|
||||
update.message.reply_text = AsyncMock()
|
||||
monkeypatch.setattr(
|
||||
"nanobot.channels.base.generate_code", lambda _ch, _sid: "ABCD-EFGH"
|
||||
)
|
||||
|
||||
await channel._on_help(update, None)
|
||||
|
||||
update.message.reply_text.assert_not_awaited()
|
||||
assert len(channel._app.bot.sent_messages) == 1
|
||||
assert "ABCD-EFGH" in channel._app.bot.sent_messages[0]["text"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_ignores_unauthorized_user_before_side_effects() -> None:
|
||||
async def test_on_message_pairs_unauthorized_private_user_before_side_effects(
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
channel = TelegramChannel(
|
||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["999"], group_policy="open"),
|
||||
MessageBus(),
|
||||
)
|
||||
channel._app = _FakeApp(lambda: None)
|
||||
started_typing: list[str] = []
|
||||
handled: list[dict] = []
|
||||
channel._start_typing = lambda chat_id: started_typing.append(chat_id)
|
||||
channel._add_reaction = AsyncMock(return_value=None)
|
||||
|
||||
async def capture_handle(**kwargs) -> None:
|
||||
handled.append(kwargs)
|
||||
|
||||
channel._handle_message = capture_handle
|
||||
channel._download_message_media = AsyncMock(return_value=([], []))
|
||||
monkeypatch.setattr(
|
||||
"nanobot.channels.base.generate_code", lambda _ch, _sid: "ABCD-EFGH"
|
||||
)
|
||||
|
||||
await channel._on_message(_make_telegram_update(text="hello", chat_type="private"), None)
|
||||
|
||||
assert started_typing == []
|
||||
channel._add_reaction.assert_not_awaited()
|
||||
assert handled == []
|
||||
channel._download_message_media.assert_not_awaited()
|
||||
assert len(channel._app.bot.sent_messages) == 1
|
||||
assert "ABCD-EFGH" in channel._app.bot.sent_messages[0]["text"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@ -130,14 +130,24 @@ async def test_process_message_caches_context_token_and_send_uses_it() -> None:
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_message_ignores_unauthorized_sender_before_side_effects(tmp_path) -> None:
|
||||
async def test_process_message_pairs_unauthorized_sender_before_media_side_effects(
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
) -> None:
|
||||
bus = MessageBus()
|
||||
channel = WeixinChannel(
|
||||
WeixinConfig(enabled=True, allow_from=["allowed-user"], state_dir=str(tmp_path)),
|
||||
bus,
|
||||
)
|
||||
channel._client = object()
|
||||
channel._token = "token"
|
||||
channel._download_media_item = AsyncMock(return_value="/tmp/test.jpg")
|
||||
channel._start_typing = AsyncMock()
|
||||
channel._get_typing_ticket = AsyncMock(return_value="")
|
||||
channel._send_text = AsyncMock()
|
||||
monkeypatch.setattr(
|
||||
"nanobot.channels.base.generate_code", lambda _ch, _sid: "ABCD-EFGH"
|
||||
)
|
||||
|
||||
await channel._process_message(
|
||||
{
|
||||
@ -154,6 +164,11 @@ async def test_process_message_ignores_unauthorized_sender_before_side_effects(t
|
||||
assert channel._context_tokens == {}
|
||||
channel._download_media_item.assert_not_awaited()
|
||||
channel._start_typing.assert_not_awaited()
|
||||
channel._send_text.assert_awaited_once()
|
||||
send_args = channel._send_text.await_args.args
|
||||
assert send_args[0] == "blocked-user"
|
||||
assert "ABCD-EFGH" in send_args[1]
|
||||
assert send_args[2] == "ctx-blocked"
|
||||
assert bus.inbound_size == 0
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user