mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-13 22:34:06 +00:00
Fix pairing for Weixin and Telegram DMs
This commit is contained in:
parent
d435cb0b21
commit
3da68ac7fe
@ -869,7 +869,9 @@ class TelegramChannel(BaseChannel):
|
|||||||
return
|
return
|
||||||
|
|
||||||
user = update.effective_user
|
user = update.effective_user
|
||||||
if not self.is_allowed(self._sender_id(user)):
|
sender_id = self._sender_id(user)
|
||||||
|
if not self.is_allowed(sender_id):
|
||||||
|
await self._send_pairing_code_if_private(sender_id, update.message, user)
|
||||||
return
|
return
|
||||||
await update.message.reply_text(
|
await update.message.reply_text(
|
||||||
f"👋 Hi {user.first_name}! I'm nanobot.\n\n"
|
f"👋 Hi {user.first_name}! I'm nanobot.\n\n"
|
||||||
@ -881,7 +883,10 @@ class TelegramChannel(BaseChannel):
|
|||||||
"""Handle /help command for allowed users only."""
|
"""Handle /help command for allowed users only."""
|
||||||
if not update.message or not update.effective_user:
|
if not update.message or not update.effective_user:
|
||||||
return
|
return
|
||||||
if not self.is_allowed(self._sender_id(update.effective_user)):
|
user = update.effective_user
|
||||||
|
sender_id = self._sender_id(user)
|
||||||
|
if not self.is_allowed(sender_id):
|
||||||
|
await self._send_pairing_code_if_private(sender_id, update.message, user)
|
||||||
return
|
return
|
||||||
await update.message.reply_text(build_help_text())
|
await update.message.reply_text(build_help_text())
|
||||||
|
|
||||||
@ -891,6 +896,17 @@ class TelegramChannel(BaseChannel):
|
|||||||
sid = str(user.id)
|
sid = str(user.id)
|
||||||
return f"{sid}|{user.username}" if user.username else sid
|
return f"{sid}|{user.username}" if user.username else sid
|
||||||
|
|
||||||
|
async def _send_pairing_code_if_private(self, sender_id: str, message, user) -> None:
|
||||||
|
if message.chat.type != "private":
|
||||||
|
return
|
||||||
|
await self._handle_message(
|
||||||
|
sender_id=sender_id,
|
||||||
|
chat_id=str(message.chat_id),
|
||||||
|
content="",
|
||||||
|
metadata=self._build_message_metadata(message, user),
|
||||||
|
is_dm=True,
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _derive_topic_session_key(message) -> str | None:
|
def _derive_topic_session_key(message) -> str | None:
|
||||||
"""Derive topic-scoped session key for Telegram chats with threads."""
|
"""Derive topic-scoped session key for Telegram chats with threads."""
|
||||||
@ -1149,6 +1165,7 @@ class TelegramChannel(BaseChannel):
|
|||||||
user = update.effective_user
|
user = update.effective_user
|
||||||
sender_id = self._sender_id(user)
|
sender_id = self._sender_id(user)
|
||||||
if not self.is_allowed(sender_id):
|
if not self.is_allowed(sender_id):
|
||||||
|
await self._send_pairing_code_if_private(sender_id, message, user)
|
||||||
return
|
return
|
||||||
self._remember_thread_context(message)
|
self._remember_thread_context(message)
|
||||||
|
|
||||||
@ -1186,6 +1203,7 @@ class TelegramChannel(BaseChannel):
|
|||||||
chat_id = message.chat_id
|
chat_id = message.chat_id
|
||||||
sender_id = self._sender_id(user)
|
sender_id = self._sender_id(user)
|
||||||
if not self.is_allowed(sender_id):
|
if not self.is_allowed(sender_id):
|
||||||
|
await self._send_pairing_code_if_private(sender_id, message, user)
|
||||||
return
|
return
|
||||||
self._remember_thread_context(message)
|
self._remember_thread_context(message)
|
||||||
|
|
||||||
|
|||||||
@ -609,9 +609,6 @@ class WeixinChannel(BaseChannel):
|
|||||||
if not from_user_id:
|
if not from_user_id:
|
||||||
return
|
return
|
||||||
|
|
||||||
if not self.is_allowed(from_user_id):
|
|
||||||
return
|
|
||||||
|
|
||||||
# Deduplication by message_id
|
# Deduplication by message_id
|
||||||
if msg_id in self._processed_ids:
|
if msg_id in self._processed_ids:
|
||||||
return
|
return
|
||||||
@ -619,8 +616,51 @@ class WeixinChannel(BaseChannel):
|
|||||||
while len(self._processed_ids) > 1000:
|
while len(self._processed_ids) > 1000:
|
||||||
self._processed_ids.popitem(last=False)
|
self._processed_ids.popitem(last=False)
|
||||||
|
|
||||||
# Cache context_token (required for all replies — inbound.ts:23-27)
|
|
||||||
ctx_token = msg.get("context_token", "")
|
ctx_token = msg.get("context_token", "")
|
||||||
|
if not self.is_allowed(from_user_id):
|
||||||
|
if from_user_id.endswith("@chatroom"):
|
||||||
|
await self._handle_message(
|
||||||
|
sender_id=from_user_id,
|
||||||
|
chat_id=from_user_id,
|
||||||
|
content="",
|
||||||
|
metadata={"message_id": msg_id},
|
||||||
|
is_dm=False,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if not ctx_token:
|
||||||
|
self.logger.warning(
|
||||||
|
"Access denied for sender {}; cannot send WeChat pairing code without context_token",
|
||||||
|
from_user_id,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
had_ctx_token = from_user_id in self._context_tokens
|
||||||
|
previous_ctx_token = self._context_tokens.get(from_user_id, "")
|
||||||
|
had_ctx_token_at = from_user_id in self._context_token_at
|
||||||
|
previous_ctx_token_at = self._context_token_at.get(from_user_id, 0.0)
|
||||||
|
self._context_tokens[from_user_id] = ctx_token
|
||||||
|
self._context_token_at[from_user_id] = time.time()
|
||||||
|
try:
|
||||||
|
await self._handle_message(
|
||||||
|
sender_id=from_user_id,
|
||||||
|
chat_id=from_user_id,
|
||||||
|
content="",
|
||||||
|
metadata={"message_id": msg_id},
|
||||||
|
is_dm=True,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
if had_ctx_token:
|
||||||
|
self._context_tokens[from_user_id] = previous_ctx_token
|
||||||
|
else:
|
||||||
|
self._context_tokens.pop(from_user_id, None)
|
||||||
|
if had_ctx_token_at:
|
||||||
|
self._context_token_at[from_user_id] = previous_ctx_token_at
|
||||||
|
else:
|
||||||
|
self._context_token_at.pop(from_user_id, None)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Cache context_token (required for all replies — inbound.ts:23-27)
|
||||||
if ctx_token:
|
if ctx_token:
|
||||||
self._context_tokens[from_user_id] = ctx_token
|
self._context_tokens[from_user_id] = ctx_token
|
||||||
self._context_token_at[from_user_id] = time.time()
|
self._context_token_at[from_user_id] = time.time()
|
||||||
|
|||||||
@ -1359,6 +1359,23 @@ async def test_forward_command_does_not_inject_reply_context() -> None:
|
|||||||
assert handled[0]["content"] == "/new"
|
assert handled[0]["content"] == "/new"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_forward_command_pairs_unauthorized_private_user(monkeypatch) -> None:
|
||||||
|
channel = TelegramChannel(
|
||||||
|
TelegramConfig(enabled=True, token="123:abc", allow_from=["999"], group_policy="open"),
|
||||||
|
MessageBus(),
|
||||||
|
)
|
||||||
|
channel._app = _FakeApp(lambda: None)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"nanobot.channels.base.generate_code", lambda _ch, _sid: "ABCD-EFGH"
|
||||||
|
)
|
||||||
|
|
||||||
|
await channel._forward_command(_make_telegram_update(text="/new", chat_type="private"), None)
|
||||||
|
|
||||||
|
assert len(channel._app.bot.sent_messages) == 1
|
||||||
|
assert "ABCD-EFGH" in channel._app.bot.sent_messages[0]["text"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_forward_command_preserves_dream_log_args_and_strips_bot_suffix() -> None:
|
async def test_forward_command_preserves_dream_log_args_and_strips_bot_suffix() -> None:
|
||||||
channel = TelegramChannel(
|
channel = TelegramChannel(
|
||||||
@ -1439,55 +1456,69 @@ async def test_on_help_includes_restart_command() -> None:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_on_start_ignores_unauthorized_user_silently() -> None:
|
async def test_on_start_sends_pairing_code_to_unauthorized_private_user(monkeypatch) -> None:
|
||||||
channel = TelegramChannel(
|
channel = TelegramChannel(
|
||||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["999"], group_policy="open"),
|
TelegramConfig(enabled=True, token="123:abc", allow_from=["999"], group_policy="open"),
|
||||||
MessageBus(),
|
MessageBus(),
|
||||||
)
|
)
|
||||||
|
channel._app = _FakeApp(lambda: None)
|
||||||
update = _make_telegram_update(text="/start", chat_type="private")
|
update = _make_telegram_update(text="/start", chat_type="private")
|
||||||
update.message.reply_text = AsyncMock()
|
update.message.reply_text = AsyncMock()
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"nanobot.channels.base.generate_code", lambda _ch, _sid: "ABCD-EFGH"
|
||||||
|
)
|
||||||
|
|
||||||
await channel._on_start(update, None)
|
await channel._on_start(update, None)
|
||||||
|
|
||||||
update.message.reply_text.assert_not_awaited()
|
update.message.reply_text.assert_not_awaited()
|
||||||
|
assert len(channel._app.bot.sent_messages) == 1
|
||||||
|
assert "ABCD-EFGH" in channel._app.bot.sent_messages[0]["text"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_on_help_ignores_unauthorized_user_silently() -> None:
|
async def test_on_help_sends_pairing_code_to_unauthorized_private_user(monkeypatch) -> None:
|
||||||
channel = TelegramChannel(
|
channel = TelegramChannel(
|
||||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["999"], group_policy="open"),
|
TelegramConfig(enabled=True, token="123:abc", allow_from=["999"], group_policy="open"),
|
||||||
MessageBus(),
|
MessageBus(),
|
||||||
)
|
)
|
||||||
|
channel._app = _FakeApp(lambda: None)
|
||||||
update = _make_telegram_update(text="/help", chat_type="private")
|
update = _make_telegram_update(text="/help", chat_type="private")
|
||||||
update.message.reply_text = AsyncMock()
|
update.message.reply_text = AsyncMock()
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"nanobot.channels.base.generate_code", lambda _ch, _sid: "ABCD-EFGH"
|
||||||
|
)
|
||||||
|
|
||||||
await channel._on_help(update, None)
|
await channel._on_help(update, None)
|
||||||
|
|
||||||
update.message.reply_text.assert_not_awaited()
|
update.message.reply_text.assert_not_awaited()
|
||||||
|
assert len(channel._app.bot.sent_messages) == 1
|
||||||
|
assert "ABCD-EFGH" in channel._app.bot.sent_messages[0]["text"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_on_message_ignores_unauthorized_user_before_side_effects() -> None:
|
async def test_on_message_pairs_unauthorized_private_user_before_side_effects(
|
||||||
|
monkeypatch,
|
||||||
|
) -> None:
|
||||||
channel = TelegramChannel(
|
channel = TelegramChannel(
|
||||||
TelegramConfig(enabled=True, token="123:abc", allow_from=["999"], group_policy="open"),
|
TelegramConfig(enabled=True, token="123:abc", allow_from=["999"], group_policy="open"),
|
||||||
MessageBus(),
|
MessageBus(),
|
||||||
)
|
)
|
||||||
channel._app = _FakeApp(lambda: None)
|
channel._app = _FakeApp(lambda: None)
|
||||||
started_typing: list[str] = []
|
started_typing: list[str] = []
|
||||||
handled: list[dict] = []
|
|
||||||
channel._start_typing = lambda chat_id: started_typing.append(chat_id)
|
channel._start_typing = lambda chat_id: started_typing.append(chat_id)
|
||||||
channel._add_reaction = AsyncMock(return_value=None)
|
channel._add_reaction = AsyncMock(return_value=None)
|
||||||
|
channel._download_message_media = AsyncMock(return_value=([], []))
|
||||||
async def capture_handle(**kwargs) -> None:
|
monkeypatch.setattr(
|
||||||
handled.append(kwargs)
|
"nanobot.channels.base.generate_code", lambda _ch, _sid: "ABCD-EFGH"
|
||||||
|
)
|
||||||
channel._handle_message = capture_handle
|
|
||||||
|
|
||||||
await channel._on_message(_make_telegram_update(text="hello", chat_type="private"), None)
|
await channel._on_message(_make_telegram_update(text="hello", chat_type="private"), None)
|
||||||
|
|
||||||
assert started_typing == []
|
assert started_typing == []
|
||||||
channel._add_reaction.assert_not_awaited()
|
channel._add_reaction.assert_not_awaited()
|
||||||
assert handled == []
|
channel._download_message_media.assert_not_awaited()
|
||||||
|
assert len(channel._app.bot.sent_messages) == 1
|
||||||
|
assert "ABCD-EFGH" in channel._app.bot.sent_messages[0]["text"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
@ -130,14 +130,24 @@ async def test_process_message_caches_context_token_and_send_uses_it() -> None:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_process_message_ignores_unauthorized_sender_before_side_effects(tmp_path) -> None:
|
async def test_process_message_pairs_unauthorized_sender_before_media_side_effects(
|
||||||
|
monkeypatch,
|
||||||
|
tmp_path,
|
||||||
|
) -> None:
|
||||||
bus = MessageBus()
|
bus = MessageBus()
|
||||||
channel = WeixinChannel(
|
channel = WeixinChannel(
|
||||||
WeixinConfig(enabled=True, allow_from=["allowed-user"], state_dir=str(tmp_path)),
|
WeixinConfig(enabled=True, allow_from=["allowed-user"], state_dir=str(tmp_path)),
|
||||||
bus,
|
bus,
|
||||||
)
|
)
|
||||||
|
channel._client = object()
|
||||||
|
channel._token = "token"
|
||||||
channel._download_media_item = AsyncMock(return_value="/tmp/test.jpg")
|
channel._download_media_item = AsyncMock(return_value="/tmp/test.jpg")
|
||||||
channel._start_typing = AsyncMock()
|
channel._start_typing = AsyncMock()
|
||||||
|
channel._get_typing_ticket = AsyncMock(return_value="")
|
||||||
|
channel._send_text = AsyncMock()
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"nanobot.channels.base.generate_code", lambda _ch, _sid: "ABCD-EFGH"
|
||||||
|
)
|
||||||
|
|
||||||
await channel._process_message(
|
await channel._process_message(
|
||||||
{
|
{
|
||||||
@ -154,6 +164,11 @@ async def test_process_message_ignores_unauthorized_sender_before_side_effects(t
|
|||||||
assert channel._context_tokens == {}
|
assert channel._context_tokens == {}
|
||||||
channel._download_media_item.assert_not_awaited()
|
channel._download_media_item.assert_not_awaited()
|
||||||
channel._start_typing.assert_not_awaited()
|
channel._start_typing.assert_not_awaited()
|
||||||
|
channel._send_text.assert_awaited_once()
|
||||||
|
send_args = channel._send_text.await_args.args
|
||||||
|
assert send_args[0] == "blocked-user"
|
||||||
|
assert "ABCD-EFGH" in send_args[1]
|
||||||
|
assert send_args[2] == "ctx-blocked"
|
||||||
assert bus.inbound_size == 0
|
assert bus.inbound_size == 0
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user