mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-24 20:06:03 +00:00
feat(weixin): implement QR redirect handling
This commit is contained in:
parent
b1d5475681
commit
0207b541df
@ -259,6 +259,25 @@ class WeixinChannel(BaseChannel):
|
|||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
return resp.json()
|
return resp.json()
|
||||||
|
|
||||||
|
async def _api_get_with_base(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
base_url: str,
|
||||||
|
endpoint: str,
|
||||||
|
params: dict | None = None,
|
||||||
|
auth: bool = True,
|
||||||
|
extra_headers: dict[str, str] | None = None,
|
||||||
|
) -> dict:
|
||||||
|
"""GET helper that allows overriding base_url for QR redirect polling."""
|
||||||
|
assert self._client is not None
|
||||||
|
url = f"{base_url.rstrip('/')}/{endpoint}"
|
||||||
|
hdrs = self._make_headers(auth=auth)
|
||||||
|
if extra_headers:
|
||||||
|
hdrs.update(extra_headers)
|
||||||
|
resp = await self._client.get(url, params=params, headers=hdrs)
|
||||||
|
resp.raise_for_status()
|
||||||
|
return resp.json()
|
||||||
|
|
||||||
async def _api_post(
|
async def _api_post(
|
||||||
self,
|
self,
|
||||||
endpoint: str,
|
endpoint: str,
|
||||||
@ -299,12 +318,14 @@ class WeixinChannel(BaseChannel):
|
|||||||
refresh_count = 0
|
refresh_count = 0
|
||||||
qrcode_id, scan_url = await self._fetch_qr_code()
|
qrcode_id, scan_url = await self._fetch_qr_code()
|
||||||
self._print_qr_code(scan_url)
|
self._print_qr_code(scan_url)
|
||||||
|
current_poll_base_url = self.config.base_url
|
||||||
|
|
||||||
logger.info("Waiting for QR code scan...")
|
logger.info("Waiting for QR code scan...")
|
||||||
while self._running:
|
while self._running:
|
||||||
try:
|
try:
|
||||||
status_data = await self._api_get(
|
status_data = await self._api_get_with_base(
|
||||||
"ilink/bot/get_qrcode_status",
|
base_url=current_poll_base_url,
|
||||||
|
endpoint="ilink/bot/get_qrcode_status",
|
||||||
params={"qrcode": qrcode_id},
|
params={"qrcode": qrcode_id},
|
||||||
auth=False,
|
auth=False,
|
||||||
)
|
)
|
||||||
@ -333,6 +354,23 @@ class WeixinChannel(BaseChannel):
|
|||||||
return False
|
return False
|
||||||
elif status == "scaned":
|
elif status == "scaned":
|
||||||
logger.info("QR code scanned, waiting for confirmation...")
|
logger.info("QR code scanned, waiting for confirmation...")
|
||||||
|
elif status == "scaned_but_redirect":
|
||||||
|
redirect_host = str(status_data.get("redirect_host", "") or "").strip()
|
||||||
|
if redirect_host:
|
||||||
|
if redirect_host.startswith("http://") or redirect_host.startswith("https://"):
|
||||||
|
redirected_base = redirect_host
|
||||||
|
else:
|
||||||
|
redirected_base = f"https://{redirect_host}"
|
||||||
|
if redirected_base != current_poll_base_url:
|
||||||
|
logger.info(
|
||||||
|
"QR status redirect: switching polling host to {}",
|
||||||
|
redirected_base,
|
||||||
|
)
|
||||||
|
current_poll_base_url = redirected_base
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"QR status returned scaned_but_redirect but redirect_host is missing",
|
||||||
|
)
|
||||||
elif status == "expired":
|
elif status == "expired":
|
||||||
refresh_count += 1
|
refresh_count += 1
|
||||||
if refresh_count > MAX_QR_REFRESH_COUNT:
|
if refresh_count > MAX_QR_REFRESH_COUNT:
|
||||||
|
|||||||
@ -227,8 +227,12 @@ async def test_qr_login_refreshes_expired_qr_and_then_succeeds() -> None:
|
|||||||
channel._api_get = AsyncMock(
|
channel._api_get = AsyncMock(
|
||||||
side_effect=[
|
side_effect=[
|
||||||
{"qrcode": "qr-1", "qrcode_img_content": "url-1"},
|
{"qrcode": "qr-1", "qrcode_img_content": "url-1"},
|
||||||
{"status": "expired"},
|
|
||||||
{"qrcode": "qr-2", "qrcode_img_content": "url-2"},
|
{"qrcode": "qr-2", "qrcode_img_content": "url-2"},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
channel._api_get_with_base = AsyncMock(
|
||||||
|
side_effect=[
|
||||||
|
{"status": "expired"},
|
||||||
{
|
{
|
||||||
"status": "confirmed",
|
"status": "confirmed",
|
||||||
"bot_token": "token-2",
|
"bot_token": "token-2",
|
||||||
@ -254,12 +258,16 @@ async def test_qr_login_returns_false_after_too_many_expired_qr_codes() -> None:
|
|||||||
channel._api_get = AsyncMock(
|
channel._api_get = AsyncMock(
|
||||||
side_effect=[
|
side_effect=[
|
||||||
{"qrcode": "qr-1", "qrcode_img_content": "url-1"},
|
{"qrcode": "qr-1", "qrcode_img_content": "url-1"},
|
||||||
{"status": "expired"},
|
|
||||||
{"qrcode": "qr-2", "qrcode_img_content": "url-2"},
|
{"qrcode": "qr-2", "qrcode_img_content": "url-2"},
|
||||||
{"status": "expired"},
|
|
||||||
{"qrcode": "qr-3", "qrcode_img_content": "url-3"},
|
{"qrcode": "qr-3", "qrcode_img_content": "url-3"},
|
||||||
{"status": "expired"},
|
|
||||||
{"qrcode": "qr-4", "qrcode_img_content": "url-4"},
|
{"qrcode": "qr-4", "qrcode_img_content": "url-4"},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
channel._api_get_with_base = AsyncMock(
|
||||||
|
side_effect=[
|
||||||
|
{"status": "expired"},
|
||||||
|
{"status": "expired"},
|
||||||
|
{"status": "expired"},
|
||||||
{"status": "expired"},
|
{"status": "expired"},
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@ -269,6 +277,70 @@ async def test_qr_login_returns_false_after_too_many_expired_qr_codes() -> None:
|
|||||||
assert ok is False
|
assert ok is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_qr_login_switches_polling_base_url_on_redirect_status() -> None:
|
||||||
|
channel, _bus = _make_channel()
|
||||||
|
channel._running = True
|
||||||
|
channel._save_state = lambda: None
|
||||||
|
channel._print_qr_code = lambda url: None
|
||||||
|
channel._fetch_qr_code = AsyncMock(return_value=("qr-1", "url-1"))
|
||||||
|
|
||||||
|
status_side_effect = [
|
||||||
|
{"status": "scaned_but_redirect", "redirect_host": "idc.redirect.test"},
|
||||||
|
{
|
||||||
|
"status": "confirmed",
|
||||||
|
"bot_token": "token-3",
|
||||||
|
"ilink_bot_id": "bot-3",
|
||||||
|
"baseurl": "https://example.test",
|
||||||
|
"ilink_user_id": "wx-user",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
channel._api_get = AsyncMock(side_effect=list(status_side_effect))
|
||||||
|
channel._api_get_with_base = AsyncMock(side_effect=list(status_side_effect))
|
||||||
|
|
||||||
|
ok = await channel._qr_login()
|
||||||
|
|
||||||
|
assert ok is True
|
||||||
|
assert channel._token == "token-3"
|
||||||
|
assert channel._api_get_with_base.await_count == 2
|
||||||
|
first_call = channel._api_get_with_base.await_args_list[0]
|
||||||
|
second_call = channel._api_get_with_base.await_args_list[1]
|
||||||
|
assert first_call.kwargs["base_url"] == "https://ilinkai.weixin.qq.com"
|
||||||
|
assert second_call.kwargs["base_url"] == "https://idc.redirect.test"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_qr_login_redirect_without_host_keeps_current_polling_base_url() -> None:
|
||||||
|
channel, _bus = _make_channel()
|
||||||
|
channel._running = True
|
||||||
|
channel._save_state = lambda: None
|
||||||
|
channel._print_qr_code = lambda url: None
|
||||||
|
channel._fetch_qr_code = AsyncMock(return_value=("qr-1", "url-1"))
|
||||||
|
|
||||||
|
status_side_effect = [
|
||||||
|
{"status": "scaned_but_redirect"},
|
||||||
|
{
|
||||||
|
"status": "confirmed",
|
||||||
|
"bot_token": "token-4",
|
||||||
|
"ilink_bot_id": "bot-4",
|
||||||
|
"baseurl": "https://example.test",
|
||||||
|
"ilink_user_id": "wx-user",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
channel._api_get = AsyncMock(side_effect=list(status_side_effect))
|
||||||
|
channel._api_get_with_base = AsyncMock(side_effect=list(status_side_effect))
|
||||||
|
|
||||||
|
ok = await channel._qr_login()
|
||||||
|
|
||||||
|
assert ok is True
|
||||||
|
assert channel._token == "token-4"
|
||||||
|
assert channel._api_get_with_base.await_count == 2
|
||||||
|
first_call = channel._api_get_with_base.await_args_list[0]
|
||||||
|
second_call = channel._api_get_with_base.await_args_list[1]
|
||||||
|
assert first_call.kwargs["base_url"] == "https://ilinkai.weixin.qq.com"
|
||||||
|
assert second_call.kwargs["base_url"] == "https://ilinkai.weixin.qq.com"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_process_message_skips_bot_messages() -> None:
|
async def test_process_message_skips_bot_messages() -> None:
|
||||||
channel, bus = _make_channel()
|
channel, bus = _make_channel()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user