diff --git a/nanobot/channels/weixin.py b/nanobot/channels/weixin.py index 891cfd099..2266bc9f0 100644 --- a/nanobot/channels/weixin.py +++ b/nanobot/channels/weixin.py @@ -13,7 +13,6 @@ import asyncio import base64 import hashlib import json -import mimetypes import os import random import re @@ -158,6 +157,7 @@ class WeixinChannel(BaseChannel): self._poll_task: asyncio.Task | None = None self._next_poll_timeout_s: int = DEFAULT_LONG_POLL_TIMEOUT_S self._session_pause_until: float = 0.0 + self._typing_tasks: dict[str, asyncio.Task] = {} self._typing_tickets: dict[str, dict[str, Any]] = {} # ------------------------------------------------------------------ @@ -193,6 +193,15 @@ class WeixinChannel(BaseChannel): } else: self._context_tokens = {} + typing_tickets = data.get("typing_tickets", {}) + if isinstance(typing_tickets, dict): + self._typing_tickets = { + str(user_id): ticket + for user_id, ticket in typing_tickets.items() + if str(user_id).strip() and isinstance(ticket, dict) + } + else: + self._typing_tickets = {} base_url = data.get("base_url", "") if base_url: self.config.base_url = base_url @@ -207,6 +216,7 @@ class WeixinChannel(BaseChannel): "token": self._token, "get_updates_buf": self._get_updates_buf, "context_tokens": self._context_tokens, + "typing_tickets": self._typing_tickets, "base_url": self.config.base_url, } state_file.write_text(json.dumps(data, ensure_ascii=False)) @@ -488,6 +498,8 @@ class WeixinChannel(BaseChannel): self._running = False if self._poll_task and not self._poll_task.done(): self._poll_task.cancel() + for chat_id in list(self._typing_tasks): + await self._stop_typing(chat_id, clear_remote=False) if self._client: await self._client.aclose() self._client = None @@ -746,6 +758,15 @@ class WeixinChannel(BaseChannel): if not content: return + logger.info( + "WeChat inbound: from={} items={} bodyLen={}", + from_user_id, + ",".join(str(i.get("type", 0)) for i in item_list), + len(content), + ) + + await self._start_typing(from_user_id, ctx_token) + await self._handle_message( sender_id=from_user_id, chat_id=from_user_id, @@ -927,6 +948,10 @@ class WeixinChannel(BaseChannel): except RuntimeError: return + is_progress = bool((msg.metadata or {}).get("_progress", False)) + if not is_progress: + await self._stop_typing(msg.chat_id, clear_remote=True) + content = msg.content.strip() ctx_token = self._context_tokens.get(msg.chat_id, "") if not ctx_token: @@ -987,12 +1012,68 @@ class WeixinChannel(BaseChannel): except asyncio.CancelledError: pass - if typing_ticket: + if typing_ticket and not is_progress: try: await self._send_typing(msg.chat_id, typing_ticket, TYPING_STATUS_CANCEL) except Exception: pass + async def _start_typing(self, chat_id: str, context_token: str = "") -> None: + """Start typing indicator immediately when a message is received.""" + if not self._client or not self._token or not chat_id: + return + await self._stop_typing(chat_id, clear_remote=False) + try: + ticket = await self._get_typing_ticket(chat_id, context_token) + if not ticket: + return + await self._send_typing(chat_id, ticket, TYPING_STATUS_TYPING) + except Exception as e: + logger.debug("WeChat typing indicator start failed for {}: {}", chat_id, e) + return + + stop_event = asyncio.Event() + + async def keepalive() -> None: + try: + while not stop_event.is_set(): + await asyncio.sleep(TYPING_KEEPALIVE_INTERVAL_S) + if stop_event.is_set(): + break + try: + await self._send_typing(chat_id, ticket, TYPING_STATUS_TYPING) + except Exception: + pass + finally: + pass + + task = asyncio.create_task(keepalive()) + task._typing_stop_event = stop_event # type: ignore[attr-defined] + self._typing_tasks[chat_id] = task + + async def _stop_typing(self, chat_id: str, *, clear_remote: bool) -> None: + """Stop typing indicator for a chat.""" + task = self._typing_tasks.pop(chat_id, None) + if task and not task.done(): + stop_event = getattr(task, "_typing_stop_event", None) + if stop_event: + stop_event.set() + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + if not clear_remote: + return + entry = self._typing_tickets.get(chat_id) + ticket = str(entry.get("ticket", "") or "") if isinstance(entry, dict) else "" + if not ticket: + return + try: + await self._send_typing(chat_id, ticket, TYPING_STATUS_CANCEL) + except Exception as e: + logger.debug("WeChat typing clear failed for {}: {}", chat_id, e) + async def _send_text( self, to_user_id: str, diff --git a/tests/channels/test_weixin_channel.py b/tests/channels/test_weixin_channel.py index 58fc30865..3a847411b 100644 --- a/tests/channels/test_weixin_channel.py +++ b/tests/channels/test_weixin_channel.py @@ -572,6 +572,85 @@ async def test_process_message_skips_bot_messages() -> None: assert bus.inbound_size == 0 +@pytest.mark.asyncio +async def test_process_message_starts_typing_on_inbound() -> None: + """Typing indicator fires immediately when user message arrives.""" + channel, _bus = _make_channel() + channel._running = True + channel._client = object() + channel._token = "token" + channel._start_typing = AsyncMock() + + await channel._process_message( + { + "message_type": 1, + "message_id": "m-typing", + "from_user_id": "wx-user", + "context_token": "ctx-typing", + "item_list": [ + {"type": ITEM_TEXT, "text_item": {"text": "hello"}}, + ], + } + ) + + channel._start_typing.assert_awaited_once_with("wx-user", "ctx-typing") + + +@pytest.mark.asyncio +async def test_send_final_message_clears_typing_indicator() -> None: + """Non-progress send should cancel typing status.""" + channel, _bus = _make_channel() + channel._client = object() + channel._token = "token" + channel._context_tokens["wx-user"] = "ctx-2" + channel._typing_tickets["wx-user"] = {"ticket": "ticket-2", "next_fetch_at": 9999999999} + channel._send_text = AsyncMock() + channel._api_post = AsyncMock(return_value={"ret": 0}) + + await channel.send( + type("Msg", (), {"chat_id": "wx-user", "content": "pong", "media": [], "metadata": {}})() + ) + + channel._send_text.assert_awaited_once_with("wx-user", "pong", "ctx-2") + typing_cancel_calls = [ + c for c in channel._api_post.await_args_list + if c.args[0] == "ilink/bot/sendtyping" and c.args[1]["status"] == 2 + ] + assert len(typing_cancel_calls) >= 1 + + +@pytest.mark.asyncio +async def test_send_progress_message_keeps_typing_indicator() -> None: + """Progress messages must not cancel typing status.""" + channel, _bus = _make_channel() + channel._client = object() + channel._token = "token" + channel._context_tokens["wx-user"] = "ctx-2" + channel._typing_tickets["wx-user"] = {"ticket": "ticket-2", "next_fetch_at": 9999999999} + channel._send_text = AsyncMock() + channel._api_post = AsyncMock(return_value={"ret": 0}) + + await channel.send( + type( + "Msg", + (), + { + "chat_id": "wx-user", + "content": "thinking", + "media": [], + "metadata": {"_progress": True}, + }, + )() + ) + + channel._send_text.assert_awaited_once_with("wx-user", "thinking", "ctx-2") + typing_cancel_calls = [ + c for c in channel._api_post.await_args_list + if c.args and c.args[0] == "ilink/bot/sendtyping" and c.args[1].get("status") == 2 + ] + assert len(typing_cancel_calls) == 0 + + class _DummyHttpResponse: def __init__(self, *, headers: dict[str, str] | None = None, status_code: int = 200) -> None: self.headers = headers or {}