mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-02 09:22:36 +00:00
fix(weixin): correct PKCS7 unpadding for AES-ECB; support full_url for media download
This commit is contained in:
parent
5bdb7a90b1
commit
3823042290
@ -685,9 +685,10 @@ class WeixinChannel(BaseChannel):
|
||||
"""Download + AES-decrypt a media item. Returns local path or None."""
|
||||
try:
|
||||
media = typed_item.get("media") or {}
|
||||
encrypt_query_param = media.get("encrypt_query_param", "")
|
||||
encrypt_query_param = str(media.get("encrypt_query_param", "") or "")
|
||||
full_url = str(media.get("full_url", "") or "").strip()
|
||||
|
||||
if not encrypt_query_param:
|
||||
if not encrypt_query_param and not full_url:
|
||||
return None
|
||||
|
||||
# Resolve AES key (media-download.ts:43-45, pic-decrypt.ts:40-52)
|
||||
@ -704,7 +705,10 @@ class WeixinChannel(BaseChannel):
|
||||
elif media_aes_key_b64:
|
||||
aes_key_b64 = media_aes_key_b64
|
||||
|
||||
# Build CDN download URL with proper URL-encoding (cdn-url.ts:7)
|
||||
# Prefer server-provided full_url, fallback to encrypted_query_param URL construction.
|
||||
if full_url:
|
||||
cdn_url = full_url
|
||||
else:
|
||||
cdn_url = (
|
||||
f"{self.config.cdn_base_url}/download"
|
||||
f"?encrypted_query_param={quote(encrypt_query_param)}"
|
||||
@ -727,7 +731,8 @@ class WeixinChannel(BaseChannel):
|
||||
ext = _ext_for_type(media_type)
|
||||
if not filename:
|
||||
ts = int(time.time())
|
||||
h = abs(hash(encrypt_query_param)) % 100000
|
||||
hash_seed = encrypt_query_param or full_url
|
||||
h = abs(hash(hash_seed)) % 100000
|
||||
filename = f"{media_type}_{ts}_{h}{ext}"
|
||||
safe_name = os.path.basename(filename)
|
||||
file_path = media_dir / safe_name
|
||||
@ -1045,24 +1050,43 @@ def _decrypt_aes_ecb(data: bytes, aes_key_b64: str) -> bytes:
|
||||
logger.warning("Failed to parse AES key, returning raw data: {}", e)
|
||||
return data
|
||||
|
||||
decrypted: bytes | None = None
|
||||
|
||||
try:
|
||||
from Crypto.Cipher import AES
|
||||
|
||||
cipher = AES.new(key, AES.MODE_ECB)
|
||||
return cipher.decrypt(data) # pycryptodome auto-strips PKCS7 with unpad
|
||||
decrypted = cipher.decrypt(data)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if decrypted is None:
|
||||
try:
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
|
||||
cipher_obj = Cipher(algorithms.AES(key), modes.ECB())
|
||||
decryptor = cipher_obj.decryptor()
|
||||
return decryptor.update(data) + decryptor.finalize()
|
||||
decrypted = decryptor.update(data) + decryptor.finalize()
|
||||
except ImportError:
|
||||
logger.warning("Cannot decrypt media: install 'pycryptodome' or 'cryptography'")
|
||||
return data
|
||||
|
||||
return _pkcs7_unpad_safe(decrypted)
|
||||
|
||||
|
||||
def _pkcs7_unpad_safe(data: bytes, block_size: int = 16) -> bytes:
|
||||
"""Safely remove PKCS7 padding when valid; otherwise return original bytes."""
|
||||
if not data:
|
||||
return data
|
||||
if len(data) % block_size != 0:
|
||||
return data
|
||||
pad_len = data[-1]
|
||||
if pad_len < 1 or pad_len > block_size:
|
||||
return data
|
||||
if data[-pad_len:] != bytes([pad_len]) * pad_len:
|
||||
return data
|
||||
return data[:-pad_len]
|
||||
|
||||
|
||||
def _ext_for_type(media_type: str) -> str:
|
||||
return {
|
||||
|
||||
@ -7,12 +7,15 @@ from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
import nanobot.channels.weixin as weixin_mod
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.weixin import (
|
||||
ITEM_IMAGE,
|
||||
ITEM_TEXT,
|
||||
MESSAGE_TYPE_BOT,
|
||||
WEIXIN_CHANNEL_VERSION,
|
||||
_decrypt_aes_ecb,
|
||||
_encrypt_aes_ecb,
|
||||
WeixinChannel,
|
||||
WeixinConfig,
|
||||
)
|
||||
@ -340,3 +343,63 @@ async def test_send_media_falls_back_to_upload_param_url(tmp_path) -> None:
|
||||
cdn_url = cdn_post.await_args_list[0].args[0]
|
||||
assert cdn_url.startswith(f"{channel.config.cdn_base_url}/upload?encrypted_query_param=enc-need-fallback")
|
||||
assert "&filekey=" in cdn_url
|
||||
|
||||
|
||||
def test_decrypt_aes_ecb_strips_valid_pkcs7_padding() -> None:
|
||||
key_b64 = "MDEyMzQ1Njc4OWFiY2RlZg==" # base64("0123456789abcdef")
|
||||
plaintext = b"hello-weixin-padding"
|
||||
|
||||
ciphertext = _encrypt_aes_ecb(plaintext, key_b64)
|
||||
decrypted = _decrypt_aes_ecb(ciphertext, key_b64)
|
||||
|
||||
assert decrypted == plaintext
|
||||
|
||||
|
||||
class _DummyDownloadResponse:
|
||||
def __init__(self, content: bytes, status_code: int = 200) -> None:
|
||||
self.content = content
|
||||
self.status_code = status_code
|
||||
|
||||
def raise_for_status(self) -> None:
|
||||
return None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_media_item_uses_full_url_when_present(tmp_path) -> None:
|
||||
channel, _bus = _make_channel()
|
||||
weixin_mod.get_media_dir = lambda _name: tmp_path
|
||||
|
||||
full_url = "https://cdn.example.test/download/full"
|
||||
channel._client = SimpleNamespace(
|
||||
get=AsyncMock(return_value=_DummyDownloadResponse(content=b"raw-image-bytes"))
|
||||
)
|
||||
|
||||
item = {
|
||||
"media": {
|
||||
"full_url": full_url,
|
||||
"encrypt_query_param": "enc-fallback-should-not-be-used",
|
||||
},
|
||||
}
|
||||
saved_path = await channel._download_media_item(item, "image")
|
||||
|
||||
assert saved_path is not None
|
||||
assert Path(saved_path).read_bytes() == b"raw-image-bytes"
|
||||
channel._client.get.assert_awaited_once_with(full_url)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_media_item_falls_back_to_encrypt_query_param(tmp_path) -> None:
|
||||
channel, _bus = _make_channel()
|
||||
weixin_mod.get_media_dir = lambda _name: tmp_path
|
||||
|
||||
channel._client = SimpleNamespace(
|
||||
get=AsyncMock(return_value=_DummyDownloadResponse(content=b"fallback-bytes"))
|
||||
)
|
||||
|
||||
item = {"media": {"encrypt_query_param": "enc-fallback"}}
|
||||
saved_path = await channel._download_media_item(item, "image")
|
||||
|
||||
assert saved_path is not None
|
||||
assert Path(saved_path).read_bytes() == b"fallback-bytes"
|
||||
called_url = channel._client.get.await_args_list[0].args[0]
|
||||
assert called_url.startswith(f"{channel.config.cdn_base_url}/download?encrypted_query_param=enc-fallback")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user