Merge branch 'main' into fix/tool-call-result-order-2943

This commit is contained in:
layla 2026-04-11 22:00:07 +08:00 committed by GitHub
commit f25cdb7138
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 2246 additions and 211 deletions

View File

@ -17,7 +17,7 @@ from nanobot.agent.autocompact import AutoCompact
from nanobot.agent.context import ContextBuilder
from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook
from nanobot.agent.memory import Consolidator, Dream
from nanobot.agent.runner import AgentRunSpec, AgentRunner
from nanobot.agent.runner import _MAX_INJECTIONS_PER_TURN, AgentRunSpec, AgentRunner
from nanobot.agent.subagent import SubagentManager
from nanobot.agent.tools.cron import CronTool
from nanobot.agent.skills import BUILTIN_SKILLS_DIR
@ -207,6 +207,10 @@ class AgentLoop:
self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks
self._background_tasks: list[asyncio.Task] = []
self._session_locks: dict[str, asyncio.Lock] = {}
# Per-session pending queues for mid-turn message injection.
# When a session has an active task, new messages for that session
# are routed here instead of creating a new task.
self._pending_queues: dict[str, asyncio.Queue] = {}
# NANOBOT_MAX_CONCURRENT_REQUESTS: <=0 means unlimited; default 3.
_max = int(os.environ.get("NANOBOT_MAX_CONCURRENT_REQUESTS", "3"))
self._concurrency_gate: asyncio.Semaphore | None = (
@ -320,6 +324,12 @@ class AgentLoop:
return format_tool_hints(tool_calls)
def _effective_session_key(self, msg: InboundMessage) -> str:
"""Return the session key used for task routing and mid-turn injections."""
if self._unified_session and not msg.session_key_override:
return UNIFIED_SESSION_KEY
return msg.session_key
async def _run_agent_loop(
self,
initial_messages: list[dict],
@ -331,13 +341,16 @@ class AgentLoop:
channel: str = "cli",
chat_id: str = "direct",
message_id: str | None = None,
) -> tuple[str | None, list[str], list[dict], str]:
pending_queue: asyncio.Queue | None = None,
) -> tuple[str | None, list[str], list[dict], str, bool]:
"""Run the agent iteration loop.
*on_stream*: called with each content delta during streaming.
*on_stream_end(resuming)*: called when a streaming session finishes.
``resuming=True`` means tool calls follow (spinner should restart);
``resuming=False`` means this is the final response.
Returns (final_content, tools_used, messages, stop_reason, had_injections).
"""
loop_hook = _LoopHook(
self,
@ -357,31 +370,56 @@ class AgentLoop:
return
self._set_runtime_checkpoint(session, payload)
result = await self.runner.run(
AgentRunSpec(
initial_messages=initial_messages,
tools=self.tools,
model=self.model,
max_iterations=self.max_iterations,
max_tool_result_chars=self.max_tool_result_chars,
hook=hook,
error_message="Sorry, I encountered an error calling the AI model.",
concurrent_tools=True,
workspace=self.workspace,
session_key=session.key if session else None,
context_window_tokens=self.context_window_tokens,
context_block_limit=self.context_block_limit,
provider_retry_mode=self.provider_retry_mode,
progress_callback=on_progress,
checkpoint_callback=_checkpoint,
)
)
async def _drain_pending(*, limit: int = _MAX_INJECTIONS_PER_TURN) -> list[dict[str, Any]]:
"""Non-blocking drain of follow-up messages from the pending queue."""
if pending_queue is None:
return []
items: list[dict[str, Any]] = []
while len(items) < limit:
try:
pending_msg = pending_queue.get_nowait()
except asyncio.QueueEmpty:
break
user_content = self.context._build_user_content(
pending_msg.content,
pending_msg.media if pending_msg.media else None,
)
runtime_ctx = self.context._build_runtime_context(
pending_msg.channel,
pending_msg.chat_id,
self.context.timezone,
)
if isinstance(user_content, str):
merged: str | list[dict[str, Any]] = f"{runtime_ctx}\n\n{user_content}"
else:
merged = [{"type": "text", "text": runtime_ctx}] + user_content
items.append({"role": "user", "content": merged})
return items
result = await self.runner.run(AgentRunSpec(
initial_messages=initial_messages,
tools=self.tools,
model=self.model,
max_iterations=self.max_iterations,
max_tool_result_chars=self.max_tool_result_chars,
hook=hook,
error_message="Sorry, I encountered an error calling the AI model.",
concurrent_tools=True,
workspace=self.workspace,
session_key=session.key if session else None,
context_window_tokens=self.context_window_tokens,
context_block_limit=self.context_block_limit,
provider_retry_mode=self.provider_retry_mode,
progress_callback=on_progress,
checkpoint_callback=_checkpoint,
injection_callback=_drain_pending,
))
self._last_usage = result.usage
if result.stop_reason == "max_iterations":
logger.warning("Max iterations ({}) reached", self.max_iterations)
elif result.stop_reason == "error":
logger.error("LLM returned error: {}", (result.final_content or "")[:200])
return result.final_content, result.tools_used, result.messages, result.stop_reason
return result.final_content, result.tools_used, result.messages, result.stop_reason, result.had_injections
async def run(self) -> None:
"""Run the agent loop, dispatching messages as tasks to stay responsive to /stop."""
@ -412,13 +450,32 @@ class AgentLoop:
if result:
await self.bus.publish_outbound(result)
continue
effective_key = self._effective_session_key(msg)
# If this session already has an active pending queue (i.e. a task
# is processing this session), route the message there for mid-turn
# injection instead of creating a competing task.
if effective_key in self._pending_queues:
pending_msg = msg
if effective_key != msg.session_key:
pending_msg = dataclasses.replace(
msg,
session_key_override=effective_key,
)
try:
self._pending_queues[effective_key].put_nowait(pending_msg)
except asyncio.QueueFull:
logger.warning(
"Pending queue full for session {}, falling back to queued task",
effective_key,
)
else:
logger.info(
"Routed follow-up message to pending queue for session {}",
effective_key,
)
continue
# Compute the effective session key before dispatching
# This ensures /stop command can find tasks correctly when unified session is enabled
effective_key = (
UNIFIED_SESSION_KEY
if self._unified_session and not msg.session_key_override
else msg.session_key
)
task = asyncio.create_task(self._dispatch(msg))
self._active_tasks.setdefault(effective_key, []).append(task)
task.add_done_callback(
@ -430,78 +487,91 @@ class AgentLoop:
async def _dispatch(self, msg: InboundMessage) -> None:
"""Process a message: per-session serial, cross-session concurrent."""
if self._unified_session and not msg.session_key_override:
msg = dataclasses.replace(msg, session_key_override=UNIFIED_SESSION_KEY)
lock = self._session_locks.setdefault(msg.session_key, asyncio.Lock())
session_key = self._effective_session_key(msg)
if session_key != msg.session_key:
msg = dataclasses.replace(msg, session_key_override=session_key)
lock = self._session_locks.setdefault(session_key, asyncio.Lock())
gate = self._concurrency_gate or nullcontext()
async with lock, gate:
try:
on_stream = on_stream_end = None
if msg.metadata.get("_wants_stream"):
# Split one answer into distinct stream segments.
stream_base_id = f"{msg.session_key}:{time.time_ns()}"
stream_segment = 0
def _current_stream_id() -> str:
return f"{stream_base_id}:{stream_segment}"
# Register a pending queue so follow-up messages for this session are
# routed here (mid-turn injection) instead of spawning a new task.
pending = asyncio.Queue(maxsize=20)
self._pending_queues[session_key] = pending
async def on_stream(delta: str) -> None:
meta = dict(msg.metadata or {})
meta["_stream_delta"] = True
meta["_stream_id"] = _current_stream_id()
await self.bus.publish_outbound(
OutboundMessage(
channel=msg.channel,
chat_id=msg.chat_id,
try:
async with lock, gate:
try:
on_stream = on_stream_end = None
if msg.metadata.get("_wants_stream"):
# Split one answer into distinct stream segments.
stream_base_id = f"{msg.session_key}:{time.time_ns()}"
stream_segment = 0
def _current_stream_id() -> str:
return f"{stream_base_id}:{stream_segment}"
async def on_stream(delta: str) -> None:
meta = dict(msg.metadata or {})
meta["_stream_delta"] = True
meta["_stream_id"] = _current_stream_id()
await self.bus.publish_outbound(OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id,
content=delta,
metadata=meta,
)
)
))
async def on_stream_end(*, resuming: bool = False) -> None:
nonlocal stream_segment
meta = dict(msg.metadata or {})
meta["_stream_end"] = True
meta["_resuming"] = resuming
meta["_stream_id"] = _current_stream_id()
await self.bus.publish_outbound(
OutboundMessage(
channel=msg.channel,
chat_id=msg.chat_id,
async def on_stream_end(*, resuming: bool = False) -> None:
nonlocal stream_segment
meta = dict(msg.metadata or {})
meta["_stream_end"] = True
meta["_resuming"] = resuming
meta["_stream_id"] = _current_stream_id()
await self.bus.publish_outbound(OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id,
content="",
metadata=meta,
)
)
stream_segment += 1
))
stream_segment += 1
response = await self._process_message(
msg,
on_stream=on_stream,
on_stream_end=on_stream_end,
)
if response is not None:
await self.bus.publish_outbound(response)
elif msg.channel == "cli":
await self.bus.publish_outbound(
OutboundMessage(
channel=msg.channel,
chat_id=msg.chat_id,
content="",
metadata=msg.metadata or {},
)
response = await self._process_message(
msg, on_stream=on_stream, on_stream_end=on_stream_end,
pending_queue=pending,
)
except asyncio.CancelledError:
logger.info("Task cancelled for session {}", msg.session_key)
raise
except Exception:
logger.exception("Error processing message for session {}", msg.session_key)
await self.bus.publish_outbound(
OutboundMessage(
channel=msg.channel,
chat_id=msg.chat_id,
if response is not None:
await self.bus.publish_outbound(response)
elif msg.channel == "cli":
await self.bus.publish_outbound(OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id,
content="", metadata=msg.metadata or {},
))
except asyncio.CancelledError:
logger.info("Task cancelled for session {}", session_key)
raise
except Exception:
logger.exception("Error processing message for session {}", session_key)
await self.bus.publish_outbound(OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id,
content="Sorry, I encountered an error.",
))
finally:
# Drain any messages still in the pending queue and re-publish
# them to the bus so they are processed as fresh inbound messages
# rather than silently lost.
queue = self._pending_queues.pop(session_key, None)
if queue is not None:
leftover = 0
while True:
try:
item = queue.get_nowait()
except asyncio.QueueEmpty:
break
await self.bus.publish_inbound(item)
leftover += 1
if leftover:
logger.info(
"Re-published {} leftover message(s) to bus for session {}",
leftover, session_key,
)
)
async def close_mcp(self) -> None:
"""Drain pending background archives, then close MCP connections."""
@ -533,6 +603,7 @@ class AgentLoop:
on_progress: Callable[[str], Awaitable[None]] | None = None,
on_stream: Callable[[str], Awaitable[None]] | None = None,
on_stream_end: Callable[..., Awaitable[None]] | None = None,
pending_queue: asyncio.Queue | None = None,
) -> OutboundMessage | None:
"""Process a single inbound message and return the response."""
# System messages: parse origin from chat_id ("channel:chat_id")
@ -559,11 +630,8 @@ class AgentLoop:
session_summary=pending,
current_role=current_role,
)
final_content, _, all_msgs, _ = await self._run_agent_loop(
messages,
session=session,
channel=channel,
chat_id=chat_id,
final_content, _, all_msgs, _, _ = await self._run_agent_loop(
messages, session=session, channel=channel, chat_id=chat_id,
message_id=msg.metadata.get("message_id"),
)
self._save_turn(session, all_msgs, 1 + len(history))
@ -623,7 +691,7 @@ class AgentLoop:
)
)
final_content, _, all_msgs, stop_reason = await self._run_agent_loop(
final_content, _, all_msgs, stop_reason, had_injections = await self._run_agent_loop(
initial_messages,
on_progress=on_progress or _bus_progress,
on_stream=on_stream,
@ -632,6 +700,7 @@ class AgentLoop:
channel=msg.channel,
chat_id=msg.chat_id,
message_id=msg.metadata.get("message_id"),
pending_queue=pending_queue,
)
if final_content is None or not final_content.strip():
@ -642,8 +711,15 @@ class AgentLoop:
self.sessions.save(session)
self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session))
# When follow-up messages were injected mid-turn, a later natural
# language reply may address those follow-ups and should not be
# suppressed just because MessageTool was used earlier in the turn.
# However, if the turn falls back to the empty-final-response
# placeholder, suppress it when the real user-visible output already
# came from MessageTool.
if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn:
return None
if not had_injections or stop_reason == "empty_final_response":
return None
preview = final_content[:120] + "..." if len(final_content) > 120 else final_content
logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview)

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import asyncio
from dataclasses import dataclass, field
import inspect
from pathlib import Path
from typing import Any
@ -34,6 +35,8 @@ _DEFAULT_ERROR_MESSAGE = "Sorry, I encountered an error calling the AI model."
_PERSISTED_MODEL_ERROR_PLACEHOLDER = "[Assistant reply unavailable due to model error.]"
_MAX_EMPTY_RETRIES = 2
_MAX_LENGTH_RECOVERIES = 3
_MAX_INJECTIONS_PER_TURN = 3
_MAX_INJECTION_CYCLES = 5
_SNIP_SAFETY_BUFFER = 1024
_MICROCOMPACT_KEEP_RECENT = 10
_MICROCOMPACT_MIN_CHARS = 500
@ -42,6 +45,9 @@ _COMPACTABLE_TOOLS = frozenset({
"web_search", "web_fetch", "list_dir",
})
_BACKFILL_CONTENT = "[Tool result unavailable — call was interrupted or lost]"
@dataclass(slots=True)
class AgentRunSpec:
"""Configuration for a single agent execution."""
@ -66,6 +72,7 @@ class AgentRunSpec:
provider_retry_mode: str = "standard"
progress_callback: Any | None = None
checkpoint_callback: Any | None = None
injection_callback: Any | None = None
@dataclass(slots=True)
@ -79,6 +86,7 @@ class AgentRunResult:
stop_reason: str = "completed"
error: str | None = None
tool_events: list[dict[str, str]] = field(default_factory=list)
had_injections: bool = False
class AgentRunner:
@ -87,6 +95,90 @@ class AgentRunner:
def __init__(self, provider: LLMProvider):
self.provider = provider
@staticmethod
def _merge_message_content(left: Any, right: Any) -> str | list[dict[str, Any]]:
if isinstance(left, str) and isinstance(right, str):
return f"{left}\n\n{right}" if left else right
def _to_blocks(value: Any) -> list[dict[str, Any]]:
if isinstance(value, list):
return [
item if isinstance(item, dict) else {"type": "text", "text": str(item)}
for item in value
]
if value is None:
return []
return [{"type": "text", "text": str(value)}]
return _to_blocks(left) + _to_blocks(right)
@classmethod
def _append_injected_messages(
cls,
messages: list[dict[str, Any]],
injections: list[dict[str, Any]],
) -> None:
"""Append injected user messages while preserving role alternation."""
for injection in injections:
if (
messages
and injection.get("role") == "user"
and messages[-1].get("role") == "user"
):
merged = dict(messages[-1])
merged["content"] = cls._merge_message_content(
merged.get("content"),
injection.get("content"),
)
messages[-1] = merged
continue
messages.append(injection)
async def _drain_injections(self, spec: AgentRunSpec) -> list[dict[str, Any]]:
"""Drain pending user messages via the injection callback.
Returns normalized user messages (capped by
``_MAX_INJECTIONS_PER_TURN``), or an empty list when there is
nothing to inject. Messages beyond the cap are logged so they
are not silently lost.
"""
if spec.injection_callback is None:
return []
try:
signature = inspect.signature(spec.injection_callback)
accepts_limit = (
"limit" in signature.parameters
or any(
parameter.kind is inspect.Parameter.VAR_KEYWORD
for parameter in signature.parameters.values()
)
)
if accepts_limit:
items = await spec.injection_callback(limit=_MAX_INJECTIONS_PER_TURN)
else:
items = await spec.injection_callback()
except Exception:
logger.exception("injection_callback failed")
return []
if not items:
return []
injected_messages: list[dict[str, Any]] = []
for item in items:
if isinstance(item, dict) and item.get("role") == "user" and "content" in item:
injected_messages.append(item)
continue
text = getattr(item, "content", str(item))
if text.strip():
injected_messages.append({"role": "user", "content": text})
if len(injected_messages) > _MAX_INJECTIONS_PER_TURN:
dropped = len(injected_messages) - _MAX_INJECTIONS_PER_TURN
logger.warning(
"Injection callback returned {} messages, capping to {} ({} dropped)",
len(injected_messages), _MAX_INJECTIONS_PER_TURN, dropped,
)
injected_messages = injected_messages[:_MAX_INJECTIONS_PER_TURN]
return injected_messages
async def run(self, spec: AgentRunSpec) -> AgentRunResult:
hook = spec.hook or AgentHook()
messages = list(spec.initial_messages)
@ -99,6 +191,8 @@ class AgentRunner:
external_lookup_counts: dict[str, int] = {}
empty_content_retries = 0
length_recovery_count = 0
had_injections = False
injection_cycles = 0
for iteration in range(spec.max_iterations):
try:
@ -207,6 +301,17 @@ class AgentRunner:
)
empty_content_retries = 0
length_recovery_count = 0
# Checkpoint 1: drain injections after tools, before next LLM call
if injection_cycles < _MAX_INJECTION_CYCLES:
injections = await self._drain_injections(spec)
if injections:
had_injections = True
injection_cycles += 1
self._append_injected_messages(messages, injections)
logger.info(
"Injected {} follow-up message(s) after tool execution ({}/{})",
len(injections), injection_cycles, _MAX_INJECTION_CYCLES,
)
await hook.after_iteration(context)
continue
@ -263,8 +368,49 @@ class AgentRunner:
await hook.after_iteration(context)
continue
assistant_message: dict[str, Any] | None = None
if response.finish_reason != "error" and not is_blank_text(clean):
assistant_message = build_assistant_message(
clean,
reasoning_content=response.reasoning_content,
thinking_blocks=response.thinking_blocks,
)
# Check for mid-turn injections BEFORE signaling stream end.
# If injections are found we keep the stream alive (resuming=True)
# so streaming channels don't prematurely finalize the card.
_injected_after_final = False
if injection_cycles < _MAX_INJECTION_CYCLES:
injections = await self._drain_injections(spec)
if injections:
had_injections = True
injection_cycles += 1
_injected_after_final = True
if assistant_message is not None:
messages.append(assistant_message)
await self._emit_checkpoint(
spec,
{
"phase": "final_response",
"iteration": iteration,
"model": spec.model,
"assistant_message": assistant_message,
"completed_tool_results": [],
"pending_tool_calls": [],
},
)
self._append_injected_messages(messages, injections)
logger.info(
"Injected {} follow-up message(s) after final response ({}/{})",
len(injections), injection_cycles, _MAX_INJECTION_CYCLES,
)
if hook.wants_streaming():
await hook.on_stream_end(context, resuming=False)
await hook.on_stream_end(context, resuming=_injected_after_final)
if _injected_after_final:
await hook.after_iteration(context)
continue
if response.finish_reason == "error":
final_content = clean or spec.error_message or _DEFAULT_ERROR_MESSAGE
@ -287,7 +433,7 @@ class AgentRunner:
await hook.after_iteration(context)
break
messages.append(build_assistant_message(
messages.append(assistant_message or build_assistant_message(
clean,
reasoning_content=response.reasoning_content,
thinking_blocks=response.thinking_blocks,
@ -330,6 +476,7 @@ class AgentRunner:
stop_reason=stop_reason,
error=error,
tool_events=tool_events,
had_injections=had_injections,
)
def _build_request_kwargs(

View File

@ -242,43 +242,46 @@ class QQChannel(BaseChannel):
async def send(self, msg: OutboundMessage) -> None:
"""Send attachments first, then text."""
if not self._client:
logger.warning("QQ client not initialized")
return
try:
if not self._client:
logger.warning("QQ client not initialized")
return
msg_id = msg.metadata.get("message_id")
chat_type = self._chat_type_cache.get(msg.chat_id, "c2c")
is_group = chat_type == "group"
msg_id = msg.metadata.get("message_id")
chat_type = self._chat_type_cache.get(msg.chat_id, "c2c")
is_group = chat_type == "group"
# 1) Send media
for media_ref in msg.media or []:
ok = await self._send_media(
chat_id=msg.chat_id,
media_ref=media_ref,
msg_id=msg_id,
is_group=is_group,
)
if not ok:
filename = (
os.path.basename(urlparse(media_ref).path)
or os.path.basename(media_ref)
or "file"
# 1) Send media
for media_ref in msg.media or []:
ok = await self._send_media(
chat_id=msg.chat_id,
media_ref=media_ref,
msg_id=msg_id,
is_group=is_group,
)
if not ok:
filename = (
os.path.basename(urlparse(media_ref).path)
or os.path.basename(media_ref)
or "file"
)
await self._send_text_only(
chat_id=msg.chat_id,
is_group=is_group,
msg_id=msg_id,
content=f"[Attachment send failed: {filename}]",
)
# 2) Send text
if msg.content and msg.content.strip():
await self._send_text_only(
chat_id=msg.chat_id,
is_group=is_group,
msg_id=msg_id,
content=f"[Attachment send failed: {filename}]",
content=msg.content.strip(),
)
# 2) Send text
if msg.content and msg.content.strip():
await self._send_text_only(
chat_id=msg.chat_id,
is_group=is_group,
msg_id=msg_id,
content=msg.content.strip(),
)
except Exception:
logger.exception("Error sending QQ message to chat_id={}", msg.chat_id)
async def _send_text_only(
self,
@ -438,15 +441,26 @@ class QQChannel(BaseChannel):
endpoint = "/v2/users/{openid}/files"
id_key = "openid"
payload = {
payload: dict[str, Any] = {
id_key: chat_id,
"file_type": file_type,
"file_data": file_data,
"file_name": file_name,
"srv_send_msg": srv_send_msg,
}
# Only pass file_name for non-image types (file_type=4).
# Passing file_name for images causes QQ client to render them as
# file attachments instead of inline images.
if file_type != QQ_FILE_TYPE_IMAGE and file_name:
payload["file_name"] = file_name
route = Route("POST", endpoint, **{id_key: chat_id})
return await self._client.api._http.request(route, json=payload)
result = await self._client.api._http.request(route, json=payload)
# Extract only the file_info field to avoid extra fields (file_uuid, ttl, etc.)
# that may confuse QQ client when sending the media object.
if isinstance(result, dict) and "file_info" in result:
return {"file_info": result["file_info"]}
return result
# ---------------------------
# Inbound (receive)
@ -454,58 +468,68 @@ class QQChannel(BaseChannel):
async def _on_message(self, data: C2CMessage | GroupMessage, is_group: bool = False) -> None:
"""Parse inbound message, download attachments, and publish to the bus."""
if data.id in self._processed_ids:
return
self._processed_ids.append(data.id)
try:
if data.id in self._processed_ids:
return
self._processed_ids.append(data.id)
if is_group:
chat_id = data.group_openid
user_id = data.author.member_openid
self._chat_type_cache[chat_id] = "group"
else:
chat_id = str(
getattr(data.author, "id", None) or getattr(data.author, "user_openid", "unknown")
)
user_id = chat_id
self._chat_type_cache[chat_id] = "c2c"
content = (data.content or "").strip()
# the data used by tests don't contain attachments property
# so we use getattr with a default of [] to avoid AttributeError in tests
attachments = getattr(data, "attachments", None) or []
media_paths, recv_lines, att_meta = await self._handle_attachments(attachments)
# Compose content that always contains actionable saved paths
if recv_lines:
tag = "[Image]" if any(_is_image_name(Path(p).name) for p in media_paths) else "[File]"
file_block = "Received files:\n" + "\n".join(recv_lines)
content = f"{content}\n\n{file_block}".strip() if content else f"{tag}\n{file_block}"
if not content and not media_paths:
return
if self.config.ack_message:
try:
await self._send_text_only(
chat_id=chat_id,
is_group=is_group,
msg_id=data.id,
content=self.config.ack_message,
if is_group:
chat_id = data.group_openid
user_id = data.author.member_openid
self._chat_type_cache[chat_id] = "group"
else:
chat_id = str(
getattr(data.author, "id", None)
or getattr(data.author, "user_openid", "unknown")
)
except Exception:
logger.debug("QQ ack message failed for chat_id={}", chat_id)
user_id = chat_id
self._chat_type_cache[chat_id] = "c2c"
await self._handle_message(
sender_id=user_id,
chat_id=chat_id,
content=content,
media=media_paths if media_paths else None,
metadata={
"message_id": data.id,
"attachments": att_meta,
},
)
content = (data.content or "").strip()
# the data used by tests don't contain attachments property
# so we use getattr with a default of [] to avoid AttributeError in tests
attachments = getattr(data, "attachments", None) or []
media_paths, recv_lines, att_meta = await self._handle_attachments(attachments)
# Compose content that always contains actionable saved paths
if recv_lines:
tag = (
"[Image]"
if any(_is_image_name(Path(p).name) for p in media_paths)
else "[File]"
)
file_block = "Received files:\n" + "\n".join(recv_lines)
content = (
f"{content}\n\n{file_block}".strip() if content else f"{tag}\n{file_block}"
)
if not content and not media_paths:
return
if self.config.ack_message:
try:
await self._send_text_only(
chat_id=chat_id,
is_group=is_group,
msg_id=data.id,
content=self.config.ack_message,
)
except Exception:
logger.debug("QQ ack message failed for chat_id={}", chat_id)
await self._handle_message(
sender_id=user_id,
chat_id=chat_id,
content=content,
media=media_paths if media_paths else None,
metadata={
"message_id": data.id,
"attachments": att_meta,
},
)
except Exception:
logger.exception("Error handling QQ inbound message id={}", getattr(data, "id", "?"))
async def _handle_attachments(
self,
@ -520,7 +544,9 @@ class QQChannel(BaseChannel):
return media_paths, recv_lines, att_meta
for att in attachments:
url, filename, ctype = att.url, att.filename, att.content_type
url = getattr(att, "url", None) or ""
filename = getattr(att, "filename", None) or ""
ctype = getattr(att, "content_type", None) or ""
logger.info("Downloading file from QQ: {}", filename or url)
local_path = await self._download_to_media_dir_chunked(url, filename_hint=filename)
@ -555,6 +581,10 @@ class QQChannel(BaseChannel):
Enforces a max download size and writes to a .part temp file
that is atomically renamed on success.
"""
# Handle protocol-relative URLs (e.g. "//multimedia.nt.qq.com/...")
if url.startswith("//"):
url = f"https:{url}"
if not self._http:
self._http = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=120))

View File

@ -1,9 +1,13 @@
"""WeCom (Enterprise WeChat) channel implementation using wecom_aibot_sdk."""
import asyncio
import base64
import hashlib
import importlib.util
import os
import re
from collections import OrderedDict
from pathlib import Path
from typing import Any
from loguru import logger
@ -17,6 +21,37 @@ from pydantic import Field
WECOM_AVAILABLE = importlib.util.find_spec("wecom_aibot_sdk") is not None
# Upload safety limits (matching QQ channel defaults)
WECOM_UPLOAD_MAX_BYTES = 1024 * 1024 * 200 # 200MB
# Replace unsafe characters with "_", keep Chinese and common safe punctuation.
_SAFE_NAME_RE = re.compile(r"[^\w.\-()\[\]()【】\u4e00-\u9fff]+", re.UNICODE)
def _sanitize_filename(name: str) -> str:
"""Sanitize filename to avoid traversal and problematic chars."""
name = (name or "").strip()
name = Path(name).name
name = _SAFE_NAME_RE.sub("_", name).strip("._ ")
return name
_IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp"}
_VIDEO_EXTS = {".mp4", ".avi", ".mov"}
_AUDIO_EXTS = {".amr", ".mp3", ".wav", ".ogg"}
def _guess_wecom_media_type(filename: str) -> str:
"""Classify file extension as WeCom media_type string."""
ext = Path(filename).suffix.lower()
if ext in _IMAGE_EXTS:
return "image"
if ext in _VIDEO_EXTS:
return "video"
if ext in _AUDIO_EXTS:
return "voice"
return "file"
class WecomConfig(Base):
"""WeCom (Enterprise WeChat) AI Bot channel configuration."""
@ -217,6 +252,7 @@ class WecomChannel(BaseChannel):
chat_id = body.get("chatid", sender_id)
content_parts = []
media_paths: list[str] = []
if msg_type == "text":
text = body.get("text", {}).get("content", "")
@ -232,7 +268,8 @@ class WecomChannel(BaseChannel):
file_path = await self._download_and_save_media(file_url, aes_key, "image")
if file_path:
filename = os.path.basename(file_path)
content_parts.append(f"[image: {filename}]\n[Image: source: {file_path}]")
content_parts.append(f"[image: {filename}]")
media_paths.append(file_path)
else:
content_parts.append("[image: download failed]")
else:
@ -256,7 +293,8 @@ class WecomChannel(BaseChannel):
if file_url and aes_key:
file_path = await self._download_and_save_media(file_url, aes_key, "file", file_name)
if file_path:
content_parts.append(f"[file: {file_name}]\n[File: source: {file_path}]")
content_parts.append(f"[file: {file_name}]")
media_paths.append(file_path)
else:
content_parts.append(f"[file: {file_name}: download failed]")
else:
@ -286,12 +324,11 @@ class WecomChannel(BaseChannel):
self._chat_frames[chat_id] = frame
# Forward to message bus
# Note: media paths are included in content for broader model compatibility
await self._handle_message(
sender_id=sender_id,
chat_id=chat_id,
content=content,
media=None,
media=media_paths or None,
metadata={
"message_id": msg_id,
"msg_type": msg_type,
@ -322,13 +359,21 @@ class WecomChannel(BaseChannel):
logger.warning("Failed to download media from WeCom")
return None
if len(data) > WECOM_UPLOAD_MAX_BYTES:
logger.warning(
"WeCom inbound media too large: {} bytes (max {})",
len(data),
WECOM_UPLOAD_MAX_BYTES,
)
return None
media_dir = get_media_dir("wecom")
if not filename:
filename = fname or f"{media_type}_{hash(file_url) % 100000}"
filename = os.path.basename(filename)
filename = _sanitize_filename(filename)
file_path = media_dir / filename
file_path.write_bytes(data)
await asyncio.to_thread(file_path.write_bytes, data)
logger.debug("Downloaded {} to {}", media_type, file_path)
return str(file_path)
@ -336,6 +381,100 @@ class WecomChannel(BaseChannel):
logger.error("Error downloading media: {}", e)
return None
async def _upload_media_ws(
self, client: Any, file_path: str,
) -> "tuple[str, str] | tuple[None, None]":
"""Upload a local file to WeCom via WebSocket 3-step protocol (base64).
Uses the WeCom WebSocket upload commands directly via
``client._ws_manager.send_reply()``:
``aibot_upload_media_init`` upload_id
``aibot_upload_media_chunk`` × N (512 KB raw per chunk, base64)
``aibot_upload_media_finish`` media_id
Returns (media_id, media_type) on success, (None, None) on failure.
"""
from wecom_aibot_sdk.utils import generate_req_id as _gen_req_id
try:
fname = os.path.basename(file_path)
media_type = _guess_wecom_media_type(fname)
# Read file size and data in a thread to avoid blocking the event loop
def _read_file():
file_size = os.path.getsize(file_path)
if file_size > WECOM_UPLOAD_MAX_BYTES:
raise ValueError(
f"File too large: {file_size} bytes (max {WECOM_UPLOAD_MAX_BYTES})"
)
with open(file_path, "rb") as f:
return file_size, f.read()
file_size, data = await asyncio.to_thread(_read_file)
# MD5 is used for file integrity only, not cryptographic security
md5_hash = hashlib.md5(data).hexdigest()
CHUNK_SIZE = 512 * 1024 # 512 KB raw (before base64)
mv = memoryview(data)
chunk_list = [bytes(mv[i : i + CHUNK_SIZE]) for i in range(0, file_size, CHUNK_SIZE)]
n_chunks = len(chunk_list)
del mv, data
# Step 1: init
req_id = _gen_req_id("upload_init")
resp = await client._ws_manager.send_reply(req_id, {
"type": media_type,
"filename": fname,
"total_size": file_size,
"total_chunks": n_chunks,
"md5": md5_hash,
}, "aibot_upload_media_init")
if resp.errcode != 0:
logger.warning("WeCom upload init failed ({}): {}", resp.errcode, resp.errmsg)
return None, None
upload_id = resp.body.get("upload_id") if resp.body else None
if not upload_id:
logger.warning("WeCom upload init: no upload_id in response")
return None, None
# Step 2: send chunks
for i, chunk in enumerate(chunk_list):
req_id = _gen_req_id("upload_chunk")
resp = await client._ws_manager.send_reply(req_id, {
"upload_id": upload_id,
"chunk_index": i,
"base64_data": base64.b64encode(chunk).decode(),
}, "aibot_upload_media_chunk")
if resp.errcode != 0:
logger.warning("WeCom upload chunk {} failed ({}): {}", i, resp.errcode, resp.errmsg)
return None, None
# Step 3: finish
req_id = _gen_req_id("upload_finish")
resp = await client._ws_manager.send_reply(req_id, {
"upload_id": upload_id,
}, "aibot_upload_media_finish")
if resp.errcode != 0:
logger.warning("WeCom upload finish failed ({}): {}", resp.errcode, resp.errmsg)
return None, None
media_id = resp.body.get("media_id") if resp.body else None
if not media_id:
logger.warning("WeCom upload finish: no media_id in response body={}", resp.body)
return None, None
suffix = "..." if len(media_id) > 16 else ""
logger.debug("WeCom uploaded {} ({}) → media_id={}", fname, media_type, media_id[:16] + suffix)
return media_id, media_type
except ValueError as e:
logger.warning("WeCom upload skipped for {}: {}", file_path, e)
return None, None
except Exception as e:
logger.error("WeCom _upload_media_ws error for {}: {}", file_path, e)
return None, None
async def send(self, msg: OutboundMessage) -> None:
"""Send a message through WeCom."""
if not self._client:
@ -343,29 +482,59 @@ class WecomChannel(BaseChannel):
return
try:
content = msg.content.strip()
if not content:
return
content = (msg.content or "").strip()
is_progress = bool(msg.metadata.get("_progress"))
# Get the stored frame for this chat
frame = self._chat_frames.get(msg.chat_id)
if not frame:
logger.warning("No frame found for chat {}, cannot reply", msg.chat_id)
# Send media files via WebSocket upload
for file_path in msg.media or []:
if not os.path.isfile(file_path):
logger.warning("WeCom media file not found: {}", file_path)
continue
media_id, media_type = await self._upload_media_ws(self._client, file_path)
if media_id:
if frame:
await self._client.reply(frame, {
"msgtype": media_type,
media_type: {"media_id": media_id},
})
else:
await self._client.send_message(msg.chat_id, {
"msgtype": media_type,
media_type: {"media_id": media_id},
})
logger.debug("WeCom sent {}{}", media_type, msg.chat_id)
else:
content += f"\n[file upload failed: {os.path.basename(file_path)}]"
if not content:
return
# Use streaming reply for better UX
stream_id = self._generate_req_id("stream")
if frame:
# Both progress and final messages must use reply_stream (cmd="aibot_respond_msg").
# The plain reply() uses cmd="reply" which does not support "text" msgtype
# and causes errcode=40008 from WeCom API.
stream_id = self._generate_req_id("stream")
await self._client.reply_stream(
frame,
stream_id,
content,
finish=not is_progress,
)
logger.debug(
"WeCom {} sent to {}",
"progress" if is_progress else "message",
msg.chat_id,
)
else:
# No frame (e.g. cron push): proactive send only supports markdown
await self._client.send_message(msg.chat_id, {
"msgtype": "markdown",
"markdown": {"content": content},
})
logger.info("WeCom proactive send to {}", msg.chat_id)
# Send as streaming message with finish=True
await self._client.reply_stream(
frame,
stream_id,
content,
finish=True,
)
logger.debug("WeCom message sent to {}", msg.chat_id)
except Exception as e:
logger.error("Error sending WeCom message: {}", e)
raise
except Exception:
logger.exception("Error sending WeCom message to chat_id={}", msg.chat_id)

View File

@ -307,7 +307,7 @@ async def test_agent_loop_extra_hook_receives_calls(tmp_path):
)
loop.tools.get_definitions = MagicMock(return_value=[])
content, tools_used, messages, _ = await loop._run_agent_loop(
content, tools_used, messages, _, _ = await loop._run_agent_loop(
[{"role": "user", "content": "hi"}]
)
@ -331,7 +331,7 @@ async def test_agent_loop_extra_hook_error_isolation(tmp_path):
)
loop.tools.get_definitions = MagicMock(return_value=[])
content, _, _, _ = await loop._run_agent_loop(
content, _, _, _, _ = await loop._run_agent_loop(
[{"role": "user", "content": "hi"}]
)
@ -373,7 +373,7 @@ async def test_agent_loop_no_hooks_backward_compat(tmp_path):
loop.tools.execute = AsyncMock(return_value="ok")
loop.max_iterations = 2
content, tools_used, _, _ = await loop._run_agent_loop([])
content, tools_used, _, _, _ = await loop._run_agent_loop([])
assert content == (
"I reached the maximum number of tool call iterations (2) "
"without completing the task. You can try breaking the task into smaller steps."

View File

@ -3,6 +3,7 @@
from __future__ import annotations
import asyncio
import base64
import os
import time
from unittest.mock import AsyncMock, MagicMock, patch
@ -798,7 +799,7 @@ async def test_loop_max_iterations_message_stays_stable(tmp_path):
loop.tools.execute = AsyncMock(return_value="ok")
loop.max_iterations = 2
final_content, _, _, _ = await loop._run_agent_loop([])
final_content, _, _, _, _ = await loop._run_agent_loop([])
assert final_content == (
"I reached the maximum number of tool call iterations (2) "
@ -825,7 +826,7 @@ async def test_loop_stream_filter_handles_think_only_prefix_without_crashing(tmp
async def on_stream_end(*, resuming: bool = False) -> None:
endings.append(resuming)
final_content, _, _, _ = await loop._run_agent_loop(
final_content, _, _, _, _ = await loop._run_agent_loop(
[],
on_stream=on_stream,
on_stream_end=on_stream_end,
@ -849,7 +850,7 @@ async def test_loop_retries_think_only_final_response(tmp_path):
loop.provider.chat_with_retry = chat_with_retry
final_content, _, _, _ = await loop._run_agent_loop([])
final_content, _, _, _, _ = await loop._run_agent_loop([])
assert final_content == "Recovered answer"
assert call_count["n"] == 2
@ -1722,3 +1723,690 @@ def test_governance_fallback_still_repairs_orphans():
repaired = AgentRunner._backfill_missing_tool_results(repaired)
# Orphan tool result should be gone.
assert not any(m.get("tool_call_id") == "orphan_tc" for m in repaired)
# ── Mid-turn injection tests ──────────────────────────────────────────────
@pytest.mark.asyncio
async def test_drain_injections_returns_empty_when_no_callback():
"""No injection_callback → empty list."""
from nanobot.agent.runner import AgentRunSpec, AgentRunner
provider = MagicMock()
runner = AgentRunner(provider)
tools = MagicMock()
tools.get_definitions.return_value = []
spec = AgentRunSpec(
initial_messages=[], tools=tools, model="m",
max_iterations=1, max_tool_result_chars=1000,
injection_callback=None,
)
result = await runner._drain_injections(spec)
assert result == []
@pytest.mark.asyncio
async def test_drain_injections_extracts_content_from_inbound_messages():
"""Should extract .content from InboundMessage objects."""
from nanobot.agent.runner import AgentRunSpec, AgentRunner
from nanobot.bus.events import InboundMessage
provider = MagicMock()
runner = AgentRunner(provider)
tools = MagicMock()
tools.get_definitions.return_value = []
msgs = [
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="hello"),
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="world"),
]
async def cb():
return msgs
spec = AgentRunSpec(
initial_messages=[], tools=tools, model="m",
max_iterations=1, max_tool_result_chars=1000,
injection_callback=cb,
)
result = await runner._drain_injections(spec)
assert result == [
{"role": "user", "content": "hello"},
{"role": "user", "content": "world"},
]
@pytest.mark.asyncio
async def test_drain_injections_passes_limit_to_callback_when_supported():
"""Limit-aware callbacks can preserve overflow in their own queue."""
from nanobot.agent.runner import AgentRunSpec, AgentRunner, _MAX_INJECTIONS_PER_TURN
from nanobot.bus.events import InboundMessage
provider = MagicMock()
runner = AgentRunner(provider)
tools = MagicMock()
tools.get_definitions.return_value = []
seen_limits: list[int] = []
msgs = [
InboundMessage(channel="cli", sender_id="u", chat_id="c", content=f"msg{i}")
for i in range(_MAX_INJECTIONS_PER_TURN + 3)
]
async def cb(*, limit: int):
seen_limits.append(limit)
return msgs[:limit]
spec = AgentRunSpec(
initial_messages=[], tools=tools, model="m",
max_iterations=1, max_tool_result_chars=1000,
injection_callback=cb,
)
result = await runner._drain_injections(spec)
assert seen_limits == [_MAX_INJECTIONS_PER_TURN]
assert result == [
{"role": "user", "content": "msg0"},
{"role": "user", "content": "msg1"},
{"role": "user", "content": "msg2"},
]
@pytest.mark.asyncio
async def test_drain_injections_skips_empty_content():
"""Messages with blank content should be filtered out."""
from nanobot.agent.runner import AgentRunSpec, AgentRunner
from nanobot.bus.events import InboundMessage
provider = MagicMock()
runner = AgentRunner(provider)
tools = MagicMock()
tools.get_definitions.return_value = []
msgs = [
InboundMessage(channel="cli", sender_id="u", chat_id="c", content=""),
InboundMessage(channel="cli", sender_id="u", chat_id="c", content=" "),
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="valid"),
]
async def cb():
return msgs
spec = AgentRunSpec(
initial_messages=[], tools=tools, model="m",
max_iterations=1, max_tool_result_chars=1000,
injection_callback=cb,
)
result = await runner._drain_injections(spec)
assert result == [{"role": "user", "content": "valid"}]
@pytest.mark.asyncio
async def test_drain_injections_handles_callback_exception():
"""If the callback raises, return empty list (error is logged)."""
from nanobot.agent.runner import AgentRunSpec, AgentRunner
provider = MagicMock()
runner = AgentRunner(provider)
tools = MagicMock()
tools.get_definitions.return_value = []
async def cb():
raise RuntimeError("boom")
spec = AgentRunSpec(
initial_messages=[], tools=tools, model="m",
max_iterations=1, max_tool_result_chars=1000,
injection_callback=cb,
)
result = await runner._drain_injections(spec)
assert result == []
@pytest.mark.asyncio
async def test_checkpoint1_injects_after_tool_execution():
"""Follow-up messages are injected after tool execution, before next LLM call."""
from nanobot.agent.runner import AgentRunSpec, AgentRunner
from nanobot.bus.events import InboundMessage
provider = MagicMock()
call_count = {"n": 0}
captured_messages = []
async def chat_with_retry(*, messages, **kwargs):
call_count["n"] += 1
captured_messages.append(list(messages))
if call_count["n"] == 1:
return LLMResponse(
content="using tool",
tool_calls=[ToolCallRequest(id="c1", name="read_file", arguments={"path": "x"})],
usage={},
)
return LLMResponse(content="final answer", tool_calls=[], usage={})
provider.chat_with_retry = chat_with_retry
tools = MagicMock()
tools.get_definitions.return_value = []
tools.execute = AsyncMock(return_value="file content")
injection_queue = asyncio.Queue()
async def inject_cb():
items = []
while not injection_queue.empty():
items.append(await injection_queue.get())
return items
# Put a follow-up message in the queue before the run starts
await injection_queue.put(
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up question")
)
runner = AgentRunner(provider)
result = await runner.run(AgentRunSpec(
initial_messages=[{"role": "user", "content": "hello"}],
tools=tools,
model="test-model",
max_iterations=5,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
injection_callback=inject_cb,
))
assert result.had_injections is True
assert result.final_content == "final answer"
# The second call should have the injected user message
assert call_count["n"] == 2
last_messages = captured_messages[-1]
injected = [m for m in last_messages if m.get("role") == "user" and m.get("content") == "follow-up question"]
assert len(injected) == 1
@pytest.mark.asyncio
async def test_checkpoint2_injects_after_final_response_with_resuming_stream():
"""After final response, if injections exist, stream_end should get resuming=True."""
from nanobot.agent.runner import AgentRunSpec, AgentRunner
from nanobot.agent.hook import AgentHook, AgentHookContext
from nanobot.bus.events import InboundMessage
provider = MagicMock()
call_count = {"n": 0}
stream_end_calls = []
class TrackingHook(AgentHook):
def wants_streaming(self) -> bool:
return True
async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None:
stream_end_calls.append(resuming)
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
return content
async def chat_stream_with_retry(*, messages, on_content_delta=None, **kwargs):
call_count["n"] += 1
if call_count["n"] == 1:
return LLMResponse(content="first answer", tool_calls=[], usage={})
return LLMResponse(content="second answer", tool_calls=[], usage={})
provider.chat_stream_with_retry = chat_stream_with_retry
tools = MagicMock()
tools.get_definitions.return_value = []
injection_queue = asyncio.Queue()
async def inject_cb():
items = []
while not injection_queue.empty():
items.append(await injection_queue.get())
return items
# Inject a follow-up that arrives during the first response
await injection_queue.put(
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="quick follow-up")
)
runner = AgentRunner(provider)
result = await runner.run(AgentRunSpec(
initial_messages=[{"role": "user", "content": "hello"}],
tools=tools,
model="test-model",
max_iterations=5,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
hook=TrackingHook(),
injection_callback=inject_cb,
))
assert result.had_injections is True
assert result.final_content == "second answer"
assert call_count["n"] == 2
# First stream_end should have resuming=True (because injections found)
assert stream_end_calls[0] is True
# Second (final) stream_end should have resuming=False
assert stream_end_calls[-1] is False
@pytest.mark.asyncio
async def test_checkpoint2_preserves_final_response_in_history_before_followup():
"""A follow-up injected after a final answer must still see that answer in history."""
from nanobot.agent.runner import AgentRunSpec, AgentRunner
from nanobot.bus.events import InboundMessage
provider = MagicMock()
call_count = {"n": 0}
captured_messages = []
async def chat_with_retry(*, messages, **kwargs):
call_count["n"] += 1
captured_messages.append([dict(message) for message in messages])
if call_count["n"] == 1:
return LLMResponse(content="first answer", tool_calls=[], usage={})
return LLMResponse(content="second answer", tool_calls=[], usage={})
provider.chat_with_retry = chat_with_retry
tools = MagicMock()
tools.get_definitions.return_value = []
injection_queue = asyncio.Queue()
async def inject_cb():
items = []
while not injection_queue.empty():
items.append(await injection_queue.get())
return items
await injection_queue.put(
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up question")
)
runner = AgentRunner(provider)
result = await runner.run(AgentRunSpec(
initial_messages=[{"role": "user", "content": "hello"}],
tools=tools,
model="test-model",
max_iterations=5,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
injection_callback=inject_cb,
))
assert result.final_content == "second answer"
assert call_count["n"] == 2
assert captured_messages[-1] == [
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "first answer"},
{"role": "user", "content": "follow-up question"},
]
assert [
{"role": message["role"], "content": message["content"]}
for message in result.messages
if message.get("role") == "assistant"
] == [
{"role": "assistant", "content": "first answer"},
{"role": "assistant", "content": "second answer"},
]
@pytest.mark.asyncio
async def test_loop_injected_followup_preserves_image_media(tmp_path):
"""Mid-turn follow-ups with images should keep multimodal content."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
image_path = tmp_path / "followup.png"
image_path.write_bytes(base64.b64decode(
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO+yF9kAAAAASUVORK5CYII="
))
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
captured_messages: list[list[dict]] = []
call_count = {"n": 0}
async def chat_with_retry(*, messages, **kwargs):
call_count["n"] += 1
captured_messages.append(list(messages))
if call_count["n"] == 1:
return LLMResponse(content="first answer", tool_calls=[], usage={})
return LLMResponse(content="second answer", tool_calls=[], usage={})
provider.chat_with_retry = chat_with_retry
loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model")
loop.tools.get_definitions = MagicMock(return_value=[])
pending_queue = asyncio.Queue()
await pending_queue.put(InboundMessage(
channel="cli",
sender_id="u",
chat_id="c",
content="",
media=[str(image_path)],
))
final_content, _, _, _, had_injections = await loop._run_agent_loop(
[{"role": "user", "content": "hello"}],
channel="cli",
chat_id="c",
pending_queue=pending_queue,
)
assert final_content == "second answer"
assert had_injections is True
assert call_count["n"] == 2
injected_user_messages = [
message for message in captured_messages[-1]
if message.get("role") == "user" and isinstance(message.get("content"), list)
]
assert injected_user_messages
assert any(
block.get("type") == "image_url"
for block in injected_user_messages[-1]["content"]
if isinstance(block, dict)
)
@pytest.mark.asyncio
async def test_runner_merges_multiple_injected_user_messages_without_losing_media():
"""Multiple injected follow-ups should not create lossy consecutive user messages."""
from nanobot.agent.runner import AgentRunSpec, AgentRunner
provider = MagicMock()
call_count = {"n": 0}
captured_messages = []
async def chat_with_retry(*, messages, **kwargs):
call_count["n"] += 1
captured_messages.append([dict(message) for message in messages])
if call_count["n"] == 1:
return LLMResponse(content="first answer", tool_calls=[], usage={})
return LLMResponse(content="second answer", tool_calls=[], usage={})
provider.chat_with_retry = chat_with_retry
tools = MagicMock()
tools.get_definitions.return_value = []
async def inject_cb():
if call_count["n"] == 1:
return [
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}},
{"type": "text", "text": "look at this"},
],
},
{"role": "user", "content": "and answer briefly"},
]
return []
runner = AgentRunner(provider)
result = await runner.run(AgentRunSpec(
initial_messages=[{"role": "user", "content": "hello"}],
tools=tools,
model="test-model",
max_iterations=5,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
injection_callback=inject_cb,
))
assert result.final_content == "second answer"
assert call_count["n"] == 2
second_call = captured_messages[-1]
user_messages = [message for message in second_call if message.get("role") == "user"]
assert len(user_messages) == 2
injected = user_messages[-1]
assert isinstance(injected["content"], list)
assert any(
block.get("type") == "image_url"
for block in injected["content"]
if isinstance(block, dict)
)
assert any(
block.get("type") == "text" and block.get("text") == "and answer briefly"
for block in injected["content"]
if isinstance(block, dict)
)
@pytest.mark.asyncio
async def test_injection_cycles_capped_at_max():
"""Injection cycles should be capped at _MAX_INJECTION_CYCLES."""
from nanobot.agent.runner import AgentRunSpec, AgentRunner, _MAX_INJECTION_CYCLES
from nanobot.bus.events import InboundMessage
provider = MagicMock()
call_count = {"n": 0}
async def chat_with_retry(*, messages, **kwargs):
call_count["n"] += 1
return LLMResponse(content=f"answer-{call_count['n']}", tool_calls=[], usage={})
provider.chat_with_retry = chat_with_retry
tools = MagicMock()
tools.get_definitions.return_value = []
drain_count = {"n": 0}
async def inject_cb():
drain_count["n"] += 1
# Only inject for the first _MAX_INJECTION_CYCLES drains
if drain_count["n"] <= _MAX_INJECTION_CYCLES:
return [InboundMessage(channel="cli", sender_id="u", chat_id="c", content=f"msg-{drain_count['n']}")]
return []
runner = AgentRunner(provider)
result = await runner.run(AgentRunSpec(
initial_messages=[{"role": "user", "content": "start"}],
tools=tools,
model="test-model",
max_iterations=20,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
injection_callback=inject_cb,
))
assert result.had_injections is True
# Should be capped: _MAX_INJECTION_CYCLES injection rounds + 1 final round
assert call_count["n"] == _MAX_INJECTION_CYCLES + 1
@pytest.mark.asyncio
async def test_no_injections_flag_is_false_by_default():
"""had_injections should be False when no injection callback or no messages."""
from nanobot.agent.runner import AgentRunSpec, AgentRunner
provider = MagicMock()
async def chat_with_retry(**kwargs):
return LLMResponse(content="done", tool_calls=[], usage={})
provider.chat_with_retry = chat_with_retry
tools = MagicMock()
tools.get_definitions.return_value = []
runner = AgentRunner(provider)
result = await runner.run(AgentRunSpec(
initial_messages=[{"role": "user", "content": "hi"}],
tools=tools,
model="test-model",
max_iterations=1,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
))
assert result.had_injections is False
@pytest.mark.asyncio
async def test_pending_queue_cleanup_on_dispatch(tmp_path):
"""_pending_queues should be cleaned up after _dispatch completes."""
loop = _make_loop(tmp_path)
async def chat_with_retry(**kwargs):
return LLMResponse(content="done", tool_calls=[], usage={})
loop.provider.chat_with_retry = chat_with_retry
from nanobot.bus.events import InboundMessage
msg = InboundMessage(channel="cli", sender_id="u", chat_id="c", content="hello")
# The queue should not exist before dispatch
assert msg.session_key not in loop._pending_queues
await loop._dispatch(msg)
# The queue should be cleaned up after dispatch
assert msg.session_key not in loop._pending_queues
@pytest.mark.asyncio
async def test_followup_routed_to_pending_queue(tmp_path):
"""Unified-session follow-ups should route into the active pending queue."""
from nanobot.agent.loop import UNIFIED_SESSION_KEY
from nanobot.bus.events import InboundMessage
loop = _make_loop(tmp_path)
loop._unified_session = True
loop._dispatch = AsyncMock() # type: ignore[method-assign]
pending = asyncio.Queue(maxsize=20)
loop._pending_queues[UNIFIED_SESSION_KEY] = pending
run_task = asyncio.create_task(loop.run())
msg = InboundMessage(channel="discord", sender_id="u", chat_id="c", content="follow-up")
await loop.bus.publish_inbound(msg)
deadline = time.time() + 2
while pending.empty() and time.time() < deadline:
await asyncio.sleep(0.01)
loop.stop()
await asyncio.wait_for(run_task, timeout=2)
assert loop._dispatch.await_count == 0
assert not pending.empty()
queued_msg = pending.get_nowait()
assert queued_msg.content == "follow-up"
assert queued_msg.session_key == UNIFIED_SESSION_KEY
@pytest.mark.asyncio
async def test_pending_queue_preserves_overflow_for_next_injection_cycle(tmp_path):
"""Pending queue should leave overflow messages queued for later drains."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.agent.runner import _MAX_INJECTIONS_PER_TURN
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
captured_messages: list[list[dict]] = []
call_count = {"n": 0}
async def chat_with_retry(*, messages, **kwargs):
call_count["n"] += 1
captured_messages.append([dict(message) for message in messages])
return LLMResponse(content=f"answer-{call_count['n']}", tool_calls=[], usage={})
provider.chat_with_retry = chat_with_retry
loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model")
loop.tools.get_definitions = MagicMock(return_value=[])
pending_queue = asyncio.Queue()
total_followups = _MAX_INJECTIONS_PER_TURN + 2
for idx in range(total_followups):
await pending_queue.put(InboundMessage(
channel="cli",
sender_id="u",
chat_id="c",
content=f"follow-up-{idx}",
))
final_content, _, _, _, had_injections = await loop._run_agent_loop(
[{"role": "user", "content": "hello"}],
channel="cli",
chat_id="c",
pending_queue=pending_queue,
)
assert final_content == "answer-3"
assert had_injections is True
assert call_count["n"] == 3
flattened_user_content = "\n".join(
message["content"]
for message in captured_messages[-1]
if message.get("role") == "user" and isinstance(message.get("content"), str)
)
for idx in range(total_followups):
assert f"follow-up-{idx}" in flattened_user_content
assert pending_queue.empty()
@pytest.mark.asyncio
async def test_pending_queue_full_falls_back_to_queued_task(tmp_path):
"""QueueFull should preserve the message by dispatching a queued task."""
from nanobot.bus.events import InboundMessage
loop = _make_loop(tmp_path)
loop._dispatch = AsyncMock() # type: ignore[method-assign]
pending = asyncio.Queue(maxsize=1)
pending.put_nowait(InboundMessage(channel="cli", sender_id="u", chat_id="c", content="already queued"))
loop._pending_queues["cli:c"] = pending
run_task = asyncio.create_task(loop.run())
msg = InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up")
await loop.bus.publish_inbound(msg)
deadline = time.time() + 2
while loop._dispatch.await_count == 0 and time.time() < deadline:
await asyncio.sleep(0.01)
loop.stop()
await asyncio.wait_for(run_task, timeout=2)
assert loop._dispatch.await_count == 1
dispatched_msg = loop._dispatch.await_args.args[0]
assert dispatched_msg.content == "follow-up"
assert pending.qsize() == 1
@pytest.mark.asyncio
async def test_dispatch_republishes_leftover_queue_messages(tmp_path):
"""Messages left in the pending queue after _dispatch are re-published to the bus.
This tests the finally-block cleanup that prevents message loss when
the runner exits early (e.g., max_iterations, tool_error) with messages
still in the queue.
"""
from nanobot.bus.events import InboundMessage
loop = _make_loop(tmp_path)
bus = loop.bus
# Simulate a completed dispatch by manually registering a queue
# with leftover messages, then running the cleanup logic directly.
pending = asyncio.Queue(maxsize=20)
session_key = "cli:c"
loop._pending_queues[session_key] = pending
pending.put_nowait(InboundMessage(channel="cli", sender_id="u", chat_id="c", content="leftover-1"))
pending.put_nowait(InboundMessage(channel="cli", sender_id="u", chat_id="c", content="leftover-2"))
# Execute the cleanup logic from the finally block
queue = loop._pending_queues.pop(session_key, None)
assert queue is not None
leftover = 0
while True:
try:
item = queue.get_nowait()
except asyncio.QueueEmpty:
break
await bus.publish_inbound(item)
leftover += 1
assert leftover == 2
# Verify the messages are now on the bus
msgs = []
while not bus.inbound.empty():
msgs.append(await asyncio.wait_for(bus.consume_inbound(), timeout=0.5))
contents = [m.content for m in msgs]
assert "leftover-1" in contents
assert "leftover-2" in contents

View File

@ -0,0 +1,304 @@
"""Tests for QQ channel media support: helpers, send, inbound, and upload."""
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
try:
from nanobot.channels import qq
QQ_AVAILABLE = getattr(qq, "QQ_AVAILABLE", False)
except ImportError:
QQ_AVAILABLE = False
if not QQ_AVAILABLE:
pytest.skip("QQ dependencies not installed (qq-botpy)", allow_module_level=True)
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.qq import (
QQ_FILE_TYPE_FILE,
QQ_FILE_TYPE_IMAGE,
QQChannel,
QQConfig,
_guess_send_file_type,
_is_image_name,
_sanitize_filename,
)
class _FakeApi:
def __init__(self) -> None:
self.c2c_calls: list[dict] = []
self.group_calls: list[dict] = []
async def post_c2c_message(self, **kwargs) -> None:
self.c2c_calls.append(kwargs)
async def post_group_message(self, **kwargs) -> None:
self.group_calls.append(kwargs)
class _FakeHttp:
"""Fake _http for _post_base64file tests."""
def __init__(self, return_value: dict | None = None) -> None:
self.return_value = return_value or {}
self.calls: list[tuple] = []
async def request(self, route, **kwargs):
self.calls.append((route, kwargs))
return self.return_value
class _FakeClient:
def __init__(self, http_return: dict | None = None) -> None:
self.api = _FakeApi()
self.api._http = _FakeHttp(http_return)
# ── Helper function tests (pure, no async) ──────────────────────────
def test_sanitize_filename_strips_path_traversal() -> None:
assert _sanitize_filename("../../etc/passwd") == "passwd"
def test_sanitize_filename_keeps_chinese_chars() -> None:
assert _sanitize_filename("文件1.jpg") == "文件1.jpg"
def test_sanitize_filename_strips_unsafe_chars() -> None:
result = _sanitize_filename('file<>:"|?*.txt')
# All unsafe chars replaced with "_", but * is replaced too
assert result.startswith("file")
assert result.endswith(".txt")
assert "<" not in result
assert ">" not in result
assert '"' not in result
assert "|" not in result
assert "?" not in result
def test_sanitize_filename_empty_input() -> None:
assert _sanitize_filename("") == ""
def test_is_image_name_with_known_extensions() -> None:
for ext in (".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp", ".tif", ".tiff", ".ico", ".svg"):
assert _is_image_name(f"photo{ext}") is True
def test_is_image_name_with_unknown_extension() -> None:
for ext in (".pdf", ".txt", ".mp3", ".mp4"):
assert _is_image_name(f"doc{ext}") is False
def test_guess_send_file_type_image() -> None:
assert _guess_send_file_type("photo.png") == QQ_FILE_TYPE_IMAGE
assert _guess_send_file_type("pic.jpg") == QQ_FILE_TYPE_IMAGE
def test_guess_send_file_type_file() -> None:
assert _guess_send_file_type("doc.pdf") == QQ_FILE_TYPE_FILE
def test_guess_send_file_type_by_mime() -> None:
# A filename with no known extension but whose mime type is image/*
assert _guess_send_file_type("photo.xyz_image_test") == QQ_FILE_TYPE_FILE
# ── send() exception handling ───────────────────────────────────────
@pytest.mark.asyncio
async def test_send_exception_caught_not_raised() -> None:
"""Exceptions inside send() must not propagate."""
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
channel._client = _FakeClient()
with patch.object(channel, "_send_text_only", new_callable=AsyncMock, side_effect=RuntimeError("boom")):
await channel.send(
OutboundMessage(channel="qq", chat_id="user1", content="hello")
)
# No exception raised — test passes if we get here.
@pytest.mark.asyncio
async def test_send_media_then_text() -> None:
"""Media is sent before text when both are present."""
import tempfile
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
channel._client = _FakeClient()
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
f.write(b"\x89PNG\r\n")
tmp = f.name
try:
with patch.object(channel, "_post_base64file", new_callable=AsyncMock, return_value={"file_info": "1"}) as mock_upload:
await channel.send(
OutboundMessage(
channel="qq",
chat_id="user1",
content="text after image",
media=[tmp],
metadata={"message_id": "m1"},
)
)
assert mock_upload.called
# Text should have been sent via c2c (default chat type)
text_calls = [c for c in channel._client.api.c2c_calls if c.get("msg_type") == 0]
assert len(text_calls) >= 1
assert text_calls[-1]["content"] == "text after image"
finally:
import os
os.unlink(tmp)
@pytest.mark.asyncio
async def test_send_media_failure_falls_back_to_text() -> None:
"""When _send_media returns False, a failure notice is appended."""
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
channel._client = _FakeClient()
with patch.object(channel, "_send_media", new_callable=AsyncMock, return_value=False):
await channel.send(
OutboundMessage(
channel="qq",
chat_id="user1",
content="hello",
media=["https://example.com/bad.png"],
metadata={"message_id": "m1"},
)
)
# Should have the failure text among the c2c calls
failure_calls = [c for c in channel._client.api.c2c_calls if "Attachment send failed" in c.get("content", "")]
assert len(failure_calls) == 1
assert "bad.png" in failure_calls[0]["content"]
# ── _on_message() exception handling ────────────────────────────────
@pytest.mark.asyncio
async def test_on_message_exception_caught_not_raised() -> None:
"""Missing required attributes should not crash _on_message."""
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
channel._client = _FakeClient()
# Construct a message-like object that lacks 'author' — triggers AttributeError
bad_data = SimpleNamespace(id="x1", content="hi")
# Should not raise
await channel._on_message(bad_data, is_group=False)
@pytest.mark.asyncio
async def test_on_message_with_attachments() -> None:
"""Messages with attachments produce media_paths and formatted content."""
import tempfile
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
channel._client = _FakeClient()
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
f.write(b"\x89PNG\r\n")
saved_path = f.name
att = SimpleNamespace(url="", filename="screenshot.png", content_type="image/png")
# Patch _download_to_media_dir_chunked to return the temp file path
async def fake_download(url, filename_hint=""):
return saved_path
try:
with patch.object(channel, "_download_to_media_dir_chunked", side_effect=fake_download):
data = SimpleNamespace(
id="att1",
content="look at this",
author=SimpleNamespace(user_openid="u1"),
attachments=[att],
)
await channel._on_message(data, is_group=False)
msg = await channel.bus.consume_inbound()
assert "look at this" in msg.content
assert "screenshot.png" in msg.content
assert "Received files:" in msg.content
assert len(msg.media) == 1
assert msg.media[0] == saved_path
finally:
import os
os.unlink(saved_path)
# ── _post_base64file() ─────────────────────────────────────────────
@pytest.mark.asyncio
async def test_post_base64file_omits_file_name_for_images() -> None:
"""file_type=1 (image) → payload must not contain file_name."""
channel = QQChannel(QQConfig(app_id="app", secret="secret"), MessageBus())
channel._client = _FakeClient(http_return={"file_info": "img_abc"})
await channel._post_base64file(
chat_id="user1",
is_group=False,
file_type=QQ_FILE_TYPE_IMAGE,
file_data="ZmFrZQ==",
file_name="photo.png",
)
http = channel._client.api._http
assert len(http.calls) == 1
payload = http.calls[0][1]["json"]
assert "file_name" not in payload
assert payload["file_type"] == QQ_FILE_TYPE_IMAGE
@pytest.mark.asyncio
async def test_post_base64file_includes_file_name_for_files() -> None:
"""file_type=4 (file) → payload must contain file_name."""
channel = QQChannel(QQConfig(app_id="app", secret="secret"), MessageBus())
channel._client = _FakeClient(http_return={"file_info": "file_abc"})
await channel._post_base64file(
chat_id="user1",
is_group=False,
file_type=QQ_FILE_TYPE_FILE,
file_data="ZmFrZQ==",
file_name="report.pdf",
)
http = channel._client.api._http
assert len(http.calls) == 1
payload = http.calls[0][1]["json"]
assert payload["file_name"] == "report.pdf"
assert payload["file_type"] == QQ_FILE_TYPE_FILE
@pytest.mark.asyncio
async def test_post_base64file_filters_response_to_file_info() -> None:
"""Response with file_info + extra fields must be filtered to only file_info."""
channel = QQChannel(QQConfig(app_id="app", secret="secret"), MessageBus())
channel._client = _FakeClient(http_return={
"file_info": "fi_123",
"file_uuid": "uuid_xxx",
"ttl": 3600,
})
result = await channel._post_base64file(
chat_id="user1",
is_group=False,
file_type=QQ_FILE_TYPE_FILE,
file_data="ZmFrZQ==",
file_name="doc.pdf",
)
assert result == {"file_info": "fi_123"}
assert "file_uuid" not in result
assert "ttl" not in result

View File

@ -0,0 +1,584 @@
"""Tests for WeCom channel: helpers, download, upload, send, and message processing."""
import os
import tempfile
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
try:
import importlib.util
WECOM_AVAILABLE = importlib.util.find_spec("wecom_aibot_sdk") is not None
except ImportError:
WECOM_AVAILABLE = False
if not WECOM_AVAILABLE:
pytest.skip("WeCom dependencies not installed (wecom_aibot_sdk)", allow_module_level=True)
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.wecom import (
WecomChannel,
WecomConfig,
_guess_wecom_media_type,
_sanitize_filename,
)
# Try to import the real response class; fall back to a stub if unavailable.
try:
from wecom_aibot_sdk.utils import WsResponse
_RealWsResponse = WsResponse
except ImportError:
_RealWsResponse = None
class _FakeResponse:
"""Minimal stand-in for wecom_aibot_sdk WsResponse."""
def __init__(self, errcode: int = 0, body: dict | None = None, errmsg: str = "ok"):
self.errcode = errcode
self.errmsg = errmsg
self.body = body or {}
class _FakeWsManager:
"""Tracks send_reply calls and returns configurable responses."""
def __init__(self, responses: list[_FakeResponse] | None = None):
self.responses = responses or []
self.calls: list[tuple[str, dict, str]] = []
self._idx = 0
async def send_reply(self, req_id: str, data: dict, cmd: str) -> _FakeResponse:
self.calls.append((req_id, data, cmd))
if self._idx < len(self.responses):
resp = self.responses[self._idx]
self._idx += 1
return resp
return _FakeResponse()
class _FakeFrame:
"""Minimal frame object with a body dict."""
def __init__(self, body: dict | None = None):
self.body = body or {}
class _FakeWeComClient:
"""Fake WeCom client with mock methods."""
def __init__(self, ws_responses: list[_FakeResponse] | None = None):
self._ws_manager = _FakeWsManager(ws_responses)
self.download_file = AsyncMock(return_value=(None, None))
self.reply = AsyncMock()
self.reply_stream = AsyncMock()
self.send_message = AsyncMock()
self.reply_welcome = AsyncMock()
# ── Helper function tests (pure, no async) ──────────────────────────
def test_sanitize_filename_strips_path_traversal() -> None:
assert _sanitize_filename("../../etc/passwd") == "passwd"
def test_sanitize_filename_keeps_chinese_chars() -> None:
assert _sanitize_filename("文件1.jpg") == "文件1.jpg"
def test_sanitize_filename_empty_input() -> None:
assert _sanitize_filename("") == ""
def test_guess_wecom_media_type_image() -> None:
for ext in (".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp"):
assert _guess_wecom_media_type(f"photo{ext}") == "image"
def test_guess_wecom_media_type_video() -> None:
for ext in (".mp4", ".avi", ".mov"):
assert _guess_wecom_media_type(f"video{ext}") == "video"
def test_guess_wecom_media_type_voice() -> None:
for ext in (".amr", ".mp3", ".wav", ".ogg"):
assert _guess_wecom_media_type(f"audio{ext}") == "voice"
def test_guess_wecom_media_type_file_fallback() -> None:
for ext in (".pdf", ".doc", ".xlsx", ".zip"):
assert _guess_wecom_media_type(f"doc{ext}") == "file"
def test_guess_wecom_media_type_case_insensitive() -> None:
assert _guess_wecom_media_type("photo.PNG") == "image"
assert _guess_wecom_media_type("photo.Jpg") == "image"
# ── _download_and_save_media() ──────────────────────────────────────
@pytest.mark.asyncio
async def test_download_and_save_success() -> None:
"""Successful download writes file and returns sanitized path."""
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
client = _FakeWeComClient()
channel._client = client
fake_data = b"\x89PNG\r\nfake image"
client.download_file.return_value = (fake_data, "raw_photo.png")
with patch("nanobot.channels.wecom.get_media_dir", return_value=Path(tempfile.gettempdir())):
path = await channel._download_and_save_media("https://example.com/img.png", "aes_key", "image", "photo.png")
assert path is not None
assert os.path.isfile(path)
assert os.path.basename(path) == "photo.png"
# Cleanup
os.unlink(path)
@pytest.mark.asyncio
async def test_download_and_save_oversized_rejected() -> None:
"""Data exceeding 200MB is rejected → returns None."""
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
client = _FakeWeComClient()
channel._client = client
big_data = b"\x00" * (200 * 1024 * 1024 + 1) # 200MB + 1 byte
client.download_file.return_value = (big_data, "big.bin")
with patch("nanobot.channels.wecom.get_media_dir", return_value=Path(tempfile.gettempdir())):
result = await channel._download_and_save_media("https://example.com/big.bin", "key", "file", "big.bin")
assert result is None
@pytest.mark.asyncio
async def test_download_and_save_failure() -> None:
"""SDK returns None data → returns None."""
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
client = _FakeWeComClient()
channel._client = client
client.download_file.return_value = (None, None)
with patch("nanobot.channels.wecom.get_media_dir", return_value=Path(tempfile.gettempdir())):
result = await channel._download_and_save_media("https://example.com/fail.png", "key", "image")
assert result is None
# ── _upload_media_ws() ──────────────────────────────────────────────
@pytest.mark.asyncio
async def test_upload_media_ws_success() -> None:
"""Happy path: init → chunk → finish → returns (media_id, media_type)."""
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
f.write(b"\x89PNG\r\n")
tmp = f.name
try:
responses = [
_FakeResponse(errcode=0, body={"upload_id": "up_1"}),
_FakeResponse(errcode=0, body={}),
_FakeResponse(errcode=0, body={"media_id": "media_abc"}),
]
client = _FakeWeComClient(responses)
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
channel._client = client
with patch("wecom_aibot_sdk.utils.generate_req_id", side_effect=lambda x: f"req_{x}"):
media_id, media_type = await channel._upload_media_ws(client, tmp)
assert media_id == "media_abc"
assert media_type == "image"
finally:
os.unlink(tmp)
@pytest.mark.asyncio
async def test_upload_media_ws_oversized_file() -> None:
"""File >200MB triggers ValueError → returns (None, None)."""
# Instead of creating a real 200MB+ file, mock os.path.getsize and open
with patch("os.path.getsize", return_value=200 * 1024 * 1024 + 1), \
patch("builtins.open", MagicMock()):
client = _FakeWeComClient()
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
channel._client = client
result = await channel._upload_media_ws(client, "/fake/large.bin")
assert result == (None, None)
@pytest.mark.asyncio
async def test_upload_media_ws_init_failure() -> None:
"""Init step returns errcode != 0 → returns (None, None)."""
with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as f:
f.write(b"hello")
tmp = f.name
try:
responses = [
_FakeResponse(errcode=50001, errmsg="invalid"),
]
client = _FakeWeComClient(responses)
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
channel._client = client
with patch("wecom_aibot_sdk.utils.generate_req_id", side_effect=lambda x: f"req_{x}"):
result = await channel._upload_media_ws(client, tmp)
assert result == (None, None)
finally:
os.unlink(tmp)
@pytest.mark.asyncio
async def test_upload_media_ws_chunk_failure() -> None:
"""Chunk step returns errcode != 0 → returns (None, None)."""
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
f.write(b"\x89PNG\r\n")
tmp = f.name
try:
responses = [
_FakeResponse(errcode=0, body={"upload_id": "up_1"}),
_FakeResponse(errcode=50002, errmsg="chunk fail"),
]
client = _FakeWeComClient(responses)
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
channel._client = client
with patch("wecom_aibot_sdk.utils.generate_req_id", side_effect=lambda x: f"req_{x}"):
result = await channel._upload_media_ws(client, tmp)
assert result == (None, None)
finally:
os.unlink(tmp)
@pytest.mark.asyncio
async def test_upload_media_ws_finish_no_media_id() -> None:
"""Finish step returns empty media_id → returns (None, None)."""
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
f.write(b"\x89PNG\r\n")
tmp = f.name
try:
responses = [
_FakeResponse(errcode=0, body={"upload_id": "up_1"}),
_FakeResponse(errcode=0, body={}),
_FakeResponse(errcode=0, body={}), # no media_id
]
client = _FakeWeComClient(responses)
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
channel._client = client
with patch("wecom_aibot_sdk.utils.generate_req_id", side_effect=lambda x: f"req_{x}"):
result = await channel._upload_media_ws(client, tmp)
assert result == (None, None)
finally:
os.unlink(tmp)
# ── send() ──────────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_send_text_with_frame() -> None:
"""When frame is stored, send uses reply_stream for final text."""
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
client = _FakeWeComClient()
channel._client = client
channel._generate_req_id = lambda x: f"req_{x}"
channel._chat_frames["chat1"] = _FakeFrame()
await channel.send(
OutboundMessage(channel="wecom", chat_id="chat1", content="hello")
)
client.reply_stream.assert_called_once()
call_args = client.reply_stream.call_args
assert call_args[0][2] == "hello" # content arg
@pytest.mark.asyncio
async def test_send_progress_with_frame() -> None:
"""When metadata has _progress, send uses reply_stream with finish=False."""
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
client = _FakeWeComClient()
channel._client = client
channel._generate_req_id = lambda x: f"req_{x}"
channel._chat_frames["chat1"] = _FakeFrame()
await channel.send(
OutboundMessage(channel="wecom", chat_id="chat1", content="thinking...", metadata={"_progress": True})
)
client.reply_stream.assert_called_once()
call_args = client.reply_stream.call_args
assert call_args[0][2] == "thinking..." # content arg
assert call_args[1]["finish"] is False
@pytest.mark.asyncio
async def test_send_proactive_without_frame() -> None:
"""Without stored frame, send uses send_message with markdown."""
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
client = _FakeWeComClient()
channel._client = client
await channel.send(
OutboundMessage(channel="wecom", chat_id="chat1", content="proactive msg")
)
client.send_message.assert_called_once()
call_args = client.send_message.call_args
assert call_args[0][0] == "chat1"
assert call_args[0][1]["msgtype"] == "markdown"
@pytest.mark.asyncio
async def test_send_media_then_text() -> None:
"""Media files are uploaded and sent before text content."""
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
f.write(b"\x89PNG\r\n")
tmp = f.name
try:
responses = [
_FakeResponse(errcode=0, body={"upload_id": "up_1"}),
_FakeResponse(errcode=0, body={}),
_FakeResponse(errcode=0, body={"media_id": "media_123"}),
]
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
client = _FakeWeComClient(responses)
channel._client = client
channel._generate_req_id = lambda x: f"req_{x}"
channel._chat_frames["chat1"] = _FakeFrame()
await channel.send(
OutboundMessage(channel="wecom", chat_id="chat1", content="see image", media=[tmp])
)
# Media should have been sent via reply
media_calls = [c for c in client.reply.call_args_list if c[0][1].get("msgtype") == "image"]
assert len(media_calls) == 1
assert media_calls[0][0][1]["image"]["media_id"] == "media_123"
# Text should have been sent via reply_stream
client.reply_stream.assert_called_once()
finally:
os.unlink(tmp)
@pytest.mark.asyncio
async def test_send_media_file_not_found() -> None:
"""Non-existent media path is skipped with a warning."""
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
client = _FakeWeComClient()
channel._client = client
channel._generate_req_id = lambda x: f"req_{x}"
channel._chat_frames["chat1"] = _FakeFrame()
await channel.send(
OutboundMessage(channel="wecom", chat_id="chat1", content="hello", media=["/nonexistent/file.png"])
)
# reply_stream should still be called for the text part
client.reply_stream.assert_called_once()
# No media reply should happen
media_calls = [c for c in client.reply.call_args_list if c[0][1].get("msgtype") in ("image", "file", "video")]
assert len(media_calls) == 0
@pytest.mark.asyncio
async def test_send_exception_caught_not_raised() -> None:
"""Exceptions inside send() must not propagate."""
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
client = _FakeWeComClient()
channel._client = client
channel._generate_req_id = lambda x: f"req_{x}"
channel._chat_frames["chat1"] = _FakeFrame()
# Make reply_stream raise
client.reply_stream.side_effect = RuntimeError("boom")
await channel.send(
OutboundMessage(channel="wecom", chat_id="chat1", content="fail test")
)
# No exception — test passes if we reach here.
# ── _process_message() ──────────────────────────────────────────────
@pytest.mark.asyncio
async def test_process_text_message() -> None:
"""Text message is routed to bus with correct fields."""
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["user1"]), MessageBus())
client = _FakeWeComClient()
channel._client = client
frame = _FakeFrame(body={
"msgid": "msg_text_1",
"chatid": "chat1",
"chattype": "single",
"from": {"userid": "user1"},
"text": {"content": "hello wecom"},
})
await channel._process_message(frame, "text")
msg = await channel.bus.consume_inbound()
assert msg.sender_id == "user1"
assert msg.chat_id == "chat1"
assert msg.content == "hello wecom"
assert msg.metadata["msg_type"] == "text"
@pytest.mark.asyncio
async def test_process_image_message() -> None:
"""Image message: download success → media_paths non-empty."""
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["user1"]), MessageBus())
client = _FakeWeComClient()
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
f.write(b"\x89PNG\r\n")
saved = f.name
client.download_file.return_value = (b"\x89PNG\r\n", "photo.png")
channel._client = client
try:
with patch("nanobot.channels.wecom.get_media_dir", return_value=Path(os.path.dirname(saved))):
frame = _FakeFrame(body={
"msgid": "msg_img_1",
"chatid": "chat1",
"from": {"userid": "user1"},
"image": {"url": "https://example.com/img.png", "aeskey": "key123"},
})
await channel._process_message(frame, "image")
msg = await channel.bus.consume_inbound()
assert len(msg.media) == 1
assert msg.media[0].endswith("photo.png")
assert "[image:" in msg.content
finally:
if os.path.exists(saved):
pass # may have been overwritten; clean up if exists
# Clean up any photo.png in tempdir
p = os.path.join(os.path.dirname(saved), "photo.png")
if os.path.exists(p):
os.unlink(p)
@pytest.mark.asyncio
async def test_process_file_message() -> None:
"""File message: download success → media_paths non-empty (critical fix verification)."""
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["user1"]), MessageBus())
client = _FakeWeComClient()
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as f:
f.write(b"%PDF-1.4 fake")
saved = f.name
client.download_file.return_value = (b"%PDF-1.4 fake", "report.pdf")
channel._client = client
try:
with patch("nanobot.channels.wecom.get_media_dir", return_value=Path(os.path.dirname(saved))):
frame = _FakeFrame(body={
"msgid": "msg_file_1",
"chatid": "chat1",
"from": {"userid": "user1"},
"file": {"url": "https://example.com/report.pdf", "aeskey": "key456", "name": "report.pdf"},
})
await channel._process_message(frame, "file")
msg = await channel.bus.consume_inbound()
assert len(msg.media) == 1
assert msg.media[0].endswith("report.pdf")
assert "[file: report.pdf]" in msg.content
finally:
p = os.path.join(os.path.dirname(saved), "report.pdf")
if os.path.exists(p):
os.unlink(p)
@pytest.mark.asyncio
async def test_process_voice_message() -> None:
"""Voice message: transcribed text is included in content."""
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["user1"]), MessageBus())
client = _FakeWeComClient()
channel._client = client
frame = _FakeFrame(body={
"msgid": "msg_voice_1",
"chatid": "chat1",
"from": {"userid": "user1"},
"voice": {"content": "transcribed text here"},
})
await channel._process_message(frame, "voice")
msg = await channel.bus.consume_inbound()
assert "transcribed text here" in msg.content
assert "[voice]" in msg.content
@pytest.mark.asyncio
async def test_process_message_deduplication() -> None:
"""Same msg_id is not processed twice."""
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["user1"]), MessageBus())
client = _FakeWeComClient()
channel._client = client
frame = _FakeFrame(body={
"msgid": "msg_dup_1",
"chatid": "chat1",
"from": {"userid": "user1"},
"text": {"content": "once"},
})
await channel._process_message(frame, "text")
await channel._process_message(frame, "text")
msg = await channel.bus.consume_inbound()
assert msg.content == "once"
# Second message should not appear on the bus
assert channel.bus.inbound.empty()
@pytest.mark.asyncio
async def test_process_message_empty_content_skipped() -> None:
"""Message with empty content produces no bus message."""
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["user1"]), MessageBus())
client = _FakeWeComClient()
channel._client = client
frame = _FakeFrame(body={
"msgid": "msg_empty_1",
"chatid": "chat1",
"from": {"userid": "user1"},
"text": {"content": ""},
})
await channel._process_message(frame, "text")
assert channel.bus.inbound.empty()

View File

@ -1,5 +1,6 @@
"""Test message tool suppress logic for final replies."""
import asyncio
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock
@ -86,6 +87,42 @@ class TestMessageToolSuppressLogic:
assert result is not None
assert "Hello" in result.content
@pytest.mark.asyncio
async def test_injected_followup_with_message_tool_does_not_emit_empty_fallback(
self, tmp_path: Path
) -> None:
loop = _make_loop(tmp_path)
tool_call = ToolCallRequest(
id="call1", name="message",
arguments={"content": "Tool reply", "channel": "feishu", "chat_id": "chat123"},
)
calls = iter([
LLMResponse(content="First answer", tool_calls=[]),
LLMResponse(content="", tool_calls=[tool_call]),
LLMResponse(content="", tool_calls=[]),
LLMResponse(content="", tool_calls=[]),
LLMResponse(content="", tool_calls=[]),
])
loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls))
loop.tools.get_definitions = MagicMock(return_value=[])
sent: list[OutboundMessage] = []
mt = loop.tools.get("message")
if isinstance(mt, MessageTool):
mt.set_send_callback(AsyncMock(side_effect=lambda m: sent.append(m)))
pending_queue = asyncio.Queue()
await pending_queue.put(
InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="follow-up")
)
msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Start")
result = await loop._process_message(msg, pending_queue=pending_queue)
assert len(sent) == 1
assert sent[0].content == "Tool reply"
assert result is None
async def test_progress_hides_internal_reasoning(self, tmp_path: Path) -> None:
loop = _make_loop(tmp_path)
tool_call = ToolCallRequest(id="call1", name="read_file", arguments={"path": "foo.txt"})
@ -107,7 +144,7 @@ class TestMessageToolSuppressLogic:
async def on_progress(content: str, *, tool_hint: bool = False) -> None:
progress.append((content, tool_hint))
final_content, _, _, _ = await loop._run_agent_loop([], on_progress=on_progress)
final_content, _, _, _, _ = await loop._run_agent_loop([], on_progress=on_progress)
assert final_content == "Done"
assert progress == [