mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-02 09:22:36 +00:00
feat(weixin): add voice message, typing keepalive, getConfig cache, and QR polling resilience
This commit is contained in:
parent
0514233217
commit
26947db479
@ -15,6 +15,7 @@ import hashlib
|
||||
import json
|
||||
import mimetypes
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
@ -102,18 +103,23 @@ MAX_QR_REFRESH_COUNT = 3
|
||||
TYPING_STATUS_TYPING = 1
|
||||
TYPING_STATUS_CANCEL = 2
|
||||
TYPING_TICKET_TTL_S = 24 * 60 * 60
|
||||
TYPING_KEEPALIVE_INTERVAL_S = 5
|
||||
CONFIG_CACHE_INITIAL_RETRY_S = 2
|
||||
CONFIG_CACHE_MAX_RETRY_S = 60 * 60
|
||||
|
||||
# Default long-poll timeout; overridden by server via longpolling_timeout_ms.
|
||||
DEFAULT_LONG_POLL_TIMEOUT_S = 35
|
||||
|
||||
# Media-type codes for getuploadurl (1=image, 2=video, 3=file)
|
||||
# Media-type codes for getuploadurl (1=image, 2=video, 3=file, 4=voice)
|
||||
UPLOAD_MEDIA_IMAGE = 1
|
||||
UPLOAD_MEDIA_VIDEO = 2
|
||||
UPLOAD_MEDIA_FILE = 3
|
||||
UPLOAD_MEDIA_VOICE = 4
|
||||
|
||||
# File extensions considered as images / videos for outbound media
|
||||
_IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp", ".tiff", ".ico", ".svg"}
|
||||
_VIDEO_EXTS = {".mp4", ".avi", ".mov", ".mkv", ".webm", ".flv"}
|
||||
_VOICE_EXTS = {".mp3", ".wav", ".amr", ".silk", ".ogg", ".m4a", ".aac", ".flac"}
|
||||
|
||||
|
||||
def _has_downloadable_media_locator(media: dict[str, Any] | None) -> bool:
|
||||
@ -167,7 +173,7 @@ class WeixinChannel(BaseChannel):
|
||||
self._poll_task: asyncio.Task | None = None
|
||||
self._next_poll_timeout_s: int = DEFAULT_LONG_POLL_TIMEOUT_S
|
||||
self._session_pause_until: float = 0.0
|
||||
self._typing_tickets: dict[str, tuple[str, float]] = {}
|
||||
self._typing_tickets: dict[str, dict[str, Any]] = {}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# State persistence
|
||||
@ -339,7 +345,16 @@ class WeixinChannel(BaseChannel):
|
||||
params={"qrcode": qrcode_id},
|
||||
auth=False,
|
||||
)
|
||||
except httpx.TimeoutException:
|
||||
except Exception as e:
|
||||
if self._is_retryable_qr_poll_error(e):
|
||||
logger.warning("QR polling temporary error, will retry: {}", e)
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
raise
|
||||
|
||||
if not isinstance(status_data, dict):
|
||||
logger.warning("QR polling got non-object response, continue waiting")
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
|
||||
status = status_data.get("status", "")
|
||||
@ -408,6 +423,16 @@ class WeixinChannel(BaseChannel):
|
||||
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _is_retryable_qr_poll_error(err: Exception) -> bool:
|
||||
if isinstance(err, httpx.TimeoutException | httpx.TransportError):
|
||||
return True
|
||||
if isinstance(err, httpx.HTTPStatusError):
|
||||
status_code = err.response.status_code if err.response is not None else 0
|
||||
if status_code >= 500:
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _print_qr_code(url: str) -> None:
|
||||
try:
|
||||
@ -858,13 +883,11 @@ class WeixinChannel(BaseChannel):
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _get_typing_ticket(self, user_id: str, context_token: str = "") -> str:
|
||||
"""Get typing ticket for a user with simple per-user TTL cache."""
|
||||
"""Get typing ticket with per-user refresh + failure backoff cache."""
|
||||
now = time.time()
|
||||
cached = self._typing_tickets.get(user_id)
|
||||
if cached:
|
||||
ticket, expires_at = cached
|
||||
if ticket and now < expires_at:
|
||||
return ticket
|
||||
entry = self._typing_tickets.get(user_id)
|
||||
if entry and now < float(entry.get("next_fetch_at", 0)):
|
||||
return str(entry.get("ticket", "") or "")
|
||||
|
||||
body: dict[str, Any] = {
|
||||
"ilink_user_id": user_id,
|
||||
@ -874,9 +897,27 @@ class WeixinChannel(BaseChannel):
|
||||
data = await self._api_post("ilink/bot/getconfig", body)
|
||||
if data.get("ret", 0) == 0:
|
||||
ticket = str(data.get("typing_ticket", "") or "")
|
||||
if ticket:
|
||||
self._typing_tickets[user_id] = (ticket, now + TYPING_TICKET_TTL_S)
|
||||
return ticket
|
||||
self._typing_tickets[user_id] = {
|
||||
"ticket": ticket,
|
||||
"ever_succeeded": True,
|
||||
"next_fetch_at": now + (random.random() * TYPING_TICKET_TTL_S),
|
||||
"retry_delay_s": CONFIG_CACHE_INITIAL_RETRY_S,
|
||||
}
|
||||
return ticket
|
||||
|
||||
prev_delay = float(entry.get("retry_delay_s", CONFIG_CACHE_INITIAL_RETRY_S)) if entry else CONFIG_CACHE_INITIAL_RETRY_S
|
||||
next_delay = min(prev_delay * 2, CONFIG_CACHE_MAX_RETRY_S)
|
||||
if entry:
|
||||
entry["next_fetch_at"] = now + next_delay
|
||||
entry["retry_delay_s"] = next_delay
|
||||
return str(entry.get("ticket", "") or "")
|
||||
|
||||
self._typing_tickets[user_id] = {
|
||||
"ticket": "",
|
||||
"ever_succeeded": False,
|
||||
"next_fetch_at": now + CONFIG_CACHE_INITIAL_RETRY_S,
|
||||
"retry_delay_s": CONFIG_CACHE_INITIAL_RETRY_S,
|
||||
}
|
||||
return ""
|
||||
|
||||
async def _send_typing(self, user_id: str, typing_ticket: str, status: int) -> None:
|
||||
@ -891,6 +932,16 @@ class WeixinChannel(BaseChannel):
|
||||
}
|
||||
await self._api_post("ilink/bot/sendtyping", body)
|
||||
|
||||
async def _typing_keepalive_loop(self, user_id: str, typing_ticket: str, stop_event: asyncio.Event) -> None:
|
||||
while not stop_event.is_set():
|
||||
await asyncio.sleep(TYPING_KEEPALIVE_INTERVAL_S)
|
||||
if stop_event.is_set():
|
||||
break
|
||||
try:
|
||||
await self._send_typing(user_id, typing_ticket, TYPING_STATUS_TYPING)
|
||||
except Exception as e:
|
||||
logger.debug("WeChat sendtyping(keepalive) failed for {}: {}", user_id, e)
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
if not self._client or not self._token:
|
||||
logger.warning("WeChat client not initialized or not authenticated")
|
||||
@ -923,6 +974,13 @@ class WeixinChannel(BaseChannel):
|
||||
except Exception as e:
|
||||
logger.debug("WeChat sendtyping(start) failed for {}: {}", msg.chat_id, e)
|
||||
|
||||
typing_keepalive_stop = asyncio.Event()
|
||||
typing_keepalive_task: asyncio.Task | None = None
|
||||
if typing_ticket:
|
||||
typing_keepalive_task = asyncio.create_task(
|
||||
self._typing_keepalive_loop(msg.chat_id, typing_ticket, typing_keepalive_stop)
|
||||
)
|
||||
|
||||
try:
|
||||
# --- Send media files first (following Telegram channel pattern) ---
|
||||
for media_path in (msg.media or []):
|
||||
@ -947,6 +1005,14 @@ class WeixinChannel(BaseChannel):
|
||||
logger.error("Error sending WeChat message: {}", e)
|
||||
raise
|
||||
finally:
|
||||
if typing_keepalive_task:
|
||||
typing_keepalive_stop.set()
|
||||
typing_keepalive_task.cancel()
|
||||
try:
|
||||
await typing_keepalive_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
if typing_ticket:
|
||||
try:
|
||||
await self._send_typing(msg.chat_id, typing_ticket, TYPING_STATUS_CANCEL)
|
||||
@ -1025,6 +1091,10 @@ class WeixinChannel(BaseChannel):
|
||||
upload_type = UPLOAD_MEDIA_VIDEO
|
||||
item_type = ITEM_VIDEO
|
||||
item_key = "video_item"
|
||||
elif ext in _VOICE_EXTS:
|
||||
upload_type = UPLOAD_MEDIA_VOICE
|
||||
item_type = ITEM_VOICE
|
||||
item_key = "voice_item"
|
||||
else:
|
||||
upload_type = UPLOAD_MEDIA_FILE
|
||||
item_type = ITEM_FILE
|
||||
|
||||
@ -6,6 +6,7 @@ from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
import httpx
|
||||
|
||||
import nanobot.channels.weixin as weixin_mod
|
||||
from nanobot.bus.queue import MessageBus
|
||||
@ -595,6 +596,158 @@ async def test_send_media_falls_back_to_upload_param_url(tmp_path) -> None:
|
||||
assert "&filekey=" in cdn_url
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_media_voice_file_uses_voice_item_and_voice_upload_type(tmp_path) -> None:
|
||||
channel, _bus = _make_channel()
|
||||
|
||||
media_file = tmp_path / "voice.mp3"
|
||||
media_file.write_bytes(b"voice-bytes")
|
||||
|
||||
cdn_post = AsyncMock(return_value=_DummyHttpResponse(headers={"x-encrypted-param": "voice-dl-param"}))
|
||||
channel._client = SimpleNamespace(post=cdn_post)
|
||||
channel._api_post = AsyncMock(
|
||||
side_effect=[
|
||||
{"upload_full_url": "https://upload-full.example.test/voice?foo=bar"},
|
||||
{"ret": 0},
|
||||
]
|
||||
)
|
||||
|
||||
await channel._send_media_file("wx-user", str(media_file), "ctx-voice")
|
||||
|
||||
getupload_body = channel._api_post.await_args_list[0].args[1]
|
||||
assert getupload_body["media_type"] == 4
|
||||
|
||||
sendmessage_body = channel._api_post.await_args_list[1].args[1]
|
||||
item = sendmessage_body["msg"]["item_list"][0]
|
||||
assert item["type"] == 3
|
||||
assert "voice_item" in item
|
||||
assert "file_item" not in item
|
||||
assert item["voice_item"]["media"]["encrypt_query_param"] == "voice-dl-param"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_typing_uses_keepalive_until_send_finishes() -> None:
|
||||
channel, _bus = _make_channel()
|
||||
channel._client = object()
|
||||
channel._token = "token"
|
||||
channel._context_tokens["wx-user"] = "ctx-typing-loop"
|
||||
async def _api_post_side_effect(endpoint: str, _body: dict | None = None, *, auth: bool = True):
|
||||
if endpoint == "ilink/bot/getconfig":
|
||||
return {"ret": 0, "typing_ticket": "ticket-keepalive"}
|
||||
return {"ret": 0}
|
||||
|
||||
channel._api_post = AsyncMock(side_effect=_api_post_side_effect)
|
||||
|
||||
async def _slow_send_text(*_args, **_kwargs) -> None:
|
||||
await asyncio.sleep(0.03)
|
||||
|
||||
channel._send_text = AsyncMock(side_effect=_slow_send_text)
|
||||
|
||||
old_interval = weixin_mod.TYPING_KEEPALIVE_INTERVAL_S
|
||||
weixin_mod.TYPING_KEEPALIVE_INTERVAL_S = 0.01
|
||||
try:
|
||||
await channel.send(
|
||||
type("Msg", (), {"chat_id": "wx-user", "content": "pong", "media": [], "metadata": {}})()
|
||||
)
|
||||
finally:
|
||||
weixin_mod.TYPING_KEEPALIVE_INTERVAL_S = old_interval
|
||||
|
||||
status_calls = [
|
||||
c.args[1]["status"]
|
||||
for c in channel._api_post.await_args_list
|
||||
if c.args and c.args[0] == "ilink/bot/sendtyping"
|
||||
]
|
||||
assert status_calls.count(1) >= 2
|
||||
assert status_calls[-1] == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_typing_ticket_failure_uses_backoff_and_cached_ticket(monkeypatch) -> None:
|
||||
channel, _bus = _make_channel()
|
||||
channel._client = object()
|
||||
channel._token = "token"
|
||||
|
||||
now = {"value": 1000.0}
|
||||
monkeypatch.setattr(weixin_mod.time, "time", lambda: now["value"])
|
||||
monkeypatch.setattr(weixin_mod.random, "random", lambda: 0.5)
|
||||
|
||||
channel._api_post = AsyncMock(return_value={"ret": 0, "typing_ticket": "ticket-ok"})
|
||||
first = await channel._get_typing_ticket("wx-user", "ctx-1")
|
||||
assert first == "ticket-ok"
|
||||
|
||||
# force refresh window reached
|
||||
now["value"] = now["value"] + (12 * 60 * 60) + 1
|
||||
channel._api_post = AsyncMock(return_value={"ret": 1, "errmsg": "temporary failure"})
|
||||
|
||||
# On refresh failure, should still return cached ticket and apply backoff.
|
||||
second = await channel._get_typing_ticket("wx-user", "ctx-2")
|
||||
assert second == "ticket-ok"
|
||||
assert channel._api_post.await_count == 1
|
||||
|
||||
# Before backoff expiry, no extra fetch should happen.
|
||||
now["value"] += 1
|
||||
third = await channel._get_typing_ticket("wx-user", "ctx-3")
|
||||
assert third == "ticket-ok"
|
||||
assert channel._api_post.await_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_qr_login_treats_temporary_connect_error_as_wait_and_recovers() -> None:
|
||||
channel, _bus = _make_channel()
|
||||
channel._running = True
|
||||
channel._save_state = lambda: None
|
||||
channel._print_qr_code = lambda url: None
|
||||
channel._fetch_qr_code = AsyncMock(return_value=("qr-1", "url-1"))
|
||||
|
||||
request = httpx.Request("GET", "https://ilinkai.weixin.qq.com/ilink/bot/get_qrcode_status")
|
||||
channel._api_get_with_base = AsyncMock(
|
||||
side_effect=[
|
||||
httpx.ConnectError("temporary network", request=request),
|
||||
{
|
||||
"status": "confirmed",
|
||||
"bot_token": "token-net-ok",
|
||||
"ilink_bot_id": "bot-id",
|
||||
"baseurl": "https://example.test",
|
||||
"ilink_user_id": "wx-user",
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
ok = await channel._qr_login()
|
||||
|
||||
assert ok is True
|
||||
assert channel._token == "token-net-ok"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_qr_login_treats_5xx_gateway_response_error_as_wait_and_recovers() -> None:
|
||||
channel, _bus = _make_channel()
|
||||
channel._running = True
|
||||
channel._save_state = lambda: None
|
||||
channel._print_qr_code = lambda url: None
|
||||
channel._fetch_qr_code = AsyncMock(return_value=("qr-1", "url-1"))
|
||||
|
||||
request = httpx.Request("GET", "https://ilinkai.weixin.qq.com/ilink/bot/get_qrcode_status")
|
||||
response = httpx.Response(status_code=524, request=request)
|
||||
channel._api_get_with_base = AsyncMock(
|
||||
side_effect=[
|
||||
httpx.HTTPStatusError("gateway timeout", request=request, response=response),
|
||||
{
|
||||
"status": "confirmed",
|
||||
"bot_token": "token-5xx-ok",
|
||||
"ilink_bot_id": "bot-id",
|
||||
"baseurl": "https://example.test",
|
||||
"ilink_user_id": "wx-user",
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
ok = await channel._qr_login()
|
||||
|
||||
assert ok is True
|
||||
assert channel._token == "token-5xx-ok"
|
||||
|
||||
|
||||
def test_decrypt_aes_ecb_strips_valid_pkcs7_padding() -> None:
|
||||
key_b64 = "MDEyMzQ1Njc4OWFiY2RlZg==" # base64("0123456789abcdef")
|
||||
plaintext = b"hello-weixin-padding"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user