Merge branch 'main' into fix/tool-call-result-order-2943

This commit is contained in:
layla 2026-04-11 22:00:07 +08:00 committed by GitHub
commit f25cdb7138
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 2246 additions and 211 deletions

View File

@ -17,7 +17,7 @@ from nanobot.agent.autocompact import AutoCompact
from nanobot.agent.context import ContextBuilder from nanobot.agent.context import ContextBuilder
from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook
from nanobot.agent.memory import Consolidator, Dream from nanobot.agent.memory import Consolidator, Dream
from nanobot.agent.runner import AgentRunSpec, AgentRunner from nanobot.agent.runner import _MAX_INJECTIONS_PER_TURN, AgentRunSpec, AgentRunner
from nanobot.agent.subagent import SubagentManager from nanobot.agent.subagent import SubagentManager
from nanobot.agent.tools.cron import CronTool from nanobot.agent.tools.cron import CronTool
from nanobot.agent.skills import BUILTIN_SKILLS_DIR from nanobot.agent.skills import BUILTIN_SKILLS_DIR
@ -207,6 +207,10 @@ class AgentLoop:
self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks
self._background_tasks: list[asyncio.Task] = [] self._background_tasks: list[asyncio.Task] = []
self._session_locks: dict[str, asyncio.Lock] = {} self._session_locks: dict[str, asyncio.Lock] = {}
# Per-session pending queues for mid-turn message injection.
# When a session has an active task, new messages for that session
# are routed here instead of creating a new task.
self._pending_queues: dict[str, asyncio.Queue] = {}
# NANOBOT_MAX_CONCURRENT_REQUESTS: <=0 means unlimited; default 3. # NANOBOT_MAX_CONCURRENT_REQUESTS: <=0 means unlimited; default 3.
_max = int(os.environ.get("NANOBOT_MAX_CONCURRENT_REQUESTS", "3")) _max = int(os.environ.get("NANOBOT_MAX_CONCURRENT_REQUESTS", "3"))
self._concurrency_gate: asyncio.Semaphore | None = ( self._concurrency_gate: asyncio.Semaphore | None = (
@ -320,6 +324,12 @@ class AgentLoop:
return format_tool_hints(tool_calls) return format_tool_hints(tool_calls)
def _effective_session_key(self, msg: InboundMessage) -> str:
"""Return the session key used for task routing and mid-turn injections."""
if self._unified_session and not msg.session_key_override:
return UNIFIED_SESSION_KEY
return msg.session_key
async def _run_agent_loop( async def _run_agent_loop(
self, self,
initial_messages: list[dict], initial_messages: list[dict],
@ -331,13 +341,16 @@ class AgentLoop:
channel: str = "cli", channel: str = "cli",
chat_id: str = "direct", chat_id: str = "direct",
message_id: str | None = None, message_id: str | None = None,
) -> tuple[str | None, list[str], list[dict], str]: pending_queue: asyncio.Queue | None = None,
) -> tuple[str | None, list[str], list[dict], str, bool]:
"""Run the agent iteration loop. """Run the agent iteration loop.
*on_stream*: called with each content delta during streaming. *on_stream*: called with each content delta during streaming.
*on_stream_end(resuming)*: called when a streaming session finishes. *on_stream_end(resuming)*: called when a streaming session finishes.
``resuming=True`` means tool calls follow (spinner should restart); ``resuming=True`` means tool calls follow (spinner should restart);
``resuming=False`` means this is the final response. ``resuming=False`` means this is the final response.
Returns (final_content, tools_used, messages, stop_reason, had_injections).
""" """
loop_hook = _LoopHook( loop_hook = _LoopHook(
self, self,
@ -357,8 +370,33 @@ class AgentLoop:
return return
self._set_runtime_checkpoint(session, payload) self._set_runtime_checkpoint(session, payload)
result = await self.runner.run( async def _drain_pending(*, limit: int = _MAX_INJECTIONS_PER_TURN) -> list[dict[str, Any]]:
AgentRunSpec( """Non-blocking drain of follow-up messages from the pending queue."""
if pending_queue is None:
return []
items: list[dict[str, Any]] = []
while len(items) < limit:
try:
pending_msg = pending_queue.get_nowait()
except asyncio.QueueEmpty:
break
user_content = self.context._build_user_content(
pending_msg.content,
pending_msg.media if pending_msg.media else None,
)
runtime_ctx = self.context._build_runtime_context(
pending_msg.channel,
pending_msg.chat_id,
self.context.timezone,
)
if isinstance(user_content, str):
merged: str | list[dict[str, Any]] = f"{runtime_ctx}\n\n{user_content}"
else:
merged = [{"type": "text", "text": runtime_ctx}] + user_content
items.append({"role": "user", "content": merged})
return items
result = await self.runner.run(AgentRunSpec(
initial_messages=initial_messages, initial_messages=initial_messages,
tools=self.tools, tools=self.tools,
model=self.model, model=self.model,
@ -374,14 +412,14 @@ class AgentLoop:
provider_retry_mode=self.provider_retry_mode, provider_retry_mode=self.provider_retry_mode,
progress_callback=on_progress, progress_callback=on_progress,
checkpoint_callback=_checkpoint, checkpoint_callback=_checkpoint,
) injection_callback=_drain_pending,
) ))
self._last_usage = result.usage self._last_usage = result.usage
if result.stop_reason == "max_iterations": if result.stop_reason == "max_iterations":
logger.warning("Max iterations ({}) reached", self.max_iterations) logger.warning("Max iterations ({}) reached", self.max_iterations)
elif result.stop_reason == "error": elif result.stop_reason == "error":
logger.error("LLM returned error: {}", (result.final_content or "")[:200]) logger.error("LLM returned error: {}", (result.final_content or "")[:200])
return result.final_content, result.tools_used, result.messages, result.stop_reason return result.final_content, result.tools_used, result.messages, result.stop_reason, result.had_injections
async def run(self) -> None: async def run(self) -> None:
"""Run the agent loop, dispatching messages as tasks to stay responsive to /stop.""" """Run the agent loop, dispatching messages as tasks to stay responsive to /stop."""
@ -412,13 +450,32 @@ class AgentLoop:
if result: if result:
await self.bus.publish_outbound(result) await self.bus.publish_outbound(result)
continue continue
effective_key = self._effective_session_key(msg)
# If this session already has an active pending queue (i.e. a task
# is processing this session), route the message there for mid-turn
# injection instead of creating a competing task.
if effective_key in self._pending_queues:
pending_msg = msg
if effective_key != msg.session_key:
pending_msg = dataclasses.replace(
msg,
session_key_override=effective_key,
)
try:
self._pending_queues[effective_key].put_nowait(pending_msg)
except asyncio.QueueFull:
logger.warning(
"Pending queue full for session {}, falling back to queued task",
effective_key,
)
else:
logger.info(
"Routed follow-up message to pending queue for session {}",
effective_key,
)
continue
# Compute the effective session key before dispatching # Compute the effective session key before dispatching
# This ensures /stop command can find tasks correctly when unified session is enabled # This ensures /stop command can find tasks correctly when unified session is enabled
effective_key = (
UNIFIED_SESSION_KEY
if self._unified_session and not msg.session_key_override
else msg.session_key
)
task = asyncio.create_task(self._dispatch(msg)) task = asyncio.create_task(self._dispatch(msg))
self._active_tasks.setdefault(effective_key, []).append(task) self._active_tasks.setdefault(effective_key, []).append(task)
task.add_done_callback( task.add_done_callback(
@ -430,10 +487,18 @@ class AgentLoop:
async def _dispatch(self, msg: InboundMessage) -> None: async def _dispatch(self, msg: InboundMessage) -> None:
"""Process a message: per-session serial, cross-session concurrent.""" """Process a message: per-session serial, cross-session concurrent."""
if self._unified_session and not msg.session_key_override: session_key = self._effective_session_key(msg)
msg = dataclasses.replace(msg, session_key_override=UNIFIED_SESSION_KEY) if session_key != msg.session_key:
lock = self._session_locks.setdefault(msg.session_key, asyncio.Lock()) msg = dataclasses.replace(msg, session_key_override=session_key)
lock = self._session_locks.setdefault(session_key, asyncio.Lock())
gate = self._concurrency_gate or nullcontext() gate = self._concurrency_gate or nullcontext()
# Register a pending queue so follow-up messages for this session are
# routed here (mid-turn injection) instead of spawning a new task.
pending = asyncio.Queue(maxsize=20)
self._pending_queues[session_key] = pending
try:
async with lock, gate: async with lock, gate:
try: try:
on_stream = on_stream_end = None on_stream = on_stream_end = None
@ -449,14 +514,11 @@ class AgentLoop:
meta = dict(msg.metadata or {}) meta = dict(msg.metadata or {})
meta["_stream_delta"] = True meta["_stream_delta"] = True
meta["_stream_id"] = _current_stream_id() meta["_stream_id"] = _current_stream_id()
await self.bus.publish_outbound( await self.bus.publish_outbound(OutboundMessage(
OutboundMessage( channel=msg.channel, chat_id=msg.chat_id,
channel=msg.channel,
chat_id=msg.chat_id,
content=delta, content=delta,
metadata=meta, metadata=meta,
) ))
)
async def on_stream_end(*, resuming: bool = False) -> None: async def on_stream_end(*, resuming: bool = False) -> None:
nonlocal stream_segment nonlocal stream_segment
@ -464,43 +526,51 @@ class AgentLoop:
meta["_stream_end"] = True meta["_stream_end"] = True
meta["_resuming"] = resuming meta["_resuming"] = resuming
meta["_stream_id"] = _current_stream_id() meta["_stream_id"] = _current_stream_id()
await self.bus.publish_outbound( await self.bus.publish_outbound(OutboundMessage(
OutboundMessage( channel=msg.channel, chat_id=msg.chat_id,
channel=msg.channel,
chat_id=msg.chat_id,
content="", content="",
metadata=meta, metadata=meta,
) ))
)
stream_segment += 1 stream_segment += 1
response = await self._process_message( response = await self._process_message(
msg, msg, on_stream=on_stream, on_stream_end=on_stream_end,
on_stream=on_stream, pending_queue=pending,
on_stream_end=on_stream_end,
) )
if response is not None: if response is not None:
await self.bus.publish_outbound(response) await self.bus.publish_outbound(response)
elif msg.channel == "cli": elif msg.channel == "cli":
await self.bus.publish_outbound( await self.bus.publish_outbound(OutboundMessage(
OutboundMessage( channel=msg.channel, chat_id=msg.chat_id,
channel=msg.channel, content="", metadata=msg.metadata or {},
chat_id=msg.chat_id, ))
content="",
metadata=msg.metadata or {},
)
)
except asyncio.CancelledError: except asyncio.CancelledError:
logger.info("Task cancelled for session {}", msg.session_key) logger.info("Task cancelled for session {}", session_key)
raise raise
except Exception: except Exception:
logger.exception("Error processing message for session {}", msg.session_key) logger.exception("Error processing message for session {}", session_key)
await self.bus.publish_outbound( await self.bus.publish_outbound(OutboundMessage(
OutboundMessage( channel=msg.channel, chat_id=msg.chat_id,
channel=msg.channel,
chat_id=msg.chat_id,
content="Sorry, I encountered an error.", content="Sorry, I encountered an error.",
) ))
finally:
# Drain any messages still in the pending queue and re-publish
# them to the bus so they are processed as fresh inbound messages
# rather than silently lost.
queue = self._pending_queues.pop(session_key, None)
if queue is not None:
leftover = 0
while True:
try:
item = queue.get_nowait()
except asyncio.QueueEmpty:
break
await self.bus.publish_inbound(item)
leftover += 1
if leftover:
logger.info(
"Re-published {} leftover message(s) to bus for session {}",
leftover, session_key,
) )
async def close_mcp(self) -> None: async def close_mcp(self) -> None:
@ -533,6 +603,7 @@ class AgentLoop:
on_progress: Callable[[str], Awaitable[None]] | None = None, on_progress: Callable[[str], Awaitable[None]] | None = None,
on_stream: Callable[[str], Awaitable[None]] | None = None, on_stream: Callable[[str], Awaitable[None]] | None = None,
on_stream_end: Callable[..., Awaitable[None]] | None = None, on_stream_end: Callable[..., Awaitable[None]] | None = None,
pending_queue: asyncio.Queue | None = None,
) -> OutboundMessage | None: ) -> OutboundMessage | None:
"""Process a single inbound message and return the response.""" """Process a single inbound message and return the response."""
# System messages: parse origin from chat_id ("channel:chat_id") # System messages: parse origin from chat_id ("channel:chat_id")
@ -559,11 +630,8 @@ class AgentLoop:
session_summary=pending, session_summary=pending,
current_role=current_role, current_role=current_role,
) )
final_content, _, all_msgs, _ = await self._run_agent_loop( final_content, _, all_msgs, _, _ = await self._run_agent_loop(
messages, messages, session=session, channel=channel, chat_id=chat_id,
session=session,
channel=channel,
chat_id=chat_id,
message_id=msg.metadata.get("message_id"), message_id=msg.metadata.get("message_id"),
) )
self._save_turn(session, all_msgs, 1 + len(history)) self._save_turn(session, all_msgs, 1 + len(history))
@ -623,7 +691,7 @@ class AgentLoop:
) )
) )
final_content, _, all_msgs, stop_reason = await self._run_agent_loop( final_content, _, all_msgs, stop_reason, had_injections = await self._run_agent_loop(
initial_messages, initial_messages,
on_progress=on_progress or _bus_progress, on_progress=on_progress or _bus_progress,
on_stream=on_stream, on_stream=on_stream,
@ -632,6 +700,7 @@ class AgentLoop:
channel=msg.channel, channel=msg.channel,
chat_id=msg.chat_id, chat_id=msg.chat_id,
message_id=msg.metadata.get("message_id"), message_id=msg.metadata.get("message_id"),
pending_queue=pending_queue,
) )
if final_content is None or not final_content.strip(): if final_content is None or not final_content.strip():
@ -642,7 +711,14 @@ class AgentLoop:
self.sessions.save(session) self.sessions.save(session)
self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session)) self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session))
# When follow-up messages were injected mid-turn, a later natural
# language reply may address those follow-ups and should not be
# suppressed just because MessageTool was used earlier in the turn.
# However, if the turn falls back to the empty-final-response
# placeholder, suppress it when the real user-visible output already
# came from MessageTool.
if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn: if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn:
if not had_injections or stop_reason == "empty_final_response":
return None return None
preview = final_content[:120] + "..." if len(final_content) > 120 else final_content preview = final_content[:120] + "..." if len(final_content) > 120 else final_content

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import asyncio import asyncio
from dataclasses import dataclass, field from dataclasses import dataclass, field
import inspect
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
@ -34,6 +35,8 @@ _DEFAULT_ERROR_MESSAGE = "Sorry, I encountered an error calling the AI model."
_PERSISTED_MODEL_ERROR_PLACEHOLDER = "[Assistant reply unavailable due to model error.]" _PERSISTED_MODEL_ERROR_PLACEHOLDER = "[Assistant reply unavailable due to model error.]"
_MAX_EMPTY_RETRIES = 2 _MAX_EMPTY_RETRIES = 2
_MAX_LENGTH_RECOVERIES = 3 _MAX_LENGTH_RECOVERIES = 3
_MAX_INJECTIONS_PER_TURN = 3
_MAX_INJECTION_CYCLES = 5
_SNIP_SAFETY_BUFFER = 1024 _SNIP_SAFETY_BUFFER = 1024
_MICROCOMPACT_KEEP_RECENT = 10 _MICROCOMPACT_KEEP_RECENT = 10
_MICROCOMPACT_MIN_CHARS = 500 _MICROCOMPACT_MIN_CHARS = 500
@ -42,6 +45,9 @@ _COMPACTABLE_TOOLS = frozenset({
"web_search", "web_fetch", "list_dir", "web_search", "web_fetch", "list_dir",
}) })
_BACKFILL_CONTENT = "[Tool result unavailable — call was interrupted or lost]" _BACKFILL_CONTENT = "[Tool result unavailable — call was interrupted or lost]"
@dataclass(slots=True) @dataclass(slots=True)
class AgentRunSpec: class AgentRunSpec:
"""Configuration for a single agent execution.""" """Configuration for a single agent execution."""
@ -66,6 +72,7 @@ class AgentRunSpec:
provider_retry_mode: str = "standard" provider_retry_mode: str = "standard"
progress_callback: Any | None = None progress_callback: Any | None = None
checkpoint_callback: Any | None = None checkpoint_callback: Any | None = None
injection_callback: Any | None = None
@dataclass(slots=True) @dataclass(slots=True)
@ -79,6 +86,7 @@ class AgentRunResult:
stop_reason: str = "completed" stop_reason: str = "completed"
error: str | None = None error: str | None = None
tool_events: list[dict[str, str]] = field(default_factory=list) tool_events: list[dict[str, str]] = field(default_factory=list)
had_injections: bool = False
class AgentRunner: class AgentRunner:
@ -87,6 +95,90 @@ class AgentRunner:
def __init__(self, provider: LLMProvider): def __init__(self, provider: LLMProvider):
self.provider = provider self.provider = provider
@staticmethod
def _merge_message_content(left: Any, right: Any) -> str | list[dict[str, Any]]:
if isinstance(left, str) and isinstance(right, str):
return f"{left}\n\n{right}" if left else right
def _to_blocks(value: Any) -> list[dict[str, Any]]:
if isinstance(value, list):
return [
item if isinstance(item, dict) else {"type": "text", "text": str(item)}
for item in value
]
if value is None:
return []
return [{"type": "text", "text": str(value)}]
return _to_blocks(left) + _to_blocks(right)
@classmethod
def _append_injected_messages(
cls,
messages: list[dict[str, Any]],
injections: list[dict[str, Any]],
) -> None:
"""Append injected user messages while preserving role alternation."""
for injection in injections:
if (
messages
and injection.get("role") == "user"
and messages[-1].get("role") == "user"
):
merged = dict(messages[-1])
merged["content"] = cls._merge_message_content(
merged.get("content"),
injection.get("content"),
)
messages[-1] = merged
continue
messages.append(injection)
async def _drain_injections(self, spec: AgentRunSpec) -> list[dict[str, Any]]:
"""Drain pending user messages via the injection callback.
Returns normalized user messages (capped by
``_MAX_INJECTIONS_PER_TURN``), or an empty list when there is
nothing to inject. Messages beyond the cap are logged so they
are not silently lost.
"""
if spec.injection_callback is None:
return []
try:
signature = inspect.signature(spec.injection_callback)
accepts_limit = (
"limit" in signature.parameters
or any(
parameter.kind is inspect.Parameter.VAR_KEYWORD
for parameter in signature.parameters.values()
)
)
if accepts_limit:
items = await spec.injection_callback(limit=_MAX_INJECTIONS_PER_TURN)
else:
items = await spec.injection_callback()
except Exception:
logger.exception("injection_callback failed")
return []
if not items:
return []
injected_messages: list[dict[str, Any]] = []
for item in items:
if isinstance(item, dict) and item.get("role") == "user" and "content" in item:
injected_messages.append(item)
continue
text = getattr(item, "content", str(item))
if text.strip():
injected_messages.append({"role": "user", "content": text})
if len(injected_messages) > _MAX_INJECTIONS_PER_TURN:
dropped = len(injected_messages) - _MAX_INJECTIONS_PER_TURN
logger.warning(
"Injection callback returned {} messages, capping to {} ({} dropped)",
len(injected_messages), _MAX_INJECTIONS_PER_TURN, dropped,
)
injected_messages = injected_messages[:_MAX_INJECTIONS_PER_TURN]
return injected_messages
async def run(self, spec: AgentRunSpec) -> AgentRunResult: async def run(self, spec: AgentRunSpec) -> AgentRunResult:
hook = spec.hook or AgentHook() hook = spec.hook or AgentHook()
messages = list(spec.initial_messages) messages = list(spec.initial_messages)
@ -99,6 +191,8 @@ class AgentRunner:
external_lookup_counts: dict[str, int] = {} external_lookup_counts: dict[str, int] = {}
empty_content_retries = 0 empty_content_retries = 0
length_recovery_count = 0 length_recovery_count = 0
had_injections = False
injection_cycles = 0
for iteration in range(spec.max_iterations): for iteration in range(spec.max_iterations):
try: try:
@ -207,6 +301,17 @@ class AgentRunner:
) )
empty_content_retries = 0 empty_content_retries = 0
length_recovery_count = 0 length_recovery_count = 0
# Checkpoint 1: drain injections after tools, before next LLM call
if injection_cycles < _MAX_INJECTION_CYCLES:
injections = await self._drain_injections(spec)
if injections:
had_injections = True
injection_cycles += 1
self._append_injected_messages(messages, injections)
logger.info(
"Injected {} follow-up message(s) after tool execution ({}/{})",
len(injections), injection_cycles, _MAX_INJECTION_CYCLES,
)
await hook.after_iteration(context) await hook.after_iteration(context)
continue continue
@ -263,8 +368,49 @@ class AgentRunner:
await hook.after_iteration(context) await hook.after_iteration(context)
continue continue
assistant_message: dict[str, Any] | None = None
if response.finish_reason != "error" and not is_blank_text(clean):
assistant_message = build_assistant_message(
clean,
reasoning_content=response.reasoning_content,
thinking_blocks=response.thinking_blocks,
)
# Check for mid-turn injections BEFORE signaling stream end.
# If injections are found we keep the stream alive (resuming=True)
# so streaming channels don't prematurely finalize the card.
_injected_after_final = False
if injection_cycles < _MAX_INJECTION_CYCLES:
injections = await self._drain_injections(spec)
if injections:
had_injections = True
injection_cycles += 1
_injected_after_final = True
if assistant_message is not None:
messages.append(assistant_message)
await self._emit_checkpoint(
spec,
{
"phase": "final_response",
"iteration": iteration,
"model": spec.model,
"assistant_message": assistant_message,
"completed_tool_results": [],
"pending_tool_calls": [],
},
)
self._append_injected_messages(messages, injections)
logger.info(
"Injected {} follow-up message(s) after final response ({}/{})",
len(injections), injection_cycles, _MAX_INJECTION_CYCLES,
)
if hook.wants_streaming(): if hook.wants_streaming():
await hook.on_stream_end(context, resuming=False) await hook.on_stream_end(context, resuming=_injected_after_final)
if _injected_after_final:
await hook.after_iteration(context)
continue
if response.finish_reason == "error": if response.finish_reason == "error":
final_content = clean or spec.error_message or _DEFAULT_ERROR_MESSAGE final_content = clean or spec.error_message or _DEFAULT_ERROR_MESSAGE
@ -287,7 +433,7 @@ class AgentRunner:
await hook.after_iteration(context) await hook.after_iteration(context)
break break
messages.append(build_assistant_message( messages.append(assistant_message or build_assistant_message(
clean, clean,
reasoning_content=response.reasoning_content, reasoning_content=response.reasoning_content,
thinking_blocks=response.thinking_blocks, thinking_blocks=response.thinking_blocks,
@ -330,6 +476,7 @@ class AgentRunner:
stop_reason=stop_reason, stop_reason=stop_reason,
error=error, error=error,
tool_events=tool_events, tool_events=tool_events,
had_injections=had_injections,
) )
def _build_request_kwargs( def _build_request_kwargs(

View File

@ -242,6 +242,7 @@ class QQChannel(BaseChannel):
async def send(self, msg: OutboundMessage) -> None: async def send(self, msg: OutboundMessage) -> None:
"""Send attachments first, then text.""" """Send attachments first, then text."""
try:
if not self._client: if not self._client:
logger.warning("QQ client not initialized") logger.warning("QQ client not initialized")
return return
@ -279,6 +280,8 @@ class QQChannel(BaseChannel):
msg_id=msg_id, msg_id=msg_id,
content=msg.content.strip(), content=msg.content.strip(),
) )
except Exception:
logger.exception("Error sending QQ message to chat_id={}", msg.chat_id)
async def _send_text_only( async def _send_text_only(
self, self,
@ -438,15 +441,26 @@ class QQChannel(BaseChannel):
endpoint = "/v2/users/{openid}/files" endpoint = "/v2/users/{openid}/files"
id_key = "openid" id_key = "openid"
payload = { payload: dict[str, Any] = {
id_key: chat_id, id_key: chat_id,
"file_type": file_type, "file_type": file_type,
"file_data": file_data, "file_data": file_data,
"file_name": file_name,
"srv_send_msg": srv_send_msg, "srv_send_msg": srv_send_msg,
} }
# Only pass file_name for non-image types (file_type=4).
# Passing file_name for images causes QQ client to render them as
# file attachments instead of inline images.
if file_type != QQ_FILE_TYPE_IMAGE and file_name:
payload["file_name"] = file_name
route = Route("POST", endpoint, **{id_key: chat_id}) route = Route("POST", endpoint, **{id_key: chat_id})
return await self._client.api._http.request(route, json=payload) result = await self._client.api._http.request(route, json=payload)
# Extract only the file_info field to avoid extra fields (file_uuid, ttl, etc.)
# that may confuse QQ client when sending the media object.
if isinstance(result, dict) and "file_info" in result:
return {"file_info": result["file_info"]}
return result
# --------------------------- # ---------------------------
# Inbound (receive) # Inbound (receive)
@ -454,6 +468,7 @@ class QQChannel(BaseChannel):
async def _on_message(self, data: C2CMessage | GroupMessage, is_group: bool = False) -> None: async def _on_message(self, data: C2CMessage | GroupMessage, is_group: bool = False) -> None:
"""Parse inbound message, download attachments, and publish to the bus.""" """Parse inbound message, download attachments, and publish to the bus."""
try:
if data.id in self._processed_ids: if data.id in self._processed_ids:
return return
self._processed_ids.append(data.id) self._processed_ids.append(data.id)
@ -464,7 +479,8 @@ class QQChannel(BaseChannel):
self._chat_type_cache[chat_id] = "group" self._chat_type_cache[chat_id] = "group"
else: else:
chat_id = str( chat_id = str(
getattr(data.author, "id", None) or getattr(data.author, "user_openid", "unknown") getattr(data.author, "id", None)
or getattr(data.author, "user_openid", "unknown")
) )
user_id = chat_id user_id = chat_id
self._chat_type_cache[chat_id] = "c2c" self._chat_type_cache[chat_id] = "c2c"
@ -478,9 +494,15 @@ class QQChannel(BaseChannel):
# Compose content that always contains actionable saved paths # Compose content that always contains actionable saved paths
if recv_lines: if recv_lines:
tag = "[Image]" if any(_is_image_name(Path(p).name) for p in media_paths) else "[File]" tag = (
"[Image]"
if any(_is_image_name(Path(p).name) for p in media_paths)
else "[File]"
)
file_block = "Received files:\n" + "\n".join(recv_lines) file_block = "Received files:\n" + "\n".join(recv_lines)
content = f"{content}\n\n{file_block}".strip() if content else f"{tag}\n{file_block}" content = (
f"{content}\n\n{file_block}".strip() if content else f"{tag}\n{file_block}"
)
if not content and not media_paths: if not content and not media_paths:
return return
@ -506,6 +528,8 @@ class QQChannel(BaseChannel):
"attachments": att_meta, "attachments": att_meta,
}, },
) )
except Exception:
logger.exception("Error handling QQ inbound message id={}", getattr(data, "id", "?"))
async def _handle_attachments( async def _handle_attachments(
self, self,
@ -520,7 +544,9 @@ class QQChannel(BaseChannel):
return media_paths, recv_lines, att_meta return media_paths, recv_lines, att_meta
for att in attachments: for att in attachments:
url, filename, ctype = att.url, att.filename, att.content_type url = getattr(att, "url", None) or ""
filename = getattr(att, "filename", None) or ""
ctype = getattr(att, "content_type", None) or ""
logger.info("Downloading file from QQ: {}", filename or url) logger.info("Downloading file from QQ: {}", filename or url)
local_path = await self._download_to_media_dir_chunked(url, filename_hint=filename) local_path = await self._download_to_media_dir_chunked(url, filename_hint=filename)
@ -555,6 +581,10 @@ class QQChannel(BaseChannel):
Enforces a max download size and writes to a .part temp file Enforces a max download size and writes to a .part temp file
that is atomically renamed on success. that is atomically renamed on success.
""" """
# Handle protocol-relative URLs (e.g. "//multimedia.nt.qq.com/...")
if url.startswith("//"):
url = f"https:{url}"
if not self._http: if not self._http:
self._http = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=120)) self._http = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=120))

View File

@ -1,9 +1,13 @@
"""WeCom (Enterprise WeChat) channel implementation using wecom_aibot_sdk.""" """WeCom (Enterprise WeChat) channel implementation using wecom_aibot_sdk."""
import asyncio import asyncio
import base64
import hashlib
import importlib.util import importlib.util
import os import os
import re
from collections import OrderedDict from collections import OrderedDict
from pathlib import Path
from typing import Any from typing import Any
from loguru import logger from loguru import logger
@ -17,6 +21,37 @@ from pydantic import Field
WECOM_AVAILABLE = importlib.util.find_spec("wecom_aibot_sdk") is not None WECOM_AVAILABLE = importlib.util.find_spec("wecom_aibot_sdk") is not None
# Upload safety limits (matching QQ channel defaults)
WECOM_UPLOAD_MAX_BYTES = 1024 * 1024 * 200 # 200MB
# Replace unsafe characters with "_", keep Chinese and common safe punctuation.
_SAFE_NAME_RE = re.compile(r"[^\w.\-()\[\]()【】\u4e00-\u9fff]+", re.UNICODE)
def _sanitize_filename(name: str) -> str:
"""Sanitize filename to avoid traversal and problematic chars."""
name = (name or "").strip()
name = Path(name).name
name = _SAFE_NAME_RE.sub("_", name).strip("._ ")
return name
_IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp"}
_VIDEO_EXTS = {".mp4", ".avi", ".mov"}
_AUDIO_EXTS = {".amr", ".mp3", ".wav", ".ogg"}
def _guess_wecom_media_type(filename: str) -> str:
"""Classify file extension as WeCom media_type string."""
ext = Path(filename).suffix.lower()
if ext in _IMAGE_EXTS:
return "image"
if ext in _VIDEO_EXTS:
return "video"
if ext in _AUDIO_EXTS:
return "voice"
return "file"
class WecomConfig(Base): class WecomConfig(Base):
"""WeCom (Enterprise WeChat) AI Bot channel configuration.""" """WeCom (Enterprise WeChat) AI Bot channel configuration."""
@ -217,6 +252,7 @@ class WecomChannel(BaseChannel):
chat_id = body.get("chatid", sender_id) chat_id = body.get("chatid", sender_id)
content_parts = [] content_parts = []
media_paths: list[str] = []
if msg_type == "text": if msg_type == "text":
text = body.get("text", {}).get("content", "") text = body.get("text", {}).get("content", "")
@ -232,7 +268,8 @@ class WecomChannel(BaseChannel):
file_path = await self._download_and_save_media(file_url, aes_key, "image") file_path = await self._download_and_save_media(file_url, aes_key, "image")
if file_path: if file_path:
filename = os.path.basename(file_path) filename = os.path.basename(file_path)
content_parts.append(f"[image: {filename}]\n[Image: source: {file_path}]") content_parts.append(f"[image: {filename}]")
media_paths.append(file_path)
else: else:
content_parts.append("[image: download failed]") content_parts.append("[image: download failed]")
else: else:
@ -256,7 +293,8 @@ class WecomChannel(BaseChannel):
if file_url and aes_key: if file_url and aes_key:
file_path = await self._download_and_save_media(file_url, aes_key, "file", file_name) file_path = await self._download_and_save_media(file_url, aes_key, "file", file_name)
if file_path: if file_path:
content_parts.append(f"[file: {file_name}]\n[File: source: {file_path}]") content_parts.append(f"[file: {file_name}]")
media_paths.append(file_path)
else: else:
content_parts.append(f"[file: {file_name}: download failed]") content_parts.append(f"[file: {file_name}: download failed]")
else: else:
@ -286,12 +324,11 @@ class WecomChannel(BaseChannel):
self._chat_frames[chat_id] = frame self._chat_frames[chat_id] = frame
# Forward to message bus # Forward to message bus
# Note: media paths are included in content for broader model compatibility
await self._handle_message( await self._handle_message(
sender_id=sender_id, sender_id=sender_id,
chat_id=chat_id, chat_id=chat_id,
content=content, content=content,
media=None, media=media_paths or None,
metadata={ metadata={
"message_id": msg_id, "message_id": msg_id,
"msg_type": msg_type, "msg_type": msg_type,
@ -322,13 +359,21 @@ class WecomChannel(BaseChannel):
logger.warning("Failed to download media from WeCom") logger.warning("Failed to download media from WeCom")
return None return None
if len(data) > WECOM_UPLOAD_MAX_BYTES:
logger.warning(
"WeCom inbound media too large: {} bytes (max {})",
len(data),
WECOM_UPLOAD_MAX_BYTES,
)
return None
media_dir = get_media_dir("wecom") media_dir = get_media_dir("wecom")
if not filename: if not filename:
filename = fname or f"{media_type}_{hash(file_url) % 100000}" filename = fname or f"{media_type}_{hash(file_url) % 100000}"
filename = os.path.basename(filename) filename = _sanitize_filename(filename)
file_path = media_dir / filename file_path = media_dir / filename
file_path.write_bytes(data) await asyncio.to_thread(file_path.write_bytes, data)
logger.debug("Downloaded {} to {}", media_type, file_path) logger.debug("Downloaded {} to {}", media_type, file_path)
return str(file_path) return str(file_path)
@ -336,6 +381,100 @@ class WecomChannel(BaseChannel):
logger.error("Error downloading media: {}", e) logger.error("Error downloading media: {}", e)
return None return None
async def _upload_media_ws(
self, client: Any, file_path: str,
) -> "tuple[str, str] | tuple[None, None]":
"""Upload a local file to WeCom via WebSocket 3-step protocol (base64).
Uses the WeCom WebSocket upload commands directly via
``client._ws_manager.send_reply()``:
``aibot_upload_media_init`` upload_id
``aibot_upload_media_chunk`` × N (512 KB raw per chunk, base64)
``aibot_upload_media_finish`` media_id
Returns (media_id, media_type) on success, (None, None) on failure.
"""
from wecom_aibot_sdk.utils import generate_req_id as _gen_req_id
try:
fname = os.path.basename(file_path)
media_type = _guess_wecom_media_type(fname)
# Read file size and data in a thread to avoid blocking the event loop
def _read_file():
file_size = os.path.getsize(file_path)
if file_size > WECOM_UPLOAD_MAX_BYTES:
raise ValueError(
f"File too large: {file_size} bytes (max {WECOM_UPLOAD_MAX_BYTES})"
)
with open(file_path, "rb") as f:
return file_size, f.read()
file_size, data = await asyncio.to_thread(_read_file)
# MD5 is used for file integrity only, not cryptographic security
md5_hash = hashlib.md5(data).hexdigest()
CHUNK_SIZE = 512 * 1024 # 512 KB raw (before base64)
mv = memoryview(data)
chunk_list = [bytes(mv[i : i + CHUNK_SIZE]) for i in range(0, file_size, CHUNK_SIZE)]
n_chunks = len(chunk_list)
del mv, data
# Step 1: init
req_id = _gen_req_id("upload_init")
resp = await client._ws_manager.send_reply(req_id, {
"type": media_type,
"filename": fname,
"total_size": file_size,
"total_chunks": n_chunks,
"md5": md5_hash,
}, "aibot_upload_media_init")
if resp.errcode != 0:
logger.warning("WeCom upload init failed ({}): {}", resp.errcode, resp.errmsg)
return None, None
upload_id = resp.body.get("upload_id") if resp.body else None
if not upload_id:
logger.warning("WeCom upload init: no upload_id in response")
return None, None
# Step 2: send chunks
for i, chunk in enumerate(chunk_list):
req_id = _gen_req_id("upload_chunk")
resp = await client._ws_manager.send_reply(req_id, {
"upload_id": upload_id,
"chunk_index": i,
"base64_data": base64.b64encode(chunk).decode(),
}, "aibot_upload_media_chunk")
if resp.errcode != 0:
logger.warning("WeCom upload chunk {} failed ({}): {}", i, resp.errcode, resp.errmsg)
return None, None
# Step 3: finish
req_id = _gen_req_id("upload_finish")
resp = await client._ws_manager.send_reply(req_id, {
"upload_id": upload_id,
}, "aibot_upload_media_finish")
if resp.errcode != 0:
logger.warning("WeCom upload finish failed ({}): {}", resp.errcode, resp.errmsg)
return None, None
media_id = resp.body.get("media_id") if resp.body else None
if not media_id:
logger.warning("WeCom upload finish: no media_id in response body={}", resp.body)
return None, None
suffix = "..." if len(media_id) > 16 else ""
logger.debug("WeCom uploaded {} ({}) → media_id={}", fname, media_type, media_id[:16] + suffix)
return media_id, media_type
except ValueError as e:
logger.warning("WeCom upload skipped for {}: {}", file_path, e)
return None, None
except Exception as e:
logger.error("WeCom _upload_media_ws error for {}: {}", file_path, e)
return None, None
async def send(self, msg: OutboundMessage) -> None: async def send(self, msg: OutboundMessage) -> None:
"""Send a message through WeCom.""" """Send a message through WeCom."""
if not self._client: if not self._client:
@ -343,29 +482,59 @@ class WecomChannel(BaseChannel):
return return
try: try:
content = msg.content.strip() content = (msg.content or "").strip()
if not content: is_progress = bool(msg.metadata.get("_progress"))
return
# Get the stored frame for this chat # Get the stored frame for this chat
frame = self._chat_frames.get(msg.chat_id) frame = self._chat_frames.get(msg.chat_id)
if not frame:
logger.warning("No frame found for chat {}, cannot reply", msg.chat_id) # Send media files via WebSocket upload
for file_path in msg.media or []:
if not os.path.isfile(file_path):
logger.warning("WeCom media file not found: {}", file_path)
continue
media_id, media_type = await self._upload_media_ws(self._client, file_path)
if media_id:
if frame:
await self._client.reply(frame, {
"msgtype": media_type,
media_type: {"media_id": media_id},
})
else:
await self._client.send_message(msg.chat_id, {
"msgtype": media_type,
media_type: {"media_id": media_id},
})
logger.debug("WeCom sent {}{}", media_type, msg.chat_id)
else:
content += f"\n[file upload failed: {os.path.basename(file_path)}]"
if not content:
return return
# Use streaming reply for better UX if frame:
# Both progress and final messages must use reply_stream (cmd="aibot_respond_msg").
# The plain reply() uses cmd="reply" which does not support "text" msgtype
# and causes errcode=40008 from WeCom API.
stream_id = self._generate_req_id("stream") stream_id = self._generate_req_id("stream")
# Send as streaming message with finish=True
await self._client.reply_stream( await self._client.reply_stream(
frame, frame,
stream_id, stream_id,
content, content,
finish=True, finish=not is_progress,
) )
logger.debug(
"WeCom {} sent to {}",
"progress" if is_progress else "message",
msg.chat_id,
)
else:
# No frame (e.g. cron push): proactive send only supports markdown
await self._client.send_message(msg.chat_id, {
"msgtype": "markdown",
"markdown": {"content": content},
})
logger.info("WeCom proactive send to {}", msg.chat_id)
logger.debug("WeCom message sent to {}", msg.chat_id) except Exception:
logger.exception("Error sending WeCom message to chat_id={}", msg.chat_id)
except Exception as e:
logger.error("Error sending WeCom message: {}", e)
raise

View File

@ -307,7 +307,7 @@ async def test_agent_loop_extra_hook_receives_calls(tmp_path):
) )
loop.tools.get_definitions = MagicMock(return_value=[]) loop.tools.get_definitions = MagicMock(return_value=[])
content, tools_used, messages, _ = await loop._run_agent_loop( content, tools_used, messages, _, _ = await loop._run_agent_loop(
[{"role": "user", "content": "hi"}] [{"role": "user", "content": "hi"}]
) )
@ -331,7 +331,7 @@ async def test_agent_loop_extra_hook_error_isolation(tmp_path):
) )
loop.tools.get_definitions = MagicMock(return_value=[]) loop.tools.get_definitions = MagicMock(return_value=[])
content, _, _, _ = await loop._run_agent_loop( content, _, _, _, _ = await loop._run_agent_loop(
[{"role": "user", "content": "hi"}] [{"role": "user", "content": "hi"}]
) )
@ -373,7 +373,7 @@ async def test_agent_loop_no_hooks_backward_compat(tmp_path):
loop.tools.execute = AsyncMock(return_value="ok") loop.tools.execute = AsyncMock(return_value="ok")
loop.max_iterations = 2 loop.max_iterations = 2
content, tools_used, _, _ = await loop._run_agent_loop([]) content, tools_used, _, _, _ = await loop._run_agent_loop([])
assert content == ( assert content == (
"I reached the maximum number of tool call iterations (2) " "I reached the maximum number of tool call iterations (2) "
"without completing the task. You can try breaking the task into smaller steps." "without completing the task. You can try breaking the task into smaller steps."

View File

@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import base64
import os import os
import time import time
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
@ -798,7 +799,7 @@ async def test_loop_max_iterations_message_stays_stable(tmp_path):
loop.tools.execute = AsyncMock(return_value="ok") loop.tools.execute = AsyncMock(return_value="ok")
loop.max_iterations = 2 loop.max_iterations = 2
final_content, _, _, _ = await loop._run_agent_loop([]) final_content, _, _, _, _ = await loop._run_agent_loop([])
assert final_content == ( assert final_content == (
"I reached the maximum number of tool call iterations (2) " "I reached the maximum number of tool call iterations (2) "
@ -825,7 +826,7 @@ async def test_loop_stream_filter_handles_think_only_prefix_without_crashing(tmp
async def on_stream_end(*, resuming: bool = False) -> None: async def on_stream_end(*, resuming: bool = False) -> None:
endings.append(resuming) endings.append(resuming)
final_content, _, _, _ = await loop._run_agent_loop( final_content, _, _, _, _ = await loop._run_agent_loop(
[], [],
on_stream=on_stream, on_stream=on_stream,
on_stream_end=on_stream_end, on_stream_end=on_stream_end,
@ -849,7 +850,7 @@ async def test_loop_retries_think_only_final_response(tmp_path):
loop.provider.chat_with_retry = chat_with_retry loop.provider.chat_with_retry = chat_with_retry
final_content, _, _, _ = await loop._run_agent_loop([]) final_content, _, _, _, _ = await loop._run_agent_loop([])
assert final_content == "Recovered answer" assert final_content == "Recovered answer"
assert call_count["n"] == 2 assert call_count["n"] == 2
@ -1722,3 +1723,690 @@ def test_governance_fallback_still_repairs_orphans():
repaired = AgentRunner._backfill_missing_tool_results(repaired) repaired = AgentRunner._backfill_missing_tool_results(repaired)
# Orphan tool result should be gone. # Orphan tool result should be gone.
assert not any(m.get("tool_call_id") == "orphan_tc" for m in repaired) assert not any(m.get("tool_call_id") == "orphan_tc" for m in repaired)
# ── Mid-turn injection tests ──────────────────────────────────────────────
@pytest.mark.asyncio
async def test_drain_injections_returns_empty_when_no_callback():
"""No injection_callback → empty list."""
from nanobot.agent.runner import AgentRunSpec, AgentRunner
provider = MagicMock()
runner = AgentRunner(provider)
tools = MagicMock()
tools.get_definitions.return_value = []
spec = AgentRunSpec(
initial_messages=[], tools=tools, model="m",
max_iterations=1, max_tool_result_chars=1000,
injection_callback=None,
)
result = await runner._drain_injections(spec)
assert result == []
@pytest.mark.asyncio
async def test_drain_injections_extracts_content_from_inbound_messages():
"""Should extract .content from InboundMessage objects."""
from nanobot.agent.runner import AgentRunSpec, AgentRunner
from nanobot.bus.events import InboundMessage
provider = MagicMock()
runner = AgentRunner(provider)
tools = MagicMock()
tools.get_definitions.return_value = []
msgs = [
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="hello"),
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="world"),
]
async def cb():
return msgs
spec = AgentRunSpec(
initial_messages=[], tools=tools, model="m",
max_iterations=1, max_tool_result_chars=1000,
injection_callback=cb,
)
result = await runner._drain_injections(spec)
assert result == [
{"role": "user", "content": "hello"},
{"role": "user", "content": "world"},
]
@pytest.mark.asyncio
async def test_drain_injections_passes_limit_to_callback_when_supported():
"""Limit-aware callbacks can preserve overflow in their own queue."""
from nanobot.agent.runner import AgentRunSpec, AgentRunner, _MAX_INJECTIONS_PER_TURN
from nanobot.bus.events import InboundMessage
provider = MagicMock()
runner = AgentRunner(provider)
tools = MagicMock()
tools.get_definitions.return_value = []
seen_limits: list[int] = []
msgs = [
InboundMessage(channel="cli", sender_id="u", chat_id="c", content=f"msg{i}")
for i in range(_MAX_INJECTIONS_PER_TURN + 3)
]
async def cb(*, limit: int):
seen_limits.append(limit)
return msgs[:limit]
spec = AgentRunSpec(
initial_messages=[], tools=tools, model="m",
max_iterations=1, max_tool_result_chars=1000,
injection_callback=cb,
)
result = await runner._drain_injections(spec)
assert seen_limits == [_MAX_INJECTIONS_PER_TURN]
assert result == [
{"role": "user", "content": "msg0"},
{"role": "user", "content": "msg1"},
{"role": "user", "content": "msg2"},
]
@pytest.mark.asyncio
async def test_drain_injections_skips_empty_content():
"""Messages with blank content should be filtered out."""
from nanobot.agent.runner import AgentRunSpec, AgentRunner
from nanobot.bus.events import InboundMessage
provider = MagicMock()
runner = AgentRunner(provider)
tools = MagicMock()
tools.get_definitions.return_value = []
msgs = [
InboundMessage(channel="cli", sender_id="u", chat_id="c", content=""),
InboundMessage(channel="cli", sender_id="u", chat_id="c", content=" "),
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="valid"),
]
async def cb():
return msgs
spec = AgentRunSpec(
initial_messages=[], tools=tools, model="m",
max_iterations=1, max_tool_result_chars=1000,
injection_callback=cb,
)
result = await runner._drain_injections(spec)
assert result == [{"role": "user", "content": "valid"}]
@pytest.mark.asyncio
async def test_drain_injections_handles_callback_exception():
"""If the callback raises, return empty list (error is logged)."""
from nanobot.agent.runner import AgentRunSpec, AgentRunner
provider = MagicMock()
runner = AgentRunner(provider)
tools = MagicMock()
tools.get_definitions.return_value = []
async def cb():
raise RuntimeError("boom")
spec = AgentRunSpec(
initial_messages=[], tools=tools, model="m",
max_iterations=1, max_tool_result_chars=1000,
injection_callback=cb,
)
result = await runner._drain_injections(spec)
assert result == []
@pytest.mark.asyncio
async def test_checkpoint1_injects_after_tool_execution():
"""Follow-up messages are injected after tool execution, before next LLM call."""
from nanobot.agent.runner import AgentRunSpec, AgentRunner
from nanobot.bus.events import InboundMessage
provider = MagicMock()
call_count = {"n": 0}
captured_messages = []
async def chat_with_retry(*, messages, **kwargs):
call_count["n"] += 1
captured_messages.append(list(messages))
if call_count["n"] == 1:
return LLMResponse(
content="using tool",
tool_calls=[ToolCallRequest(id="c1", name="read_file", arguments={"path": "x"})],
usage={},
)
return LLMResponse(content="final answer", tool_calls=[], usage={})
provider.chat_with_retry = chat_with_retry
tools = MagicMock()
tools.get_definitions.return_value = []
tools.execute = AsyncMock(return_value="file content")
injection_queue = asyncio.Queue()
async def inject_cb():
items = []
while not injection_queue.empty():
items.append(await injection_queue.get())
return items
# Put a follow-up message in the queue before the run starts
await injection_queue.put(
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up question")
)
runner = AgentRunner(provider)
result = await runner.run(AgentRunSpec(
initial_messages=[{"role": "user", "content": "hello"}],
tools=tools,
model="test-model",
max_iterations=5,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
injection_callback=inject_cb,
))
assert result.had_injections is True
assert result.final_content == "final answer"
# The second call should have the injected user message
assert call_count["n"] == 2
last_messages = captured_messages[-1]
injected = [m for m in last_messages if m.get("role") == "user" and m.get("content") == "follow-up question"]
assert len(injected) == 1
@pytest.mark.asyncio
async def test_checkpoint2_injects_after_final_response_with_resuming_stream():
"""After final response, if injections exist, stream_end should get resuming=True."""
from nanobot.agent.runner import AgentRunSpec, AgentRunner
from nanobot.agent.hook import AgentHook, AgentHookContext
from nanobot.bus.events import InboundMessage
provider = MagicMock()
call_count = {"n": 0}
stream_end_calls = []
class TrackingHook(AgentHook):
def wants_streaming(self) -> bool:
return True
async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None:
stream_end_calls.append(resuming)
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
return content
async def chat_stream_with_retry(*, messages, on_content_delta=None, **kwargs):
call_count["n"] += 1
if call_count["n"] == 1:
return LLMResponse(content="first answer", tool_calls=[], usage={})
return LLMResponse(content="second answer", tool_calls=[], usage={})
provider.chat_stream_with_retry = chat_stream_with_retry
tools = MagicMock()
tools.get_definitions.return_value = []
injection_queue = asyncio.Queue()
async def inject_cb():
items = []
while not injection_queue.empty():
items.append(await injection_queue.get())
return items
# Inject a follow-up that arrives during the first response
await injection_queue.put(
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="quick follow-up")
)
runner = AgentRunner(provider)
result = await runner.run(AgentRunSpec(
initial_messages=[{"role": "user", "content": "hello"}],
tools=tools,
model="test-model",
max_iterations=5,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
hook=TrackingHook(),
injection_callback=inject_cb,
))
assert result.had_injections is True
assert result.final_content == "second answer"
assert call_count["n"] == 2
# First stream_end should have resuming=True (because injections found)
assert stream_end_calls[0] is True
# Second (final) stream_end should have resuming=False
assert stream_end_calls[-1] is False
@pytest.mark.asyncio
async def test_checkpoint2_preserves_final_response_in_history_before_followup():
"""A follow-up injected after a final answer must still see that answer in history."""
from nanobot.agent.runner import AgentRunSpec, AgentRunner
from nanobot.bus.events import InboundMessage
provider = MagicMock()
call_count = {"n": 0}
captured_messages = []
async def chat_with_retry(*, messages, **kwargs):
call_count["n"] += 1
captured_messages.append([dict(message) for message in messages])
if call_count["n"] == 1:
return LLMResponse(content="first answer", tool_calls=[], usage={})
return LLMResponse(content="second answer", tool_calls=[], usage={})
provider.chat_with_retry = chat_with_retry
tools = MagicMock()
tools.get_definitions.return_value = []
injection_queue = asyncio.Queue()
async def inject_cb():
items = []
while not injection_queue.empty():
items.append(await injection_queue.get())
return items
await injection_queue.put(
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up question")
)
runner = AgentRunner(provider)
result = await runner.run(AgentRunSpec(
initial_messages=[{"role": "user", "content": "hello"}],
tools=tools,
model="test-model",
max_iterations=5,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
injection_callback=inject_cb,
))
assert result.final_content == "second answer"
assert call_count["n"] == 2
assert captured_messages[-1] == [
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "first answer"},
{"role": "user", "content": "follow-up question"},
]
assert [
{"role": message["role"], "content": message["content"]}
for message in result.messages
if message.get("role") == "assistant"
] == [
{"role": "assistant", "content": "first answer"},
{"role": "assistant", "content": "second answer"},
]
@pytest.mark.asyncio
async def test_loop_injected_followup_preserves_image_media(tmp_path):
"""Mid-turn follow-ups with images should keep multimodal content."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
image_path = tmp_path / "followup.png"
image_path.write_bytes(base64.b64decode(
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO+yF9kAAAAASUVORK5CYII="
))
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
captured_messages: list[list[dict]] = []
call_count = {"n": 0}
async def chat_with_retry(*, messages, **kwargs):
call_count["n"] += 1
captured_messages.append(list(messages))
if call_count["n"] == 1:
return LLMResponse(content="first answer", tool_calls=[], usage={})
return LLMResponse(content="second answer", tool_calls=[], usage={})
provider.chat_with_retry = chat_with_retry
loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model")
loop.tools.get_definitions = MagicMock(return_value=[])
pending_queue = asyncio.Queue()
await pending_queue.put(InboundMessage(
channel="cli",
sender_id="u",
chat_id="c",
content="",
media=[str(image_path)],
))
final_content, _, _, _, had_injections = await loop._run_agent_loop(
[{"role": "user", "content": "hello"}],
channel="cli",
chat_id="c",
pending_queue=pending_queue,
)
assert final_content == "second answer"
assert had_injections is True
assert call_count["n"] == 2
injected_user_messages = [
message for message in captured_messages[-1]
if message.get("role") == "user" and isinstance(message.get("content"), list)
]
assert injected_user_messages
assert any(
block.get("type") == "image_url"
for block in injected_user_messages[-1]["content"]
if isinstance(block, dict)
)
@pytest.mark.asyncio
async def test_runner_merges_multiple_injected_user_messages_without_losing_media():
"""Multiple injected follow-ups should not create lossy consecutive user messages."""
from nanobot.agent.runner import AgentRunSpec, AgentRunner
provider = MagicMock()
call_count = {"n": 0}
captured_messages = []
async def chat_with_retry(*, messages, **kwargs):
call_count["n"] += 1
captured_messages.append([dict(message) for message in messages])
if call_count["n"] == 1:
return LLMResponse(content="first answer", tool_calls=[], usage={})
return LLMResponse(content="second answer", tool_calls=[], usage={})
provider.chat_with_retry = chat_with_retry
tools = MagicMock()
tools.get_definitions.return_value = []
async def inject_cb():
if call_count["n"] == 1:
return [
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}},
{"type": "text", "text": "look at this"},
],
},
{"role": "user", "content": "and answer briefly"},
]
return []
runner = AgentRunner(provider)
result = await runner.run(AgentRunSpec(
initial_messages=[{"role": "user", "content": "hello"}],
tools=tools,
model="test-model",
max_iterations=5,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
injection_callback=inject_cb,
))
assert result.final_content == "second answer"
assert call_count["n"] == 2
second_call = captured_messages[-1]
user_messages = [message for message in second_call if message.get("role") == "user"]
assert len(user_messages) == 2
injected = user_messages[-1]
assert isinstance(injected["content"], list)
assert any(
block.get("type") == "image_url"
for block in injected["content"]
if isinstance(block, dict)
)
assert any(
block.get("type") == "text" and block.get("text") == "and answer briefly"
for block in injected["content"]
if isinstance(block, dict)
)
@pytest.mark.asyncio
async def test_injection_cycles_capped_at_max():
"""Injection cycles should be capped at _MAX_INJECTION_CYCLES."""
from nanobot.agent.runner import AgentRunSpec, AgentRunner, _MAX_INJECTION_CYCLES
from nanobot.bus.events import InboundMessage
provider = MagicMock()
call_count = {"n": 0}
async def chat_with_retry(*, messages, **kwargs):
call_count["n"] += 1
return LLMResponse(content=f"answer-{call_count['n']}", tool_calls=[], usage={})
provider.chat_with_retry = chat_with_retry
tools = MagicMock()
tools.get_definitions.return_value = []
drain_count = {"n": 0}
async def inject_cb():
drain_count["n"] += 1
# Only inject for the first _MAX_INJECTION_CYCLES drains
if drain_count["n"] <= _MAX_INJECTION_CYCLES:
return [InboundMessage(channel="cli", sender_id="u", chat_id="c", content=f"msg-{drain_count['n']}")]
return []
runner = AgentRunner(provider)
result = await runner.run(AgentRunSpec(
initial_messages=[{"role": "user", "content": "start"}],
tools=tools,
model="test-model",
max_iterations=20,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
injection_callback=inject_cb,
))
assert result.had_injections is True
# Should be capped: _MAX_INJECTION_CYCLES injection rounds + 1 final round
assert call_count["n"] == _MAX_INJECTION_CYCLES + 1
@pytest.mark.asyncio
async def test_no_injections_flag_is_false_by_default():
"""had_injections should be False when no injection callback or no messages."""
from nanobot.agent.runner import AgentRunSpec, AgentRunner
provider = MagicMock()
async def chat_with_retry(**kwargs):
return LLMResponse(content="done", tool_calls=[], usage={})
provider.chat_with_retry = chat_with_retry
tools = MagicMock()
tools.get_definitions.return_value = []
runner = AgentRunner(provider)
result = await runner.run(AgentRunSpec(
initial_messages=[{"role": "user", "content": "hi"}],
tools=tools,
model="test-model",
max_iterations=1,
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
))
assert result.had_injections is False
@pytest.mark.asyncio
async def test_pending_queue_cleanup_on_dispatch(tmp_path):
"""_pending_queues should be cleaned up after _dispatch completes."""
loop = _make_loop(tmp_path)
async def chat_with_retry(**kwargs):
return LLMResponse(content="done", tool_calls=[], usage={})
loop.provider.chat_with_retry = chat_with_retry
from nanobot.bus.events import InboundMessage
msg = InboundMessage(channel="cli", sender_id="u", chat_id="c", content="hello")
# The queue should not exist before dispatch
assert msg.session_key not in loop._pending_queues
await loop._dispatch(msg)
# The queue should be cleaned up after dispatch
assert msg.session_key not in loop._pending_queues
@pytest.mark.asyncio
async def test_followup_routed_to_pending_queue(tmp_path):
"""Unified-session follow-ups should route into the active pending queue."""
from nanobot.agent.loop import UNIFIED_SESSION_KEY
from nanobot.bus.events import InboundMessage
loop = _make_loop(tmp_path)
loop._unified_session = True
loop._dispatch = AsyncMock() # type: ignore[method-assign]
pending = asyncio.Queue(maxsize=20)
loop._pending_queues[UNIFIED_SESSION_KEY] = pending
run_task = asyncio.create_task(loop.run())
msg = InboundMessage(channel="discord", sender_id="u", chat_id="c", content="follow-up")
await loop.bus.publish_inbound(msg)
deadline = time.time() + 2
while pending.empty() and time.time() < deadline:
await asyncio.sleep(0.01)
loop.stop()
await asyncio.wait_for(run_task, timeout=2)
assert loop._dispatch.await_count == 0
assert not pending.empty()
queued_msg = pending.get_nowait()
assert queued_msg.content == "follow-up"
assert queued_msg.session_key == UNIFIED_SESSION_KEY
@pytest.mark.asyncio
async def test_pending_queue_preserves_overflow_for_next_injection_cycle(tmp_path):
"""Pending queue should leave overflow messages queued for later drains."""
from nanobot.agent.loop import AgentLoop
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.agent.runner import _MAX_INJECTIONS_PER_TURN
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
captured_messages: list[list[dict]] = []
call_count = {"n": 0}
async def chat_with_retry(*, messages, **kwargs):
call_count["n"] += 1
captured_messages.append([dict(message) for message in messages])
return LLMResponse(content=f"answer-{call_count['n']}", tool_calls=[], usage={})
provider.chat_with_retry = chat_with_retry
loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model")
loop.tools.get_definitions = MagicMock(return_value=[])
pending_queue = asyncio.Queue()
total_followups = _MAX_INJECTIONS_PER_TURN + 2
for idx in range(total_followups):
await pending_queue.put(InboundMessage(
channel="cli",
sender_id="u",
chat_id="c",
content=f"follow-up-{idx}",
))
final_content, _, _, _, had_injections = await loop._run_agent_loop(
[{"role": "user", "content": "hello"}],
channel="cli",
chat_id="c",
pending_queue=pending_queue,
)
assert final_content == "answer-3"
assert had_injections is True
assert call_count["n"] == 3
flattened_user_content = "\n".join(
message["content"]
for message in captured_messages[-1]
if message.get("role") == "user" and isinstance(message.get("content"), str)
)
for idx in range(total_followups):
assert f"follow-up-{idx}" in flattened_user_content
assert pending_queue.empty()
@pytest.mark.asyncio
async def test_pending_queue_full_falls_back_to_queued_task(tmp_path):
"""QueueFull should preserve the message by dispatching a queued task."""
from nanobot.bus.events import InboundMessage
loop = _make_loop(tmp_path)
loop._dispatch = AsyncMock() # type: ignore[method-assign]
pending = asyncio.Queue(maxsize=1)
pending.put_nowait(InboundMessage(channel="cli", sender_id="u", chat_id="c", content="already queued"))
loop._pending_queues["cli:c"] = pending
run_task = asyncio.create_task(loop.run())
msg = InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up")
await loop.bus.publish_inbound(msg)
deadline = time.time() + 2
while loop._dispatch.await_count == 0 and time.time() < deadline:
await asyncio.sleep(0.01)
loop.stop()
await asyncio.wait_for(run_task, timeout=2)
assert loop._dispatch.await_count == 1
dispatched_msg = loop._dispatch.await_args.args[0]
assert dispatched_msg.content == "follow-up"
assert pending.qsize() == 1
@pytest.mark.asyncio
async def test_dispatch_republishes_leftover_queue_messages(tmp_path):
"""Messages left in the pending queue after _dispatch are re-published to the bus.
This tests the finally-block cleanup that prevents message loss when
the runner exits early (e.g., max_iterations, tool_error) with messages
still in the queue.
"""
from nanobot.bus.events import InboundMessage
loop = _make_loop(tmp_path)
bus = loop.bus
# Simulate a completed dispatch by manually registering a queue
# with leftover messages, then running the cleanup logic directly.
pending = asyncio.Queue(maxsize=20)
session_key = "cli:c"
loop._pending_queues[session_key] = pending
pending.put_nowait(InboundMessage(channel="cli", sender_id="u", chat_id="c", content="leftover-1"))
pending.put_nowait(InboundMessage(channel="cli", sender_id="u", chat_id="c", content="leftover-2"))
# Execute the cleanup logic from the finally block
queue = loop._pending_queues.pop(session_key, None)
assert queue is not None
leftover = 0
while True:
try:
item = queue.get_nowait()
except asyncio.QueueEmpty:
break
await bus.publish_inbound(item)
leftover += 1
assert leftover == 2
# Verify the messages are now on the bus
msgs = []
while not bus.inbound.empty():
msgs.append(await asyncio.wait_for(bus.consume_inbound(), timeout=0.5))
contents = [m.content for m in msgs]
assert "leftover-1" in contents
assert "leftover-2" in contents

View File

@ -0,0 +1,304 @@
"""Tests for QQ channel media support: helpers, send, inbound, and upload."""
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
try:
from nanobot.channels import qq
QQ_AVAILABLE = getattr(qq, "QQ_AVAILABLE", False)
except ImportError:
QQ_AVAILABLE = False
if not QQ_AVAILABLE:
pytest.skip("QQ dependencies not installed (qq-botpy)", allow_module_level=True)
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.qq import (
QQ_FILE_TYPE_FILE,
QQ_FILE_TYPE_IMAGE,
QQChannel,
QQConfig,
_guess_send_file_type,
_is_image_name,
_sanitize_filename,
)
class _FakeApi:
def __init__(self) -> None:
self.c2c_calls: list[dict] = []
self.group_calls: list[dict] = []
async def post_c2c_message(self, **kwargs) -> None:
self.c2c_calls.append(kwargs)
async def post_group_message(self, **kwargs) -> None:
self.group_calls.append(kwargs)
class _FakeHttp:
"""Fake _http for _post_base64file tests."""
def __init__(self, return_value: dict | None = None) -> None:
self.return_value = return_value or {}
self.calls: list[tuple] = []
async def request(self, route, **kwargs):
self.calls.append((route, kwargs))
return self.return_value
class _FakeClient:
def __init__(self, http_return: dict | None = None) -> None:
self.api = _FakeApi()
self.api._http = _FakeHttp(http_return)
# ── Helper function tests (pure, no async) ──────────────────────────
def test_sanitize_filename_strips_path_traversal() -> None:
assert _sanitize_filename("../../etc/passwd") == "passwd"
def test_sanitize_filename_keeps_chinese_chars() -> None:
assert _sanitize_filename("文件1.jpg") == "文件1.jpg"
def test_sanitize_filename_strips_unsafe_chars() -> None:
result = _sanitize_filename('file<>:"|?*.txt')
# All unsafe chars replaced with "_", but * is replaced too
assert result.startswith("file")
assert result.endswith(".txt")
assert "<" not in result
assert ">" not in result
assert '"' not in result
assert "|" not in result
assert "?" not in result
def test_sanitize_filename_empty_input() -> None:
assert _sanitize_filename("") == ""
def test_is_image_name_with_known_extensions() -> None:
for ext in (".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp", ".tif", ".tiff", ".ico", ".svg"):
assert _is_image_name(f"photo{ext}") is True
def test_is_image_name_with_unknown_extension() -> None:
for ext in (".pdf", ".txt", ".mp3", ".mp4"):
assert _is_image_name(f"doc{ext}") is False
def test_guess_send_file_type_image() -> None:
assert _guess_send_file_type("photo.png") == QQ_FILE_TYPE_IMAGE
assert _guess_send_file_type("pic.jpg") == QQ_FILE_TYPE_IMAGE
def test_guess_send_file_type_file() -> None:
assert _guess_send_file_type("doc.pdf") == QQ_FILE_TYPE_FILE
def test_guess_send_file_type_by_mime() -> None:
# A filename with no known extension but whose mime type is image/*
assert _guess_send_file_type("photo.xyz_image_test") == QQ_FILE_TYPE_FILE
# ── send() exception handling ───────────────────────────────────────
@pytest.mark.asyncio
async def test_send_exception_caught_not_raised() -> None:
"""Exceptions inside send() must not propagate."""
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
channel._client = _FakeClient()
with patch.object(channel, "_send_text_only", new_callable=AsyncMock, side_effect=RuntimeError("boom")):
await channel.send(
OutboundMessage(channel="qq", chat_id="user1", content="hello")
)
# No exception raised — test passes if we get here.
@pytest.mark.asyncio
async def test_send_media_then_text() -> None:
"""Media is sent before text when both are present."""
import tempfile
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
channel._client = _FakeClient()
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
f.write(b"\x89PNG\r\n")
tmp = f.name
try:
with patch.object(channel, "_post_base64file", new_callable=AsyncMock, return_value={"file_info": "1"}) as mock_upload:
await channel.send(
OutboundMessage(
channel="qq",
chat_id="user1",
content="text after image",
media=[tmp],
metadata={"message_id": "m1"},
)
)
assert mock_upload.called
# Text should have been sent via c2c (default chat type)
text_calls = [c for c in channel._client.api.c2c_calls if c.get("msg_type") == 0]
assert len(text_calls) >= 1
assert text_calls[-1]["content"] == "text after image"
finally:
import os
os.unlink(tmp)
@pytest.mark.asyncio
async def test_send_media_failure_falls_back_to_text() -> None:
"""When _send_media returns False, a failure notice is appended."""
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
channel._client = _FakeClient()
with patch.object(channel, "_send_media", new_callable=AsyncMock, return_value=False):
await channel.send(
OutboundMessage(
channel="qq",
chat_id="user1",
content="hello",
media=["https://example.com/bad.png"],
metadata={"message_id": "m1"},
)
)
# Should have the failure text among the c2c calls
failure_calls = [c for c in channel._client.api.c2c_calls if "Attachment send failed" in c.get("content", "")]
assert len(failure_calls) == 1
assert "bad.png" in failure_calls[0]["content"]
# ── _on_message() exception handling ────────────────────────────────
@pytest.mark.asyncio
async def test_on_message_exception_caught_not_raised() -> None:
"""Missing required attributes should not crash _on_message."""
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
channel._client = _FakeClient()
# Construct a message-like object that lacks 'author' — triggers AttributeError
bad_data = SimpleNamespace(id="x1", content="hi")
# Should not raise
await channel._on_message(bad_data, is_group=False)
@pytest.mark.asyncio
async def test_on_message_with_attachments() -> None:
"""Messages with attachments produce media_paths and formatted content."""
import tempfile
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
channel._client = _FakeClient()
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
f.write(b"\x89PNG\r\n")
saved_path = f.name
att = SimpleNamespace(url="", filename="screenshot.png", content_type="image/png")
# Patch _download_to_media_dir_chunked to return the temp file path
async def fake_download(url, filename_hint=""):
return saved_path
try:
with patch.object(channel, "_download_to_media_dir_chunked", side_effect=fake_download):
data = SimpleNamespace(
id="att1",
content="look at this",
author=SimpleNamespace(user_openid="u1"),
attachments=[att],
)
await channel._on_message(data, is_group=False)
msg = await channel.bus.consume_inbound()
assert "look at this" in msg.content
assert "screenshot.png" in msg.content
assert "Received files:" in msg.content
assert len(msg.media) == 1
assert msg.media[0] == saved_path
finally:
import os
os.unlink(saved_path)
# ── _post_base64file() ─────────────────────────────────────────────
@pytest.mark.asyncio
async def test_post_base64file_omits_file_name_for_images() -> None:
"""file_type=1 (image) → payload must not contain file_name."""
channel = QQChannel(QQConfig(app_id="app", secret="secret"), MessageBus())
channel._client = _FakeClient(http_return={"file_info": "img_abc"})
await channel._post_base64file(
chat_id="user1",
is_group=False,
file_type=QQ_FILE_TYPE_IMAGE,
file_data="ZmFrZQ==",
file_name="photo.png",
)
http = channel._client.api._http
assert len(http.calls) == 1
payload = http.calls[0][1]["json"]
assert "file_name" not in payload
assert payload["file_type"] == QQ_FILE_TYPE_IMAGE
@pytest.mark.asyncio
async def test_post_base64file_includes_file_name_for_files() -> None:
"""file_type=4 (file) → payload must contain file_name."""
channel = QQChannel(QQConfig(app_id="app", secret="secret"), MessageBus())
channel._client = _FakeClient(http_return={"file_info": "file_abc"})
await channel._post_base64file(
chat_id="user1",
is_group=False,
file_type=QQ_FILE_TYPE_FILE,
file_data="ZmFrZQ==",
file_name="report.pdf",
)
http = channel._client.api._http
assert len(http.calls) == 1
payload = http.calls[0][1]["json"]
assert payload["file_name"] == "report.pdf"
assert payload["file_type"] == QQ_FILE_TYPE_FILE
@pytest.mark.asyncio
async def test_post_base64file_filters_response_to_file_info() -> None:
"""Response with file_info + extra fields must be filtered to only file_info."""
channel = QQChannel(QQConfig(app_id="app", secret="secret"), MessageBus())
channel._client = _FakeClient(http_return={
"file_info": "fi_123",
"file_uuid": "uuid_xxx",
"ttl": 3600,
})
result = await channel._post_base64file(
chat_id="user1",
is_group=False,
file_type=QQ_FILE_TYPE_FILE,
file_data="ZmFrZQ==",
file_name="doc.pdf",
)
assert result == {"file_info": "fi_123"}
assert "file_uuid" not in result
assert "ttl" not in result

View File

@ -0,0 +1,584 @@
"""Tests for WeCom channel: helpers, download, upload, send, and message processing."""
import os
import tempfile
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
try:
import importlib.util
WECOM_AVAILABLE = importlib.util.find_spec("wecom_aibot_sdk") is not None
except ImportError:
WECOM_AVAILABLE = False
if not WECOM_AVAILABLE:
pytest.skip("WeCom dependencies not installed (wecom_aibot_sdk)", allow_module_level=True)
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.wecom import (
WecomChannel,
WecomConfig,
_guess_wecom_media_type,
_sanitize_filename,
)
# Try to import the real response class; fall back to a stub if unavailable.
try:
from wecom_aibot_sdk.utils import WsResponse
_RealWsResponse = WsResponse
except ImportError:
_RealWsResponse = None
class _FakeResponse:
"""Minimal stand-in for wecom_aibot_sdk WsResponse."""
def __init__(self, errcode: int = 0, body: dict | None = None, errmsg: str = "ok"):
self.errcode = errcode
self.errmsg = errmsg
self.body = body or {}
class _FakeWsManager:
"""Tracks send_reply calls and returns configurable responses."""
def __init__(self, responses: list[_FakeResponse] | None = None):
self.responses = responses or []
self.calls: list[tuple[str, dict, str]] = []
self._idx = 0
async def send_reply(self, req_id: str, data: dict, cmd: str) -> _FakeResponse:
self.calls.append((req_id, data, cmd))
if self._idx < len(self.responses):
resp = self.responses[self._idx]
self._idx += 1
return resp
return _FakeResponse()
class _FakeFrame:
"""Minimal frame object with a body dict."""
def __init__(self, body: dict | None = None):
self.body = body or {}
class _FakeWeComClient:
"""Fake WeCom client with mock methods."""
def __init__(self, ws_responses: list[_FakeResponse] | None = None):
self._ws_manager = _FakeWsManager(ws_responses)
self.download_file = AsyncMock(return_value=(None, None))
self.reply = AsyncMock()
self.reply_stream = AsyncMock()
self.send_message = AsyncMock()
self.reply_welcome = AsyncMock()
# ── Helper function tests (pure, no async) ──────────────────────────
def test_sanitize_filename_strips_path_traversal() -> None:
assert _sanitize_filename("../../etc/passwd") == "passwd"
def test_sanitize_filename_keeps_chinese_chars() -> None:
assert _sanitize_filename("文件1.jpg") == "文件1.jpg"
def test_sanitize_filename_empty_input() -> None:
assert _sanitize_filename("") == ""
def test_guess_wecom_media_type_image() -> None:
for ext in (".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp"):
assert _guess_wecom_media_type(f"photo{ext}") == "image"
def test_guess_wecom_media_type_video() -> None:
for ext in (".mp4", ".avi", ".mov"):
assert _guess_wecom_media_type(f"video{ext}") == "video"
def test_guess_wecom_media_type_voice() -> None:
for ext in (".amr", ".mp3", ".wav", ".ogg"):
assert _guess_wecom_media_type(f"audio{ext}") == "voice"
def test_guess_wecom_media_type_file_fallback() -> None:
for ext in (".pdf", ".doc", ".xlsx", ".zip"):
assert _guess_wecom_media_type(f"doc{ext}") == "file"
def test_guess_wecom_media_type_case_insensitive() -> None:
assert _guess_wecom_media_type("photo.PNG") == "image"
assert _guess_wecom_media_type("photo.Jpg") == "image"
# ── _download_and_save_media() ──────────────────────────────────────
@pytest.mark.asyncio
async def test_download_and_save_success() -> None:
"""Successful download writes file and returns sanitized path."""
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
client = _FakeWeComClient()
channel._client = client
fake_data = b"\x89PNG\r\nfake image"
client.download_file.return_value = (fake_data, "raw_photo.png")
with patch("nanobot.channels.wecom.get_media_dir", return_value=Path(tempfile.gettempdir())):
path = await channel._download_and_save_media("https://example.com/img.png", "aes_key", "image", "photo.png")
assert path is not None
assert os.path.isfile(path)
assert os.path.basename(path) == "photo.png"
# Cleanup
os.unlink(path)
@pytest.mark.asyncio
async def test_download_and_save_oversized_rejected() -> None:
"""Data exceeding 200MB is rejected → returns None."""
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
client = _FakeWeComClient()
channel._client = client
big_data = b"\x00" * (200 * 1024 * 1024 + 1) # 200MB + 1 byte
client.download_file.return_value = (big_data, "big.bin")
with patch("nanobot.channels.wecom.get_media_dir", return_value=Path(tempfile.gettempdir())):
result = await channel._download_and_save_media("https://example.com/big.bin", "key", "file", "big.bin")
assert result is None
@pytest.mark.asyncio
async def test_download_and_save_failure() -> None:
"""SDK returns None data → returns None."""
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
client = _FakeWeComClient()
channel._client = client
client.download_file.return_value = (None, None)
with patch("nanobot.channels.wecom.get_media_dir", return_value=Path(tempfile.gettempdir())):
result = await channel._download_and_save_media("https://example.com/fail.png", "key", "image")
assert result is None
# ── _upload_media_ws() ──────────────────────────────────────────────
@pytest.mark.asyncio
async def test_upload_media_ws_success() -> None:
"""Happy path: init → chunk → finish → returns (media_id, media_type)."""
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
f.write(b"\x89PNG\r\n")
tmp = f.name
try:
responses = [
_FakeResponse(errcode=0, body={"upload_id": "up_1"}),
_FakeResponse(errcode=0, body={}),
_FakeResponse(errcode=0, body={"media_id": "media_abc"}),
]
client = _FakeWeComClient(responses)
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
channel._client = client
with patch("wecom_aibot_sdk.utils.generate_req_id", side_effect=lambda x: f"req_{x}"):
media_id, media_type = await channel._upload_media_ws(client, tmp)
assert media_id == "media_abc"
assert media_type == "image"
finally:
os.unlink(tmp)
@pytest.mark.asyncio
async def test_upload_media_ws_oversized_file() -> None:
"""File >200MB triggers ValueError → returns (None, None)."""
# Instead of creating a real 200MB+ file, mock os.path.getsize and open
with patch("os.path.getsize", return_value=200 * 1024 * 1024 + 1), \
patch("builtins.open", MagicMock()):
client = _FakeWeComClient()
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
channel._client = client
result = await channel._upload_media_ws(client, "/fake/large.bin")
assert result == (None, None)
@pytest.mark.asyncio
async def test_upload_media_ws_init_failure() -> None:
"""Init step returns errcode != 0 → returns (None, None)."""
with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as f:
f.write(b"hello")
tmp = f.name
try:
responses = [
_FakeResponse(errcode=50001, errmsg="invalid"),
]
client = _FakeWeComClient(responses)
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
channel._client = client
with patch("wecom_aibot_sdk.utils.generate_req_id", side_effect=lambda x: f"req_{x}"):
result = await channel._upload_media_ws(client, tmp)
assert result == (None, None)
finally:
os.unlink(tmp)
@pytest.mark.asyncio
async def test_upload_media_ws_chunk_failure() -> None:
"""Chunk step returns errcode != 0 → returns (None, None)."""
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
f.write(b"\x89PNG\r\n")
tmp = f.name
try:
responses = [
_FakeResponse(errcode=0, body={"upload_id": "up_1"}),
_FakeResponse(errcode=50002, errmsg="chunk fail"),
]
client = _FakeWeComClient(responses)
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
channel._client = client
with patch("wecom_aibot_sdk.utils.generate_req_id", side_effect=lambda x: f"req_{x}"):
result = await channel._upload_media_ws(client, tmp)
assert result == (None, None)
finally:
os.unlink(tmp)
@pytest.mark.asyncio
async def test_upload_media_ws_finish_no_media_id() -> None:
"""Finish step returns empty media_id → returns (None, None)."""
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
f.write(b"\x89PNG\r\n")
tmp = f.name
try:
responses = [
_FakeResponse(errcode=0, body={"upload_id": "up_1"}),
_FakeResponse(errcode=0, body={}),
_FakeResponse(errcode=0, body={}), # no media_id
]
client = _FakeWeComClient(responses)
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
channel._client = client
with patch("wecom_aibot_sdk.utils.generate_req_id", side_effect=lambda x: f"req_{x}"):
result = await channel._upload_media_ws(client, tmp)
assert result == (None, None)
finally:
os.unlink(tmp)
# ── send() ──────────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_send_text_with_frame() -> None:
"""When frame is stored, send uses reply_stream for final text."""
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
client = _FakeWeComClient()
channel._client = client
channel._generate_req_id = lambda x: f"req_{x}"
channel._chat_frames["chat1"] = _FakeFrame()
await channel.send(
OutboundMessage(channel="wecom", chat_id="chat1", content="hello")
)
client.reply_stream.assert_called_once()
call_args = client.reply_stream.call_args
assert call_args[0][2] == "hello" # content arg
@pytest.mark.asyncio
async def test_send_progress_with_frame() -> None:
"""When metadata has _progress, send uses reply_stream with finish=False."""
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
client = _FakeWeComClient()
channel._client = client
channel._generate_req_id = lambda x: f"req_{x}"
channel._chat_frames["chat1"] = _FakeFrame()
await channel.send(
OutboundMessage(channel="wecom", chat_id="chat1", content="thinking...", metadata={"_progress": True})
)
client.reply_stream.assert_called_once()
call_args = client.reply_stream.call_args
assert call_args[0][2] == "thinking..." # content arg
assert call_args[1]["finish"] is False
@pytest.mark.asyncio
async def test_send_proactive_without_frame() -> None:
"""Without stored frame, send uses send_message with markdown."""
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
client = _FakeWeComClient()
channel._client = client
await channel.send(
OutboundMessage(channel="wecom", chat_id="chat1", content="proactive msg")
)
client.send_message.assert_called_once()
call_args = client.send_message.call_args
assert call_args[0][0] == "chat1"
assert call_args[0][1]["msgtype"] == "markdown"
@pytest.mark.asyncio
async def test_send_media_then_text() -> None:
"""Media files are uploaded and sent before text content."""
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
f.write(b"\x89PNG\r\n")
tmp = f.name
try:
responses = [
_FakeResponse(errcode=0, body={"upload_id": "up_1"}),
_FakeResponse(errcode=0, body={}),
_FakeResponse(errcode=0, body={"media_id": "media_123"}),
]
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
client = _FakeWeComClient(responses)
channel._client = client
channel._generate_req_id = lambda x: f"req_{x}"
channel._chat_frames["chat1"] = _FakeFrame()
await channel.send(
OutboundMessage(channel="wecom", chat_id="chat1", content="see image", media=[tmp])
)
# Media should have been sent via reply
media_calls = [c for c in client.reply.call_args_list if c[0][1].get("msgtype") == "image"]
assert len(media_calls) == 1
assert media_calls[0][0][1]["image"]["media_id"] == "media_123"
# Text should have been sent via reply_stream
client.reply_stream.assert_called_once()
finally:
os.unlink(tmp)
@pytest.mark.asyncio
async def test_send_media_file_not_found() -> None:
"""Non-existent media path is skipped with a warning."""
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
client = _FakeWeComClient()
channel._client = client
channel._generate_req_id = lambda x: f"req_{x}"
channel._chat_frames["chat1"] = _FakeFrame()
await channel.send(
OutboundMessage(channel="wecom", chat_id="chat1", content="hello", media=["/nonexistent/file.png"])
)
# reply_stream should still be called for the text part
client.reply_stream.assert_called_once()
# No media reply should happen
media_calls = [c for c in client.reply.call_args_list if c[0][1].get("msgtype") in ("image", "file", "video")]
assert len(media_calls) == 0
@pytest.mark.asyncio
async def test_send_exception_caught_not_raised() -> None:
"""Exceptions inside send() must not propagate."""
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
client = _FakeWeComClient()
channel._client = client
channel._generate_req_id = lambda x: f"req_{x}"
channel._chat_frames["chat1"] = _FakeFrame()
# Make reply_stream raise
client.reply_stream.side_effect = RuntimeError("boom")
await channel.send(
OutboundMessage(channel="wecom", chat_id="chat1", content="fail test")
)
# No exception — test passes if we reach here.
# ── _process_message() ──────────────────────────────────────────────
@pytest.mark.asyncio
async def test_process_text_message() -> None:
"""Text message is routed to bus with correct fields."""
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["user1"]), MessageBus())
client = _FakeWeComClient()
channel._client = client
frame = _FakeFrame(body={
"msgid": "msg_text_1",
"chatid": "chat1",
"chattype": "single",
"from": {"userid": "user1"},
"text": {"content": "hello wecom"},
})
await channel._process_message(frame, "text")
msg = await channel.bus.consume_inbound()
assert msg.sender_id == "user1"
assert msg.chat_id == "chat1"
assert msg.content == "hello wecom"
assert msg.metadata["msg_type"] == "text"
@pytest.mark.asyncio
async def test_process_image_message() -> None:
"""Image message: download success → media_paths non-empty."""
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["user1"]), MessageBus())
client = _FakeWeComClient()
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
f.write(b"\x89PNG\r\n")
saved = f.name
client.download_file.return_value = (b"\x89PNG\r\n", "photo.png")
channel._client = client
try:
with patch("nanobot.channels.wecom.get_media_dir", return_value=Path(os.path.dirname(saved))):
frame = _FakeFrame(body={
"msgid": "msg_img_1",
"chatid": "chat1",
"from": {"userid": "user1"},
"image": {"url": "https://example.com/img.png", "aeskey": "key123"},
})
await channel._process_message(frame, "image")
msg = await channel.bus.consume_inbound()
assert len(msg.media) == 1
assert msg.media[0].endswith("photo.png")
assert "[image:" in msg.content
finally:
if os.path.exists(saved):
pass # may have been overwritten; clean up if exists
# Clean up any photo.png in tempdir
p = os.path.join(os.path.dirname(saved), "photo.png")
if os.path.exists(p):
os.unlink(p)
@pytest.mark.asyncio
async def test_process_file_message() -> None:
"""File message: download success → media_paths non-empty (critical fix verification)."""
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["user1"]), MessageBus())
client = _FakeWeComClient()
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as f:
f.write(b"%PDF-1.4 fake")
saved = f.name
client.download_file.return_value = (b"%PDF-1.4 fake", "report.pdf")
channel._client = client
try:
with patch("nanobot.channels.wecom.get_media_dir", return_value=Path(os.path.dirname(saved))):
frame = _FakeFrame(body={
"msgid": "msg_file_1",
"chatid": "chat1",
"from": {"userid": "user1"},
"file": {"url": "https://example.com/report.pdf", "aeskey": "key456", "name": "report.pdf"},
})
await channel._process_message(frame, "file")
msg = await channel.bus.consume_inbound()
assert len(msg.media) == 1
assert msg.media[0].endswith("report.pdf")
assert "[file: report.pdf]" in msg.content
finally:
p = os.path.join(os.path.dirname(saved), "report.pdf")
if os.path.exists(p):
os.unlink(p)
@pytest.mark.asyncio
async def test_process_voice_message() -> None:
"""Voice message: transcribed text is included in content."""
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["user1"]), MessageBus())
client = _FakeWeComClient()
channel._client = client
frame = _FakeFrame(body={
"msgid": "msg_voice_1",
"chatid": "chat1",
"from": {"userid": "user1"},
"voice": {"content": "transcribed text here"},
})
await channel._process_message(frame, "voice")
msg = await channel.bus.consume_inbound()
assert "transcribed text here" in msg.content
assert "[voice]" in msg.content
@pytest.mark.asyncio
async def test_process_message_deduplication() -> None:
"""Same msg_id is not processed twice."""
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["user1"]), MessageBus())
client = _FakeWeComClient()
channel._client = client
frame = _FakeFrame(body={
"msgid": "msg_dup_1",
"chatid": "chat1",
"from": {"userid": "user1"},
"text": {"content": "once"},
})
await channel._process_message(frame, "text")
await channel._process_message(frame, "text")
msg = await channel.bus.consume_inbound()
assert msg.content == "once"
# Second message should not appear on the bus
assert channel.bus.inbound.empty()
@pytest.mark.asyncio
async def test_process_message_empty_content_skipped() -> None:
"""Message with empty content produces no bus message."""
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["user1"]), MessageBus())
client = _FakeWeComClient()
channel._client = client
frame = _FakeFrame(body={
"msgid": "msg_empty_1",
"chatid": "chat1",
"from": {"userid": "user1"},
"text": {"content": ""},
})
await channel._process_message(frame, "text")
assert channel.bus.inbound.empty()

View File

@ -1,5 +1,6 @@
"""Test message tool suppress logic for final replies.""" """Test message tool suppress logic for final replies."""
import asyncio
from pathlib import Path from pathlib import Path
from unittest.mock import AsyncMock, MagicMock from unittest.mock import AsyncMock, MagicMock
@ -86,6 +87,42 @@ class TestMessageToolSuppressLogic:
assert result is not None assert result is not None
assert "Hello" in result.content assert "Hello" in result.content
@pytest.mark.asyncio
async def test_injected_followup_with_message_tool_does_not_emit_empty_fallback(
self, tmp_path: Path
) -> None:
loop = _make_loop(tmp_path)
tool_call = ToolCallRequest(
id="call1", name="message",
arguments={"content": "Tool reply", "channel": "feishu", "chat_id": "chat123"},
)
calls = iter([
LLMResponse(content="First answer", tool_calls=[]),
LLMResponse(content="", tool_calls=[tool_call]),
LLMResponse(content="", tool_calls=[]),
LLMResponse(content="", tool_calls=[]),
LLMResponse(content="", tool_calls=[]),
])
loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls))
loop.tools.get_definitions = MagicMock(return_value=[])
sent: list[OutboundMessage] = []
mt = loop.tools.get("message")
if isinstance(mt, MessageTool):
mt.set_send_callback(AsyncMock(side_effect=lambda m: sent.append(m)))
pending_queue = asyncio.Queue()
await pending_queue.put(
InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="follow-up")
)
msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Start")
result = await loop._process_message(msg, pending_queue=pending_queue)
assert len(sent) == 1
assert sent[0].content == "Tool reply"
assert result is None
async def test_progress_hides_internal_reasoning(self, tmp_path: Path) -> None: async def test_progress_hides_internal_reasoning(self, tmp_path: Path) -> None:
loop = _make_loop(tmp_path) loop = _make_loop(tmp_path)
tool_call = ToolCallRequest(id="call1", name="read_file", arguments={"path": "foo.txt"}) tool_call = ToolCallRequest(id="call1", name="read_file", arguments={"path": "foo.txt"})
@ -107,7 +144,7 @@ class TestMessageToolSuppressLogic:
async def on_progress(content: str, *, tool_hint: bool = False) -> None: async def on_progress(content: str, *, tool_hint: bool = False) -> None:
progress.append((content, tool_hint)) progress.append((content, tool_hint))
final_content, _, _, _ = await loop._run_agent_loop([], on_progress=on_progress) final_content, _, _, _, _ = await loop._run_agent_loop([], on_progress=on_progress)
assert final_content == "Done" assert final_content == "Done"
assert progress == [ assert progress == [