mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-27 21:35:51 +00:00
feat(weixin): add voice message, typing keepalive, getConfig cache, and QR polling resilience
This commit is contained in:
parent
0514233217
commit
26947db479
@ -15,6 +15,7 @@ import hashlib
|
|||||||
import json
|
import json
|
||||||
import mimetypes
|
import mimetypes
|
||||||
import os
|
import os
|
||||||
|
import random
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
@ -102,18 +103,23 @@ MAX_QR_REFRESH_COUNT = 3
|
|||||||
TYPING_STATUS_TYPING = 1
|
TYPING_STATUS_TYPING = 1
|
||||||
TYPING_STATUS_CANCEL = 2
|
TYPING_STATUS_CANCEL = 2
|
||||||
TYPING_TICKET_TTL_S = 24 * 60 * 60
|
TYPING_TICKET_TTL_S = 24 * 60 * 60
|
||||||
|
TYPING_KEEPALIVE_INTERVAL_S = 5
|
||||||
|
CONFIG_CACHE_INITIAL_RETRY_S = 2
|
||||||
|
CONFIG_CACHE_MAX_RETRY_S = 60 * 60
|
||||||
|
|
||||||
# Default long-poll timeout; overridden by server via longpolling_timeout_ms.
|
# Default long-poll timeout; overridden by server via longpolling_timeout_ms.
|
||||||
DEFAULT_LONG_POLL_TIMEOUT_S = 35
|
DEFAULT_LONG_POLL_TIMEOUT_S = 35
|
||||||
|
|
||||||
# Media-type codes for getuploadurl (1=image, 2=video, 3=file)
|
# Media-type codes for getuploadurl (1=image, 2=video, 3=file, 4=voice)
|
||||||
UPLOAD_MEDIA_IMAGE = 1
|
UPLOAD_MEDIA_IMAGE = 1
|
||||||
UPLOAD_MEDIA_VIDEO = 2
|
UPLOAD_MEDIA_VIDEO = 2
|
||||||
UPLOAD_MEDIA_FILE = 3
|
UPLOAD_MEDIA_FILE = 3
|
||||||
|
UPLOAD_MEDIA_VOICE = 4
|
||||||
|
|
||||||
# File extensions considered as images / videos for outbound media
|
# File extensions considered as images / videos for outbound media
|
||||||
_IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp", ".tiff", ".ico", ".svg"}
|
_IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp", ".tiff", ".ico", ".svg"}
|
||||||
_VIDEO_EXTS = {".mp4", ".avi", ".mov", ".mkv", ".webm", ".flv"}
|
_VIDEO_EXTS = {".mp4", ".avi", ".mov", ".mkv", ".webm", ".flv"}
|
||||||
|
_VOICE_EXTS = {".mp3", ".wav", ".amr", ".silk", ".ogg", ".m4a", ".aac", ".flac"}
|
||||||
|
|
||||||
|
|
||||||
def _has_downloadable_media_locator(media: dict[str, Any] | None) -> bool:
|
def _has_downloadable_media_locator(media: dict[str, Any] | None) -> bool:
|
||||||
@ -167,7 +173,7 @@ class WeixinChannel(BaseChannel):
|
|||||||
self._poll_task: asyncio.Task | None = None
|
self._poll_task: asyncio.Task | None = None
|
||||||
self._next_poll_timeout_s: int = DEFAULT_LONG_POLL_TIMEOUT_S
|
self._next_poll_timeout_s: int = DEFAULT_LONG_POLL_TIMEOUT_S
|
||||||
self._session_pause_until: float = 0.0
|
self._session_pause_until: float = 0.0
|
||||||
self._typing_tickets: dict[str, tuple[str, float]] = {}
|
self._typing_tickets: dict[str, dict[str, Any]] = {}
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# State persistence
|
# State persistence
|
||||||
@ -339,7 +345,16 @@ class WeixinChannel(BaseChannel):
|
|||||||
params={"qrcode": qrcode_id},
|
params={"qrcode": qrcode_id},
|
||||||
auth=False,
|
auth=False,
|
||||||
)
|
)
|
||||||
except httpx.TimeoutException:
|
except Exception as e:
|
||||||
|
if self._is_retryable_qr_poll_error(e):
|
||||||
|
logger.warning("QR polling temporary error, will retry: {}", e)
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
continue
|
||||||
|
raise
|
||||||
|
|
||||||
|
if not isinstance(status_data, dict):
|
||||||
|
logger.warning("QR polling got non-object response, continue waiting")
|
||||||
|
await asyncio.sleep(1)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
status = status_data.get("status", "")
|
status = status_data.get("status", "")
|
||||||
@ -408,6 +423,16 @@ class WeixinChannel(BaseChannel):
|
|||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _is_retryable_qr_poll_error(err: Exception) -> bool:
|
||||||
|
if isinstance(err, httpx.TimeoutException | httpx.TransportError):
|
||||||
|
return True
|
||||||
|
if isinstance(err, httpx.HTTPStatusError):
|
||||||
|
status_code = err.response.status_code if err.response is not None else 0
|
||||||
|
if status_code >= 500:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _print_qr_code(url: str) -> None:
|
def _print_qr_code(url: str) -> None:
|
||||||
try:
|
try:
|
||||||
@ -858,13 +883,11 @@ class WeixinChannel(BaseChannel):
|
|||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
async def _get_typing_ticket(self, user_id: str, context_token: str = "") -> str:
|
async def _get_typing_ticket(self, user_id: str, context_token: str = "") -> str:
|
||||||
"""Get typing ticket for a user with simple per-user TTL cache."""
|
"""Get typing ticket with per-user refresh + failure backoff cache."""
|
||||||
now = time.time()
|
now = time.time()
|
||||||
cached = self._typing_tickets.get(user_id)
|
entry = self._typing_tickets.get(user_id)
|
||||||
if cached:
|
if entry and now < float(entry.get("next_fetch_at", 0)):
|
||||||
ticket, expires_at = cached
|
return str(entry.get("ticket", "") or "")
|
||||||
if ticket and now < expires_at:
|
|
||||||
return ticket
|
|
||||||
|
|
||||||
body: dict[str, Any] = {
|
body: dict[str, Any] = {
|
||||||
"ilink_user_id": user_id,
|
"ilink_user_id": user_id,
|
||||||
@ -874,9 +897,27 @@ class WeixinChannel(BaseChannel):
|
|||||||
data = await self._api_post("ilink/bot/getconfig", body)
|
data = await self._api_post("ilink/bot/getconfig", body)
|
||||||
if data.get("ret", 0) == 0:
|
if data.get("ret", 0) == 0:
|
||||||
ticket = str(data.get("typing_ticket", "") or "")
|
ticket = str(data.get("typing_ticket", "") or "")
|
||||||
if ticket:
|
self._typing_tickets[user_id] = {
|
||||||
self._typing_tickets[user_id] = (ticket, now + TYPING_TICKET_TTL_S)
|
"ticket": ticket,
|
||||||
return ticket
|
"ever_succeeded": True,
|
||||||
|
"next_fetch_at": now + (random.random() * TYPING_TICKET_TTL_S),
|
||||||
|
"retry_delay_s": CONFIG_CACHE_INITIAL_RETRY_S,
|
||||||
|
}
|
||||||
|
return ticket
|
||||||
|
|
||||||
|
prev_delay = float(entry.get("retry_delay_s", CONFIG_CACHE_INITIAL_RETRY_S)) if entry else CONFIG_CACHE_INITIAL_RETRY_S
|
||||||
|
next_delay = min(prev_delay * 2, CONFIG_CACHE_MAX_RETRY_S)
|
||||||
|
if entry:
|
||||||
|
entry["next_fetch_at"] = now + next_delay
|
||||||
|
entry["retry_delay_s"] = next_delay
|
||||||
|
return str(entry.get("ticket", "") or "")
|
||||||
|
|
||||||
|
self._typing_tickets[user_id] = {
|
||||||
|
"ticket": "",
|
||||||
|
"ever_succeeded": False,
|
||||||
|
"next_fetch_at": now + CONFIG_CACHE_INITIAL_RETRY_S,
|
||||||
|
"retry_delay_s": CONFIG_CACHE_INITIAL_RETRY_S,
|
||||||
|
}
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
async def _send_typing(self, user_id: str, typing_ticket: str, status: int) -> None:
|
async def _send_typing(self, user_id: str, typing_ticket: str, status: int) -> None:
|
||||||
@ -891,6 +932,16 @@ class WeixinChannel(BaseChannel):
|
|||||||
}
|
}
|
||||||
await self._api_post("ilink/bot/sendtyping", body)
|
await self._api_post("ilink/bot/sendtyping", body)
|
||||||
|
|
||||||
|
async def _typing_keepalive_loop(self, user_id: str, typing_ticket: str, stop_event: asyncio.Event) -> None:
|
||||||
|
while not stop_event.is_set():
|
||||||
|
await asyncio.sleep(TYPING_KEEPALIVE_INTERVAL_S)
|
||||||
|
if stop_event.is_set():
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
await self._send_typing(user_id, typing_ticket, TYPING_STATUS_TYPING)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug("WeChat sendtyping(keepalive) failed for {}: {}", user_id, e)
|
||||||
|
|
||||||
async def send(self, msg: OutboundMessage) -> None:
|
async def send(self, msg: OutboundMessage) -> None:
|
||||||
if not self._client or not self._token:
|
if not self._client or not self._token:
|
||||||
logger.warning("WeChat client not initialized or not authenticated")
|
logger.warning("WeChat client not initialized or not authenticated")
|
||||||
@ -923,6 +974,13 @@ class WeixinChannel(BaseChannel):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug("WeChat sendtyping(start) failed for {}: {}", msg.chat_id, e)
|
logger.debug("WeChat sendtyping(start) failed for {}: {}", msg.chat_id, e)
|
||||||
|
|
||||||
|
typing_keepalive_stop = asyncio.Event()
|
||||||
|
typing_keepalive_task: asyncio.Task | None = None
|
||||||
|
if typing_ticket:
|
||||||
|
typing_keepalive_task = asyncio.create_task(
|
||||||
|
self._typing_keepalive_loop(msg.chat_id, typing_ticket, typing_keepalive_stop)
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# --- Send media files first (following Telegram channel pattern) ---
|
# --- Send media files first (following Telegram channel pattern) ---
|
||||||
for media_path in (msg.media or []):
|
for media_path in (msg.media or []):
|
||||||
@ -947,6 +1005,14 @@ class WeixinChannel(BaseChannel):
|
|||||||
logger.error("Error sending WeChat message: {}", e)
|
logger.error("Error sending WeChat message: {}", e)
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
|
if typing_keepalive_task:
|
||||||
|
typing_keepalive_stop.set()
|
||||||
|
typing_keepalive_task.cancel()
|
||||||
|
try:
|
||||||
|
await typing_keepalive_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
if typing_ticket:
|
if typing_ticket:
|
||||||
try:
|
try:
|
||||||
await self._send_typing(msg.chat_id, typing_ticket, TYPING_STATUS_CANCEL)
|
await self._send_typing(msg.chat_id, typing_ticket, TYPING_STATUS_CANCEL)
|
||||||
@ -1025,6 +1091,10 @@ class WeixinChannel(BaseChannel):
|
|||||||
upload_type = UPLOAD_MEDIA_VIDEO
|
upload_type = UPLOAD_MEDIA_VIDEO
|
||||||
item_type = ITEM_VIDEO
|
item_type = ITEM_VIDEO
|
||||||
item_key = "video_item"
|
item_key = "video_item"
|
||||||
|
elif ext in _VOICE_EXTS:
|
||||||
|
upload_type = UPLOAD_MEDIA_VOICE
|
||||||
|
item_type = ITEM_VOICE
|
||||||
|
item_key = "voice_item"
|
||||||
else:
|
else:
|
||||||
upload_type = UPLOAD_MEDIA_FILE
|
upload_type = UPLOAD_MEDIA_FILE
|
||||||
item_type = ITEM_FILE
|
item_type = ITEM_FILE
|
||||||
|
|||||||
@ -6,6 +6,7 @@ from types import SimpleNamespace
|
|||||||
from unittest.mock import AsyncMock
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import httpx
|
||||||
|
|
||||||
import nanobot.channels.weixin as weixin_mod
|
import nanobot.channels.weixin as weixin_mod
|
||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
@ -595,6 +596,158 @@ async def test_send_media_falls_back_to_upload_param_url(tmp_path) -> None:
|
|||||||
assert "&filekey=" in cdn_url
|
assert "&filekey=" in cdn_url
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_media_voice_file_uses_voice_item_and_voice_upload_type(tmp_path) -> None:
|
||||||
|
channel, _bus = _make_channel()
|
||||||
|
|
||||||
|
media_file = tmp_path / "voice.mp3"
|
||||||
|
media_file.write_bytes(b"voice-bytes")
|
||||||
|
|
||||||
|
cdn_post = AsyncMock(return_value=_DummyHttpResponse(headers={"x-encrypted-param": "voice-dl-param"}))
|
||||||
|
channel._client = SimpleNamespace(post=cdn_post)
|
||||||
|
channel._api_post = AsyncMock(
|
||||||
|
side_effect=[
|
||||||
|
{"upload_full_url": "https://upload-full.example.test/voice?foo=bar"},
|
||||||
|
{"ret": 0},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
await channel._send_media_file("wx-user", str(media_file), "ctx-voice")
|
||||||
|
|
||||||
|
getupload_body = channel._api_post.await_args_list[0].args[1]
|
||||||
|
assert getupload_body["media_type"] == 4
|
||||||
|
|
||||||
|
sendmessage_body = channel._api_post.await_args_list[1].args[1]
|
||||||
|
item = sendmessage_body["msg"]["item_list"][0]
|
||||||
|
assert item["type"] == 3
|
||||||
|
assert "voice_item" in item
|
||||||
|
assert "file_item" not in item
|
||||||
|
assert item["voice_item"]["media"]["encrypt_query_param"] == "voice-dl-param"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_typing_uses_keepalive_until_send_finishes() -> None:
|
||||||
|
channel, _bus = _make_channel()
|
||||||
|
channel._client = object()
|
||||||
|
channel._token = "token"
|
||||||
|
channel._context_tokens["wx-user"] = "ctx-typing-loop"
|
||||||
|
async def _api_post_side_effect(endpoint: str, _body: dict | None = None, *, auth: bool = True):
|
||||||
|
if endpoint == "ilink/bot/getconfig":
|
||||||
|
return {"ret": 0, "typing_ticket": "ticket-keepalive"}
|
||||||
|
return {"ret": 0}
|
||||||
|
|
||||||
|
channel._api_post = AsyncMock(side_effect=_api_post_side_effect)
|
||||||
|
|
||||||
|
async def _slow_send_text(*_args, **_kwargs) -> None:
|
||||||
|
await asyncio.sleep(0.03)
|
||||||
|
|
||||||
|
channel._send_text = AsyncMock(side_effect=_slow_send_text)
|
||||||
|
|
||||||
|
old_interval = weixin_mod.TYPING_KEEPALIVE_INTERVAL_S
|
||||||
|
weixin_mod.TYPING_KEEPALIVE_INTERVAL_S = 0.01
|
||||||
|
try:
|
||||||
|
await channel.send(
|
||||||
|
type("Msg", (), {"chat_id": "wx-user", "content": "pong", "media": [], "metadata": {}})()
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
weixin_mod.TYPING_KEEPALIVE_INTERVAL_S = old_interval
|
||||||
|
|
||||||
|
status_calls = [
|
||||||
|
c.args[1]["status"]
|
||||||
|
for c in channel._api_post.await_args_list
|
||||||
|
if c.args and c.args[0] == "ilink/bot/sendtyping"
|
||||||
|
]
|
||||||
|
assert status_calls.count(1) >= 2
|
||||||
|
assert status_calls[-1] == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_typing_ticket_failure_uses_backoff_and_cached_ticket(monkeypatch) -> None:
|
||||||
|
channel, _bus = _make_channel()
|
||||||
|
channel._client = object()
|
||||||
|
channel._token = "token"
|
||||||
|
|
||||||
|
now = {"value": 1000.0}
|
||||||
|
monkeypatch.setattr(weixin_mod.time, "time", lambda: now["value"])
|
||||||
|
monkeypatch.setattr(weixin_mod.random, "random", lambda: 0.5)
|
||||||
|
|
||||||
|
channel._api_post = AsyncMock(return_value={"ret": 0, "typing_ticket": "ticket-ok"})
|
||||||
|
first = await channel._get_typing_ticket("wx-user", "ctx-1")
|
||||||
|
assert first == "ticket-ok"
|
||||||
|
|
||||||
|
# force refresh window reached
|
||||||
|
now["value"] = now["value"] + (12 * 60 * 60) + 1
|
||||||
|
channel._api_post = AsyncMock(return_value={"ret": 1, "errmsg": "temporary failure"})
|
||||||
|
|
||||||
|
# On refresh failure, should still return cached ticket and apply backoff.
|
||||||
|
second = await channel._get_typing_ticket("wx-user", "ctx-2")
|
||||||
|
assert second == "ticket-ok"
|
||||||
|
assert channel._api_post.await_count == 1
|
||||||
|
|
||||||
|
# Before backoff expiry, no extra fetch should happen.
|
||||||
|
now["value"] += 1
|
||||||
|
third = await channel._get_typing_ticket("wx-user", "ctx-3")
|
||||||
|
assert third == "ticket-ok"
|
||||||
|
assert channel._api_post.await_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_qr_login_treats_temporary_connect_error_as_wait_and_recovers() -> None:
|
||||||
|
channel, _bus = _make_channel()
|
||||||
|
channel._running = True
|
||||||
|
channel._save_state = lambda: None
|
||||||
|
channel._print_qr_code = lambda url: None
|
||||||
|
channel._fetch_qr_code = AsyncMock(return_value=("qr-1", "url-1"))
|
||||||
|
|
||||||
|
request = httpx.Request("GET", "https://ilinkai.weixin.qq.com/ilink/bot/get_qrcode_status")
|
||||||
|
channel._api_get_with_base = AsyncMock(
|
||||||
|
side_effect=[
|
||||||
|
httpx.ConnectError("temporary network", request=request),
|
||||||
|
{
|
||||||
|
"status": "confirmed",
|
||||||
|
"bot_token": "token-net-ok",
|
||||||
|
"ilink_bot_id": "bot-id",
|
||||||
|
"baseurl": "https://example.test",
|
||||||
|
"ilink_user_id": "wx-user",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
ok = await channel._qr_login()
|
||||||
|
|
||||||
|
assert ok is True
|
||||||
|
assert channel._token == "token-net-ok"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_qr_login_treats_5xx_gateway_response_error_as_wait_and_recovers() -> None:
|
||||||
|
channel, _bus = _make_channel()
|
||||||
|
channel._running = True
|
||||||
|
channel._save_state = lambda: None
|
||||||
|
channel._print_qr_code = lambda url: None
|
||||||
|
channel._fetch_qr_code = AsyncMock(return_value=("qr-1", "url-1"))
|
||||||
|
|
||||||
|
request = httpx.Request("GET", "https://ilinkai.weixin.qq.com/ilink/bot/get_qrcode_status")
|
||||||
|
response = httpx.Response(status_code=524, request=request)
|
||||||
|
channel._api_get_with_base = AsyncMock(
|
||||||
|
side_effect=[
|
||||||
|
httpx.HTTPStatusError("gateway timeout", request=request, response=response),
|
||||||
|
{
|
||||||
|
"status": "confirmed",
|
||||||
|
"bot_token": "token-5xx-ok",
|
||||||
|
"ilink_bot_id": "bot-id",
|
||||||
|
"baseurl": "https://example.test",
|
||||||
|
"ilink_user_id": "wx-user",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
ok = await channel._qr_login()
|
||||||
|
|
||||||
|
assert ok is True
|
||||||
|
assert channel._token == "token-5xx-ok"
|
||||||
|
|
||||||
|
|
||||||
def test_decrypt_aes_ecb_strips_valid_pkcs7_padding() -> None:
|
def test_decrypt_aes_ecb_strips_valid_pkcs7_padding() -> None:
|
||||||
key_b64 = "MDEyMzQ1Njc4OWFiY2RlZg==" # base64("0123456789abcdef")
|
key_b64 = "MDEyMzQ1Njc4OWFiY2RlZg==" # base64("0123456789abcdef")
|
||||||
plaintext = b"hello-weixin-padding"
|
plaintext = b"hello-weixin-padding"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user