diff --git a/nanobot/channels/weixin.py b/nanobot/channels/weixin.py index f09ef95f7..9e2caae3f 100644 --- a/nanobot/channels/weixin.py +++ b/nanobot/channels/weixin.py @@ -13,7 +13,6 @@ import asyncio import base64 import hashlib import json -import mimetypes import os import re import time @@ -124,6 +123,8 @@ class WeixinChannel(BaseChannel): self._poll_task: asyncio.Task | None = None self._next_poll_timeout_s: int = DEFAULT_LONG_POLL_TIMEOUT_S self._session_pause_until: float = 0.0 + self._typing_tasks: dict[str, asyncio.Task] = {} + self._typing_tickets: dict[str, str] = {} # ------------------------------------------------------------------ # State persistence @@ -158,6 +159,15 @@ class WeixinChannel(BaseChannel): } else: self._context_tokens = {} + typing_tickets = data.get("typing_tickets", {}) + if isinstance(typing_tickets, dict): + self._typing_tickets = { + str(user_id): str(ticket) + for user_id, ticket in typing_tickets.items() + if str(user_id).strip() and str(ticket).strip() + } + else: + self._typing_tickets = {} base_url = data.get("base_url", "") if base_url: self.config.base_url = base_url @@ -173,6 +183,7 @@ class WeixinChannel(BaseChannel): "token": self._token, "get_updates_buf": self._get_updates_buf, "context_tokens": self._context_tokens, + "typing_tickets": self._typing_tickets, "base_url": self.config.base_url, } state_file.write_text(json.dumps(data, ensure_ascii=False)) @@ -415,6 +426,8 @@ class WeixinChannel(BaseChannel): self._running = False if self._poll_task and not self._poll_task.done(): self._poll_task.cancel() + for chat_id in list(self._typing_tasks): + await self._stop_typing(chat_id, clear_remote=False) if self._client: await self._client.aclose() self._client = None @@ -631,6 +644,8 @@ class WeixinChannel(BaseChannel): len(content), ) + await self._start_typing(from_user_id, ctx_token) + await self._handle_message( sender_id=from_user_id, chat_id=from_user_id, @@ -720,6 +735,10 @@ class WeixinChannel(BaseChannel): logger.warning("WeChat send blocked: {}", e) return + is_progress = bool((msg.metadata or {}).get("_progress", False)) + if not is_progress: + await self._stop_typing(msg.chat_id, clear_remote=True) + content = msg.content.strip() ctx_token = self._context_tokens.get(msg.chat_id, "") if not ctx_token: @@ -753,6 +772,85 @@ class WeixinChannel(BaseChannel): logger.error("Error sending WeChat message: {}", e) raise + async def _get_typing_ticket(self, user_id: str, context_token: str) -> str: + """Fetch and cache typing ticket for a user/context pair.""" + if not self._client or not self._token or not user_id or not context_token: + return "" + cached = self._typing_tickets.get(user_id, "") + if cached: + return cached + try: + data = await self._api_post( + "ilink/bot/getconfig", + { + "ilink_user_id": user_id, + "context_token": context_token, + }, + ) + except Exception as e: + logger.debug("WeChat getconfig failed for {}: {}", user_id, e) + return "" + ticket = str(data.get("typing_ticket") or "").strip() + if ticket: + self._typing_tickets[user_id] = ticket + self._save_state() + return ticket + + async def _send_typing_status(self, to_user_id: str, typing_ticket: str, status: int) -> None: + if not typing_ticket: + return + await self._api_post( + "ilink/bot/sendtyping", + { + "ilink_user_id": to_user_id, + "typing_ticket": typing_ticket, + "status": status, + }, + ) + + async def _start_typing(self, chat_id: str, context_token: str) -> None: + if not self._client or not self._token or not chat_id or not context_token: + return + await self._stop_typing(chat_id, clear_remote=False) + ticket = await self._get_typing_ticket(chat_id, context_token) + if not ticket: + return + try: + await self._send_typing_status(chat_id, ticket, 1) + except Exception as e: + logger.debug("WeChat typing indicator failed for {}: {}", chat_id, e) + return + + async def typing_loop() -> None: + try: + while self._running: + await asyncio.sleep(5) + await self._send_typing_status(chat_id, ticket, 1) + except asyncio.CancelledError: + pass + except Exception as e: + logger.debug("WeChat typing keepalive stopped for {}: {}", chat_id, e) + + self._typing_tasks[chat_id] = asyncio.create_task(typing_loop()) + + async def _stop_typing(self, chat_id: str, *, clear_remote: bool) -> None: + task = self._typing_tasks.pop(chat_id, None) + if task and not task.done(): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + if not clear_remote: + return + ticket = self._typing_tickets.get(chat_id, "") + if not ticket: + return + try: + await self._send_typing_status(chat_id, ticket, 2) + except Exception as e: + logger.debug("WeChat typing clear failed for {}: {}", chat_id, e) + async def _send_text( self, to_user_id: str, diff --git a/tests/channels/test_weixin_channel.py b/tests/channels/test_weixin_channel.py index 54d9bd93f..35b01db8b 100644 --- a/tests/channels/test_weixin_channel.py +++ b/tests/channels/test_weixin_channel.py @@ -278,3 +278,77 @@ async def test_process_message_skips_bot_messages() -> None: ) assert bus.inbound_size == 0 + + +@pytest.mark.asyncio +async def test_process_message_fetches_typing_ticket_and_starts_typing() -> None: + channel, _bus = _make_channel() + channel._running = True + channel._client = object() + channel._token = "token" + channel._api_post = AsyncMock(return_value={"typing_ticket": "ticket-1"}) + + await channel._process_message( + { + "message_type": 1, + "message_id": "m-typing", + "from_user_id": "wx-user", + "context_token": "ctx-typing", + "item_list": [ + {"type": ITEM_TEXT, "text_item": {"text": "hello"}}, + ], + } + ) + + assert channel._typing_tickets["wx-user"] == "ticket-1" + assert "wx-user" in channel._typing_tasks + await channel._stop_typing("wx-user", clear_remote=False) + + +@pytest.mark.asyncio +async def test_send_final_message_clears_typing_indicator() -> None: + channel, _bus = _make_channel() + channel._client = object() + channel._token = "token" + channel._context_tokens["wx-user"] = "ctx-2" + channel._typing_tickets["wx-user"] = "ticket-2" + channel._send_text = AsyncMock() + channel._api_post = AsyncMock(return_value={}) + + await channel.send( + type("Msg", (), {"chat_id": "wx-user", "content": "pong", "media": [], "metadata": {}})() + ) + + channel._send_text.assert_awaited_once_with("wx-user", "pong", "ctx-2") + channel._api_post.assert_awaited_once() + endpoint, body = channel._api_post.await_args.args + assert endpoint == "ilink/bot/sendtyping" + assert body["status"] == 2 + assert body["typing_ticket"] == "ticket-2" + + +@pytest.mark.asyncio +async def test_send_progress_message_keeps_typing_indicator() -> None: + channel, _bus = _make_channel() + channel._client = object() + channel._token = "token" + channel._context_tokens["wx-user"] = "ctx-2" + channel._typing_tickets["wx-user"] = "ticket-2" + channel._send_text = AsyncMock() + channel._api_post = AsyncMock(return_value={}) + + await channel.send( + type( + "Msg", + (), + { + "chat_id": "wx-user", + "content": "thinking", + "media": [], + "metadata": {"_progress": True}, + }, + )() + ) + + channel._send_text.assert_awaited_once_with("wx-user", "thinking", "ctx-2") + channel._api_post.assert_not_awaited()