tags)
+ if cls._MD_LINK_RE.search(stripped):
+ return "post"
+
+ # Short plain text → text format
+ if len(stripped) <= cls._TEXT_MAX_LEN:
+ return "text"
+
+ # Medium plain text without any formatting → post format
+ return "post"
+
+ @classmethod
+ def _markdown_to_post(cls, content: str) -> str:
+ """Convert markdown content to Feishu post message JSON.
+
+ Handles links ``[text](url)`` as ``a`` tags; everything else as ``text`` tags.
+ Each line becomes a paragraph (row) in the post body.
+ """
+ lines = content.strip().split("\n")
+ paragraphs: list[list[dict]] = []
+
+ for line in lines:
+ elements: list[dict] = []
+ last_end = 0
+
+ for m in cls._MD_LINK_RE.finditer(line):
+ # Text before this link
+ before = line[last_end:m.start()]
+ if before:
+ elements.append({"tag": "text", "text": before})
+ elements.append({
+ "tag": "a",
+ "text": m.group(1),
+ "href": m.group(2),
+ })
+ last_end = m.end()
+
+ # Remaining text after last link
+ remaining = line[last_end:]
+ if remaining:
+ elements.append({"tag": "text", "text": remaining})
+
+ # Empty line → empty paragraph for spacing
+ if not elements:
+ elements.append({"tag": "text", "text": ""})
+
+ paragraphs.append(elements)
+
+ post_body = {
+ "zh_cn": {
+ "content": paragraphs,
+ }
+ }
+ return json.dumps(post_body, ensure_ascii=False)
+
+ _IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp", ".ico", ".tiff", ".tif"}
+ _AUDIO_EXTS = {".opus"}
+ _VIDEO_EXTS = {".mp4", ".mov", ".avi"}
+ _FILE_TYPE_MAP = {
+ ".opus": "opus", ".mp4": "mp4", ".pdf": "pdf", ".doc": "doc", ".docx": "doc",
+ ".xls": "xls", ".xlsx": "xls", ".ppt": "ppt", ".pptx": "ppt",
+ }
+
+ def _upload_image_sync(self, file_path: str) -> str | None:
+ """Upload an image to Feishu and return the image_key."""
+ from lark_oapi.api.im.v1 import CreateImageRequest, CreateImageRequestBody
+ try:
+ with open(file_path, "rb") as f:
+ request = CreateImageRequest.builder() \
+ .request_body(
+ CreateImageRequestBody.builder()
+ .image_type("message")
+ .image(f)
+ .build()
+ ).build()
+ response = self._client.im.v1.image.create(request)
+ if response.success():
+ image_key = response.data.image_key
+ logger.debug("Uploaded image {}: {}", os.path.basename(file_path), image_key)
+ return image_key
+ else:
+ logger.error("Failed to upload image: code={}, msg={}", response.code, response.msg)
+ return None
+ except Exception as e:
+ logger.error("Error uploading image {}: {}", file_path, e)
+ return None
+
+ def _upload_file_sync(self, file_path: str) -> str | None:
+ """Upload a file to Feishu and return the file_key."""
+ from lark_oapi.api.im.v1 import CreateFileRequest, CreateFileRequestBody
+ ext = os.path.splitext(file_path)[1].lower()
+ file_type = self._FILE_TYPE_MAP.get(ext, "stream")
+ file_name = os.path.basename(file_path)
+ try:
+ with open(file_path, "rb") as f:
+ request = CreateFileRequest.builder() \
+ .request_body(
+ CreateFileRequestBody.builder()
+ .file_type(file_type)
+ .file_name(file_name)
+ .file(f)
+ .build()
+ ).build()
+ response = self._client.im.v1.file.create(request)
+ if response.success():
+ file_key = response.data.file_key
+ logger.debug("Uploaded file {}: {}", file_name, file_key)
+ return file_key
+ else:
+ logger.error("Failed to upload file: code={}, msg={}", response.code, response.msg)
+ return None
+ except Exception as e:
+ logger.error("Error uploading file {}: {}", file_path, e)
+ return None
+
+ def _download_image_sync(self, message_id: str, image_key: str) -> tuple[bytes | None, str | None]:
+ """Download an image from Feishu message by message_id and image_key."""
+ from lark_oapi.api.im.v1 import GetMessageResourceRequest
+ try:
+ request = GetMessageResourceRequest.builder() \
+ .message_id(message_id) \
+ .file_key(image_key) \
+ .type("image") \
+ .build()
+ response = self._client.im.v1.message_resource.get(request)
+ if response.success():
+ file_data = response.file
+ # GetMessageResourceRequest returns BytesIO, need to read bytes
+ if hasattr(file_data, 'read'):
+ file_data = file_data.read()
+ return file_data, response.file_name
+ else:
+ logger.error("Failed to download image: code={}, msg={}", response.code, response.msg)
+ return None, None
+ except Exception as e:
+ logger.error("Error downloading image {}: {}", image_key, e)
+ return None, None
+
+ def _download_file_sync(
+ self, message_id: str, file_key: str, resource_type: str = "file"
+ ) -> tuple[bytes | None, str | None]:
+ """Download a file/audio/media from a Feishu message by message_id and file_key."""
+ from lark_oapi.api.im.v1 import GetMessageResourceRequest
+
+ # Feishu resource download API only accepts 'image' or 'file' as type.
+ # Both 'audio' and 'media' (video) messages use type='file' for download.
+ if resource_type in ("audio", "media"):
+ resource_type = "file"
+
+ try:
+ request = (
+ GetMessageResourceRequest.builder()
+ .message_id(message_id)
+ .file_key(file_key)
+ .type(resource_type)
+ .build()
+ )
+ response = self._client.im.v1.message_resource.get(request)
+ if response.success():
+ file_data = response.file
+ if hasattr(file_data, "read"):
+ file_data = file_data.read()
+ return file_data, response.file_name
+ else:
+ logger.error("Failed to download {}: code={}, msg={}", resource_type, response.code, response.msg)
+ return None, None
+ except Exception:
+ logger.exception("Error downloading {} {}", resource_type, file_key)
+ return None, None
+
+ async def _download_and_save_media(
+ self,
+ msg_type: str,
+ content_json: dict,
+ message_id: str | None = None
+ ) -> tuple[str | None, str]:
+ """
+ Download media from Feishu and save to local disk.
+
+ Returns:
+ (file_path, content_text) - file_path is None if download failed
+ """
+ loop = asyncio.get_running_loop()
+ media_dir = get_media_dir("feishu")
+
+ data, filename = None, None
+
+ if msg_type == "image":
+ image_key = content_json.get("image_key")
+ if image_key and message_id:
+ data, filename = await loop.run_in_executor(
+ None, self._download_image_sync, message_id, image_key
+ )
+ if not filename:
+ filename = f"{image_key[:16]}.jpg"
+
+ elif msg_type in ("audio", "file", "media"):
+ file_key = content_json.get("file_key")
+ if file_key and message_id:
+ data, filename = await loop.run_in_executor(
+ None, self._download_file_sync, message_id, file_key, msg_type
+ )
+ if not filename:
+ filename = file_key[:16]
+ if msg_type == "audio" and not filename.endswith(".opus"):
+ filename = f"{filename}.opus"
+
+ if data and filename:
+ file_path = media_dir / filename
+ file_path.write_bytes(data)
+ logger.debug("Downloaded {} to {}", msg_type, file_path)
+ return str(file_path), f"[{msg_type}: {filename}]"
+
+ return None, f"[{msg_type}: download failed]"
+
+ _REPLY_CONTEXT_MAX_LEN = 200
+
+ def _get_message_content_sync(self, message_id: str) -> str | None:
+ """Fetch the text content of a Feishu message by ID (synchronous).
+
+ Returns a "[Reply to: ...]" context string, or None on failure.
+ """
+ from lark_oapi.api.im.v1 import GetMessageRequest
+ try:
+ request = GetMessageRequest.builder().message_id(message_id).build()
+ response = self._client.im.v1.message.get(request)
+ if not response.success():
+ logger.debug(
+ "Feishu: could not fetch parent message {}: code={}, msg={}",
+ message_id, response.code, response.msg,
+ )
+ return None
+ items = getattr(response.data, "items", None)
+ if not items:
+ return None
+ msg_obj = items[0]
+ raw_content = getattr(msg_obj, "body", None)
+ raw_content = getattr(raw_content, "content", None) if raw_content else None
+ if not raw_content:
+ return None
+ try:
+ content_json = json.loads(raw_content)
+ except (json.JSONDecodeError, TypeError):
+ return None
+ msg_type = getattr(msg_obj, "msg_type", "")
+ if msg_type == "text":
+ text = content_json.get("text", "").strip()
+ elif msg_type == "post":
+ text, _ = _extract_post_content(content_json)
+ text = text.strip()
+ else:
+ text = ""
+ if not text:
+ return None
+ if len(text) > self._REPLY_CONTEXT_MAX_LEN:
+ text = text[: self._REPLY_CONTEXT_MAX_LEN] + "..."
+ return f"[Reply to: {text}]"
+ except Exception as e:
+ logger.debug("Feishu: error fetching parent message {}: {}", message_id, e)
+ return None
+
+ def _reply_message_sync(self, parent_message_id: str, msg_type: str, content: str) -> bool:
+ """Reply to an existing Feishu message using the Reply API (synchronous)."""
+ from lark_oapi.api.im.v1 import ReplyMessageRequest, ReplyMessageRequestBody
+ try:
+ request = ReplyMessageRequest.builder() \
+ .message_id(parent_message_id) \
+ .request_body(
+ ReplyMessageRequestBody.builder()
+ .msg_type(msg_type)
+ .content(content)
+ .build()
+ ).build()
+ response = self._client.im.v1.message.reply(request)
+ if not response.success():
+ logger.error(
+ "Failed to reply to Feishu message {}: code={}, msg={}, log_id={}",
+ parent_message_id, response.code, response.msg, response.get_log_id()
+ )
+ return False
+ logger.debug("Feishu reply sent to message {}", parent_message_id)
+ return True
+ except Exception as e:
+ logger.error("Error replying to Feishu message {}: {}", parent_message_id, e)
+ return False
+
+ def _send_message_sync(self, receive_id_type: str, receive_id: str, msg_type: str, content: str) -> str | None:
+ """Send a single message and return the message_id on success."""
+ from lark_oapi.api.im.v1 import CreateMessageRequest, CreateMessageRequestBody
+ try:
+ request = CreateMessageRequest.builder() \
+ .receive_id_type(receive_id_type) \
+ .request_body(
+ CreateMessageRequestBody.builder()
+ .receive_id(receive_id)
+ .msg_type(msg_type)
+ .content(content)
+ .build()
+ ).build()
+ response = self._client.im.v1.message.create(request)
+ if not response.success():
+ logger.error(
+ "Failed to send Feishu {} message: code={}, msg={}, log_id={}",
+ msg_type, response.code, response.msg, response.get_log_id()
+ )
+ return None
+ msg_id = getattr(response.data, "message_id", None)
+ logger.debug("Feishu {} message sent to {}: {}", msg_type, receive_id, msg_id)
+ return msg_id
+ except Exception as e:
+ logger.error("Error sending Feishu {} message: {}", msg_type, e)
+ return None
+
+ def _create_streaming_card_sync(self, receive_id_type: str, chat_id: str) -> str | None:
+ """Create a CardKit streaming card, send it to chat, return card_id."""
+ from lark_oapi.api.cardkit.v1 import CreateCardRequest, CreateCardRequestBody
+ card_json = {
+ "schema": "2.0",
+ "config": {"wide_screen_mode": True, "update_multi": True, "streaming_mode": True},
+ "body": {"elements": [{"tag": "markdown", "content": "", "element_id": _STREAM_ELEMENT_ID}]},
+ }
+ try:
+ request = CreateCardRequest.builder().request_body(
+ CreateCardRequestBody.builder()
+ .type("card_json")
+ .data(json.dumps(card_json, ensure_ascii=False))
+ .build()
+ ).build()
+ response = self._client.cardkit.v1.card.create(request)
+ if not response.success():
+ logger.warning("Failed to create streaming card: code={}, msg={}", response.code, response.msg)
+ return None
+ card_id = getattr(response.data, "card_id", None)
+ if card_id:
+ message_id = self._send_message_sync(
+ receive_id_type, chat_id, "interactive",
+ json.dumps({"type": "card", "data": {"card_id": card_id}}),
+ )
+ if message_id:
+ return card_id
+ logger.warning("Created streaming card {} but failed to send it to {}", card_id, chat_id)
+ return None
+ except Exception as e:
+ logger.warning("Error creating streaming card: {}", e)
+ return None
+
+ def _stream_update_text_sync(self, card_id: str, content: str, sequence: int) -> bool:
+ """Stream-update the markdown element on a CardKit card (typewriter effect)."""
+ from lark_oapi.api.cardkit.v1 import ContentCardElementRequest, ContentCardElementRequestBody
+ try:
+ request = ContentCardElementRequest.builder() \
+ .card_id(card_id) \
+ .element_id(_STREAM_ELEMENT_ID) \
+ .request_body(
+ ContentCardElementRequestBody.builder()
+ .content(content).sequence(sequence).build()
+ ).build()
+ response = self._client.cardkit.v1.card_element.content(request)
+ if not response.success():
+ logger.warning("Failed to stream-update card {}: code={}, msg={}", card_id, response.code, response.msg)
+ return False
+ return True
+ except Exception as e:
+ logger.warning("Error stream-updating card {}: {}", card_id, e)
+ return False
+
+ def _close_streaming_mode_sync(self, card_id: str, sequence: int) -> bool:
+ """Turn off CardKit streaming_mode so the chat list preview exits the streaming placeholder.
+
+ Per Feishu docs, streaming cards keep a generating-style summary in the session list until
+ streaming_mode is set to false via card settings (after final content update).
+ Sequence must strictly exceed the previous card OpenAPI operation on this entity.
+ """
+ from lark_oapi.api.cardkit.v1 import SettingsCardRequest, SettingsCardRequestBody
+ settings_payload = json.dumps({"config": {"streaming_mode": False}}, ensure_ascii=False)
+ try:
+ request = SettingsCardRequest.builder() \
+ .card_id(card_id) \
+ .request_body(
+ SettingsCardRequestBody.builder()
+ .settings(settings_payload)
+ .sequence(sequence)
+ .uuid(str(uuid.uuid4()))
+ .build()
+ ).build()
+ response = self._client.cardkit.v1.card.settings(request)
+ if not response.success():
+ logger.warning(
+ "Failed to close streaming on card {}: code={}, msg={}",
+ card_id, response.code, response.msg,
+ )
+ return False
+ return True
+ except Exception as e:
+ logger.warning("Error closing streaming on card {}: {}", card_id, e)
+ return False
+
+ async def send_delta(self, chat_id: str, delta: str, metadata: dict[str, Any] | None = None) -> None:
+ """Progressive streaming via CardKit: create card on first delta, stream-update on subsequent."""
+ if not self._client:
+ return
+ meta = metadata or {}
+ loop = asyncio.get_running_loop()
+ rid_type = "chat_id" if chat_id.startswith("oc_") else "open_id"
+
+ # --- stream end: final update or fallback ---
+ if meta.get("_stream_end"):
+ if (message_id := meta.get("message_id")) and (reaction_id := meta.get("reaction_id")):
+ await self._remove_reaction(message_id, reaction_id)
+
+ buf = self._stream_bufs.pop(chat_id, None)
+ if not buf or not buf.text:
+ return
+ if buf.card_id:
+ buf.sequence += 1
+ await loop.run_in_executor(
+ None, self._stream_update_text_sync, buf.card_id, buf.text, buf.sequence,
+ )
+ # Required so the chat list preview exits the streaming placeholder (Feishu streaming card docs).
+ buf.sequence += 1
+ await loop.run_in_executor(
+ None, self._close_streaming_mode_sync, buf.card_id, buf.sequence,
+ )
+ else:
+ for chunk in self._split_elements_by_table_limit(self._build_card_elements(buf.text)):
+ card = json.dumps({"config": {"wide_screen_mode": True}, "elements": chunk}, ensure_ascii=False)
+ await loop.run_in_executor(None, self._send_message_sync, rid_type, chat_id, "interactive", card)
+ return
+
+ # --- accumulate delta ---
+ buf = self._stream_bufs.get(chat_id)
+ if buf is None:
+ buf = _FeishuStreamBuf()
+ self._stream_bufs[chat_id] = buf
+ buf.text += delta
+ if not buf.text.strip():
+ return
+
+ now = time.monotonic()
+ if buf.card_id is None:
+ card_id = await loop.run_in_executor(None, self._create_streaming_card_sync, rid_type, chat_id)
+ if card_id:
+ buf.card_id = card_id
+ buf.sequence = 1
+ await loop.run_in_executor(None, self._stream_update_text_sync, card_id, buf.text, 1)
+ buf.last_edit = now
+ elif (now - buf.last_edit) >= self._STREAM_EDIT_INTERVAL:
+ buf.sequence += 1
+ await loop.run_in_executor(None, self._stream_update_text_sync, buf.card_id, buf.text, buf.sequence)
+ buf.last_edit = now
+
+ async def send(self, msg: OutboundMessage) -> None:
+ """Send a message through Feishu, including media (images/files) if present."""
+ if not self._client:
+ logger.warning("Feishu client not initialized")
+ return
+
+ try:
+ receive_id_type = "chat_id" if msg.chat_id.startswith("oc_") else "open_id"
+ loop = asyncio.get_running_loop()
+
+ # Handle tool hint messages as code blocks in interactive cards.
+ # These are progress-only messages and should bypass normal reply routing.
+ if msg.metadata.get("_tool_hint"):
+ if msg.content and msg.content.strip():
+ await self._send_tool_hint_card(
+ receive_id_type, msg.chat_id, msg.content.strip()
+ )
+ return
+
+ # Determine whether the first message should quote the user's message.
+ # Only the very first send (media or text) in this call uses reply; subsequent
+ # chunks/media fall back to plain create to avoid redundant quote bubbles.
+ reply_message_id: str | None = None
+ if (
+ self.config.reply_to_message
+ and not msg.metadata.get("_progress", False)
+ ):
+ reply_message_id = msg.metadata.get("message_id") or None
+ # For topic group messages, always reply to keep context in thread
+ elif msg.metadata.get("thread_id"):
+ reply_message_id = msg.metadata.get("root_id") or msg.metadata.get("message_id") or None
+
+ first_send = True # tracks whether the reply has already been used
+
+ def _do_send(m_type: str, content: str) -> None:
+ """Send via reply (first message) or create (subsequent)."""
+ nonlocal first_send
+ if reply_message_id and first_send:
+ first_send = False
+ ok = self._reply_message_sync(reply_message_id, m_type, content)
+ if ok:
+ return
+ # Fall back to regular send if reply fails
+ self._send_message_sync(receive_id_type, msg.chat_id, m_type, content)
+
+ for file_path in msg.media:
+ if not os.path.isfile(file_path):
+ logger.warning("Media file not found: {}", file_path)
+ continue
+ ext = os.path.splitext(file_path)[1].lower()
+ if ext in self._IMAGE_EXTS:
+ key = await loop.run_in_executor(None, self._upload_image_sync, file_path)
+ if key:
+ await loop.run_in_executor(
+ None, _do_send,
+ "image", json.dumps({"image_key": key}, ensure_ascii=False),
+ )
+ else:
+ key = await loop.run_in_executor(None, self._upload_file_sync, file_path)
+ if key:
+ # Use msg_type "audio" for audio, "video" for video, "file" for documents.
+ # Feishu requires these specific msg_types for inline playback.
+ # Note: "media" is only valid as a tag inside "post" messages, not as a standalone msg_type.
+ if ext in self._AUDIO_EXTS:
+ media_type = "audio"
+ elif ext in self._VIDEO_EXTS:
+ media_type = "video"
+ else:
+ media_type = "file"
+ await loop.run_in_executor(
+ None, _do_send,
+ media_type, json.dumps({"file_key": key}, ensure_ascii=False),
+ )
+
+ if msg.content and msg.content.strip():
+ fmt = self._detect_msg_format(msg.content)
+
+ if fmt == "text":
+ # Short plain text – send as simple text message
+ text_body = json.dumps({"text": msg.content.strip()}, ensure_ascii=False)
+ await loop.run_in_executor(None, _do_send, "text", text_body)
+
+ elif fmt == "post":
+ # Medium content with links – send as rich-text post
+ post_body = self._markdown_to_post(msg.content)
+ await loop.run_in_executor(None, _do_send, "post", post_body)
+
+ else:
+ # Complex / long content – send as interactive card
+ elements = self._build_card_elements(msg.content)
+ for chunk in self._split_elements_by_table_limit(elements):
+ card = {"config": {"wide_screen_mode": True}, "elements": chunk}
+ await loop.run_in_executor(
+ None, _do_send,
+ "interactive", json.dumps(card, ensure_ascii=False),
+ )
+
+ except Exception as e:
+ logger.error("Error sending Feishu message: {}", e)
+ raise
+
+ def _on_message_sync(self, data: Any) -> None:
+ """
+ Sync handler for incoming messages (called from WebSocket thread).
+ Schedules async handling in the main event loop.
+ """
+ if self._loop and self._loop.is_running():
+ asyncio.run_coroutine_threadsafe(self._on_message(data), self._loop)
+
+ async def _on_message(self, data: Any) -> None:
+ """Handle incoming message from Feishu."""
+ try:
+ event = data.event
+ message = event.message
+ sender = event.sender
+
+ # Deduplication check
+ message_id = message.message_id
+ if message_id in self._processed_message_ids:
+ return
+ self._processed_message_ids[message_id] = None
+
+ # Trim cache
+ while len(self._processed_message_ids) > 1000:
+ self._processed_message_ids.popitem(last=False)
+
+ # Skip bot messages
+ if sender.sender_type == "bot":
+ return
+
+ sender_id = sender.sender_id.open_id if sender.sender_id else "unknown"
+ chat_id = message.chat_id
+ chat_type = message.chat_type
+ msg_type = message.message_type
+
+ if chat_type == "group" and not self._is_group_message_for_bot(message):
+ logger.debug("Feishu: skipping group message (not mentioned)")
+ return
+
+ # Add reaction
+ reaction_id = await self._add_reaction(message_id, self.config.react_emoji)
+
+ # Parse content
+ content_parts = []
+ media_paths = []
+
+ try:
+ content_json = json.loads(message.content) if message.content else {}
+ except json.JSONDecodeError:
+ content_json = {}
+
+ if msg_type == "text":
+ text = content_json.get("text", "")
+ if text:
+ content_parts.append(text)
+
+ elif msg_type == "post":
+ text, image_keys = _extract_post_content(content_json)
+ if text:
+ content_parts.append(text)
+ # Download images embedded in post
+ for img_key in image_keys:
+ file_path, content_text = await self._download_and_save_media(
+ "image", {"image_key": img_key}, message_id
+ )
+ if file_path:
+ media_paths.append(file_path)
+ content_parts.append(content_text)
+
+ elif msg_type in ("image", "audio", "file", "media"):
+ file_path, content_text = await self._download_and_save_media(msg_type, content_json, message_id)
+ if file_path:
+ media_paths.append(file_path)
+
+ if msg_type == "audio" and file_path:
+ transcription = await self.transcribe_audio(file_path)
+ if transcription:
+ content_text = f"[transcription: {transcription}]"
+
+ content_parts.append(content_text)
+
+ elif msg_type in ("share_chat", "share_user", "interactive", "share_calendar_event", "system", "merge_forward"):
+ # Handle share cards and interactive messages
+ text = _extract_share_card_content(content_json, msg_type)
+ if text:
+ content_parts.append(text)
+
+ else:
+ content_parts.append(MSG_TYPE_MAP.get(msg_type, f"[{msg_type}]"))
+
+ # Extract reply context (parent/root message IDs)
+ parent_id = getattr(message, "parent_id", None) or None
+ root_id = getattr(message, "root_id", None) or None
+ thread_id = getattr(message, "thread_id", None) or None
+
+ # Prepend quoted message text when the user replied to another message
+ if parent_id and self._client:
+ loop = asyncio.get_running_loop()
+ reply_ctx = await loop.run_in_executor(
+ None, self._get_message_content_sync, parent_id
+ )
+ if reply_ctx:
+ content_parts.insert(0, reply_ctx)
+
+ content = "\n".join(content_parts) if content_parts else ""
+
+ if not content and not media_paths:
+ return
+
+ # Forward to message bus
+ reply_to = chat_id if chat_type == "group" else sender_id
+ await self._handle_message(
+ sender_id=sender_id,
+ chat_id=reply_to,
+ content=content,
+ media=media_paths,
+ metadata={
+ "message_id": message_id,
+ "reaction_id": reaction_id,
+ "chat_type": chat_type,
+ "msg_type": msg_type,
+ "parent_id": parent_id,
+ "root_id": root_id,
+ "thread_id": thread_id,
+ }
+ )
+
+ except Exception as e:
+ logger.error("Error processing Feishu message: {}", e)
+
+ def _on_reaction_created(self, data: Any) -> None:
+ """Ignore reaction events so they do not generate SDK noise."""
+ pass
+
+ def _on_message_read(self, data: Any) -> None:
+ """Ignore read events so they do not generate SDK noise."""
+ pass
+
+ def _on_bot_p2p_chat_entered(self, data: Any) -> None:
+ """Ignore p2p-enter events when a user opens a bot chat."""
+ logger.debug("Bot entered p2p chat (user opened chat window)")
+ pass
+
+ @staticmethod
+ def _format_tool_hint_lines(tool_hint: str) -> str:
+ """Split tool hints across lines on top-level call separators only."""
+ parts: list[str] = []
+ buf: list[str] = []
+ depth = 0
+ in_string = False
+ quote_char = ""
+ escaped = False
+
+ for i, ch in enumerate(tool_hint):
+ buf.append(ch)
+
+ if in_string:
+ if escaped:
+ escaped = False
+ elif ch == "\\":
+ escaped = True
+ elif ch == quote_char:
+ in_string = False
+ continue
+
+ if ch in {'"', "'"}:
+ in_string = True
+ quote_char = ch
+ continue
+
+ if ch == "(":
+ depth += 1
+ continue
+
+ if ch == ")" and depth > 0:
+ depth -= 1
+ continue
+
+ if ch == "," and depth == 0:
+ next_char = tool_hint[i + 1] if i + 1 < len(tool_hint) else ""
+ if next_char == " ":
+ parts.append("".join(buf).rstrip())
+ buf = []
+
+ if buf:
+ parts.append("".join(buf).strip())
+
+ return "\n".join(part for part in parts if part)
+
+ async def _send_tool_hint_card(self, receive_id_type: str, receive_id: str, tool_hint: str) -> None:
+ """Send tool hint as an interactive card with formatted code block.
+
+ Args:
+ receive_id_type: "chat_id" or "open_id"
+ receive_id: The target chat or user ID
+ tool_hint: Formatted tool hint string (e.g., 'web_search("q"), read_file("path")')
+ """
+ loop = asyncio.get_running_loop()
+
+ # Put each top-level tool call on its own line without altering commas inside arguments.
+ formatted_code = self._format_tool_hint_lines(tool_hint)
+
+ card = {
+ "config": {"wide_screen_mode": True},
+ "elements": [
+ {
+ "tag": "markdown",
+ "content": f"**Tool Calls**\n\n```text\n{formatted_code}\n```"
+ }
+ ]
+ }
+
+ await loop.run_in_executor(
+ None, self._send_message_sync,
+ receive_id_type, receive_id, "interactive",
+ json.dumps(card, ensure_ascii=False),
+ )
diff --git a/nanobot/channels/manager.py b/nanobot/channels/manager.py
index 73c3334de..1f26f4d7a 100644
--- a/nanobot/channels/manager.py
+++ b/nanobot/channels/manager.py
@@ -1,5 +1,7 @@
"""Channel manager for coordinating chat channels."""
+from __future__ import annotations
+
import asyncio
from typing import Any
@@ -9,75 +11,113 @@ from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.base import BaseChannel
from nanobot.config.schema import Config
+from nanobot.utils.restart import consume_restart_notice_from_env, format_restart_completed_message
+
+# Retry delays for message sending (exponential backoff: 1s, 2s, 4s)
+_SEND_RETRY_DELAYS = (1, 2, 4)
class ChannelManager:
"""
Manages chat channels and coordinates message routing.
-
+
Responsibilities:
- Initialize enabled channels (Telegram, WhatsApp, etc.)
- Start/stop channels
- Route outbound messages
"""
-
+
def __init__(self, config: Config, bus: MessageBus):
self.config = config
self.bus = bus
self.channels: dict[str, BaseChannel] = {}
self._dispatch_task: asyncio.Task | None = None
-
+
self._init_channels()
-
+
def _init_channels(self) -> None:
- """Initialize channels based on config."""
-
- # Telegram channel
- if self.config.channels.telegram.enabled:
+ """Initialize channels discovered via pkgutil scan + entry_points plugins."""
+ from nanobot.channels.registry import discover_all
+
+ groq_key = self.config.providers.groq.api_key
+
+ for name, cls in discover_all().items():
+ section = getattr(self.config.channels, name, None)
+ if section is None:
+ continue
+ enabled = (
+ section.get("enabled", False)
+ if isinstance(section, dict)
+ else getattr(section, "enabled", False)
+ )
+ if not enabled:
+ continue
try:
- from nanobot.channels.telegram import TelegramChannel
- self.channels["telegram"] = TelegramChannel(
- self.config.channels.telegram,
- self.bus,
- groq_api_key=self.config.providers.groq.api_key,
+ channel = cls(section, self.bus)
+ channel.transcription_api_key = groq_key
+ self.channels[name] = channel
+ logger.info("{} channel enabled", cls.display_name)
+ except Exception as e:
+ logger.warning("{} channel not available: {}", name, e)
+
+ self._validate_allow_from()
+
+ def _validate_allow_from(self) -> None:
+ for name, ch in self.channels.items():
+ if getattr(ch.config, "allow_from", None) == []:
+ raise SystemExit(
+ f'Error: "{name}" has empty allowFrom (denies all). '
+ f'Set ["*"] to allow everyone, or add specific user IDs.'
)
- logger.info("Telegram channel enabled")
- except ImportError as e:
- logger.warning(f"Telegram channel not available: {e}")
-
- # WhatsApp channel
- if self.config.channels.whatsapp.enabled:
- try:
- from nanobot.channels.whatsapp import WhatsAppChannel
- self.channels["whatsapp"] = WhatsAppChannel(
- self.config.channels.whatsapp, self.bus
- )
- logger.info("WhatsApp channel enabled")
- except ImportError as e:
- logger.warning(f"WhatsApp channel not available: {e}")
-
+
+ async def _start_channel(self, name: str, channel: BaseChannel) -> None:
+ """Start a channel and log any exceptions."""
+ try:
+ await channel.start()
+ except Exception as e:
+ logger.error("Failed to start channel {}: {}", name, e)
+
async def start_all(self) -> None:
- """Start WhatsApp channel and the outbound dispatcher."""
+ """Start all channels and the outbound dispatcher."""
if not self.channels:
logger.warning("No channels enabled")
return
-
+
# Start outbound dispatcher
self._dispatch_task = asyncio.create_task(self._dispatch_outbound())
-
- # Start WhatsApp channel
+
+ # Start channels
tasks = []
for name, channel in self.channels.items():
- logger.info(f"Starting {name} channel...")
- tasks.append(asyncio.create_task(channel.start()))
-
+ logger.info("Starting {} channel...", name)
+ tasks.append(asyncio.create_task(self._start_channel(name, channel)))
+
+ self._notify_restart_done_if_needed()
+
# Wait for all to complete (they should run forever)
await asyncio.gather(*tasks, return_exceptions=True)
-
+
+ def _notify_restart_done_if_needed(self) -> None:
+ """Send restart completion message when runtime env markers are present."""
+ notice = consume_restart_notice_from_env()
+ if not notice:
+ return
+ target = self.channels.get(notice.channel)
+ if not target:
+ return
+ asyncio.create_task(self._send_with_retry(
+ target,
+ OutboundMessage(
+ channel=notice.channel,
+ chat_id=notice.chat_id,
+ content=format_restart_completed_message(notice.started_at_raw),
+ ),
+ ))
+
async def stop_all(self) -> None:
"""Stop all channels and the dispatcher."""
logger.info("Stopping all channels...")
-
+
# Stop dispatcher
if self._dispatch_task:
self._dispatch_task.cancel()
@@ -85,44 +125,149 @@ class ChannelManager:
await self._dispatch_task
except asyncio.CancelledError:
pass
-
+
# Stop all channels
for name, channel in self.channels.items():
try:
await channel.stop()
- logger.info(f"Stopped {name} channel")
+ logger.info("Stopped {} channel", name)
except Exception as e:
- logger.error(f"Error stopping {name}: {e}")
-
+ logger.error("Error stopping {}: {}", name, e)
+
async def _dispatch_outbound(self) -> None:
"""Dispatch outbound messages to the appropriate channel."""
logger.info("Outbound dispatcher started")
-
+
+ # Buffer for messages that couldn't be processed during delta coalescing
+ # (since asyncio.Queue doesn't support push_front)
+ pending: list[OutboundMessage] = []
+
while True:
try:
- msg = await asyncio.wait_for(
- self.bus.consume_outbound(),
- timeout=1.0
- )
-
+ # First check pending buffer before waiting on queue
+ if pending:
+ msg = pending.pop(0)
+ else:
+ msg = await asyncio.wait_for(
+ self.bus.consume_outbound(),
+ timeout=1.0
+ )
+
+ if msg.metadata.get("_progress"):
+ if msg.metadata.get("_tool_hint") and not self.config.channels.send_tool_hints:
+ continue
+ if not msg.metadata.get("_tool_hint") and not self.config.channels.send_progress:
+ continue
+
+ # Coalesce consecutive _stream_delta messages for the same (channel, chat_id)
+ # to reduce API calls and improve streaming latency
+ if msg.metadata.get("_stream_delta") and not msg.metadata.get("_stream_end"):
+ msg, extra_pending = self._coalesce_stream_deltas(msg)
+ pending.extend(extra_pending)
+
channel = self.channels.get(msg.channel)
if channel:
- try:
- await channel.send(msg)
- except Exception as e:
- logger.error(f"Error sending to {msg.channel}: {e}")
+ await self._send_with_retry(channel, msg)
else:
- logger.warning(f"Unknown channel: {msg.channel}")
-
+ logger.warning("Unknown channel: {}", msg.channel)
+
except asyncio.TimeoutError:
continue
except asyncio.CancelledError:
break
-
+
+ @staticmethod
+ async def _send_once(channel: BaseChannel, msg: OutboundMessage) -> None:
+ """Send one outbound message without retry policy."""
+ if msg.metadata.get("_stream_delta") or msg.metadata.get("_stream_end"):
+ await channel.send_delta(msg.chat_id, msg.content, msg.metadata)
+ elif not msg.metadata.get("_streamed"):
+ await channel.send(msg)
+
+ def _coalesce_stream_deltas(
+ self, first_msg: OutboundMessage
+ ) -> tuple[OutboundMessage, list[OutboundMessage]]:
+ """Merge consecutive _stream_delta messages for the same (channel, chat_id).
+
+ This reduces the number of API calls when the queue has accumulated multiple
+ deltas, which happens when LLM generates faster than the channel can process.
+
+ Returns:
+ tuple of (merged_message, list_of_non_matching_messages)
+ """
+ target_key = (first_msg.channel, first_msg.chat_id)
+ combined_content = first_msg.content
+ final_metadata = dict(first_msg.metadata or {})
+ non_matching: list[OutboundMessage] = []
+
+ # Only merge consecutive deltas. As soon as we hit any other message,
+ # stop and hand that boundary back to the dispatcher via `pending`.
+ while True:
+ try:
+ next_msg = self.bus.outbound.get_nowait()
+ except asyncio.QueueEmpty:
+ break
+
+ # Check if this message belongs to the same stream
+ same_target = (next_msg.channel, next_msg.chat_id) == target_key
+ is_delta = next_msg.metadata and next_msg.metadata.get("_stream_delta")
+ is_end = next_msg.metadata and next_msg.metadata.get("_stream_end")
+
+ if same_target and is_delta and not final_metadata.get("_stream_end"):
+ # Accumulate content
+ combined_content += next_msg.content
+ # If we see _stream_end, remember it and stop coalescing this stream
+ if is_end:
+ final_metadata["_stream_end"] = True
+ # Stream ended - stop coalescing this stream
+ break
+ else:
+ # First non-matching message defines the coalescing boundary.
+ non_matching.append(next_msg)
+ break
+
+ merged = OutboundMessage(
+ channel=first_msg.channel,
+ chat_id=first_msg.chat_id,
+ content=combined_content,
+ metadata=final_metadata,
+ )
+ return merged, non_matching
+
+ async def _send_with_retry(self, channel: BaseChannel, msg: OutboundMessage) -> None:
+ """Send a message with retry on failure using exponential backoff.
+
+ Note: CancelledError is re-raised to allow graceful shutdown.
+ """
+ max_attempts = max(self.config.channels.send_max_retries, 1)
+
+ for attempt in range(max_attempts):
+ try:
+ await self._send_once(channel, msg)
+ return # Send succeeded
+ except asyncio.CancelledError:
+ raise # Propagate cancellation for graceful shutdown
+ except Exception as e:
+ if attempt == max_attempts - 1:
+ logger.error(
+ "Failed to send to {} after {} attempts: {} - {}",
+ msg.channel, max_attempts, type(e).__name__, e
+ )
+ return
+ delay = _SEND_RETRY_DELAYS[min(attempt, len(_SEND_RETRY_DELAYS) - 1)]
+ logger.warning(
+ "Send to {} failed (attempt {}/{}): {}, retrying in {}s",
+ msg.channel, attempt + 1, max_attempts, type(e).__name__, delay
+ )
+ try:
+ await asyncio.sleep(delay)
+ except asyncio.CancelledError:
+ raise # Propagate cancellation during sleep
+
def get_channel(self, name: str) -> BaseChannel | None:
"""Get a channel by name."""
return self.channels.get(name)
-
+
def get_status(self) -> dict[str, Any]:
"""Get status of all channels."""
return {
@@ -132,7 +277,7 @@ class ChannelManager:
}
for name, channel in self.channels.items()
}
-
+
@property
def enabled_channels(self) -> list[str]:
"""Get list of enabled channel names."""
diff --git a/nanobot/channels/matrix.py b/nanobot/channels/matrix.py
new file mode 100644
index 000000000..bc6d9398a
--- /dev/null
+++ b/nanobot/channels/matrix.py
@@ -0,0 +1,847 @@
+"""Matrix (Element) channel — inbound sync + outbound message/media delivery."""
+
+import asyncio
+import logging
+import mimetypes
+import time
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Any, Literal, TypeAlias
+
+from loguru import logger
+from pydantic import Field
+
+try:
+ import nh3
+ from mistune import create_markdown
+ from nio import (
+ AsyncClient,
+ AsyncClientConfig,
+ ContentRepositoryConfigError,
+ DownloadError,
+ InviteEvent,
+ JoinError,
+ MatrixRoom,
+ MemoryDownloadResponse,
+ RoomEncryptedMedia,
+ RoomMessage,
+ RoomMessageMedia,
+ RoomMessageText,
+ RoomSendError,
+ RoomTypingError,
+ SyncError,
+ UploadError, RoomSendResponse,
+)
+ from nio.crypto.attachments import decrypt_attachment
+ from nio.exceptions import EncryptionError
+except ImportError as e:
+ raise ImportError(
+ "Matrix dependencies not installed. Run: pip install nanobot-ai[matrix]"
+ ) from e
+
+from nanobot.bus.events import OutboundMessage
+from nanobot.bus.queue import MessageBus
+from nanobot.channels.base import BaseChannel
+from nanobot.config.paths import get_data_dir, get_media_dir
+from nanobot.config.schema import Base
+from nanobot.utils.helpers import safe_filename
+
+TYPING_NOTICE_TIMEOUT_MS = 30_000
+# Must stay below TYPING_NOTICE_TIMEOUT_MS so the indicator doesn't expire mid-processing.
+TYPING_KEEPALIVE_INTERVAL_MS = 20_000
+MATRIX_HTML_FORMAT = "org.matrix.custom.html"
+_ATTACH_MARKER = "[attachment: {}]"
+_ATTACH_TOO_LARGE = "[attachment: {} - too large]"
+_ATTACH_FAILED = "[attachment: {} - download failed]"
+_ATTACH_UPLOAD_FAILED = "[attachment: {} - upload failed]"
+_DEFAULT_ATTACH_NAME = "attachment"
+_MSGTYPE_MAP = {"m.image": "image", "m.audio": "audio", "m.video": "video", "m.file": "file"}
+
+MATRIX_MEDIA_EVENT_FILTER = (RoomMessageMedia, RoomEncryptedMedia)
+MatrixMediaEvent: TypeAlias = RoomMessageMedia | RoomEncryptedMedia
+
+MATRIX_MARKDOWN = create_markdown(
+ escape=True,
+ plugins=["table", "strikethrough", "url", "superscript", "subscript"],
+)
+
+MATRIX_ALLOWED_HTML_TAGS = {
+ "p", "a", "strong", "em", "del", "code", "pre", "blockquote",
+ "ul", "ol", "li", "h1", "h2", "h3", "h4", "h5", "h6",
+ "hr", "br", "table", "thead", "tbody", "tr", "th", "td",
+ "caption", "sup", "sub", "img",
+}
+MATRIX_ALLOWED_HTML_ATTRIBUTES: dict[str, set[str]] = {
+ "a": {"href"}, "code": {"class"}, "ol": {"start"},
+ "img": {"src", "alt", "title", "width", "height"},
+}
+MATRIX_ALLOWED_URL_SCHEMES = {"https", "http", "matrix", "mailto", "mxc"}
+
+
+def _filter_matrix_html_attribute(tag: str, attr: str, value: str) -> str | None:
+ """Filter attribute values to a safe Matrix-compatible subset."""
+ if tag == "a" and attr == "href":
+ return value if value.lower().startswith(("https://", "http://", "matrix:", "mailto:")) else None
+ if tag == "img" and attr == "src":
+ return value if value.lower().startswith("mxc://") else None
+ if tag == "code" and attr == "class":
+ classes = [c for c in value.split() if c.startswith("language-") and not c.startswith("language-_")]
+ return " ".join(classes) if classes else None
+ return value
+
+
+MATRIX_HTML_CLEANER = nh3.Cleaner(
+ tags=MATRIX_ALLOWED_HTML_TAGS,
+ attributes=MATRIX_ALLOWED_HTML_ATTRIBUTES,
+ attribute_filter=_filter_matrix_html_attribute,
+ url_schemes=MATRIX_ALLOWED_URL_SCHEMES,
+ strip_comments=True,
+ link_rel="noopener noreferrer",
+)
+
+@dataclass
+class _StreamBuf:
+ """
+ Represents a buffer for managing LLM response stream data.
+
+ :ivar text: Stores the text content of the buffer.
+ :type text: str
+ :ivar event_id: Identifier for the associated event. None indicates no
+ specific event association.
+ :type event_id: str | None
+ :ivar last_edit: Timestamp of the most recent edit to the buffer.
+ :type last_edit: float
+ """
+ text: str = ""
+ event_id: str | None = None
+ last_edit: float = 0.0
+
+def _render_markdown_html(text: str) -> str | None:
+ """Render markdown to sanitized HTML; returns None for plain text."""
+ try:
+ formatted = MATRIX_HTML_CLEANER.clean(MATRIX_MARKDOWN(text)).strip()
+ except Exception:
+ return None
+ if not formatted:
+ return None
+ # Skip formatted_body for plain text
to keep payload minimal.
+ if formatted.startswith("") and formatted.endswith("
"):
+ inner = formatted[3:-4]
+ if "<" not in inner and ">" not in inner:
+ return None
+ return formatted
+
+
+def _build_matrix_text_content(
+ text: str,
+ event_id: str | None = None,
+ thread_relates_to: dict[str, object] | None = None,
+) -> dict[str, object]:
+ """
+ Constructs and returns a dictionary representing the matrix text content with optional
+ HTML formatting and reference to an existing event for replacement. This function is
+ primarily used to create content payloads compatible with the Matrix messaging protocol.
+
+ :param text: The plain text content to include in the message.
+ :type text: str
+ :param event_id: Optional ID of the event to replace. If provided, the function will
+ include information indicating that the message is a replacement of the specified
+ event.
+ :type event_id: str | None
+ :param thread_relates_to: Optional Matrix thread relation metadata. For edits this is
+ stored in ``m.new_content`` so the replacement remains in the same thread.
+ :type thread_relates_to: dict[str, object] | None
+ :return: A dictionary containing the matrix text content, potentially enriched with
+ HTML formatting and replacement metadata if applicable.
+ :rtype: dict[str, object]
+ """
+ content: dict[str, object] = {"msgtype": "m.text", "body": text, "m.mentions": {}}
+ if html := _render_markdown_html(text):
+ content["format"] = MATRIX_HTML_FORMAT
+ content["formatted_body"] = html
+ if event_id:
+ content["m.new_content"] = {
+ "body": text,
+ "msgtype": "m.text",
+ }
+ content["m.relates_to"] = {
+ "rel_type": "m.replace",
+ "event_id": event_id,
+ }
+ if thread_relates_to:
+ content["m.new_content"]["m.relates_to"] = thread_relates_to
+ elif thread_relates_to:
+ content["m.relates_to"] = thread_relates_to
+
+ return content
+
+
+class _NioLoguruHandler(logging.Handler):
+ """Route matrix-nio stdlib logs into Loguru."""
+
+ def emit(self, record: logging.LogRecord) -> None:
+ try:
+ level = logger.level(record.levelname).name
+ except ValueError:
+ level = record.levelno
+ frame, depth = logging.currentframe(), 2
+ while frame and frame.f_code.co_filename == logging.__file__:
+ frame, depth = frame.f_back, depth + 1
+ logger.opt(depth=depth, exception=record.exc_info).log(level, record.getMessage())
+
+
+def _configure_nio_logging_bridge() -> None:
+ """Bridge matrix-nio logs to Loguru (idempotent)."""
+ nio_logger = logging.getLogger("nio")
+ if not any(isinstance(h, _NioLoguruHandler) for h in nio_logger.handlers):
+ nio_logger.handlers = [_NioLoguruHandler()]
+ nio_logger.propagate = False
+
+
+class MatrixConfig(Base):
+ """Matrix (Element) channel configuration."""
+
+ enabled: bool = False
+ homeserver: str = "https://matrix.org"
+ access_token: str = ""
+ user_id: str = ""
+ device_id: str = ""
+ e2ee_enabled: bool = True
+ sync_stop_grace_seconds: int = 2
+ max_media_bytes: int = 20 * 1024 * 1024
+ allow_from: list[str] = Field(default_factory=list)
+ group_policy: Literal["open", "mention", "allowlist"] = "open"
+ group_allow_from: list[str] = Field(default_factory=list)
+ allow_room_mentions: bool = False,
+ streaming: bool = False
+
+
+class MatrixChannel(BaseChannel):
+ """Matrix (Element) channel using long-polling sync."""
+
+ name = "matrix"
+ display_name = "Matrix"
+ _STREAM_EDIT_INTERVAL = 2 # min seconds between edit_message_text calls
+ monotonic_time = time.monotonic
+
+ @classmethod
+ def default_config(cls) -> dict[str, Any]:
+ return MatrixConfig().model_dump(by_alias=True)
+
+ def __init__(
+ self,
+ config: Any,
+ bus: MessageBus,
+ *,
+ restrict_to_workspace: bool = False,
+ workspace: str | Path | None = None,
+ ):
+ if isinstance(config, dict):
+ config = MatrixConfig.model_validate(config)
+ super().__init__(config, bus)
+ self.client: AsyncClient | None = None
+ self._sync_task: asyncio.Task | None = None
+ self._typing_tasks: dict[str, asyncio.Task] = {}
+ self._restrict_to_workspace = bool(restrict_to_workspace)
+ self._workspace = (
+ Path(workspace).expanduser().resolve(strict=False) if workspace is not None else None
+ )
+ self._server_upload_limit_bytes: int | None = None
+ self._server_upload_limit_checked = False
+ self._stream_bufs: dict[str, _StreamBuf] = {}
+
+
+ async def start(self) -> None:
+ """Start Matrix client and begin sync loop."""
+ self._running = True
+ _configure_nio_logging_bridge()
+
+ store_path = get_data_dir() / "matrix-store"
+ store_path.mkdir(parents=True, exist_ok=True)
+
+ self.client = AsyncClient(
+ homeserver=self.config.homeserver, user=self.config.user_id,
+ store_path=store_path,
+ config=AsyncClientConfig(store_sync_tokens=True, encryption_enabled=self.config.e2ee_enabled),
+ )
+ self.client.user_id = self.config.user_id
+ self.client.access_token = self.config.access_token
+ self.client.device_id = self.config.device_id
+
+ self._register_event_callbacks()
+ self._register_response_callbacks()
+
+ if not self.config.e2ee_enabled:
+ logger.warning("Matrix E2EE disabled; encrypted rooms may be undecryptable.")
+
+ if self.config.device_id:
+ try:
+ self.client.load_store()
+ except Exception:
+ logger.exception("Matrix store load failed; restart may replay recent messages.")
+ else:
+ logger.warning("Matrix device_id empty; restart may replay recent messages.")
+
+ self._sync_task = asyncio.create_task(self._sync_loop())
+
+ async def stop(self) -> None:
+ """Stop the Matrix channel with graceful sync shutdown."""
+ self._running = False
+ for room_id in list(self._typing_tasks):
+ await self._stop_typing_keepalive(room_id, clear_typing=False)
+ if self.client:
+ self.client.stop_sync_forever()
+ if self._sync_task:
+ try:
+ await asyncio.wait_for(asyncio.shield(self._sync_task),
+ timeout=self.config.sync_stop_grace_seconds)
+ except (asyncio.TimeoutError, asyncio.CancelledError):
+ self._sync_task.cancel()
+ try:
+ await self._sync_task
+ except asyncio.CancelledError:
+ pass
+ if self.client:
+ await self.client.close()
+
+ def _is_workspace_path_allowed(self, path: Path) -> bool:
+ """Check path is inside workspace (when restriction enabled)."""
+ if not self._restrict_to_workspace or not self._workspace:
+ return True
+ try:
+ path.resolve(strict=False).relative_to(self._workspace)
+ return True
+ except ValueError:
+ return False
+
+ def _collect_outbound_media_candidates(self, media: list[str]) -> list[Path]:
+ """Deduplicate and resolve outbound attachment paths."""
+ seen: set[str] = set()
+ candidates: list[Path] = []
+ for raw in media:
+ if not isinstance(raw, str) or not raw.strip():
+ continue
+ path = Path(raw.strip()).expanduser()
+ try:
+ key = str(path.resolve(strict=False))
+ except OSError:
+ key = str(path)
+ if key not in seen:
+ seen.add(key)
+ candidates.append(path)
+ return candidates
+
+ @staticmethod
+ def _build_outbound_attachment_content(
+ *, filename: str, mime: str, size_bytes: int,
+ mxc_url: str, encryption_info: dict[str, Any] | None = None,
+ ) -> dict[str, Any]:
+ """Build Matrix content payload for an uploaded file/image/audio/video."""
+ prefix = mime.split("/")[0]
+ msgtype = {"image": "m.image", "audio": "m.audio", "video": "m.video"}.get(prefix, "m.file")
+ content: dict[str, Any] = {
+ "msgtype": msgtype, "body": filename, "filename": filename,
+ "info": {"mimetype": mime, "size": size_bytes}, "m.mentions": {},
+ }
+ if encryption_info:
+ content["file"] = {**encryption_info, "url": mxc_url}
+ else:
+ content["url"] = mxc_url
+ return content
+
+ def _is_encrypted_room(self, room_id: str) -> bool:
+ if not self.client:
+ return False
+ room = getattr(self.client, "rooms", {}).get(room_id)
+ return bool(getattr(room, "encrypted", False))
+
+ async def _send_room_content(self, room_id: str,
+ content: dict[str, Any]) -> None | RoomSendResponse | RoomSendError:
+ """Send m.room.message with E2EE options."""
+ if not self.client:
+ return None
+ kwargs: dict[str, Any] = {"room_id": room_id, "message_type": "m.room.message", "content": content}
+
+ if self.config.e2ee_enabled:
+ kwargs["ignore_unverified_devices"] = True
+ response = await self.client.room_send(**kwargs)
+ return response
+
+ async def _resolve_server_upload_limit_bytes(self) -> int | None:
+ """Query homeserver upload limit once per channel lifecycle."""
+ if self._server_upload_limit_checked:
+ return self._server_upload_limit_bytes
+ self._server_upload_limit_checked = True
+ if not self.client:
+ return None
+ try:
+ response = await self.client.content_repository_config()
+ except Exception:
+ return None
+ upload_size = getattr(response, "upload_size", None)
+ if isinstance(upload_size, int) and upload_size > 0:
+ self._server_upload_limit_bytes = upload_size
+ return upload_size
+ return None
+
+ async def _effective_media_limit_bytes(self) -> int:
+ """min(local config, server advertised) — 0 blocks all uploads."""
+ local_limit = max(int(self.config.max_media_bytes), 0)
+ server_limit = await self._resolve_server_upload_limit_bytes()
+ if server_limit is None:
+ return local_limit
+ return min(local_limit, server_limit) if local_limit else 0
+
+ async def _upload_and_send_attachment(
+ self, room_id: str, path: Path, limit_bytes: int,
+ relates_to: dict[str, Any] | None = None,
+ ) -> str | None:
+ """Upload one local file to Matrix and send it as a media message. Returns failure marker or None."""
+ if not self.client:
+ return _ATTACH_UPLOAD_FAILED.format(path.name or _DEFAULT_ATTACH_NAME)
+
+ resolved = path.expanduser().resolve(strict=False)
+ filename = safe_filename(resolved.name) or _DEFAULT_ATTACH_NAME
+ fail = _ATTACH_UPLOAD_FAILED.format(filename)
+
+ if not resolved.is_file() or not self._is_workspace_path_allowed(resolved):
+ return fail
+ try:
+ size_bytes = resolved.stat().st_size
+ except OSError:
+ return fail
+ if limit_bytes <= 0 or size_bytes > limit_bytes:
+ return _ATTACH_TOO_LARGE.format(filename)
+
+ mime = mimetypes.guess_type(filename, strict=False)[0] or "application/octet-stream"
+ try:
+ with resolved.open("rb") as f:
+ upload_result = await self.client.upload(
+ f, content_type=mime, filename=filename,
+ encrypt=self.config.e2ee_enabled and self._is_encrypted_room(room_id),
+ filesize=size_bytes,
+ )
+ except Exception:
+ return fail
+
+ upload_response = upload_result[0] if isinstance(upload_result, tuple) else upload_result
+ encryption_info = upload_result[1] if isinstance(upload_result, tuple) and isinstance(upload_result[1], dict) else None
+ if isinstance(upload_response, UploadError):
+ return fail
+ mxc_url = getattr(upload_response, "content_uri", None)
+ if not isinstance(mxc_url, str) or not mxc_url.startswith("mxc://"):
+ return fail
+
+ content = self._build_outbound_attachment_content(
+ filename=filename, mime=mime, size_bytes=size_bytes,
+ mxc_url=mxc_url, encryption_info=encryption_info,
+ )
+ if relates_to:
+ content["m.relates_to"] = relates_to
+ try:
+ await self._send_room_content(room_id, content)
+ except Exception:
+ return fail
+ return None
+
+ async def send(self, msg: OutboundMessage) -> None:
+ """Send outbound content; clear typing for non-progress messages."""
+ if not self.client:
+ return
+ text = msg.content or ""
+ candidates = self._collect_outbound_media_candidates(msg.media)
+ relates_to = self._build_thread_relates_to(msg.metadata)
+ is_progress = bool((msg.metadata or {}).get("_progress"))
+ try:
+ failures: list[str] = []
+ if candidates:
+ limit_bytes = await self._effective_media_limit_bytes()
+ for path in candidates:
+ if fail := await self._upload_and_send_attachment(
+ room_id=msg.chat_id,
+ path=path,
+ limit_bytes=limit_bytes,
+ relates_to=relates_to,
+ ):
+ failures.append(fail)
+ if failures:
+ text = f"{text.rstrip()}\n{chr(10).join(failures)}" if text.strip() else "\n".join(failures)
+ if text or not candidates:
+ content = _build_matrix_text_content(text)
+ if relates_to:
+ content["m.relates_to"] = relates_to
+ await self._send_room_content(msg.chat_id, content)
+ finally:
+ if not is_progress:
+ await self._stop_typing_keepalive(msg.chat_id, clear_typing=True)
+
+ async def send_delta(self, chat_id: str, delta: str, metadata: dict[str, Any] | None = None) -> None:
+ meta = metadata or {}
+ relates_to = self._build_thread_relates_to(metadata)
+
+ if meta.get("_stream_end"):
+ buf = self._stream_bufs.pop(chat_id, None)
+ if not buf or not buf.event_id or not buf.text:
+ return
+
+ await self._stop_typing_keepalive(chat_id, clear_typing=True)
+
+ content = _build_matrix_text_content(
+ buf.text,
+ buf.event_id,
+ thread_relates_to=relates_to,
+ )
+ await self._send_room_content(chat_id, content)
+ return
+
+ buf = self._stream_bufs.get(chat_id)
+ if buf is None:
+ buf = _StreamBuf()
+ self._stream_bufs[chat_id] = buf
+ buf.text += delta
+
+ if not buf.text.strip():
+ return
+
+ now = self.monotonic_time()
+
+ if not buf.last_edit or (now - buf.last_edit) >= self._STREAM_EDIT_INTERVAL:
+ try:
+ content = _build_matrix_text_content(
+ buf.text,
+ buf.event_id,
+ thread_relates_to=relates_to,
+ )
+ response = await self._send_room_content(chat_id, content)
+ buf.last_edit = now
+ if not buf.event_id:
+ # we are editing the same message all the time, so only the first time the event id needs to be set
+ buf.event_id = response.event_id
+ except Exception:
+ await self._stop_typing_keepalive(chat_id, clear_typing=True)
+ pass
+
+
+ def _register_event_callbacks(self) -> None:
+ self.client.add_event_callback(self._on_message, RoomMessageText)
+ self.client.add_event_callback(self._on_media_message, MATRIX_MEDIA_EVENT_FILTER)
+ self.client.add_event_callback(self._on_room_invite, InviteEvent)
+
+ def _register_response_callbacks(self) -> None:
+ self.client.add_response_callback(self._on_sync_error, SyncError)
+ self.client.add_response_callback(self._on_join_error, JoinError)
+ self.client.add_response_callback(self._on_send_error, RoomSendError)
+
+ def _log_response_error(self, label: str, response: Any) -> None:
+ """Log Matrix response errors — auth errors at ERROR level, rest at WARNING."""
+ code = getattr(response, "status_code", None)
+ is_auth = code in {"M_UNKNOWN_TOKEN", "M_FORBIDDEN", "M_UNAUTHORIZED"}
+ is_fatal = is_auth or getattr(response, "soft_logout", False)
+ (logger.error if is_fatal else logger.warning)("Matrix {} failed: {}", label, response)
+
+ async def _on_sync_error(self, response: SyncError) -> None:
+ self._log_response_error("sync", response)
+
+ async def _on_join_error(self, response: JoinError) -> None:
+ self._log_response_error("join", response)
+
+ async def _on_send_error(self, response: RoomSendError) -> None:
+ self._log_response_error("send", response)
+
+ async def _set_typing(self, room_id: str, typing: bool) -> None:
+ """Best-effort typing indicator update."""
+ if not self.client:
+ return
+ try:
+ response = await self.client.room_typing(room_id=room_id, typing_state=typing,
+ timeout=TYPING_NOTICE_TIMEOUT_MS)
+ if isinstance(response, RoomTypingError):
+ logger.debug("Matrix typing failed for {}: {}", room_id, response)
+ except Exception:
+ pass
+
+ async def _start_typing_keepalive(self, room_id: str) -> None:
+ """Start periodic typing refresh (spec-recommended keepalive)."""
+ await self._stop_typing_keepalive(room_id, clear_typing=False)
+ await self._set_typing(room_id, True)
+ if not self._running:
+ return
+
+ async def loop() -> None:
+ try:
+ while self._running:
+ await asyncio.sleep(TYPING_KEEPALIVE_INTERVAL_MS / 1000)
+ await self._set_typing(room_id, True)
+ except asyncio.CancelledError:
+ pass
+
+ self._typing_tasks[room_id] = asyncio.create_task(loop())
+
+ async def _stop_typing_keepalive(self, room_id: str, *, clear_typing: bool) -> None:
+ if task := self._typing_tasks.pop(room_id, None):
+ task.cancel()
+ try:
+ await task
+ except asyncio.CancelledError:
+ pass
+ if clear_typing:
+ await self._set_typing(room_id, False)
+
+ async def _sync_loop(self) -> None:
+ while self._running:
+ try:
+ await self.client.sync_forever(timeout=30000, full_state=True)
+ except asyncio.CancelledError:
+ break
+ except Exception:
+ await asyncio.sleep(2)
+
+ async def _on_room_invite(self, room: MatrixRoom, event: InviteEvent) -> None:
+ if self.is_allowed(event.sender):
+ await self.client.join(room.room_id)
+
+ def _is_direct_room(self, room: MatrixRoom) -> bool:
+ count = getattr(room, "member_count", None)
+ return isinstance(count, int) and count <= 2
+
+ def _is_bot_mentioned(self, event: RoomMessage) -> bool:
+ """Check m.mentions payload for bot mention."""
+ source = getattr(event, "source", None)
+ if not isinstance(source, dict):
+ return False
+ mentions = (source.get("content") or {}).get("m.mentions")
+ if not isinstance(mentions, dict):
+ return False
+ user_ids = mentions.get("user_ids")
+ if isinstance(user_ids, list) and self.config.user_id in user_ids:
+ return True
+ return bool(self.config.allow_room_mentions and mentions.get("room") is True)
+
+ def _should_process_message(self, room: MatrixRoom, event: RoomMessage) -> bool:
+ """Apply sender and room policy checks."""
+ if not self.is_allowed(event.sender):
+ return False
+ if self._is_direct_room(room):
+ return True
+ policy = self.config.group_policy
+ if policy == "open":
+ return True
+ if policy == "allowlist":
+ return room.room_id in (self.config.group_allow_from or [])
+ if policy == "mention":
+ return self._is_bot_mentioned(event)
+ return False
+
+ def _media_dir(self) -> Path:
+ return get_media_dir("matrix")
+
+ @staticmethod
+ def _event_source_content(event: RoomMessage) -> dict[str, Any]:
+ source = getattr(event, "source", None)
+ if not isinstance(source, dict):
+ return {}
+ content = source.get("content")
+ return content if isinstance(content, dict) else {}
+
+ def _event_thread_root_id(self, event: RoomMessage) -> str | None:
+ relates_to = self._event_source_content(event).get("m.relates_to")
+ if not isinstance(relates_to, dict) or relates_to.get("rel_type") != "m.thread":
+ return None
+ root_id = relates_to.get("event_id")
+ return root_id if isinstance(root_id, str) and root_id else None
+
+ def _thread_metadata(self, event: RoomMessage) -> dict[str, str] | None:
+ if not (root_id := self._event_thread_root_id(event)):
+ return None
+ meta: dict[str, str] = {"thread_root_event_id": root_id}
+ if isinstance(reply_to := getattr(event, "event_id", None), str) and reply_to:
+ meta["thread_reply_to_event_id"] = reply_to
+ return meta
+
+ @staticmethod
+ def _build_thread_relates_to(metadata: dict[str, Any] | None) -> dict[str, Any] | None:
+ if not metadata:
+ return None
+ root_id = metadata.get("thread_root_event_id")
+ if not isinstance(root_id, str) or not root_id:
+ return None
+ reply_to = metadata.get("thread_reply_to_event_id") or metadata.get("event_id")
+ if not isinstance(reply_to, str) or not reply_to:
+ return None
+ return {"rel_type": "m.thread", "event_id": root_id,
+ "m.in_reply_to": {"event_id": reply_to}, "is_falling_back": True}
+
+ def _event_attachment_type(self, event: MatrixMediaEvent) -> str:
+ msgtype = self._event_source_content(event).get("msgtype")
+ return _MSGTYPE_MAP.get(msgtype, "file")
+
+ @staticmethod
+ def _is_encrypted_media_event(event: MatrixMediaEvent) -> bool:
+ return (isinstance(getattr(event, "key", None), dict)
+ and isinstance(getattr(event, "hashes", None), dict)
+ and isinstance(getattr(event, "iv", None), str))
+
+ def _event_declared_size_bytes(self, event: MatrixMediaEvent) -> int | None:
+ info = self._event_source_content(event).get("info")
+ size = info.get("size") if isinstance(info, dict) else None
+ return size if isinstance(size, int) and size >= 0 else None
+
+ def _event_mime(self, event: MatrixMediaEvent) -> str | None:
+ info = self._event_source_content(event).get("info")
+ if isinstance(info, dict) and isinstance(m := info.get("mimetype"), str) and m:
+ return m
+ m = getattr(event, "mimetype", None)
+ return m if isinstance(m, str) and m else None
+
+ def _event_filename(self, event: MatrixMediaEvent, attachment_type: str) -> str:
+ body = getattr(event, "body", None)
+ if isinstance(body, str) and body.strip():
+ if candidate := safe_filename(Path(body).name):
+ return candidate
+ return _DEFAULT_ATTACH_NAME if attachment_type == "file" else attachment_type
+
+ def _build_attachment_path(self, event: MatrixMediaEvent, attachment_type: str,
+ filename: str, mime: str | None) -> Path:
+ safe_name = safe_filename(Path(filename).name) or _DEFAULT_ATTACH_NAME
+ suffix = Path(safe_name).suffix
+ if not suffix and mime:
+ if guessed := mimetypes.guess_extension(mime, strict=False):
+ safe_name, suffix = f"{safe_name}{guessed}", guessed
+ stem = (Path(safe_name).stem or attachment_type)[:72]
+ suffix = suffix[:16]
+ event_id = safe_filename(str(getattr(event, "event_id", "") or "evt").lstrip("$"))
+ event_prefix = (event_id[:24] or "evt").strip("_")
+ return self._media_dir() / f"{event_prefix}_{stem}{suffix}"
+
+ async def _download_media_bytes(self, mxc_url: str) -> bytes | None:
+ if not self.client:
+ return None
+ response = await self.client.download(mxc=mxc_url)
+ if isinstance(response, DownloadError):
+ logger.warning("Matrix download failed for {}: {}", mxc_url, response)
+ return None
+ body = getattr(response, "body", None)
+ if isinstance(body, (bytes, bytearray)):
+ return bytes(body)
+ if isinstance(response, MemoryDownloadResponse):
+ return bytes(response.body)
+ if isinstance(body, (str, Path)):
+ path = Path(body)
+ if path.is_file():
+ try:
+ return path.read_bytes()
+ except OSError:
+ return None
+ return None
+
+ def _decrypt_media_bytes(self, event: MatrixMediaEvent, ciphertext: bytes) -> bytes | None:
+ key_obj, hashes, iv = getattr(event, "key", None), getattr(event, "hashes", None), getattr(event, "iv", None)
+ key = key_obj.get("k") if isinstance(key_obj, dict) else None
+ sha256 = hashes.get("sha256") if isinstance(hashes, dict) else None
+ if not all(isinstance(v, str) for v in (key, sha256, iv)):
+ return None
+ try:
+ return decrypt_attachment(ciphertext, key, sha256, iv)
+ except (EncryptionError, ValueError, TypeError):
+ logger.warning("Matrix decrypt failed for event {}", getattr(event, "event_id", ""))
+ return None
+
+ async def _fetch_media_attachment(
+ self, room: MatrixRoom, event: MatrixMediaEvent,
+ ) -> tuple[dict[str, Any] | None, str]:
+ """Download, decrypt if needed, and persist a Matrix attachment."""
+ atype = self._event_attachment_type(event)
+ mime = self._event_mime(event)
+ filename = self._event_filename(event, atype)
+ mxc_url = getattr(event, "url", None)
+ fail = _ATTACH_FAILED.format(filename)
+
+ if not isinstance(mxc_url, str) or not mxc_url.startswith("mxc://"):
+ return None, fail
+
+ limit_bytes = await self._effective_media_limit_bytes()
+ declared = self._event_declared_size_bytes(event)
+ if declared is not None and declared > limit_bytes:
+ return None, _ATTACH_TOO_LARGE.format(filename)
+
+ downloaded = await self._download_media_bytes(mxc_url)
+ if downloaded is None:
+ return None, fail
+
+ encrypted = self._is_encrypted_media_event(event)
+ data = downloaded
+ if encrypted:
+ if (data := self._decrypt_media_bytes(event, downloaded)) is None:
+ return None, fail
+
+ if len(data) > limit_bytes:
+ return None, _ATTACH_TOO_LARGE.format(filename)
+
+ path = self._build_attachment_path(event, atype, filename, mime)
+ try:
+ path.write_bytes(data)
+ except OSError:
+ return None, fail
+
+ attachment = {
+ "type": atype, "mime": mime, "filename": filename,
+ "event_id": str(getattr(event, "event_id", "") or ""),
+ "encrypted": encrypted, "size_bytes": len(data),
+ "path": str(path), "mxc_url": mxc_url,
+ }
+ return attachment, _ATTACH_MARKER.format(path)
+
+ def _base_metadata(self, room: MatrixRoom, event: RoomMessage) -> dict[str, Any]:
+ """Build common metadata for text and media handlers."""
+ meta: dict[str, Any] = {"room": getattr(room, "display_name", room.room_id)}
+ if isinstance(eid := getattr(event, "event_id", None), str) and eid:
+ meta["event_id"] = eid
+ if thread := self._thread_metadata(event):
+ meta.update(thread)
+ return meta
+
+ async def _on_message(self, room: MatrixRoom, event: RoomMessageText) -> None:
+ if event.sender == self.config.user_id or not self._should_process_message(room, event):
+ return
+ await self._start_typing_keepalive(room.room_id)
+ try:
+ await self._handle_message(
+ sender_id=event.sender, chat_id=room.room_id,
+ content=event.body, metadata=self._base_metadata(room, event),
+ )
+ except Exception:
+ await self._stop_typing_keepalive(room.room_id, clear_typing=True)
+ raise
+
+ async def _on_media_message(self, room: MatrixRoom, event: MatrixMediaEvent) -> None:
+ if event.sender == self.config.user_id or not self._should_process_message(room, event):
+ return
+ attachment, marker = await self._fetch_media_attachment(room, event)
+ parts: list[str] = []
+ if isinstance(body := getattr(event, "body", None), str) and body.strip():
+ parts.append(body.strip())
+
+ if attachment and attachment.get("type") == "audio":
+ transcription = await self.transcribe_audio(attachment["path"])
+ if transcription:
+ parts.append(f"[transcription: {transcription}]")
+ else:
+ parts.append(marker)
+ elif marker:
+ parts.append(marker)
+
+ await self._start_typing_keepalive(room.room_id)
+ try:
+ meta = self._base_metadata(room, event)
+ meta["attachments"] = []
+ if attachment:
+ meta["attachments"] = [attachment]
+ await self._handle_message(
+ sender_id=event.sender, chat_id=room.room_id,
+ content="\n".join(parts),
+ media=[attachment["path"]] if attachment else [],
+ metadata=meta,
+ )
+ except Exception:
+ await self._stop_typing_keepalive(room.room_id, clear_typing=True)
+ raise
diff --git a/nanobot/channels/mochat.py b/nanobot/channels/mochat.py
new file mode 100644
index 000000000..0b02aec62
--- /dev/null
+++ b/nanobot/channels/mochat.py
@@ -0,0 +1,947 @@
+"""Mochat channel implementation using Socket.IO with HTTP polling fallback."""
+
+from __future__ import annotations
+
+import asyncio
+import json
+from collections import deque
+from dataclasses import dataclass, field
+from datetime import datetime
+from typing import Any
+
+import httpx
+from loguru import logger
+
+from nanobot.bus.events import OutboundMessage
+from nanobot.bus.queue import MessageBus
+from nanobot.channels.base import BaseChannel
+from nanobot.config.paths import get_runtime_subdir
+from nanobot.config.schema import Base
+from pydantic import Field
+
+try:
+ import socketio
+ SOCKETIO_AVAILABLE = True
+except ImportError:
+ socketio = None
+ SOCKETIO_AVAILABLE = False
+
+try:
+ import msgpack # noqa: F401
+ MSGPACK_AVAILABLE = True
+except ImportError:
+ MSGPACK_AVAILABLE = False
+
+MAX_SEEN_MESSAGE_IDS = 2000
+CURSOR_SAVE_DEBOUNCE_S = 0.5
+
+
+# ---------------------------------------------------------------------------
+# Data classes
+# ---------------------------------------------------------------------------
+
+@dataclass
+class MochatBufferedEntry:
+ """Buffered inbound entry for delayed dispatch."""
+ raw_body: str
+ author: str
+ sender_name: str = ""
+ sender_username: str = ""
+ timestamp: int | None = None
+ message_id: str = ""
+ group_id: str = ""
+
+
+@dataclass
+class DelayState:
+ """Per-target delayed message state."""
+ entries: list[MochatBufferedEntry] = field(default_factory=list)
+ lock: asyncio.Lock = field(default_factory=asyncio.Lock)
+ timer: asyncio.Task | None = None
+
+
+@dataclass
+class MochatTarget:
+ """Outbound target resolution result."""
+ id: str
+ is_panel: bool
+
+
+# ---------------------------------------------------------------------------
+# Pure helpers
+# ---------------------------------------------------------------------------
+
+def _safe_dict(value: Any) -> dict:
+ """Return *value* if it's a dict, else empty dict."""
+ return value if isinstance(value, dict) else {}
+
+
+def _str_field(src: dict, *keys: str) -> str:
+ """Return the first non-empty str value found for *keys*, stripped."""
+ for k in keys:
+ v = src.get(k)
+ if isinstance(v, str) and v.strip():
+ return v.strip()
+ return ""
+
+
+def _make_synthetic_event(
+ message_id: str, author: str, content: Any,
+ meta: Any, group_id: str, converse_id: str,
+ timestamp: Any = None, *, author_info: Any = None,
+) -> dict[str, Any]:
+ """Build a synthetic ``message.add`` event dict."""
+ payload: dict[str, Any] = {
+ "messageId": message_id, "author": author,
+ "content": content, "meta": _safe_dict(meta),
+ "groupId": group_id, "converseId": converse_id,
+ }
+ if author_info is not None:
+ payload["authorInfo"] = _safe_dict(author_info)
+ return {
+ "type": "message.add",
+ "timestamp": timestamp or datetime.utcnow().isoformat(),
+ "payload": payload,
+ }
+
+
+def normalize_mochat_content(content: Any) -> str:
+ """Normalize content payload to text."""
+ if isinstance(content, str):
+ return content.strip()
+ if content is None:
+ return ""
+ try:
+ return json.dumps(content, ensure_ascii=False)
+ except TypeError:
+ return str(content)
+
+
+def resolve_mochat_target(raw: str) -> MochatTarget:
+ """Resolve id and target kind from user-provided target string."""
+ trimmed = (raw or "").strip()
+ if not trimmed:
+ return MochatTarget(id="", is_panel=False)
+
+ lowered = trimmed.lower()
+ cleaned, forced_panel = trimmed, False
+ for prefix in ("mochat:", "group:", "channel:", "panel:"):
+ if lowered.startswith(prefix):
+ cleaned = trimmed[len(prefix):].strip()
+ forced_panel = prefix in {"group:", "channel:", "panel:"}
+ break
+
+ if not cleaned:
+ return MochatTarget(id="", is_panel=False)
+ return MochatTarget(id=cleaned, is_panel=forced_panel or not cleaned.startswith("session_"))
+
+
+def extract_mention_ids(value: Any) -> list[str]:
+ """Extract mention ids from heterogeneous mention payload."""
+ if not isinstance(value, list):
+ return []
+ ids: list[str] = []
+ for item in value:
+ if isinstance(item, str):
+ if item.strip():
+ ids.append(item.strip())
+ elif isinstance(item, dict):
+ for key in ("id", "userId", "_id"):
+ candidate = item.get(key)
+ if isinstance(candidate, str) and candidate.strip():
+ ids.append(candidate.strip())
+ break
+ return ids
+
+
+def resolve_was_mentioned(payload: dict[str, Any], agent_user_id: str) -> bool:
+ """Resolve mention state from payload metadata and text fallback."""
+ meta = payload.get("meta")
+ if isinstance(meta, dict):
+ if meta.get("mentioned") is True or meta.get("wasMentioned") is True:
+ return True
+ for f in ("mentions", "mentionIds", "mentionedUserIds", "mentionedUsers"):
+ if agent_user_id and agent_user_id in extract_mention_ids(meta.get(f)):
+ return True
+ if not agent_user_id:
+ return False
+ content = payload.get("content")
+ if not isinstance(content, str) or not content:
+ return False
+ return f"<@{agent_user_id}>" in content or f"@{agent_user_id}" in content
+
+
+def resolve_require_mention(config: MochatConfig, session_id: str, group_id: str) -> bool:
+ """Resolve mention requirement for group/panel conversations."""
+ groups = config.groups or {}
+ for key in (group_id, session_id, "*"):
+ if key and key in groups:
+ return bool(groups[key].require_mention)
+ return bool(config.mention.require_in_groups)
+
+
+def build_buffered_body(entries: list[MochatBufferedEntry], is_group: bool) -> str:
+ """Build text body from one or more buffered entries."""
+ if not entries:
+ return ""
+ if len(entries) == 1:
+ return entries[0].raw_body
+ lines: list[str] = []
+ for entry in entries:
+ if not entry.raw_body:
+ continue
+ if is_group:
+ label = entry.sender_name.strip() or entry.sender_username.strip() or entry.author
+ if label:
+ lines.append(f"{label}: {entry.raw_body}")
+ continue
+ lines.append(entry.raw_body)
+ return "\n".join(lines).strip()
+
+
+def parse_timestamp(value: Any) -> int | None:
+ """Parse event timestamp to epoch milliseconds."""
+ if not isinstance(value, str) or not value.strip():
+ return None
+ try:
+ return int(datetime.fromisoformat(value.replace("Z", "+00:00")).timestamp() * 1000)
+ except ValueError:
+ return None
+
+
+# ---------------------------------------------------------------------------
+# Config classes
+# ---------------------------------------------------------------------------
+
+class MochatMentionConfig(Base):
+ """Mochat mention behavior configuration."""
+
+ require_in_groups: bool = False
+
+
+class MochatGroupRule(Base):
+ """Mochat per-group mention requirement."""
+
+ require_mention: bool = False
+
+
+class MochatConfig(Base):
+ """Mochat channel configuration."""
+
+ enabled: bool = False
+ base_url: str = "https://mochat.io"
+ socket_url: str = ""
+ socket_path: str = "/socket.io"
+ socket_disable_msgpack: bool = False
+ socket_reconnect_delay_ms: int = 1000
+ socket_max_reconnect_delay_ms: int = 10000
+ socket_connect_timeout_ms: int = 10000
+ refresh_interval_ms: int = 30000
+ watch_timeout_ms: int = 25000
+ watch_limit: int = 100
+ retry_delay_ms: int = 500
+ max_retry_attempts: int = 0
+ claw_token: str = ""
+ agent_user_id: str = ""
+ sessions: list[str] = Field(default_factory=list)
+ panels: list[str] = Field(default_factory=list)
+ allow_from: list[str] = Field(default_factory=list)
+ mention: MochatMentionConfig = Field(default_factory=MochatMentionConfig)
+ groups: dict[str, MochatGroupRule] = Field(default_factory=dict)
+ reply_delay_mode: str = "non-mention"
+ reply_delay_ms: int = 120000
+
+
+# ---------------------------------------------------------------------------
+# Channel
+# ---------------------------------------------------------------------------
+
+class MochatChannel(BaseChannel):
+ """Mochat channel using socket.io with fallback polling workers."""
+
+ name = "mochat"
+ display_name = "Mochat"
+
+ @classmethod
+ def default_config(cls) -> dict[str, Any]:
+ return MochatConfig().model_dump(by_alias=True)
+
+ def __init__(self, config: Any, bus: MessageBus):
+ if isinstance(config, dict):
+ config = MochatConfig.model_validate(config)
+ super().__init__(config, bus)
+ self.config: MochatConfig = config
+ self._http: httpx.AsyncClient | None = None
+ self._socket: Any = None
+ self._ws_connected = self._ws_ready = False
+
+ self._state_dir = get_runtime_subdir("mochat")
+ self._cursor_path = self._state_dir / "session_cursors.json"
+ self._session_cursor: dict[str, int] = {}
+ self._cursor_save_task: asyncio.Task | None = None
+
+ self._session_set: set[str] = set()
+ self._panel_set: set[str] = set()
+ self._auto_discover_sessions = self._auto_discover_panels = False
+
+ self._cold_sessions: set[str] = set()
+ self._session_by_converse: dict[str, str] = {}
+
+ self._seen_set: dict[str, set[str]] = {}
+ self._seen_queue: dict[str, deque[str]] = {}
+ self._delay_states: dict[str, DelayState] = {}
+
+ self._fallback_mode = False
+ self._session_fallback_tasks: dict[str, asyncio.Task] = {}
+ self._panel_fallback_tasks: dict[str, asyncio.Task] = {}
+ self._refresh_task: asyncio.Task | None = None
+ self._target_locks: dict[str, asyncio.Lock] = {}
+
+ # ---- lifecycle ---------------------------------------------------------
+
+ async def start(self) -> None:
+ """Start Mochat channel workers and websocket connection."""
+ if not self.config.claw_token:
+ logger.error("Mochat claw_token not configured")
+ return
+
+ self._running = True
+ self._http = httpx.AsyncClient(timeout=30.0)
+ self._state_dir.mkdir(parents=True, exist_ok=True)
+ await self._load_session_cursors()
+ self._seed_targets_from_config()
+ await self._refresh_targets(subscribe_new=False)
+
+ if not await self._start_socket_client():
+ await self._ensure_fallback_workers()
+
+ self._refresh_task = asyncio.create_task(self._refresh_loop())
+ while self._running:
+ await asyncio.sleep(1)
+
+ async def stop(self) -> None:
+ """Stop all workers and clean up resources."""
+ self._running = False
+ if self._refresh_task:
+ self._refresh_task.cancel()
+ self._refresh_task = None
+
+ await self._stop_fallback_workers()
+ await self._cancel_delay_timers()
+
+ if self._socket:
+ try:
+ await self._socket.disconnect()
+ except Exception:
+ pass
+ self._socket = None
+
+ if self._cursor_save_task:
+ self._cursor_save_task.cancel()
+ self._cursor_save_task = None
+ await self._save_session_cursors()
+
+ if self._http:
+ await self._http.aclose()
+ self._http = None
+ self._ws_connected = self._ws_ready = False
+
+ async def send(self, msg: OutboundMessage) -> None:
+ """Send outbound message to session or panel."""
+ if not self.config.claw_token:
+ logger.warning("Mochat claw_token missing, skip send")
+ return
+
+ parts = ([msg.content.strip()] if msg.content and msg.content.strip() else [])
+ if msg.media:
+ parts.extend(m for m in msg.media if isinstance(m, str) and m.strip())
+ content = "\n".join(parts).strip()
+ if not content:
+ return
+
+ target = resolve_mochat_target(msg.chat_id)
+ if not target.id:
+ logger.warning("Mochat outbound target is empty")
+ return
+
+ is_panel = (target.is_panel or target.id in self._panel_set) and not target.id.startswith("session_")
+ try:
+ if is_panel:
+ await self._api_send("/api/claw/groups/panels/send", "panelId", target.id,
+ content, msg.reply_to, self._read_group_id(msg.metadata))
+ else:
+ await self._api_send("/api/claw/sessions/send", "sessionId", target.id,
+ content, msg.reply_to)
+ except Exception as e:
+ logger.error("Failed to send Mochat message: {}", e)
+ raise
+
+ # ---- config / init helpers ---------------------------------------------
+
+ def _seed_targets_from_config(self) -> None:
+ sessions, self._auto_discover_sessions = self._normalize_id_list(self.config.sessions)
+ panels, self._auto_discover_panels = self._normalize_id_list(self.config.panels)
+ self._session_set.update(sessions)
+ self._panel_set.update(panels)
+ for sid in sessions:
+ if sid not in self._session_cursor:
+ self._cold_sessions.add(sid)
+
+ @staticmethod
+ def _normalize_id_list(values: list[str]) -> tuple[list[str], bool]:
+ cleaned = [str(v).strip() for v in values if str(v).strip()]
+ return sorted({v for v in cleaned if v != "*"}), "*" in cleaned
+
+ # ---- websocket ---------------------------------------------------------
+
+ async def _start_socket_client(self) -> bool:
+ if not SOCKETIO_AVAILABLE:
+ logger.warning("python-socketio not installed, Mochat using polling fallback")
+ return False
+
+ serializer = "default"
+ if not self.config.socket_disable_msgpack:
+ if MSGPACK_AVAILABLE:
+ serializer = "msgpack"
+ else:
+ logger.warning("msgpack not installed but socket_disable_msgpack=false; using JSON")
+
+ client = socketio.AsyncClient(
+ reconnection=True,
+ reconnection_attempts=self.config.max_retry_attempts or None,
+ reconnection_delay=max(0.1, self.config.socket_reconnect_delay_ms / 1000.0),
+ reconnection_delay_max=max(0.1, self.config.socket_max_reconnect_delay_ms / 1000.0),
+ logger=False, engineio_logger=False, serializer=serializer,
+ )
+
+ @client.event
+ async def connect() -> None:
+ self._ws_connected, self._ws_ready = True, False
+ logger.info("Mochat websocket connected")
+ subscribed = await self._subscribe_all()
+ self._ws_ready = subscribed
+ await (self._stop_fallback_workers() if subscribed else self._ensure_fallback_workers())
+
+ @client.event
+ async def disconnect() -> None:
+ if not self._running:
+ return
+ self._ws_connected = self._ws_ready = False
+ logger.warning("Mochat websocket disconnected")
+ await self._ensure_fallback_workers()
+
+ @client.event
+ async def connect_error(data: Any) -> None:
+ logger.error("Mochat websocket connect error: {}", data)
+
+ @client.on("claw.session.events")
+ async def on_session_events(payload: dict[str, Any]) -> None:
+ await self._handle_watch_payload(payload, "session")
+
+ @client.on("claw.panel.events")
+ async def on_panel_events(payload: dict[str, Any]) -> None:
+ await self._handle_watch_payload(payload, "panel")
+
+ for ev in ("notify:chat.inbox.append", "notify:chat.message.add",
+ "notify:chat.message.update", "notify:chat.message.recall",
+ "notify:chat.message.delete"):
+ client.on(ev, self._build_notify_handler(ev))
+
+ socket_url = (self.config.socket_url or self.config.base_url).strip().rstrip("/")
+ socket_path = (self.config.socket_path or "/socket.io").strip().lstrip("/")
+
+ try:
+ self._socket = client
+ await client.connect(
+ socket_url, transports=["websocket"], socketio_path=socket_path,
+ auth={"token": self.config.claw_token},
+ wait_timeout=max(1.0, self.config.socket_connect_timeout_ms / 1000.0),
+ )
+ return True
+ except Exception as e:
+ logger.error("Failed to connect Mochat websocket: {}", e)
+ try:
+ await client.disconnect()
+ except Exception:
+ pass
+ self._socket = None
+ return False
+
+ def _build_notify_handler(self, event_name: str):
+ async def handler(payload: Any) -> None:
+ if event_name == "notify:chat.inbox.append":
+ await self._handle_notify_inbox_append(payload)
+ elif event_name.startswith("notify:chat.message."):
+ await self._handle_notify_chat_message(payload)
+ return handler
+
+ # ---- subscribe ---------------------------------------------------------
+
+ async def _subscribe_all(self) -> bool:
+ ok = await self._subscribe_sessions(sorted(self._session_set))
+ ok = await self._subscribe_panels(sorted(self._panel_set)) and ok
+ if self._auto_discover_sessions or self._auto_discover_panels:
+ await self._refresh_targets(subscribe_new=True)
+ return ok
+
+ async def _subscribe_sessions(self, session_ids: list[str]) -> bool:
+ if not session_ids:
+ return True
+ for sid in session_ids:
+ if sid not in self._session_cursor:
+ self._cold_sessions.add(sid)
+
+ ack = await self._socket_call("com.claw.im.subscribeSessions", {
+ "sessionIds": session_ids, "cursors": self._session_cursor,
+ "limit": self.config.watch_limit,
+ })
+ if not ack.get("result"):
+ logger.error("Mochat subscribeSessions failed: {}", ack.get('message', 'unknown error'))
+ return False
+
+ data = ack.get("data")
+ items: list[dict[str, Any]] = []
+ if isinstance(data, list):
+ items = [i for i in data if isinstance(i, dict)]
+ elif isinstance(data, dict):
+ sessions = data.get("sessions")
+ if isinstance(sessions, list):
+ items = [i for i in sessions if isinstance(i, dict)]
+ elif "sessionId" in data:
+ items = [data]
+ for p in items:
+ await self._handle_watch_payload(p, "session")
+ return True
+
+ async def _subscribe_panels(self, panel_ids: list[str]) -> bool:
+ if not self._auto_discover_panels and not panel_ids:
+ return True
+ ack = await self._socket_call("com.claw.im.subscribePanels", {"panelIds": panel_ids})
+ if not ack.get("result"):
+ logger.error("Mochat subscribePanels failed: {}", ack.get('message', 'unknown error'))
+ return False
+ return True
+
+ async def _socket_call(self, event_name: str, payload: dict[str, Any]) -> dict[str, Any]:
+ if not self._socket:
+ return {"result": False, "message": "socket not connected"}
+ try:
+ raw = await self._socket.call(event_name, payload, timeout=10)
+ except Exception as e:
+ return {"result": False, "message": str(e)}
+ return raw if isinstance(raw, dict) else {"result": True, "data": raw}
+
+ # ---- refresh / discovery -----------------------------------------------
+
+ async def _refresh_loop(self) -> None:
+ interval_s = max(1.0, self.config.refresh_interval_ms / 1000.0)
+ while self._running:
+ await asyncio.sleep(interval_s)
+ try:
+ await self._refresh_targets(subscribe_new=self._ws_ready)
+ except Exception as e:
+ logger.warning("Mochat refresh failed: {}", e)
+ if self._fallback_mode:
+ await self._ensure_fallback_workers()
+
+ async def _refresh_targets(self, subscribe_new: bool) -> None:
+ if self._auto_discover_sessions:
+ await self._refresh_sessions_directory(subscribe_new)
+ if self._auto_discover_panels:
+ await self._refresh_panels(subscribe_new)
+
+ async def _refresh_sessions_directory(self, subscribe_new: bool) -> None:
+ try:
+ response = await self._post_json("/api/claw/sessions/list", {})
+ except Exception as e:
+ logger.warning("Mochat listSessions failed: {}", e)
+ return
+
+ sessions = response.get("sessions")
+ if not isinstance(sessions, list):
+ return
+
+ new_ids: list[str] = []
+ for s in sessions:
+ if not isinstance(s, dict):
+ continue
+ sid = _str_field(s, "sessionId")
+ if not sid:
+ continue
+ if sid not in self._session_set:
+ self._session_set.add(sid)
+ new_ids.append(sid)
+ if sid not in self._session_cursor:
+ self._cold_sessions.add(sid)
+ cid = _str_field(s, "converseId")
+ if cid:
+ self._session_by_converse[cid] = sid
+
+ if not new_ids:
+ return
+ if self._ws_ready and subscribe_new:
+ await self._subscribe_sessions(new_ids)
+ if self._fallback_mode:
+ await self._ensure_fallback_workers()
+
+ async def _refresh_panels(self, subscribe_new: bool) -> None:
+ try:
+ response = await self._post_json("/api/claw/groups/get", {})
+ except Exception as e:
+ logger.warning("Mochat getWorkspaceGroup failed: {}", e)
+ return
+
+ raw_panels = response.get("panels")
+ if not isinstance(raw_panels, list):
+ return
+
+ new_ids: list[str] = []
+ for p in raw_panels:
+ if not isinstance(p, dict):
+ continue
+ pt = p.get("type")
+ if isinstance(pt, int) and pt != 0:
+ continue
+ pid = _str_field(p, "id", "_id")
+ if pid and pid not in self._panel_set:
+ self._panel_set.add(pid)
+ new_ids.append(pid)
+
+ if not new_ids:
+ return
+ if self._ws_ready and subscribe_new:
+ await self._subscribe_panels(new_ids)
+ if self._fallback_mode:
+ await self._ensure_fallback_workers()
+
+ # ---- fallback workers --------------------------------------------------
+
+ async def _ensure_fallback_workers(self) -> None:
+ if not self._running:
+ return
+ self._fallback_mode = True
+ for sid in sorted(self._session_set):
+ t = self._session_fallback_tasks.get(sid)
+ if not t or t.done():
+ self._session_fallback_tasks[sid] = asyncio.create_task(self._session_watch_worker(sid))
+ for pid in sorted(self._panel_set):
+ t = self._panel_fallback_tasks.get(pid)
+ if not t or t.done():
+ self._panel_fallback_tasks[pid] = asyncio.create_task(self._panel_poll_worker(pid))
+
+ async def _stop_fallback_workers(self) -> None:
+ self._fallback_mode = False
+ tasks = [*self._session_fallback_tasks.values(), *self._panel_fallback_tasks.values()]
+ for t in tasks:
+ t.cancel()
+ if tasks:
+ await asyncio.gather(*tasks, return_exceptions=True)
+ self._session_fallback_tasks.clear()
+ self._panel_fallback_tasks.clear()
+
+ async def _session_watch_worker(self, session_id: str) -> None:
+ while self._running and self._fallback_mode:
+ try:
+ payload = await self._post_json("/api/claw/sessions/watch", {
+ "sessionId": session_id, "cursor": self._session_cursor.get(session_id, 0),
+ "timeoutMs": self.config.watch_timeout_ms, "limit": self.config.watch_limit,
+ })
+ await self._handle_watch_payload(payload, "session")
+ except asyncio.CancelledError:
+ break
+ except Exception as e:
+ logger.warning("Mochat watch fallback error ({}): {}", session_id, e)
+ await asyncio.sleep(max(0.1, self.config.retry_delay_ms / 1000.0))
+
+ async def _panel_poll_worker(self, panel_id: str) -> None:
+ sleep_s = max(1.0, self.config.refresh_interval_ms / 1000.0)
+ while self._running and self._fallback_mode:
+ try:
+ resp = await self._post_json("/api/claw/groups/panels/messages", {
+ "panelId": panel_id, "limit": min(100, max(1, self.config.watch_limit)),
+ })
+ msgs = resp.get("messages")
+ if isinstance(msgs, list):
+ for m in reversed(msgs):
+ if not isinstance(m, dict):
+ continue
+ evt = _make_synthetic_event(
+ message_id=str(m.get("messageId") or ""),
+ author=str(m.get("author") or ""),
+ content=m.get("content"),
+ meta=m.get("meta"), group_id=str(resp.get("groupId") or ""),
+ converse_id=panel_id, timestamp=m.get("createdAt"),
+ author_info=m.get("authorInfo"),
+ )
+ await self._process_inbound_event(panel_id, evt, "panel")
+ except asyncio.CancelledError:
+ break
+ except Exception as e:
+ logger.warning("Mochat panel polling error ({}): {}", panel_id, e)
+ await asyncio.sleep(sleep_s)
+
+ # ---- inbound event processing ------------------------------------------
+
+ async def _handle_watch_payload(self, payload: dict[str, Any], target_kind: str) -> None:
+ if not isinstance(payload, dict):
+ return
+ target_id = _str_field(payload, "sessionId")
+ if not target_id:
+ return
+
+ lock = self._target_locks.setdefault(f"{target_kind}:{target_id}", asyncio.Lock())
+ async with lock:
+ prev = self._session_cursor.get(target_id, 0) if target_kind == "session" else 0
+ pc = payload.get("cursor")
+ if target_kind == "session" and isinstance(pc, int) and pc >= 0:
+ self._mark_session_cursor(target_id, pc)
+
+ raw_events = payload.get("events")
+ if not isinstance(raw_events, list):
+ return
+ if target_kind == "session" and target_id in self._cold_sessions:
+ self._cold_sessions.discard(target_id)
+ return
+
+ for event in raw_events:
+ if not isinstance(event, dict):
+ continue
+ seq = event.get("seq")
+ if target_kind == "session" and isinstance(seq, int) and seq > self._session_cursor.get(target_id, prev):
+ self._mark_session_cursor(target_id, seq)
+ if event.get("type") == "message.add":
+ await self._process_inbound_event(target_id, event, target_kind)
+
+ async def _process_inbound_event(self, target_id: str, event: dict[str, Any], target_kind: str) -> None:
+ payload = event.get("payload")
+ if not isinstance(payload, dict):
+ return
+
+ author = _str_field(payload, "author")
+ if not author or (self.config.agent_user_id and author == self.config.agent_user_id):
+ return
+ if not self.is_allowed(author):
+ return
+
+ message_id = _str_field(payload, "messageId")
+ seen_key = f"{target_kind}:{target_id}"
+ if message_id and self._remember_message_id(seen_key, message_id):
+ return
+
+ raw_body = normalize_mochat_content(payload.get("content")) or "[empty message]"
+ ai = _safe_dict(payload.get("authorInfo"))
+ sender_name = _str_field(ai, "nickname", "email")
+ sender_username = _str_field(ai, "agentId")
+
+ group_id = _str_field(payload, "groupId")
+ is_group = bool(group_id)
+ was_mentioned = resolve_was_mentioned(payload, self.config.agent_user_id)
+ require_mention = target_kind == "panel" and is_group and resolve_require_mention(self.config, target_id, group_id)
+ use_delay = target_kind == "panel" and self.config.reply_delay_mode == "non-mention"
+
+ if require_mention and not was_mentioned and not use_delay:
+ return
+
+ entry = MochatBufferedEntry(
+ raw_body=raw_body, author=author, sender_name=sender_name,
+ sender_username=sender_username, timestamp=parse_timestamp(event.get("timestamp")),
+ message_id=message_id, group_id=group_id,
+ )
+
+ if use_delay:
+ delay_key = seen_key
+ if was_mentioned:
+ await self._flush_delayed_entries(delay_key, target_id, target_kind, "mention", entry)
+ else:
+ await self._enqueue_delayed_entry(delay_key, target_id, target_kind, entry)
+ return
+
+ await self._dispatch_entries(target_id, target_kind, [entry], was_mentioned)
+
+ # ---- dedup / buffering -------------------------------------------------
+
+ def _remember_message_id(self, key: str, message_id: str) -> bool:
+ seen_set = self._seen_set.setdefault(key, set())
+ seen_queue = self._seen_queue.setdefault(key, deque())
+ if message_id in seen_set:
+ return True
+ seen_set.add(message_id)
+ seen_queue.append(message_id)
+ while len(seen_queue) > MAX_SEEN_MESSAGE_IDS:
+ seen_set.discard(seen_queue.popleft())
+ return False
+
+ async def _enqueue_delayed_entry(self, key: str, target_id: str, target_kind: str, entry: MochatBufferedEntry) -> None:
+ state = self._delay_states.setdefault(key, DelayState())
+ async with state.lock:
+ state.entries.append(entry)
+ if state.timer:
+ state.timer.cancel()
+ state.timer = asyncio.create_task(self._delay_flush_after(key, target_id, target_kind))
+
+ async def _delay_flush_after(self, key: str, target_id: str, target_kind: str) -> None:
+ await asyncio.sleep(max(0, self.config.reply_delay_ms) / 1000.0)
+ await self._flush_delayed_entries(key, target_id, target_kind, "timer", None)
+
+ async def _flush_delayed_entries(self, key: str, target_id: str, target_kind: str, reason: str, entry: MochatBufferedEntry | None) -> None:
+ state = self._delay_states.setdefault(key, DelayState())
+ async with state.lock:
+ if entry:
+ state.entries.append(entry)
+ current = asyncio.current_task()
+ if state.timer and state.timer is not current:
+ state.timer.cancel()
+ state.timer = None
+ entries = state.entries[:]
+ state.entries.clear()
+ if entries:
+ await self._dispatch_entries(target_id, target_kind, entries, reason == "mention")
+
+ async def _dispatch_entries(self, target_id: str, target_kind: str, entries: list[MochatBufferedEntry], was_mentioned: bool) -> None:
+ if not entries:
+ return
+ last = entries[-1]
+ is_group = bool(last.group_id)
+ body = build_buffered_body(entries, is_group) or "[empty message]"
+ await self._handle_message(
+ sender_id=last.author, chat_id=target_id, content=body,
+ metadata={
+ "message_id": last.message_id, "timestamp": last.timestamp,
+ "is_group": is_group, "group_id": last.group_id,
+ "sender_name": last.sender_name, "sender_username": last.sender_username,
+ "target_kind": target_kind, "was_mentioned": was_mentioned,
+ "buffered_count": len(entries),
+ },
+ )
+
+ async def _cancel_delay_timers(self) -> None:
+ for state in self._delay_states.values():
+ if state.timer:
+ state.timer.cancel()
+ self._delay_states.clear()
+
+ # ---- notify handlers ---------------------------------------------------
+
+ async def _handle_notify_chat_message(self, payload: Any) -> None:
+ if not isinstance(payload, dict):
+ return
+ group_id = _str_field(payload, "groupId")
+ panel_id = _str_field(payload, "converseId", "panelId")
+ if not group_id or not panel_id:
+ return
+ if self._panel_set and panel_id not in self._panel_set:
+ return
+
+ evt = _make_synthetic_event(
+ message_id=str(payload.get("_id") or payload.get("messageId") or ""),
+ author=str(payload.get("author") or ""),
+ content=payload.get("content"), meta=payload.get("meta"),
+ group_id=group_id, converse_id=panel_id,
+ timestamp=payload.get("createdAt"), author_info=payload.get("authorInfo"),
+ )
+ await self._process_inbound_event(panel_id, evt, "panel")
+
+ async def _handle_notify_inbox_append(self, payload: Any) -> None:
+ if not isinstance(payload, dict) or payload.get("type") != "message":
+ return
+ detail = payload.get("payload")
+ if not isinstance(detail, dict):
+ return
+ if _str_field(detail, "groupId"):
+ return
+ converse_id = _str_field(detail, "converseId")
+ if not converse_id:
+ return
+
+ session_id = self._session_by_converse.get(converse_id)
+ if not session_id:
+ await self._refresh_sessions_directory(self._ws_ready)
+ session_id = self._session_by_converse.get(converse_id)
+ if not session_id:
+ return
+
+ evt = _make_synthetic_event(
+ message_id=str(detail.get("messageId") or payload.get("_id") or ""),
+ author=str(detail.get("messageAuthor") or ""),
+ content=str(detail.get("messagePlainContent") or detail.get("messageSnippet") or ""),
+ meta={"source": "notify:chat.inbox.append", "converseId": converse_id},
+ group_id="", converse_id=converse_id, timestamp=payload.get("createdAt"),
+ )
+ await self._process_inbound_event(session_id, evt, "session")
+
+ # ---- cursor persistence ------------------------------------------------
+
+ def _mark_session_cursor(self, session_id: str, cursor: int) -> None:
+ if cursor < 0 or cursor < self._session_cursor.get(session_id, 0):
+ return
+ self._session_cursor[session_id] = cursor
+ if not self._cursor_save_task or self._cursor_save_task.done():
+ self._cursor_save_task = asyncio.create_task(self._save_cursor_debounced())
+
+ async def _save_cursor_debounced(self) -> None:
+ await asyncio.sleep(CURSOR_SAVE_DEBOUNCE_S)
+ await self._save_session_cursors()
+
+ async def _load_session_cursors(self) -> None:
+ if not self._cursor_path.exists():
+ return
+ try:
+ data = json.loads(self._cursor_path.read_text("utf-8"))
+ except Exception as e:
+ logger.warning("Failed to read Mochat cursor file: {}", e)
+ return
+ cursors = data.get("cursors") if isinstance(data, dict) else None
+ if isinstance(cursors, dict):
+ for sid, cur in cursors.items():
+ if isinstance(sid, str) and isinstance(cur, int) and cur >= 0:
+ self._session_cursor[sid] = cur
+
+ async def _save_session_cursors(self) -> None:
+ try:
+ self._state_dir.mkdir(parents=True, exist_ok=True)
+ self._cursor_path.write_text(json.dumps({
+ "schemaVersion": 1, "updatedAt": datetime.utcnow().isoformat(),
+ "cursors": self._session_cursor,
+ }, ensure_ascii=False, indent=2) + "\n", "utf-8")
+ except Exception as e:
+ logger.warning("Failed to save Mochat cursor file: {}", e)
+
+ # ---- HTTP helpers ------------------------------------------------------
+
+ async def _post_json(self, path: str, payload: dict[str, Any]) -> dict[str, Any]:
+ if not self._http:
+ raise RuntimeError("Mochat HTTP client not initialized")
+ url = f"{self.config.base_url.strip().rstrip('/')}{path}"
+ response = await self._http.post(url, headers={
+ "Content-Type": "application/json", "X-Claw-Token": self.config.claw_token,
+ }, json=payload)
+ if not response.is_success:
+ raise RuntimeError(f"Mochat HTTP {response.status_code}: {response.text[:200]}")
+ try:
+ parsed = response.json()
+ except Exception:
+ parsed = response.text
+ if isinstance(parsed, dict) and isinstance(parsed.get("code"), int):
+ if parsed["code"] != 200:
+ msg = str(parsed.get("message") or parsed.get("name") or "request failed")
+ raise RuntimeError(f"Mochat API error: {msg} (code={parsed['code']})")
+ data = parsed.get("data")
+ return data if isinstance(data, dict) else {}
+ return parsed if isinstance(parsed, dict) else {}
+
+ async def _api_send(self, path: str, id_key: str, id_val: str,
+ content: str, reply_to: str | None, group_id: str | None = None) -> dict[str, Any]:
+ """Unified send helper for session and panel messages."""
+ body: dict[str, Any] = {id_key: id_val, "content": content}
+ if reply_to:
+ body["replyTo"] = reply_to
+ if group_id:
+ body["groupId"] = group_id
+ return await self._post_json(path, body)
+
+ @staticmethod
+ def _read_group_id(metadata: dict[str, Any]) -> str | None:
+ if not isinstance(metadata, dict):
+ return None
+ value = metadata.get("group_id") or metadata.get("groupId")
+ return value.strip() if isinstance(value, str) and value.strip() else None
diff --git a/nanobot/channels/qq.py b/nanobot/channels/qq.py
new file mode 100644
index 000000000..bef2cf27a
--- /dev/null
+++ b/nanobot/channels/qq.py
@@ -0,0 +1,651 @@
+"""QQ channel implementation using botpy SDK.
+
+Inbound:
+- Parse QQ botpy messages (C2C / Group)
+- Download attachments to media dir using chunked streaming write (memory-safe)
+- Publish to Nanobot bus via BaseChannel._handle_message()
+- Content includes a clear, actionable "Received files:" list with local paths
+
+Outbound:
+- Send attachments (msg.media) first via QQ rich media API (base64 upload + msg_type=7)
+- Then send text (plain or markdown)
+- msg.media supports local paths, file:// paths, and http(s) URLs
+
+Notes:
+- QQ restricts many audio/video formats. We conservatively classify as image vs file.
+- Attachment structures differ across botpy versions; we try multiple field candidates.
+"""
+
+from __future__ import annotations
+
+import asyncio
+import base64
+import mimetypes
+import os
+import re
+import time
+from collections import deque
+from pathlib import Path
+from typing import TYPE_CHECKING, Any, Literal
+from urllib.parse import unquote, urlparse
+
+import aiohttp
+from loguru import logger
+from pydantic import Field
+
+from nanobot.bus.events import OutboundMessage
+from nanobot.bus.queue import MessageBus
+from nanobot.channels.base import BaseChannel
+from nanobot.config.schema import Base
+from nanobot.security.network import validate_url_target
+
+try:
+ from nanobot.config.paths import get_media_dir
+except Exception: # pragma: no cover
+ get_media_dir = None # type: ignore
+
+try:
+ import botpy
+ from botpy.http import Route
+
+ QQ_AVAILABLE = True
+except ImportError: # pragma: no cover
+ QQ_AVAILABLE = False
+ botpy = None
+ Route = None
+
+if TYPE_CHECKING:
+ from botpy.message import BaseMessage, C2CMessage, GroupMessage
+ from botpy.types.message import Media
+
+
+# QQ rich media file_type: 1=image, 4=file
+# (2=voice, 3=video are restricted; we only use image vs file)
+QQ_FILE_TYPE_IMAGE = 1
+QQ_FILE_TYPE_FILE = 4
+
+_IMAGE_EXTS = {
+ ".png",
+ ".jpg",
+ ".jpeg",
+ ".gif",
+ ".bmp",
+ ".webp",
+ ".tif",
+ ".tiff",
+ ".ico",
+ ".svg",
+}
+
+# Replace unsafe characters with "_", keep Chinese and common safe punctuation.
+_SAFE_NAME_RE = re.compile(r"[^\w.\-()\[\]()【】\u4e00-\u9fff]+", re.UNICODE)
+
+
+def _sanitize_filename(name: str) -> str:
+ """Sanitize filename to avoid traversal and problematic chars."""
+ name = (name or "").strip()
+ name = Path(name).name
+ name = _SAFE_NAME_RE.sub("_", name).strip("._ ")
+ return name
+
+
+def _is_image_name(name: str) -> bool:
+ return Path(name).suffix.lower() in _IMAGE_EXTS
+
+
+def _guess_send_file_type(filename: str) -> int:
+ """Conservative send type: images -> 1, else -> 4."""
+ ext = Path(filename).suffix.lower()
+ mime, _ = mimetypes.guess_type(filename)
+ if ext in _IMAGE_EXTS or (mime and mime.startswith("image/")):
+ return QQ_FILE_TYPE_IMAGE
+ return QQ_FILE_TYPE_FILE
+
+
+def _make_bot_class(channel: QQChannel) -> type[botpy.Client]:
+ """Create a botpy Client subclass bound to the given channel."""
+ intents = botpy.Intents(public_messages=True, direct_message=True)
+
+ class _Bot(botpy.Client):
+ def __init__(self):
+ # Disable botpy's file log — nanobot uses loguru; default "botpy.log" fails on read-only fs
+ super().__init__(intents=intents, ext_handlers=False)
+
+ async def on_ready(self):
+ logger.info("QQ bot ready: {}", self.robot.name)
+
+ async def on_c2c_message_create(self, message: C2CMessage):
+ await channel._on_message(message, is_group=False)
+
+ async def on_group_at_message_create(self, message: GroupMessage):
+ await channel._on_message(message, is_group=True)
+
+ async def on_direct_message_create(self, message):
+ await channel._on_message(message, is_group=False)
+
+ return _Bot
+
+
+class QQConfig(Base):
+ """QQ channel configuration using botpy SDK."""
+
+ enabled: bool = False
+ app_id: str = ""
+ secret: str = ""
+ allow_from: list[str] = Field(default_factory=list)
+ msg_format: Literal["plain", "markdown"] = "plain"
+ ack_message: str = "⏳ Processing..."
+
+ # Optional: directory to save inbound attachments. If empty, use nanobot get_media_dir("qq").
+ media_dir: str = ""
+
+ # Download tuning
+ download_chunk_size: int = 1024 * 256 # 256KB
+ download_max_bytes: int = 1024 * 1024 * 200 # 200MB safety limit
+
+
+class QQChannel(BaseChannel):
+ """QQ channel using botpy SDK with WebSocket connection."""
+
+ name = "qq"
+ display_name = "QQ"
+
+ @classmethod
+ def default_config(cls) -> dict[str, Any]:
+ return QQConfig().model_dump(by_alias=True)
+
+ def __init__(self, config: Any, bus: MessageBus):
+ if isinstance(config, dict):
+ config = QQConfig.model_validate(config)
+ super().__init__(config, bus)
+ self.config: QQConfig = config
+
+ self._client: botpy.Client | None = None
+ self._http: aiohttp.ClientSession | None = None
+
+ self._processed_ids: deque[str] = deque(maxlen=1000)
+ self._msg_seq: int = 1 # used to avoid QQ API dedup
+ self._chat_type_cache: dict[str, str] = {}
+
+ self._media_root: Path = self._init_media_root()
+
+ # ---------------------------
+ # Lifecycle
+ # ---------------------------
+
+ def _init_media_root(self) -> Path:
+ """Choose a directory for saving inbound attachments."""
+ if self.config.media_dir:
+ root = Path(self.config.media_dir).expanduser()
+ elif get_media_dir:
+ try:
+ root = Path(get_media_dir("qq"))
+ except Exception:
+ root = Path.home() / ".nanobot" / "media" / "qq"
+ else:
+ root = Path.home() / ".nanobot" / "media" / "qq"
+
+ root.mkdir(parents=True, exist_ok=True)
+ logger.info("QQ media directory: {}", str(root))
+ return root
+
+ async def start(self) -> None:
+ """Start the QQ bot with auto-reconnect loop."""
+ if not QQ_AVAILABLE:
+ logger.error("QQ SDK not installed. Run: pip install qq-botpy")
+ return
+
+ if not self.config.app_id or not self.config.secret:
+ logger.error("QQ app_id and secret not configured")
+ return
+
+ self._running = True
+ self._http = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=120))
+
+ self._client = _make_bot_class(self)()
+ logger.info("QQ bot started (C2C & Group supported)")
+ await self._run_bot()
+
+ async def _run_bot(self) -> None:
+ """Run the bot connection with auto-reconnect."""
+ while self._running:
+ try:
+ await self._client.start(appid=self.config.app_id, secret=self.config.secret)
+ except Exception as e:
+ logger.warning("QQ bot error: {}", e)
+ if self._running:
+ logger.info("Reconnecting QQ bot in 5 seconds...")
+ await asyncio.sleep(5)
+
+ async def stop(self) -> None:
+ """Stop bot and cleanup resources."""
+ self._running = False
+ if self._client:
+ try:
+ await self._client.close()
+ except Exception:
+ pass
+ self._client = None
+
+ if self._http:
+ try:
+ await self._http.close()
+ except Exception:
+ pass
+ self._http = None
+
+ logger.info("QQ bot stopped")
+
+ # ---------------------------
+ # Outbound (send)
+ # ---------------------------
+
+ async def send(self, msg: OutboundMessage) -> None:
+ """Send attachments first, then text."""
+ if not self._client:
+ logger.warning("QQ client not initialized")
+ return
+
+ msg_id = msg.metadata.get("message_id")
+ chat_type = self._chat_type_cache.get(msg.chat_id, "c2c")
+ is_group = chat_type == "group"
+
+ # 1) Send media
+ for media_ref in msg.media or []:
+ ok = await self._send_media(
+ chat_id=msg.chat_id,
+ media_ref=media_ref,
+ msg_id=msg_id,
+ is_group=is_group,
+ )
+ if not ok:
+ filename = (
+ os.path.basename(urlparse(media_ref).path)
+ or os.path.basename(media_ref)
+ or "file"
+ )
+ await self._send_text_only(
+ chat_id=msg.chat_id,
+ is_group=is_group,
+ msg_id=msg_id,
+ content=f"[Attachment send failed: {filename}]",
+ )
+
+ # 2) Send text
+ if msg.content and msg.content.strip():
+ await self._send_text_only(
+ chat_id=msg.chat_id,
+ is_group=is_group,
+ msg_id=msg_id,
+ content=msg.content.strip(),
+ )
+
+ async def _send_text_only(
+ self,
+ chat_id: str,
+ is_group: bool,
+ msg_id: str | None,
+ content: str,
+ ) -> None:
+ """Send a plain/markdown text message."""
+ if not self._client:
+ return
+
+ self._msg_seq += 1
+ use_markdown = self.config.msg_format == "markdown"
+ payload: dict[str, Any] = {
+ "msg_type": 2 if use_markdown else 0,
+ "msg_id": msg_id,
+ "msg_seq": self._msg_seq,
+ }
+ if use_markdown:
+ payload["markdown"] = {"content": content}
+ else:
+ payload["content"] = content
+
+ if is_group:
+ await self._client.api.post_group_message(group_openid=chat_id, **payload)
+ else:
+ await self._client.api.post_c2c_message(openid=chat_id, **payload)
+
+ async def _send_media(
+ self,
+ chat_id: str,
+ media_ref: str,
+ msg_id: str | None,
+ is_group: bool,
+ ) -> bool:
+ """Read bytes -> base64 upload -> msg_type=7 send."""
+ if not self._client:
+ return False
+
+ data, filename = await self._read_media_bytes(media_ref)
+ if not data or not filename:
+ return False
+
+ try:
+ file_type = _guess_send_file_type(filename)
+ file_data_b64 = base64.b64encode(data).decode()
+
+ media_obj = await self._post_base64file(
+ chat_id=chat_id,
+ is_group=is_group,
+ file_type=file_type,
+ file_data=file_data_b64,
+ file_name=filename,
+ srv_send_msg=False,
+ )
+ if not media_obj:
+ logger.error("QQ media upload failed: empty response")
+ return False
+
+ self._msg_seq += 1
+ if is_group:
+ await self._client.api.post_group_message(
+ group_openid=chat_id,
+ msg_type=7,
+ msg_id=msg_id,
+ msg_seq=self._msg_seq,
+ media=media_obj,
+ )
+ else:
+ await self._client.api.post_c2c_message(
+ openid=chat_id,
+ msg_type=7,
+ msg_id=msg_id,
+ msg_seq=self._msg_seq,
+ media=media_obj,
+ )
+
+ logger.info("QQ media sent: {}", filename)
+ return True
+ except Exception as e:
+ logger.error("QQ send media failed filename={} err={}", filename, e)
+ return False
+
+ async def _read_media_bytes(self, media_ref: str) -> tuple[bytes | None, str | None]:
+ """Read bytes from http(s) or local file path; return (data, filename)."""
+ media_ref = (media_ref or "").strip()
+ if not media_ref:
+ return None, None
+
+ # Local file: plain path or file:// URI
+ if not media_ref.startswith("http://") and not media_ref.startswith("https://"):
+ try:
+ if media_ref.startswith("file://"):
+ parsed = urlparse(media_ref)
+ # Windows: path in netloc; Unix: path in path
+ raw = parsed.path or parsed.netloc
+ local_path = Path(unquote(raw))
+ else:
+ local_path = Path(os.path.expanduser(media_ref))
+
+ if not local_path.is_file():
+ logger.warning("QQ outbound media file not found: {}", str(local_path))
+ return None, None
+
+ data = await asyncio.to_thread(local_path.read_bytes)
+ return data, local_path.name
+ except Exception as e:
+ logger.warning("QQ outbound media read error ref={} err={}", media_ref, e)
+ return None, None
+
+ # Remote URL
+ ok, err = validate_url_target(media_ref)
+ if not ok:
+ logger.warning("QQ outbound media URL validation failed url={} err={}", media_ref, err)
+ return None, None
+
+ if not self._http:
+ self._http = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=120))
+ try:
+ async with self._http.get(media_ref, allow_redirects=True) as resp:
+ if resp.status >= 400:
+ logger.warning(
+ "QQ outbound media download failed status={} url={}",
+ resp.status,
+ media_ref,
+ )
+ return None, None
+ data = await resp.read()
+ if not data:
+ return None, None
+ filename = os.path.basename(urlparse(media_ref).path) or "file.bin"
+ return data, filename
+ except Exception as e:
+ logger.warning("QQ outbound media download error url={} err={}", media_ref, e)
+ return None, None
+
+ # https://github.com/tencent-connect/botpy/issues/198
+ # https://bot.q.qq.com/wiki/develop/api-v2/server-inter/message/send-receive/rich-media.html
+ async def _post_base64file(
+ self,
+ chat_id: str,
+ is_group: bool,
+ file_type: int,
+ file_data: str,
+ file_name: str | None = None,
+ srv_send_msg: bool = False,
+ ) -> Media:
+ """Upload base64-encoded file and return Media object."""
+ if not self._client:
+ raise RuntimeError("QQ client not initialized")
+
+ if is_group:
+ endpoint = "/v2/groups/{group_openid}/files"
+ id_key = "group_openid"
+ else:
+ endpoint = "/v2/users/{openid}/files"
+ id_key = "openid"
+
+ payload = {
+ id_key: chat_id,
+ "file_type": file_type,
+ "file_data": file_data,
+ "file_name": file_name,
+ "srv_send_msg": srv_send_msg,
+ }
+ route = Route("POST", endpoint, **{id_key: chat_id})
+ return await self._client.api._http.request(route, json=payload)
+
+ # ---------------------------
+ # Inbound (receive)
+ # ---------------------------
+
+ async def _on_message(self, data: C2CMessage | GroupMessage, is_group: bool = False) -> None:
+ """Parse inbound message, download attachments, and publish to the bus."""
+ if data.id in self._processed_ids:
+ return
+ self._processed_ids.append(data.id)
+
+ if is_group:
+ chat_id = data.group_openid
+ user_id = data.author.member_openid
+ self._chat_type_cache[chat_id] = "group"
+ else:
+ chat_id = str(
+ getattr(data.author, "id", None) or getattr(data.author, "user_openid", "unknown")
+ )
+ user_id = chat_id
+ self._chat_type_cache[chat_id] = "c2c"
+
+ content = (data.content or "").strip()
+
+ # the data used by tests don't contain attachments property
+ # so we use getattr with a default of [] to avoid AttributeError in tests
+ attachments = getattr(data, "attachments", None) or []
+ media_paths, recv_lines, att_meta = await self._handle_attachments(attachments)
+
+ # Compose content that always contains actionable saved paths
+ if recv_lines:
+ tag = "[Image]" if any(_is_image_name(Path(p).name) for p in media_paths) else "[File]"
+ file_block = "Received files:\n" + "\n".join(recv_lines)
+ content = f"{content}\n\n{file_block}".strip() if content else f"{tag}\n{file_block}"
+
+ if not content and not media_paths:
+ return
+
+ if self.config.ack_message:
+ try:
+ await self._send_text_only(
+ chat_id=chat_id,
+ is_group=is_group,
+ msg_id=data.id,
+ content=self.config.ack_message,
+ )
+ except Exception:
+ logger.debug("QQ ack message failed for chat_id={}", chat_id)
+
+ await self._handle_message(
+ sender_id=user_id,
+ chat_id=chat_id,
+ content=content,
+ media=media_paths if media_paths else None,
+ metadata={
+ "message_id": data.id,
+ "attachments": att_meta,
+ },
+ )
+
+ async def _handle_attachments(
+ self,
+ attachments: list[BaseMessage._Attachments],
+ ) -> tuple[list[str], list[str], list[dict[str, Any]]]:
+ """Extract, download (chunked), and format attachments for agent consumption."""
+ media_paths: list[str] = []
+ recv_lines: list[str] = []
+ att_meta: list[dict[str, Any]] = []
+
+ if not attachments:
+ return media_paths, recv_lines, att_meta
+
+ for att in attachments:
+ url, filename, ctype = att.url, att.filename, att.content_type
+
+ logger.info("Downloading file from QQ: {}", filename or url)
+ local_path = await self._download_to_media_dir_chunked(url, filename_hint=filename)
+
+ att_meta.append(
+ {
+ "url": url,
+ "filename": filename,
+ "content_type": ctype,
+ "saved_path": local_path,
+ }
+ )
+
+ if local_path:
+ media_paths.append(local_path)
+ shown_name = filename or os.path.basename(local_path)
+ recv_lines.append(f"- {shown_name}\n saved: {local_path}")
+ else:
+ shown_name = filename or url
+ recv_lines.append(f"- {shown_name}\n saved: [download failed]")
+
+ return media_paths, recv_lines, att_meta
+
+ async def _download_to_media_dir_chunked(
+ self,
+ url: str,
+ filename_hint: str = "",
+ ) -> str | None:
+ """Download an inbound attachment using streaming chunk write.
+
+ Uses chunked streaming to avoid loading large files into memory.
+ Enforces a max download size and writes to a .part temp file
+ that is atomically renamed on success.
+ """
+ if not self._http:
+ self._http = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=120))
+
+ safe = _sanitize_filename(filename_hint)
+ ts = int(time.time() * 1000)
+ tmp_path: Path | None = None
+
+ try:
+ async with self._http.get(
+ url,
+ timeout=aiohttp.ClientTimeout(total=120),
+ allow_redirects=True,
+ ) as resp:
+ if resp.status != 200:
+ logger.warning("QQ download failed: status={} url={}", resp.status, url)
+ return None
+
+ ctype = (resp.headers.get("Content-Type") or "").lower()
+
+ # Infer extension: url -> filename_hint -> content-type -> fallback
+ ext = Path(urlparse(url).path).suffix
+ if not ext:
+ ext = Path(filename_hint).suffix
+ if not ext:
+ if "png" in ctype:
+ ext = ".png"
+ elif "jpeg" in ctype or "jpg" in ctype:
+ ext = ".jpg"
+ elif "gif" in ctype:
+ ext = ".gif"
+ elif "webp" in ctype:
+ ext = ".webp"
+ elif "pdf" in ctype:
+ ext = ".pdf"
+ else:
+ ext = ".bin"
+
+ if safe:
+ if not Path(safe).suffix:
+ safe = safe + ext
+ filename = safe
+ else:
+ filename = f"qq_file_{ts}{ext}"
+
+ target = self._media_root / filename
+ if target.exists():
+ target = self._media_root / f"{target.stem}_{ts}{target.suffix}"
+
+ tmp_path = target.with_suffix(target.suffix + ".part")
+
+ # Stream write
+ downloaded = 0
+ chunk_size = max(1024, int(self.config.download_chunk_size or 262144))
+ max_bytes = max(
+ 1024 * 1024, int(self.config.download_max_bytes or (200 * 1024 * 1024))
+ )
+
+ def _open_tmp():
+ tmp_path.parent.mkdir(parents=True, exist_ok=True)
+ return open(tmp_path, "wb") # noqa: SIM115
+
+ f = await asyncio.to_thread(_open_tmp)
+ try:
+ async for chunk in resp.content.iter_chunked(chunk_size):
+ if not chunk:
+ continue
+ downloaded += len(chunk)
+ if downloaded > max_bytes:
+ logger.warning(
+ "QQ download exceeded max_bytes={} url={} -> abort",
+ max_bytes,
+ url,
+ )
+ return None
+ await asyncio.to_thread(f.write, chunk)
+ finally:
+ await asyncio.to_thread(f.close)
+
+ # Atomic rename
+ await asyncio.to_thread(os.replace, tmp_path, target)
+ tmp_path = None # mark as moved
+ logger.info("QQ file saved: {}", str(target))
+ return str(target)
+
+ except Exception as e:
+ logger.error("QQ download error: {}", e)
+ return None
+ finally:
+ # Cleanup partial file
+ if tmp_path is not None:
+ try:
+ tmp_path.unlink(missing_ok=True)
+ except Exception:
+ pass
diff --git a/nanobot/channels/registry.py b/nanobot/channels/registry.py
new file mode 100644
index 000000000..04effc77d
--- /dev/null
+++ b/nanobot/channels/registry.py
@@ -0,0 +1,71 @@
+"""Auto-discovery for built-in channel modules and external plugins."""
+
+from __future__ import annotations
+
+import importlib
+import pkgutil
+from typing import TYPE_CHECKING
+
+from loguru import logger
+
+if TYPE_CHECKING:
+ from nanobot.channels.base import BaseChannel
+
+_INTERNAL = frozenset({"base", "manager", "registry"})
+
+
+def discover_channel_names() -> list[str]:
+ """Return all built-in channel module names by scanning the package (zero imports)."""
+ import nanobot.channels as pkg
+
+ return [
+ name
+ for _, name, ispkg in pkgutil.iter_modules(pkg.__path__)
+ if name not in _INTERNAL and not ispkg
+ ]
+
+
+def load_channel_class(module_name: str) -> type[BaseChannel]:
+ """Import *module_name* and return the first BaseChannel subclass found."""
+ from nanobot.channels.base import BaseChannel as _Base
+
+ mod = importlib.import_module(f"nanobot.channels.{module_name}")
+ for attr in dir(mod):
+ obj = getattr(mod, attr)
+ if isinstance(obj, type) and issubclass(obj, _Base) and obj is not _Base:
+ return obj
+ raise ImportError(f"No BaseChannel subclass in nanobot.channels.{module_name}")
+
+
+def discover_plugins() -> dict[str, type[BaseChannel]]:
+ """Discover external channel plugins registered via entry_points."""
+ from importlib.metadata import entry_points
+
+ plugins: dict[str, type[BaseChannel]] = {}
+ for ep in entry_points(group="nanobot.channels"):
+ try:
+ cls = ep.load()
+ plugins[ep.name] = cls
+ except Exception as e:
+ logger.warning("Failed to load channel plugin '{}': {}", ep.name, e)
+ return plugins
+
+
+def discover_all() -> dict[str, type[BaseChannel]]:
+ """Return all channels: built-in (pkgutil) merged with external (entry_points).
+
+ Built-in channels take priority — an external plugin cannot shadow a built-in name.
+ """
+ builtin: dict[str, type[BaseChannel]] = {}
+ for modname in discover_channel_names():
+ try:
+ builtin[modname] = load_channel_class(modname)
+ except ImportError as e:
+ logger.debug("Skipping built-in channel '{}': {}", modname, e)
+
+ external = discover_plugins()
+ shadowed = set(external) & set(builtin)
+ if shadowed:
+ logger.warning("Plugin(s) shadowed by built-in channels (ignored): {}", shadowed)
+
+ return {**external, **builtin}
diff --git a/nanobot/channels/slack.py b/nanobot/channels/slack.py
new file mode 100644
index 000000000..2503f6a2d
--- /dev/null
+++ b/nanobot/channels/slack.py
@@ -0,0 +1,344 @@
+"""Slack channel implementation using Socket Mode."""
+
+import asyncio
+import re
+from typing import Any
+
+from loguru import logger
+from slack_sdk.socket_mode.request import SocketModeRequest
+from slack_sdk.socket_mode.response import SocketModeResponse
+from slack_sdk.socket_mode.websockets import SocketModeClient
+from slack_sdk.web.async_client import AsyncWebClient
+from slackify_markdown import slackify_markdown
+
+from nanobot.bus.events import OutboundMessage
+from nanobot.bus.queue import MessageBus
+from pydantic import Field
+
+from nanobot.channels.base import BaseChannel
+from nanobot.config.schema import Base
+
+
+class SlackDMConfig(Base):
+ """Slack DM policy configuration."""
+
+ enabled: bool = True
+ policy: str = "open"
+ allow_from: list[str] = Field(default_factory=list)
+
+
+class SlackConfig(Base):
+ """Slack channel configuration."""
+
+ enabled: bool = False
+ mode: str = "socket"
+ webhook_path: str = "/slack/events"
+ bot_token: str = ""
+ app_token: str = ""
+ user_token_read_only: bool = True
+ reply_in_thread: bool = True
+ react_emoji: str = "eyes"
+ done_emoji: str = "white_check_mark"
+ allow_from: list[str] = Field(default_factory=list)
+ group_policy: str = "mention"
+ group_allow_from: list[str] = Field(default_factory=list)
+ dm: SlackDMConfig = Field(default_factory=SlackDMConfig)
+
+
+class SlackChannel(BaseChannel):
+ """Slack channel using Socket Mode."""
+
+ name = "slack"
+ display_name = "Slack"
+
+ @classmethod
+ def default_config(cls) -> dict[str, Any]:
+ return SlackConfig().model_dump(by_alias=True)
+
+ def __init__(self, config: Any, bus: MessageBus):
+ if isinstance(config, dict):
+ config = SlackConfig.model_validate(config)
+ super().__init__(config, bus)
+ self.config: SlackConfig = config
+ self._web_client: AsyncWebClient | None = None
+ self._socket_client: SocketModeClient | None = None
+ self._bot_user_id: str | None = None
+
+ async def start(self) -> None:
+ """Start the Slack Socket Mode client."""
+ if not self.config.bot_token or not self.config.app_token:
+ logger.error("Slack bot/app token not configured")
+ return
+ if self.config.mode != "socket":
+ logger.error("Unsupported Slack mode: {}", self.config.mode)
+ return
+
+ self._running = True
+
+ self._web_client = AsyncWebClient(token=self.config.bot_token)
+ self._socket_client = SocketModeClient(
+ app_token=self.config.app_token,
+ web_client=self._web_client,
+ )
+
+ self._socket_client.socket_mode_request_listeners.append(self._on_socket_request)
+
+ # Resolve bot user ID for mention handling
+ try:
+ auth = await self._web_client.auth_test()
+ self._bot_user_id = auth.get("user_id")
+ logger.info("Slack bot connected as {}", self._bot_user_id)
+ except Exception as e:
+ logger.warning("Slack auth_test failed: {}", e)
+
+ logger.info("Starting Slack Socket Mode client...")
+ await self._socket_client.connect()
+
+ while self._running:
+ await asyncio.sleep(1)
+
+ async def stop(self) -> None:
+ """Stop the Slack client."""
+ self._running = False
+ if self._socket_client:
+ try:
+ await self._socket_client.close()
+ except Exception as e:
+ logger.warning("Slack socket close failed: {}", e)
+ self._socket_client = None
+
+ async def send(self, msg: OutboundMessage) -> None:
+ """Send a message through Slack."""
+ if not self._web_client:
+ logger.warning("Slack client not running")
+ return
+ try:
+ slack_meta = msg.metadata.get("slack", {}) if msg.metadata else {}
+ thread_ts = slack_meta.get("thread_ts")
+ channel_type = slack_meta.get("channel_type")
+ # Slack DMs don't use threads; channel/group replies may keep thread_ts.
+ thread_ts_param = thread_ts if thread_ts and channel_type != "im" else None
+
+ # Slack rejects empty text payloads. Keep media-only messages media-only,
+ # but send a single blank message when the bot has no text or files to send.
+ if msg.content or not (msg.media or []):
+ await self._web_client.chat_postMessage(
+ channel=msg.chat_id,
+ text=self._to_mrkdwn(msg.content) if msg.content else " ",
+ thread_ts=thread_ts_param,
+ )
+
+ for media_path in msg.media or []:
+ try:
+ await self._web_client.files_upload_v2(
+ channel=msg.chat_id,
+ file=media_path,
+ thread_ts=thread_ts_param,
+ )
+ except Exception as e:
+ logger.error("Failed to upload file {}: {}", media_path, e)
+
+ # Update reaction emoji when the final (non-progress) response is sent
+ if not (msg.metadata or {}).get("_progress"):
+ event = slack_meta.get("event", {})
+ await self._update_react_emoji(msg.chat_id, event.get("ts"))
+
+ except Exception as e:
+ logger.error("Error sending Slack message: {}", e)
+ raise
+
+ async def _on_socket_request(
+ self,
+ client: SocketModeClient,
+ req: SocketModeRequest,
+ ) -> None:
+ """Handle incoming Socket Mode requests."""
+ if req.type != "events_api":
+ return
+
+ # Acknowledge right away
+ await client.send_socket_mode_response(
+ SocketModeResponse(envelope_id=req.envelope_id)
+ )
+
+ payload = req.payload or {}
+ event = payload.get("event") or {}
+ event_type = event.get("type")
+
+ # Handle app mentions or plain messages
+ if event_type not in ("message", "app_mention"):
+ return
+
+ sender_id = event.get("user")
+ chat_id = event.get("channel")
+
+ # Ignore bot/system messages (any subtype = not a normal user message)
+ if event.get("subtype"):
+ return
+ if self._bot_user_id and sender_id == self._bot_user_id:
+ return
+
+ # Avoid double-processing: Slack sends both `message` and `app_mention`
+ # for mentions in channels. Prefer `app_mention`.
+ text = event.get("text") or ""
+ if event_type == "message" and self._bot_user_id and f"<@{self._bot_user_id}>" in text:
+ return
+
+ # Debug: log basic event shape
+ logger.debug(
+ "Slack event: type={} subtype={} user={} channel={} channel_type={} text={}",
+ event_type,
+ event.get("subtype"),
+ sender_id,
+ chat_id,
+ event.get("channel_type"),
+ text[:80],
+ )
+ if not sender_id or not chat_id:
+ return
+
+ channel_type = event.get("channel_type") or ""
+
+ if not self._is_allowed(sender_id, chat_id, channel_type):
+ return
+
+ if channel_type != "im" and not self._should_respond_in_channel(event_type, text, chat_id):
+ return
+
+ text = self._strip_bot_mention(text)
+
+ thread_ts = event.get("thread_ts")
+ if self.config.reply_in_thread and not thread_ts:
+ thread_ts = event.get("ts")
+ # Add :eyes: reaction to the triggering message (best-effort)
+ try:
+ if self._web_client and event.get("ts"):
+ await self._web_client.reactions_add(
+ channel=chat_id,
+ name=self.config.react_emoji,
+ timestamp=event.get("ts"),
+ )
+ except Exception as e:
+ logger.debug("Slack reactions_add failed: {}", e)
+
+ # Thread-scoped session key for channel/group messages
+ session_key = f"slack:{chat_id}:{thread_ts}" if thread_ts and channel_type != "im" else None
+
+ try:
+ await self._handle_message(
+ sender_id=sender_id,
+ chat_id=chat_id,
+ content=text,
+ metadata={
+ "slack": {
+ "event": event,
+ "thread_ts": thread_ts,
+ "channel_type": channel_type,
+ },
+ },
+ session_key=session_key,
+ )
+ except Exception:
+ logger.exception("Error handling Slack message from {}", sender_id)
+
+ async def _update_react_emoji(self, chat_id: str, ts: str | None) -> None:
+ """Remove the in-progress reaction and optionally add a done reaction."""
+ if not self._web_client or not ts:
+ return
+ try:
+ await self._web_client.reactions_remove(
+ channel=chat_id,
+ name=self.config.react_emoji,
+ timestamp=ts,
+ )
+ except Exception as e:
+ logger.debug("Slack reactions_remove failed: {}", e)
+ if self.config.done_emoji:
+ try:
+ await self._web_client.reactions_add(
+ channel=chat_id,
+ name=self.config.done_emoji,
+ timestamp=ts,
+ )
+ except Exception as e:
+ logger.debug("Slack done reaction failed: {}", e)
+
+ def _is_allowed(self, sender_id: str, chat_id: str, channel_type: str) -> bool:
+ if channel_type == "im":
+ if not self.config.dm.enabled:
+ return False
+ if self.config.dm.policy == "allowlist":
+ return sender_id in self.config.dm.allow_from
+ return True
+
+ # Group / channel messages
+ if self.config.group_policy == "allowlist":
+ return chat_id in self.config.group_allow_from
+ return True
+
+ def _should_respond_in_channel(self, event_type: str, text: str, chat_id: str) -> bool:
+ if self.config.group_policy == "open":
+ return True
+ if self.config.group_policy == "mention":
+ if event_type == "app_mention":
+ return True
+ return self._bot_user_id is not None and f"<@{self._bot_user_id}>" in text
+ if self.config.group_policy == "allowlist":
+ return chat_id in self.config.group_allow_from
+ return False
+
+ def _strip_bot_mention(self, text: str) -> str:
+ if not text or not self._bot_user_id:
+ return text
+ return re.sub(rf"<@{re.escape(self._bot_user_id)}>\s*", "", text).strip()
+
+ _TABLE_RE = re.compile(r"(?m)^\|.*\|$(?:\n\|[\s:|-]*\|$)(?:\n\|.*\|$)*")
+ _CODE_FENCE_RE = re.compile(r"```[\s\S]*?```")
+ _INLINE_CODE_RE = re.compile(r"`[^`]+`")
+ _LEFTOVER_BOLD_RE = re.compile(r"\*\*(.+?)\*\*")
+ _LEFTOVER_HEADER_RE = re.compile(r"^#{1,6}\s+(.+)$", re.MULTILINE)
+ _BARE_URL_RE = re.compile(r"(? str:
+ """Convert Markdown to Slack mrkdwn, including tables."""
+ if not text:
+ return ""
+ text = cls._TABLE_RE.sub(cls._convert_table, text)
+ return cls._fixup_mrkdwn(slackify_markdown(text))
+
+ @classmethod
+ def _fixup_mrkdwn(cls, text: str) -> str:
+ """Fix markdown artifacts that slackify_markdown misses."""
+ code_blocks: list[str] = []
+
+ def _save_code(m: re.Match) -> str:
+ code_blocks.append(m.group(0))
+ return f"\x00CB{len(code_blocks) - 1}\x00"
+
+ text = cls._CODE_FENCE_RE.sub(_save_code, text)
+ text = cls._INLINE_CODE_RE.sub(_save_code, text)
+ text = cls._LEFTOVER_BOLD_RE.sub(r"*\1*", text)
+ text = cls._LEFTOVER_HEADER_RE.sub(r"*\1*", text)
+ text = cls._BARE_URL_RE.sub(lambda m: m.group(0).replace("&", "&"), text)
+
+ for i, block in enumerate(code_blocks):
+ text = text.replace(f"\x00CB{i}\x00", block)
+ return text
+
+ @staticmethod
+ def _convert_table(match: re.Match) -> str:
+ """Convert a Markdown table to a Slack-readable list."""
+ lines = [ln.strip() for ln in match.group(0).strip().splitlines() if ln.strip()]
+ if len(lines) < 2:
+ return match.group(0)
+ headers = [h.strip() for h in lines[0].strip("|").split("|")]
+ start = 2 if re.fullmatch(r"[|\s:\-]+", lines[1]) else 1
+ rows: list[str] = []
+ for line in lines[start:]:
+ cells = [c.strip() for c in line.strip("|").split("|")]
+ cells = (cells + [""] * len(headers))[: len(headers)]
+ parts = [f"**{headers[i]}**: {cells[i]}" for i in range(len(headers)) if cells[i]]
+ if parts:
+ rows.append(" · ".join(parts))
+ return "\n".join(rows)
diff --git a/nanobot/channels/telegram.py b/nanobot/channels/telegram.py
index 23e1de00e..35f9ad620 100644
--- a/nanobot/channels/telegram.py
+++ b/nanobot/channels/telegram.py
@@ -1,16 +1,83 @@
"""Telegram channel implementation using python-telegram-bot."""
+from __future__ import annotations
+
import asyncio
import re
+import time
+import unicodedata
+from dataclasses import dataclass, field
+from typing import Any, Literal
from loguru import logger
-from telegram import Update
-from telegram.ext import Application, MessageHandler, filters, ContextTypes
+from pydantic import Field
+from telegram import BotCommand, ReactionTypeEmoji, ReplyParameters, Update
+from telegram.error import BadRequest, NetworkError, TimedOut
+from telegram.ext import Application, CommandHandler, ContextTypes, MessageHandler, filters
+from telegram.request import HTTPXRequest
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.base import BaseChannel
-from nanobot.config.schema import TelegramConfig
+from nanobot.command.builtin import build_help_text
+from nanobot.config.paths import get_media_dir
+from nanobot.config.schema import Base
+from nanobot.security.network import validate_url_target
+from nanobot.utils.helpers import split_message
+
+TELEGRAM_MAX_MESSAGE_LEN = 4000 # Telegram message character limit
+TELEGRAM_REPLY_CONTEXT_MAX_LEN = TELEGRAM_MAX_MESSAGE_LEN # Max length for reply context in user message
+
+
+def _escape_telegram_html(text: str) -> str:
+ """Escape text for Telegram HTML parse mode."""
+ return text.replace("&", "&").replace("<", "<").replace(">", ">")
+
+
+def _tool_hint_to_telegram_blockquote(text: str) -> str:
+ """Render tool hints as an expandable blockquote (collapsed by default)."""
+ return f"{_escape_telegram_html(text)}
" if text else ""
+
+
+def _strip_md(s: str) -> str:
+ """Strip markdown inline formatting from text."""
+ s = re.sub(r'\*\*(.+?)\*\*', r'\1', s)
+ s = re.sub(r'__(.+?)__', r'\1', s)
+ s = re.sub(r'~~(.+?)~~', r'\1', s)
+ s = re.sub(r'`([^`]+)`', r'\1', s)
+ return s.strip()
+
+
+def _render_table_box(table_lines: list[str]) -> str:
+ """Convert markdown pipe-table to compact aligned text for display."""
+
+ def dw(s: str) -> int:
+ return sum(2 if unicodedata.east_asian_width(c) in ('W', 'F') else 1 for c in s)
+
+ rows: list[list[str]] = []
+ has_sep = False
+ for line in table_lines:
+ cells = [_strip_md(c) for c in line.strip().strip('|').split('|')]
+ if all(re.match(r'^:?-+:?$', c) for c in cells if c):
+ has_sep = True
+ continue
+ rows.append(cells)
+ if not rows or not has_sep:
+ return '\n'.join(table_lines)
+
+ ncols = max(len(r) for r in rows)
+ for r in rows:
+ r.extend([''] * (ncols - len(r)))
+ widths = [max(dw(r[c]) for r in rows) for c in range(ncols)]
+
+ def dr(cells: list[str]) -> str:
+ return ' '.join(f'{c}{" " * (w - dw(c))}' for c, w in zip(cells, widths))
+
+ out = [dr(rows[0])]
+ out.append(' '.join('─' * w for w in widths))
+ for row in rows[1:]:
+ out.append(dr(row))
+ return '\n'.join(out)
def _markdown_to_telegram_html(text: str) -> str:
@@ -19,277 +86,964 @@ def _markdown_to_telegram_html(text: str) -> str:
"""
if not text:
return ""
-
+
# 1. Extract and protect code blocks (preserve content from other processing)
code_blocks: list[str] = []
def save_code_block(m: re.Match) -> str:
code_blocks.append(m.group(1))
return f"\x00CB{len(code_blocks) - 1}\x00"
-
+
text = re.sub(r'```[\w]*\n?([\s\S]*?)```', save_code_block, text)
-
+
+ # 1.5. Convert markdown tables to box-drawing (reuse code_block placeholders)
+ lines = text.split('\n')
+ rebuilt: list[str] = []
+ li = 0
+ while li < len(lines):
+ if re.match(r'^\s*\|.+\|', lines[li]):
+ tbl: list[str] = []
+ while li < len(lines) and re.match(r'^\s*\|.+\|', lines[li]):
+ tbl.append(lines[li])
+ li += 1
+ box = _render_table_box(tbl)
+ if box != '\n'.join(tbl):
+ code_blocks.append(box)
+ rebuilt.append(f"\x00CB{len(code_blocks) - 1}\x00")
+ else:
+ rebuilt.extend(tbl)
+ else:
+ rebuilt.append(lines[li])
+ li += 1
+ text = '\n'.join(rebuilt)
+
# 2. Extract and protect inline code
inline_codes: list[str] = []
def save_inline_code(m: re.Match) -> str:
inline_codes.append(m.group(1))
return f"\x00IC{len(inline_codes) - 1}\x00"
-
+
text = re.sub(r'`([^`]+)`', save_inline_code, text)
-
+
# 3. Headers # Title -> just the title text
text = re.sub(r'^#{1,6}\s+(.+)$', r'\1', text, flags=re.MULTILINE)
-
+
# 4. Blockquotes > text -> just the text (before HTML escaping)
text = re.sub(r'^>\s*(.*)$', r'\1', text, flags=re.MULTILINE)
-
+
# 5. Escape HTML special characters
- text = text.replace("&", "&").replace("<", "<").replace(">", ">")
-
+ text = _escape_telegram_html(text)
+
# 6. Links [text](url) - must be before bold/italic to handle nested cases
text = re.sub(r'\[([^\]]+)\]\(([^)]+)\)', r'\1', text)
-
+
# 7. Bold **text** or __text__
text = re.sub(r'\*\*(.+?)\*\*', r'\1', text)
text = re.sub(r'__(.+?)__', r'\1', text)
-
+
# 8. Italic _text_ (avoid matching inside words like some_var_name)
text = re.sub(r'(?\1', text)
-
+
# 9. Strikethrough ~~text~~
text = re.sub(r'~~(.+?)~~', r'\1', text)
-
+
# 10. Bullet lists - item -> • item
text = re.sub(r'^[-*]\s+', '• ', text, flags=re.MULTILINE)
-
+
# 11. Restore inline code with HTML tags
for i, code in enumerate(inline_codes):
# Escape HTML in code content
- escaped = code.replace("&", "&").replace("<", "<").replace(">", ">")
+ escaped = _escape_telegram_html(code)
text = text.replace(f"\x00IC{i}\x00", f"{escaped}")
-
+
# 12. Restore code blocks with HTML tags
for i, code in enumerate(code_blocks):
# Escape HTML in code content
- escaped = code.replace("&", "&").replace("<", "<").replace(">", ">")
+ escaped = _escape_telegram_html(code)
text = text.replace(f"\x00CB{i}\x00", f"{escaped}
")
-
+
return text
+_SEND_MAX_RETRIES = 3
+_SEND_RETRY_BASE_DELAY = 0.5 # seconds, doubled each retry
+
+
+@dataclass
+class _StreamBuf:
+ """Per-chat streaming accumulator for progressive message editing."""
+ text: str = ""
+ message_id: int | None = None
+ last_edit: float = 0.0
+ stream_id: str | None = None
+
+
+class TelegramConfig(Base):
+ """Telegram channel configuration."""
+
+ enabled: bool = False
+ token: str = ""
+ allow_from: list[str] = Field(default_factory=list)
+ proxy: str | None = None
+ reply_to_message: bool = False
+ react_emoji: str = "👀"
+ group_policy: Literal["open", "mention"] = "mention"
+ connection_pool_size: int = 32
+ pool_timeout: float = 5.0
+ streaming: bool = True
+
+
class TelegramChannel(BaseChannel):
"""
Telegram channel using long polling.
-
+
Simple and reliable - no webhook/public IP needed.
"""
-
+
name = "telegram"
-
- def __init__(self, config: TelegramConfig, bus: MessageBus, groq_api_key: str = ""):
+ display_name = "Telegram"
+
+ # Commands registered with Telegram's command menu
+ BOT_COMMANDS = [
+ BotCommand("start", "Start the bot"),
+ BotCommand("new", "Start a new conversation"),
+ BotCommand("stop", "Stop the current task"),
+ BotCommand("restart", "Restart the bot"),
+ BotCommand("status", "Show bot status"),
+ BotCommand("dream", "Run Dream memory consolidation now"),
+ BotCommand("dream_log", "Show the latest Dream memory change"),
+ BotCommand("dream_restore", "Restore Dream memory to an earlier version"),
+ BotCommand("help", "Show available commands"),
+ ]
+
+ @classmethod
+ def default_config(cls) -> dict[str, Any]:
+ return TelegramConfig().model_dump(by_alias=True)
+
+ _STREAM_EDIT_INTERVAL = 0.6 # min seconds between edit_message_text calls
+
+ def __init__(self, config: Any, bus: MessageBus):
+ if isinstance(config, dict):
+ config = TelegramConfig.model_validate(config)
super().__init__(config, bus)
self.config: TelegramConfig = config
- self.groq_api_key = groq_api_key
self._app: Application | None = None
self._chat_ids: dict[str, int] = {} # Map sender_id to chat_id for replies
-
+ self._typing_tasks: dict[str, asyncio.Task] = {} # chat_id -> typing loop task
+ self._media_group_buffers: dict[str, dict] = {}
+ self._media_group_tasks: dict[str, asyncio.Task] = {}
+ self._message_threads: dict[tuple[str, int], int] = {}
+ self._bot_user_id: int | None = None
+ self._bot_username: str | None = None
+ self._stream_bufs: dict[str, _StreamBuf] = {} # chat_id -> streaming state
+
+ def is_allowed(self, sender_id: str) -> bool:
+ """Preserve Telegram's legacy id|username allowlist matching."""
+ if super().is_allowed(sender_id):
+ return True
+
+ allow_list = getattr(self.config, "allow_from", [])
+ if not allow_list or "*" in allow_list:
+ return False
+
+ sender_str = str(sender_id)
+ if sender_str.count("|") != 1:
+ return False
+
+ sid, username = sender_str.split("|", 1)
+ if not sid.isdigit() or not username:
+ return False
+
+ return sid in allow_list or username in allow_list
+
+ @staticmethod
+ def _normalize_telegram_command(content: str) -> str:
+ """Map Telegram-safe command aliases back to canonical nanobot commands."""
+ if not content.startswith("/"):
+ return content
+ if content == "/dream_log" or content.startswith("/dream_log "):
+ return content.replace("/dream_log", "/dream-log", 1)
+ if content == "/dream_restore" or content.startswith("/dream_restore "):
+ return content.replace("/dream_restore", "/dream-restore", 1)
+ return content
+
async def start(self) -> None:
"""Start the Telegram bot with long polling."""
if not self.config.token:
logger.error("Telegram bot token not configured")
return
-
+
self._running = True
-
- # Build the application
- self._app = (
+
+ proxy = self.config.proxy or None
+
+ # Separate pools so long-polling (getUpdates) never starves outbound sends.
+ api_request = HTTPXRequest(
+ connection_pool_size=self.config.connection_pool_size,
+ pool_timeout=self.config.pool_timeout,
+ connect_timeout=30.0,
+ read_timeout=30.0,
+ proxy=proxy,
+ )
+ poll_request = HTTPXRequest(
+ connection_pool_size=4,
+ pool_timeout=self.config.pool_timeout,
+ connect_timeout=30.0,
+ read_timeout=30.0,
+ proxy=proxy,
+ )
+ builder = (
Application.builder()
.token(self.config.token)
- .build()
+ .request(api_request)
+ .get_updates_request(poll_request)
)
-
+ self._app = builder.build()
+ self._app.add_error_handler(self._on_error)
+
+ # Add command handlers (using Regex to support @username suffixes before bot initialization)
+ self._app.add_handler(MessageHandler(filters.Regex(r"^/start(?:@\w+)?$"), self._on_start))
+ self._app.add_handler(
+ MessageHandler(
+ filters.Regex(r"^/(new|stop|restart|status|dream)(?:@\w+)?(?:\s+.*)?$"),
+ self._forward_command,
+ )
+ )
+ self._app.add_handler(
+ MessageHandler(
+ filters.Regex(r"^/(dream-log|dream_log|dream-restore|dream_restore)(?:@\w+)?(?:\s+.*)?$"),
+ self._forward_command,
+ )
+ )
+ self._app.add_handler(MessageHandler(filters.Regex(r"^/help(?:@\w+)?$"), self._on_help))
+
# Add message handler for text, photos, voice, documents
self._app.add_handler(
MessageHandler(
- (filters.TEXT | filters.PHOTO | filters.VOICE | filters.AUDIO | filters.Document.ALL)
- & ~filters.COMMAND,
+ (filters.TEXT | filters.PHOTO | filters.VOICE | filters.AUDIO | filters.Document.ALL)
+ & ~filters.COMMAND,
self._on_message
)
)
-
- # Add /start command handler
- from telegram.ext import CommandHandler
- self._app.add_handler(CommandHandler("start", self._on_start))
-
+
logger.info("Starting Telegram bot (polling mode)...")
-
+
# Initialize and start polling
await self._app.initialize()
await self._app.start()
-
- # Get bot info
+
+ # Get bot info and register command menu
bot_info = await self._app.bot.get_me()
- logger.info(f"Telegram bot @{bot_info.username} connected")
-
+ self._bot_user_id = getattr(bot_info, "id", None)
+ self._bot_username = getattr(bot_info, "username", None)
+ logger.info("Telegram bot @{} connected", bot_info.username)
+
+ try:
+ await self._app.bot.set_my_commands(self.BOT_COMMANDS)
+ logger.debug("Telegram bot commands registered")
+ except Exception as e:
+ logger.warning("Failed to register bot commands: {}", e)
+
# Start polling (this runs until stopped)
await self._app.updater.start_polling(
allowed_updates=["message"],
- drop_pending_updates=True # Ignore old messages on startup
+ drop_pending_updates=False, # Process pending messages on startup
+ error_callback=self._on_polling_error,
)
-
+
# Keep running until stopped
while self._running:
await asyncio.sleep(1)
-
+
async def stop(self) -> None:
"""Stop the Telegram bot."""
self._running = False
-
+
+ # Cancel all typing indicators
+ for chat_id in list(self._typing_tasks):
+ self._stop_typing(chat_id)
+
+ for task in self._media_group_tasks.values():
+ task.cancel()
+ self._media_group_tasks.clear()
+ self._media_group_buffers.clear()
+
if self._app:
logger.info("Stopping Telegram bot...")
await self._app.updater.stop()
await self._app.stop()
await self._app.shutdown()
self._app = None
-
+
+ @staticmethod
+ def _get_media_type(path: str) -> str:
+ """Guess media type from file extension."""
+ ext = path.rsplit(".", 1)[-1].lower() if "." in path else ""
+ if ext in ("jpg", "jpeg", "png", "gif", "webp"):
+ return "photo"
+ if ext == "ogg":
+ return "voice"
+ if ext in ("mp3", "m4a", "wav", "aac"):
+ return "audio"
+ return "document"
+
+ @staticmethod
+ def _is_remote_media_url(path: str) -> bool:
+ return path.startswith(("http://", "https://"))
+
async def send(self, msg: OutboundMessage) -> None:
"""Send a message through Telegram."""
if not self._app:
logger.warning("Telegram bot not running")
return
-
+
+ # Only stop typing indicator and remove reaction for final responses
+ if not msg.metadata.get("_progress", False):
+ self._stop_typing(msg.chat_id)
+ if reply_to_message_id := msg.metadata.get("message_id"):
+ try:
+ await self._remove_reaction(msg.chat_id, int(reply_to_message_id))
+ except ValueError:
+ pass
+
try:
- # chat_id should be the Telegram chat ID (integer)
chat_id = int(msg.chat_id)
- # Convert markdown to Telegram HTML
- html_content = _markdown_to_telegram_html(msg.content)
- await self._app.bot.send_message(
- chat_id=chat_id,
- text=html_content,
- parse_mode="HTML"
- )
except ValueError:
- logger.error(f"Invalid chat_id: {msg.chat_id}")
- except Exception as e:
- # Fallback to plain text if HTML parsing fails
- logger.warning(f"HTML parse failed, falling back to plain text: {e}")
+ logger.error("Invalid chat_id: {}", msg.chat_id)
+ return
+ reply_to_message_id = msg.metadata.get("message_id")
+ message_thread_id = msg.metadata.get("message_thread_id")
+ if message_thread_id is None and reply_to_message_id is not None:
+ message_thread_id = self._message_threads.get((msg.chat_id, reply_to_message_id))
+ thread_kwargs = {}
+ if message_thread_id is not None:
+ thread_kwargs["message_thread_id"] = message_thread_id
+
+ reply_params = None
+ if self.config.reply_to_message:
+ if reply_to_message_id:
+ reply_params = ReplyParameters(
+ message_id=reply_to_message_id,
+ allow_sending_without_reply=True
+ )
+
+ # Send media files
+ for media_path in (msg.media or []):
try:
+ media_type = self._get_media_type(media_path)
+ sender = {
+ "photo": self._app.bot.send_photo,
+ "voice": self._app.bot.send_voice,
+ "audio": self._app.bot.send_audio,
+ }.get(media_type, self._app.bot.send_document)
+ param = "photo" if media_type == "photo" else media_type if media_type in ("voice", "audio") else "document"
+
+ # Telegram Bot API accepts HTTP(S) URLs directly for media params.
+ if self._is_remote_media_url(media_path):
+ ok, error = validate_url_target(media_path)
+ if not ok:
+ raise ValueError(f"unsafe media URL: {error}")
+ await self._call_with_retry(
+ sender,
+ chat_id=chat_id,
+ **{param: media_path},
+ reply_parameters=reply_params,
+ **thread_kwargs,
+ )
+ continue
+
+ with open(media_path, "rb") as f:
+ await sender(
+ chat_id=chat_id,
+ **{param: f},
+ reply_parameters=reply_params,
+ **thread_kwargs,
+ )
+ except Exception as e:
+ filename = media_path.rsplit("/", 1)[-1]
+ logger.error("Failed to send media {}: {}", media_path, e)
await self._app.bot.send_message(
- chat_id=int(msg.chat_id),
- text=msg.content
+ chat_id=chat_id,
+ text=f"[Failed to send: {filename}]",
+ reply_parameters=reply_params,
+ **thread_kwargs,
+ )
+
+ # Send text content
+ if msg.content and msg.content != "[empty message]":
+ render_as_blockquote = bool(msg.metadata.get("_tool_hint"))
+ for chunk in split_message(msg.content, TELEGRAM_MAX_MESSAGE_LEN):
+ await self._send_text(
+ chat_id, chunk, reply_params, thread_kwargs,
+ render_as_blockquote=render_as_blockquote,
+ )
+
+ async def _call_with_retry(self, fn, *args, **kwargs):
+ """Call an async Telegram API function with retry on pool/network timeout and RetryAfter."""
+ from telegram.error import RetryAfter
+
+ for attempt in range(1, _SEND_MAX_RETRIES + 1):
+ try:
+ return await fn(*args, **kwargs)
+ except TimedOut:
+ if attempt == _SEND_MAX_RETRIES:
+ raise
+ delay = _SEND_RETRY_BASE_DELAY * (2 ** (attempt - 1))
+ logger.warning(
+ "Telegram timeout (attempt {}/{}), retrying in {:.1f}s",
+ attempt, _SEND_MAX_RETRIES, delay,
+ )
+ await asyncio.sleep(delay)
+ except RetryAfter as e:
+ if attempt == _SEND_MAX_RETRIES:
+ raise
+ delay = float(e.retry_after)
+ logger.warning(
+ "Telegram Flood Control (attempt {}/{}), retrying in {:.1f}s",
+ attempt, _SEND_MAX_RETRIES, delay,
+ )
+ await asyncio.sleep(delay)
+
+ async def _send_text(
+ self,
+ chat_id: int,
+ text: str,
+ reply_params=None,
+ thread_kwargs: dict | None = None,
+ render_as_blockquote: bool = False,
+ ) -> None:
+ """Send a plain text message with HTML fallback."""
+ try:
+ html = _tool_hint_to_telegram_blockquote(text) if render_as_blockquote else _markdown_to_telegram_html(text)
+ await self._call_with_retry(
+ self._app.bot.send_message,
+ chat_id=chat_id, text=html, parse_mode="HTML",
+ reply_parameters=reply_params,
+ **(thread_kwargs or {}),
+ )
+ except Exception as e:
+ logger.warning("HTML parse failed, falling back to plain text: {}", e)
+ try:
+ await self._call_with_retry(
+ self._app.bot.send_message,
+ chat_id=chat_id,
+ text=text,
+ reply_parameters=reply_params,
+ **(thread_kwargs or {}),
)
except Exception as e2:
- logger.error(f"Error sending Telegram message: {e2}")
-
+ logger.error("Error sending Telegram message: {}", e2)
+ raise
+
+ @staticmethod
+ def _is_not_modified_error(exc: Exception) -> bool:
+ return isinstance(exc, BadRequest) and "message is not modified" in str(exc).lower()
+
+ async def send_delta(self, chat_id: str, delta: str, metadata: dict[str, Any] | None = None) -> None:
+ """Progressive message editing: send on first delta, edit on subsequent ones."""
+ if not self._app:
+ return
+ meta = metadata or {}
+ int_chat_id = int(chat_id)
+ stream_id = meta.get("_stream_id")
+
+ if meta.get("_stream_end"):
+ buf = self._stream_bufs.get(chat_id)
+ if not buf or not buf.message_id or not buf.text:
+ return
+ if stream_id is not None and buf.stream_id is not None and buf.stream_id != stream_id:
+ return
+ self._stop_typing(chat_id)
+ if reply_to_message_id := meta.get("message_id"):
+ try:
+ await self._remove_reaction(chat_id, int(reply_to_message_id))
+ except ValueError:
+ pass
+ try:
+ html = _markdown_to_telegram_html(buf.text)
+ await self._call_with_retry(
+ self._app.bot.edit_message_text,
+ chat_id=int_chat_id, message_id=buf.message_id,
+ text=html, parse_mode="HTML",
+ )
+ except Exception as e:
+ if self._is_not_modified_error(e):
+ logger.debug("Final stream edit already applied for {}", chat_id)
+ self._stream_bufs.pop(chat_id, None)
+ return
+ logger.debug("Final stream edit failed (HTML), trying plain: {}", e)
+ try:
+ await self._call_with_retry(
+ self._app.bot.edit_message_text,
+ chat_id=int_chat_id, message_id=buf.message_id,
+ text=buf.text,
+ )
+ except Exception as e2:
+ if self._is_not_modified_error(e2):
+ logger.debug("Final stream plain edit already applied for {}", chat_id)
+ self._stream_bufs.pop(chat_id, None)
+ return
+ logger.warning("Final stream edit failed: {}", e2)
+ raise # Let ChannelManager handle retry
+ self._stream_bufs.pop(chat_id, None)
+ return
+
+ buf = self._stream_bufs.get(chat_id)
+ if buf is None or (stream_id is not None and buf.stream_id is not None and buf.stream_id != stream_id):
+ buf = _StreamBuf(stream_id=stream_id)
+ self._stream_bufs[chat_id] = buf
+ elif buf.stream_id is None:
+ buf.stream_id = stream_id
+ buf.text += delta
+
+ if not buf.text.strip():
+ return
+
+ now = time.monotonic()
+ thread_kwargs = {}
+ if message_thread_id := meta.get("message_thread_id"):
+ thread_kwargs["message_thread_id"] = message_thread_id
+ if buf.message_id is None:
+ try:
+ sent = await self._call_with_retry(
+ self._app.bot.send_message,
+ chat_id=int_chat_id, text=buf.text,
+ **thread_kwargs,
+ )
+ buf.message_id = sent.message_id
+ buf.last_edit = now
+ except Exception as e:
+ logger.warning("Stream initial send failed: {}", e)
+ raise # Let ChannelManager handle retry
+ elif (now - buf.last_edit) >= self._STREAM_EDIT_INTERVAL:
+ try:
+ await self._call_with_retry(
+ self._app.bot.edit_message_text,
+ chat_id=int_chat_id, message_id=buf.message_id,
+ text=buf.text,
+ )
+ buf.last_edit = now
+ except Exception as e:
+ if self._is_not_modified_error(e):
+ buf.last_edit = now
+ return
+ logger.warning("Stream edit failed: {}", e)
+ raise # Let ChannelManager handle retry
+
async def _on_start(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
"""Handle /start command."""
if not update.message or not update.effective_user:
return
-
+
user = update.effective_user
await update.message.reply_text(
f"👋 Hi {user.first_name}! I'm nanobot.\n\n"
- "Send me a message and I'll respond!"
+ "Send me a message and I'll respond!\n"
+ "Type /help to see available commands."
)
-
+
+ async def _on_help(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
+ """Handle /help command, bypassing ACL so all users can access it."""
+ if not update.message:
+ return
+ await update.message.reply_text(build_help_text())
+
+ @staticmethod
+ def _sender_id(user) -> str:
+ """Build sender_id with username for allowlist matching."""
+ sid = str(user.id)
+ return f"{sid}|{user.username}" if user.username else sid
+
+ @staticmethod
+ def _derive_topic_session_key(message) -> str | None:
+ """Derive topic-scoped session key for Telegram chats with threads."""
+ message_thread_id = getattr(message, "message_thread_id", None)
+ if message_thread_id is None:
+ return None
+ return f"telegram:{message.chat_id}:topic:{message_thread_id}"
+
+ @staticmethod
+ def _build_message_metadata(message, user) -> dict:
+ """Build common Telegram inbound metadata payload."""
+ reply_to = getattr(message, "reply_to_message", None)
+ return {
+ "message_id": message.message_id,
+ "user_id": user.id,
+ "username": user.username,
+ "first_name": user.first_name,
+ "is_group": message.chat.type != "private",
+ "message_thread_id": getattr(message, "message_thread_id", None),
+ "is_forum": bool(getattr(message.chat, "is_forum", False)),
+ "reply_to_message_id": getattr(reply_to, "message_id", None) if reply_to else None,
+ }
+
+ async def _extract_reply_context(self, message) -> str | None:
+ """Extract text from the message being replied to, if any."""
+ reply = getattr(message, "reply_to_message", None)
+ if not reply:
+ return None
+ text = getattr(reply, "text", None) or getattr(reply, "caption", None) or ""
+ if len(text) > TELEGRAM_REPLY_CONTEXT_MAX_LEN:
+ text = text[:TELEGRAM_REPLY_CONTEXT_MAX_LEN] + "..."
+
+ if not text:
+ return None
+
+ bot_id, _ = await self._ensure_bot_identity()
+ reply_user = getattr(reply, "from_user", None)
+
+ if bot_id and reply_user and getattr(reply_user, "id", None) == bot_id:
+ return f"[Reply to bot: {text}]"
+ elif reply_user and getattr(reply_user, "username", None):
+ return f"[Reply to @{reply_user.username}: {text}]"
+ elif reply_user and getattr(reply_user, "first_name", None):
+ return f"[Reply to {reply_user.first_name}: {text}]"
+ else:
+ return f"[Reply to: {text}]"
+
+ async def _download_message_media(
+ self, msg, *, add_failure_content: bool = False
+ ) -> tuple[list[str], list[str]]:
+ """Download media from a message (current or reply). Returns (media_paths, content_parts)."""
+ media_file = None
+ media_type = None
+ if getattr(msg, "photo", None):
+ media_file = msg.photo[-1]
+ media_type = "image"
+ elif getattr(msg, "voice", None):
+ media_file = msg.voice
+ media_type = "voice"
+ elif getattr(msg, "audio", None):
+ media_file = msg.audio
+ media_type = "audio"
+ elif getattr(msg, "document", None):
+ media_file = msg.document
+ media_type = "file"
+ elif getattr(msg, "video", None):
+ media_file = msg.video
+ media_type = "video"
+ elif getattr(msg, "video_note", None):
+ media_file = msg.video_note
+ media_type = "video"
+ elif getattr(msg, "animation", None):
+ media_file = msg.animation
+ media_type = "animation"
+ if not media_file or not self._app:
+ return [], []
+ try:
+ file = await self._app.bot.get_file(media_file.file_id)
+ ext = self._get_extension(
+ media_type,
+ getattr(media_file, "mime_type", None),
+ getattr(media_file, "file_name", None),
+ )
+ media_dir = get_media_dir("telegram")
+ unique_id = getattr(media_file, "file_unique_id", media_file.file_id)
+ file_path = media_dir / f"{unique_id}{ext}"
+ await file.download_to_drive(str(file_path))
+ path_str = str(file_path)
+ if media_type in ("voice", "audio"):
+ transcription = await self.transcribe_audio(file_path)
+ if transcription:
+ logger.info("Transcribed {}: {}...", media_type, transcription[:50])
+ return [path_str], [f"[transcription: {transcription}]"]
+ return [path_str], [f"[{media_type}: {path_str}]"]
+ return [path_str], [f"[{media_type}: {path_str}]"]
+ except Exception as e:
+ logger.warning("Failed to download message media: {}", e)
+ if add_failure_content:
+ return [], [f"[{media_type}: download failed]"]
+ return [], []
+
+ async def _ensure_bot_identity(self) -> tuple[int | None, str | None]:
+ """Load bot identity once and reuse it for mention/reply checks."""
+ if self._bot_user_id is not None or self._bot_username is not None:
+ return self._bot_user_id, self._bot_username
+ if not self._app:
+ return None, None
+ bot_info = await self._app.bot.get_me()
+ self._bot_user_id = getattr(bot_info, "id", None)
+ self._bot_username = getattr(bot_info, "username", None)
+ return self._bot_user_id, self._bot_username
+
+ @staticmethod
+ def _has_mention_entity(
+ text: str,
+ entities,
+ bot_username: str,
+ bot_id: int | None,
+ ) -> bool:
+ """Check Telegram mention entities against the bot username."""
+ handle = f"@{bot_username}".lower()
+ for entity in entities or []:
+ entity_type = getattr(entity, "type", None)
+ if entity_type == "text_mention":
+ user = getattr(entity, "user", None)
+ if user is not None and bot_id is not None and getattr(user, "id", None) == bot_id:
+ return True
+ continue
+ if entity_type != "mention":
+ continue
+ offset = getattr(entity, "offset", None)
+ length = getattr(entity, "length", None)
+ if offset is None or length is None:
+ continue
+ if text[offset : offset + length].lower() == handle:
+ return True
+ return handle in text.lower()
+
+ async def _is_group_message_for_bot(self, message) -> bool:
+ """Allow group messages when policy is open, @mentioned, or replying to the bot."""
+ if message.chat.type == "private" or self.config.group_policy == "open":
+ return True
+
+ bot_id, bot_username = await self._ensure_bot_identity()
+ if bot_username:
+ text = message.text or ""
+ caption = message.caption or ""
+ if self._has_mention_entity(
+ text,
+ getattr(message, "entities", None),
+ bot_username,
+ bot_id,
+ ):
+ return True
+ if self._has_mention_entity(
+ caption,
+ getattr(message, "caption_entities", None),
+ bot_username,
+ bot_id,
+ ):
+ return True
+
+ reply_user = getattr(getattr(message, "reply_to_message", None), "from_user", None)
+ return bool(bot_id and reply_user and reply_user.id == bot_id)
+
+ def _remember_thread_context(self, message) -> None:
+ """Cache Telegram thread context by chat/message id for follow-up replies."""
+ message_thread_id = getattr(message, "message_thread_id", None)
+ if message_thread_id is None:
+ return
+ key = (str(message.chat_id), message.message_id)
+ self._message_threads[key] = message_thread_id
+ if len(self._message_threads) > 1000:
+ self._message_threads.pop(next(iter(self._message_threads)))
+
+ async def _forward_command(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
+ """Forward slash commands to the bus for unified handling in AgentLoop."""
+ if not update.message or not update.effective_user:
+ return
+ message = update.message
+ user = update.effective_user
+ self._remember_thread_context(message)
+
+ # Strip @bot_username suffix if present
+ content = message.text or ""
+ if content.startswith("/") and "@" in content:
+ cmd_part, *rest = content.split(" ", 1)
+ cmd_part = cmd_part.split("@")[0]
+ content = f"{cmd_part} {rest[0]}" if rest else cmd_part
+ content = self._normalize_telegram_command(content)
+
+ await self._handle_message(
+ sender_id=self._sender_id(user),
+ chat_id=str(message.chat_id),
+ content=content,
+ metadata=self._build_message_metadata(message, user),
+ session_key=self._derive_topic_session_key(message),
+ )
+
async def _on_message(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
"""Handle incoming messages (text, photos, voice, documents)."""
if not update.message or not update.effective_user:
return
-
+
message = update.message
user = update.effective_user
chat_id = message.chat_id
-
- # Use stable numeric ID, but keep username for allowlist compatibility
- sender_id = str(user.id)
- if user.username:
- sender_id = f"{sender_id}|{user.username}"
-
+ sender_id = self._sender_id(user)
+ self._remember_thread_context(message)
+
# Store chat_id for replies
self._chat_ids[sender_id] = chat_id
-
+
+ if not await self._is_group_message_for_bot(message):
+ return
+
# Build content from text and/or media
content_parts = []
media_paths = []
-
+
# Text content
if message.text:
content_parts.append(message.text)
if message.caption:
content_parts.append(message.caption)
-
- # Handle media files
- media_file = None
- media_type = None
-
- if message.photo:
- media_file = message.photo[-1] # Largest photo
- media_type = "image"
- elif message.voice:
- media_file = message.voice
- media_type = "voice"
- elif message.audio:
- media_file = message.audio
- media_type = "audio"
- elif message.document:
- media_file = message.document
- media_type = "file"
-
- # Download media if present
- if media_file and self._app:
- try:
- file = await self._app.bot.get_file(media_file.file_id)
- ext = self._get_extension(media_type, getattr(media_file, 'mime_type', None))
-
- # Save to workspace/media/
- from pathlib import Path
- media_dir = Path.home() / ".nanobot" / "media"
- media_dir.mkdir(parents=True, exist_ok=True)
-
- file_path = media_dir / f"{media_file.file_id[:16]}{ext}"
- await file.download_to_drive(str(file_path))
-
- media_paths.append(str(file_path))
-
- # Handle voice transcription
- if media_type == "voice" or media_type == "audio":
- from nanobot.providers.transcription import GroqTranscriptionProvider
- transcriber = GroqTranscriptionProvider(api_key=self.groq_api_key)
- transcription = await transcriber.transcribe(file_path)
- if transcription:
- logger.info(f"Transcribed {media_type}: {transcription[:50]}...")
- content_parts.append(f"[transcription: {transcription}]")
- else:
- content_parts.append(f"[{media_type}: {file_path}]")
- else:
- content_parts.append(f"[{media_type}: {file_path}]")
-
- logger.debug(f"Downloaded {media_type} to {file_path}")
- except Exception as e:
- logger.error(f"Failed to download media: {e}")
- content_parts.append(f"[{media_type}: download failed]")
-
+
+ # Download current message media
+ current_media_paths, current_media_parts = await self._download_message_media(
+ message, add_failure_content=True
+ )
+ media_paths.extend(current_media_paths)
+ content_parts.extend(current_media_parts)
+ if current_media_paths:
+ logger.debug("Downloaded message media to {}", current_media_paths[0])
+
+ # Reply context: text and/or media from the replied-to message
+ reply = getattr(message, "reply_to_message", None)
+ if reply is not None:
+ reply_ctx = await self._extract_reply_context(message)
+ reply_media, reply_media_parts = await self._download_message_media(reply)
+ if reply_media:
+ media_paths = reply_media + media_paths
+ logger.debug("Attached replied-to media: {}", reply_media[0])
+ tag = reply_ctx or (f"[Reply to: {reply_media_parts[0]}]" if reply_media_parts else None)
+ if tag:
+ content_parts.insert(0, tag)
content = "\n".join(content_parts) if content_parts else "[empty message]"
-
- logger.debug(f"Telegram message from {sender_id}: {content[:50]}...")
-
+
+ logger.debug("Telegram message from {}: {}...", sender_id, content[:50])
+
+ str_chat_id = str(chat_id)
+ metadata = self._build_message_metadata(message, user)
+ session_key = self._derive_topic_session_key(message)
+
+ # Telegram media groups: buffer briefly, forward as one aggregated turn.
+ if media_group_id := getattr(message, "media_group_id", None):
+ key = f"{str_chat_id}:{media_group_id}"
+ if key not in self._media_group_buffers:
+ self._media_group_buffers[key] = {
+ "sender_id": sender_id, "chat_id": str_chat_id,
+ "contents": [], "media": [],
+ "metadata": metadata,
+ "session_key": session_key,
+ }
+ self._start_typing(str_chat_id)
+ await self._add_reaction(str_chat_id, message.message_id, self.config.react_emoji)
+ buf = self._media_group_buffers[key]
+ if content and content != "[empty message]":
+ buf["contents"].append(content)
+ buf["media"].extend(media_paths)
+ if key not in self._media_group_tasks:
+ self._media_group_tasks[key] = asyncio.create_task(self._flush_media_group(key))
+ return
+
+ # Start typing indicator before processing
+ self._start_typing(str_chat_id)
+ await self._add_reaction(str_chat_id, message.message_id, self.config.react_emoji)
+
# Forward to the message bus
await self._handle_message(
sender_id=sender_id,
- chat_id=str(chat_id),
+ chat_id=str_chat_id,
content=content,
media=media_paths,
- metadata={
- "message_id": message.message_id,
- "user_id": user.id,
- "username": user.username,
- "first_name": user.first_name,
- "is_group": message.chat.type != "private"
- }
+ metadata=metadata,
+ session_key=session_key,
)
-
- def _get_extension(self, media_type: str, mime_type: str | None) -> str:
- """Get file extension based on media type."""
+
+ async def _flush_media_group(self, key: str) -> None:
+ """Wait briefly, then forward buffered media-group as one turn."""
+ try:
+ await asyncio.sleep(0.6)
+ if not (buf := self._media_group_buffers.pop(key, None)):
+ return
+ content = "\n".join(buf["contents"]) or "[empty message]"
+ await self._handle_message(
+ sender_id=buf["sender_id"], chat_id=buf["chat_id"],
+ content=content, media=list(dict.fromkeys(buf["media"])),
+ metadata=buf["metadata"],
+ session_key=buf.get("session_key"),
+ )
+ finally:
+ self._media_group_tasks.pop(key, None)
+
+ def _start_typing(self, chat_id: str) -> None:
+ """Start sending 'typing...' indicator for a chat."""
+ # Cancel any existing typing task for this chat
+ self._stop_typing(chat_id)
+ self._typing_tasks[chat_id] = asyncio.create_task(self._typing_loop(chat_id))
+
+ def _stop_typing(self, chat_id: str) -> None:
+ """Stop the typing indicator for a chat."""
+ task = self._typing_tasks.pop(chat_id, None)
+ if task and not task.done():
+ task.cancel()
+
+ async def _add_reaction(self, chat_id: str, message_id: int, emoji: str) -> None:
+ """Add emoji reaction to a message (best-effort, non-blocking)."""
+ if not self._app or not emoji:
+ return
+ try:
+ await self._app.bot.set_message_reaction(
+ chat_id=int(chat_id),
+ message_id=message_id,
+ reaction=[ReactionTypeEmoji(emoji=emoji)],
+ )
+ except Exception as e:
+ logger.debug("Telegram reaction failed: {}", e)
+
+ async def _remove_reaction(self, chat_id: str, message_id: int) -> None:
+ """Remove emoji reaction from a message (best-effort, non-blocking)."""
+ if not self._app:
+ return
+ try:
+ await self._app.bot.set_message_reaction(
+ chat_id=int(chat_id),
+ message_id=message_id,
+ reaction=[],
+ )
+ except Exception as e:
+ logger.debug("Telegram reaction removal failed: {}", e)
+
+ async def _typing_loop(self, chat_id: str) -> None:
+ """Repeatedly send 'typing' action until cancelled."""
+ try:
+ while self._app:
+ await self._app.bot.send_chat_action(chat_id=int(chat_id), action="typing")
+ await asyncio.sleep(4)
+ except asyncio.CancelledError:
+ pass
+ except Exception as e:
+ logger.debug("Typing indicator stopped for {}: {}", chat_id, e)
+
+ @staticmethod
+ def _format_telegram_error(exc: Exception) -> str:
+ """Return a short, readable error summary for logs."""
+ text = str(exc).strip()
+ if text:
+ return text
+ if exc.__cause__ is not None:
+ cause = exc.__cause__
+ cause_text = str(cause).strip()
+ if cause_text:
+ return f"{exc.__class__.__name__} ({cause_text})"
+ return f"{exc.__class__.__name__} ({cause.__class__.__name__})"
+ return exc.__class__.__name__
+
+ def _on_polling_error(self, exc: Exception) -> None:
+ """Keep long-polling network failures to a single readable line."""
+ summary = self._format_telegram_error(exc)
+ if isinstance(exc, (NetworkError, TimedOut)):
+ logger.warning("Telegram polling network issue: {}", summary)
+ else:
+ logger.error("Telegram polling error: {}", summary)
+
+ async def _on_error(self, update: object, context: ContextTypes.DEFAULT_TYPE) -> None:
+ """Log polling / handler errors instead of silently swallowing them."""
+ summary = self._format_telegram_error(context.error)
+
+ if isinstance(context.error, (NetworkError, TimedOut)):
+ logger.warning("Telegram network issue: {}", summary)
+ else:
+ logger.error("Telegram error: {}", summary)
+
+ def _get_extension(
+ self,
+ media_type: str,
+ mime_type: str | None,
+ filename: str | None = None,
+ ) -> str:
+ """Get file extension based on media type or original filename."""
if mime_type:
ext_map = {
"image/jpeg": ".jpg", "image/png": ".png", "image/gif": ".gif",
@@ -297,6 +1051,14 @@ class TelegramChannel(BaseChannel):
}
if mime_type in ext_map:
return ext_map[mime_type]
-
+
type_map = {"image": ".jpg", "voice": ".ogg", "audio": ".mp3", "file": ""}
- return type_map.get(media_type, "")
+ if ext := type_map.get(media_type, ""):
+ return ext
+
+ if filename:
+ from pathlib import Path
+
+ return "".join(Path(filename).suffixes)
+
+ return ""
diff --git a/nanobot/channels/wecom.py b/nanobot/channels/wecom.py
new file mode 100644
index 000000000..05ad14825
--- /dev/null
+++ b/nanobot/channels/wecom.py
@@ -0,0 +1,371 @@
+"""WeCom (Enterprise WeChat) channel implementation using wecom_aibot_sdk."""
+
+import asyncio
+import importlib.util
+import os
+from collections import OrderedDict
+from typing import Any
+
+from loguru import logger
+
+from nanobot.bus.events import OutboundMessage
+from nanobot.bus.queue import MessageBus
+from nanobot.channels.base import BaseChannel
+from nanobot.config.paths import get_media_dir
+from nanobot.config.schema import Base
+from pydantic import Field
+
+WECOM_AVAILABLE = importlib.util.find_spec("wecom_aibot_sdk") is not None
+
+class WecomConfig(Base):
+ """WeCom (Enterprise WeChat) AI Bot channel configuration."""
+
+ enabled: bool = False
+ bot_id: str = ""
+ secret: str = ""
+ allow_from: list[str] = Field(default_factory=list)
+ welcome_message: str = ""
+
+
+# Message type display mapping
+MSG_TYPE_MAP = {
+ "image": "[image]",
+ "voice": "[voice]",
+ "file": "[file]",
+ "mixed": "[mixed content]",
+}
+
+
+class WecomChannel(BaseChannel):
+ """
+ WeCom (Enterprise WeChat) channel using WebSocket long connection.
+
+ Uses WebSocket to receive events - no public IP or webhook required.
+
+ Requires:
+ - Bot ID and Secret from WeCom AI Bot platform
+ """
+
+ name = "wecom"
+ display_name = "WeCom"
+
+ @classmethod
+ def default_config(cls) -> dict[str, Any]:
+ return WecomConfig().model_dump(by_alias=True)
+
+ def __init__(self, config: Any, bus: MessageBus):
+ if isinstance(config, dict):
+ config = WecomConfig.model_validate(config)
+ super().__init__(config, bus)
+ self.config: WecomConfig = config
+ self._client: Any = None
+ self._processed_message_ids: OrderedDict[str, None] = OrderedDict()
+ self._loop: asyncio.AbstractEventLoop | None = None
+ self._generate_req_id = None
+ # Store frame headers for each chat to enable replies
+ self._chat_frames: dict[str, Any] = {}
+
+ async def start(self) -> None:
+ """Start the WeCom bot with WebSocket long connection."""
+ if not WECOM_AVAILABLE:
+ logger.error("WeCom SDK not installed. Run: pip install nanobot-ai[wecom]")
+ return
+
+ if not self.config.bot_id or not self.config.secret:
+ logger.error("WeCom bot_id and secret not configured")
+ return
+
+ from wecom_aibot_sdk import WSClient, generate_req_id
+
+ self._running = True
+ self._loop = asyncio.get_running_loop()
+ self._generate_req_id = generate_req_id
+
+ # Create WebSocket client
+ self._client = WSClient({
+ "bot_id": self.config.bot_id,
+ "secret": self.config.secret,
+ "reconnect_interval": 1000,
+ "max_reconnect_attempts": -1, # Infinite reconnect
+ "heartbeat_interval": 30000,
+ })
+
+ # Register event handlers
+ self._client.on("connected", self._on_connected)
+ self._client.on("authenticated", self._on_authenticated)
+ self._client.on("disconnected", self._on_disconnected)
+ self._client.on("error", self._on_error)
+ self._client.on("message.text", self._on_text_message)
+ self._client.on("message.image", self._on_image_message)
+ self._client.on("message.voice", self._on_voice_message)
+ self._client.on("message.file", self._on_file_message)
+ self._client.on("message.mixed", self._on_mixed_message)
+ self._client.on("event.enter_chat", self._on_enter_chat)
+
+ logger.info("WeCom bot starting with WebSocket long connection")
+ logger.info("No public IP required - using WebSocket to receive events")
+
+ # Connect
+ await self._client.connect_async()
+
+ # Keep running until stopped
+ while self._running:
+ await asyncio.sleep(1)
+
+ async def stop(self) -> None:
+ """Stop the WeCom bot."""
+ self._running = False
+ if self._client:
+ await self._client.disconnect()
+ logger.info("WeCom bot stopped")
+
+ async def _on_connected(self, frame: Any) -> None:
+ """Handle WebSocket connected event."""
+ logger.info("WeCom WebSocket connected")
+
+ async def _on_authenticated(self, frame: Any) -> None:
+ """Handle authentication success event."""
+ logger.info("WeCom authenticated successfully")
+
+ async def _on_disconnected(self, frame: Any) -> None:
+ """Handle WebSocket disconnected event."""
+ reason = frame.body if hasattr(frame, 'body') else str(frame)
+ logger.warning("WeCom WebSocket disconnected: {}", reason)
+
+ async def _on_error(self, frame: Any) -> None:
+ """Handle error event."""
+ logger.error("WeCom error: {}", frame)
+
+ async def _on_text_message(self, frame: Any) -> None:
+ """Handle text message."""
+ await self._process_message(frame, "text")
+
+ async def _on_image_message(self, frame: Any) -> None:
+ """Handle image message."""
+ await self._process_message(frame, "image")
+
+ async def _on_voice_message(self, frame: Any) -> None:
+ """Handle voice message."""
+ await self._process_message(frame, "voice")
+
+ async def _on_file_message(self, frame: Any) -> None:
+ """Handle file message."""
+ await self._process_message(frame, "file")
+
+ async def _on_mixed_message(self, frame: Any) -> None:
+ """Handle mixed content message."""
+ await self._process_message(frame, "mixed")
+
+ async def _on_enter_chat(self, frame: Any) -> None:
+ """Handle enter_chat event (user opens chat with bot)."""
+ try:
+ # Extract body from WsFrame dataclass or dict
+ if hasattr(frame, 'body'):
+ body = frame.body or {}
+ elif isinstance(frame, dict):
+ body = frame.get("body", frame)
+ else:
+ body = {}
+
+ chat_id = body.get("chatid", "") if isinstance(body, dict) else ""
+
+ if chat_id and self.config.welcome_message:
+ await self._client.reply_welcome(frame, {
+ "msgtype": "text",
+ "text": {"content": self.config.welcome_message},
+ })
+ except Exception as e:
+ logger.error("Error handling enter_chat: {}", e)
+
+ async def _process_message(self, frame: Any, msg_type: str) -> None:
+ """Process incoming message and forward to bus."""
+ try:
+ # Extract body from WsFrame dataclass or dict
+ if hasattr(frame, 'body'):
+ body = frame.body or {}
+ elif isinstance(frame, dict):
+ body = frame.get("body", frame)
+ else:
+ body = {}
+
+ # Ensure body is a dict
+ if not isinstance(body, dict):
+ logger.warning("Invalid body type: {}", type(body))
+ return
+
+ # Extract message info
+ msg_id = body.get("msgid", "")
+ if not msg_id:
+ msg_id = f"{body.get('chatid', '')}_{body.get('sendertime', '')}"
+
+ # Deduplication check
+ if msg_id in self._processed_message_ids:
+ return
+ self._processed_message_ids[msg_id] = None
+
+ # Trim cache
+ while len(self._processed_message_ids) > 1000:
+ self._processed_message_ids.popitem(last=False)
+
+ # Extract sender info from "from" field (SDK format)
+ from_info = body.get("from", {})
+ sender_id = from_info.get("userid", "unknown") if isinstance(from_info, dict) else "unknown"
+
+ # For single chat, chatid is the sender's userid
+ # For group chat, chatid is provided in body
+ chat_type = body.get("chattype", "single")
+ chat_id = body.get("chatid", sender_id)
+
+ content_parts = []
+
+ if msg_type == "text":
+ text = body.get("text", {}).get("content", "")
+ if text:
+ content_parts.append(text)
+
+ elif msg_type == "image":
+ image_info = body.get("image", {})
+ file_url = image_info.get("url", "")
+ aes_key = image_info.get("aeskey", "")
+
+ if file_url and aes_key:
+ file_path = await self._download_and_save_media(file_url, aes_key, "image")
+ if file_path:
+ filename = os.path.basename(file_path)
+ content_parts.append(f"[image: {filename}]\n[Image: source: {file_path}]")
+ else:
+ content_parts.append("[image: download failed]")
+ else:
+ content_parts.append("[image: download failed]")
+
+ elif msg_type == "voice":
+ voice_info = body.get("voice", {})
+ # Voice message already contains transcribed content from WeCom
+ voice_content = voice_info.get("content", "")
+ if voice_content:
+ content_parts.append(f"[voice] {voice_content}")
+ else:
+ content_parts.append("[voice]")
+
+ elif msg_type == "file":
+ file_info = body.get("file", {})
+ file_url = file_info.get("url", "")
+ aes_key = file_info.get("aeskey", "")
+ file_name = file_info.get("name", "unknown")
+
+ if file_url and aes_key:
+ file_path = await self._download_and_save_media(file_url, aes_key, "file", file_name)
+ if file_path:
+ content_parts.append(f"[file: {file_name}]\n[File: source: {file_path}]")
+ else:
+ content_parts.append(f"[file: {file_name}: download failed]")
+ else:
+ content_parts.append(f"[file: {file_name}: download failed]")
+
+ elif msg_type == "mixed":
+ # Mixed content contains multiple message items
+ msg_items = body.get("mixed", {}).get("item", [])
+ for item in msg_items:
+ item_type = item.get("type", "")
+ if item_type == "text":
+ text = item.get("text", {}).get("content", "")
+ if text:
+ content_parts.append(text)
+ else:
+ content_parts.append(MSG_TYPE_MAP.get(item_type, f"[{item_type}]"))
+
+ else:
+ content_parts.append(MSG_TYPE_MAP.get(msg_type, f"[{msg_type}]"))
+
+ content = "\n".join(content_parts) if content_parts else ""
+
+ if not content:
+ return
+
+ # Store frame for this chat to enable replies
+ self._chat_frames[chat_id] = frame
+
+ # Forward to message bus
+ # Note: media paths are included in content for broader model compatibility
+ await self._handle_message(
+ sender_id=sender_id,
+ chat_id=chat_id,
+ content=content,
+ media=None,
+ metadata={
+ "message_id": msg_id,
+ "msg_type": msg_type,
+ "chat_type": chat_type,
+ }
+ )
+
+ except Exception as e:
+ logger.error("Error processing WeCom message: {}", e)
+
+ async def _download_and_save_media(
+ self,
+ file_url: str,
+ aes_key: str,
+ media_type: str,
+ filename: str | None = None,
+ ) -> str | None:
+ """
+ Download and decrypt media from WeCom.
+
+ Returns:
+ file_path or None if download failed
+ """
+ try:
+ data, fname = await self._client.download_file(file_url, aes_key)
+
+ if not data:
+ logger.warning("Failed to download media from WeCom")
+ return None
+
+ media_dir = get_media_dir("wecom")
+ if not filename:
+ filename = fname or f"{media_type}_{hash(file_url) % 100000}"
+ filename = os.path.basename(filename)
+
+ file_path = media_dir / filename
+ file_path.write_bytes(data)
+ logger.debug("Downloaded {} to {}", media_type, file_path)
+ return str(file_path)
+
+ except Exception as e:
+ logger.error("Error downloading media: {}", e)
+ return None
+
+ async def send(self, msg: OutboundMessage) -> None:
+ """Send a message through WeCom."""
+ if not self._client:
+ logger.warning("WeCom client not initialized")
+ return
+
+ try:
+ content = msg.content.strip()
+ if not content:
+ return
+
+ # Get the stored frame for this chat
+ frame = self._chat_frames.get(msg.chat_id)
+ if not frame:
+ logger.warning("No frame found for chat {}, cannot reply", msg.chat_id)
+ return
+
+ # Use streaming reply for better UX
+ stream_id = self._generate_req_id("stream")
+
+ # Send as streaming message with finish=True
+ await self._client.reply_stream(
+ frame,
+ stream_id,
+ content,
+ finish=True,
+ )
+
+ logger.debug("WeCom message sent to {}", msg.chat_id)
+
+ except Exception as e:
+ logger.error("Error sending WeCom message: {}", e)
+ raise
diff --git a/nanobot/channels/weixin.py b/nanobot/channels/weixin.py
new file mode 100644
index 000000000..2266bc9f0
--- /dev/null
+++ b/nanobot/channels/weixin.py
@@ -0,0 +1,1380 @@
+"""Personal WeChat (微信) channel using HTTP long-poll API.
+
+Uses the ilinkai.weixin.qq.com API for personal WeChat messaging.
+No WebSocket, no local WeChat client needed — just HTTP requests with a
+bot token obtained via QR code login.
+
+Protocol reverse-engineered from ``@tencent-weixin/openclaw-weixin`` v1.0.3.
+"""
+
+from __future__ import annotations
+
+import asyncio
+import base64
+import hashlib
+import json
+import os
+import random
+import re
+import time
+import uuid
+from collections import OrderedDict
+from pathlib import Path
+from typing import Any
+from urllib.parse import quote
+
+import httpx
+from loguru import logger
+from pydantic import Field
+
+from nanobot.bus.events import OutboundMessage
+from nanobot.bus.queue import MessageBus
+from nanobot.channels.base import BaseChannel
+from nanobot.config.paths import get_media_dir, get_runtime_subdir
+from nanobot.config.schema import Base
+from nanobot.utils.helpers import split_message
+
+# ---------------------------------------------------------------------------
+# Protocol constants (from openclaw-weixin types.ts)
+# ---------------------------------------------------------------------------
+
+# MessageItemType
+ITEM_TEXT = 1
+ITEM_IMAGE = 2
+ITEM_VOICE = 3
+ITEM_FILE = 4
+ITEM_VIDEO = 5
+
+# MessageType (1 = inbound from user, 2 = outbound from bot)
+MESSAGE_TYPE_USER = 1
+MESSAGE_TYPE_BOT = 2
+
+# MessageState
+MESSAGE_STATE_FINISH = 2
+
+WEIXIN_MAX_MESSAGE_LEN = 4000
+WEIXIN_CHANNEL_VERSION = "2.1.1"
+ILINK_APP_ID = "bot"
+
+
+def _build_client_version(version: str) -> int:
+ """Encode semantic version as 0x00MMNNPP (major/minor/patch in one uint32)."""
+ parts = version.split(".")
+
+ def _as_int(idx: int) -> int:
+ try:
+ return int(parts[idx])
+ except Exception:
+ return 0
+
+ major = _as_int(0)
+ minor = _as_int(1)
+ patch = _as_int(2)
+ return ((major & 0xFF) << 16) | ((minor & 0xFF) << 8) | (patch & 0xFF)
+
+ILINK_APP_CLIENT_VERSION = _build_client_version(WEIXIN_CHANNEL_VERSION)
+BASE_INFO: dict[str, str] = {"channel_version": WEIXIN_CHANNEL_VERSION}
+
+# Session-expired error code
+ERRCODE_SESSION_EXPIRED = -14
+SESSION_PAUSE_DURATION_S = 60 * 60
+
+# Retry constants (matching the reference plugin's monitor.ts)
+MAX_CONSECUTIVE_FAILURES = 3
+BACKOFF_DELAY_S = 30
+RETRY_DELAY_S = 2
+MAX_QR_REFRESH_COUNT = 3
+TYPING_STATUS_TYPING = 1
+TYPING_STATUS_CANCEL = 2
+TYPING_TICKET_TTL_S = 24 * 60 * 60
+TYPING_KEEPALIVE_INTERVAL_S = 5
+CONFIG_CACHE_INITIAL_RETRY_S = 2
+CONFIG_CACHE_MAX_RETRY_S = 60 * 60
+
+# Default long-poll timeout; overridden by server via longpolling_timeout_ms.
+DEFAULT_LONG_POLL_TIMEOUT_S = 35
+
+# Media-type codes for getuploadurl (1=image, 2=video, 3=file, 4=voice)
+UPLOAD_MEDIA_IMAGE = 1
+UPLOAD_MEDIA_VIDEO = 2
+UPLOAD_MEDIA_FILE = 3
+UPLOAD_MEDIA_VOICE = 4
+
+# File extensions considered as images / videos for outbound media
+_IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp", ".tiff", ".ico", ".svg"}
+_VIDEO_EXTS = {".mp4", ".avi", ".mov", ".mkv", ".webm", ".flv"}
+_VOICE_EXTS = {".mp3", ".wav", ".amr", ".silk", ".ogg", ".m4a", ".aac", ".flac"}
+
+
+def _has_downloadable_media_locator(media: dict[str, Any] | None) -> bool:
+ if not isinstance(media, dict):
+ return False
+ return bool(str(media.get("encrypt_query_param", "") or "") or str(media.get("full_url", "") or "").strip())
+
+
+class WeixinConfig(Base):
+ """Personal WeChat channel configuration."""
+
+ enabled: bool = False
+ allow_from: list[str] = Field(default_factory=list)
+ base_url: str = "https://ilinkai.weixin.qq.com"
+ cdn_base_url: str = "https://novac2c.cdn.weixin.qq.com/c2c"
+ route_tag: str | int | None = None
+ token: str = "" # Manually set token, or obtained via QR login
+ state_dir: str = "" # Default: ~/.nanobot/weixin/
+ poll_timeout: int = DEFAULT_LONG_POLL_TIMEOUT_S # seconds for long-poll
+
+
+class WeixinChannel(BaseChannel):
+ """
+ Personal WeChat channel using HTTP long-poll.
+
+ Connects to ilinkai.weixin.qq.com API to receive and send personal
+ WeChat messages. Authentication is via QR code login which produces
+ a bot token.
+ """
+
+ name = "weixin"
+ display_name = "WeChat"
+
+ @classmethod
+ def default_config(cls) -> dict[str, Any]:
+ return WeixinConfig().model_dump(by_alias=True)
+
+ def __init__(self, config: Any, bus: MessageBus):
+ if isinstance(config, dict):
+ config = WeixinConfig.model_validate(config)
+ super().__init__(config, bus)
+ self.config: WeixinConfig = config
+
+ # State
+ self._client: httpx.AsyncClient | None = None
+ self._get_updates_buf: str = ""
+ self._context_tokens: dict[str, str] = {} # from_user_id -> context_token
+ self._processed_ids: OrderedDict[str, None] = OrderedDict()
+ self._state_dir: Path | None = None
+ self._token: str = ""
+ self._poll_task: asyncio.Task | None = None
+ self._next_poll_timeout_s: int = DEFAULT_LONG_POLL_TIMEOUT_S
+ self._session_pause_until: float = 0.0
+ self._typing_tasks: dict[str, asyncio.Task] = {}
+ self._typing_tickets: dict[str, dict[str, Any]] = {}
+
+ # ------------------------------------------------------------------
+ # State persistence
+ # ------------------------------------------------------------------
+
+ def _get_state_dir(self) -> Path:
+ if self._state_dir:
+ return self._state_dir
+ if self.config.state_dir:
+ d = Path(self.config.state_dir).expanduser()
+ else:
+ d = get_runtime_subdir("weixin")
+ d.mkdir(parents=True, exist_ok=True)
+ self._state_dir = d
+ return d
+
+ def _load_state(self) -> bool:
+ """Load saved account state. Returns True if a valid token was found."""
+ state_file = self._get_state_dir() / "account.json"
+ if not state_file.exists():
+ return False
+ try:
+ data = json.loads(state_file.read_text())
+ self._token = data.get("token", "")
+ self._get_updates_buf = data.get("get_updates_buf", "")
+ context_tokens = data.get("context_tokens", {})
+ if isinstance(context_tokens, dict):
+ self._context_tokens = {
+ str(user_id): str(token)
+ for user_id, token in context_tokens.items()
+ if str(user_id).strip() and str(token).strip()
+ }
+ else:
+ self._context_tokens = {}
+ typing_tickets = data.get("typing_tickets", {})
+ if isinstance(typing_tickets, dict):
+ self._typing_tickets = {
+ str(user_id): ticket
+ for user_id, ticket in typing_tickets.items()
+ if str(user_id).strip() and isinstance(ticket, dict)
+ }
+ else:
+ self._typing_tickets = {}
+ base_url = data.get("base_url", "")
+ if base_url:
+ self.config.base_url = base_url
+ return bool(self._token)
+ except Exception:
+ return False
+
+ def _save_state(self) -> None:
+ state_file = self._get_state_dir() / "account.json"
+ try:
+ data = {
+ "token": self._token,
+ "get_updates_buf": self._get_updates_buf,
+ "context_tokens": self._context_tokens,
+ "typing_tickets": self._typing_tickets,
+ "base_url": self.config.base_url,
+ }
+ state_file.write_text(json.dumps(data, ensure_ascii=False))
+ except Exception:
+ pass
+
+ # ------------------------------------------------------------------
+ # HTTP helpers (matches api.ts buildHeaders / apiFetch)
+ # ------------------------------------------------------------------
+
+ @staticmethod
+ def _random_wechat_uin() -> str:
+ """X-WECHAT-UIN: random uint32 → decimal string → base64.
+
+ Matches the reference plugin's ``randomWechatUin()`` in api.ts.
+ Generated fresh for **every** request (same as reference).
+ """
+ uint32 = int.from_bytes(os.urandom(4), "big")
+ return base64.b64encode(str(uint32).encode()).decode()
+
+ def _make_headers(self, *, auth: bool = True) -> dict[str, str]:
+ """Build per-request headers (new UIN each call, matching reference)."""
+ headers: dict[str, str] = {
+ "X-WECHAT-UIN": self._random_wechat_uin(),
+ "Content-Type": "application/json",
+ "AuthorizationType": "ilink_bot_token",
+ "iLink-App-Id": ILINK_APP_ID,
+ "iLink-App-ClientVersion": str(ILINK_APP_CLIENT_VERSION),
+ }
+ if auth and self._token:
+ headers["Authorization"] = f"Bearer {self._token}"
+ if self.config.route_tag is not None and str(self.config.route_tag).strip():
+ headers["SKRouteTag"] = str(self.config.route_tag).strip()
+ return headers
+
+ @staticmethod
+ def _is_retryable_media_download_error(err: Exception) -> bool:
+ if isinstance(err, httpx.TimeoutException | httpx.TransportError):
+ return True
+ if isinstance(err, httpx.HTTPStatusError):
+ status_code = err.response.status_code if err.response is not None else 0
+ return status_code >= 500
+ return False
+
+ async def _api_get(
+ self,
+ endpoint: str,
+ params: dict | None = None,
+ *,
+ auth: bool = True,
+ extra_headers: dict[str, str] | None = None,
+ ) -> dict:
+ assert self._client is not None
+ url = f"{self.config.base_url}/{endpoint}"
+ hdrs = self._make_headers(auth=auth)
+ if extra_headers:
+ hdrs.update(extra_headers)
+ resp = await self._client.get(url, params=params, headers=hdrs)
+ resp.raise_for_status()
+ return resp.json()
+
+ async def _api_get_with_base(
+ self,
+ *,
+ base_url: str,
+ endpoint: str,
+ params: dict | None = None,
+ auth: bool = True,
+ extra_headers: dict[str, str] | None = None,
+ ) -> dict:
+ """GET helper that allows overriding base_url for QR redirect polling."""
+ assert self._client is not None
+ url = f"{base_url.rstrip('/')}/{endpoint}"
+ hdrs = self._make_headers(auth=auth)
+ if extra_headers:
+ hdrs.update(extra_headers)
+ resp = await self._client.get(url, params=params, headers=hdrs)
+ resp.raise_for_status()
+ return resp.json()
+
+ async def _api_post(
+ self,
+ endpoint: str,
+ body: dict | None = None,
+ *,
+ auth: bool = True,
+ ) -> dict:
+ assert self._client is not None
+ url = f"{self.config.base_url}/{endpoint}"
+ payload = body or {}
+ if "base_info" not in payload:
+ payload["base_info"] = BASE_INFO
+ resp = await self._client.post(url, json=payload, headers=self._make_headers(auth=auth))
+ resp.raise_for_status()
+ return resp.json()
+
+ # ------------------------------------------------------------------
+ # QR Code Login (matches login-qr.ts)
+ # ------------------------------------------------------------------
+
+ async def _fetch_qr_code(self) -> tuple[str, str]:
+ """Fetch a fresh QR code. Returns (qrcode_id, scan_url)."""
+ data = await self._api_get(
+ "ilink/bot/get_bot_qrcode",
+ params={"bot_type": "3"},
+ auth=False,
+ )
+ qrcode_img_content = data.get("qrcode_img_content", "")
+ qrcode_id = data.get("qrcode", "")
+ if not qrcode_id:
+ raise RuntimeError(f"Failed to get QR code from WeChat API: {data}")
+ return qrcode_id, (qrcode_img_content or qrcode_id)
+
+ async def _qr_login(self) -> bool:
+ """Perform QR code login flow. Returns True on success."""
+ try:
+ refresh_count = 0
+ qrcode_id, scan_url = await self._fetch_qr_code()
+ self._print_qr_code(scan_url)
+ current_poll_base_url = self.config.base_url
+
+ while self._running:
+ try:
+ status_data = await self._api_get_with_base(
+ base_url=current_poll_base_url,
+ endpoint="ilink/bot/get_qrcode_status",
+ params={"qrcode": qrcode_id},
+ auth=False,
+ )
+ except Exception as e:
+ if self._is_retryable_qr_poll_error(e):
+ await asyncio.sleep(1)
+ continue
+ raise
+
+ if not isinstance(status_data, dict):
+ await asyncio.sleep(1)
+ continue
+
+ status = status_data.get("status", "")
+ if status == "confirmed":
+ token = status_data.get("bot_token", "")
+ bot_id = status_data.get("ilink_bot_id", "")
+ base_url = status_data.get("baseurl", "")
+ user_id = status_data.get("ilink_user_id", "")
+ if token:
+ self._token = token
+ if base_url:
+ self.config.base_url = base_url
+ self._save_state()
+ logger.info(
+ "WeChat login successful! bot_id={} user_id={}",
+ bot_id,
+ user_id,
+ )
+ return True
+ else:
+ logger.error("Login confirmed but no bot_token in response")
+ return False
+ elif status == "scaned_but_redirect":
+ redirect_host = str(status_data.get("redirect_host", "") or "").strip()
+ if redirect_host:
+ if redirect_host.startswith("http://") or redirect_host.startswith("https://"):
+ redirected_base = redirect_host
+ else:
+ redirected_base = f"https://{redirect_host}"
+ if redirected_base != current_poll_base_url:
+ current_poll_base_url = redirected_base
+ elif status == "expired":
+ refresh_count += 1
+ if refresh_count > MAX_QR_REFRESH_COUNT:
+ logger.warning(
+ "QR code expired too many times ({}/{}), giving up.",
+ refresh_count - 1,
+ MAX_QR_REFRESH_COUNT,
+ )
+ return False
+ qrcode_id, scan_url = await self._fetch_qr_code()
+ current_poll_base_url = self.config.base_url
+ self._print_qr_code(scan_url)
+ continue
+ # status == "wait" — keep polling
+
+ await asyncio.sleep(1)
+
+ except Exception as e:
+ logger.error("WeChat QR login failed: {}", e)
+
+ return False
+
+ @staticmethod
+ def _is_retryable_qr_poll_error(err: Exception) -> bool:
+ if isinstance(err, httpx.TimeoutException | httpx.TransportError):
+ return True
+ if isinstance(err, httpx.HTTPStatusError):
+ status_code = err.response.status_code if err.response is not None else 0
+ if status_code >= 500:
+ return True
+ return False
+
+ @staticmethod
+ def _print_qr_code(url: str) -> None:
+ try:
+ import qrcode as qr_lib
+
+ qr = qr_lib.QRCode(border=1)
+ qr.add_data(url)
+ qr.make(fit=True)
+ qr.print_ascii(invert=True)
+ except ImportError:
+ print(f"\nLogin URL: {url}\n")
+
+ # ------------------------------------------------------------------
+ # Channel lifecycle
+ # ------------------------------------------------------------------
+
+ async def login(self, force: bool = False) -> bool:
+ """Perform QR code login and save token. Returns True on success."""
+ if force:
+ self._token = ""
+ self._get_updates_buf = ""
+ state_file = self._get_state_dir() / "account.json"
+ if state_file.exists():
+ state_file.unlink()
+ if self._token or self._load_state():
+ return True
+
+ # Initialize HTTP client for the login flow
+ self._client = httpx.AsyncClient(
+ timeout=httpx.Timeout(60, connect=30),
+ follow_redirects=True,
+ )
+ self._running = True # Enable polling loop in _qr_login()
+ try:
+ return await self._qr_login()
+ finally:
+ self._running = False
+ if self._client:
+ await self._client.aclose()
+ self._client = None
+
+ async def start(self) -> None:
+ self._running = True
+ self._next_poll_timeout_s = self.config.poll_timeout
+ self._client = httpx.AsyncClient(
+ timeout=httpx.Timeout(self._next_poll_timeout_s + 10, connect=30),
+ follow_redirects=True,
+ )
+
+ if self.config.token:
+ self._token = self.config.token
+ elif not self._load_state():
+ if not await self._qr_login():
+ logger.error("WeChat login failed. Run 'nanobot channels login weixin' to authenticate.")
+ self._running = False
+ return
+
+ logger.info("WeChat channel starting with long-poll...")
+
+ consecutive_failures = 0
+ while self._running:
+ try:
+ await self._poll_once()
+ consecutive_failures = 0
+ except httpx.TimeoutException:
+ # Normal for long-poll, just retry
+ continue
+ except Exception as e:
+ if not self._running:
+ break
+ consecutive_failures += 1
+ if consecutive_failures >= MAX_CONSECUTIVE_FAILURES:
+ consecutive_failures = 0
+ await asyncio.sleep(BACKOFF_DELAY_S)
+ else:
+ await asyncio.sleep(RETRY_DELAY_S)
+
+ async def stop(self) -> None:
+ self._running = False
+ if self._poll_task and not self._poll_task.done():
+ self._poll_task.cancel()
+ for chat_id in list(self._typing_tasks):
+ await self._stop_typing(chat_id, clear_remote=False)
+ if self._client:
+ await self._client.aclose()
+ self._client = None
+ self._save_state()
+ # ------------------------------------------------------------------
+ # Polling (matches monitor.ts monitorWeixinProvider)
+ # ------------------------------------------------------------------
+
+ def _pause_session(self, duration_s: int = SESSION_PAUSE_DURATION_S) -> None:
+ self._session_pause_until = time.time() + duration_s
+
+ def _session_pause_remaining_s(self) -> int:
+ remaining = int(self._session_pause_until - time.time())
+ if remaining <= 0:
+ self._session_pause_until = 0.0
+ return 0
+ return remaining
+
+ def _assert_session_active(self) -> None:
+ remaining = self._session_pause_remaining_s()
+ if remaining > 0:
+ remaining_min = max((remaining + 59) // 60, 1)
+ raise RuntimeError(
+ f"WeChat session paused, {remaining_min} min remaining (errcode {ERRCODE_SESSION_EXPIRED})"
+ )
+
+ async def _poll_once(self) -> None:
+ remaining = self._session_pause_remaining_s()
+ if remaining > 0:
+ await asyncio.sleep(remaining)
+ return
+
+ body: dict[str, Any] = {
+ "get_updates_buf": self._get_updates_buf,
+ "base_info": BASE_INFO,
+ }
+
+ # Adjust httpx timeout to match the current poll timeout
+ assert self._client is not None
+ self._client.timeout = httpx.Timeout(self._next_poll_timeout_s + 10, connect=30)
+
+ data = await self._api_post("ilink/bot/getupdates", body)
+
+ # Check for API-level errors (monitor.ts checks both ret and errcode)
+ ret = data.get("ret", 0)
+ errcode = data.get("errcode", 0)
+ is_error = (ret is not None and ret != 0) or (errcode is not None and errcode != 0)
+
+ if is_error:
+ if errcode == ERRCODE_SESSION_EXPIRED or ret == ERRCODE_SESSION_EXPIRED:
+ self._pause_session()
+ remaining = self._session_pause_remaining_s()
+ logger.warning(
+ "WeChat session expired (errcode {}). Pausing {} min.",
+ errcode,
+ max((remaining + 59) // 60, 1),
+ )
+ return
+ raise RuntimeError(
+ f"getUpdates failed: ret={ret} errcode={errcode} errmsg={data.get('errmsg', '')}"
+ )
+
+ # Honour server-suggested poll timeout (monitor.ts:102-105)
+ server_timeout_ms = data.get("longpolling_timeout_ms")
+ if server_timeout_ms and server_timeout_ms > 0:
+ self._next_poll_timeout_s = max(server_timeout_ms // 1000, 5)
+
+ # Update cursor
+ new_buf = data.get("get_updates_buf", "")
+ if new_buf:
+ self._get_updates_buf = new_buf
+ self._save_state()
+
+ # Process messages (WeixinMessage[] from types.ts)
+ msgs: list[dict] = data.get("msgs", []) or []
+ for msg in msgs:
+ try:
+ await self._process_message(msg)
+ except Exception:
+ pass
+
+ # ------------------------------------------------------------------
+ # Inbound message processing (matches inbound.ts + process-message.ts)
+ # ------------------------------------------------------------------
+
+ async def _process_message(self, msg: dict) -> None:
+ """Process a single WeixinMessage from getUpdates."""
+ # Skip bot's own messages (message_type 2 = BOT)
+ if msg.get("message_type") == MESSAGE_TYPE_BOT:
+ return
+
+ # Deduplication by message_id
+ msg_id = str(msg.get("message_id", "") or msg.get("seq", ""))
+ if not msg_id:
+ msg_id = f"{msg.get('from_user_id', '')}_{msg.get('create_time_ms', '')}"
+ if msg_id in self._processed_ids:
+ return
+ self._processed_ids[msg_id] = None
+ while len(self._processed_ids) > 1000:
+ self._processed_ids.popitem(last=False)
+
+ from_user_id = msg.get("from_user_id", "") or ""
+ if not from_user_id:
+ return
+
+ # Cache context_token (required for all replies — inbound.ts:23-27)
+ ctx_token = msg.get("context_token", "")
+ if ctx_token:
+ self._context_tokens[from_user_id] = ctx_token
+ self._save_state()
+
+ # Parse item_list (WeixinMessage.item_list — types.ts:161)
+ item_list: list[dict] = msg.get("item_list") or []
+ content_parts: list[str] = []
+ media_paths: list[str] = []
+ has_top_level_downloadable_media = False
+
+ for item in item_list:
+ item_type = item.get("type", 0)
+
+ if item_type == ITEM_TEXT:
+ text = (item.get("text_item") or {}).get("text", "")
+ if text:
+ # Handle quoted/ref messages (inbound.ts:86-98)
+ ref = item.get("ref_msg")
+ if ref:
+ ref_item = ref.get("message_item")
+ # If quoted message is media, just pass the text
+ if ref_item and ref_item.get("type", 0) in (
+ ITEM_IMAGE,
+ ITEM_VOICE,
+ ITEM_FILE,
+ ITEM_VIDEO,
+ ):
+ content_parts.append(text)
+ else:
+ parts: list[str] = []
+ if ref.get("title"):
+ parts.append(ref["title"])
+ if ref_item:
+ ref_text = (ref_item.get("text_item") or {}).get("text", "")
+ if ref_text:
+ parts.append(ref_text)
+ if parts:
+ content_parts.append(f"[引用: {' | '.join(parts)}]\n{text}")
+ else:
+ content_parts.append(text)
+ else:
+ content_parts.append(text)
+
+ elif item_type == ITEM_IMAGE:
+ image_item = item.get("image_item") or {}
+ if _has_downloadable_media_locator(image_item.get("media")):
+ has_top_level_downloadable_media = True
+ file_path = await self._download_media_item(image_item, "image")
+ if file_path:
+ content_parts.append(f"[image]\n[Image: source: {file_path}]")
+ media_paths.append(file_path)
+ else:
+ content_parts.append("[image]")
+
+ elif item_type == ITEM_VOICE:
+ voice_item = item.get("voice_item") or {}
+ # Voice-to-text provided by WeChat (inbound.ts:101-103)
+ voice_text = voice_item.get("text", "")
+ if voice_text:
+ content_parts.append(f"[voice] {voice_text}")
+ else:
+ if _has_downloadable_media_locator(voice_item.get("media")):
+ has_top_level_downloadable_media = True
+ file_path = await self._download_media_item(voice_item, "voice")
+ if file_path:
+ transcription = await self.transcribe_audio(file_path)
+ if transcription:
+ content_parts.append(f"[voice] {transcription}")
+ else:
+ content_parts.append(f"[voice]\n[Audio: source: {file_path}]")
+ media_paths.append(file_path)
+ else:
+ content_parts.append("[voice]")
+
+ elif item_type == ITEM_FILE:
+ file_item = item.get("file_item") or {}
+ if _has_downloadable_media_locator(file_item.get("media")):
+ has_top_level_downloadable_media = True
+ file_name = file_item.get("file_name", "unknown")
+ file_path = await self._download_media_item(
+ file_item,
+ "file",
+ file_name,
+ )
+ if file_path:
+ content_parts.append(f"[file: {file_name}]\n[File: source: {file_path}]")
+ media_paths.append(file_path)
+ else:
+ content_parts.append(f"[file: {file_name}]")
+
+ elif item_type == ITEM_VIDEO:
+ video_item = item.get("video_item") or {}
+ if _has_downloadable_media_locator(video_item.get("media")):
+ has_top_level_downloadable_media = True
+ file_path = await self._download_media_item(video_item, "video")
+ if file_path:
+ content_parts.append(f"[video]\n[Video: source: {file_path}]")
+ media_paths.append(file_path)
+ else:
+ content_parts.append("[video]")
+
+ # Fallback: when no top-level media was downloaded, try quoted/referenced media.
+ # This aligns with the reference plugin behavior that checks ref_msg.message_item
+ # when main item_list has no downloadable media.
+ if not media_paths and not has_top_level_downloadable_media:
+ ref_media_item: dict[str, Any] | None = None
+ for item in item_list:
+ if item.get("type", 0) != ITEM_TEXT:
+ continue
+ ref = item.get("ref_msg") or {}
+ candidate = ref.get("message_item") or {}
+ if candidate.get("type", 0) in (ITEM_IMAGE, ITEM_VOICE, ITEM_FILE, ITEM_VIDEO):
+ ref_media_item = candidate
+ break
+
+ if ref_media_item:
+ ref_type = ref_media_item.get("type", 0)
+ if ref_type == ITEM_IMAGE:
+ image_item = ref_media_item.get("image_item") or {}
+ file_path = await self._download_media_item(image_item, "image")
+ if file_path:
+ content_parts.append(f"[image]\n[Image: source: {file_path}]")
+ media_paths.append(file_path)
+ elif ref_type == ITEM_VOICE:
+ voice_item = ref_media_item.get("voice_item") or {}
+ file_path = await self._download_media_item(voice_item, "voice")
+ if file_path:
+ transcription = await self.transcribe_audio(file_path)
+ if transcription:
+ content_parts.append(f"[voice] {transcription}")
+ else:
+ content_parts.append(f"[voice]\n[Audio: source: {file_path}]")
+ media_paths.append(file_path)
+ elif ref_type == ITEM_FILE:
+ file_item = ref_media_item.get("file_item") or {}
+ file_name = file_item.get("file_name", "unknown")
+ file_path = await self._download_media_item(file_item, "file", file_name)
+ if file_path:
+ content_parts.append(f"[file: {file_name}]\n[File: source: {file_path}]")
+ media_paths.append(file_path)
+ elif ref_type == ITEM_VIDEO:
+ video_item = ref_media_item.get("video_item") or {}
+ file_path = await self._download_media_item(video_item, "video")
+ if file_path:
+ content_parts.append(f"[video]\n[Video: source: {file_path}]")
+ media_paths.append(file_path)
+
+ content = "\n".join(content_parts)
+ if not content:
+ return
+
+ logger.info(
+ "WeChat inbound: from={} items={} bodyLen={}",
+ from_user_id,
+ ",".join(str(i.get("type", 0)) for i in item_list),
+ len(content),
+ )
+
+ await self._start_typing(from_user_id, ctx_token)
+
+ await self._handle_message(
+ sender_id=from_user_id,
+ chat_id=from_user_id,
+ content=content,
+ media=media_paths or None,
+ metadata={"message_id": msg_id},
+ )
+
+ # ------------------------------------------------------------------
+ # Media download (matches media-download.ts + pic-decrypt.ts)
+ # ------------------------------------------------------------------
+
+ async def _download_media_item(
+ self,
+ typed_item: dict,
+ media_type: str,
+ filename: str | None = None,
+ ) -> str | None:
+ """Download + AES-decrypt a media item. Returns local path or None."""
+ try:
+ media = typed_item.get("media") or {}
+ encrypt_query_param = str(media.get("encrypt_query_param", "") or "")
+ full_url = str(media.get("full_url", "") or "").strip()
+
+ if not encrypt_query_param and not full_url:
+ return None
+
+ # Resolve AES key (media-download.ts:43-45, pic-decrypt.ts:40-52)
+ # image_item.aeskey is a raw hex string (16 bytes as 32 hex chars).
+ # media.aes_key is always base64-encoded.
+ # For images, prefer image_item.aeskey; for others use media.aes_key.
+ raw_aeskey_hex = typed_item.get("aeskey", "")
+ media_aes_key_b64 = media.get("aes_key", "")
+
+ aes_key_b64: str = ""
+ if raw_aeskey_hex:
+ # Convert hex → raw bytes → base64 (matches media-download.ts:43-44)
+ aes_key_b64 = base64.b64encode(bytes.fromhex(raw_aeskey_hex)).decode()
+ elif media_aes_key_b64:
+ aes_key_b64 = media_aes_key_b64
+
+ # Reference protocol behavior: VOICE/FILE/VIDEO require aes_key;
+ # only IMAGE may be downloaded as plain bytes when key is missing.
+ if media_type != "image" and not aes_key_b64:
+ return None
+
+ assert self._client is not None
+ fallback_url = ""
+ if encrypt_query_param:
+ fallback_url = (
+ f"{self.config.cdn_base_url}/download"
+ f"?encrypted_query_param={quote(encrypt_query_param)}"
+ )
+
+ download_candidates: list[tuple[str, str]] = []
+ if full_url:
+ download_candidates.append(("full_url", full_url))
+ if fallback_url and (not full_url or fallback_url != full_url):
+ download_candidates.append(("encrypt_query_param", fallback_url))
+
+ data = b""
+ for idx, (download_source, cdn_url) in enumerate(download_candidates):
+ try:
+ resp = await self._client.get(cdn_url)
+ resp.raise_for_status()
+ data = resp.content
+ break
+ except Exception as e:
+ has_more_candidates = idx + 1 < len(download_candidates)
+ should_fallback = (
+ download_source == "full_url"
+ and has_more_candidates
+ and self._is_retryable_media_download_error(e)
+ )
+ if should_fallback:
+ logger.warning(
+ "WeChat media download failed via full_url, falling back to encrypt_query_param: type={} err={}",
+ media_type,
+ e,
+ )
+ continue
+ raise
+
+ if aes_key_b64 and data:
+ data = _decrypt_aes_ecb(data, aes_key_b64)
+
+ if not data:
+ return None
+
+ media_dir = get_media_dir("weixin")
+ ext = _ext_for_type(media_type)
+ if not filename:
+ ts = int(time.time())
+ hash_seed = encrypt_query_param or full_url
+ h = abs(hash(hash_seed)) % 100000
+ filename = f"{media_type}_{ts}_{h}{ext}"
+ safe_name = os.path.basename(filename)
+ file_path = media_dir / safe_name
+ file_path.write_bytes(data)
+ return str(file_path)
+
+ except Exception as e:
+ logger.error("Error downloading WeChat media: {}", e)
+ return None
+
+ # ------------------------------------------------------------------
+ # Outbound (matches send.ts buildTextMessageReq + sendMessageWeixin)
+ # ------------------------------------------------------------------
+
+ async def _get_typing_ticket(self, user_id: str, context_token: str = "") -> str:
+ """Get typing ticket with per-user refresh + failure backoff cache."""
+ now = time.time()
+ entry = self._typing_tickets.get(user_id)
+ if entry and now < float(entry.get("next_fetch_at", 0)):
+ return str(entry.get("ticket", "") or "")
+
+ body: dict[str, Any] = {
+ "ilink_user_id": user_id,
+ "context_token": context_token or None,
+ "base_info": BASE_INFO,
+ }
+ data = await self._api_post("ilink/bot/getconfig", body)
+ if data.get("ret", 0) == 0:
+ ticket = str(data.get("typing_ticket", "") or "")
+ self._typing_tickets[user_id] = {
+ "ticket": ticket,
+ "ever_succeeded": True,
+ "next_fetch_at": now + (random.random() * TYPING_TICKET_TTL_S),
+ "retry_delay_s": CONFIG_CACHE_INITIAL_RETRY_S,
+ }
+ return ticket
+
+ prev_delay = float(entry.get("retry_delay_s", CONFIG_CACHE_INITIAL_RETRY_S)) if entry else CONFIG_CACHE_INITIAL_RETRY_S
+ next_delay = min(prev_delay * 2, CONFIG_CACHE_MAX_RETRY_S)
+ if entry:
+ entry["next_fetch_at"] = now + next_delay
+ entry["retry_delay_s"] = next_delay
+ return str(entry.get("ticket", "") or "")
+
+ self._typing_tickets[user_id] = {
+ "ticket": "",
+ "ever_succeeded": False,
+ "next_fetch_at": now + CONFIG_CACHE_INITIAL_RETRY_S,
+ "retry_delay_s": CONFIG_CACHE_INITIAL_RETRY_S,
+ }
+ return ""
+
+ async def _send_typing(self, user_id: str, typing_ticket: str, status: int) -> None:
+ """Best-effort sendtyping wrapper."""
+ if not typing_ticket:
+ return
+ body: dict[str, Any] = {
+ "ilink_user_id": user_id,
+ "typing_ticket": typing_ticket,
+ "status": status,
+ "base_info": BASE_INFO,
+ }
+ await self._api_post("ilink/bot/sendtyping", body)
+
+ async def _typing_keepalive_loop(self, user_id: str, typing_ticket: str, stop_event: asyncio.Event) -> None:
+ try:
+ while not stop_event.is_set():
+ await asyncio.sleep(TYPING_KEEPALIVE_INTERVAL_S)
+ if stop_event.is_set():
+ break
+ try:
+ await self._send_typing(user_id, typing_ticket, TYPING_STATUS_TYPING)
+ except Exception:
+ pass
+ finally:
+ pass
+
+ async def send(self, msg: OutboundMessage) -> None:
+ if not self._client or not self._token:
+ logger.warning("WeChat client not initialized or not authenticated")
+ return
+ try:
+ self._assert_session_active()
+ except RuntimeError:
+ return
+
+ is_progress = bool((msg.metadata or {}).get("_progress", False))
+ if not is_progress:
+ await self._stop_typing(msg.chat_id, clear_remote=True)
+
+ content = msg.content.strip()
+ ctx_token = self._context_tokens.get(msg.chat_id, "")
+ if not ctx_token:
+ logger.warning(
+ "WeChat: no context_token for chat_id={}, cannot send",
+ msg.chat_id,
+ )
+ return
+
+ typing_ticket = ""
+ try:
+ typing_ticket = await self._get_typing_ticket(msg.chat_id, ctx_token)
+ except Exception:
+ typing_ticket = ""
+
+ if typing_ticket:
+ try:
+ await self._send_typing(msg.chat_id, typing_ticket, TYPING_STATUS_TYPING)
+ except Exception:
+ pass
+
+ typing_keepalive_stop = asyncio.Event()
+ typing_keepalive_task: asyncio.Task | None = None
+ if typing_ticket:
+ typing_keepalive_task = asyncio.create_task(
+ self._typing_keepalive_loop(msg.chat_id, typing_ticket, typing_keepalive_stop)
+ )
+
+ try:
+ # --- Send media files first (following Telegram channel pattern) ---
+ for media_path in (msg.media or []):
+ try:
+ await self._send_media_file(msg.chat_id, media_path, ctx_token)
+ except Exception as e:
+ filename = Path(media_path).name
+ logger.error("Failed to send WeChat media {}: {}", media_path, e)
+ # Notify user about failure via text
+ await self._send_text(
+ msg.chat_id, f"[Failed to send: {filename}]", ctx_token,
+ )
+
+ # --- Send text content ---
+ if not content:
+ return
+
+ chunks = split_message(content, WEIXIN_MAX_MESSAGE_LEN)
+ for chunk in chunks:
+ await self._send_text(msg.chat_id, chunk, ctx_token)
+ except Exception as e:
+ logger.error("Error sending WeChat message: {}", e)
+ raise
+ finally:
+ if typing_keepalive_task:
+ typing_keepalive_stop.set()
+ typing_keepalive_task.cancel()
+ try:
+ await typing_keepalive_task
+ except asyncio.CancelledError:
+ pass
+
+ if typing_ticket and not is_progress:
+ try:
+ await self._send_typing(msg.chat_id, typing_ticket, TYPING_STATUS_CANCEL)
+ except Exception:
+ pass
+
+ async def _start_typing(self, chat_id: str, context_token: str = "") -> None:
+ """Start typing indicator immediately when a message is received."""
+ if not self._client or not self._token or not chat_id:
+ return
+ await self._stop_typing(chat_id, clear_remote=False)
+ try:
+ ticket = await self._get_typing_ticket(chat_id, context_token)
+ if not ticket:
+ return
+ await self._send_typing(chat_id, ticket, TYPING_STATUS_TYPING)
+ except Exception as e:
+ logger.debug("WeChat typing indicator start failed for {}: {}", chat_id, e)
+ return
+
+ stop_event = asyncio.Event()
+
+ async def keepalive() -> None:
+ try:
+ while not stop_event.is_set():
+ await asyncio.sleep(TYPING_KEEPALIVE_INTERVAL_S)
+ if stop_event.is_set():
+ break
+ try:
+ await self._send_typing(chat_id, ticket, TYPING_STATUS_TYPING)
+ except Exception:
+ pass
+ finally:
+ pass
+
+ task = asyncio.create_task(keepalive())
+ task._typing_stop_event = stop_event # type: ignore[attr-defined]
+ self._typing_tasks[chat_id] = task
+
+ async def _stop_typing(self, chat_id: str, *, clear_remote: bool) -> None:
+ """Stop typing indicator for a chat."""
+ task = self._typing_tasks.pop(chat_id, None)
+ if task and not task.done():
+ stop_event = getattr(task, "_typing_stop_event", None)
+ if stop_event:
+ stop_event.set()
+ task.cancel()
+ try:
+ await task
+ except asyncio.CancelledError:
+ pass
+ if not clear_remote:
+ return
+ entry = self._typing_tickets.get(chat_id)
+ ticket = str(entry.get("ticket", "") or "") if isinstance(entry, dict) else ""
+ if not ticket:
+ return
+ try:
+ await self._send_typing(chat_id, ticket, TYPING_STATUS_CANCEL)
+ except Exception as e:
+ logger.debug("WeChat typing clear failed for {}: {}", chat_id, e)
+
+ async def _send_text(
+ self,
+ to_user_id: str,
+ text: str,
+ context_token: str,
+ ) -> None:
+ """Send a text message matching the exact protocol from send.ts."""
+ client_id = f"nanobot-{uuid.uuid4().hex[:12]}"
+
+ item_list: list[dict] = []
+ if text:
+ item_list.append({"type": ITEM_TEXT, "text_item": {"text": text}})
+
+ weixin_msg: dict[str, Any] = {
+ "from_user_id": "",
+ "to_user_id": to_user_id,
+ "client_id": client_id,
+ "message_type": MESSAGE_TYPE_BOT,
+ "message_state": MESSAGE_STATE_FINISH,
+ }
+ if item_list:
+ weixin_msg["item_list"] = item_list
+ if context_token:
+ weixin_msg["context_token"] = context_token
+
+ body: dict[str, Any] = {
+ "msg": weixin_msg,
+ "base_info": BASE_INFO,
+ }
+
+ data = await self._api_post("ilink/bot/sendmessage", body)
+ errcode = data.get("errcode", 0)
+ if errcode and errcode != 0:
+ logger.warning(
+ "WeChat send error (code {}): {}",
+ errcode,
+ data.get("errmsg", ""),
+ )
+
+ async def _send_media_file(
+ self,
+ to_user_id: str,
+ media_path: str,
+ context_token: str,
+ ) -> None:
+ """Upload a local file to WeChat CDN and send it as a media message.
+
+ Follows the exact protocol from ``@tencent-weixin/openclaw-weixin`` v1.0.3:
+ 1. Generate a random 16-byte AES key (client-side).
+ 2. Call ``getuploadurl`` with file metadata + hex-encoded AES key.
+ 3. AES-128-ECB encrypt the file and POST to CDN (``{cdnBaseUrl}/upload``).
+ 4. Read ``x-encrypted-param`` header from CDN response as the download param.
+ 5. Send a ``sendmessage`` with the appropriate media item referencing the upload.
+ """
+ p = Path(media_path)
+ if not p.is_file():
+ raise FileNotFoundError(f"Media file not found: {media_path}")
+
+ raw_data = p.read_bytes()
+ raw_size = len(raw_data)
+ raw_md5 = hashlib.md5(raw_data).hexdigest()
+
+ # Determine upload media type from extension
+ ext = p.suffix.lower()
+ if ext in _IMAGE_EXTS:
+ upload_type = UPLOAD_MEDIA_IMAGE
+ item_type = ITEM_IMAGE
+ item_key = "image_item"
+ elif ext in _VIDEO_EXTS:
+ upload_type = UPLOAD_MEDIA_VIDEO
+ item_type = ITEM_VIDEO
+ item_key = "video_item"
+ elif ext in _VOICE_EXTS:
+ upload_type = UPLOAD_MEDIA_VOICE
+ item_type = ITEM_VOICE
+ item_key = "voice_item"
+ else:
+ upload_type = UPLOAD_MEDIA_FILE
+ item_type = ITEM_FILE
+ item_key = "file_item"
+
+ # Generate client-side AES-128 key (16 random bytes)
+ aes_key_raw = os.urandom(16)
+ aes_key_hex = aes_key_raw.hex()
+
+ # Compute encrypted size: PKCS7 padding to 16-byte boundary
+ # Matches aesEcbPaddedSize: Math.ceil((size + 1) / 16) * 16
+ padded_size = ((raw_size + 1 + 15) // 16) * 16
+
+ # Step 1: Get upload URL from server (prefer upload_full_url, fallback to upload_param)
+ file_key = os.urandom(16).hex()
+ upload_body: dict[str, Any] = {
+ "filekey": file_key,
+ "media_type": upload_type,
+ "to_user_id": to_user_id,
+ "rawsize": raw_size,
+ "rawfilemd5": raw_md5,
+ "filesize": padded_size,
+ "no_need_thumb": True,
+ "aeskey": aes_key_hex,
+ }
+
+ assert self._client is not None
+ upload_resp = await self._api_post("ilink/bot/getuploadurl", upload_body)
+
+ upload_full_url = str(upload_resp.get("upload_full_url", "") or "").strip()
+ upload_param = str(upload_resp.get("upload_param", "") or "")
+ if not upload_full_url and not upload_param:
+ raise RuntimeError(
+ "getuploadurl returned no upload URL "
+ f"(need upload_full_url or upload_param): {upload_resp}"
+ )
+
+ # Step 2: AES-128-ECB encrypt and POST to CDN
+ aes_key_b64 = base64.b64encode(aes_key_raw).decode()
+ encrypted_data = _encrypt_aes_ecb(raw_data, aes_key_b64)
+
+ if upload_full_url:
+ cdn_upload_url = upload_full_url
+ else:
+ cdn_upload_url = (
+ f"{self.config.cdn_base_url}/upload"
+ f"?encrypted_query_param={quote(upload_param)}"
+ f"&filekey={quote(file_key)}"
+ )
+
+ cdn_resp = await self._client.post(
+ cdn_upload_url,
+ content=encrypted_data,
+ headers={"Content-Type": "application/octet-stream"},
+ )
+ cdn_resp.raise_for_status()
+
+ # The download encrypted_query_param comes from CDN response header
+ download_param = cdn_resp.headers.get("x-encrypted-param", "")
+ if not download_param:
+ raise RuntimeError(
+ "CDN upload response missing x-encrypted-param header; "
+ f"status={cdn_resp.status_code} headers={dict(cdn_resp.headers)}"
+ )
+
+ # Step 3: Send message with the media item
+ # aes_key for CDNMedia is the hex key encoded as base64
+ # (matches: Buffer.from(uploaded.aeskey).toString("base64"))
+ cdn_aes_key_b64 = base64.b64encode(aes_key_hex.encode()).decode()
+
+ media_item: dict[str, Any] = {
+ "media": {
+ "encrypt_query_param": download_param,
+ "aes_key": cdn_aes_key_b64,
+ "encrypt_type": 1,
+ },
+ }
+
+ if item_type == ITEM_IMAGE:
+ media_item["mid_size"] = padded_size
+ elif item_type == ITEM_VIDEO:
+ media_item["video_size"] = padded_size
+ elif item_type == ITEM_FILE:
+ media_item["file_name"] = p.name
+ media_item["len"] = str(raw_size)
+
+ # Send each media item as its own message (matching reference plugin)
+ client_id = f"nanobot-{uuid.uuid4().hex[:12]}"
+ item_list: list[dict] = [{"type": item_type, item_key: media_item}]
+
+ weixin_msg: dict[str, Any] = {
+ "from_user_id": "",
+ "to_user_id": to_user_id,
+ "client_id": client_id,
+ "message_type": MESSAGE_TYPE_BOT,
+ "message_state": MESSAGE_STATE_FINISH,
+ "item_list": item_list,
+ }
+ if context_token:
+ weixin_msg["context_token"] = context_token
+
+ body: dict[str, Any] = {
+ "msg": weixin_msg,
+ "base_info": BASE_INFO,
+ }
+
+ data = await self._api_post("ilink/bot/sendmessage", body)
+ errcode = data.get("errcode", 0)
+ if errcode and errcode != 0:
+ raise RuntimeError(
+ f"WeChat send media error (code {errcode}): {data.get('errmsg', '')}"
+ )
+
+
+# ---------------------------------------------------------------------------
+# AES-128-ECB encryption / decryption (matches pic-decrypt.ts / aes-ecb.ts)
+# ---------------------------------------------------------------------------
+
+
+def _parse_aes_key(aes_key_b64: str) -> bytes:
+ """Parse a base64-encoded AES key, handling both encodings seen in the wild.
+
+ From ``pic-decrypt.ts parseAesKey``:
+
+ * ``base64(raw 16 bytes)`` → images (media.aes_key)
+ * ``base64(hex string of 16 bytes)`` → file / voice / video
+
+ In the second case base64-decoding yields 32 ASCII hex chars which must
+ then be parsed as hex to recover the actual 16-byte key.
+ """
+ decoded = base64.b64decode(aes_key_b64)
+ if len(decoded) == 16:
+ return decoded
+ if len(decoded) == 32 and re.fullmatch(rb"[0-9a-fA-F]{32}", decoded):
+ # hex-encoded key: base64 → hex string → raw bytes
+ return bytes.fromhex(decoded.decode("ascii"))
+ raise ValueError(
+ f"aes_key must decode to 16 raw bytes or 32-char hex string, got {len(decoded)} bytes"
+ )
+
+
+def _encrypt_aes_ecb(data: bytes, aes_key_b64: str) -> bytes:
+ """Encrypt data with AES-128-ECB and PKCS7 padding for CDN upload."""
+ try:
+ key = _parse_aes_key(aes_key_b64)
+ except Exception as e:
+ logger.warning("Failed to parse AES key for encryption, sending raw: {}", e)
+ return data
+
+ # PKCS7 padding
+ pad_len = 16 - len(data) % 16
+ padded = data + bytes([pad_len] * pad_len)
+
+ try:
+ from Crypto.Cipher import AES
+
+ cipher = AES.new(key, AES.MODE_ECB)
+ return cipher.encrypt(padded)
+ except ImportError:
+ pass
+
+ try:
+ from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
+
+ cipher_obj = Cipher(algorithms.AES(key), modes.ECB())
+ encryptor = cipher_obj.encryptor()
+ return encryptor.update(padded) + encryptor.finalize()
+ except ImportError:
+ logger.warning("Cannot encrypt media: install 'pycryptodome' or 'cryptography'")
+ return data
+
+
+def _decrypt_aes_ecb(data: bytes, aes_key_b64: str) -> bytes:
+ """Decrypt AES-128-ECB media data.
+
+ ``aes_key_b64`` is always base64-encoded (caller converts hex keys first).
+ """
+ try:
+ key = _parse_aes_key(aes_key_b64)
+ except Exception as e:
+ logger.warning("Failed to parse AES key, returning raw data: {}", e)
+ return data
+
+ decrypted: bytes | None = None
+
+ try:
+ from Crypto.Cipher import AES
+
+ cipher = AES.new(key, AES.MODE_ECB)
+ decrypted = cipher.decrypt(data)
+ except ImportError:
+ pass
+
+ if decrypted is None:
+ try:
+ from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
+
+ cipher_obj = Cipher(algorithms.AES(key), modes.ECB())
+ decryptor = cipher_obj.decryptor()
+ decrypted = decryptor.update(data) + decryptor.finalize()
+ except ImportError:
+ logger.warning("Cannot decrypt media: install 'pycryptodome' or 'cryptography'")
+ return data
+
+ return _pkcs7_unpad_safe(decrypted)
+
+
+def _pkcs7_unpad_safe(data: bytes, block_size: int = 16) -> bytes:
+ """Safely remove PKCS7 padding when valid; otherwise return original bytes."""
+ if not data:
+ return data
+ if len(data) % block_size != 0:
+ return data
+ pad_len = data[-1]
+ if pad_len < 1 or pad_len > block_size:
+ return data
+ if data[-pad_len:] != bytes([pad_len]) * pad_len:
+ return data
+ return data[:-pad_len]
+
+
+def _ext_for_type(media_type: str) -> str:
+ return {
+ "image": ".jpg",
+ "voice": ".silk",
+ "video": ".mp4",
+ "file": "",
+ }.get(media_type, "")
diff --git a/nanobot/channels/whatsapp.py b/nanobot/channels/whatsapp.py
index c14a6c3e6..a788dd727 100644
--- a/nanobot/channels/whatsapp.py
+++ b/nanobot/channels/whatsapp.py
@@ -2,140 +2,331 @@
import asyncio
import json
-from typing import Any
+import mimetypes
+import os
+import secrets
+import shutil
+import subprocess
+from collections import OrderedDict
+from pathlib import Path
+from typing import Any, Literal
from loguru import logger
+from pydantic import Field
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.base import BaseChannel
-from nanobot.config.schema import WhatsAppConfig
+from nanobot.config.schema import Base
+
+
+class WhatsAppConfig(Base):
+ """WhatsApp channel configuration."""
+
+ enabled: bool = False
+ bridge_url: str = "ws://localhost:3001"
+ bridge_token: str = ""
+ allow_from: list[str] = Field(default_factory=list)
+ group_policy: Literal["open", "mention"] = "open" # "open" responds to all, "mention" only when @mentioned
+
+
+def _bridge_token_path() -> Path:
+ from nanobot.config.paths import get_runtime_subdir
+
+ return get_runtime_subdir("whatsapp-auth") / "bridge-token"
+
+
+def _load_or_create_bridge_token(path: Path) -> str:
+ """Load a persisted bridge token or create one on first use."""
+ if path.exists():
+ token = path.read_text(encoding="utf-8").strip()
+ if token:
+ return token
+
+ path.parent.mkdir(parents=True, exist_ok=True)
+ token = secrets.token_urlsafe(32)
+ path.write_text(token, encoding="utf-8")
+ try:
+ path.chmod(0o600)
+ except OSError:
+ pass
+ return token
class WhatsAppChannel(BaseChannel):
"""
WhatsApp channel that connects to a Node.js bridge.
-
+
The bridge uses @whiskeysockets/baileys to handle the WhatsApp Web protocol.
Communication between Python and Node.js is via WebSocket.
"""
-
+
name = "whatsapp"
-
- def __init__(self, config: WhatsAppConfig, bus: MessageBus):
+ display_name = "WhatsApp"
+
+ @classmethod
+ def default_config(cls) -> dict[str, Any]:
+ return WhatsAppConfig().model_dump(by_alias=True)
+
+ def __init__(self, config: Any, bus: MessageBus):
+ if isinstance(config, dict):
+ config = WhatsAppConfig.model_validate(config)
super().__init__(config, bus)
- self.config: WhatsAppConfig = config
self._ws = None
self._connected = False
-
+ self._processed_message_ids: OrderedDict[str, None] = OrderedDict()
+ self._bridge_token: str | None = None
+
+ def _effective_bridge_token(self) -> str:
+ """Resolve the bridge token, generating a local secret when needed."""
+ if self._bridge_token is not None:
+ return self._bridge_token
+ configured = self.config.bridge_token.strip()
+ if configured:
+ self._bridge_token = configured
+ else:
+ self._bridge_token = _load_or_create_bridge_token(_bridge_token_path())
+ return self._bridge_token
+
+ async def login(self, force: bool = False) -> bool:
+ """
+ Set up and run the WhatsApp bridge for QR code login.
+
+ This spawns the Node.js bridge process which handles the WhatsApp
+ authentication flow. The process blocks until the user scans the QR code
+ or interrupts with Ctrl+C.
+ """
+ try:
+ bridge_dir = _ensure_bridge_setup()
+ except RuntimeError as e:
+ logger.error("{}", e)
+ return False
+
+ env = {**os.environ}
+ env["BRIDGE_TOKEN"] = self._effective_bridge_token()
+ env["AUTH_DIR"] = str(_bridge_token_path().parent)
+
+ logger.info("Starting WhatsApp bridge for QR login...")
+ try:
+ subprocess.run(
+ [shutil.which("npm"), "start"], cwd=bridge_dir, check=True, env=env
+ )
+ except subprocess.CalledProcessError:
+ return False
+
+ return True
+
async def start(self) -> None:
"""Start the WhatsApp channel by connecting to the bridge."""
import websockets
-
+
bridge_url = self.config.bridge_url
-
- logger.info(f"Connecting to WhatsApp bridge at {bridge_url}...")
-
+
+ logger.info("Connecting to WhatsApp bridge at {}...", bridge_url)
+
self._running = True
-
+
while self._running:
try:
async with websockets.connect(bridge_url) as ws:
self._ws = ws
+ await ws.send(
+ json.dumps({"type": "auth", "token": self._effective_bridge_token()})
+ )
self._connected = True
logger.info("Connected to WhatsApp bridge")
-
+
# Listen for messages
async for message in ws:
try:
await self._handle_bridge_message(message)
except Exception as e:
- logger.error(f"Error handling bridge message: {e}")
-
+ logger.error("Error handling bridge message: {}", e)
+
except asyncio.CancelledError:
break
except Exception as e:
self._connected = False
self._ws = None
- logger.warning(f"WhatsApp bridge connection error: {e}")
-
+ logger.warning("WhatsApp bridge connection error: {}", e)
+
if self._running:
logger.info("Reconnecting in 5 seconds...")
await asyncio.sleep(5)
-
+
async def stop(self) -> None:
"""Stop the WhatsApp channel."""
self._running = False
self._connected = False
-
+
if self._ws:
await self._ws.close()
self._ws = None
-
+
async def send(self, msg: OutboundMessage) -> None:
"""Send a message through WhatsApp."""
if not self._ws or not self._connected:
logger.warning("WhatsApp bridge not connected")
return
-
- try:
- payload = {
- "type": "send",
- "to": msg.chat_id,
- "text": msg.content
- }
- await self._ws.send(json.dumps(payload))
- except Exception as e:
- logger.error(f"Error sending WhatsApp message: {e}")
-
+
+ chat_id = msg.chat_id
+
+ if msg.content:
+ try:
+ payload = {"type": "send", "to": chat_id, "text": msg.content}
+ await self._ws.send(json.dumps(payload, ensure_ascii=False))
+ except Exception as e:
+ logger.error("Error sending WhatsApp message: {}", e)
+ raise
+
+ for media_path in msg.media or []:
+ try:
+ mime, _ = mimetypes.guess_type(media_path)
+ payload = {
+ "type": "send_media",
+ "to": chat_id,
+ "filePath": media_path,
+ "mimetype": mime or "application/octet-stream",
+ "fileName": media_path.rsplit("/", 1)[-1],
+ }
+ await self._ws.send(json.dumps(payload, ensure_ascii=False))
+ except Exception as e:
+ logger.error("Error sending WhatsApp media {}: {}", media_path, e)
+ raise
+
async def _handle_bridge_message(self, raw: str) -> None:
"""Handle a message from the bridge."""
try:
data = json.loads(raw)
except json.JSONDecodeError:
- logger.warning(f"Invalid JSON from bridge: {raw[:100]}")
+ logger.warning("Invalid JSON from bridge: {}", raw[:100])
return
-
+
msg_type = data.get("type")
-
+
if msg_type == "message":
# Incoming message from WhatsApp
+ # Deprecated by whatsapp: old phone number style typically: @s.whatspp.net
+ pn = data.get("pn", "")
+ # New LID sytle typically:
sender = data.get("sender", "")
content = data.get("content", "")
-
- # sender is typically: @s.whatsapp.net
- # Extract just the phone number as chat_id
- chat_id = sender.split("@")[0] if "@" in sender else sender
-
+ message_id = data.get("id", "")
+
+ if message_id:
+ if message_id in self._processed_message_ids:
+ return
+ self._processed_message_ids[message_id] = None
+ while len(self._processed_message_ids) > 1000:
+ self._processed_message_ids.popitem(last=False)
+
+ # Extract just the phone number or lid as chat_id
+ is_group = data.get("isGroup", False)
+ was_mentioned = data.get("wasMentioned", False)
+
+ if is_group and getattr(self.config, "group_policy", "open") == "mention":
+ if not was_mentioned:
+ return
+
+ user_id = pn if pn else sender
+ sender_id = user_id.split("@")[0] if "@" in user_id else user_id
+ logger.info("Sender {}", sender)
+
# Handle voice transcription if it's a voice message
if content == "[Voice Message]":
- logger.info(f"Voice message received from {chat_id}, but direct download from bridge is not yet supported.")
+ logger.info(
+ "Voice message received from {}, but direct download from bridge is not yet supported.",
+ sender_id,
+ )
content = "[Voice Message: Transcription not available for WhatsApp yet]"
-
+
+ # Extract media paths (images/documents/videos downloaded by the bridge)
+ media_paths = data.get("media") or []
+
+ # Build content tags matching Telegram's pattern: [image: /path] or [file: /path]
+ if media_paths:
+ for p in media_paths:
+ mime, _ = mimetypes.guess_type(p)
+ media_type = "image" if mime and mime.startswith("image/") else "file"
+ media_tag = f"[{media_type}: {p}]"
+ content = f"{content}\n{media_tag}" if content else media_tag
+
await self._handle_message(
- sender_id=chat_id,
- chat_id=sender, # Use full JID for replies
+ sender_id=sender_id,
+ chat_id=sender, # Use full LID for replies
content=content,
+ media=media_paths,
metadata={
- "message_id": data.get("id"),
+ "message_id": message_id,
"timestamp": data.get("timestamp"),
- "is_group": data.get("isGroup", False)
- }
+ "is_group": data.get("isGroup", False),
+ },
)
-
+
elif msg_type == "status":
# Connection status update
status = data.get("status")
- logger.info(f"WhatsApp status: {status}")
-
+ logger.info("WhatsApp status: {}", status)
+
if status == "connected":
self._connected = True
elif status == "disconnected":
self._connected = False
-
+
elif msg_type == "qr":
# QR code for authentication
logger.info("Scan QR code in the bridge terminal to connect WhatsApp")
-
+
elif msg_type == "error":
- logger.error(f"WhatsApp bridge error: {data.get('error')}")
+ logger.error("WhatsApp bridge error: {}", data.get("error"))
+
+
+def _ensure_bridge_setup() -> Path:
+ """
+ Ensure the WhatsApp bridge is set up and built.
+
+ Returns the bridge directory. Raises RuntimeError if npm is not found
+ or bridge cannot be built.
+ """
+ from nanobot.config.paths import get_bridge_install_dir
+
+ user_bridge = get_bridge_install_dir()
+
+ if (user_bridge / "dist" / "index.js").exists():
+ return user_bridge
+
+ npm_path = shutil.which("npm")
+ if not npm_path:
+ raise RuntimeError("npm not found. Please install Node.js >= 18.")
+
+ # Find source bridge
+ current_file = Path(__file__)
+ pkg_bridge = current_file.parent.parent / "bridge"
+ src_bridge = current_file.parent.parent.parent / "bridge"
+
+ source = None
+ if (pkg_bridge / "package.json").exists():
+ source = pkg_bridge
+ elif (src_bridge / "package.json").exists():
+ source = src_bridge
+
+ if not source:
+ raise RuntimeError(
+ "WhatsApp bridge source not found. "
+ "Try reinstalling: pip install --force-reinstall nanobot"
+ )
+
+ logger.info("Setting up WhatsApp bridge...")
+ user_bridge.parent.mkdir(parents=True, exist_ok=True)
+ if user_bridge.exists():
+ shutil.rmtree(user_bridge)
+ shutil.copytree(source, user_bridge, ignore=shutil.ignore_patterns("node_modules", "dist"))
+
+ logger.info(" Installing dependencies...")
+ subprocess.run([npm_path, "install"], cwd=user_bridge, check=True, capture_output=True)
+
+ logger.info(" Building...")
+ subprocess.run([npm_path, "run", "build"], cwd=user_bridge, check=True, capture_output=True)
+
+ logger.info("Bridge ready")
+ return user_bridge
diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py
index c2241fbf2..dfb13ba97 100644
--- a/nanobot/cli/commands.py
+++ b/nanobot/cli/commands.py
@@ -1,21 +1,235 @@
"""CLI commands for nanobot."""
import asyncio
+from contextlib import contextmanager, nullcontext
+
+import os
+import select
+import signal
+import sys
from pathlib import Path
+from typing import Any
+
+# Force UTF-8 encoding for Windows console
+if sys.platform == "win32":
+ if sys.stdout.encoding != "utf-8":
+ os.environ["PYTHONIOENCODING"] = "utf-8"
+ # Re-open stdout/stderr with UTF-8 encoding
+ try:
+ sys.stdout.reconfigure(encoding="utf-8", errors="replace")
+ sys.stderr.reconfigure(encoding="utf-8", errors="replace")
+ except Exception:
+ pass
import typer
+from loguru import logger
+from prompt_toolkit import PromptSession, print_formatted_text
+from prompt_toolkit.application import run_in_terminal
+from prompt_toolkit.formatted_text import ANSI, HTML
+from prompt_toolkit.history import FileHistory
+from prompt_toolkit.patch_stdout import patch_stdout
from rich.console import Console
+from rich.markdown import Markdown
from rich.table import Table
+from rich.text import Text
-from nanobot import __version__, __logo__
+from nanobot import __logo__, __version__
+from nanobot.cli.stream import StreamRenderer, ThinkingSpinner
+from nanobot.config.paths import get_workspace_path, is_default_workspace
+from nanobot.config.schema import Config
+from nanobot.utils.helpers import sync_workspace_templates
+from nanobot.utils.restart import (
+ consume_restart_notice_from_env,
+ format_restart_completed_message,
+ should_show_cli_restart_notice,
+)
app = typer.Typer(
name="nanobot",
+ context_settings={"help_option_names": ["-h", "--help"]},
help=f"{__logo__} nanobot - Personal AI Assistant",
no_args_is_help=True,
)
console = Console()
+EXIT_COMMANDS = {"exit", "quit", "/exit", "/quit", ":q"}
+
+# ---------------------------------------------------------------------------
+# CLI input: prompt_toolkit for editing, paste, history, and display
+# ---------------------------------------------------------------------------
+
+_PROMPT_SESSION: PromptSession | None = None
+_SAVED_TERM_ATTRS = None # original termios settings, restored on exit
+
+
+def _flush_pending_tty_input() -> None:
+ """Drop unread keypresses typed while the model was generating output."""
+ try:
+ fd = sys.stdin.fileno()
+ if not os.isatty(fd):
+ return
+ except Exception:
+ return
+
+ try:
+ import termios
+ termios.tcflush(fd, termios.TCIFLUSH)
+ return
+ except Exception:
+ pass
+
+ try:
+ while True:
+ ready, _, _ = select.select([fd], [], [], 0)
+ if not ready:
+ break
+ if not os.read(fd, 4096):
+ break
+ except Exception:
+ return
+
+
+def _restore_terminal() -> None:
+ """Restore terminal to its original state (echo, line buffering, etc.)."""
+ if _SAVED_TERM_ATTRS is None:
+ return
+ try:
+ import termios
+ termios.tcsetattr(sys.stdin.fileno(), termios.TCSADRAIN, _SAVED_TERM_ATTRS)
+ except Exception:
+ pass
+
+
+def _init_prompt_session() -> None:
+ """Create the prompt_toolkit session with persistent file history."""
+ global _PROMPT_SESSION, _SAVED_TERM_ATTRS
+
+ # Save terminal state so we can restore it on exit
+ try:
+ import termios
+ _SAVED_TERM_ATTRS = termios.tcgetattr(sys.stdin.fileno())
+ except Exception:
+ pass
+
+ from nanobot.config.paths import get_cli_history_path
+
+ history_file = get_cli_history_path()
+ history_file.parent.mkdir(parents=True, exist_ok=True)
+
+ _PROMPT_SESSION = PromptSession(
+ history=FileHistory(str(history_file)),
+ enable_open_in_editor=False,
+ multiline=False, # Enter submits (single line mode)
+ )
+
+
+def _make_console() -> Console:
+ return Console(file=sys.stdout)
+
+
+def _render_interactive_ansi(render_fn) -> str:
+ """Render Rich output to ANSI so prompt_toolkit can print it safely."""
+ ansi_console = Console(
+ force_terminal=True,
+ color_system=console.color_system or "standard",
+ width=console.width,
+ )
+ with ansi_console.capture() as capture:
+ render_fn(ansi_console)
+ return capture.get()
+
+
+def _print_agent_response(
+ response: str,
+ render_markdown: bool,
+ metadata: dict | None = None,
+) -> None:
+ """Render assistant response with consistent terminal styling."""
+ console = _make_console()
+ content = response or ""
+ body = _response_renderable(content, render_markdown, metadata)
+ console.print()
+ console.print(f"[cyan]{__logo__} nanobot[/cyan]")
+ console.print(body)
+ console.print()
+
+
+def _response_renderable(content: str, render_markdown: bool, metadata: dict | None = None):
+ """Render plain-text command output without markdown collapsing newlines."""
+ if not render_markdown:
+ return Text(content)
+ if (metadata or {}).get("render_as") == "text":
+ return Text(content)
+ return Markdown(content)
+
+
+async def _print_interactive_line(text: str) -> None:
+ """Print async interactive updates with prompt_toolkit-safe Rich styling."""
+ def _write() -> None:
+ ansi = _render_interactive_ansi(
+ lambda c: c.print(f" [dim]↳ {text}[/dim]")
+ )
+ print_formatted_text(ANSI(ansi), end="")
+
+ await run_in_terminal(_write)
+
+
+async def _print_interactive_response(
+ response: str,
+ render_markdown: bool,
+ metadata: dict | None = None,
+) -> None:
+ """Print async interactive replies with prompt_toolkit-safe Rich styling."""
+ def _write() -> None:
+ content = response or ""
+ ansi = _render_interactive_ansi(
+ lambda c: (
+ c.print(),
+ c.print(f"[cyan]{__logo__} nanobot[/cyan]"),
+ c.print(_response_renderable(content, render_markdown, metadata)),
+ c.print(),
+ )
+ )
+ print_formatted_text(ANSI(ansi), end="")
+
+ await run_in_terminal(_write)
+
+
+def _print_cli_progress_line(text: str, thinking: ThinkingSpinner | None) -> None:
+ """Print a CLI progress line, pausing the spinner if needed."""
+ with thinking.pause() if thinking else nullcontext():
+ console.print(f" [dim]↳ {text}[/dim]")
+
+
+async def _print_interactive_progress_line(text: str, thinking: ThinkingSpinner | None) -> None:
+ """Print an interactive progress line, pausing the spinner if needed."""
+ with thinking.pause() if thinking else nullcontext():
+ await _print_interactive_line(text)
+
+
+def _is_exit_command(command: str) -> bool:
+ """Return True when input should end interactive chat."""
+ return command.lower() in EXIT_COMMANDS
+
+
+async def _read_interactive_input_async() -> str:
+ """Read user input using prompt_toolkit (handles paste, history, display).
+
+ prompt_toolkit natively handles:
+ - Multiline paste (bracketed paste mode)
+ - History navigation (up/down arrows)
+ - Clean display (no ghost characters or artifacts)
+ """
+ if _PROMPT_SESSION is None:
+ raise RuntimeError("Call _init_prompt_session() first")
+ try:
+ with patch_stdout():
+ return await _PROMPT_SESSION.prompt_async(
+ HTML("You: "),
+ )
+ except EOFError as exc:
+ raise KeyboardInterrupt from exc
+
def version_callback(value: bool):
@@ -40,111 +254,337 @@ def main(
@app.command()
-def onboard():
+def onboard(
+ workspace: str | None = typer.Option(None, "--workspace", "-w", help="Workspace directory"),
+ config: str | None = typer.Option(None, "--config", "-c", help="Path to config file"),
+ wizard: bool = typer.Option(False, "--wizard", help="Use interactive wizard"),
+):
"""Initialize nanobot configuration and workspace."""
- from nanobot.config.loader import get_config_path, save_config
+ from nanobot.config.loader import get_config_path, load_config, save_config, set_config_path
from nanobot.config.schema import Config
- from nanobot.utils.helpers import get_workspace_path
-
- config_path = get_config_path()
-
+
+ if config:
+ config_path = Path(config).expanduser().resolve()
+ set_config_path(config_path)
+ console.print(f"[dim]Using config: {config_path}[/dim]")
+ else:
+ config_path = get_config_path()
+
+ def _apply_workspace_override(loaded: Config) -> Config:
+ if workspace:
+ loaded.agents.defaults.workspace = workspace
+ return loaded
+
+ # Create or update config
if config_path.exists():
- console.print(f"[yellow]Config already exists at {config_path}[/yellow]")
- if not typer.confirm("Overwrite?"):
- raise typer.Exit()
-
- # Create default config
- config = Config()
- save_config(config)
- console.print(f"[green]✓[/green] Created config at {config_path}")
-
- # Create workspace
- workspace = get_workspace_path()
- console.print(f"[green]✓[/green] Created workspace at {workspace}")
-
- # Create default bootstrap files
- _create_workspace_templates(workspace)
-
+ if wizard:
+ config = _apply_workspace_override(load_config(config_path))
+ else:
+ console.print(f"[yellow]Config already exists at {config_path}[/yellow]")
+ console.print(" [bold]y[/bold] = overwrite with defaults (existing values will be lost)")
+ console.print(" [bold]N[/bold] = refresh config, keeping existing values and adding new fields")
+ if typer.confirm("Overwrite?"):
+ config = _apply_workspace_override(Config())
+ save_config(config, config_path)
+ console.print(f"[green]✓[/green] Config reset to defaults at {config_path}")
+ else:
+ config = _apply_workspace_override(load_config(config_path))
+ save_config(config, config_path)
+ console.print(f"[green]✓[/green] Config refreshed at {config_path} (existing values preserved)")
+ else:
+ config = _apply_workspace_override(Config())
+ # In wizard mode, don't save yet - the wizard will handle saving if should_save=True
+ if not wizard:
+ save_config(config, config_path)
+ console.print(f"[green]✓[/green] Created config at {config_path}")
+
+ # Run interactive wizard if enabled
+ if wizard:
+ from nanobot.cli.onboard import run_onboard
+
+ try:
+ result = run_onboard(initial_config=config)
+ if not result.should_save:
+ console.print("[yellow]Configuration discarded. No changes were saved.[/yellow]")
+ return
+
+ config = result.config
+ save_config(config, config_path)
+ console.print(f"[green]✓[/green] Config saved at {config_path}")
+ except Exception as e:
+ console.print(f"[red]✗[/red] Error during configuration: {e}")
+ console.print("[yellow]Please run 'nanobot onboard' again to complete setup.[/yellow]")
+ raise typer.Exit(1)
+ _onboard_plugins(config_path)
+
+ # Create workspace, preferring the configured workspace path.
+ workspace_path = get_workspace_path(config.workspace_path)
+ if not workspace_path.exists():
+ workspace_path.mkdir(parents=True, exist_ok=True)
+ console.print(f"[green]✓[/green] Created workspace at {workspace_path}")
+
+ sync_workspace_templates(workspace_path)
+
+ agent_cmd = 'nanobot agent -m "Hello!"'
+ gateway_cmd = "nanobot gateway"
+ if config:
+ agent_cmd += f" --config {config_path}"
+ gateway_cmd += f" --config {config_path}"
+
console.print(f"\n{__logo__} nanobot is ready!")
console.print("\nNext steps:")
- console.print(" 1. Add your API key to [cyan]~/.nanobot/config.json[/cyan]")
- console.print(" Get one at: https://openrouter.ai/keys")
- console.print(" 2. Chat: [cyan]nanobot agent -m \"Hello!\"[/cyan]")
+ if wizard:
+ console.print(f" 1. Chat: [cyan]{agent_cmd}[/cyan]")
+ console.print(f" 2. Start gateway: [cyan]{gateway_cmd}[/cyan]")
+ else:
+ console.print(f" 1. Add your API key to [cyan]{config_path}[/cyan]")
+ console.print(" Get one at: https://openrouter.ai/keys")
+ console.print(f" 2. Chat: [cyan]{agent_cmd}[/cyan]")
console.print("\n[dim]Want Telegram/WhatsApp? See: https://github.com/HKUDS/nanobot#-chat-apps[/dim]")
+def _merge_missing_defaults(existing: Any, defaults: Any) -> Any:
+ """Recursively fill in missing values from defaults without overwriting user config."""
+ if not isinstance(existing, dict) or not isinstance(defaults, dict):
+ return existing
+
+ merged = dict(existing)
+ for key, value in defaults.items():
+ if key not in merged:
+ merged[key] = value
+ else:
+ merged[key] = _merge_missing_defaults(merged[key], value)
+ return merged
-def _create_workspace_templates(workspace: Path):
- """Create default workspace template files."""
- templates = {
- "AGENTS.md": """# Agent Instructions
+def _onboard_plugins(config_path: Path) -> None:
+ """Inject default config for all discovered channels (built-in + plugins)."""
+ import json
-You are a helpful AI assistant. Be concise, accurate, and friendly.
+ from nanobot.channels.registry import discover_all
-## Guidelines
+ all_channels = discover_all()
+ if not all_channels:
+ return
-- Always explain what you're doing before taking actions
-- Ask for clarification when the request is ambiguous
-- Use tools to help accomplish tasks
-- Remember important information in your memory files
-""",
- "SOUL.md": """# Soul
+ with open(config_path, encoding="utf-8") as f:
+ data = json.load(f)
-I am nanobot, a lightweight AI assistant.
+ channels = data.setdefault("channels", {})
+ for name, cls in all_channels.items():
+ if name not in channels:
+ channels[name] = cls.default_config()
+ else:
+ channels[name] = _merge_missing_defaults(channels[name], cls.default_config())
-## Personality
+ with open(config_path, "w", encoding="utf-8") as f:
+ json.dump(data, f, indent=2, ensure_ascii=False)
-- Helpful and friendly
-- Concise and to the point
-- Curious and eager to learn
-## Values
+def _make_provider(config: Config):
+ """Create the appropriate LLM provider from config.
-- Accuracy over speed
-- User privacy and safety
-- Transparency in actions
-""",
- "USER.md": """# User
+ Routing is driven by ``ProviderSpec.backend`` in the registry.
+ """
+ from nanobot.providers.base import GenerationSettings
+ from nanobot.providers.registry import find_by_name
-Information about the user goes here.
+ model = config.agents.defaults.model
+ provider_name = config.get_provider_name(model)
+ p = config.get_provider(model)
+ spec = find_by_name(provider_name) if provider_name else None
+ backend = spec.backend if spec else "openai_compat"
-## Preferences
+ # --- validation ---
+ if backend == "azure_openai":
+ if not p or not p.api_key or not p.api_base:
+ console.print("[red]Error: Azure OpenAI requires api_key and api_base.[/red]")
+ console.print("Set them in ~/.nanobot/config.json under providers.azure_openai section")
+ console.print("Use the model field to specify the deployment name.")
+ raise typer.Exit(1)
+ elif backend == "openai_compat" and not model.startswith("bedrock/"):
+ needs_key = not (p and p.api_key)
+ exempt = spec and (spec.is_oauth or spec.is_local or spec.is_direct)
+ if needs_key and not exempt:
+ console.print("[red]Error: No API key configured.[/red]")
+ console.print("Set one in ~/.nanobot/config.json under providers section")
+ raise typer.Exit(1)
-- Communication style: (casual/formal)
-- Timezone: (your timezone)
-- Language: (your preferred language)
-""",
- }
-
- for filename, content in templates.items():
- file_path = workspace / filename
- if not file_path.exists():
- file_path.write_text(content)
- console.print(f" [dim]Created {filename}[/dim]")
-
- # Create memory directory and MEMORY.md
- memory_dir = workspace / "memory"
- memory_dir.mkdir(exist_ok=True)
- memory_file = memory_dir / "MEMORY.md"
- if not memory_file.exists():
- memory_file.write_text("""# Long-term Memory
+ # --- instantiation by backend ---
+ if backend == "openai_codex":
+ from nanobot.providers.openai_codex_provider import OpenAICodexProvider
+ provider = OpenAICodexProvider(default_model=model)
+ elif backend == "azure_openai":
+ from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
+ provider = AzureOpenAIProvider(
+ api_key=p.api_key,
+ api_base=p.api_base,
+ default_model=model,
+ )
+ elif backend == "github_copilot":
+ from nanobot.providers.github_copilot_provider import GitHubCopilotProvider
+ provider = GitHubCopilotProvider(default_model=model)
+ elif backend == "anthropic":
+ from nanobot.providers.anthropic_provider import AnthropicProvider
+ provider = AnthropicProvider(
+ api_key=p.api_key if p else None,
+ api_base=config.get_api_base(model),
+ default_model=model,
+ extra_headers=p.extra_headers if p else None,
+ )
+ else:
+ from nanobot.providers.openai_compat_provider import OpenAICompatProvider
+ provider = OpenAICompatProvider(
+ api_key=p.api_key if p else None,
+ api_base=config.get_api_base(model),
+ default_model=model,
+ extra_headers=p.extra_headers if p else None,
+ spec=spec,
+ )
-This file stores important information that should persist across sessions.
+ defaults = config.agents.defaults
+ provider.generation = GenerationSettings(
+ temperature=defaults.temperature,
+ max_tokens=defaults.max_tokens,
+ reasoning_effort=defaults.reasoning_effort,
+ )
+ return provider
-## User Information
-(Important facts about the user)
+def _load_runtime_config(config: str | None = None, workspace: str | None = None) -> Config:
+ """Load config and optionally override the active workspace."""
+ from nanobot.config.loader import load_config, set_config_path
-## Preferences
+ config_path = None
+ if config:
+ config_path = Path(config).expanduser().resolve()
+ if not config_path.exists():
+ console.print(f"[red]Error: Config file not found: {config_path}[/red]")
+ raise typer.Exit(1)
+ set_config_path(config_path)
+ console.print(f"[dim]Using config: {config_path}[/dim]")
-(User preferences learned over time)
+ loaded = load_config(config_path)
+ _warn_deprecated_config_keys(config_path)
+ if workspace:
+ loaded.agents.defaults.workspace = workspace
+ return loaded
-## Important Notes
-(Things to remember)
-""")
- console.print(" [dim]Created memory/MEMORY.md[/dim]")
+def _warn_deprecated_config_keys(config_path: Path | None) -> None:
+ """Hint users to remove obsolete keys from their config file."""
+ import json
+ from nanobot.config.loader import get_config_path
+
+ path = config_path or get_config_path()
+ try:
+ raw = json.loads(path.read_text(encoding="utf-8"))
+ except Exception:
+ return
+ if "memoryWindow" in raw.get("agents", {}).get("defaults", {}):
+ console.print(
+ "[dim]Hint: `memoryWindow` in your config is no longer used "
+ "and can be safely removed.[/dim]"
+ )
+
+
+def _migrate_cron_store(config: "Config") -> None:
+ """One-time migration: move legacy global cron store into the workspace."""
+ from nanobot.config.paths import get_cron_dir
+
+ legacy_path = get_cron_dir() / "jobs.json"
+ new_path = config.workspace_path / "cron" / "jobs.json"
+ if legacy_path.is_file() and not new_path.exists():
+ new_path.parent.mkdir(parents=True, exist_ok=True)
+ import shutil
+ shutil.move(str(legacy_path), str(new_path))
+
+
+# ============================================================================
+# OpenAI-Compatible API Server
+# ============================================================================
+
+
+@app.command()
+def serve(
+ port: int | None = typer.Option(None, "--port", "-p", help="API server port"),
+ host: str | None = typer.Option(None, "--host", "-H", help="Bind address"),
+ timeout: float | None = typer.Option(None, "--timeout", "-t", help="Per-request timeout (seconds)"),
+ verbose: bool = typer.Option(False, "--verbose", "-v", help="Show nanobot runtime logs"),
+ workspace: str | None = typer.Option(None, "--workspace", "-w", help="Workspace directory"),
+ config: str | None = typer.Option(None, "--config", "-c", help="Path to config file"),
+):
+ """Start the OpenAI-compatible API server (/v1/chat/completions)."""
+ try:
+ from aiohttp import web # noqa: F401
+ except ImportError:
+ console.print("[red]aiohttp is required. Install with: pip install 'nanobot-ai[api]'[/red]")
+ raise typer.Exit(1)
+
+ from loguru import logger
+ from nanobot.agent.loop import AgentLoop
+ from nanobot.api.server import create_app
+ from nanobot.bus.queue import MessageBus
+ from nanobot.session.manager import SessionManager
+
+ if verbose:
+ logger.enable("nanobot")
+ else:
+ logger.disable("nanobot")
+
+ runtime_config = _load_runtime_config(config, workspace)
+ api_cfg = runtime_config.api
+ host = host if host is not None else api_cfg.host
+ port = port if port is not None else api_cfg.port
+ timeout = timeout if timeout is not None else api_cfg.timeout
+ sync_workspace_templates(runtime_config.workspace_path)
+ bus = MessageBus()
+ provider = _make_provider(runtime_config)
+ session_manager = SessionManager(runtime_config.workspace_path)
+ agent_loop = AgentLoop(
+ bus=bus,
+ provider=provider,
+ workspace=runtime_config.workspace_path,
+ model=runtime_config.agents.defaults.model,
+ max_iterations=runtime_config.agents.defaults.max_tool_iterations,
+ context_window_tokens=runtime_config.agents.defaults.context_window_tokens,
+ context_block_limit=runtime_config.agents.defaults.context_block_limit,
+ max_tool_result_chars=runtime_config.agents.defaults.max_tool_result_chars,
+ provider_retry_mode=runtime_config.agents.defaults.provider_retry_mode,
+ web_config=runtime_config.tools.web,
+ exec_config=runtime_config.tools.exec,
+ restrict_to_workspace=runtime_config.tools.restrict_to_workspace,
+ session_manager=session_manager,
+ mcp_servers=runtime_config.tools.mcp_servers,
+ channels_config=runtime_config.channels,
+ timezone=runtime_config.agents.defaults.timezone,
+ )
+
+ model_name = runtime_config.agents.defaults.model
+ console.print(f"{__logo__} Starting OpenAI-compatible API server")
+ console.print(f" [cyan]Endpoint[/cyan] : http://{host}:{port}/v1/chat/completions")
+ console.print(f" [cyan]Model[/cyan] : {model_name}")
+ console.print(" [cyan]Session[/cyan] : api:default")
+ console.print(f" [cyan]Timeout[/cyan] : {timeout}s")
+ if host in {"0.0.0.0", "::"}:
+ console.print(
+ "[yellow]Warning:[/yellow] API is bound to all interfaces. "
+ "Only do this behind a trusted network boundary, firewall, or reverse proxy."
+ )
+ console.print()
+
+ api_app = create_app(agent_loop, model_name=model_name, request_timeout=timeout)
+
+ async def on_startup(_app):
+ await agent_loop._connect_mcp()
+
+ async def on_cleanup(_app):
+ await agent_loop.close_mcp()
+
+ api_app.on_startup.append(on_startup)
+ api_app.on_cleanup.append(on_cleanup)
+
+ web.run_app(api_app, host=host, port=port, print=lambda msg: logger.info(msg))
# ============================================================================
@@ -154,104 +594,208 @@ This file stores important information that should persist across sessions.
@app.command()
def gateway(
- port: int = typer.Option(18790, "--port", "-p", help="Gateway port"),
+ port: int | None = typer.Option(None, "--port", "-p", help="Gateway port"),
+ workspace: str | None = typer.Option(None, "--workspace", "-w", help="Workspace directory"),
verbose: bool = typer.Option(False, "--verbose", "-v", help="Verbose output"),
+ config: str | None = typer.Option(None, "--config", "-c", help="Path to config file"),
):
"""Start the nanobot gateway."""
- from nanobot.config.loader import load_config, get_data_dir
- from nanobot.bus.queue import MessageBus
- from nanobot.providers.litellm_provider import LiteLLMProvider
from nanobot.agent.loop import AgentLoop
+ from nanobot.bus.queue import MessageBus
from nanobot.channels.manager import ChannelManager
from nanobot.cron.service import CronService
from nanobot.cron.types import CronJob
from nanobot.heartbeat.service import HeartbeatService
-
+ from nanobot.session.manager import SessionManager
+
if verbose:
import logging
logging.basicConfig(level=logging.DEBUG)
-
- console.print(f"{__logo__} Starting nanobot gateway on port {port}...")
-
- config = load_config()
-
- # Create components
- bus = MessageBus()
-
- # Create provider (supports OpenRouter, Anthropic, OpenAI, Bedrock)
- api_key = config.get_api_key()
- api_base = config.get_api_base()
- model = config.agents.defaults.model
- is_bedrock = model.startswith("bedrock/")
- if not api_key and not is_bedrock:
- console.print("[red]Error: No API key configured.[/red]")
- console.print("Set one in ~/.nanobot/config.json under providers.openrouter.apiKey")
- raise typer.Exit(1)
-
- provider = LiteLLMProvider(
- api_key=api_key,
- api_base=api_base,
- default_model=config.agents.defaults.model
- )
-
- # Create agent
+ config = _load_runtime_config(config, workspace)
+ port = port if port is not None else config.gateway.port
+
+ console.print(f"{__logo__} Starting nanobot gateway version {__version__} on port {port}...")
+ sync_workspace_templates(config.workspace_path)
+ bus = MessageBus()
+ provider = _make_provider(config)
+ session_manager = SessionManager(config.workspace_path)
+
+ # Preserve existing single-workspace installs, but keep custom workspaces clean.
+ if is_default_workspace(config.workspace_path):
+ _migrate_cron_store(config)
+
+ # Create cron service with workspace-scoped store
+ cron_store_path = config.workspace_path / "cron" / "jobs.json"
+ cron = CronService(cron_store_path)
+
+ # Create agent with cron service
agent = AgentLoop(
bus=bus,
provider=provider,
workspace=config.workspace_path,
model=config.agents.defaults.model,
max_iterations=config.agents.defaults.max_tool_iterations,
- brave_api_key=config.tools.web.search.api_key or None,
+ context_window_tokens=config.agents.defaults.context_window_tokens,
+ web_config=config.tools.web,
+ context_block_limit=config.agents.defaults.context_block_limit,
+ max_tool_result_chars=config.agents.defaults.max_tool_result_chars,
+ provider_retry_mode=config.agents.defaults.provider_retry_mode,
exec_config=config.tools.exec,
+ cron_service=cron,
+ restrict_to_workspace=config.tools.restrict_to_workspace,
+ session_manager=session_manager,
+ mcp_servers=config.tools.mcp_servers,
+ channels_config=config.channels,
+ timezone=config.agents.defaults.timezone,
)
-
- # Create cron service
+
+ # Set cron callback (needs agent)
async def on_cron_job(job: CronJob) -> str | None:
"""Execute a cron job through the agent."""
- response = await agent.process_direct(
- job.payload.message,
- session_key=f"cron:{job.id}"
+ # Dream is an internal job — run directly, not through the agent loop.
+ if job.name == "dream":
+ try:
+ await agent.dream.run()
+ logger.info("Dream cron job completed")
+ except Exception:
+ logger.exception("Dream cron job failed")
+ return None
+
+ from nanobot.agent.tools.cron import CronTool
+ from nanobot.agent.tools.message import MessageTool
+ from nanobot.utils.evaluator import evaluate_response
+
+ reminder_note = (
+ "[Scheduled Task] Timer finished.\n\n"
+ f"Task '{job.name}' has been triggered.\n"
+ f"Scheduled instruction: {job.payload.message}"
)
- # Optionally deliver to channel
- if job.payload.deliver and job.payload.to:
- from nanobot.bus.events import OutboundMessage
- await bus.publish_outbound(OutboundMessage(
- channel=job.payload.channel or "whatsapp",
- chat_id=job.payload.to,
- content=response or ""
- ))
+
+ cron_tool = agent.tools.get("cron")
+ cron_token = None
+ if isinstance(cron_tool, CronTool):
+ cron_token = cron_tool.set_cron_context(True)
+ try:
+ resp = await agent.process_direct(
+ reminder_note,
+ session_key=f"cron:{job.id}",
+ channel=job.payload.channel or "cli",
+ chat_id=job.payload.to or "direct",
+ )
+ finally:
+ if isinstance(cron_tool, CronTool) and cron_token is not None:
+ cron_tool.reset_cron_context(cron_token)
+
+ response = resp.content if resp else ""
+
+ message_tool = agent.tools.get("message")
+ if isinstance(message_tool, MessageTool) and message_tool._sent_in_turn:
+ return response
+
+ if job.payload.deliver and job.payload.to and response:
+ should_notify = await evaluate_response(
+ response, job.payload.message, provider, agent.model,
+ )
+ if should_notify:
+ from nanobot.bus.events import OutboundMessage
+ await bus.publish_outbound(OutboundMessage(
+ channel=job.payload.channel or "cli",
+ chat_id=job.payload.to,
+ content=response,
+ ))
return response
-
- cron_store_path = get_data_dir() / "cron" / "jobs.json"
- cron = CronService(cron_store_path, on_job=on_cron_job)
-
- # Create heartbeat service
- async def on_heartbeat(prompt: str) -> str:
- """Execute heartbeat through the agent."""
- return await agent.process_direct(prompt, session_key="heartbeat")
-
- heartbeat = HeartbeatService(
- workspace=config.workspace_path,
- on_heartbeat=on_heartbeat,
- interval_s=30 * 60, # 30 minutes
- enabled=True
- )
-
+ cron.on_job = on_cron_job
+
# Create channel manager
channels = ChannelManager(config, bus)
-
+
+ def _pick_heartbeat_target() -> tuple[str, str]:
+ """Pick a routable channel/chat target for heartbeat-triggered messages."""
+ enabled = set(channels.enabled_channels)
+ # Prefer the most recently updated non-internal session on an enabled channel.
+ for item in session_manager.list_sessions():
+ key = item.get("key") or ""
+ if ":" not in key:
+ continue
+ channel, chat_id = key.split(":", 1)
+ if channel in {"cli", "system"}:
+ continue
+ if channel in enabled and chat_id:
+ return channel, chat_id
+ # Fallback keeps prior behavior but remains explicit.
+ return "cli", "direct"
+
+ # Create heartbeat service
+ async def on_heartbeat_execute(tasks: str) -> str:
+ """Phase 2: execute heartbeat tasks through the full agent loop."""
+ channel, chat_id = _pick_heartbeat_target()
+
+ async def _silent(*_args, **_kwargs):
+ pass
+
+ resp = await agent.process_direct(
+ tasks,
+ session_key="heartbeat",
+ channel=channel,
+ chat_id=chat_id,
+ on_progress=_silent,
+ )
+
+ # Keep a small tail of heartbeat history so the loop stays bounded
+ # without losing all short-term context between runs.
+ session = agent.sessions.get_or_create("heartbeat")
+ session.retain_recent_legal_suffix(hb_cfg.keep_recent_messages)
+ agent.sessions.save(session)
+
+ return resp.content if resp else ""
+
+ async def on_heartbeat_notify(response: str) -> None:
+ """Deliver a heartbeat response to the user's channel."""
+ from nanobot.bus.events import OutboundMessage
+ channel, chat_id = _pick_heartbeat_target()
+ if channel == "cli":
+ return # No external channel available to deliver to
+ await bus.publish_outbound(OutboundMessage(channel=channel, chat_id=chat_id, content=response))
+
+ hb_cfg = config.gateway.heartbeat
+ heartbeat = HeartbeatService(
+ workspace=config.workspace_path,
+ provider=provider,
+ model=agent.model,
+ on_execute=on_heartbeat_execute,
+ on_notify=on_heartbeat_notify,
+ interval_s=hb_cfg.interval_s,
+ enabled=hb_cfg.enabled,
+ timezone=config.agents.defaults.timezone,
+ )
+
if channels.enabled_channels:
console.print(f"[green]✓[/green] Channels enabled: {', '.join(channels.enabled_channels)}")
else:
console.print("[yellow]Warning: No channels enabled[/yellow]")
-
+
cron_status = cron.status()
if cron_status["jobs"] > 0:
console.print(f"[green]✓[/green] Cron: {cron_status['jobs']} scheduled jobs")
-
- console.print(f"[green]✓[/green] Heartbeat: every 30m")
-
+
+ console.print(f"[green]✓[/green] Heartbeat: every {hb_cfg.interval_s}s")
+
+ # Register Dream system job (always-on, idempotent on restart)
+ dream_cfg = config.agents.defaults.dream
+ if dream_cfg.model_override:
+ agent.dream.model = dream_cfg.model_override
+ agent.dream.max_batch_size = dream_cfg.max_batch_size
+ agent.dream.max_iterations = dream_cfg.max_iterations
+ from nanobot.cron.types import CronJob, CronPayload
+ cron.register_system_job(CronJob(
+ id="dream",
+ name="dream",
+ schedule=dream_cfg.build_schedule(config.agents.defaults.timezone),
+ payload=CronPayload(kind="system_event"),
+ ))
+ console.print(f"[green]✓[/green] Dream: {dream_cfg.describe_schedule()}")
+
async def run():
try:
await cron.start()
@@ -262,11 +806,17 @@ def gateway(
)
except KeyboardInterrupt:
console.print("\nShutting down...")
+ except Exception:
+ import traceback
+ console.print("\n[red]Error: Gateway crashed unexpectedly[/red]")
+ console.print(traceback.format_exc())
+ finally:
+ await agent.close_mcp()
heartbeat.stop()
cron.stop()
agent.stop()
await channels.stop_all()
-
+
asyncio.run(run())
@@ -280,64 +830,231 @@ def gateway(
@app.command()
def agent(
message: str = typer.Option(None, "--message", "-m", help="Message to send to the agent"),
- session_id: str = typer.Option("cli:default", "--session", "-s", help="Session ID"),
+ session_id: str = typer.Option("cli:direct", "--session", "-s", help="Session ID"),
+ workspace: str | None = typer.Option(None, "--workspace", "-w", help="Workspace directory"),
+ config: str | None = typer.Option(None, "--config", "-c", help="Config file path"),
+ markdown: bool = typer.Option(True, "--markdown/--no-markdown", help="Render assistant output as Markdown"),
+ logs: bool = typer.Option(False, "--logs/--no-logs", help="Show nanobot runtime logs during chat"),
):
"""Interact with the agent directly."""
- from nanobot.config.loader import load_config
- from nanobot.bus.queue import MessageBus
- from nanobot.providers.litellm_provider import LiteLLMProvider
- from nanobot.agent.loop import AgentLoop
-
- config = load_config()
-
- api_key = config.get_api_key()
- api_base = config.get_api_base()
- model = config.agents.defaults.model
- is_bedrock = model.startswith("bedrock/")
+ from loguru import logger
- if not api_key and not is_bedrock:
- console.print("[red]Error: No API key configured.[/red]")
- raise typer.Exit(1)
+ from nanobot.agent.loop import AgentLoop
+ from nanobot.bus.queue import MessageBus
+ from nanobot.cron.service import CronService
+
+ config = _load_runtime_config(config, workspace)
+ sync_workspace_templates(config.workspace_path)
bus = MessageBus()
- provider = LiteLLMProvider(
- api_key=api_key,
- api_base=api_base,
- default_model=config.agents.defaults.model
- )
-
+ provider = _make_provider(config)
+
+ # Preserve existing single-workspace installs, but keep custom workspaces clean.
+ if is_default_workspace(config.workspace_path):
+ _migrate_cron_store(config)
+
+ # Create cron service with workspace-scoped store
+ cron_store_path = config.workspace_path / "cron" / "jobs.json"
+ cron = CronService(cron_store_path)
+
+ if logs:
+ logger.enable("nanobot")
+ else:
+ logger.disable("nanobot")
+
agent_loop = AgentLoop(
bus=bus,
provider=provider,
workspace=config.workspace_path,
- brave_api_key=config.tools.web.search.api_key or None,
+ model=config.agents.defaults.model,
+ max_iterations=config.agents.defaults.max_tool_iterations,
+ context_window_tokens=config.agents.defaults.context_window_tokens,
+ web_config=config.tools.web,
+ context_block_limit=config.agents.defaults.context_block_limit,
+ max_tool_result_chars=config.agents.defaults.max_tool_result_chars,
+ provider_retry_mode=config.agents.defaults.provider_retry_mode,
exec_config=config.tools.exec,
+ cron_service=cron,
+ restrict_to_workspace=config.tools.restrict_to_workspace,
+ mcp_servers=config.tools.mcp_servers,
+ channels_config=config.channels,
+ timezone=config.agents.defaults.timezone,
)
-
+ restart_notice = consume_restart_notice_from_env()
+ if restart_notice and should_show_cli_restart_notice(restart_notice, session_id):
+ _print_agent_response(
+ format_restart_completed_message(restart_notice.started_at_raw),
+ render_markdown=False,
+ )
+
+ # Shared reference for progress callbacks
+ _thinking: ThinkingSpinner | None = None
+
+ async def _cli_progress(content: str, *, tool_hint: bool = False) -> None:
+ ch = agent_loop.channels_config
+ if ch and tool_hint and not ch.send_tool_hints:
+ return
+ if ch and not tool_hint and not ch.send_progress:
+ return
+ _print_cli_progress_line(content, _thinking)
+
if message:
- # Single message mode
+ # Single message mode — direct call, no bus needed
async def run_once():
- response = await agent_loop.process_direct(message, session_id)
- console.print(f"\n{__logo__} {response}")
-
+ renderer = StreamRenderer(render_markdown=markdown)
+ response = await agent_loop.process_direct(
+ message, session_id,
+ on_progress=_cli_progress,
+ on_stream=renderer.on_delta,
+ on_stream_end=renderer.on_end,
+ )
+ if not renderer.streamed:
+ await renderer.close()
+ _print_agent_response(
+ response.content if response else "",
+ render_markdown=markdown,
+ metadata=response.metadata if response else None,
+ )
+ await agent_loop.close_mcp()
+
asyncio.run(run_once())
else:
- # Interactive mode
- console.print(f"{__logo__} Interactive mode (Ctrl+C to exit)\n")
-
+ # Interactive mode — route through bus like other channels
+ from nanobot.bus.events import InboundMessage
+ _init_prompt_session()
+ console.print(f"{__logo__} Interactive mode (type [bold]exit[/bold] or [bold]Ctrl+C[/bold] to quit)\n")
+
+ if ":" in session_id:
+ cli_channel, cli_chat_id = session_id.split(":", 1)
+ else:
+ cli_channel, cli_chat_id = "cli", session_id
+
+ def _handle_signal(signum, frame):
+ sig_name = signal.Signals(signum).name
+ _restore_terminal()
+ console.print(f"\nReceived {sig_name}, goodbye!")
+ sys.exit(0)
+
+ signal.signal(signal.SIGINT, _handle_signal)
+ signal.signal(signal.SIGTERM, _handle_signal)
+ # SIGHUP is not available on Windows
+ if hasattr(signal, 'SIGHUP'):
+ signal.signal(signal.SIGHUP, _handle_signal)
+ # Ignore SIGPIPE to prevent silent process termination when writing to closed pipes
+ # SIGPIPE is not available on Windows
+ if hasattr(signal, 'SIGPIPE'):
+ signal.signal(signal.SIGPIPE, signal.SIG_IGN)
+
async def run_interactive():
- while True:
- try:
- user_input = console.input("[bold blue]You:[/bold blue] ")
- if not user_input.strip():
+ bus_task = asyncio.create_task(agent_loop.run())
+ turn_done = asyncio.Event()
+ turn_done.set()
+ turn_response: list[tuple[str, dict]] = []
+ renderer: StreamRenderer | None = None
+
+ async def _consume_outbound():
+ while True:
+ try:
+ msg = await asyncio.wait_for(bus.consume_outbound(), timeout=1.0)
+
+ if msg.metadata.get("_stream_delta"):
+ if renderer:
+ await renderer.on_delta(msg.content)
+ continue
+ if msg.metadata.get("_stream_end"):
+ if renderer:
+ await renderer.on_end(
+ resuming=msg.metadata.get("_resuming", False),
+ )
+ continue
+ if msg.metadata.get("_streamed"):
+ turn_done.set()
+ continue
+
+ if msg.metadata.get("_progress"):
+ is_tool_hint = msg.metadata.get("_tool_hint", False)
+ ch = agent_loop.channels_config
+ if ch and is_tool_hint and not ch.send_tool_hints:
+ pass
+ elif ch and not is_tool_hint and not ch.send_progress:
+ pass
+ else:
+ await _print_interactive_progress_line(msg.content, _thinking)
+ continue
+
+ if not turn_done.is_set():
+ if msg.content:
+ turn_response.append((msg.content, dict(msg.metadata or {})))
+ turn_done.set()
+ elif msg.content:
+ await _print_interactive_response(
+ msg.content,
+ render_markdown=markdown,
+ metadata=msg.metadata,
+ )
+
+ except asyncio.TimeoutError:
continue
-
- response = await agent_loop.process_direct(user_input, session_id)
- console.print(f"\n{__logo__} {response}\n")
- except KeyboardInterrupt:
- console.print("\nGoodbye!")
- break
-
+ except asyncio.CancelledError:
+ break
+
+ outbound_task = asyncio.create_task(_consume_outbound())
+
+ try:
+ while True:
+ try:
+ _flush_pending_tty_input()
+ # Stop spinner before user input to avoid prompt_toolkit conflicts
+ if renderer:
+ renderer.stop_for_input()
+ user_input = await _read_interactive_input_async()
+ command = user_input.strip()
+ if not command:
+ continue
+
+ if _is_exit_command(command):
+ _restore_terminal()
+ console.print("\nGoodbye!")
+ break
+
+ turn_done.clear()
+ turn_response.clear()
+ renderer = StreamRenderer(render_markdown=markdown)
+
+ await bus.publish_inbound(InboundMessage(
+ channel=cli_channel,
+ sender_id="user",
+ chat_id=cli_chat_id,
+ content=user_input,
+ metadata={"_wants_stream": True},
+ ))
+
+ await turn_done.wait()
+
+ if turn_response:
+ content, meta = turn_response[0]
+ if content and not meta.get("_streamed"):
+ if renderer:
+ await renderer.close()
+ _print_agent_response(
+ content, render_markdown=markdown, metadata=meta,
+ )
+ elif renderer and not renderer.streamed:
+ await renderer.close()
+ except KeyboardInterrupt:
+ _restore_terminal()
+ console.print("\nGoodbye!")
+ break
+ except EOFError:
+ _restore_terminal()
+ console.print("\nGoodbye!")
+ break
+ finally:
+ agent_loop.stop()
+ outbound_task.cancel()
+ await asyncio.gather(bus_task, outbound_task, return_exceptions=True)
+ await agent_loop.close_mcp()
+
asyncio.run(run_interactive())
@@ -351,33 +1068,35 @@ app.add_typer(channels_app, name="channels")
@channels_app.command("status")
-def channels_status():
+def channels_status(
+ config_path: str | None = typer.Option(None, "--config", "-c", help="Path to config file"),
+):
"""Show channel status."""
- from nanobot.config.loader import load_config
+ from nanobot.channels.registry import discover_all
+ from nanobot.config.loader import load_config, set_config_path
- config = load_config()
+ resolved_config_path = Path(config_path).expanduser().resolve() if config_path else None
+ if resolved_config_path is not None:
+ set_config_path(resolved_config_path)
+
+ config = load_config(resolved_config_path)
table = Table(title="Channel Status")
table.add_column("Channel", style="cyan")
table.add_column("Enabled", style="green")
- table.add_column("Configuration", style="yellow")
- # WhatsApp
- wa = config.channels.whatsapp
- table.add_row(
- "WhatsApp",
- "✓" if wa.enabled else "✗",
- wa.bridge_url
- )
-
- # Telegram
- tg = config.channels.telegram
- tg_config = f"token: {tg.token[:10]}..." if tg.token else "[dim]not configured[/dim]"
- table.add_row(
- "Telegram",
- "✓" if tg.enabled else "✗",
- tg_config
- )
+ for name, cls in sorted(discover_all().items()):
+ section = getattr(config.channels, name, None)
+ if section is None:
+ enabled = False
+ elif isinstance(section, dict):
+ enabled = section.get("enabled", False)
+ else:
+ enabled = getattr(section, "enabled", False)
+ table.add_row(
+ cls.display_name,
+ "[green]\u2713[/green]" if enabled else "[dim]\u2717[/dim]",
+ )
console.print(table)
@@ -386,233 +1105,138 @@ def _get_bridge_dir() -> Path:
"""Get the bridge directory, setting it up if needed."""
import shutil
import subprocess
-
+
# User's bridge location
- user_bridge = Path.home() / ".nanobot" / "bridge"
-
+ from nanobot.config.paths import get_bridge_install_dir
+
+ user_bridge = get_bridge_install_dir()
+
# Check if already built
if (user_bridge / "dist" / "index.js").exists():
return user_bridge
-
+
# Check for npm
- if not shutil.which("npm"):
+ npm_path = shutil.which("npm")
+ if not npm_path:
console.print("[red]npm not found. Please install Node.js >= 18.[/red]")
raise typer.Exit(1)
-
+
# Find source bridge: first check package data, then source dir
pkg_bridge = Path(__file__).parent.parent / "bridge" # nanobot/bridge (installed)
src_bridge = Path(__file__).parent.parent.parent / "bridge" # repo root/bridge (dev)
-
+
source = None
if (pkg_bridge / "package.json").exists():
source = pkg_bridge
elif (src_bridge / "package.json").exists():
source = src_bridge
-
+
if not source:
console.print("[red]Bridge source not found.[/red]")
console.print("Try reinstalling: pip install --force-reinstall nanobot")
raise typer.Exit(1)
-
+
console.print(f"{__logo__} Setting up bridge...")
-
+
# Copy to user directory
user_bridge.parent.mkdir(parents=True, exist_ok=True)
if user_bridge.exists():
shutil.rmtree(user_bridge)
shutil.copytree(source, user_bridge, ignore=shutil.ignore_patterns("node_modules", "dist"))
-
+
# Install and build
try:
console.print(" Installing dependencies...")
- subprocess.run(["npm", "install"], cwd=user_bridge, check=True, capture_output=True)
-
+ subprocess.run([npm_path, "install"], cwd=user_bridge, check=True, capture_output=True)
+
console.print(" Building...")
- subprocess.run(["npm", "run", "build"], cwd=user_bridge, check=True, capture_output=True)
-
+ subprocess.run([npm_path, "run", "build"], cwd=user_bridge, check=True, capture_output=True)
+
console.print("[green]✓[/green] Bridge ready\n")
except subprocess.CalledProcessError as e:
console.print(f"[red]Build failed: {e}[/red]")
if e.stderr:
console.print(f"[dim]{e.stderr.decode()[:500]}[/dim]")
raise typer.Exit(1)
-
+
return user_bridge
@channels_app.command("login")
-def channels_login():
- """Link device via QR code."""
- import subprocess
-
- bridge_dir = _get_bridge_dir()
-
- console.print(f"{__logo__} Starting bridge...")
- console.print("Scan the QR code to connect.\n")
-
- try:
- subprocess.run(["npm", "start"], cwd=bridge_dir, check=True)
- except subprocess.CalledProcessError as e:
- console.print(f"[red]Bridge failed: {e}[/red]")
- except FileNotFoundError:
- console.print("[red]npm not found. Please install Node.js.[/red]")
-
-
-# ============================================================================
-# Cron Commands
-# ============================================================================
-
-cron_app = typer.Typer(help="Manage scheduled tasks")
-app.add_typer(cron_app, name="cron")
-
-
-@cron_app.command("list")
-def cron_list(
- all: bool = typer.Option(False, "--all", "-a", help="Include disabled jobs"),
+def channels_login(
+ channel_name: str = typer.Argument(..., help="Channel name (e.g. weixin, whatsapp)"),
+ force: bool = typer.Option(False, "--force", "-f", help="Force re-authentication even if already logged in"),
+ config_path: str | None = typer.Option(None, "--config", "-c", help="Path to config file"),
):
- """List scheduled jobs."""
- from nanobot.config.loader import get_data_dir
- from nanobot.cron.service import CronService
-
- store_path = get_data_dir() / "cron" / "jobs.json"
- service = CronService(store_path)
-
- jobs = service.list_jobs(include_disabled=all)
-
- if not jobs:
- console.print("No scheduled jobs.")
- return
-
- table = Table(title="Scheduled Jobs")
- table.add_column("ID", style="cyan")
- table.add_column("Name")
- table.add_column("Schedule")
- table.add_column("Status")
- table.add_column("Next Run")
-
- import time
- for job in jobs:
- # Format schedule
- if job.schedule.kind == "every":
- sched = f"every {(job.schedule.every_ms or 0) // 1000}s"
- elif job.schedule.kind == "cron":
- sched = job.schedule.expr or ""
- else:
- sched = "one-time"
-
- # Format next run
- next_run = ""
- if job.state.next_run_at_ms:
- next_time = time.strftime("%Y-%m-%d %H:%M", time.localtime(job.state.next_run_at_ms / 1000))
- next_run = next_time
-
- status = "[green]enabled[/green]" if job.enabled else "[dim]disabled[/dim]"
-
- table.add_row(job.id, job.name, sched, status, next_run)
-
- console.print(table)
+ """Authenticate with a channel via QR code or other interactive login."""
+ from nanobot.channels.registry import discover_all
+ from nanobot.config.loader import load_config, set_config_path
+ resolved_config_path = Path(config_path).expanduser().resolve() if config_path else None
+ if resolved_config_path is not None:
+ set_config_path(resolved_config_path)
-@cron_app.command("add")
-def cron_add(
- name: str = typer.Option(..., "--name", "-n", help="Job name"),
- message: str = typer.Option(..., "--message", "-m", help="Message for agent"),
- every: int = typer.Option(None, "--every", "-e", help="Run every N seconds"),
- cron_expr: str = typer.Option(None, "--cron", "-c", help="Cron expression (e.g. '0 9 * * *')"),
- at: str = typer.Option(None, "--at", help="Run once at time (ISO format)"),
- deliver: bool = typer.Option(False, "--deliver", "-d", help="Deliver response to channel"),
- to: str = typer.Option(None, "--to", help="Recipient for delivery"),
- channel: str = typer.Option(None, "--channel", help="Channel for delivery (e.g. 'telegram', 'whatsapp')"),
-):
- """Add a scheduled job."""
- from nanobot.config.loader import get_data_dir
- from nanobot.cron.service import CronService
- from nanobot.cron.types import CronSchedule
-
- # Determine schedule type
- if every:
- schedule = CronSchedule(kind="every", every_ms=every * 1000)
- elif cron_expr:
- schedule = CronSchedule(kind="cron", expr=cron_expr)
- elif at:
- import datetime
- dt = datetime.datetime.fromisoformat(at)
- schedule = CronSchedule(kind="at", at_ms=int(dt.timestamp() * 1000))
- else:
- console.print("[red]Error: Must specify --every, --cron, or --at[/red]")
+ config = load_config(resolved_config_path)
+ channel_cfg = getattr(config.channels, channel_name, None) or {}
+
+ # Validate channel exists
+ all_channels = discover_all()
+ if channel_name not in all_channels:
+ available = ", ".join(all_channels.keys())
+ console.print(f"[red]Unknown channel: {channel_name}[/red] Available: {available}")
+ raise typer.Exit(1)
+
+ console.print(f"{__logo__} {all_channels[channel_name].display_name} Login\n")
+
+ channel_cls = all_channels[channel_name]
+ channel = channel_cls(channel_cfg, bus=None)
+
+ success = asyncio.run(channel.login(force=force))
+
+ if not success:
raise typer.Exit(1)
-
- store_path = get_data_dir() / "cron" / "jobs.json"
- service = CronService(store_path)
-
- job = service.add_job(
- name=name,
- schedule=schedule,
- message=message,
- deliver=deliver,
- to=to,
- channel=channel,
- )
-
- console.print(f"[green]✓[/green] Added job '{job.name}' ({job.id})")
-@cron_app.command("remove")
-def cron_remove(
- job_id: str = typer.Argument(..., help="Job ID to remove"),
-):
- """Remove a scheduled job."""
- from nanobot.config.loader import get_data_dir
- from nanobot.cron.service import CronService
-
- store_path = get_data_dir() / "cron" / "jobs.json"
- service = CronService(store_path)
-
- if service.remove_job(job_id):
- console.print(f"[green]✓[/green] Removed job {job_id}")
- else:
- console.print(f"[red]Job {job_id} not found[/red]")
+# ============================================================================
+# Plugin Commands
+# ============================================================================
+
+plugins_app = typer.Typer(help="Manage channel plugins")
+app.add_typer(plugins_app, name="plugins")
-@cron_app.command("enable")
-def cron_enable(
- job_id: str = typer.Argument(..., help="Job ID"),
- disable: bool = typer.Option(False, "--disable", help="Disable instead of enable"),
-):
- """Enable or disable a job."""
- from nanobot.config.loader import get_data_dir
- from nanobot.cron.service import CronService
-
- store_path = get_data_dir() / "cron" / "jobs.json"
- service = CronService(store_path)
-
- job = service.enable_job(job_id, enabled=not disable)
- if job:
- status = "disabled" if disable else "enabled"
- console.print(f"[green]✓[/green] Job '{job.name}' {status}")
- else:
- console.print(f"[red]Job {job_id} not found[/red]")
+@plugins_app.command("list")
+def plugins_list():
+ """List all discovered channels (built-in and plugins)."""
+ from nanobot.channels.registry import discover_all, discover_channel_names
+ from nanobot.config.loader import load_config
+ config = load_config()
+ builtin_names = set(discover_channel_names())
+ all_channels = discover_all()
-@cron_app.command("run")
-def cron_run(
- job_id: str = typer.Argument(..., help="Job ID to run"),
- force: bool = typer.Option(False, "--force", "-f", help="Run even if disabled"),
-):
- """Manually run a job."""
- from nanobot.config.loader import get_data_dir
- from nanobot.cron.service import CronService
-
- store_path = get_data_dir() / "cron" / "jobs.json"
- service = CronService(store_path)
-
- async def run():
- return await service.run_job(job_id, force=force)
-
- if asyncio.run(run()):
- console.print(f"[green]✓[/green] Job executed")
- else:
- console.print(f"[red]Failed to run job {job_id}[/red]")
+ table = Table(title="Channel Plugins")
+ table.add_column("Name", style="cyan")
+ table.add_column("Source", style="magenta")
+ table.add_column("Enabled", style="green")
+
+ for name in sorted(all_channels):
+ cls = all_channels[name]
+ source = "builtin" if name in builtin_names else "plugin"
+ section = getattr(config.channels, name, None)
+ if section is None:
+ enabled = False
+ elif isinstance(section, dict):
+ enabled = section.get("enabled", False)
+ else:
+ enabled = getattr(section, "enabled", False)
+ table.add_row(
+ cls.display_name,
+ source,
+ "[green]yes[/green]" if enabled else "[dim]no[/dim]",
+ )
+
+ console.print(table)
# ============================================================================
@@ -623,7 +1247,7 @@ def cron_run(
@app.command()
def status():
"""Show nanobot status."""
- from nanobot.config.loader import load_config, get_config_path
+ from nanobot.config.loader import get_config_path, load_config
config_path = get_config_path()
config = load_config()
@@ -635,21 +1259,108 @@ def status():
console.print(f"Workspace: {workspace} {'[green]✓[/green]' if workspace.exists() else '[red]✗[/red]'}")
if config_path.exists():
+ from nanobot.providers.registry import PROVIDERS
+
console.print(f"Model: {config.agents.defaults.model}")
-
- # Check API keys
- has_openrouter = bool(config.providers.openrouter.api_key)
- has_anthropic = bool(config.providers.anthropic.api_key)
- has_openai = bool(config.providers.openai.api_key)
- has_gemini = bool(config.providers.gemini.api_key)
- has_vllm = bool(config.providers.vllm.api_base)
-
- console.print(f"OpenRouter API: {'[green]✓[/green]' if has_openrouter else '[dim]not set[/dim]'}")
- console.print(f"Anthropic API: {'[green]✓[/green]' if has_anthropic else '[dim]not set[/dim]'}")
- console.print(f"OpenAI API: {'[green]✓[/green]' if has_openai else '[dim]not set[/dim]'}")
- console.print(f"Gemini API: {'[green]✓[/green]' if has_gemini else '[dim]not set[/dim]'}")
- vllm_status = f"[green]✓ {config.providers.vllm.api_base}[/green]" if has_vllm else "[dim]not set[/dim]"
- console.print(f"vLLM/Local: {vllm_status}")
+
+ # Check API keys from registry
+ for spec in PROVIDERS:
+ p = getattr(config.providers, spec.name, None)
+ if p is None:
+ continue
+ if spec.is_oauth:
+ console.print(f"{spec.label}: [green]✓ (OAuth)[/green]")
+ elif spec.is_local:
+ # Local deployments show api_base instead of api_key
+ if p.api_base:
+ console.print(f"{spec.label}: [green]✓ {p.api_base}[/green]")
+ else:
+ console.print(f"{spec.label}: [dim]not set[/dim]")
+ else:
+ has_key = bool(p.api_key)
+ console.print(f"{spec.label}: {'[green]✓[/green]' if has_key else '[dim]not set[/dim]'}")
+
+
+# ============================================================================
+# OAuth Login
+# ============================================================================
+
+provider_app = typer.Typer(help="Manage providers")
+app.add_typer(provider_app, name="provider")
+
+
+_LOGIN_HANDLERS: dict[str, callable] = {}
+
+
+def _register_login(name: str):
+ def decorator(fn):
+ _LOGIN_HANDLERS[name] = fn
+ return fn
+ return decorator
+
+
+@provider_app.command("login")
+def provider_login(
+ provider: str = typer.Argument(..., help="OAuth provider (e.g. 'openai-codex', 'github-copilot')"),
+):
+ """Authenticate with an OAuth provider."""
+ from nanobot.providers.registry import PROVIDERS
+
+ key = provider.replace("-", "_")
+ spec = next((s for s in PROVIDERS if s.name == key and s.is_oauth), None)
+ if not spec:
+ names = ", ".join(s.name.replace("_", "-") for s in PROVIDERS if s.is_oauth)
+ console.print(f"[red]Unknown OAuth provider: {provider}[/red] Supported: {names}")
+ raise typer.Exit(1)
+
+ handler = _LOGIN_HANDLERS.get(spec.name)
+ if not handler:
+ console.print(f"[red]Login not implemented for {spec.label}[/red]")
+ raise typer.Exit(1)
+
+ console.print(f"{__logo__} OAuth Login - {spec.label}\n")
+ handler()
+
+
+@_register_login("openai_codex")
+def _login_openai_codex() -> None:
+ try:
+ from oauth_cli_kit import get_token, login_oauth_interactive
+ token = None
+ try:
+ token = get_token()
+ except Exception:
+ pass
+ if not (token and token.access):
+ console.print("[cyan]Starting interactive OAuth login...[/cyan]\n")
+ token = login_oauth_interactive(
+ print_fn=lambda s: console.print(s),
+ prompt_fn=lambda s: typer.prompt(s),
+ )
+ if not (token and token.access):
+ console.print("[red]✗ Authentication failed[/red]")
+ raise typer.Exit(1)
+ console.print(f"[green]✓ Authenticated with OpenAI Codex[/green] [dim]{token.account_id}[/dim]")
+ except ImportError:
+ console.print("[red]oauth_cli_kit not installed. Run: pip install oauth-cli-kit[/red]")
+ raise typer.Exit(1)
+
+
+@_register_login("github_copilot")
+def _login_github_copilot() -> None:
+ try:
+ from nanobot.providers.github_copilot_provider import login_github_copilot
+
+ console.print("[cyan]Starting GitHub Copilot device flow...[/cyan]\n")
+ token = login_github_copilot(
+ print_fn=lambda s: console.print(s),
+ prompt_fn=lambda s: typer.prompt(s),
+ )
+ account = token.account_id or "GitHub"
+ console.print(f"[green]✓ Authenticated with GitHub Copilot[/green] [dim]{account}[/dim]")
+ except Exception as e:
+ console.print(f"[red]Authentication error: {e}[/red]")
+ raise typer.Exit(1)
if __name__ == "__main__":
diff --git a/nanobot/cli/models.py b/nanobot/cli/models.py
new file mode 100644
index 000000000..0ba24018f
--- /dev/null
+++ b/nanobot/cli/models.py
@@ -0,0 +1,31 @@
+"""Model information helpers for the onboard wizard.
+
+Model database / autocomplete is temporarily disabled while litellm is
+being replaced. All public function signatures are preserved so callers
+continue to work without changes.
+"""
+
+from __future__ import annotations
+
+from typing import Any
+
+
+def get_all_models() -> list[str]:
+ return []
+
+
+def find_model_info(model_name: str) -> dict[str, Any] | None:
+ return None
+
+
+def get_model_context_limit(model: str, provider: str = "auto") -> int | None:
+ return None
+
+
+def get_model_suggestions(partial: str, provider: str = "auto", limit: int = 20) -> list[str]:
+ return []
+
+
+def format_token_count(tokens: int) -> str:
+ """Format token count for display (e.g., 200000 -> '200,000')."""
+ return f"{tokens:,}"
diff --git a/nanobot/cli/onboard.py b/nanobot/cli/onboard.py
new file mode 100644
index 000000000..4e3b6e562
--- /dev/null
+++ b/nanobot/cli/onboard.py
@@ -0,0 +1,1023 @@
+"""Interactive onboarding questionnaire for nanobot."""
+
+import json
+import types
+from dataclasses import dataclass
+from functools import lru_cache
+from typing import Any, NamedTuple, get_args, get_origin
+
+try:
+ import questionary
+except ModuleNotFoundError: # pragma: no cover - exercised in environments without wizard deps
+ questionary = None
+from loguru import logger
+from pydantic import BaseModel
+from rich.console import Console
+from rich.panel import Panel
+from rich.table import Table
+
+from nanobot.cli.models import (
+ format_token_count,
+ get_model_context_limit,
+ get_model_suggestions,
+)
+from nanobot.config.loader import get_config_path, load_config
+from nanobot.config.schema import Config
+
+console = Console()
+
+
+@dataclass
+class OnboardResult:
+ """Result of an onboarding session."""
+
+ config: Config
+ should_save: bool
+
+# --- Field Hints for Select Fields ---
+# Maps field names to (choices, hint_text)
+# To add a new select field with hints, add an entry:
+# "field_name": (["choice1", "choice2", ...], "hint text for the field")
+_SELECT_FIELD_HINTS: dict[str, tuple[list[str], str]] = {
+ "reasoning_effort": (
+ ["low", "medium", "high"],
+ "low / medium / high - enables LLM thinking mode",
+ ),
+}
+
+# --- Key Bindings for Navigation ---
+
+_BACK_PRESSED = object() # Sentinel value for back navigation
+
+
+def _get_questionary():
+ """Return questionary or raise a clear error when wizard deps are unavailable."""
+ if questionary is None:
+ raise RuntimeError(
+ "Interactive onboarding requires the optional 'questionary' dependency. "
+ "Install project dependencies and rerun with --wizard."
+ )
+ return questionary
+
+
+def _select_with_back(
+ prompt: str, choices: list[str], default: str | None = None
+) -> str | None | object:
+ """Select with Escape/Left arrow support for going back.
+
+ Args:
+ prompt: The prompt text to display.
+ choices: List of choices to select from. Must not be empty.
+ default: The default choice to pre-select. If not in choices, first item is used.
+
+ Returns:
+ _BACK_PRESSED sentinel if user pressed Escape or Left arrow
+ The selected choice string if user confirmed
+ None if user cancelled (Ctrl+C)
+ """
+ from prompt_toolkit.application import Application
+ from prompt_toolkit.key_binding import KeyBindings
+ from prompt_toolkit.keys import Keys
+ from prompt_toolkit.layout import Layout
+ from prompt_toolkit.layout.containers import HSplit, Window
+ from prompt_toolkit.layout.controls import FormattedTextControl
+ from prompt_toolkit.styles import Style
+
+ # Validate choices
+ if not choices:
+ logger.warning("Empty choices list provided to _select_with_back")
+ return None
+
+ # Find default index
+ selected_index = 0
+ if default and default in choices:
+ selected_index = choices.index(default)
+
+ # State holder for the result
+ state: dict[str, str | None | object] = {"result": None}
+
+ # Build menu items (uses closure over selected_index)
+ def get_menu_text():
+ items = []
+ for i, choice in enumerate(choices):
+ if i == selected_index:
+ items.append(("class:selected", f"> {choice}\n"))
+ else:
+ items.append(("", f" {choice}\n"))
+ return items
+
+ # Create layout
+ menu_control = FormattedTextControl(get_menu_text)
+ menu_window = Window(content=menu_control, height=len(choices))
+
+ prompt_control = FormattedTextControl(lambda: [("class:question", f"> {prompt}")])
+ prompt_window = Window(content=prompt_control, height=1)
+
+ layout = Layout(HSplit([prompt_window, menu_window]))
+
+ # Key bindings
+ bindings = KeyBindings()
+
+ @bindings.add(Keys.Up)
+ def _up(event):
+ nonlocal selected_index
+ selected_index = (selected_index - 1) % len(choices)
+ event.app.invalidate()
+
+ @bindings.add(Keys.Down)
+ def _down(event):
+ nonlocal selected_index
+ selected_index = (selected_index + 1) % len(choices)
+ event.app.invalidate()
+
+ @bindings.add(Keys.Enter)
+ def _enter(event):
+ state["result"] = choices[selected_index]
+ event.app.exit()
+
+ @bindings.add("escape")
+ def _escape(event):
+ state["result"] = _BACK_PRESSED
+ event.app.exit()
+
+ @bindings.add(Keys.Left)
+ def _left(event):
+ state["result"] = _BACK_PRESSED
+ event.app.exit()
+
+ @bindings.add(Keys.ControlC)
+ def _ctrl_c(event):
+ state["result"] = None
+ event.app.exit()
+
+ # Style
+ style = Style.from_dict({
+ "selected": "fg:green bold",
+ "question": "fg:cyan",
+ })
+
+ app = Application(layout=layout, key_bindings=bindings, style=style)
+ try:
+ app.run()
+ except Exception:
+ logger.exception("Error in select prompt")
+ return None
+
+ return state["result"]
+
+# --- Type Introspection ---
+
+
+class FieldTypeInfo(NamedTuple):
+ """Result of field type introspection."""
+
+ type_name: str
+ inner_type: Any
+
+
+def _get_field_type_info(field_info) -> FieldTypeInfo:
+ """Extract field type info from Pydantic field."""
+ annotation = field_info.annotation
+ if annotation is None:
+ return FieldTypeInfo("str", None)
+
+ origin = get_origin(annotation)
+ args = get_args(annotation)
+
+ if origin is types.UnionType:
+ non_none_args = [a for a in args if a is not type(None)]
+ if len(non_none_args) == 1:
+ annotation = non_none_args[0]
+ origin = get_origin(annotation)
+ args = get_args(annotation)
+
+ _SIMPLE_TYPES: dict[type, str] = {bool: "bool", int: "int", float: "float"}
+
+ if origin is list or (hasattr(origin, "__name__") and origin.__name__ == "List"):
+ return FieldTypeInfo("list", args[0] if args else str)
+ if origin is dict or (hasattr(origin, "__name__") and origin.__name__ == "Dict"):
+ return FieldTypeInfo("dict", None)
+ for py_type, name in _SIMPLE_TYPES.items():
+ if annotation is py_type:
+ return FieldTypeInfo(name, None)
+ if isinstance(annotation, type) and issubclass(annotation, BaseModel):
+ return FieldTypeInfo("model", annotation)
+ return FieldTypeInfo("str", None)
+
+
+def _get_field_display_name(field_key: str, field_info) -> str:
+ """Get display name for a field."""
+ if field_info and field_info.description:
+ return field_info.description
+ name = field_key
+ suffix_map = {
+ "_s": " (seconds)",
+ "_ms": " (ms)",
+ "_url": " URL",
+ "_path": " Path",
+ "_id": " ID",
+ "_key": " Key",
+ "_token": " Token",
+ }
+ for suffix, replacement in suffix_map.items():
+ if name.endswith(suffix):
+ name = name[: -len(suffix)] + replacement
+ break
+ return name.replace("_", " ").title()
+
+
+# --- Sensitive Field Masking ---
+
+_SENSITIVE_KEYWORDS = frozenset({"api_key", "token", "secret", "password", "credentials"})
+
+
+def _is_sensitive_field(field_name: str) -> bool:
+ """Check if a field name indicates sensitive content."""
+ return any(kw in field_name.lower() for kw in _SENSITIVE_KEYWORDS)
+
+
+def _mask_value(value: str) -> str:
+ """Mask a sensitive value, showing only the last 4 characters."""
+ if len(value) <= 4:
+ return "****"
+ return "*" * (len(value) - 4) + value[-4:]
+
+
+# --- Value Formatting ---
+
+
+def _format_value(value: Any, rich: bool = True, field_name: str = "") -> str:
+ """Single recursive entry point for safe value display. Handles any depth."""
+ if value is None or value == "" or value == {} or value == []:
+ return "[dim]not set[/dim]" if rich else "[not set]"
+ if _is_sensitive_field(field_name) and isinstance(value, str):
+ masked = _mask_value(value)
+ return f"[dim]{masked}[/dim]" if rich else masked
+ if isinstance(value, BaseModel):
+ parts = []
+ for fname, _finfo in type(value).model_fields.items():
+ fval = getattr(value, fname, None)
+ formatted = _format_value(fval, rich=False, field_name=fname)
+ if formatted != "[not set]":
+ parts.append(f"{fname}={formatted}")
+ return ", ".join(parts) if parts else ("[dim]not set[/dim]" if rich else "[not set]")
+ if isinstance(value, list):
+ return ", ".join(str(v) for v in value)
+ if isinstance(value, dict):
+ return json.dumps(value)
+ return str(value)
+
+
+def _format_value_for_input(value: Any, field_type: str) -> str:
+ """Format a value for use as input default."""
+ if value is None or value == "":
+ return ""
+ if field_type == "list" and isinstance(value, list):
+ return ",".join(str(v) for v in value)
+ if field_type == "dict" and isinstance(value, dict):
+ return json.dumps(value)
+ return str(value)
+
+
+# --- Rich UI Components ---
+
+
+def _show_config_panel(display_name: str, model: BaseModel, fields: list) -> None:
+ """Display current configuration as a rich table."""
+ table = Table(show_header=False, box=None, padding=(0, 2))
+ table.add_column("Field", style="cyan")
+ table.add_column("Value")
+
+ for fname, field_info in fields:
+ value = getattr(model, fname, None)
+ display = _get_field_display_name(fname, field_info)
+ formatted = _format_value(value, rich=True, field_name=fname)
+ table.add_row(display, formatted)
+
+ console.print(Panel(table, title=f"[bold]{display_name}[/bold]", border_style="blue"))
+
+
+def _show_main_menu_header() -> None:
+ """Display the main menu header."""
+ from nanobot import __logo__, __version__
+
+ console.print()
+ # Use Align.CENTER for the single line of text
+ from rich.align import Align
+
+ console.print(
+ Align.center(f"{__logo__} [bold cyan]nanobot[{__version__}][/bold cyan]")
+ )
+ console.print()
+
+
+def _show_section_header(title: str, subtitle: str = "") -> None:
+ """Display a section header."""
+ console.print()
+ if subtitle:
+ console.print(
+ Panel(f"[dim]{subtitle}[/dim]", title=f"[bold]{title}[/bold]", border_style="blue")
+ )
+ else:
+ console.print(Panel("", title=f"[bold]{title}[/bold]", border_style="blue"))
+
+
+# --- Input Handlers ---
+
+
+def _input_bool(display_name: str, current: bool | None) -> bool | None:
+ """Get boolean input via confirm dialog."""
+ return _get_questionary().confirm(
+ display_name,
+ default=bool(current) if current is not None else False,
+ ).ask()
+
+
+def _input_text(display_name: str, current: Any, field_type: str) -> Any:
+ """Get text input and parse based on field type."""
+ default = _format_value_for_input(current, field_type)
+
+ value = _get_questionary().text(f"{display_name}:", default=default).ask()
+
+ if value is None or value == "":
+ return None
+
+ if field_type == "int":
+ try:
+ return int(value)
+ except ValueError:
+ console.print("[yellow]! Invalid number format, value not saved[/yellow]")
+ return None
+ elif field_type == "float":
+ try:
+ return float(value)
+ except ValueError:
+ console.print("[yellow]! Invalid number format, value not saved[/yellow]")
+ return None
+ elif field_type == "list":
+ return [v.strip() for v in value.split(",") if v.strip()]
+ elif field_type == "dict":
+ try:
+ return json.loads(value)
+ except json.JSONDecodeError:
+ console.print("[yellow]! Invalid JSON format, value not saved[/yellow]")
+ return None
+
+ return value
+
+
+def _input_with_existing(
+ display_name: str, current: Any, field_type: str
+) -> Any:
+ """Handle input with 'keep existing' option for non-empty values."""
+ has_existing = current is not None and current != "" and current != {} and current != []
+
+ if has_existing and not isinstance(current, list):
+ choice = _get_questionary().select(
+ display_name,
+ choices=["Enter new value", "Keep existing value"],
+ default="Keep existing value",
+ ).ask()
+ if choice == "Keep existing value" or choice is None:
+ return None
+
+ return _input_text(display_name, current, field_type)
+
+
+# --- Pydantic Model Configuration ---
+
+
+def _get_current_provider(model: BaseModel) -> str:
+ """Get the current provider setting from a model (if available)."""
+ if hasattr(model, "provider"):
+ return getattr(model, "provider", "auto") or "auto"
+ return "auto"
+
+
+def _input_model_with_autocomplete(
+ display_name: str, current: Any, provider: str
+) -> str | None:
+ """Get model input with autocomplete suggestions.
+
+ """
+ from prompt_toolkit.completion import Completer, Completion
+
+ default = str(current) if current else ""
+
+ class DynamicModelCompleter(Completer):
+ """Completer that dynamically fetches model suggestions."""
+
+ def __init__(self, provider_name: str):
+ self.provider = provider_name
+
+ def get_completions(self, document, complete_event):
+ text = document.text_before_cursor
+ suggestions = get_model_suggestions(text, provider=self.provider, limit=50)
+ for model in suggestions:
+ # Skip if model doesn't contain the typed text
+ if text.lower() not in model.lower():
+ continue
+ yield Completion(
+ model,
+ start_position=-len(text),
+ display=model,
+ )
+
+ value = _get_questionary().autocomplete(
+ f"{display_name}:",
+ choices=[""], # Placeholder, actual completions from completer
+ completer=DynamicModelCompleter(provider),
+ default=default,
+ qmark=">",
+ ).ask()
+
+ return value if value else None
+
+
+def _input_context_window_with_recommendation(
+ display_name: str, current: Any, model_obj: BaseModel
+) -> int | None:
+ """Get context window input with option to fetch recommended value."""
+ current_val = current if current else ""
+
+ choices = ["Enter new value"]
+ if current_val:
+ choices.append("Keep existing value")
+ choices.append("[?] Get recommended value")
+
+ choice = _get_questionary().select(
+ display_name,
+ choices=choices,
+ default="Enter new value",
+ ).ask()
+
+ if choice is None:
+ return None
+
+ if choice == "Keep existing value":
+ return None
+
+ if choice == "[?] Get recommended value":
+ # Get the model name from the model object
+ model_name = getattr(model_obj, "model", None)
+ if not model_name:
+ console.print("[yellow]! Please configure the model field first[/yellow]")
+ return None
+
+ provider = _get_current_provider(model_obj)
+ context_limit = get_model_context_limit(model_name, provider)
+
+ if context_limit:
+ console.print(f"[green]+ Recommended context window: {format_token_count(context_limit)} tokens[/green]")
+ return context_limit
+ else:
+ console.print("[yellow]! Could not fetch model info, please enter manually[/yellow]")
+ # Fall through to manual input
+
+ # Manual input
+ value = _get_questionary().text(
+ f"{display_name}:",
+ default=str(current_val) if current_val else "",
+ ).ask()
+
+ if value is None or value == "":
+ return None
+
+ try:
+ return int(value)
+ except ValueError:
+ console.print("[yellow]! Invalid number format, value not saved[/yellow]")
+ return None
+
+
+def _handle_model_field(
+ working_model: BaseModel, field_name: str, field_display: str, current_value: Any
+) -> None:
+ """Handle the 'model' field with autocomplete and context-window auto-fill."""
+ provider = _get_current_provider(working_model)
+ new_value = _input_model_with_autocomplete(field_display, current_value, provider)
+ if new_value is not None and new_value != current_value:
+ setattr(working_model, field_name, new_value)
+ _try_auto_fill_context_window(working_model, new_value)
+
+
+def _handle_context_window_field(
+ working_model: BaseModel, field_name: str, field_display: str, current_value: Any
+) -> None:
+ """Handle context_window_tokens with recommendation lookup."""
+ new_value = _input_context_window_with_recommendation(
+ field_display, current_value, working_model
+ )
+ if new_value is not None:
+ setattr(working_model, field_name, new_value)
+
+
+_FIELD_HANDLERS: dict[str, Any] = {
+ "model": _handle_model_field,
+ "context_window_tokens": _handle_context_window_field,
+}
+
+
+def _configure_pydantic_model(
+ model: BaseModel,
+ display_name: str,
+ *,
+ skip_fields: set[str] | None = None,
+) -> BaseModel | None:
+ """Configure a Pydantic model interactively.
+
+ Returns the updated model only when the user explicitly selects "Done".
+ Back and cancel actions discard the section draft.
+ """
+ skip_fields = skip_fields or set()
+ working_model = model.model_copy(deep=True)
+
+ fields = [
+ (name, info)
+ for name, info in type(working_model).model_fields.items()
+ if name not in skip_fields
+ ]
+ if not fields:
+ console.print(f"[dim]{display_name}: No configurable fields[/dim]")
+ return working_model
+
+ def get_choices() -> list[str]:
+ items = []
+ for fname, finfo in fields:
+ value = getattr(working_model, fname, None)
+ display = _get_field_display_name(fname, finfo)
+ formatted = _format_value(value, rich=False, field_name=fname)
+ items.append(f"{display}: {formatted}")
+ return items + ["[Done]"]
+
+ while True:
+ console.clear()
+ _show_config_panel(display_name, working_model, fields)
+ choices = get_choices()
+ answer = _select_with_back("Select field to configure:", choices)
+
+ if answer is _BACK_PRESSED or answer is None:
+ return None
+ if answer == "[Done]":
+ return working_model
+
+ field_idx = next((i for i, c in enumerate(choices) if c == answer), -1)
+ if field_idx < 0 or field_idx >= len(fields):
+ return None
+
+ field_name, field_info = fields[field_idx]
+ current_value = getattr(working_model, field_name, None)
+ ftype = _get_field_type_info(field_info)
+ field_display = _get_field_display_name(field_name, field_info)
+
+ # Nested Pydantic model - recurse
+ if ftype.type_name == "model":
+ nested = current_value
+ created = nested is None
+ if nested is None and ftype.inner_type:
+ nested = ftype.inner_type()
+ if nested and isinstance(nested, BaseModel):
+ updated = _configure_pydantic_model(nested, field_display)
+ if updated is not None:
+ setattr(working_model, field_name, updated)
+ elif created:
+ setattr(working_model, field_name, None)
+ continue
+
+ # Registered special-field handlers
+ handler = _FIELD_HANDLERS.get(field_name)
+ if handler:
+ handler(working_model, field_name, field_display, current_value)
+ continue
+
+ # Select fields with hints (e.g. reasoning_effort)
+ if field_name in _SELECT_FIELD_HINTS:
+ choices_list, hint = _SELECT_FIELD_HINTS[field_name]
+ select_choices = choices_list + ["(clear/unset)"]
+ console.print(f"[dim] Hint: {hint}[/dim]")
+ new_value = _select_with_back(
+ field_display, select_choices, default=current_value or select_choices[0]
+ )
+ if new_value is _BACK_PRESSED:
+ continue
+ if new_value == "(clear/unset)":
+ setattr(working_model, field_name, None)
+ elif new_value is not None:
+ setattr(working_model, field_name, new_value)
+ continue
+
+ # Generic field input
+ if ftype.type_name == "bool":
+ new_value = _input_bool(field_display, current_value)
+ else:
+ new_value = _input_with_existing(field_display, current_value, ftype.type_name)
+ if new_value is not None:
+ setattr(working_model, field_name, new_value)
+
+
+def _try_auto_fill_context_window(model: BaseModel, new_model_name: str) -> None:
+ """Try to auto-fill context_window_tokens if it's at default value.
+
+ Note:
+ This function imports AgentDefaults from nanobot.config.schema to get
+ the default context_window_tokens value. If the schema changes, this
+ coupling needs to be updated accordingly.
+ """
+ # Check if context_window_tokens field exists
+ if not hasattr(model, "context_window_tokens"):
+ return
+
+ current_context = getattr(model, "context_window_tokens", None)
+
+ # Check if current value is the default (65536)
+ # We only auto-fill if the user hasn't changed it from default
+ from nanobot.config.schema import AgentDefaults
+
+ default_context = AgentDefaults.model_fields["context_window_tokens"].default
+
+ if current_context != default_context:
+ return # User has customized it, don't override
+
+ provider = _get_current_provider(model)
+ context_limit = get_model_context_limit(new_model_name, provider)
+
+ if context_limit:
+ setattr(model, "context_window_tokens", context_limit)
+ console.print(f"[green]+ Auto-filled context window: {format_token_count(context_limit)} tokens[/green]")
+ else:
+ console.print("[dim](i) Could not auto-fill context window (model not in database)[/dim]")
+
+
+# --- Provider Configuration ---
+
+
+@lru_cache(maxsize=1)
+def _get_provider_info() -> dict[str, tuple[str, bool, bool, str]]:
+ """Get provider info from registry (cached)."""
+ from nanobot.providers.registry import PROVIDERS
+
+ return {
+ spec.name: (
+ spec.display_name or spec.name,
+ spec.is_gateway,
+ spec.is_local,
+ spec.default_api_base,
+ )
+ for spec in PROVIDERS
+ if not spec.is_oauth
+ }
+
+
+def _get_provider_names() -> dict[str, str]:
+ """Get provider display names."""
+ info = _get_provider_info()
+ return {name: data[0] for name, data in info.items() if name}
+
+
+def _configure_provider(config: Config, provider_name: str) -> None:
+ """Configure a single LLM provider."""
+ provider_config = getattr(config.providers, provider_name, None)
+ if provider_config is None:
+ console.print(f"[red]Unknown provider: {provider_name}[/red]")
+ return
+
+ display_name = _get_provider_names().get(provider_name, provider_name)
+ info = _get_provider_info()
+ default_api_base = info.get(provider_name, (None, None, None, None))[3]
+
+ if default_api_base and not provider_config.api_base:
+ provider_config.api_base = default_api_base
+
+ updated_provider = _configure_pydantic_model(
+ provider_config,
+ display_name,
+ )
+ if updated_provider is not None:
+ setattr(config.providers, provider_name, updated_provider)
+
+
+def _configure_providers(config: Config) -> None:
+ """Configure LLM providers."""
+
+ def get_provider_choices() -> list[str]:
+ """Build provider choices with config status indicators."""
+ choices = []
+ for name, display in _get_provider_names().items():
+ provider = getattr(config.providers, name, None)
+ if provider and provider.api_key:
+ choices.append(f"{display} *")
+ else:
+ choices.append(display)
+ return choices + ["<- Back"]
+
+ while True:
+ try:
+ console.clear()
+ _show_section_header("LLM Providers", "Select a provider to configure API key and endpoint")
+ choices = get_provider_choices()
+ answer = _select_with_back("Select provider:", choices)
+
+ if answer is _BACK_PRESSED or answer is None or answer == "<- Back":
+ break
+
+ # Type guard: answer is now guaranteed to be a string
+ assert isinstance(answer, str)
+ # Extract provider name from choice (remove " *" suffix if present)
+ provider_name = answer.replace(" *", "")
+ # Find the actual provider key from display names
+ for name, display in _get_provider_names().items():
+ if display == provider_name:
+ _configure_provider(config, name)
+ break
+
+ except KeyboardInterrupt:
+ console.print("\n[dim]Returning to main menu...[/dim]")
+ break
+
+
+# --- Channel Configuration ---
+
+
+@lru_cache(maxsize=1)
+def _get_channel_info() -> dict[str, tuple[str, type[BaseModel]]]:
+ """Get channel info (display name + config class) from channel modules."""
+ import importlib
+
+ from nanobot.channels.registry import discover_all
+
+ result: dict[str, tuple[str, type[BaseModel]]] = {}
+ for name, channel_cls in discover_all().items():
+ try:
+ mod = importlib.import_module(f"nanobot.channels.{name}")
+ config_name = channel_cls.__name__.replace("Channel", "Config")
+ config_cls = getattr(mod, config_name, None)
+ if config_cls and isinstance(config_cls, type) and issubclass(config_cls, BaseModel):
+ display_name = getattr(channel_cls, "display_name", name.capitalize())
+ result[name] = (display_name, config_cls)
+ except Exception:
+ logger.warning(f"Failed to load channel module: {name}")
+ return result
+
+
+def _get_channel_names() -> dict[str, str]:
+ """Get channel display names."""
+ return {name: info[0] for name, info in _get_channel_info().items()}
+
+
+def _get_channel_config_class(channel: str) -> type[BaseModel] | None:
+ """Get channel config class."""
+ entry = _get_channel_info().get(channel)
+ return entry[1] if entry else None
+
+
+def _configure_channel(config: Config, channel_name: str) -> None:
+ """Configure a single channel."""
+ channel_dict = getattr(config.channels, channel_name, None)
+ if channel_dict is None:
+ channel_dict = {}
+ setattr(config.channels, channel_name, channel_dict)
+
+ display_name = _get_channel_names().get(channel_name, channel_name)
+ config_cls = _get_channel_config_class(channel_name)
+
+ if config_cls is None:
+ console.print(f"[red]No configuration class found for {display_name}[/red]")
+ return
+
+ model = config_cls.model_validate(channel_dict) if channel_dict else config_cls()
+
+ updated_channel = _configure_pydantic_model(
+ model,
+ display_name,
+ )
+ if updated_channel is not None:
+ new_dict = updated_channel.model_dump(by_alias=True, exclude_none=True)
+ setattr(config.channels, channel_name, new_dict)
+
+
+def _configure_channels(config: Config) -> None:
+ """Configure chat channels."""
+ channel_names = list(_get_channel_names().keys())
+ choices = channel_names + ["<- Back"]
+
+ while True:
+ try:
+ console.clear()
+ _show_section_header("Chat Channels", "Select a channel to configure connection settings")
+ answer = _select_with_back("Select channel:", choices)
+
+ if answer is _BACK_PRESSED or answer is None or answer == "<- Back":
+ break
+
+ # Type guard: answer is now guaranteed to be a string
+ assert isinstance(answer, str)
+ _configure_channel(config, answer)
+ except KeyboardInterrupt:
+ console.print("\n[dim]Returning to main menu...[/dim]")
+ break
+
+
+# --- General Settings ---
+
+_SETTINGS_SECTIONS: dict[str, tuple[str, str, set[str] | None]] = {
+ "Agent Settings": ("Agent Defaults", "Configure default model, temperature, and behavior", None),
+ "Gateway": ("Gateway Settings", "Configure server host, port, and heartbeat", None),
+ "Tools": ("Tools Settings", "Configure web search, shell exec, and other tools", {"mcp_servers"}),
+}
+
+_SETTINGS_GETTER = {
+ "Agent Settings": lambda c: c.agents.defaults,
+ "Gateway": lambda c: c.gateway,
+ "Tools": lambda c: c.tools,
+}
+
+_SETTINGS_SETTER = {
+ "Agent Settings": lambda c, v: setattr(c.agents, "defaults", v),
+ "Gateway": lambda c, v: setattr(c, "gateway", v),
+ "Tools": lambda c, v: setattr(c, "tools", v),
+}
+
+
+def _configure_general_settings(config: Config, section: str) -> None:
+ """Configure a general settings section (header + model edit + writeback)."""
+ meta = _SETTINGS_SECTIONS.get(section)
+ if not meta:
+ return
+ display_name, subtitle, skip = meta
+ model = _SETTINGS_GETTER[section](config)
+ updated = _configure_pydantic_model(model, display_name, skip_fields=skip)
+ if updated is not None:
+ _SETTINGS_SETTER[section](config, updated)
+
+
+# --- Summary ---
+
+
+def _summarize_model(obj: BaseModel) -> list[tuple[str, str]]:
+ """Recursively summarize a Pydantic model. Returns list of (field, value) tuples."""
+ items: list[tuple[str, str]] = []
+ for field_name, field_info in type(obj).model_fields.items():
+ value = getattr(obj, field_name, None)
+ if value is None or value == "" or value == {} or value == []:
+ continue
+ display = _get_field_display_name(field_name, field_info)
+ ftype = _get_field_type_info(field_info)
+ if ftype.type_name == "model" and isinstance(value, BaseModel):
+ for nested_field, nested_value in _summarize_model(value):
+ items.append((f"{display}.{nested_field}", nested_value))
+ continue
+ formatted = _format_value(value, rich=False, field_name=field_name)
+ if formatted != "[not set]":
+ items.append((display, formatted))
+ return items
+
+
+def _print_summary_panel(rows: list[tuple[str, str]], title: str) -> None:
+ """Build a two-column summary panel and print it."""
+ if not rows:
+ return
+ table = Table(show_header=False, box=None, padding=(0, 2))
+ table.add_column("Setting", style="cyan")
+ table.add_column("Value")
+ for field, value in rows:
+ table.add_row(field, value)
+ console.print(Panel(table, title=f"[bold]{title}[/bold]", border_style="blue"))
+
+
+def _show_summary(config: Config) -> None:
+ """Display configuration summary using rich."""
+ console.print()
+
+ # Providers
+ provider_rows = []
+ for name, display in _get_provider_names().items():
+ provider = getattr(config.providers, name, None)
+ status = "[green]configured[/green]" if (provider and provider.api_key) else "[dim]not configured[/dim]"
+ provider_rows.append((display, status))
+ _print_summary_panel(provider_rows, "LLM Providers")
+
+ # Channels
+ channel_rows = []
+ for name, display in _get_channel_names().items():
+ channel = getattr(config.channels, name, None)
+ if channel:
+ enabled = (
+ channel.get("enabled", False)
+ if isinstance(channel, dict)
+ else getattr(channel, "enabled", False)
+ )
+ status = "[green]enabled[/green]" if enabled else "[dim]disabled[/dim]"
+ else:
+ status = "[dim]not configured[/dim]"
+ channel_rows.append((display, status))
+ _print_summary_panel(channel_rows, "Chat Channels")
+
+ # Settings sections
+ for title, model in [
+ ("Agent Settings", config.agents.defaults),
+ ("Gateway", config.gateway),
+ ("Tools", config.tools),
+ ("Channel Common", config.channels),
+ ]:
+ _print_summary_panel(_summarize_model(model), title)
+
+
+# --- Main Entry Point ---
+
+
+def _has_unsaved_changes(original: Config, current: Config) -> bool:
+ """Return True when the onboarding session has committed changes."""
+ return original.model_dump(by_alias=True) != current.model_dump(by_alias=True)
+
+
+def _prompt_main_menu_exit(has_unsaved_changes: bool) -> str:
+ """Resolve how to leave the main menu."""
+ if not has_unsaved_changes:
+ return "discard"
+
+ answer = _get_questionary().select(
+ "You have unsaved changes. What would you like to do?",
+ choices=[
+ "[S] Save and Exit",
+ "[X] Exit Without Saving",
+ "[R] Resume Editing",
+ ],
+ default="[R] Resume Editing",
+ qmark=">",
+ ).ask()
+
+ if answer == "[S] Save and Exit":
+ return "save"
+ if answer == "[X] Exit Without Saving":
+ return "discard"
+ return "resume"
+
+
+def run_onboard(initial_config: Config | None = None) -> OnboardResult:
+ """Run the interactive onboarding questionnaire.
+
+ Args:
+ initial_config: Optional pre-loaded config to use as starting point.
+ If None, loads from config file or creates new default.
+ """
+ _get_questionary()
+
+ if initial_config is not None:
+ base_config = initial_config.model_copy(deep=True)
+ else:
+ config_path = get_config_path()
+ if config_path.exists():
+ base_config = load_config()
+ else:
+ base_config = Config()
+
+ original_config = base_config.model_copy(deep=True)
+ config = base_config.model_copy(deep=True)
+
+ while True:
+ console.clear()
+ _show_main_menu_header()
+
+ try:
+ answer = _get_questionary().select(
+ "What would you like to configure?",
+ choices=[
+ "[P] LLM Provider",
+ "[C] Chat Channel",
+ "[A] Agent Settings",
+ "[G] Gateway",
+ "[T] Tools",
+ "[V] View Configuration Summary",
+ "[S] Save and Exit",
+ "[X] Exit Without Saving",
+ ],
+ qmark=">",
+ ).ask()
+ except KeyboardInterrupt:
+ answer = None
+
+ if answer is None:
+ action = _prompt_main_menu_exit(_has_unsaved_changes(original_config, config))
+ if action == "save":
+ return OnboardResult(config=config, should_save=True)
+ if action == "discard":
+ return OnboardResult(config=original_config, should_save=False)
+ continue
+
+ _MENU_DISPATCH = {
+ "[P] LLM Provider": lambda: _configure_providers(config),
+ "[C] Chat Channel": lambda: _configure_channels(config),
+ "[A] Agent Settings": lambda: _configure_general_settings(config, "Agent Settings"),
+ "[G] Gateway": lambda: _configure_general_settings(config, "Gateway"),
+ "[T] Tools": lambda: _configure_general_settings(config, "Tools"),
+ "[V] View Configuration Summary": lambda: _show_summary(config),
+ }
+
+ if answer == "[S] Save and Exit":
+ return OnboardResult(config=config, should_save=True)
+ if answer == "[X] Exit Without Saving":
+ return OnboardResult(config=original_config, should_save=False)
+
+ action_fn = _MENU_DISPATCH.get(answer)
+ if action_fn:
+ action_fn()
diff --git a/nanobot/cli/stream.py b/nanobot/cli/stream.py
new file mode 100644
index 000000000..8151e3ddc
--- /dev/null
+++ b/nanobot/cli/stream.py
@@ -0,0 +1,132 @@
+"""Streaming renderer for CLI output.
+
+Uses Rich Live with auto_refresh=False for stable, flicker-free
+markdown rendering during streaming. Ellipsis mode handles overflow.
+"""
+
+from __future__ import annotations
+
+import sys
+import time
+
+from rich.console import Console
+from rich.live import Live
+from rich.markdown import Markdown
+from rich.text import Text
+
+from nanobot import __logo__
+
+
+def _make_console() -> Console:
+ return Console(file=sys.stdout, force_terminal=True)
+
+
+class ThinkingSpinner:
+ """Spinner that shows 'nanobot is thinking...' with pause support."""
+
+ def __init__(self, console: Console | None = None):
+ c = console or _make_console()
+ self._spinner = c.status("[dim]nanobot is thinking...[/dim]", spinner="dots")
+ self._active = False
+
+ def __enter__(self):
+ self._spinner.start()
+ self._active = True
+ return self
+
+ def __exit__(self, *exc):
+ self._active = False
+ self._spinner.stop()
+ return False
+
+ def pause(self):
+ """Context manager: temporarily stop spinner for clean output."""
+ from contextlib import contextmanager
+
+ @contextmanager
+ def _ctx():
+ if self._spinner and self._active:
+ self._spinner.stop()
+ try:
+ yield
+ finally:
+ if self._spinner and self._active:
+ self._spinner.start()
+
+ return _ctx()
+
+
+class StreamRenderer:
+ """Rich Live streaming with markdown. auto_refresh=False avoids render races.
+
+ Deltas arrive pre-filtered (no tags) from the agent loop.
+
+ Flow per round:
+ spinner -> first visible delta -> header + Live renders ->
+ on_end -> Live stops (content stays on screen)
+ """
+
+ def __init__(self, render_markdown: bool = True, show_spinner: bool = True):
+ self._md = render_markdown
+ self._show_spinner = show_spinner
+ self._buf = ""
+ self._live: Live | None = None
+ self._t = 0.0
+ self.streamed = False
+ self._spinner: ThinkingSpinner | None = None
+ self._start_spinner()
+
+ def _render(self):
+ return Markdown(self._buf) if self._md and self._buf else Text(self._buf or "")
+
+ def _start_spinner(self) -> None:
+ if self._show_spinner:
+ self._spinner = ThinkingSpinner()
+ self._spinner.__enter__()
+
+ def _stop_spinner(self) -> None:
+ if self._spinner:
+ self._spinner.__exit__(None, None, None)
+ self._spinner = None
+
+ async def on_delta(self, delta: str) -> None:
+ self.streamed = True
+ self._buf += delta
+ if self._live is None:
+ if not self._buf.strip():
+ return
+ self._stop_spinner()
+ c = _make_console()
+ c.print()
+ c.print(f"[cyan]{__logo__} nanobot[/cyan]")
+ self._live = Live(self._render(), console=c, auto_refresh=False)
+ self._live.start()
+ now = time.monotonic()
+ if "\n" in delta or (now - self._t) > 0.05:
+ self._live.update(self._render())
+ self._live.refresh()
+ self._t = now
+
+ async def on_end(self, *, resuming: bool = False) -> None:
+ if self._live:
+ self._live.update(self._render())
+ self._live.refresh()
+ self._live.stop()
+ self._live = None
+ self._stop_spinner()
+ if resuming:
+ self._buf = ""
+ self._start_spinner()
+ else:
+ _make_console().print()
+
+ def stop_for_input(self) -> None:
+ """Stop spinner before user input to avoid prompt_toolkit conflicts."""
+ self._stop_spinner()
+
+ async def close(self) -> None:
+ """Stop spinner/live without rendering a final streamed round."""
+ if self._live:
+ self._live.stop()
+ self._live = None
+ self._stop_spinner()
diff --git a/nanobot/command/__init__.py b/nanobot/command/__init__.py
new file mode 100644
index 000000000..84e7138c6
--- /dev/null
+++ b/nanobot/command/__init__.py
@@ -0,0 +1,6 @@
+"""Slash command routing and built-in handlers."""
+
+from nanobot.command.builtin import register_builtin_commands
+from nanobot.command.router import CommandContext, CommandRouter
+
+__all__ = ["CommandContext", "CommandRouter", "register_builtin_commands"]
diff --git a/nanobot/command/builtin.py b/nanobot/command/builtin.py
new file mode 100644
index 000000000..514ac1438
--- /dev/null
+++ b/nanobot/command/builtin.py
@@ -0,0 +1,329 @@
+"""Built-in slash command handlers."""
+
+from __future__ import annotations
+
+import asyncio
+import os
+import sys
+
+from nanobot import __version__
+from nanobot.bus.events import OutboundMessage
+from nanobot.command.router import CommandContext, CommandRouter
+from nanobot.utils.helpers import build_status_content
+from nanobot.utils.restart import set_restart_notice_to_env
+
+
+async def cmd_stop(ctx: CommandContext) -> OutboundMessage:
+ """Cancel all active tasks and subagents for the session."""
+ loop = ctx.loop
+ msg = ctx.msg
+ tasks = loop._active_tasks.pop(msg.session_key, [])
+ cancelled = sum(1 for t in tasks if not t.done() and t.cancel())
+ for t in tasks:
+ try:
+ await t
+ except (asyncio.CancelledError, Exception):
+ pass
+ sub_cancelled = await loop.subagents.cancel_by_session(msg.session_key)
+ total = cancelled + sub_cancelled
+ content = f"Stopped {total} task(s)." if total else "No active task to stop."
+ return OutboundMessage(
+ channel=msg.channel, chat_id=msg.chat_id, content=content,
+ metadata=dict(msg.metadata or {})
+ )
+
+
+async def cmd_restart(ctx: CommandContext) -> OutboundMessage:
+ """Restart the process in-place via os.execv."""
+ msg = ctx.msg
+ set_restart_notice_to_env(channel=msg.channel, chat_id=msg.chat_id)
+
+ async def _do_restart():
+ await asyncio.sleep(1)
+ os.execv(sys.executable, [sys.executable, "-m", "nanobot"] + sys.argv[1:])
+
+ asyncio.create_task(_do_restart())
+ return OutboundMessage(
+ channel=msg.channel, chat_id=msg.chat_id, content="Restarting...",
+ metadata=dict(msg.metadata or {})
+ )
+
+
+async def cmd_status(ctx: CommandContext) -> OutboundMessage:
+ """Build an outbound status message for a session."""
+ loop = ctx.loop
+ session = ctx.session or loop.sessions.get_or_create(ctx.key)
+ ctx_est = 0
+ try:
+ ctx_est, _ = loop.consolidator.estimate_session_prompt_tokens(session)
+ except Exception:
+ pass
+ if ctx_est <= 0:
+ ctx_est = loop._last_usage.get("prompt_tokens", 0)
+ return OutboundMessage(
+ channel=ctx.msg.channel,
+ chat_id=ctx.msg.chat_id,
+ content=build_status_content(
+ version=__version__, model=loop.model,
+ start_time=loop._start_time, last_usage=loop._last_usage,
+ context_window_tokens=loop.context_window_tokens,
+ session_msg_count=len(session.get_history(max_messages=0)),
+ context_tokens_estimate=ctx_est,
+ ),
+ metadata={**dict(ctx.msg.metadata or {}), "render_as": "text"},
+ )
+
+
+async def cmd_new(ctx: CommandContext) -> OutboundMessage:
+ """Start a fresh session."""
+ loop = ctx.loop
+ session = ctx.session or loop.sessions.get_or_create(ctx.key)
+ snapshot = session.messages[session.last_consolidated:]
+ session.clear()
+ loop.sessions.save(session)
+ loop.sessions.invalidate(session.key)
+ if snapshot:
+ loop._schedule_background(loop.consolidator.archive(snapshot))
+ return OutboundMessage(
+ channel=ctx.msg.channel, chat_id=ctx.msg.chat_id,
+ content="New session started.",
+ metadata=dict(ctx.msg.metadata or {})
+ )
+
+
+async def cmd_dream(ctx: CommandContext) -> OutboundMessage:
+ """Manually trigger a Dream consolidation run."""
+ import time
+
+ loop = ctx.loop
+ msg = ctx.msg
+
+ async def _run_dream():
+ t0 = time.monotonic()
+ try:
+ did_work = await loop.dream.run()
+ elapsed = time.monotonic() - t0
+ if did_work:
+ content = f"Dream completed in {elapsed:.1f}s."
+ else:
+ content = "Dream: nothing to process."
+ except Exception as e:
+ elapsed = time.monotonic() - t0
+ content = f"Dream failed after {elapsed:.1f}s: {e}"
+ await loop.bus.publish_outbound(OutboundMessage(
+ channel=msg.channel, chat_id=msg.chat_id, content=content,
+ ))
+
+ asyncio.create_task(_run_dream())
+ return OutboundMessage(
+ channel=msg.channel, chat_id=msg.chat_id, content="Dreaming...",
+ )
+
+
+def _extract_changed_files(diff: str) -> list[str]:
+ """Extract changed file paths from a unified diff."""
+ files: list[str] = []
+ seen: set[str] = set()
+ for line in diff.splitlines():
+ if not line.startswith("diff --git "):
+ continue
+ parts = line.split()
+ if len(parts) < 4:
+ continue
+ path = parts[3]
+ if path.startswith("b/"):
+ path = path[2:]
+ if path in seen:
+ continue
+ seen.add(path)
+ files.append(path)
+ return files
+
+
+def _format_changed_files(diff: str) -> str:
+ files = _extract_changed_files(diff)
+ if not files:
+ return "No tracked memory files changed."
+ return ", ".join(f"`{path}`" for path in files)
+
+
+def _format_dream_log_content(commit, diff: str, *, requested_sha: str | None = None) -> str:
+ files_line = _format_changed_files(diff)
+ lines = [
+ "## Dream Update",
+ "",
+ "Here is the selected Dream memory change." if requested_sha else "Here is the latest Dream memory change.",
+ "",
+ f"- Commit: `{commit.sha}`",
+ f"- Time: {commit.timestamp}",
+ f"- Changed files: {files_line}",
+ ]
+ if diff:
+ lines.extend([
+ "",
+ f"Use `/dream-restore {commit.sha}` to undo this change.",
+ "",
+ "```diff",
+ diff.rstrip(),
+ "```",
+ ])
+ else:
+ lines.extend([
+ "",
+ "Dream recorded this version, but there is no file diff to display.",
+ ])
+ return "\n".join(lines)
+
+
+def _format_dream_restore_list(commits: list) -> str:
+ lines = [
+ "## Dream Restore",
+ "",
+ "Choose a Dream memory version to restore. Latest first:",
+ "",
+ ]
+ for c in commits:
+ lines.append(f"- `{c.sha}` {c.timestamp} - {c.message.splitlines()[0]}")
+ lines.extend([
+ "",
+ "Preview a version with `/dream-log ` before restoring it.",
+ "Restore a version with `/dream-restore `.",
+ ])
+ return "\n".join(lines)
+
+
+async def cmd_dream_log(ctx: CommandContext) -> OutboundMessage:
+ """Show what the last Dream changed.
+
+ Default: diff of the latest commit (HEAD~1 vs HEAD).
+ With /dream-log : diff of that specific commit.
+ """
+ store = ctx.loop.consolidator.store
+ git = store.git
+
+ if not git.is_initialized():
+ if store.get_last_dream_cursor() == 0:
+ msg = "Dream has not run yet. Run `/dream`, or wait for the next scheduled Dream cycle."
+ else:
+ msg = "Dream history is not available because memory versioning is not initialized."
+ return OutboundMessage(
+ channel=ctx.msg.channel, chat_id=ctx.msg.chat_id,
+ content=msg, metadata={"render_as": "text"},
+ )
+
+ args = ctx.args.strip()
+
+ if args:
+ # Show diff of a specific commit
+ sha = args.split()[0]
+ result = git.show_commit_diff(sha)
+ if not result:
+ content = (
+ f"Couldn't find Dream change `{sha}`.\n\n"
+ "Use `/dream-restore` to list recent versions, "
+ "or `/dream-log` to inspect the latest one."
+ )
+ else:
+ commit, diff = result
+ content = _format_dream_log_content(commit, diff, requested_sha=sha)
+ else:
+ # Default: show the latest commit's diff
+ commits = git.log(max_entries=1)
+ result = git.show_commit_diff(commits[0].sha) if commits else None
+ if result:
+ commit, diff = result
+ content = _format_dream_log_content(commit, diff)
+ else:
+ content = "Dream memory has no saved versions yet."
+
+ return OutboundMessage(
+ channel=ctx.msg.channel, chat_id=ctx.msg.chat_id,
+ content=content, metadata={"render_as": "text"},
+ )
+
+
+async def cmd_dream_restore(ctx: CommandContext) -> OutboundMessage:
+ """Restore memory files from a previous dream commit.
+
+ Usage:
+ /dream-restore — list recent commits
+ /dream-restore — revert a specific commit
+ """
+ store = ctx.loop.consolidator.store
+ git = store.git
+ if not git.is_initialized():
+ return OutboundMessage(
+ channel=ctx.msg.channel, chat_id=ctx.msg.chat_id,
+ content="Dream history is not available because memory versioning is not initialized.",
+ )
+
+ args = ctx.args.strip()
+ if not args:
+ # Show recent commits for the user to pick
+ commits = git.log(max_entries=10)
+ if not commits:
+ content = "Dream memory has no saved versions to restore yet."
+ else:
+ content = _format_dream_restore_list(commits)
+ else:
+ sha = args.split()[0]
+ result = git.show_commit_diff(sha)
+ changed_files = _format_changed_files(result[1]) if result else "the tracked memory files"
+ new_sha = git.revert(sha)
+ if new_sha:
+ content = (
+ f"Restored Dream memory to the state before `{sha}`.\n\n"
+ f"- New safety commit: `{new_sha}`\n"
+ f"- Restored files: {changed_files}\n\n"
+ f"Use `/dream-log {new_sha}` to inspect the restore diff."
+ )
+ else:
+ content = (
+ f"Couldn't restore Dream change `{sha}`.\n\n"
+ "It may not exist, or it may be the first saved version with no earlier state to restore."
+ )
+ return OutboundMessage(
+ channel=ctx.msg.channel, chat_id=ctx.msg.chat_id,
+ content=content, metadata={"render_as": "text"},
+ )
+
+
+async def cmd_help(ctx: CommandContext) -> OutboundMessage:
+ """Return available slash commands."""
+ return OutboundMessage(
+ channel=ctx.msg.channel,
+ chat_id=ctx.msg.chat_id,
+ content=build_help_text(),
+ metadata={**dict(ctx.msg.metadata or {}), "render_as": "text"},
+ )
+
+
+def build_help_text() -> str:
+ """Build canonical help text shared across channels."""
+ lines = [
+ "🐈 nanobot commands:",
+ "/new — Start a new conversation",
+ "/stop — Stop the current task",
+ "/restart — Restart the bot",
+ "/status — Show bot status",
+ "/dream — Manually trigger Dream consolidation",
+ "/dream-log — Show what the last Dream changed",
+ "/dream-restore — Revert memory to a previous state",
+ "/help — Show available commands",
+ ]
+ return "\n".join(lines)
+
+
+def register_builtin_commands(router: CommandRouter) -> None:
+ """Register the default set of slash commands."""
+ router.priority("/stop", cmd_stop)
+ router.priority("/restart", cmd_restart)
+ router.priority("/status", cmd_status)
+ router.exact("/new", cmd_new)
+ router.exact("/status", cmd_status)
+ router.exact("/dream", cmd_dream)
+ router.exact("/dream-log", cmd_dream_log)
+ router.prefix("/dream-log ", cmd_dream_log)
+ router.exact("/dream-restore", cmd_dream_restore)
+ router.prefix("/dream-restore ", cmd_dream_restore)
+ router.exact("/help", cmd_help)
diff --git a/nanobot/command/router.py b/nanobot/command/router.py
new file mode 100644
index 000000000..35a475453
--- /dev/null
+++ b/nanobot/command/router.py
@@ -0,0 +1,84 @@
+"""Minimal command routing table for slash commands."""
+
+from __future__ import annotations
+
+from dataclasses import dataclass
+from typing import TYPE_CHECKING, Any, Awaitable, Callable
+
+if TYPE_CHECKING:
+ from nanobot.bus.events import InboundMessage, OutboundMessage
+ from nanobot.session.manager import Session
+
+Handler = Callable[["CommandContext"], Awaitable["OutboundMessage | None"]]
+
+
+@dataclass
+class CommandContext:
+ """Everything a command handler needs to produce a response."""
+
+ msg: InboundMessage
+ session: Session | None
+ key: str
+ raw: str
+ args: str = ""
+ loop: Any = None
+
+
+class CommandRouter:
+ """Pure dict-based command dispatch.
+
+ Three tiers checked in order:
+ 1. *priority* — exact-match commands handled before the dispatch lock
+ (e.g. /stop, /restart).
+ 2. *exact* — exact-match commands handled inside the dispatch lock.
+ 3. *prefix* — longest-prefix-first match (e.g. "/team ").
+ 4. *interceptors* — fallback predicates (e.g. team-mode active check).
+ """
+
+ def __init__(self) -> None:
+ self._priority: dict[str, Handler] = {}
+ self._exact: dict[str, Handler] = {}
+ self._prefix: list[tuple[str, Handler]] = []
+ self._interceptors: list[Handler] = []
+
+ def priority(self, cmd: str, handler: Handler) -> None:
+ self._priority[cmd] = handler
+
+ def exact(self, cmd: str, handler: Handler) -> None:
+ self._exact[cmd] = handler
+
+ def prefix(self, pfx: str, handler: Handler) -> None:
+ self._prefix.append((pfx, handler))
+ self._prefix.sort(key=lambda p: len(p[0]), reverse=True)
+
+ def intercept(self, handler: Handler) -> None:
+ self._interceptors.append(handler)
+
+ def is_priority(self, text: str) -> bool:
+ return text.strip().lower() in self._priority
+
+ async def dispatch_priority(self, ctx: CommandContext) -> OutboundMessage | None:
+ """Dispatch a priority command. Called from run() without the lock."""
+ handler = self._priority.get(ctx.raw.lower())
+ if handler:
+ return await handler(ctx)
+ return None
+
+ async def dispatch(self, ctx: CommandContext) -> OutboundMessage | None:
+ """Try exact, prefix, then interceptors. Returns None if unhandled."""
+ cmd = ctx.raw.lower()
+
+ if handler := self._exact.get(cmd):
+ return await handler(ctx)
+
+ for pfx, handler in self._prefix:
+ if cmd.startswith(pfx):
+ ctx.args = ctx.raw[len(pfx):]
+ return await handler(ctx)
+
+ for interceptor in self._interceptors:
+ result = await interceptor(ctx)
+ if result is not None:
+ return result
+
+ return None
diff --git a/nanobot/config/__init__.py b/nanobot/config/__init__.py
index 88e8e9b07..4b9fccec3 100644
--- a/nanobot/config/__init__.py
+++ b/nanobot/config/__init__.py
@@ -1,6 +1,32 @@
"""Configuration module for nanobot."""
-from nanobot.config.loader import load_config, get_config_path
+from nanobot.config.loader import get_config_path, load_config
+from nanobot.config.paths import (
+ get_bridge_install_dir,
+ get_cli_history_path,
+ get_cron_dir,
+ get_data_dir,
+ get_legacy_sessions_dir,
+ is_default_workspace,
+ get_logs_dir,
+ get_media_dir,
+ get_runtime_subdir,
+ get_workspace_path,
+)
from nanobot.config.schema import Config
-__all__ = ["Config", "load_config", "get_config_path"]
+__all__ = [
+ "Config",
+ "load_config",
+ "get_config_path",
+ "get_data_dir",
+ "get_runtime_subdir",
+ "get_media_dir",
+ "get_cron_dir",
+ "get_logs_dir",
+ "get_workspace_path",
+ "is_default_workspace",
+ "get_cli_history_path",
+ "get_bridge_install_dir",
+ "get_legacy_sessions_dir",
+]
diff --git a/nanobot/config/loader.py b/nanobot/config/loader.py
index f8de88177..f5b2f33b8 100644
--- a/nanobot/config/loader.py
+++ b/nanobot/config/loader.py
@@ -2,94 +2,85 @@
import json
from pathlib import Path
-from typing import Any
+
+import pydantic
+from loguru import logger
from nanobot.config.schema import Config
+# Global variable to store current config path (for multi-instance support)
+_current_config_path: Path | None = None
+
+
+def set_config_path(path: Path) -> None:
+ """Set the current config path (used to derive data directory)."""
+ global _current_config_path
+ _current_config_path = path
+
def get_config_path() -> Path:
- """Get the default configuration file path."""
+ """Get the configuration file path."""
+ if _current_config_path:
+ return _current_config_path
return Path.home() / ".nanobot" / "config.json"
-def get_data_dir() -> Path:
- """Get the nanobot data directory."""
- from nanobot.utils.helpers import get_data_path
- return get_data_path()
-
-
def load_config(config_path: Path | None = None) -> Config:
"""
Load configuration from file or create default.
-
+
Args:
config_path: Optional path to config file. Uses default if not provided.
-
+
Returns:
Loaded configuration object.
"""
path = config_path or get_config_path()
-
+
+ config = Config()
if path.exists():
try:
- with open(path) as f:
+ with open(path, encoding="utf-8") as f:
data = json.load(f)
- return Config.model_validate(convert_keys(data))
- except (json.JSONDecodeError, ValueError) as e:
- print(f"Warning: Failed to load config from {path}: {e}")
- print("Using default configuration.")
-
- return Config()
+ data = _migrate_config(data)
+ config = Config.model_validate(data)
+ except (json.JSONDecodeError, ValueError, pydantic.ValidationError) as e:
+ logger.warning(f"Failed to load config from {path}: {e}")
+ logger.warning("Using default configuration.")
+
+ _apply_ssrf_whitelist(config)
+ return config
+
+
+def _apply_ssrf_whitelist(config: Config) -> None:
+ """Apply SSRF whitelist from config to the network security module."""
+ from nanobot.security.network import configure_ssrf_whitelist
+
+ configure_ssrf_whitelist(config.tools.ssrf_whitelist)
def save_config(config: Config, config_path: Path | None = None) -> None:
"""
Save configuration to file.
-
+
Args:
config: Configuration to save.
config_path: Optional path to save to. Uses default if not provided.
"""
path = config_path or get_config_path()
path.parent.mkdir(parents=True, exist_ok=True)
-
- # Convert to camelCase format
- data = config.model_dump()
- data = convert_to_camel(data)
-
- with open(path, "w") as f:
- json.dump(data, f, indent=2)
+
+ data = config.model_dump(mode="json", by_alias=True)
+
+ with open(path, "w", encoding="utf-8") as f:
+ json.dump(data, f, indent=2, ensure_ascii=False)
-def convert_keys(data: Any) -> Any:
- """Convert camelCase keys to snake_case for Pydantic."""
- if isinstance(data, dict):
- return {camel_to_snake(k): convert_keys(v) for k, v in data.items()}
- if isinstance(data, list):
- return [convert_keys(item) for item in data]
+def _migrate_config(data: dict) -> dict:
+ """Migrate old config formats to current."""
+ # Move tools.exec.restrictToWorkspace → tools.restrictToWorkspace
+ tools = data.get("tools", {})
+ exec_cfg = tools.get("exec", {})
+ if "restrictToWorkspace" in exec_cfg and "restrictToWorkspace" not in tools:
+ tools["restrictToWorkspace"] = exec_cfg.pop("restrictToWorkspace")
return data
-
-
-def convert_to_camel(data: Any) -> Any:
- """Convert snake_case keys to camelCase."""
- if isinstance(data, dict):
- return {snake_to_camel(k): convert_to_camel(v) for k, v in data.items()}
- if isinstance(data, list):
- return [convert_to_camel(item) for item in data]
- return data
-
-
-def camel_to_snake(name: str) -> str:
- """Convert camelCase to snake_case."""
- result = []
- for i, char in enumerate(name):
- if char.isupper() and i > 0:
- result.append("_")
- result.append(char.lower())
- return "".join(result)
-
-
-def snake_to_camel(name: str) -> str:
- """Convert snake_case to camelCase."""
- components = name.split("_")
- return components[0] + "".join(x.title() for x in components[1:])
diff --git a/nanobot/config/paths.py b/nanobot/config/paths.py
new file mode 100644
index 000000000..527c5f38e
--- /dev/null
+++ b/nanobot/config/paths.py
@@ -0,0 +1,62 @@
+"""Runtime path helpers derived from the active config context."""
+
+from __future__ import annotations
+
+from pathlib import Path
+
+from nanobot.config.loader import get_config_path
+from nanobot.utils.helpers import ensure_dir
+
+
+def get_data_dir() -> Path:
+ """Return the instance-level runtime data directory."""
+ return ensure_dir(get_config_path().parent)
+
+
+def get_runtime_subdir(name: str) -> Path:
+ """Return a named runtime subdirectory under the instance data dir."""
+ return ensure_dir(get_data_dir() / name)
+
+
+def get_media_dir(channel: str | None = None) -> Path:
+ """Return the media directory, optionally namespaced per channel."""
+ base = get_runtime_subdir("media")
+ return ensure_dir(base / channel) if channel else base
+
+
+def get_cron_dir() -> Path:
+ """Return the cron storage directory."""
+ return get_runtime_subdir("cron")
+
+
+def get_logs_dir() -> Path:
+ """Return the logs directory."""
+ return get_runtime_subdir("logs")
+
+
+def get_workspace_path(workspace: str | None = None) -> Path:
+ """Resolve and ensure the agent workspace path."""
+ path = Path(workspace).expanduser() if workspace else Path.home() / ".nanobot" / "workspace"
+ return ensure_dir(path)
+
+
+def is_default_workspace(workspace: str | Path | None) -> bool:
+ """Return whether a workspace resolves to nanobot's default workspace path."""
+ current = Path(workspace).expanduser() if workspace is not None else Path.home() / ".nanobot" / "workspace"
+ default = Path.home() / ".nanobot" / "workspace"
+ return current.resolve(strict=False) == default.resolve(strict=False)
+
+
+def get_cli_history_path() -> Path:
+ """Return the shared CLI history file path."""
+ return Path.home() / ".nanobot" / "history" / "cli_history"
+
+
+def get_bridge_install_dir() -> Path:
+ """Return the shared WhatsApp bridge installation directory."""
+ return Path.home() / ".nanobot" / "bridge"
+
+
+def get_legacy_sessions_dir() -> Path:
+ """Return the legacy global session directory used for migration fallback."""
+ return Path.home() / ".nanobot" / "sessions"
diff --git a/nanobot/config/schema.py b/nanobot/config/schema.py
index 4c348348e..dfb91c528 100644
--- a/nanobot/config/schema.py
+++ b/nanobot/config/schema.py
@@ -1,126 +1,311 @@
"""Configuration schema using Pydantic."""
from pathlib import Path
-from pydantic import BaseModel, Field
+from typing import Literal
+
+from pydantic import AliasChoices, BaseModel, ConfigDict, Field
+from pydantic.alias_generators import to_camel
from pydantic_settings import BaseSettings
-
-class WhatsAppConfig(BaseModel):
- """WhatsApp channel configuration."""
- enabled: bool = False
- bridge_url: str = "ws://localhost:3001"
- allow_from: list[str] = Field(default_factory=list) # Allowed phone numbers
+from nanobot.cron.types import CronSchedule
-class TelegramConfig(BaseModel):
- """Telegram channel configuration."""
- enabled: bool = False
- token: str = "" # Bot token from @BotFather
- allow_from: list[str] = Field(default_factory=list) # Allowed user IDs or usernames
+class Base(BaseModel):
+ """Base model that accepts both camelCase and snake_case keys."""
+
+ model_config = ConfigDict(alias_generator=to_camel, populate_by_name=True)
+
+class ChannelsConfig(Base):
+ """Configuration for chat channels.
+
+ Built-in and plugin channel configs are stored as extra fields (dicts).
+ Each channel parses its own config in __init__.
+ Per-channel "streaming": true enables streaming output (requires send_delta impl).
+ """
+
+ model_config = ConfigDict(extra="allow")
+
+ send_progress: bool = True # stream agent's text progress to the channel
+ send_tool_hints: bool = False # stream tool-call hints (e.g. read_file("…"))
+ send_max_retries: int = Field(default=3, ge=0, le=10) # Max delivery attempts (initial send included)
-class ChannelsConfig(BaseModel):
- """Configuration for chat channels."""
- whatsapp: WhatsAppConfig = Field(default_factory=WhatsAppConfig)
- telegram: TelegramConfig = Field(default_factory=TelegramConfig)
+class DreamConfig(Base):
+ """Dream memory consolidation configuration."""
+
+ _HOUR_MS = 3_600_000
+
+ interval_h: int = Field(default=2, ge=1) # Every 2 hours by default
+ cron: str | None = Field(default=None, exclude=True) # Legacy compatibility override
+ model_override: str | None = Field(
+ default=None,
+ validation_alias=AliasChoices("modelOverride", "model", "model_override"),
+ ) # Optional Dream-specific model override
+ max_batch_size: int = Field(default=20, ge=1) # Max history entries per run
+ max_iterations: int = Field(default=10, ge=1) # Max tool calls per Phase 2
+
+ def build_schedule(self, timezone: str) -> CronSchedule:
+ """Build the runtime schedule, preferring the legacy cron override if present."""
+ if self.cron:
+ return CronSchedule(kind="cron", expr=self.cron, tz=timezone)
+ return CronSchedule(kind="every", every_ms=self.interval_h * self._HOUR_MS)
+
+ def describe_schedule(self) -> str:
+ """Return a human-readable summary for logs and startup output."""
+ if self.cron:
+ return f"cron {self.cron} (legacy)"
+ hours = self.interval_h
+ return f"every {hours}h"
-class AgentDefaults(BaseModel):
+class AgentDefaults(Base):
"""Default agent configuration."""
+
workspace: str = "~/.nanobot/workspace"
model: str = "anthropic/claude-opus-4-5"
+ provider: str = (
+ "auto" # Provider name (e.g. "anthropic", "openrouter") or "auto" for auto-detection
+ )
max_tokens: int = 8192
- temperature: float = 0.7
- max_tool_iterations: int = 20
+ context_window_tokens: int = 65_536
+ context_block_limit: int | None = None
+ temperature: float = 0.1
+ max_tool_iterations: int = 200
+ max_tool_result_chars: int = 16_000
+ provider_retry_mode: Literal["standard", "persistent"] = "standard"
+ reasoning_effort: str | None = None # low / medium / high - enables LLM thinking mode
+ timezone: str = "UTC" # IANA timezone, e.g. "Asia/Shanghai", "America/New_York"
+ dream: DreamConfig = Field(default_factory=DreamConfig)
-class AgentsConfig(BaseModel):
+class AgentsConfig(Base):
"""Agent configuration."""
+
defaults: AgentDefaults = Field(default_factory=AgentDefaults)
-class ProviderConfig(BaseModel):
+class ProviderConfig(Base):
"""LLM provider configuration."""
+
api_key: str = ""
api_base: str | None = None
+ extra_headers: dict[str, str] | None = None # Custom headers (e.g. APP-Code for AiHubMix)
-class ProvidersConfig(BaseModel):
+class ProvidersConfig(Base):
"""Configuration for LLM providers."""
+
+ custom: ProviderConfig = Field(default_factory=ProviderConfig) # Any OpenAI-compatible endpoint
+ azure_openai: ProviderConfig = Field(default_factory=ProviderConfig) # Azure OpenAI (model = deployment name)
anthropic: ProviderConfig = Field(default_factory=ProviderConfig)
openai: ProviderConfig = Field(default_factory=ProviderConfig)
openrouter: ProviderConfig = Field(default_factory=ProviderConfig)
+ deepseek: ProviderConfig = Field(default_factory=ProviderConfig)
groq: ProviderConfig = Field(default_factory=ProviderConfig)
zhipu: ProviderConfig = Field(default_factory=ProviderConfig)
+ dashscope: ProviderConfig = Field(default_factory=ProviderConfig)
vllm: ProviderConfig = Field(default_factory=ProviderConfig)
+ ollama: ProviderConfig = Field(default_factory=ProviderConfig) # Ollama local models
+ ovms: ProviderConfig = Field(default_factory=ProviderConfig) # OpenVINO Model Server (OVMS)
gemini: ProviderConfig = Field(default_factory=ProviderConfig)
+ moonshot: ProviderConfig = Field(default_factory=ProviderConfig)
+ minimax: ProviderConfig = Field(default_factory=ProviderConfig)
+ mistral: ProviderConfig = Field(default_factory=ProviderConfig)
+ stepfun: ProviderConfig = Field(default_factory=ProviderConfig) # Step Fun (阶跃星辰)
+ xiaomi_mimo: ProviderConfig = Field(default_factory=ProviderConfig) # Xiaomi MIMO (小米)
+ aihubmix: ProviderConfig = Field(default_factory=ProviderConfig) # AiHubMix API gateway
+ siliconflow: ProviderConfig = Field(default_factory=ProviderConfig) # SiliconFlow (硅基流动)
+ volcengine: ProviderConfig = Field(default_factory=ProviderConfig) # VolcEngine (火山引擎)
+ volcengine_coding_plan: ProviderConfig = Field(default_factory=ProviderConfig) # VolcEngine Coding Plan
+ byteplus: ProviderConfig = Field(default_factory=ProviderConfig) # BytePlus (VolcEngine international)
+ byteplus_coding_plan: ProviderConfig = Field(default_factory=ProviderConfig) # BytePlus Coding Plan
+ openai_codex: ProviderConfig = Field(default_factory=ProviderConfig, exclude=True) # OpenAI Codex (OAuth)
+ github_copilot: ProviderConfig = Field(default_factory=ProviderConfig, exclude=True) # Github Copilot (OAuth)
+ qianfan: ProviderConfig = Field(default_factory=ProviderConfig) # Qianfan (百度千帆)
-class GatewayConfig(BaseModel):
+class HeartbeatConfig(Base):
+ """Heartbeat service configuration."""
+
+ enabled: bool = True
+ interval_s: int = 30 * 60 # 30 minutes
+ keep_recent_messages: int = 8
+
+
+class ApiConfig(Base):
+ """OpenAI-compatible API server configuration."""
+
+ host: str = "127.0.0.1" # Safer default: local-only bind.
+ port: int = 8900
+ timeout: float = 120.0 # Per-request timeout in seconds.
+
+
+class GatewayConfig(Base):
"""Gateway/server configuration."""
+
host: str = "0.0.0.0"
port: int = 18790
+ heartbeat: HeartbeatConfig = Field(default_factory=HeartbeatConfig)
-class WebSearchConfig(BaseModel):
+class WebSearchConfig(Base):
"""Web search tool configuration."""
- api_key: str = "" # Brave Search API key
+
+ provider: str = "duckduckgo" # brave, tavily, duckduckgo, searxng, jina
+ api_key: str = ""
+ base_url: str = "" # SearXNG base URL
max_results: int = 5
+ timeout: int = 30 # Wall-clock timeout (seconds) for search operations
-class WebToolsConfig(BaseModel):
+class WebToolsConfig(Base):
"""Web tools configuration."""
+
+ enable: bool = True
+ proxy: str | None = (
+ None # HTTP/SOCKS5 proxy URL, e.g. "http://127.0.0.1:7890" or "socks5://127.0.0.1:1080"
+ )
search: WebSearchConfig = Field(default_factory=WebSearchConfig)
-class ExecToolConfig(BaseModel):
+class ExecToolConfig(Base):
"""Shell exec tool configuration."""
+
+ enable: bool = True
timeout: int = 60
- restrict_to_workspace: bool = False # If true, block commands accessing paths outside workspace
+ path_append: str = ""
+ sandbox: str = "" # sandbox backend: "" (none) or "bwrap"
+class MCPServerConfig(Base):
+ """MCP server connection configuration (stdio or HTTP)."""
-class ToolsConfig(BaseModel):
+ type: Literal["stdio", "sse", "streamableHttp"] | None = None # auto-detected if omitted
+ command: str = "" # Stdio: command to run (e.g. "npx")
+ args: list[str] = Field(default_factory=list) # Stdio: command arguments
+ env: dict[str, str] = Field(default_factory=dict) # Stdio: extra env vars
+ url: str = "" # HTTP/SSE: endpoint URL
+ headers: dict[str, str] = Field(default_factory=dict) # HTTP/SSE: custom headers
+ tool_timeout: int = 30 # seconds before a tool call is cancelled
+ enabled_tools: list[str] = Field(default_factory=lambda: ["*"]) # Only register these tools; accepts raw MCP names or wrapped mcp__ names; ["*"] = all tools; [] = no tools
+
+class ToolsConfig(Base):
"""Tools configuration."""
+
web: WebToolsConfig = Field(default_factory=WebToolsConfig)
exec: ExecToolConfig = Field(default_factory=ExecToolConfig)
+ restrict_to_workspace: bool = False # restrict all tool access to workspace directory
+ mcp_servers: dict[str, MCPServerConfig] = Field(default_factory=dict)
+ ssrf_whitelist: list[str] = Field(default_factory=list) # CIDR ranges to exempt from SSRF blocking (e.g. ["100.64.0.0/10"] for Tailscale)
class Config(BaseSettings):
"""Root configuration for nanobot."""
+
agents: AgentsConfig = Field(default_factory=AgentsConfig)
channels: ChannelsConfig = Field(default_factory=ChannelsConfig)
providers: ProvidersConfig = Field(default_factory=ProvidersConfig)
+ api: ApiConfig = Field(default_factory=ApiConfig)
gateway: GatewayConfig = Field(default_factory=GatewayConfig)
tools: ToolsConfig = Field(default_factory=ToolsConfig)
-
+
@property
def workspace_path(self) -> Path:
"""Get expanded workspace path."""
return Path(self.agents.defaults.workspace).expanduser()
-
- def get_api_key(self) -> str | None:
- """Get API key in priority order: OpenRouter > Anthropic > OpenAI > Gemini > Zhipu > Groq > vLLM."""
- return (
- self.providers.openrouter.api_key or
- self.providers.anthropic.api_key or
- self.providers.openai.api_key or
- self.providers.gemini.api_key or
- self.providers.zhipu.api_key or
- self.providers.groq.api_key or
- self.providers.vllm.api_key or
- None
- )
-
- def get_api_base(self) -> str | None:
- """Get API base URL if using OpenRouter, Zhipu or vLLM."""
- if self.providers.openrouter.api_key:
- return self.providers.openrouter.api_base or "https://openrouter.ai/api/v1"
- if self.providers.zhipu.api_key:
- return self.providers.zhipu.api_base
- if self.providers.vllm.api_base:
- return self.providers.vllm.api_base
+
+ def _match_provider(
+ self, model: str | None = None
+ ) -> tuple["ProviderConfig | None", str | None]:
+ """Match provider config and its registry name. Returns (config, spec_name)."""
+ from nanobot.providers.registry import PROVIDERS, find_by_name
+
+ forced = self.agents.defaults.provider
+ if forced != "auto":
+ spec = find_by_name(forced)
+ if spec:
+ p = getattr(self.providers, spec.name, None)
+ return (p, spec.name) if p else (None, None)
+ return None, None
+
+ model_lower = (model or self.agents.defaults.model).lower()
+ model_normalized = model_lower.replace("-", "_")
+ model_prefix = model_lower.split("/", 1)[0] if "/" in model_lower else ""
+ normalized_prefix = model_prefix.replace("-", "_")
+
+ def _kw_matches(kw: str) -> bool:
+ kw = kw.lower()
+ return kw in model_lower or kw.replace("-", "_") in model_normalized
+
+ # Explicit provider prefix wins — prevents `github-copilot/...codex` matching openai_codex.
+ for spec in PROVIDERS:
+ p = getattr(self.providers, spec.name, None)
+ if p and model_prefix and normalized_prefix == spec.name:
+ if spec.is_oauth or spec.is_local or p.api_key:
+ return p, spec.name
+
+ # Match by keyword (order follows PROVIDERS registry)
+ for spec in PROVIDERS:
+ p = getattr(self.providers, spec.name, None)
+ if p and any(_kw_matches(kw) for kw in spec.keywords):
+ if spec.is_oauth or spec.is_local or p.api_key:
+ return p, spec.name
+
+ # Fallback: configured local providers can route models without
+ # provider-specific keywords (for example plain "llama3.2" on Ollama).
+ # Prefer providers whose detect_by_base_keyword matches the configured api_base
+ # (e.g. Ollama's "11434" in "http://localhost:11434") over plain registry order.
+ local_fallback: tuple[ProviderConfig, str] | None = None
+ for spec in PROVIDERS:
+ if not spec.is_local:
+ continue
+ p = getattr(self.providers, spec.name, None)
+ if not (p and p.api_base):
+ continue
+ if spec.detect_by_base_keyword and spec.detect_by_base_keyword in p.api_base:
+ return p, spec.name
+ if local_fallback is None:
+ local_fallback = (p, spec.name)
+ if local_fallback:
+ return local_fallback
+
+ # Fallback: gateways first, then others (follows registry order)
+ # OAuth providers are NOT valid fallbacks — they require explicit model selection
+ for spec in PROVIDERS:
+ if spec.is_oauth:
+ continue
+ p = getattr(self.providers, spec.name, None)
+ if p and p.api_key:
+ return p, spec.name
+ return None, None
+
+ def get_provider(self, model: str | None = None) -> ProviderConfig | None:
+ """Get matched provider config (api_key, api_base, extra_headers). Falls back to first available."""
+ p, _ = self._match_provider(model)
+ return p
+
+ def get_provider_name(self, model: str | None = None) -> str | None:
+ """Get the registry name of the matched provider (e.g. "deepseek", "openrouter")."""
+ _, name = self._match_provider(model)
+ return name
+
+ def get_api_key(self, model: str | None = None) -> str | None:
+ """Get API key for the given model. Falls back to first available key."""
+ p = self.get_provider(model)
+ return p.api_key if p else None
+
+ def get_api_base(self, model: str | None = None) -> str | None:
+ """Get API base URL for the given model. Applies default URLs for gateway/local providers."""
+ from nanobot.providers.registry import find_by_name
+
+ p, name = self._match_provider(model)
+ if p and p.api_base:
+ return p.api_base
+ # Only gateways get a default api_base here. Standard providers
+ # resolve their base URL from the registry in the provider constructor.
+ if name:
+ spec = find_by_name(name)
+ if spec and (spec.is_gateway or spec.is_local) and spec.default_api_base:
+ return spec.default_api_base
return None
-
- class Config:
- env_prefix = "NANOBOT_"
- env_nested_delimiter = "__"
+
+ model_config = ConfigDict(env_prefix="NANOBOT_", env_nested_delimiter="__")
diff --git a/nanobot/cron/service.py b/nanobot/cron/service.py
index d1965a9ec..d60846640 100644
--- a/nanobot/cron/service.py
+++ b/nanobot/cron/service.py
@@ -4,12 +4,13 @@ import asyncio
import json
import time
import uuid
+from datetime import datetime
from pathlib import Path
-from typing import Any, Callable, Coroutine
+from typing import Any, Callable, Coroutine, Literal
from loguru import logger
-from nanobot.cron.types import CronJob, CronJobState, CronPayload, CronSchedule, CronStore
+from nanobot.cron.types import CronJob, CronJobState, CronPayload, CronRunRecord, CronSchedule, CronStore
def _now_ms() -> int:
@@ -20,47 +21,75 @@ def _compute_next_run(schedule: CronSchedule, now_ms: int) -> int | None:
"""Compute next run time in ms."""
if schedule.kind == "at":
return schedule.at_ms if schedule.at_ms and schedule.at_ms > now_ms else None
-
+
if schedule.kind == "every":
if not schedule.every_ms or schedule.every_ms <= 0:
return None
# Next interval from now
return now_ms + schedule.every_ms
-
+
if schedule.kind == "cron" and schedule.expr:
try:
+ from zoneinfo import ZoneInfo
+
from croniter import croniter
- cron = croniter(schedule.expr, time.time())
- next_time = cron.get_next()
- return int(next_time * 1000)
+ # Use caller-provided reference time for deterministic scheduling
+ base_time = now_ms / 1000
+ tz = ZoneInfo(schedule.tz) if schedule.tz else datetime.now().astimezone().tzinfo
+ base_dt = datetime.fromtimestamp(base_time, tz=tz)
+ cron = croniter(schedule.expr, base_dt)
+ next_dt = cron.get_next(datetime)
+ return int(next_dt.timestamp() * 1000)
except Exception:
return None
-
+
return None
+def _validate_schedule_for_add(schedule: CronSchedule) -> None:
+ """Validate schedule fields that would otherwise create non-runnable jobs."""
+ if schedule.tz and schedule.kind != "cron":
+ raise ValueError("tz can only be used with cron schedules")
+
+ if schedule.kind == "cron" and schedule.tz:
+ try:
+ from zoneinfo import ZoneInfo
+
+ ZoneInfo(schedule.tz)
+ except Exception:
+ raise ValueError(f"unknown timezone '{schedule.tz}'") from None
+
+
class CronService:
"""Service for managing and executing scheduled jobs."""
-
+
+ _MAX_RUN_HISTORY = 20
+
def __init__(
self,
store_path: Path,
- on_job: Callable[[CronJob], Coroutine[Any, Any, str | None]] | None = None
+ on_job: Callable[[CronJob], Coroutine[Any, Any, str | None]] | None = None,
):
self.store_path = store_path
- self.on_job = on_job # Callback to execute job, returns response text
+ self.on_job = on_job
self._store: CronStore | None = None
+ self._last_mtime: float = 0.0
self._timer_task: asyncio.Task | None = None
self._running = False
-
+
def _load_store(self) -> CronStore:
- """Load jobs from disk."""
+ """Load jobs from disk. Reloads automatically if file was modified externally."""
+ if self._store and self.store_path.exists():
+ mtime = self.store_path.stat().st_mtime
+ if mtime != self._last_mtime:
+ logger.info("Cron: jobs.json modified externally, reloading")
+ self._store = None
if self._store:
return self._store
-
+
if self.store_path.exists():
try:
- data = json.loads(self.store_path.read_text())
+ data = json.loads(self.store_path.read_text(encoding="utf-8"))
jobs = []
for j in data.get("jobs", []):
jobs.append(CronJob(
@@ -86,6 +115,15 @@ class CronService:
last_run_at_ms=j.get("state", {}).get("lastRunAtMs"),
last_status=j.get("state", {}).get("lastStatus"),
last_error=j.get("state", {}).get("lastError"),
+ run_history=[
+ CronRunRecord(
+ run_at_ms=r["runAtMs"],
+ status=r["status"],
+ duration_ms=r.get("durationMs", 0),
+ error=r.get("error"),
+ )
+ for r in j.get("state", {}).get("runHistory", [])
+ ],
),
created_at_ms=j.get("createdAtMs", 0),
updated_at_ms=j.get("updatedAtMs", 0),
@@ -93,20 +131,20 @@ class CronService:
))
self._store = CronStore(jobs=jobs)
except Exception as e:
- logger.warning(f"Failed to load cron store: {e}")
+ logger.warning("Failed to load cron store: {}", e)
self._store = CronStore()
else:
self._store = CronStore()
-
+
return self._store
-
+
def _save_store(self) -> None:
"""Save jobs to disk."""
if not self._store:
return
-
+
self.store_path.parent.mkdir(parents=True, exist_ok=True)
-
+
data = {
"version": self._store.version,
"jobs": [
@@ -133,6 +171,15 @@ class CronService:
"lastRunAtMs": j.state.last_run_at_ms,
"lastStatus": j.state.last_status,
"lastError": j.state.last_error,
+ "runHistory": [
+ {
+ "runAtMs": r.run_at_ms,
+ "status": r.status,
+ "durationMs": r.duration_ms,
+ "error": r.error,
+ }
+ for r in j.state.run_history
+ ],
},
"createdAtMs": j.created_at_ms,
"updatedAtMs": j.updated_at_ms,
@@ -141,8 +188,9 @@ class CronService:
for j in self._store.jobs
]
}
-
- self.store_path.write_text(json.dumps(data, indent=2))
+
+ self.store_path.write_text(json.dumps(data, indent=2, ensure_ascii=False), encoding="utf-8")
+ self._last_mtime = self.store_path.stat().st_mtime
async def start(self) -> None:
"""Start the cron service."""
@@ -151,15 +199,15 @@ class CronService:
self._recompute_next_runs()
self._save_store()
self._arm_timer()
- logger.info(f"Cron service started with {len(self._store.jobs if self._store else [])} jobs")
-
+ logger.info("Cron service started with {} jobs", len(self._store.jobs if self._store else []))
+
def stop(self) -> None:
"""Stop the cron service."""
self._running = False
if self._timer_task:
self._timer_task.cancel()
self._timer_task = None
-
+
def _recompute_next_runs(self) -> None:
"""Recompute next run times for all enabled jobs."""
if not self._store:
@@ -168,73 +216,82 @@ class CronService:
for job in self._store.jobs:
if job.enabled:
job.state.next_run_at_ms = _compute_next_run(job.schedule, now)
-
+
def _get_next_wake_ms(self) -> int | None:
"""Get the earliest next run time across all jobs."""
if not self._store:
return None
- times = [j.state.next_run_at_ms for j in self._store.jobs
+ times = [j.state.next_run_at_ms for j in self._store.jobs
if j.enabled and j.state.next_run_at_ms]
return min(times) if times else None
-
+
def _arm_timer(self) -> None:
"""Schedule the next timer tick."""
if self._timer_task:
self._timer_task.cancel()
-
+
next_wake = self._get_next_wake_ms()
if not next_wake or not self._running:
return
-
+
delay_ms = max(0, next_wake - _now_ms())
delay_s = delay_ms / 1000
-
+
async def tick():
await asyncio.sleep(delay_s)
if self._running:
await self._on_timer()
-
+
self._timer_task = asyncio.create_task(tick())
-
+
async def _on_timer(self) -> None:
"""Handle timer tick - run due jobs."""
+ self._load_store()
if not self._store:
return
-
+
now = _now_ms()
due_jobs = [
j for j in self._store.jobs
if j.enabled and j.state.next_run_at_ms and now >= j.state.next_run_at_ms
]
-
+
for job in due_jobs:
await self._execute_job(job)
-
+
self._save_store()
self._arm_timer()
-
+
async def _execute_job(self, job: CronJob) -> None:
"""Execute a single job."""
start_ms = _now_ms()
- logger.info(f"Cron: executing job '{job.name}' ({job.id})")
-
+ logger.info("Cron: executing job '{}' ({})", job.name, job.id)
+
try:
- response = None
if self.on_job:
- response = await self.on_job(job)
-
+ await self.on_job(job)
+
job.state.last_status = "ok"
job.state.last_error = None
- logger.info(f"Cron: job '{job.name}' completed")
-
+ logger.info("Cron: job '{}' completed", job.name)
+
except Exception as e:
job.state.last_status = "error"
job.state.last_error = str(e)
- logger.error(f"Cron: job '{job.name}' failed: {e}")
-
+ logger.error("Cron: job '{}' failed: {}", job.name, e)
+
+ end_ms = _now_ms()
job.state.last_run_at_ms = start_ms
- job.updated_at_ms = _now_ms()
-
+ job.updated_at_ms = end_ms
+
+ job.state.run_history.append(CronRunRecord(
+ run_at_ms=start_ms,
+ status=job.state.last_status,
+ duration_ms=end_ms - start_ms,
+ error=job.state.last_error,
+ ))
+ job.state.run_history = job.state.run_history[-self._MAX_RUN_HISTORY:]
+
# Handle one-shot jobs
if job.schedule.kind == "at":
if job.delete_after_run:
@@ -245,15 +302,15 @@ class CronService:
else:
# Compute next run
job.state.next_run_at_ms = _compute_next_run(job.schedule, _now_ms())
-
+
# ========== Public API ==========
-
+
def list_jobs(self, include_disabled: bool = False) -> list[CronJob]:
"""List all jobs."""
store = self._load_store()
jobs = store.jobs if include_disabled else [j for j in store.jobs if j.enabled]
return sorted(jobs, key=lambda j: j.state.next_run_at_ms or float('inf'))
-
+
def add_job(
self,
name: str,
@@ -266,8 +323,9 @@ class CronService:
) -> CronJob:
"""Add a new job."""
store = self._load_store()
+ _validate_schedule_for_add(schedule)
now = _now_ms()
-
+
job = CronJob(
id=str(uuid.uuid4())[:8],
name=name,
@@ -285,28 +343,50 @@ class CronService:
updated_at_ms=now,
delete_after_run=delete_after_run,
)
-
+
store.jobs.append(job)
self._save_store()
self._arm_timer()
-
- logger.info(f"Cron: added job '{name}' ({job.id})")
+
+ logger.info("Cron: added job '{}' ({})", name, job.id)
return job
-
- def remove_job(self, job_id: str) -> bool:
- """Remove a job by ID."""
+
+ def register_system_job(self, job: CronJob) -> CronJob:
+ """Register an internal system job (idempotent on restart)."""
store = self._load_store()
+ now = _now_ms()
+ job.state = CronJobState(next_run_at_ms=_compute_next_run(job.schedule, now))
+ job.created_at_ms = now
+ job.updated_at_ms = now
+ store.jobs = [j for j in store.jobs if j.id != job.id]
+ store.jobs.append(job)
+ self._save_store()
+ self._arm_timer()
+ logger.info("Cron: registered system job '{}' ({})", job.name, job.id)
+ return job
+
+ def remove_job(self, job_id: str) -> Literal["removed", "protected", "not_found"]:
+ """Remove a job by ID, unless it is a protected system job."""
+ store = self._load_store()
+ job = next((j for j in store.jobs if j.id == job_id), None)
+ if job is None:
+ return "not_found"
+ if job.payload.kind == "system_event":
+ logger.info("Cron: refused to remove protected system job {}", job_id)
+ return "protected"
+
before = len(store.jobs)
store.jobs = [j for j in store.jobs if j.id != job_id]
removed = len(store.jobs) < before
-
+
if removed:
self._save_store()
self._arm_timer()
- logger.info(f"Cron: removed job {job_id}")
-
- return removed
-
+ logger.info("Cron: removed job {}", job_id)
+ return "removed"
+
+ return "not_found"
+
def enable_job(self, job_id: str, enabled: bool = True) -> CronJob | None:
"""Enable or disable a job."""
store = self._load_store()
@@ -322,7 +402,7 @@ class CronService:
self._arm_timer()
return job
return None
-
+
async def run_job(self, job_id: str, force: bool = False) -> bool:
"""Manually run a job."""
store = self._load_store()
@@ -335,7 +415,12 @@ class CronService:
self._arm_timer()
return True
return False
-
+
+ def get_job(self, job_id: str) -> CronJob | None:
+ """Get a job by ID."""
+ store = self._load_store()
+ return next((j for j in store.jobs if j.id == job_id), None)
+
def status(self) -> dict:
"""Get service status."""
store = self._load_store()
diff --git a/nanobot/cron/types.py b/nanobot/cron/types.py
index 2b4206057..e7b2c4391 100644
--- a/nanobot/cron/types.py
+++ b/nanobot/cron/types.py
@@ -29,6 +29,15 @@ class CronPayload:
to: str | None = None # e.g. phone number
+@dataclass
+class CronRunRecord:
+ """A single execution record for a cron job."""
+ run_at_ms: int
+ status: Literal["ok", "error", "skipped"]
+ duration_ms: int = 0
+ error: str | None = None
+
+
@dataclass
class CronJobState:
"""Runtime state of a job."""
@@ -36,6 +45,7 @@ class CronJobState:
last_run_at_ms: int | None = None
last_status: Literal["ok", "error", "skipped"] | None = None
last_error: str | None = None
+ run_history: list[CronRunRecord] = field(default_factory=list)
@dataclass
diff --git a/nanobot/heartbeat/service.py b/nanobot/heartbeat/service.py
index 221ed27d9..00f6b17e1 100644
--- a/nanobot/heartbeat/service.py
+++ b/nanobot/heartbeat/service.py
@@ -1,92 +1,135 @@
"""Heartbeat service - periodic agent wake-up to check for tasks."""
+from __future__ import annotations
+
import asyncio
from pathlib import Path
-from typing import Any, Callable, Coroutine
+from typing import TYPE_CHECKING, Any, Callable, Coroutine
from loguru import logger
-# Default interval: 30 minutes
-DEFAULT_HEARTBEAT_INTERVAL_S = 30 * 60
+if TYPE_CHECKING:
+ from nanobot.providers.base import LLMProvider
-# The prompt sent to agent during heartbeat
-HEARTBEAT_PROMPT = """Read HEARTBEAT.md in your workspace (if it exists).
-Follow any instructions or tasks listed there.
-If nothing needs attention, reply with just: HEARTBEAT_OK"""
-
-# Token that indicates "nothing to do"
-HEARTBEAT_OK_TOKEN = "HEARTBEAT_OK"
-
-
-def _is_heartbeat_empty(content: str | None) -> bool:
- """Check if HEARTBEAT.md has no actionable content."""
- if not content:
- return True
-
- # Lines to skip: empty, headers, HTML comments, empty checkboxes
- skip_patterns = {"- [ ]", "* [ ]", "- [x]", "* [x]"}
-
- for line in content.split("\n"):
- line = line.strip()
- if not line or line.startswith("#") or line.startswith("