diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index 0c33bc5c8..675865350 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -17,7 +17,7 @@ from nanobot.agent.autocompact import AutoCompact from nanobot.agent.context import ContextBuilder from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook from nanobot.agent.memory import Consolidator, Dream -from nanobot.agent.runner import AgentRunSpec, AgentRunner +from nanobot.agent.runner import _MAX_INJECTIONS_PER_TURN, AgentRunSpec, AgentRunner from nanobot.agent.subagent import SubagentManager from nanobot.agent.tools.cron import CronTool from nanobot.agent.skills import BUILTIN_SKILLS_DIR @@ -207,6 +207,10 @@ class AgentLoop: self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks self._background_tasks: list[asyncio.Task] = [] self._session_locks: dict[str, asyncio.Lock] = {} + # Per-session pending queues for mid-turn message injection. + # When a session has an active task, new messages for that session + # are routed here instead of creating a new task. + self._pending_queues: dict[str, asyncio.Queue] = {} # NANOBOT_MAX_CONCURRENT_REQUESTS: <=0 means unlimited; default 3. _max = int(os.environ.get("NANOBOT_MAX_CONCURRENT_REQUESTS", "3")) self._concurrency_gate: asyncio.Semaphore | None = ( @@ -320,6 +324,12 @@ class AgentLoop: return format_tool_hints(tool_calls) + def _effective_session_key(self, msg: InboundMessage) -> str: + """Return the session key used for task routing and mid-turn injections.""" + if self._unified_session and not msg.session_key_override: + return UNIFIED_SESSION_KEY + return msg.session_key + async def _run_agent_loop( self, initial_messages: list[dict], @@ -331,13 +341,16 @@ class AgentLoop: channel: str = "cli", chat_id: str = "direct", message_id: str | None = None, - ) -> tuple[str | None, list[str], list[dict], str]: + pending_queue: asyncio.Queue | None = None, + ) -> tuple[str | None, list[str], list[dict], str, bool]: """Run the agent iteration loop. *on_stream*: called with each content delta during streaming. *on_stream_end(resuming)*: called when a streaming session finishes. ``resuming=True`` means tool calls follow (spinner should restart); ``resuming=False`` means this is the final response. + + Returns (final_content, tools_used, messages, stop_reason, had_injections). """ loop_hook = _LoopHook( self, @@ -357,31 +370,56 @@ class AgentLoop: return self._set_runtime_checkpoint(session, payload) - result = await self.runner.run( - AgentRunSpec( - initial_messages=initial_messages, - tools=self.tools, - model=self.model, - max_iterations=self.max_iterations, - max_tool_result_chars=self.max_tool_result_chars, - hook=hook, - error_message="Sorry, I encountered an error calling the AI model.", - concurrent_tools=True, - workspace=self.workspace, - session_key=session.key if session else None, - context_window_tokens=self.context_window_tokens, - context_block_limit=self.context_block_limit, - provider_retry_mode=self.provider_retry_mode, - progress_callback=on_progress, - checkpoint_callback=_checkpoint, - ) - ) + async def _drain_pending(*, limit: int = _MAX_INJECTIONS_PER_TURN) -> list[dict[str, Any]]: + """Non-blocking drain of follow-up messages from the pending queue.""" + if pending_queue is None: + return [] + items: list[dict[str, Any]] = [] + while len(items) < limit: + try: + pending_msg = pending_queue.get_nowait() + except asyncio.QueueEmpty: + break + user_content = self.context._build_user_content( + pending_msg.content, + pending_msg.media if pending_msg.media else None, + ) + runtime_ctx = self.context._build_runtime_context( + pending_msg.channel, + pending_msg.chat_id, + self.context.timezone, + ) + if isinstance(user_content, str): + merged: str | list[dict[str, Any]] = f"{runtime_ctx}\n\n{user_content}" + else: + merged = [{"type": "text", "text": runtime_ctx}] + user_content + items.append({"role": "user", "content": merged}) + return items + + result = await self.runner.run(AgentRunSpec( + initial_messages=initial_messages, + tools=self.tools, + model=self.model, + max_iterations=self.max_iterations, + max_tool_result_chars=self.max_tool_result_chars, + hook=hook, + error_message="Sorry, I encountered an error calling the AI model.", + concurrent_tools=True, + workspace=self.workspace, + session_key=session.key if session else None, + context_window_tokens=self.context_window_tokens, + context_block_limit=self.context_block_limit, + provider_retry_mode=self.provider_retry_mode, + progress_callback=on_progress, + checkpoint_callback=_checkpoint, + injection_callback=_drain_pending, + )) self._last_usage = result.usage if result.stop_reason == "max_iterations": logger.warning("Max iterations ({}) reached", self.max_iterations) elif result.stop_reason == "error": logger.error("LLM returned error: {}", (result.final_content or "")[:200]) - return result.final_content, result.tools_used, result.messages, result.stop_reason + return result.final_content, result.tools_used, result.messages, result.stop_reason, result.had_injections async def run(self) -> None: """Run the agent loop, dispatching messages as tasks to stay responsive to /stop.""" @@ -412,13 +450,32 @@ class AgentLoop: if result: await self.bus.publish_outbound(result) continue + effective_key = self._effective_session_key(msg) + # If this session already has an active pending queue (i.e. a task + # is processing this session), route the message there for mid-turn + # injection instead of creating a competing task. + if effective_key in self._pending_queues: + pending_msg = msg + if effective_key != msg.session_key: + pending_msg = dataclasses.replace( + msg, + session_key_override=effective_key, + ) + try: + self._pending_queues[effective_key].put_nowait(pending_msg) + except asyncio.QueueFull: + logger.warning( + "Pending queue full for session {}, falling back to queued task", + effective_key, + ) + else: + logger.info( + "Routed follow-up message to pending queue for session {}", + effective_key, + ) + continue # Compute the effective session key before dispatching # This ensures /stop command can find tasks correctly when unified session is enabled - effective_key = ( - UNIFIED_SESSION_KEY - if self._unified_session and not msg.session_key_override - else msg.session_key - ) task = asyncio.create_task(self._dispatch(msg)) self._active_tasks.setdefault(effective_key, []).append(task) task.add_done_callback( @@ -430,78 +487,91 @@ class AgentLoop: async def _dispatch(self, msg: InboundMessage) -> None: """Process a message: per-session serial, cross-session concurrent.""" - if self._unified_session and not msg.session_key_override: - msg = dataclasses.replace(msg, session_key_override=UNIFIED_SESSION_KEY) - lock = self._session_locks.setdefault(msg.session_key, asyncio.Lock()) + session_key = self._effective_session_key(msg) + if session_key != msg.session_key: + msg = dataclasses.replace(msg, session_key_override=session_key) + lock = self._session_locks.setdefault(session_key, asyncio.Lock()) gate = self._concurrency_gate or nullcontext() - async with lock, gate: - try: - on_stream = on_stream_end = None - if msg.metadata.get("_wants_stream"): - # Split one answer into distinct stream segments. - stream_base_id = f"{msg.session_key}:{time.time_ns()}" - stream_segment = 0 - def _current_stream_id() -> str: - return f"{stream_base_id}:{stream_segment}" + # Register a pending queue so follow-up messages for this session are + # routed here (mid-turn injection) instead of spawning a new task. + pending = asyncio.Queue(maxsize=20) + self._pending_queues[session_key] = pending - async def on_stream(delta: str) -> None: - meta = dict(msg.metadata or {}) - meta["_stream_delta"] = True - meta["_stream_id"] = _current_stream_id() - await self.bus.publish_outbound( - OutboundMessage( - channel=msg.channel, - chat_id=msg.chat_id, + try: + async with lock, gate: + try: + on_stream = on_stream_end = None + if msg.metadata.get("_wants_stream"): + # Split one answer into distinct stream segments. + stream_base_id = f"{msg.session_key}:{time.time_ns()}" + stream_segment = 0 + + def _current_stream_id() -> str: + return f"{stream_base_id}:{stream_segment}" + + async def on_stream(delta: str) -> None: + meta = dict(msg.metadata or {}) + meta["_stream_delta"] = True + meta["_stream_id"] = _current_stream_id() + await self.bus.publish_outbound(OutboundMessage( + channel=msg.channel, chat_id=msg.chat_id, content=delta, metadata=meta, - ) - ) + )) - async def on_stream_end(*, resuming: bool = False) -> None: - nonlocal stream_segment - meta = dict(msg.metadata or {}) - meta["_stream_end"] = True - meta["_resuming"] = resuming - meta["_stream_id"] = _current_stream_id() - await self.bus.publish_outbound( - OutboundMessage( - channel=msg.channel, - chat_id=msg.chat_id, + async def on_stream_end(*, resuming: bool = False) -> None: + nonlocal stream_segment + meta = dict(msg.metadata or {}) + meta["_stream_end"] = True + meta["_resuming"] = resuming + meta["_stream_id"] = _current_stream_id() + await self.bus.publish_outbound(OutboundMessage( + channel=msg.channel, chat_id=msg.chat_id, content="", metadata=meta, - ) - ) - stream_segment += 1 + )) + stream_segment += 1 - response = await self._process_message( - msg, - on_stream=on_stream, - on_stream_end=on_stream_end, - ) - if response is not None: - await self.bus.publish_outbound(response) - elif msg.channel == "cli": - await self.bus.publish_outbound( - OutboundMessage( - channel=msg.channel, - chat_id=msg.chat_id, - content="", - metadata=msg.metadata or {}, - ) + response = await self._process_message( + msg, on_stream=on_stream, on_stream_end=on_stream_end, + pending_queue=pending, ) - except asyncio.CancelledError: - logger.info("Task cancelled for session {}", msg.session_key) - raise - except Exception: - logger.exception("Error processing message for session {}", msg.session_key) - await self.bus.publish_outbound( - OutboundMessage( - channel=msg.channel, - chat_id=msg.chat_id, + if response is not None: + await self.bus.publish_outbound(response) + elif msg.channel == "cli": + await self.bus.publish_outbound(OutboundMessage( + channel=msg.channel, chat_id=msg.chat_id, + content="", metadata=msg.metadata or {}, + )) + except asyncio.CancelledError: + logger.info("Task cancelled for session {}", session_key) + raise + except Exception: + logger.exception("Error processing message for session {}", session_key) + await self.bus.publish_outbound(OutboundMessage( + channel=msg.channel, chat_id=msg.chat_id, content="Sorry, I encountered an error.", + )) + finally: + # Drain any messages still in the pending queue and re-publish + # them to the bus so they are processed as fresh inbound messages + # rather than silently lost. + queue = self._pending_queues.pop(session_key, None) + if queue is not None: + leftover = 0 + while True: + try: + item = queue.get_nowait() + except asyncio.QueueEmpty: + break + await self.bus.publish_inbound(item) + leftover += 1 + if leftover: + logger.info( + "Re-published {} leftover message(s) to bus for session {}", + leftover, session_key, ) - ) async def close_mcp(self) -> None: """Drain pending background archives, then close MCP connections.""" @@ -533,6 +603,7 @@ class AgentLoop: on_progress: Callable[[str], Awaitable[None]] | None = None, on_stream: Callable[[str], Awaitable[None]] | None = None, on_stream_end: Callable[..., Awaitable[None]] | None = None, + pending_queue: asyncio.Queue | None = None, ) -> OutboundMessage | None: """Process a single inbound message and return the response.""" # System messages: parse origin from chat_id ("channel:chat_id") @@ -559,11 +630,8 @@ class AgentLoop: session_summary=pending, current_role=current_role, ) - final_content, _, all_msgs, _ = await self._run_agent_loop( - messages, - session=session, - channel=channel, - chat_id=chat_id, + final_content, _, all_msgs, _, _ = await self._run_agent_loop( + messages, session=session, channel=channel, chat_id=chat_id, message_id=msg.metadata.get("message_id"), ) self._save_turn(session, all_msgs, 1 + len(history)) @@ -623,7 +691,7 @@ class AgentLoop: ) ) - final_content, _, all_msgs, stop_reason = await self._run_agent_loop( + final_content, _, all_msgs, stop_reason, had_injections = await self._run_agent_loop( initial_messages, on_progress=on_progress or _bus_progress, on_stream=on_stream, @@ -632,6 +700,7 @@ class AgentLoop: channel=msg.channel, chat_id=msg.chat_id, message_id=msg.metadata.get("message_id"), + pending_queue=pending_queue, ) if final_content is None or not final_content.strip(): @@ -642,8 +711,15 @@ class AgentLoop: self.sessions.save(session) self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session)) + # When follow-up messages were injected mid-turn, a later natural + # language reply may address those follow-ups and should not be + # suppressed just because MessageTool was used earlier in the turn. + # However, if the turn falls back to the empty-final-response + # placeholder, suppress it when the real user-visible output already + # came from MessageTool. if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn: - return None + if not had_injections or stop_reason == "empty_final_response": + return None preview = final_content[:120] + "..." if len(final_content) > 120 else final_content logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview) diff --git a/nanobot/agent/runner.py b/nanobot/agent/runner.py index 0d8062842..e92d864f2 100644 --- a/nanobot/agent/runner.py +++ b/nanobot/agent/runner.py @@ -4,6 +4,7 @@ from __future__ import annotations import asyncio from dataclasses import dataclass, field +import inspect from pathlib import Path from typing import Any @@ -34,6 +35,8 @@ _DEFAULT_ERROR_MESSAGE = "Sorry, I encountered an error calling the AI model." _PERSISTED_MODEL_ERROR_PLACEHOLDER = "[Assistant reply unavailable due to model error.]" _MAX_EMPTY_RETRIES = 2 _MAX_LENGTH_RECOVERIES = 3 +_MAX_INJECTIONS_PER_TURN = 3 +_MAX_INJECTION_CYCLES = 5 _SNIP_SAFETY_BUFFER = 1024 _MICROCOMPACT_KEEP_RECENT = 10 _MICROCOMPACT_MIN_CHARS = 500 @@ -42,6 +45,9 @@ _COMPACTABLE_TOOLS = frozenset({ "web_search", "web_fetch", "list_dir", }) _BACKFILL_CONTENT = "[Tool result unavailable — call was interrupted or lost]" + + + @dataclass(slots=True) class AgentRunSpec: """Configuration for a single agent execution.""" @@ -66,6 +72,7 @@ class AgentRunSpec: provider_retry_mode: str = "standard" progress_callback: Any | None = None checkpoint_callback: Any | None = None + injection_callback: Any | None = None @dataclass(slots=True) @@ -79,6 +86,7 @@ class AgentRunResult: stop_reason: str = "completed" error: str | None = None tool_events: list[dict[str, str]] = field(default_factory=list) + had_injections: bool = False class AgentRunner: @@ -87,6 +95,90 @@ class AgentRunner: def __init__(self, provider: LLMProvider): self.provider = provider + @staticmethod + def _merge_message_content(left: Any, right: Any) -> str | list[dict[str, Any]]: + if isinstance(left, str) and isinstance(right, str): + return f"{left}\n\n{right}" if left else right + + def _to_blocks(value: Any) -> list[dict[str, Any]]: + if isinstance(value, list): + return [ + item if isinstance(item, dict) else {"type": "text", "text": str(item)} + for item in value + ] + if value is None: + return [] + return [{"type": "text", "text": str(value)}] + + return _to_blocks(left) + _to_blocks(right) + + @classmethod + def _append_injected_messages( + cls, + messages: list[dict[str, Any]], + injections: list[dict[str, Any]], + ) -> None: + """Append injected user messages while preserving role alternation.""" + for injection in injections: + if ( + messages + and injection.get("role") == "user" + and messages[-1].get("role") == "user" + ): + merged = dict(messages[-1]) + merged["content"] = cls._merge_message_content( + merged.get("content"), + injection.get("content"), + ) + messages[-1] = merged + continue + messages.append(injection) + + async def _drain_injections(self, spec: AgentRunSpec) -> list[dict[str, Any]]: + """Drain pending user messages via the injection callback. + + Returns normalized user messages (capped by + ``_MAX_INJECTIONS_PER_TURN``), or an empty list when there is + nothing to inject. Messages beyond the cap are logged so they + are not silently lost. + """ + if spec.injection_callback is None: + return [] + try: + signature = inspect.signature(spec.injection_callback) + accepts_limit = ( + "limit" in signature.parameters + or any( + parameter.kind is inspect.Parameter.VAR_KEYWORD + for parameter in signature.parameters.values() + ) + ) + if accepts_limit: + items = await spec.injection_callback(limit=_MAX_INJECTIONS_PER_TURN) + else: + items = await spec.injection_callback() + except Exception: + logger.exception("injection_callback failed") + return [] + if not items: + return [] + injected_messages: list[dict[str, Any]] = [] + for item in items: + if isinstance(item, dict) and item.get("role") == "user" and "content" in item: + injected_messages.append(item) + continue + text = getattr(item, "content", str(item)) + if text.strip(): + injected_messages.append({"role": "user", "content": text}) + if len(injected_messages) > _MAX_INJECTIONS_PER_TURN: + dropped = len(injected_messages) - _MAX_INJECTIONS_PER_TURN + logger.warning( + "Injection callback returned {} messages, capping to {} ({} dropped)", + len(injected_messages), _MAX_INJECTIONS_PER_TURN, dropped, + ) + injected_messages = injected_messages[:_MAX_INJECTIONS_PER_TURN] + return injected_messages + async def run(self, spec: AgentRunSpec) -> AgentRunResult: hook = spec.hook or AgentHook() messages = list(spec.initial_messages) @@ -99,6 +191,8 @@ class AgentRunner: external_lookup_counts: dict[str, int] = {} empty_content_retries = 0 length_recovery_count = 0 + had_injections = False + injection_cycles = 0 for iteration in range(spec.max_iterations): try: @@ -207,6 +301,17 @@ class AgentRunner: ) empty_content_retries = 0 length_recovery_count = 0 + # Checkpoint 1: drain injections after tools, before next LLM call + if injection_cycles < _MAX_INJECTION_CYCLES: + injections = await self._drain_injections(spec) + if injections: + had_injections = True + injection_cycles += 1 + self._append_injected_messages(messages, injections) + logger.info( + "Injected {} follow-up message(s) after tool execution ({}/{})", + len(injections), injection_cycles, _MAX_INJECTION_CYCLES, + ) await hook.after_iteration(context) continue @@ -263,8 +368,49 @@ class AgentRunner: await hook.after_iteration(context) continue + assistant_message: dict[str, Any] | None = None + if response.finish_reason != "error" and not is_blank_text(clean): + assistant_message = build_assistant_message( + clean, + reasoning_content=response.reasoning_content, + thinking_blocks=response.thinking_blocks, + ) + + # Check for mid-turn injections BEFORE signaling stream end. + # If injections are found we keep the stream alive (resuming=True) + # so streaming channels don't prematurely finalize the card. + _injected_after_final = False + if injection_cycles < _MAX_INJECTION_CYCLES: + injections = await self._drain_injections(spec) + if injections: + had_injections = True + injection_cycles += 1 + _injected_after_final = True + if assistant_message is not None: + messages.append(assistant_message) + await self._emit_checkpoint( + spec, + { + "phase": "final_response", + "iteration": iteration, + "model": spec.model, + "assistant_message": assistant_message, + "completed_tool_results": [], + "pending_tool_calls": [], + }, + ) + self._append_injected_messages(messages, injections) + logger.info( + "Injected {} follow-up message(s) after final response ({}/{})", + len(injections), injection_cycles, _MAX_INJECTION_CYCLES, + ) + if hook.wants_streaming(): - await hook.on_stream_end(context, resuming=False) + await hook.on_stream_end(context, resuming=_injected_after_final) + + if _injected_after_final: + await hook.after_iteration(context) + continue if response.finish_reason == "error": final_content = clean or spec.error_message or _DEFAULT_ERROR_MESSAGE @@ -287,7 +433,7 @@ class AgentRunner: await hook.after_iteration(context) break - messages.append(build_assistant_message( + messages.append(assistant_message or build_assistant_message( clean, reasoning_content=response.reasoning_content, thinking_blocks=response.thinking_blocks, @@ -330,6 +476,7 @@ class AgentRunner: stop_reason=stop_reason, error=error, tool_events=tool_events, + had_injections=had_injections, ) def _build_request_kwargs( diff --git a/nanobot/channels/qq.py b/nanobot/channels/qq.py index bef2cf27a..484eed6e2 100644 --- a/nanobot/channels/qq.py +++ b/nanobot/channels/qq.py @@ -242,43 +242,46 @@ class QQChannel(BaseChannel): async def send(self, msg: OutboundMessage) -> None: """Send attachments first, then text.""" - if not self._client: - logger.warning("QQ client not initialized") - return + try: + if not self._client: + logger.warning("QQ client not initialized") + return - msg_id = msg.metadata.get("message_id") - chat_type = self._chat_type_cache.get(msg.chat_id, "c2c") - is_group = chat_type == "group" + msg_id = msg.metadata.get("message_id") + chat_type = self._chat_type_cache.get(msg.chat_id, "c2c") + is_group = chat_type == "group" - # 1) Send media - for media_ref in msg.media or []: - ok = await self._send_media( - chat_id=msg.chat_id, - media_ref=media_ref, - msg_id=msg_id, - is_group=is_group, - ) - if not ok: - filename = ( - os.path.basename(urlparse(media_ref).path) - or os.path.basename(media_ref) - or "file" + # 1) Send media + for media_ref in msg.media or []: + ok = await self._send_media( + chat_id=msg.chat_id, + media_ref=media_ref, + msg_id=msg_id, + is_group=is_group, ) + if not ok: + filename = ( + os.path.basename(urlparse(media_ref).path) + or os.path.basename(media_ref) + or "file" + ) + await self._send_text_only( + chat_id=msg.chat_id, + is_group=is_group, + msg_id=msg_id, + content=f"[Attachment send failed: {filename}]", + ) + + # 2) Send text + if msg.content and msg.content.strip(): await self._send_text_only( chat_id=msg.chat_id, is_group=is_group, msg_id=msg_id, - content=f"[Attachment send failed: {filename}]", + content=msg.content.strip(), ) - - # 2) Send text - if msg.content and msg.content.strip(): - await self._send_text_only( - chat_id=msg.chat_id, - is_group=is_group, - msg_id=msg_id, - content=msg.content.strip(), - ) + except Exception: + logger.exception("Error sending QQ message to chat_id={}", msg.chat_id) async def _send_text_only( self, @@ -438,15 +441,26 @@ class QQChannel(BaseChannel): endpoint = "/v2/users/{openid}/files" id_key = "openid" - payload = { + payload: dict[str, Any] = { id_key: chat_id, "file_type": file_type, "file_data": file_data, - "file_name": file_name, "srv_send_msg": srv_send_msg, } + # Only pass file_name for non-image types (file_type=4). + # Passing file_name for images causes QQ client to render them as + # file attachments instead of inline images. + if file_type != QQ_FILE_TYPE_IMAGE and file_name: + payload["file_name"] = file_name + route = Route("POST", endpoint, **{id_key: chat_id}) - return await self._client.api._http.request(route, json=payload) + result = await self._client.api._http.request(route, json=payload) + + # Extract only the file_info field to avoid extra fields (file_uuid, ttl, etc.) + # that may confuse QQ client when sending the media object. + if isinstance(result, dict) and "file_info" in result: + return {"file_info": result["file_info"]} + return result # --------------------------- # Inbound (receive) @@ -454,58 +468,68 @@ class QQChannel(BaseChannel): async def _on_message(self, data: C2CMessage | GroupMessage, is_group: bool = False) -> None: """Parse inbound message, download attachments, and publish to the bus.""" - if data.id in self._processed_ids: - return - self._processed_ids.append(data.id) + try: + if data.id in self._processed_ids: + return + self._processed_ids.append(data.id) - if is_group: - chat_id = data.group_openid - user_id = data.author.member_openid - self._chat_type_cache[chat_id] = "group" - else: - chat_id = str( - getattr(data.author, "id", None) or getattr(data.author, "user_openid", "unknown") - ) - user_id = chat_id - self._chat_type_cache[chat_id] = "c2c" - - content = (data.content or "").strip() - - # the data used by tests don't contain attachments property - # so we use getattr with a default of [] to avoid AttributeError in tests - attachments = getattr(data, "attachments", None) or [] - media_paths, recv_lines, att_meta = await self._handle_attachments(attachments) - - # Compose content that always contains actionable saved paths - if recv_lines: - tag = "[Image]" if any(_is_image_name(Path(p).name) for p in media_paths) else "[File]" - file_block = "Received files:\n" + "\n".join(recv_lines) - content = f"{content}\n\n{file_block}".strip() if content else f"{tag}\n{file_block}" - - if not content and not media_paths: - return - - if self.config.ack_message: - try: - await self._send_text_only( - chat_id=chat_id, - is_group=is_group, - msg_id=data.id, - content=self.config.ack_message, + if is_group: + chat_id = data.group_openid + user_id = data.author.member_openid + self._chat_type_cache[chat_id] = "group" + else: + chat_id = str( + getattr(data.author, "id", None) + or getattr(data.author, "user_openid", "unknown") ) - except Exception: - logger.debug("QQ ack message failed for chat_id={}", chat_id) + user_id = chat_id + self._chat_type_cache[chat_id] = "c2c" - await self._handle_message( - sender_id=user_id, - chat_id=chat_id, - content=content, - media=media_paths if media_paths else None, - metadata={ - "message_id": data.id, - "attachments": att_meta, - }, - ) + content = (data.content or "").strip() + + # the data used by tests don't contain attachments property + # so we use getattr with a default of [] to avoid AttributeError in tests + attachments = getattr(data, "attachments", None) or [] + media_paths, recv_lines, att_meta = await self._handle_attachments(attachments) + + # Compose content that always contains actionable saved paths + if recv_lines: + tag = ( + "[Image]" + if any(_is_image_name(Path(p).name) for p in media_paths) + else "[File]" + ) + file_block = "Received files:\n" + "\n".join(recv_lines) + content = ( + f"{content}\n\n{file_block}".strip() if content else f"{tag}\n{file_block}" + ) + + if not content and not media_paths: + return + + if self.config.ack_message: + try: + await self._send_text_only( + chat_id=chat_id, + is_group=is_group, + msg_id=data.id, + content=self.config.ack_message, + ) + except Exception: + logger.debug("QQ ack message failed for chat_id={}", chat_id) + + await self._handle_message( + sender_id=user_id, + chat_id=chat_id, + content=content, + media=media_paths if media_paths else None, + metadata={ + "message_id": data.id, + "attachments": att_meta, + }, + ) + except Exception: + logger.exception("Error handling QQ inbound message id={}", getattr(data, "id", "?")) async def _handle_attachments( self, @@ -520,7 +544,9 @@ class QQChannel(BaseChannel): return media_paths, recv_lines, att_meta for att in attachments: - url, filename, ctype = att.url, att.filename, att.content_type + url = getattr(att, "url", None) or "" + filename = getattr(att, "filename", None) or "" + ctype = getattr(att, "content_type", None) or "" logger.info("Downloading file from QQ: {}", filename or url) local_path = await self._download_to_media_dir_chunked(url, filename_hint=filename) @@ -555,6 +581,10 @@ class QQChannel(BaseChannel): Enforces a max download size and writes to a .part temp file that is atomically renamed on success. """ + # Handle protocol-relative URLs (e.g. "//multimedia.nt.qq.com/...") + if url.startswith("//"): + url = f"https:{url}" + if not self._http: self._http = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=120)) diff --git a/nanobot/channels/wecom.py b/nanobot/channels/wecom.py index 05ad14825..a7d7f1fe2 100644 --- a/nanobot/channels/wecom.py +++ b/nanobot/channels/wecom.py @@ -1,9 +1,13 @@ """WeCom (Enterprise WeChat) channel implementation using wecom_aibot_sdk.""" import asyncio +import base64 +import hashlib import importlib.util import os +import re from collections import OrderedDict +from pathlib import Path from typing import Any from loguru import logger @@ -17,6 +21,37 @@ from pydantic import Field WECOM_AVAILABLE = importlib.util.find_spec("wecom_aibot_sdk") is not None +# Upload safety limits (matching QQ channel defaults) +WECOM_UPLOAD_MAX_BYTES = 1024 * 1024 * 200 # 200MB + +# Replace unsafe characters with "_", keep Chinese and common safe punctuation. +_SAFE_NAME_RE = re.compile(r"[^\w.\-()\[\]()【】\u4e00-\u9fff]+", re.UNICODE) + + +def _sanitize_filename(name: str) -> str: + """Sanitize filename to avoid traversal and problematic chars.""" + name = (name or "").strip() + name = Path(name).name + name = _SAFE_NAME_RE.sub("_", name).strip("._ ") + return name + + +_IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp"} +_VIDEO_EXTS = {".mp4", ".avi", ".mov"} +_AUDIO_EXTS = {".amr", ".mp3", ".wav", ".ogg"} + + +def _guess_wecom_media_type(filename: str) -> str: + """Classify file extension as WeCom media_type string.""" + ext = Path(filename).suffix.lower() + if ext in _IMAGE_EXTS: + return "image" + if ext in _VIDEO_EXTS: + return "video" + if ext in _AUDIO_EXTS: + return "voice" + return "file" + class WecomConfig(Base): """WeCom (Enterprise WeChat) AI Bot channel configuration.""" @@ -217,6 +252,7 @@ class WecomChannel(BaseChannel): chat_id = body.get("chatid", sender_id) content_parts = [] + media_paths: list[str] = [] if msg_type == "text": text = body.get("text", {}).get("content", "") @@ -232,7 +268,8 @@ class WecomChannel(BaseChannel): file_path = await self._download_and_save_media(file_url, aes_key, "image") if file_path: filename = os.path.basename(file_path) - content_parts.append(f"[image: {filename}]\n[Image: source: {file_path}]") + content_parts.append(f"[image: {filename}]") + media_paths.append(file_path) else: content_parts.append("[image: download failed]") else: @@ -256,7 +293,8 @@ class WecomChannel(BaseChannel): if file_url and aes_key: file_path = await self._download_and_save_media(file_url, aes_key, "file", file_name) if file_path: - content_parts.append(f"[file: {file_name}]\n[File: source: {file_path}]") + content_parts.append(f"[file: {file_name}]") + media_paths.append(file_path) else: content_parts.append(f"[file: {file_name}: download failed]") else: @@ -286,12 +324,11 @@ class WecomChannel(BaseChannel): self._chat_frames[chat_id] = frame # Forward to message bus - # Note: media paths are included in content for broader model compatibility await self._handle_message( sender_id=sender_id, chat_id=chat_id, content=content, - media=None, + media=media_paths or None, metadata={ "message_id": msg_id, "msg_type": msg_type, @@ -322,13 +359,21 @@ class WecomChannel(BaseChannel): logger.warning("Failed to download media from WeCom") return None + if len(data) > WECOM_UPLOAD_MAX_BYTES: + logger.warning( + "WeCom inbound media too large: {} bytes (max {})", + len(data), + WECOM_UPLOAD_MAX_BYTES, + ) + return None + media_dir = get_media_dir("wecom") if not filename: filename = fname or f"{media_type}_{hash(file_url) % 100000}" - filename = os.path.basename(filename) + filename = _sanitize_filename(filename) file_path = media_dir / filename - file_path.write_bytes(data) + await asyncio.to_thread(file_path.write_bytes, data) logger.debug("Downloaded {} to {}", media_type, file_path) return str(file_path) @@ -336,6 +381,100 @@ class WecomChannel(BaseChannel): logger.error("Error downloading media: {}", e) return None + async def _upload_media_ws( + self, client: Any, file_path: str, + ) -> "tuple[str, str] | tuple[None, None]": + """Upload a local file to WeCom via WebSocket 3-step protocol (base64). + + Uses the WeCom WebSocket upload commands directly via + ``client._ws_manager.send_reply()``: + + ``aibot_upload_media_init`` → upload_id + ``aibot_upload_media_chunk`` × N (≤512 KB raw per chunk, base64) + ``aibot_upload_media_finish`` → media_id + + Returns (media_id, media_type) on success, (None, None) on failure. + """ + from wecom_aibot_sdk.utils import generate_req_id as _gen_req_id + + try: + fname = os.path.basename(file_path) + media_type = _guess_wecom_media_type(fname) + + # Read file size and data in a thread to avoid blocking the event loop + def _read_file(): + file_size = os.path.getsize(file_path) + if file_size > WECOM_UPLOAD_MAX_BYTES: + raise ValueError( + f"File too large: {file_size} bytes (max {WECOM_UPLOAD_MAX_BYTES})" + ) + with open(file_path, "rb") as f: + return file_size, f.read() + + file_size, data = await asyncio.to_thread(_read_file) + # MD5 is used for file integrity only, not cryptographic security + md5_hash = hashlib.md5(data).hexdigest() + + CHUNK_SIZE = 512 * 1024 # 512 KB raw (before base64) + mv = memoryview(data) + chunk_list = [bytes(mv[i : i + CHUNK_SIZE]) for i in range(0, file_size, CHUNK_SIZE)] + n_chunks = len(chunk_list) + del mv, data + + # Step 1: init + req_id = _gen_req_id("upload_init") + resp = await client._ws_manager.send_reply(req_id, { + "type": media_type, + "filename": fname, + "total_size": file_size, + "total_chunks": n_chunks, + "md5": md5_hash, + }, "aibot_upload_media_init") + if resp.errcode != 0: + logger.warning("WeCom upload init failed ({}): {}", resp.errcode, resp.errmsg) + return None, None + upload_id = resp.body.get("upload_id") if resp.body else None + if not upload_id: + logger.warning("WeCom upload init: no upload_id in response") + return None, None + + # Step 2: send chunks + for i, chunk in enumerate(chunk_list): + req_id = _gen_req_id("upload_chunk") + resp = await client._ws_manager.send_reply(req_id, { + "upload_id": upload_id, + "chunk_index": i, + "base64_data": base64.b64encode(chunk).decode(), + }, "aibot_upload_media_chunk") + if resp.errcode != 0: + logger.warning("WeCom upload chunk {} failed ({}): {}", i, resp.errcode, resp.errmsg) + return None, None + + # Step 3: finish + req_id = _gen_req_id("upload_finish") + resp = await client._ws_manager.send_reply(req_id, { + "upload_id": upload_id, + }, "aibot_upload_media_finish") + if resp.errcode != 0: + logger.warning("WeCom upload finish failed ({}): {}", resp.errcode, resp.errmsg) + return None, None + + media_id = resp.body.get("media_id") if resp.body else None + if not media_id: + logger.warning("WeCom upload finish: no media_id in response body={}", resp.body) + return None, None + + suffix = "..." if len(media_id) > 16 else "" + logger.debug("WeCom uploaded {} ({}) → media_id={}", fname, media_type, media_id[:16] + suffix) + return media_id, media_type + + except ValueError as e: + logger.warning("WeCom upload skipped for {}: {}", file_path, e) + return None, None + except Exception as e: + logger.error("WeCom _upload_media_ws error for {}: {}", file_path, e) + return None, None + async def send(self, msg: OutboundMessage) -> None: """Send a message through WeCom.""" if not self._client: @@ -343,29 +482,59 @@ class WecomChannel(BaseChannel): return try: - content = msg.content.strip() - if not content: - return + content = (msg.content or "").strip() + is_progress = bool(msg.metadata.get("_progress")) # Get the stored frame for this chat frame = self._chat_frames.get(msg.chat_id) - if not frame: - logger.warning("No frame found for chat {}, cannot reply", msg.chat_id) + + # Send media files via WebSocket upload + for file_path in msg.media or []: + if not os.path.isfile(file_path): + logger.warning("WeCom media file not found: {}", file_path) + continue + media_id, media_type = await self._upload_media_ws(self._client, file_path) + if media_id: + if frame: + await self._client.reply(frame, { + "msgtype": media_type, + media_type: {"media_id": media_id}, + }) + else: + await self._client.send_message(msg.chat_id, { + "msgtype": media_type, + media_type: {"media_id": media_id}, + }) + logger.debug("WeCom sent {} → {}", media_type, msg.chat_id) + else: + content += f"\n[file upload failed: {os.path.basename(file_path)}]" + + if not content: return - # Use streaming reply for better UX - stream_id = self._generate_req_id("stream") + if frame: + # Both progress and final messages must use reply_stream (cmd="aibot_respond_msg"). + # The plain reply() uses cmd="reply" which does not support "text" msgtype + # and causes errcode=40008 from WeCom API. + stream_id = self._generate_req_id("stream") + await self._client.reply_stream( + frame, + stream_id, + content, + finish=not is_progress, + ) + logger.debug( + "WeCom {} sent to {}", + "progress" if is_progress else "message", + msg.chat_id, + ) + else: + # No frame (e.g. cron push): proactive send only supports markdown + await self._client.send_message(msg.chat_id, { + "msgtype": "markdown", + "markdown": {"content": content}, + }) + logger.info("WeCom proactive send to {}", msg.chat_id) - # Send as streaming message with finish=True - await self._client.reply_stream( - frame, - stream_id, - content, - finish=True, - ) - - logger.debug("WeCom message sent to {}", msg.chat_id) - - except Exception as e: - logger.error("Error sending WeCom message: {}", e) - raise + except Exception: + logger.exception("Error sending WeCom message to chat_id={}", msg.chat_id) diff --git a/tests/agent/test_hook_composite.py b/tests/agent/test_hook_composite.py index 672f38ed2..8971d48ec 100644 --- a/tests/agent/test_hook_composite.py +++ b/tests/agent/test_hook_composite.py @@ -307,7 +307,7 @@ async def test_agent_loop_extra_hook_receives_calls(tmp_path): ) loop.tools.get_definitions = MagicMock(return_value=[]) - content, tools_used, messages, _ = await loop._run_agent_loop( + content, tools_used, messages, _, _ = await loop._run_agent_loop( [{"role": "user", "content": "hi"}] ) @@ -331,7 +331,7 @@ async def test_agent_loop_extra_hook_error_isolation(tmp_path): ) loop.tools.get_definitions = MagicMock(return_value=[]) - content, _, _, _ = await loop._run_agent_loop( + content, _, _, _, _ = await loop._run_agent_loop( [{"role": "user", "content": "hi"}] ) @@ -373,7 +373,7 @@ async def test_agent_loop_no_hooks_backward_compat(tmp_path): loop.tools.execute = AsyncMock(return_value="ok") loop.max_iterations = 2 - content, tools_used, _, _ = await loop._run_agent_loop([]) + content, tools_used, _, _, _ = await loop._run_agent_loop([]) assert content == ( "I reached the maximum number of tool call iterations (2) " "without completing the task. You can try breaking the task into smaller steps." diff --git a/tests/agent/test_runner.py b/tests/agent/test_runner.py index 45da0896c..a62457aa8 100644 --- a/tests/agent/test_runner.py +++ b/tests/agent/test_runner.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio +import base64 import os import time from unittest.mock import AsyncMock, MagicMock, patch @@ -798,7 +799,7 @@ async def test_loop_max_iterations_message_stays_stable(tmp_path): loop.tools.execute = AsyncMock(return_value="ok") loop.max_iterations = 2 - final_content, _, _, _ = await loop._run_agent_loop([]) + final_content, _, _, _, _ = await loop._run_agent_loop([]) assert final_content == ( "I reached the maximum number of tool call iterations (2) " @@ -825,7 +826,7 @@ async def test_loop_stream_filter_handles_think_only_prefix_without_crashing(tmp async def on_stream_end(*, resuming: bool = False) -> None: endings.append(resuming) - final_content, _, _, _ = await loop._run_agent_loop( + final_content, _, _, _, _ = await loop._run_agent_loop( [], on_stream=on_stream, on_stream_end=on_stream_end, @@ -849,7 +850,7 @@ async def test_loop_retries_think_only_final_response(tmp_path): loop.provider.chat_with_retry = chat_with_retry - final_content, _, _, _ = await loop._run_agent_loop([]) + final_content, _, _, _, _ = await loop._run_agent_loop([]) assert final_content == "Recovered answer" assert call_count["n"] == 2 @@ -1722,3 +1723,690 @@ def test_governance_fallback_still_repairs_orphans(): repaired = AgentRunner._backfill_missing_tool_results(repaired) # Orphan tool result should be gone. assert not any(m.get("tool_call_id") == "orphan_tc" for m in repaired) +# ── Mid-turn injection tests ────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_drain_injections_returns_empty_when_no_callback(): + """No injection_callback → empty list.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock() + runner = AgentRunner(provider) + tools = MagicMock() + tools.get_definitions.return_value = [] + spec = AgentRunSpec( + initial_messages=[], tools=tools, model="m", + max_iterations=1, max_tool_result_chars=1000, + injection_callback=None, + ) + result = await runner._drain_injections(spec) + assert result == [] + + +@pytest.mark.asyncio +async def test_drain_injections_extracts_content_from_inbound_messages(): + """Should extract .content from InboundMessage objects.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.bus.events import InboundMessage + + provider = MagicMock() + runner = AgentRunner(provider) + tools = MagicMock() + tools.get_definitions.return_value = [] + + msgs = [ + InboundMessage(channel="cli", sender_id="u", chat_id="c", content="hello"), + InboundMessage(channel="cli", sender_id="u", chat_id="c", content="world"), + ] + + async def cb(): + return msgs + + spec = AgentRunSpec( + initial_messages=[], tools=tools, model="m", + max_iterations=1, max_tool_result_chars=1000, + injection_callback=cb, + ) + result = await runner._drain_injections(spec) + assert result == [ + {"role": "user", "content": "hello"}, + {"role": "user", "content": "world"}, + ] + + +@pytest.mark.asyncio +async def test_drain_injections_passes_limit_to_callback_when_supported(): + """Limit-aware callbacks can preserve overflow in their own queue.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner, _MAX_INJECTIONS_PER_TURN + from nanobot.bus.events import InboundMessage + + provider = MagicMock() + runner = AgentRunner(provider) + tools = MagicMock() + tools.get_definitions.return_value = [] + seen_limits: list[int] = [] + + msgs = [ + InboundMessage(channel="cli", sender_id="u", chat_id="c", content=f"msg{i}") + for i in range(_MAX_INJECTIONS_PER_TURN + 3) + ] + + async def cb(*, limit: int): + seen_limits.append(limit) + return msgs[:limit] + + spec = AgentRunSpec( + initial_messages=[], tools=tools, model="m", + max_iterations=1, max_tool_result_chars=1000, + injection_callback=cb, + ) + result = await runner._drain_injections(spec) + assert seen_limits == [_MAX_INJECTIONS_PER_TURN] + assert result == [ + {"role": "user", "content": "msg0"}, + {"role": "user", "content": "msg1"}, + {"role": "user", "content": "msg2"}, + ] + + +@pytest.mark.asyncio +async def test_drain_injections_skips_empty_content(): + """Messages with blank content should be filtered out.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.bus.events import InboundMessage + + provider = MagicMock() + runner = AgentRunner(provider) + tools = MagicMock() + tools.get_definitions.return_value = [] + + msgs = [ + InboundMessage(channel="cli", sender_id="u", chat_id="c", content=""), + InboundMessage(channel="cli", sender_id="u", chat_id="c", content=" "), + InboundMessage(channel="cli", sender_id="u", chat_id="c", content="valid"), + ] + + async def cb(): + return msgs + + spec = AgentRunSpec( + initial_messages=[], tools=tools, model="m", + max_iterations=1, max_tool_result_chars=1000, + injection_callback=cb, + ) + result = await runner._drain_injections(spec) + assert result == [{"role": "user", "content": "valid"}] + + +@pytest.mark.asyncio +async def test_drain_injections_handles_callback_exception(): + """If the callback raises, return empty list (error is logged).""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock() + runner = AgentRunner(provider) + tools = MagicMock() + tools.get_definitions.return_value = [] + + async def cb(): + raise RuntimeError("boom") + + spec = AgentRunSpec( + initial_messages=[], tools=tools, model="m", + max_iterations=1, max_tool_result_chars=1000, + injection_callback=cb, + ) + result = await runner._drain_injections(spec) + assert result == [] + + +@pytest.mark.asyncio +async def test_checkpoint1_injects_after_tool_execution(): + """Follow-up messages are injected after tool execution, before next LLM call.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.bus.events import InboundMessage + + provider = MagicMock() + call_count = {"n": 0} + captured_messages = [] + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + captured_messages.append(list(messages)) + if call_count["n"] == 1: + return LLMResponse( + content="using tool", + tool_calls=[ToolCallRequest(id="c1", name="read_file", arguments={"path": "x"})], + usage={}, + ) + return LLMResponse(content="final answer", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + tools.execute = AsyncMock(return_value="file content") + + injection_queue = asyncio.Queue() + + async def inject_cb(): + items = [] + while not injection_queue.empty(): + items.append(await injection_queue.get()) + return items + + # Put a follow-up message in the queue before the run starts + await injection_queue.put( + InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up question") + ) + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "hello"}], + tools=tools, + model="test-model", + max_iterations=5, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + injection_callback=inject_cb, + )) + + assert result.had_injections is True + assert result.final_content == "final answer" + # The second call should have the injected user message + assert call_count["n"] == 2 + last_messages = captured_messages[-1] + injected = [m for m in last_messages if m.get("role") == "user" and m.get("content") == "follow-up question"] + assert len(injected) == 1 + + +@pytest.mark.asyncio +async def test_checkpoint2_injects_after_final_response_with_resuming_stream(): + """After final response, if injections exist, stream_end should get resuming=True.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.agent.hook import AgentHook, AgentHookContext + from nanobot.bus.events import InboundMessage + + provider = MagicMock() + call_count = {"n": 0} + stream_end_calls = [] + + class TrackingHook(AgentHook): + def wants_streaming(self) -> bool: + return True + + async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None: + stream_end_calls.append(resuming) + + def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None: + return content + + async def chat_stream_with_retry(*, messages, on_content_delta=None, **kwargs): + call_count["n"] += 1 + if call_count["n"] == 1: + return LLMResponse(content="first answer", tool_calls=[], usage={}) + return LLMResponse(content="second answer", tool_calls=[], usage={}) + + provider.chat_stream_with_retry = chat_stream_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + + injection_queue = asyncio.Queue() + + async def inject_cb(): + items = [] + while not injection_queue.empty(): + items.append(await injection_queue.get()) + return items + + # Inject a follow-up that arrives during the first response + await injection_queue.put( + InboundMessage(channel="cli", sender_id="u", chat_id="c", content="quick follow-up") + ) + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "hello"}], + tools=tools, + model="test-model", + max_iterations=5, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + hook=TrackingHook(), + injection_callback=inject_cb, + )) + + assert result.had_injections is True + assert result.final_content == "second answer" + assert call_count["n"] == 2 + # First stream_end should have resuming=True (because injections found) + assert stream_end_calls[0] is True + # Second (final) stream_end should have resuming=False + assert stream_end_calls[-1] is False + + +@pytest.mark.asyncio +async def test_checkpoint2_preserves_final_response_in_history_before_followup(): + """A follow-up injected after a final answer must still see that answer in history.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner + from nanobot.bus.events import InboundMessage + + provider = MagicMock() + call_count = {"n": 0} + captured_messages = [] + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + captured_messages.append([dict(message) for message in messages]) + if call_count["n"] == 1: + return LLMResponse(content="first answer", tool_calls=[], usage={}) + return LLMResponse(content="second answer", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + + injection_queue = asyncio.Queue() + + async def inject_cb(): + items = [] + while not injection_queue.empty(): + items.append(await injection_queue.get()) + return items + + await injection_queue.put( + InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up question") + ) + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "hello"}], + tools=tools, + model="test-model", + max_iterations=5, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + injection_callback=inject_cb, + )) + + assert result.final_content == "second answer" + assert call_count["n"] == 2 + assert captured_messages[-1] == [ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "first answer"}, + {"role": "user", "content": "follow-up question"}, + ] + assert [ + {"role": message["role"], "content": message["content"]} + for message in result.messages + if message.get("role") == "assistant" + ] == [ + {"role": "assistant", "content": "first answer"}, + {"role": "assistant", "content": "second answer"}, + ] + + +@pytest.mark.asyncio +async def test_loop_injected_followup_preserves_image_media(tmp_path): + """Mid-turn follow-ups with images should keep multimodal content.""" + from nanobot.agent.loop import AgentLoop + from nanobot.bus.events import InboundMessage + from nanobot.bus.queue import MessageBus + + image_path = tmp_path / "followup.png" + image_path.write_bytes(base64.b64decode( + "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO+yF9kAAAAASUVORK5CYII=" + )) + + bus = MessageBus() + provider = MagicMock() + provider.get_default_model.return_value = "test-model" + captured_messages: list[list[dict]] = [] + call_count = {"n": 0} + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + captured_messages.append(list(messages)) + if call_count["n"] == 1: + return LLMResponse(content="first answer", tool_calls=[], usage={}) + return LLMResponse(content="second answer", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model") + loop.tools.get_definitions = MagicMock(return_value=[]) + + pending_queue = asyncio.Queue() + await pending_queue.put(InboundMessage( + channel="cli", + sender_id="u", + chat_id="c", + content="", + media=[str(image_path)], + )) + + final_content, _, _, _, had_injections = await loop._run_agent_loop( + [{"role": "user", "content": "hello"}], + channel="cli", + chat_id="c", + pending_queue=pending_queue, + ) + + assert final_content == "second answer" + assert had_injections is True + assert call_count["n"] == 2 + injected_user_messages = [ + message for message in captured_messages[-1] + if message.get("role") == "user" and isinstance(message.get("content"), list) + ] + assert injected_user_messages + assert any( + block.get("type") == "image_url" + for block in injected_user_messages[-1]["content"] + if isinstance(block, dict) + ) + + +@pytest.mark.asyncio +async def test_runner_merges_multiple_injected_user_messages_without_losing_media(): + """Multiple injected follow-ups should not create lossy consecutive user messages.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock() + call_count = {"n": 0} + captured_messages = [] + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + captured_messages.append([dict(message) for message in messages]) + if call_count["n"] == 1: + return LLMResponse(content="first answer", tool_calls=[], usage={}) + return LLMResponse(content="second answer", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + + async def inject_cb(): + if call_count["n"] == 1: + return [ + { + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}}, + {"type": "text", "text": "look at this"}, + ], + }, + {"role": "user", "content": "and answer briefly"}, + ] + return [] + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "hello"}], + tools=tools, + model="test-model", + max_iterations=5, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + injection_callback=inject_cb, + )) + + assert result.final_content == "second answer" + assert call_count["n"] == 2 + second_call = captured_messages[-1] + user_messages = [message for message in second_call if message.get("role") == "user"] + assert len(user_messages) == 2 + injected = user_messages[-1] + assert isinstance(injected["content"], list) + assert any( + block.get("type") == "image_url" + for block in injected["content"] + if isinstance(block, dict) + ) + assert any( + block.get("type") == "text" and block.get("text") == "and answer briefly" + for block in injected["content"] + if isinstance(block, dict) + ) + + +@pytest.mark.asyncio +async def test_injection_cycles_capped_at_max(): + """Injection cycles should be capped at _MAX_INJECTION_CYCLES.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner, _MAX_INJECTION_CYCLES + from nanobot.bus.events import InboundMessage + + provider = MagicMock() + call_count = {"n": 0} + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + return LLMResponse(content=f"answer-{call_count['n']}", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + + drain_count = {"n": 0} + + async def inject_cb(): + drain_count["n"] += 1 + # Only inject for the first _MAX_INJECTION_CYCLES drains + if drain_count["n"] <= _MAX_INJECTION_CYCLES: + return [InboundMessage(channel="cli", sender_id="u", chat_id="c", content=f"msg-{drain_count['n']}")] + return [] + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "start"}], + tools=tools, + model="test-model", + max_iterations=20, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + injection_callback=inject_cb, + )) + + assert result.had_injections is True + # Should be capped: _MAX_INJECTION_CYCLES injection rounds + 1 final round + assert call_count["n"] == _MAX_INJECTION_CYCLES + 1 + + +@pytest.mark.asyncio +async def test_no_injections_flag_is_false_by_default(): + """had_injections should be False when no injection callback or no messages.""" + from nanobot.agent.runner import AgentRunSpec, AgentRunner + + provider = MagicMock() + + async def chat_with_retry(**kwargs): + return LLMResponse(content="done", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + tools = MagicMock() + tools.get_definitions.return_value = [] + + runner = AgentRunner(provider) + result = await runner.run(AgentRunSpec( + initial_messages=[{"role": "user", "content": "hi"}], + tools=tools, + model="test-model", + max_iterations=1, + max_tool_result_chars=_MAX_TOOL_RESULT_CHARS, + )) + + assert result.had_injections is False + + +@pytest.mark.asyncio +async def test_pending_queue_cleanup_on_dispatch(tmp_path): + """_pending_queues should be cleaned up after _dispatch completes.""" + loop = _make_loop(tmp_path) + + async def chat_with_retry(**kwargs): + return LLMResponse(content="done", tool_calls=[], usage={}) + + loop.provider.chat_with_retry = chat_with_retry + + from nanobot.bus.events import InboundMessage + + msg = InboundMessage(channel="cli", sender_id="u", chat_id="c", content="hello") + # The queue should not exist before dispatch + assert msg.session_key not in loop._pending_queues + + await loop._dispatch(msg) + + # The queue should be cleaned up after dispatch + assert msg.session_key not in loop._pending_queues + + +@pytest.mark.asyncio +async def test_followup_routed_to_pending_queue(tmp_path): + """Unified-session follow-ups should route into the active pending queue.""" + from nanobot.agent.loop import UNIFIED_SESSION_KEY + from nanobot.bus.events import InboundMessage + + loop = _make_loop(tmp_path) + loop._unified_session = True + loop._dispatch = AsyncMock() # type: ignore[method-assign] + + pending = asyncio.Queue(maxsize=20) + loop._pending_queues[UNIFIED_SESSION_KEY] = pending + + run_task = asyncio.create_task(loop.run()) + msg = InboundMessage(channel="discord", sender_id="u", chat_id="c", content="follow-up") + await loop.bus.publish_inbound(msg) + + deadline = time.time() + 2 + while pending.empty() and time.time() < deadline: + await asyncio.sleep(0.01) + + loop.stop() + await asyncio.wait_for(run_task, timeout=2) + + assert loop._dispatch.await_count == 0 + assert not pending.empty() + queued_msg = pending.get_nowait() + assert queued_msg.content == "follow-up" + assert queued_msg.session_key == UNIFIED_SESSION_KEY + + +@pytest.mark.asyncio +async def test_pending_queue_preserves_overflow_for_next_injection_cycle(tmp_path): + """Pending queue should leave overflow messages queued for later drains.""" + from nanobot.agent.loop import AgentLoop + from nanobot.bus.events import InboundMessage + from nanobot.bus.queue import MessageBus + from nanobot.agent.runner import _MAX_INJECTIONS_PER_TURN + + bus = MessageBus() + provider = MagicMock() + provider.get_default_model.return_value = "test-model" + captured_messages: list[list[dict]] = [] + call_count = {"n": 0} + + async def chat_with_retry(*, messages, **kwargs): + call_count["n"] += 1 + captured_messages.append([dict(message) for message in messages]) + return LLMResponse(content=f"answer-{call_count['n']}", tool_calls=[], usage={}) + + provider.chat_with_retry = chat_with_retry + loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model") + loop.tools.get_definitions = MagicMock(return_value=[]) + + pending_queue = asyncio.Queue() + total_followups = _MAX_INJECTIONS_PER_TURN + 2 + for idx in range(total_followups): + await pending_queue.put(InboundMessage( + channel="cli", + sender_id="u", + chat_id="c", + content=f"follow-up-{idx}", + )) + + final_content, _, _, _, had_injections = await loop._run_agent_loop( + [{"role": "user", "content": "hello"}], + channel="cli", + chat_id="c", + pending_queue=pending_queue, + ) + + assert final_content == "answer-3" + assert had_injections is True + assert call_count["n"] == 3 + flattened_user_content = "\n".join( + message["content"] + for message in captured_messages[-1] + if message.get("role") == "user" and isinstance(message.get("content"), str) + ) + for idx in range(total_followups): + assert f"follow-up-{idx}" in flattened_user_content + assert pending_queue.empty() + + +@pytest.mark.asyncio +async def test_pending_queue_full_falls_back_to_queued_task(tmp_path): + """QueueFull should preserve the message by dispatching a queued task.""" + from nanobot.bus.events import InboundMessage + + loop = _make_loop(tmp_path) + loop._dispatch = AsyncMock() # type: ignore[method-assign] + + pending = asyncio.Queue(maxsize=1) + pending.put_nowait(InboundMessage(channel="cli", sender_id="u", chat_id="c", content="already queued")) + loop._pending_queues["cli:c"] = pending + + run_task = asyncio.create_task(loop.run()) + msg = InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up") + await loop.bus.publish_inbound(msg) + + deadline = time.time() + 2 + while loop._dispatch.await_count == 0 and time.time() < deadline: + await asyncio.sleep(0.01) + + loop.stop() + await asyncio.wait_for(run_task, timeout=2) + + assert loop._dispatch.await_count == 1 + dispatched_msg = loop._dispatch.await_args.args[0] + assert dispatched_msg.content == "follow-up" + assert pending.qsize() == 1 + + +@pytest.mark.asyncio +async def test_dispatch_republishes_leftover_queue_messages(tmp_path): + """Messages left in the pending queue after _dispatch are re-published to the bus. + + This tests the finally-block cleanup that prevents message loss when + the runner exits early (e.g., max_iterations, tool_error) with messages + still in the queue. + """ + from nanobot.bus.events import InboundMessage + + loop = _make_loop(tmp_path) + bus = loop.bus + + # Simulate a completed dispatch by manually registering a queue + # with leftover messages, then running the cleanup logic directly. + pending = asyncio.Queue(maxsize=20) + session_key = "cli:c" + loop._pending_queues[session_key] = pending + pending.put_nowait(InboundMessage(channel="cli", sender_id="u", chat_id="c", content="leftover-1")) + pending.put_nowait(InboundMessage(channel="cli", sender_id="u", chat_id="c", content="leftover-2")) + + # Execute the cleanup logic from the finally block + queue = loop._pending_queues.pop(session_key, None) + assert queue is not None + leftover = 0 + while True: + try: + item = queue.get_nowait() + except asyncio.QueueEmpty: + break + await bus.publish_inbound(item) + leftover += 1 + + assert leftover == 2 + + # Verify the messages are now on the bus + msgs = [] + while not bus.inbound.empty(): + msgs.append(await asyncio.wait_for(bus.consume_inbound(), timeout=0.5)) + contents = [m.content for m in msgs] + assert "leftover-1" in contents + assert "leftover-2" in contents diff --git a/tests/channels/test_qq_media.py b/tests/channels/test_qq_media.py new file mode 100644 index 000000000..80a5ad20e --- /dev/null +++ b/tests/channels/test_qq_media.py @@ -0,0 +1,304 @@ +"""Tests for QQ channel media support: helpers, send, inbound, and upload.""" + +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +try: + from nanobot.channels import qq + + QQ_AVAILABLE = getattr(qq, "QQ_AVAILABLE", False) +except ImportError: + QQ_AVAILABLE = False + +if not QQ_AVAILABLE: + pytest.skip("QQ dependencies not installed (qq-botpy)", allow_module_level=True) + +from nanobot.bus.events import OutboundMessage +from nanobot.bus.queue import MessageBus +from nanobot.channels.qq import ( + QQ_FILE_TYPE_FILE, + QQ_FILE_TYPE_IMAGE, + QQChannel, + QQConfig, + _guess_send_file_type, + _is_image_name, + _sanitize_filename, +) + + +class _FakeApi: + def __init__(self) -> None: + self.c2c_calls: list[dict] = [] + self.group_calls: list[dict] = [] + + async def post_c2c_message(self, **kwargs) -> None: + self.c2c_calls.append(kwargs) + + async def post_group_message(self, **kwargs) -> None: + self.group_calls.append(kwargs) + + +class _FakeHttp: + """Fake _http for _post_base64file tests.""" + + def __init__(self, return_value: dict | None = None) -> None: + self.return_value = return_value or {} + self.calls: list[tuple] = [] + + async def request(self, route, **kwargs): + self.calls.append((route, kwargs)) + return self.return_value + + +class _FakeClient: + def __init__(self, http_return: dict | None = None) -> None: + self.api = _FakeApi() + self.api._http = _FakeHttp(http_return) + + +# ── Helper function tests (pure, no async) ────────────────────────── + + +def test_sanitize_filename_strips_path_traversal() -> None: + assert _sanitize_filename("../../etc/passwd") == "passwd" + + +def test_sanitize_filename_keeps_chinese_chars() -> None: + assert _sanitize_filename("文件(1).jpg") == "文件(1).jpg" + + +def test_sanitize_filename_strips_unsafe_chars() -> None: + result = _sanitize_filename('file<>:"|?*.txt') + # All unsafe chars replaced with "_", but * is replaced too + assert result.startswith("file") + assert result.endswith(".txt") + assert "<" not in result + assert ">" not in result + assert '"' not in result + assert "|" not in result + assert "?" not in result + + +def test_sanitize_filename_empty_input() -> None: + assert _sanitize_filename("") == "" + + +def test_is_image_name_with_known_extensions() -> None: + for ext in (".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp", ".tif", ".tiff", ".ico", ".svg"): + assert _is_image_name(f"photo{ext}") is True + + +def test_is_image_name_with_unknown_extension() -> None: + for ext in (".pdf", ".txt", ".mp3", ".mp4"): + assert _is_image_name(f"doc{ext}") is False + + +def test_guess_send_file_type_image() -> None: + assert _guess_send_file_type("photo.png") == QQ_FILE_TYPE_IMAGE + assert _guess_send_file_type("pic.jpg") == QQ_FILE_TYPE_IMAGE + + +def test_guess_send_file_type_file() -> None: + assert _guess_send_file_type("doc.pdf") == QQ_FILE_TYPE_FILE + + +def test_guess_send_file_type_by_mime() -> None: + # A filename with no known extension but whose mime type is image/* + assert _guess_send_file_type("photo.xyz_image_test") == QQ_FILE_TYPE_FILE + + +# ── send() exception handling ─────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_send_exception_caught_not_raised() -> None: + """Exceptions inside send() must not propagate.""" + channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus()) + channel._client = _FakeClient() + + with patch.object(channel, "_send_text_only", new_callable=AsyncMock, side_effect=RuntimeError("boom")): + await channel.send( + OutboundMessage(channel="qq", chat_id="user1", content="hello") + ) + # No exception raised — test passes if we get here. + + +@pytest.mark.asyncio +async def test_send_media_then_text() -> None: + """Media is sent before text when both are present.""" + import tempfile + + channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus()) + channel._client = _FakeClient() + + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f: + f.write(b"\x89PNG\r\n") + tmp = f.name + + try: + with patch.object(channel, "_post_base64file", new_callable=AsyncMock, return_value={"file_info": "1"}) as mock_upload: + await channel.send( + OutboundMessage( + channel="qq", + chat_id="user1", + content="text after image", + media=[tmp], + metadata={"message_id": "m1"}, + ) + ) + assert mock_upload.called + + # Text should have been sent via c2c (default chat type) + text_calls = [c for c in channel._client.api.c2c_calls if c.get("msg_type") == 0] + assert len(text_calls) >= 1 + assert text_calls[-1]["content"] == "text after image" + finally: + import os + os.unlink(tmp) + + +@pytest.mark.asyncio +async def test_send_media_failure_falls_back_to_text() -> None: + """When _send_media returns False, a failure notice is appended.""" + channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus()) + channel._client = _FakeClient() + + with patch.object(channel, "_send_media", new_callable=AsyncMock, return_value=False): + await channel.send( + OutboundMessage( + channel="qq", + chat_id="user1", + content="hello", + media=["https://example.com/bad.png"], + metadata={"message_id": "m1"}, + ) + ) + + # Should have the failure text among the c2c calls + failure_calls = [c for c in channel._client.api.c2c_calls if "Attachment send failed" in c.get("content", "")] + assert len(failure_calls) == 1 + assert "bad.png" in failure_calls[0]["content"] + + +# ── _on_message() exception handling ──────────────────────────────── + + +@pytest.mark.asyncio +async def test_on_message_exception_caught_not_raised() -> None: + """Missing required attributes should not crash _on_message.""" + channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus()) + channel._client = _FakeClient() + + # Construct a message-like object that lacks 'author' — triggers AttributeError + bad_data = SimpleNamespace(id="x1", content="hi") + # Should not raise + await channel._on_message(bad_data, is_group=False) + + +@pytest.mark.asyncio +async def test_on_message_with_attachments() -> None: + """Messages with attachments produce media_paths and formatted content.""" + import tempfile + + channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus()) + channel._client = _FakeClient() + + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f: + f.write(b"\x89PNG\r\n") + saved_path = f.name + + att = SimpleNamespace(url="", filename="screenshot.png", content_type="image/png") + + # Patch _download_to_media_dir_chunked to return the temp file path + async def fake_download(url, filename_hint=""): + return saved_path + + try: + with patch.object(channel, "_download_to_media_dir_chunked", side_effect=fake_download): + data = SimpleNamespace( + id="att1", + content="look at this", + author=SimpleNamespace(user_openid="u1"), + attachments=[att], + ) + await channel._on_message(data, is_group=False) + + msg = await channel.bus.consume_inbound() + assert "look at this" in msg.content + assert "screenshot.png" in msg.content + assert "Received files:" in msg.content + assert len(msg.media) == 1 + assert msg.media[0] == saved_path + finally: + import os + os.unlink(saved_path) + + +# ── _post_base64file() ───────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_post_base64file_omits_file_name_for_images() -> None: + """file_type=1 (image) → payload must not contain file_name.""" + channel = QQChannel(QQConfig(app_id="app", secret="secret"), MessageBus()) + channel._client = _FakeClient(http_return={"file_info": "img_abc"}) + + await channel._post_base64file( + chat_id="user1", + is_group=False, + file_type=QQ_FILE_TYPE_IMAGE, + file_data="ZmFrZQ==", + file_name="photo.png", + ) + + http = channel._client.api._http + assert len(http.calls) == 1 + payload = http.calls[0][1]["json"] + assert "file_name" not in payload + assert payload["file_type"] == QQ_FILE_TYPE_IMAGE + + +@pytest.mark.asyncio +async def test_post_base64file_includes_file_name_for_files() -> None: + """file_type=4 (file) → payload must contain file_name.""" + channel = QQChannel(QQConfig(app_id="app", secret="secret"), MessageBus()) + channel._client = _FakeClient(http_return={"file_info": "file_abc"}) + + await channel._post_base64file( + chat_id="user1", + is_group=False, + file_type=QQ_FILE_TYPE_FILE, + file_data="ZmFrZQ==", + file_name="report.pdf", + ) + + http = channel._client.api._http + assert len(http.calls) == 1 + payload = http.calls[0][1]["json"] + assert payload["file_name"] == "report.pdf" + assert payload["file_type"] == QQ_FILE_TYPE_FILE + + +@pytest.mark.asyncio +async def test_post_base64file_filters_response_to_file_info() -> None: + """Response with file_info + extra fields must be filtered to only file_info.""" + channel = QQChannel(QQConfig(app_id="app", secret="secret"), MessageBus()) + channel._client = _FakeClient(http_return={ + "file_info": "fi_123", + "file_uuid": "uuid_xxx", + "ttl": 3600, + }) + + result = await channel._post_base64file( + chat_id="user1", + is_group=False, + file_type=QQ_FILE_TYPE_FILE, + file_data="ZmFrZQ==", + file_name="doc.pdf", + ) + + assert result == {"file_info": "fi_123"} + assert "file_uuid" not in result + assert "ttl" not in result diff --git a/tests/channels/test_wecom_channel.py b/tests/channels/test_wecom_channel.py new file mode 100644 index 000000000..b79c023ba --- /dev/null +++ b/tests/channels/test_wecom_channel.py @@ -0,0 +1,584 @@ +"""Tests for WeCom channel: helpers, download, upload, send, and message processing.""" + +import os +import tempfile +from pathlib import Path +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +try: + import importlib.util + + WECOM_AVAILABLE = importlib.util.find_spec("wecom_aibot_sdk") is not None +except ImportError: + WECOM_AVAILABLE = False + +if not WECOM_AVAILABLE: + pytest.skip("WeCom dependencies not installed (wecom_aibot_sdk)", allow_module_level=True) + +from nanobot.bus.events import OutboundMessage +from nanobot.bus.queue import MessageBus +from nanobot.channels.wecom import ( + WecomChannel, + WecomConfig, + _guess_wecom_media_type, + _sanitize_filename, +) + +# Try to import the real response class; fall back to a stub if unavailable. +try: + from wecom_aibot_sdk.utils import WsResponse + + _RealWsResponse = WsResponse +except ImportError: + _RealWsResponse = None + + +class _FakeResponse: + """Minimal stand-in for wecom_aibot_sdk WsResponse.""" + + def __init__(self, errcode: int = 0, body: dict | None = None, errmsg: str = "ok"): + self.errcode = errcode + self.errmsg = errmsg + self.body = body or {} + + +class _FakeWsManager: + """Tracks send_reply calls and returns configurable responses.""" + + def __init__(self, responses: list[_FakeResponse] | None = None): + self.responses = responses or [] + self.calls: list[tuple[str, dict, str]] = [] + self._idx = 0 + + async def send_reply(self, req_id: str, data: dict, cmd: str) -> _FakeResponse: + self.calls.append((req_id, data, cmd)) + if self._idx < len(self.responses): + resp = self.responses[self._idx] + self._idx += 1 + return resp + return _FakeResponse() + + +class _FakeFrame: + """Minimal frame object with a body dict.""" + + def __init__(self, body: dict | None = None): + self.body = body or {} + + +class _FakeWeComClient: + """Fake WeCom client with mock methods.""" + + def __init__(self, ws_responses: list[_FakeResponse] | None = None): + self._ws_manager = _FakeWsManager(ws_responses) + self.download_file = AsyncMock(return_value=(None, None)) + self.reply = AsyncMock() + self.reply_stream = AsyncMock() + self.send_message = AsyncMock() + self.reply_welcome = AsyncMock() + + +# ── Helper function tests (pure, no async) ────────────────────────── + + +def test_sanitize_filename_strips_path_traversal() -> None: + assert _sanitize_filename("../../etc/passwd") == "passwd" + + +def test_sanitize_filename_keeps_chinese_chars() -> None: + assert _sanitize_filename("文件(1).jpg") == "文件(1).jpg" + + +def test_sanitize_filename_empty_input() -> None: + assert _sanitize_filename("") == "" + + +def test_guess_wecom_media_type_image() -> None: + for ext in (".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp"): + assert _guess_wecom_media_type(f"photo{ext}") == "image" + + +def test_guess_wecom_media_type_video() -> None: + for ext in (".mp4", ".avi", ".mov"): + assert _guess_wecom_media_type(f"video{ext}") == "video" + + +def test_guess_wecom_media_type_voice() -> None: + for ext in (".amr", ".mp3", ".wav", ".ogg"): + assert _guess_wecom_media_type(f"audio{ext}") == "voice" + + +def test_guess_wecom_media_type_file_fallback() -> None: + for ext in (".pdf", ".doc", ".xlsx", ".zip"): + assert _guess_wecom_media_type(f"doc{ext}") == "file" + + +def test_guess_wecom_media_type_case_insensitive() -> None: + assert _guess_wecom_media_type("photo.PNG") == "image" + assert _guess_wecom_media_type("photo.Jpg") == "image" + + +# ── _download_and_save_media() ────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_download_and_save_success() -> None: + """Successful download writes file and returns sanitized path.""" + channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus()) + client = _FakeWeComClient() + channel._client = client + + fake_data = b"\x89PNG\r\nfake image" + client.download_file.return_value = (fake_data, "raw_photo.png") + + with patch("nanobot.channels.wecom.get_media_dir", return_value=Path(tempfile.gettempdir())): + path = await channel._download_and_save_media("https://example.com/img.png", "aes_key", "image", "photo.png") + + assert path is not None + assert os.path.isfile(path) + assert os.path.basename(path) == "photo.png" + # Cleanup + os.unlink(path) + + +@pytest.mark.asyncio +async def test_download_and_save_oversized_rejected() -> None: + """Data exceeding 200MB is rejected → returns None.""" + channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus()) + client = _FakeWeComClient() + channel._client = client + + big_data = b"\x00" * (200 * 1024 * 1024 + 1) # 200MB + 1 byte + client.download_file.return_value = (big_data, "big.bin") + + with patch("nanobot.channels.wecom.get_media_dir", return_value=Path(tempfile.gettempdir())): + result = await channel._download_and_save_media("https://example.com/big.bin", "key", "file", "big.bin") + + assert result is None + + +@pytest.mark.asyncio +async def test_download_and_save_failure() -> None: + """SDK returns None data → returns None.""" + channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus()) + client = _FakeWeComClient() + channel._client = client + + client.download_file.return_value = (None, None) + + with patch("nanobot.channels.wecom.get_media_dir", return_value=Path(tempfile.gettempdir())): + result = await channel._download_and_save_media("https://example.com/fail.png", "key", "image") + + assert result is None + + +# ── _upload_media_ws() ────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_upload_media_ws_success() -> None: + """Happy path: init → chunk → finish → returns (media_id, media_type).""" + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f: + f.write(b"\x89PNG\r\n") + tmp = f.name + + try: + responses = [ + _FakeResponse(errcode=0, body={"upload_id": "up_1"}), + _FakeResponse(errcode=0, body={}), + _FakeResponse(errcode=0, body={"media_id": "media_abc"}), + ] + + client = _FakeWeComClient(responses) + channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus()) + channel._client = client + + with patch("wecom_aibot_sdk.utils.generate_req_id", side_effect=lambda x: f"req_{x}"): + media_id, media_type = await channel._upload_media_ws(client, tmp) + + assert media_id == "media_abc" + assert media_type == "image" + finally: + os.unlink(tmp) + + +@pytest.mark.asyncio +async def test_upload_media_ws_oversized_file() -> None: + """File >200MB triggers ValueError → returns (None, None).""" + # Instead of creating a real 200MB+ file, mock os.path.getsize and open + with patch("os.path.getsize", return_value=200 * 1024 * 1024 + 1), \ + patch("builtins.open", MagicMock()): + client = _FakeWeComClient() + channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus()) + channel._client = client + + result = await channel._upload_media_ws(client, "/fake/large.bin") + assert result == (None, None) + + +@pytest.mark.asyncio +async def test_upload_media_ws_init_failure() -> None: + """Init step returns errcode != 0 → returns (None, None).""" + with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as f: + f.write(b"hello") + tmp = f.name + + try: + responses = [ + _FakeResponse(errcode=50001, errmsg="invalid"), + ] + + client = _FakeWeComClient(responses) + channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus()) + channel._client = client + + with patch("wecom_aibot_sdk.utils.generate_req_id", side_effect=lambda x: f"req_{x}"): + result = await channel._upload_media_ws(client, tmp) + + assert result == (None, None) + finally: + os.unlink(tmp) + + +@pytest.mark.asyncio +async def test_upload_media_ws_chunk_failure() -> None: + """Chunk step returns errcode != 0 → returns (None, None).""" + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f: + f.write(b"\x89PNG\r\n") + tmp = f.name + + try: + responses = [ + _FakeResponse(errcode=0, body={"upload_id": "up_1"}), + _FakeResponse(errcode=50002, errmsg="chunk fail"), + ] + + client = _FakeWeComClient(responses) + channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus()) + channel._client = client + + with patch("wecom_aibot_sdk.utils.generate_req_id", side_effect=lambda x: f"req_{x}"): + result = await channel._upload_media_ws(client, tmp) + + assert result == (None, None) + finally: + os.unlink(tmp) + + +@pytest.mark.asyncio +async def test_upload_media_ws_finish_no_media_id() -> None: + """Finish step returns empty media_id → returns (None, None).""" + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f: + f.write(b"\x89PNG\r\n") + tmp = f.name + + try: + responses = [ + _FakeResponse(errcode=0, body={"upload_id": "up_1"}), + _FakeResponse(errcode=0, body={}), + _FakeResponse(errcode=0, body={}), # no media_id + ] + + client = _FakeWeComClient(responses) + channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus()) + channel._client = client + + with patch("wecom_aibot_sdk.utils.generate_req_id", side_effect=lambda x: f"req_{x}"): + result = await channel._upload_media_ws(client, tmp) + + assert result == (None, None) + finally: + os.unlink(tmp) + + +# ── send() ────────────────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_send_text_with_frame() -> None: + """When frame is stored, send uses reply_stream for final text.""" + channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus()) + client = _FakeWeComClient() + channel._client = client + channel._generate_req_id = lambda x: f"req_{x}" + channel._chat_frames["chat1"] = _FakeFrame() + + await channel.send( + OutboundMessage(channel="wecom", chat_id="chat1", content="hello") + ) + + client.reply_stream.assert_called_once() + call_args = client.reply_stream.call_args + assert call_args[0][2] == "hello" # content arg + + +@pytest.mark.asyncio +async def test_send_progress_with_frame() -> None: + """When metadata has _progress, send uses reply_stream with finish=False.""" + channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus()) + client = _FakeWeComClient() + channel._client = client + channel._generate_req_id = lambda x: f"req_{x}" + channel._chat_frames["chat1"] = _FakeFrame() + + await channel.send( + OutboundMessage(channel="wecom", chat_id="chat1", content="thinking...", metadata={"_progress": True}) + ) + + client.reply_stream.assert_called_once() + call_args = client.reply_stream.call_args + assert call_args[0][2] == "thinking..." # content arg + assert call_args[1]["finish"] is False + + +@pytest.mark.asyncio +async def test_send_proactive_without_frame() -> None: + """Without stored frame, send uses send_message with markdown.""" + channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus()) + client = _FakeWeComClient() + channel._client = client + + await channel.send( + OutboundMessage(channel="wecom", chat_id="chat1", content="proactive msg") + ) + + client.send_message.assert_called_once() + call_args = client.send_message.call_args + assert call_args[0][0] == "chat1" + assert call_args[0][1]["msgtype"] == "markdown" + + +@pytest.mark.asyncio +async def test_send_media_then_text() -> None: + """Media files are uploaded and sent before text content.""" + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f: + f.write(b"\x89PNG\r\n") + tmp = f.name + + try: + responses = [ + _FakeResponse(errcode=0, body={"upload_id": "up_1"}), + _FakeResponse(errcode=0, body={}), + _FakeResponse(errcode=0, body={"media_id": "media_123"}), + ] + + channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus()) + client = _FakeWeComClient(responses) + channel._client = client + channel._generate_req_id = lambda x: f"req_{x}" + channel._chat_frames["chat1"] = _FakeFrame() + + await channel.send( + OutboundMessage(channel="wecom", chat_id="chat1", content="see image", media=[tmp]) + ) + + # Media should have been sent via reply + media_calls = [c for c in client.reply.call_args_list if c[0][1].get("msgtype") == "image"] + assert len(media_calls) == 1 + assert media_calls[0][0][1]["image"]["media_id"] == "media_123" + + # Text should have been sent via reply_stream + client.reply_stream.assert_called_once() + finally: + os.unlink(tmp) + + +@pytest.mark.asyncio +async def test_send_media_file_not_found() -> None: + """Non-existent media path is skipped with a warning.""" + channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus()) + client = _FakeWeComClient() + channel._client = client + channel._generate_req_id = lambda x: f"req_{x}" + channel._chat_frames["chat1"] = _FakeFrame() + + await channel.send( + OutboundMessage(channel="wecom", chat_id="chat1", content="hello", media=["/nonexistent/file.png"]) + ) + + # reply_stream should still be called for the text part + client.reply_stream.assert_called_once() + # No media reply should happen + media_calls = [c for c in client.reply.call_args_list if c[0][1].get("msgtype") in ("image", "file", "video")] + assert len(media_calls) == 0 + + +@pytest.mark.asyncio +async def test_send_exception_caught_not_raised() -> None: + """Exceptions inside send() must not propagate.""" + channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus()) + client = _FakeWeComClient() + channel._client = client + channel._generate_req_id = lambda x: f"req_{x}" + channel._chat_frames["chat1"] = _FakeFrame() + + # Make reply_stream raise + client.reply_stream.side_effect = RuntimeError("boom") + + await channel.send( + OutboundMessage(channel="wecom", chat_id="chat1", content="fail test") + ) + # No exception — test passes if we reach here. + + +# ── _process_message() ────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_process_text_message() -> None: + """Text message is routed to bus with correct fields.""" + channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["user1"]), MessageBus()) + client = _FakeWeComClient() + channel._client = client + + frame = _FakeFrame(body={ + "msgid": "msg_text_1", + "chatid": "chat1", + "chattype": "single", + "from": {"userid": "user1"}, + "text": {"content": "hello wecom"}, + }) + + await channel._process_message(frame, "text") + + msg = await channel.bus.consume_inbound() + assert msg.sender_id == "user1" + assert msg.chat_id == "chat1" + assert msg.content == "hello wecom" + assert msg.metadata["msg_type"] == "text" + + +@pytest.mark.asyncio +async def test_process_image_message() -> None: + """Image message: download success → media_paths non-empty.""" + channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["user1"]), MessageBus()) + client = _FakeWeComClient() + + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f: + f.write(b"\x89PNG\r\n") + saved = f.name + + client.download_file.return_value = (b"\x89PNG\r\n", "photo.png") + channel._client = client + + try: + with patch("nanobot.channels.wecom.get_media_dir", return_value=Path(os.path.dirname(saved))): + frame = _FakeFrame(body={ + "msgid": "msg_img_1", + "chatid": "chat1", + "from": {"userid": "user1"}, + "image": {"url": "https://example.com/img.png", "aeskey": "key123"}, + }) + await channel._process_message(frame, "image") + + msg = await channel.bus.consume_inbound() + assert len(msg.media) == 1 + assert msg.media[0].endswith("photo.png") + assert "[image:" in msg.content + finally: + if os.path.exists(saved): + pass # may have been overwritten; clean up if exists + # Clean up any photo.png in tempdir + p = os.path.join(os.path.dirname(saved), "photo.png") + if os.path.exists(p): + os.unlink(p) + + +@pytest.mark.asyncio +async def test_process_file_message() -> None: + """File message: download success → media_paths non-empty (critical fix verification).""" + channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["user1"]), MessageBus()) + client = _FakeWeComClient() + + with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as f: + f.write(b"%PDF-1.4 fake") + saved = f.name + + client.download_file.return_value = (b"%PDF-1.4 fake", "report.pdf") + channel._client = client + + try: + with patch("nanobot.channels.wecom.get_media_dir", return_value=Path(os.path.dirname(saved))): + frame = _FakeFrame(body={ + "msgid": "msg_file_1", + "chatid": "chat1", + "from": {"userid": "user1"}, + "file": {"url": "https://example.com/report.pdf", "aeskey": "key456", "name": "report.pdf"}, + }) + await channel._process_message(frame, "file") + + msg = await channel.bus.consume_inbound() + assert len(msg.media) == 1 + assert msg.media[0].endswith("report.pdf") + assert "[file: report.pdf]" in msg.content + finally: + p = os.path.join(os.path.dirname(saved), "report.pdf") + if os.path.exists(p): + os.unlink(p) + + +@pytest.mark.asyncio +async def test_process_voice_message() -> None: + """Voice message: transcribed text is included in content.""" + channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["user1"]), MessageBus()) + client = _FakeWeComClient() + channel._client = client + + frame = _FakeFrame(body={ + "msgid": "msg_voice_1", + "chatid": "chat1", + "from": {"userid": "user1"}, + "voice": {"content": "transcribed text here"}, + }) + + await channel._process_message(frame, "voice") + + msg = await channel.bus.consume_inbound() + assert "transcribed text here" in msg.content + assert "[voice]" in msg.content + + +@pytest.mark.asyncio +async def test_process_message_deduplication() -> None: + """Same msg_id is not processed twice.""" + channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["user1"]), MessageBus()) + client = _FakeWeComClient() + channel._client = client + + frame = _FakeFrame(body={ + "msgid": "msg_dup_1", + "chatid": "chat1", + "from": {"userid": "user1"}, + "text": {"content": "once"}, + }) + + await channel._process_message(frame, "text") + await channel._process_message(frame, "text") + + msg = await channel.bus.consume_inbound() + assert msg.content == "once" + + # Second message should not appear on the bus + assert channel.bus.inbound.empty() + + +@pytest.mark.asyncio +async def test_process_message_empty_content_skipped() -> None: + """Message with empty content produces no bus message.""" + channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["user1"]), MessageBus()) + client = _FakeWeComClient() + channel._client = client + + frame = _FakeFrame(body={ + "msgid": "msg_empty_1", + "chatid": "chat1", + "from": {"userid": "user1"}, + "text": {"content": ""}, + }) + + await channel._process_message(frame, "text") + + assert channel.bus.inbound.empty() diff --git a/tests/tools/test_message_tool_suppress.py b/tests/tools/test_message_tool_suppress.py index 3f06b4a70..434b2ca71 100644 --- a/tests/tools/test_message_tool_suppress.py +++ b/tests/tools/test_message_tool_suppress.py @@ -1,5 +1,6 @@ """Test message tool suppress logic for final replies.""" +import asyncio from pathlib import Path from unittest.mock import AsyncMock, MagicMock @@ -86,6 +87,42 @@ class TestMessageToolSuppressLogic: assert result is not None assert "Hello" in result.content + @pytest.mark.asyncio + async def test_injected_followup_with_message_tool_does_not_emit_empty_fallback( + self, tmp_path: Path + ) -> None: + loop = _make_loop(tmp_path) + tool_call = ToolCallRequest( + id="call1", name="message", + arguments={"content": "Tool reply", "channel": "feishu", "chat_id": "chat123"}, + ) + calls = iter([ + LLMResponse(content="First answer", tool_calls=[]), + LLMResponse(content="", tool_calls=[tool_call]), + LLMResponse(content="", tool_calls=[]), + LLMResponse(content="", tool_calls=[]), + LLMResponse(content="", tool_calls=[]), + ]) + loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls)) + loop.tools.get_definitions = MagicMock(return_value=[]) + + sent: list[OutboundMessage] = [] + mt = loop.tools.get("message") + if isinstance(mt, MessageTool): + mt.set_send_callback(AsyncMock(side_effect=lambda m: sent.append(m))) + + pending_queue = asyncio.Queue() + await pending_queue.put( + InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="follow-up") + ) + + msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Start") + result = await loop._process_message(msg, pending_queue=pending_queue) + + assert len(sent) == 1 + assert sent[0].content == "Tool reply" + assert result is None + async def test_progress_hides_internal_reasoning(self, tmp_path: Path) -> None: loop = _make_loop(tmp_path) tool_call = ToolCallRequest(id="call1", name="read_file", arguments={"path": "foo.txt"}) @@ -107,7 +144,7 @@ class TestMessageToolSuppressLogic: async def on_progress(content: str, *, tool_hint: bool = False) -> None: progress.append((content, tool_hint)) - final_content, _, _, _ = await loop._run_agent_loop([], on_progress=on_progress) + final_content, _, _, _, _ = await loop._run_agent_loop([], on_progress=on_progress) assert final_content == "Done" assert progress == [