mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-14 23:19:55 +00:00
Merge branch 'main' into fix/tool-call-result-order-2943
This commit is contained in:
commit
f25cdb7138
@ -17,7 +17,7 @@ from nanobot.agent.autocompact import AutoCompact
|
||||
from nanobot.agent.context import ContextBuilder
|
||||
from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook
|
||||
from nanobot.agent.memory import Consolidator, Dream
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
from nanobot.agent.runner import _MAX_INJECTIONS_PER_TURN, AgentRunSpec, AgentRunner
|
||||
from nanobot.agent.subagent import SubagentManager
|
||||
from nanobot.agent.tools.cron import CronTool
|
||||
from nanobot.agent.skills import BUILTIN_SKILLS_DIR
|
||||
@ -207,6 +207,10 @@ class AgentLoop:
|
||||
self._active_tasks: dict[str, list[asyncio.Task]] = {} # session_key -> tasks
|
||||
self._background_tasks: list[asyncio.Task] = []
|
||||
self._session_locks: dict[str, asyncio.Lock] = {}
|
||||
# Per-session pending queues for mid-turn message injection.
|
||||
# When a session has an active task, new messages for that session
|
||||
# are routed here instead of creating a new task.
|
||||
self._pending_queues: dict[str, asyncio.Queue] = {}
|
||||
# NANOBOT_MAX_CONCURRENT_REQUESTS: <=0 means unlimited; default 3.
|
||||
_max = int(os.environ.get("NANOBOT_MAX_CONCURRENT_REQUESTS", "3"))
|
||||
self._concurrency_gate: asyncio.Semaphore | None = (
|
||||
@ -320,6 +324,12 @@ class AgentLoop:
|
||||
|
||||
return format_tool_hints(tool_calls)
|
||||
|
||||
def _effective_session_key(self, msg: InboundMessage) -> str:
|
||||
"""Return the session key used for task routing and mid-turn injections."""
|
||||
if self._unified_session and not msg.session_key_override:
|
||||
return UNIFIED_SESSION_KEY
|
||||
return msg.session_key
|
||||
|
||||
async def _run_agent_loop(
|
||||
self,
|
||||
initial_messages: list[dict],
|
||||
@ -331,13 +341,16 @@ class AgentLoop:
|
||||
channel: str = "cli",
|
||||
chat_id: str = "direct",
|
||||
message_id: str | None = None,
|
||||
) -> tuple[str | None, list[str], list[dict], str]:
|
||||
pending_queue: asyncio.Queue | None = None,
|
||||
) -> tuple[str | None, list[str], list[dict], str, bool]:
|
||||
"""Run the agent iteration loop.
|
||||
|
||||
*on_stream*: called with each content delta during streaming.
|
||||
*on_stream_end(resuming)*: called when a streaming session finishes.
|
||||
``resuming=True`` means tool calls follow (spinner should restart);
|
||||
``resuming=False`` means this is the final response.
|
||||
|
||||
Returns (final_content, tools_used, messages, stop_reason, had_injections).
|
||||
"""
|
||||
loop_hook = _LoopHook(
|
||||
self,
|
||||
@ -357,31 +370,56 @@ class AgentLoop:
|
||||
return
|
||||
self._set_runtime_checkpoint(session, payload)
|
||||
|
||||
result = await self.runner.run(
|
||||
AgentRunSpec(
|
||||
initial_messages=initial_messages,
|
||||
tools=self.tools,
|
||||
model=self.model,
|
||||
max_iterations=self.max_iterations,
|
||||
max_tool_result_chars=self.max_tool_result_chars,
|
||||
hook=hook,
|
||||
error_message="Sorry, I encountered an error calling the AI model.",
|
||||
concurrent_tools=True,
|
||||
workspace=self.workspace,
|
||||
session_key=session.key if session else None,
|
||||
context_window_tokens=self.context_window_tokens,
|
||||
context_block_limit=self.context_block_limit,
|
||||
provider_retry_mode=self.provider_retry_mode,
|
||||
progress_callback=on_progress,
|
||||
checkpoint_callback=_checkpoint,
|
||||
)
|
||||
)
|
||||
async def _drain_pending(*, limit: int = _MAX_INJECTIONS_PER_TURN) -> list[dict[str, Any]]:
|
||||
"""Non-blocking drain of follow-up messages from the pending queue."""
|
||||
if pending_queue is None:
|
||||
return []
|
||||
items: list[dict[str, Any]] = []
|
||||
while len(items) < limit:
|
||||
try:
|
||||
pending_msg = pending_queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
user_content = self.context._build_user_content(
|
||||
pending_msg.content,
|
||||
pending_msg.media if pending_msg.media else None,
|
||||
)
|
||||
runtime_ctx = self.context._build_runtime_context(
|
||||
pending_msg.channel,
|
||||
pending_msg.chat_id,
|
||||
self.context.timezone,
|
||||
)
|
||||
if isinstance(user_content, str):
|
||||
merged: str | list[dict[str, Any]] = f"{runtime_ctx}\n\n{user_content}"
|
||||
else:
|
||||
merged = [{"type": "text", "text": runtime_ctx}] + user_content
|
||||
items.append({"role": "user", "content": merged})
|
||||
return items
|
||||
|
||||
result = await self.runner.run(AgentRunSpec(
|
||||
initial_messages=initial_messages,
|
||||
tools=self.tools,
|
||||
model=self.model,
|
||||
max_iterations=self.max_iterations,
|
||||
max_tool_result_chars=self.max_tool_result_chars,
|
||||
hook=hook,
|
||||
error_message="Sorry, I encountered an error calling the AI model.",
|
||||
concurrent_tools=True,
|
||||
workspace=self.workspace,
|
||||
session_key=session.key if session else None,
|
||||
context_window_tokens=self.context_window_tokens,
|
||||
context_block_limit=self.context_block_limit,
|
||||
provider_retry_mode=self.provider_retry_mode,
|
||||
progress_callback=on_progress,
|
||||
checkpoint_callback=_checkpoint,
|
||||
injection_callback=_drain_pending,
|
||||
))
|
||||
self._last_usage = result.usage
|
||||
if result.stop_reason == "max_iterations":
|
||||
logger.warning("Max iterations ({}) reached", self.max_iterations)
|
||||
elif result.stop_reason == "error":
|
||||
logger.error("LLM returned error: {}", (result.final_content or "")[:200])
|
||||
return result.final_content, result.tools_used, result.messages, result.stop_reason
|
||||
return result.final_content, result.tools_used, result.messages, result.stop_reason, result.had_injections
|
||||
|
||||
async def run(self) -> None:
|
||||
"""Run the agent loop, dispatching messages as tasks to stay responsive to /stop."""
|
||||
@ -412,13 +450,32 @@ class AgentLoop:
|
||||
if result:
|
||||
await self.bus.publish_outbound(result)
|
||||
continue
|
||||
effective_key = self._effective_session_key(msg)
|
||||
# If this session already has an active pending queue (i.e. a task
|
||||
# is processing this session), route the message there for mid-turn
|
||||
# injection instead of creating a competing task.
|
||||
if effective_key in self._pending_queues:
|
||||
pending_msg = msg
|
||||
if effective_key != msg.session_key:
|
||||
pending_msg = dataclasses.replace(
|
||||
msg,
|
||||
session_key_override=effective_key,
|
||||
)
|
||||
try:
|
||||
self._pending_queues[effective_key].put_nowait(pending_msg)
|
||||
except asyncio.QueueFull:
|
||||
logger.warning(
|
||||
"Pending queue full for session {}, falling back to queued task",
|
||||
effective_key,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Routed follow-up message to pending queue for session {}",
|
||||
effective_key,
|
||||
)
|
||||
continue
|
||||
# Compute the effective session key before dispatching
|
||||
# This ensures /stop command can find tasks correctly when unified session is enabled
|
||||
effective_key = (
|
||||
UNIFIED_SESSION_KEY
|
||||
if self._unified_session and not msg.session_key_override
|
||||
else msg.session_key
|
||||
)
|
||||
task = asyncio.create_task(self._dispatch(msg))
|
||||
self._active_tasks.setdefault(effective_key, []).append(task)
|
||||
task.add_done_callback(
|
||||
@ -430,78 +487,91 @@ class AgentLoop:
|
||||
|
||||
async def _dispatch(self, msg: InboundMessage) -> None:
|
||||
"""Process a message: per-session serial, cross-session concurrent."""
|
||||
if self._unified_session and not msg.session_key_override:
|
||||
msg = dataclasses.replace(msg, session_key_override=UNIFIED_SESSION_KEY)
|
||||
lock = self._session_locks.setdefault(msg.session_key, asyncio.Lock())
|
||||
session_key = self._effective_session_key(msg)
|
||||
if session_key != msg.session_key:
|
||||
msg = dataclasses.replace(msg, session_key_override=session_key)
|
||||
lock = self._session_locks.setdefault(session_key, asyncio.Lock())
|
||||
gate = self._concurrency_gate or nullcontext()
|
||||
async with lock, gate:
|
||||
try:
|
||||
on_stream = on_stream_end = None
|
||||
if msg.metadata.get("_wants_stream"):
|
||||
# Split one answer into distinct stream segments.
|
||||
stream_base_id = f"{msg.session_key}:{time.time_ns()}"
|
||||
stream_segment = 0
|
||||
|
||||
def _current_stream_id() -> str:
|
||||
return f"{stream_base_id}:{stream_segment}"
|
||||
# Register a pending queue so follow-up messages for this session are
|
||||
# routed here (mid-turn injection) instead of spawning a new task.
|
||||
pending = asyncio.Queue(maxsize=20)
|
||||
self._pending_queues[session_key] = pending
|
||||
|
||||
async def on_stream(delta: str) -> None:
|
||||
meta = dict(msg.metadata or {})
|
||||
meta["_stream_delta"] = True
|
||||
meta["_stream_id"] = _current_stream_id()
|
||||
await self.bus.publish_outbound(
|
||||
OutboundMessage(
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
try:
|
||||
async with lock, gate:
|
||||
try:
|
||||
on_stream = on_stream_end = None
|
||||
if msg.metadata.get("_wants_stream"):
|
||||
# Split one answer into distinct stream segments.
|
||||
stream_base_id = f"{msg.session_key}:{time.time_ns()}"
|
||||
stream_segment = 0
|
||||
|
||||
def _current_stream_id() -> str:
|
||||
return f"{stream_base_id}:{stream_segment}"
|
||||
|
||||
async def on_stream(delta: str) -> None:
|
||||
meta = dict(msg.metadata or {})
|
||||
meta["_stream_delta"] = True
|
||||
meta["_stream_id"] = _current_stream_id()
|
||||
await self.bus.publish_outbound(OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
content=delta,
|
||||
metadata=meta,
|
||||
)
|
||||
)
|
||||
))
|
||||
|
||||
async def on_stream_end(*, resuming: bool = False) -> None:
|
||||
nonlocal stream_segment
|
||||
meta = dict(msg.metadata or {})
|
||||
meta["_stream_end"] = True
|
||||
meta["_resuming"] = resuming
|
||||
meta["_stream_id"] = _current_stream_id()
|
||||
await self.bus.publish_outbound(
|
||||
OutboundMessage(
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
async def on_stream_end(*, resuming: bool = False) -> None:
|
||||
nonlocal stream_segment
|
||||
meta = dict(msg.metadata or {})
|
||||
meta["_stream_end"] = True
|
||||
meta["_resuming"] = resuming
|
||||
meta["_stream_id"] = _current_stream_id()
|
||||
await self.bus.publish_outbound(OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
content="",
|
||||
metadata=meta,
|
||||
)
|
||||
)
|
||||
stream_segment += 1
|
||||
))
|
||||
stream_segment += 1
|
||||
|
||||
response = await self._process_message(
|
||||
msg,
|
||||
on_stream=on_stream,
|
||||
on_stream_end=on_stream_end,
|
||||
)
|
||||
if response is not None:
|
||||
await self.bus.publish_outbound(response)
|
||||
elif msg.channel == "cli":
|
||||
await self.bus.publish_outbound(
|
||||
OutboundMessage(
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
content="",
|
||||
metadata=msg.metadata or {},
|
||||
)
|
||||
response = await self._process_message(
|
||||
msg, on_stream=on_stream, on_stream_end=on_stream_end,
|
||||
pending_queue=pending,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Task cancelled for session {}", msg.session_key)
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Error processing message for session {}", msg.session_key)
|
||||
await self.bus.publish_outbound(
|
||||
OutboundMessage(
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
if response is not None:
|
||||
await self.bus.publish_outbound(response)
|
||||
elif msg.channel == "cli":
|
||||
await self.bus.publish_outbound(OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
content="", metadata=msg.metadata or {},
|
||||
))
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Task cancelled for session {}", session_key)
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Error processing message for session {}", session_key)
|
||||
await self.bus.publish_outbound(OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
content="Sorry, I encountered an error.",
|
||||
))
|
||||
finally:
|
||||
# Drain any messages still in the pending queue and re-publish
|
||||
# them to the bus so they are processed as fresh inbound messages
|
||||
# rather than silently lost.
|
||||
queue = self._pending_queues.pop(session_key, None)
|
||||
if queue is not None:
|
||||
leftover = 0
|
||||
while True:
|
||||
try:
|
||||
item = queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
await self.bus.publish_inbound(item)
|
||||
leftover += 1
|
||||
if leftover:
|
||||
logger.info(
|
||||
"Re-published {} leftover message(s) to bus for session {}",
|
||||
leftover, session_key,
|
||||
)
|
||||
)
|
||||
|
||||
async def close_mcp(self) -> None:
|
||||
"""Drain pending background archives, then close MCP connections."""
|
||||
@ -533,6 +603,7 @@ class AgentLoop:
|
||||
on_progress: Callable[[str], Awaitable[None]] | None = None,
|
||||
on_stream: Callable[[str], Awaitable[None]] | None = None,
|
||||
on_stream_end: Callable[..., Awaitable[None]] | None = None,
|
||||
pending_queue: asyncio.Queue | None = None,
|
||||
) -> OutboundMessage | None:
|
||||
"""Process a single inbound message and return the response."""
|
||||
# System messages: parse origin from chat_id ("channel:chat_id")
|
||||
@ -559,11 +630,8 @@ class AgentLoop:
|
||||
session_summary=pending,
|
||||
current_role=current_role,
|
||||
)
|
||||
final_content, _, all_msgs, _ = await self._run_agent_loop(
|
||||
messages,
|
||||
session=session,
|
||||
channel=channel,
|
||||
chat_id=chat_id,
|
||||
final_content, _, all_msgs, _, _ = await self._run_agent_loop(
|
||||
messages, session=session, channel=channel, chat_id=chat_id,
|
||||
message_id=msg.metadata.get("message_id"),
|
||||
)
|
||||
self._save_turn(session, all_msgs, 1 + len(history))
|
||||
@ -623,7 +691,7 @@ class AgentLoop:
|
||||
)
|
||||
)
|
||||
|
||||
final_content, _, all_msgs, stop_reason = await self._run_agent_loop(
|
||||
final_content, _, all_msgs, stop_reason, had_injections = await self._run_agent_loop(
|
||||
initial_messages,
|
||||
on_progress=on_progress or _bus_progress,
|
||||
on_stream=on_stream,
|
||||
@ -632,6 +700,7 @@ class AgentLoop:
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
message_id=msg.metadata.get("message_id"),
|
||||
pending_queue=pending_queue,
|
||||
)
|
||||
|
||||
if final_content is None or not final_content.strip():
|
||||
@ -642,8 +711,15 @@ class AgentLoop:
|
||||
self.sessions.save(session)
|
||||
self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session))
|
||||
|
||||
# When follow-up messages were injected mid-turn, a later natural
|
||||
# language reply may address those follow-ups and should not be
|
||||
# suppressed just because MessageTool was used earlier in the turn.
|
||||
# However, if the turn falls back to the empty-final-response
|
||||
# placeholder, suppress it when the real user-visible output already
|
||||
# came from MessageTool.
|
||||
if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn:
|
||||
return None
|
||||
if not had_injections or stop_reason == "empty_final_response":
|
||||
return None
|
||||
|
||||
preview = final_content[:120] + "..." if len(final_content) > 120 else final_content
|
||||
logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview)
|
||||
|
||||
@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass, field
|
||||
import inspect
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
@ -34,6 +35,8 @@ _DEFAULT_ERROR_MESSAGE = "Sorry, I encountered an error calling the AI model."
|
||||
_PERSISTED_MODEL_ERROR_PLACEHOLDER = "[Assistant reply unavailable due to model error.]"
|
||||
_MAX_EMPTY_RETRIES = 2
|
||||
_MAX_LENGTH_RECOVERIES = 3
|
||||
_MAX_INJECTIONS_PER_TURN = 3
|
||||
_MAX_INJECTION_CYCLES = 5
|
||||
_SNIP_SAFETY_BUFFER = 1024
|
||||
_MICROCOMPACT_KEEP_RECENT = 10
|
||||
_MICROCOMPACT_MIN_CHARS = 500
|
||||
@ -42,6 +45,9 @@ _COMPACTABLE_TOOLS = frozenset({
|
||||
"web_search", "web_fetch", "list_dir",
|
||||
})
|
||||
_BACKFILL_CONTENT = "[Tool result unavailable — call was interrupted or lost]"
|
||||
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class AgentRunSpec:
|
||||
"""Configuration for a single agent execution."""
|
||||
@ -66,6 +72,7 @@ class AgentRunSpec:
|
||||
provider_retry_mode: str = "standard"
|
||||
progress_callback: Any | None = None
|
||||
checkpoint_callback: Any | None = None
|
||||
injection_callback: Any | None = None
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
@ -79,6 +86,7 @@ class AgentRunResult:
|
||||
stop_reason: str = "completed"
|
||||
error: str | None = None
|
||||
tool_events: list[dict[str, str]] = field(default_factory=list)
|
||||
had_injections: bool = False
|
||||
|
||||
|
||||
class AgentRunner:
|
||||
@ -87,6 +95,90 @@ class AgentRunner:
|
||||
def __init__(self, provider: LLMProvider):
|
||||
self.provider = provider
|
||||
|
||||
@staticmethod
|
||||
def _merge_message_content(left: Any, right: Any) -> str | list[dict[str, Any]]:
|
||||
if isinstance(left, str) and isinstance(right, str):
|
||||
return f"{left}\n\n{right}" if left else right
|
||||
|
||||
def _to_blocks(value: Any) -> list[dict[str, Any]]:
|
||||
if isinstance(value, list):
|
||||
return [
|
||||
item if isinstance(item, dict) else {"type": "text", "text": str(item)}
|
||||
for item in value
|
||||
]
|
||||
if value is None:
|
||||
return []
|
||||
return [{"type": "text", "text": str(value)}]
|
||||
|
||||
return _to_blocks(left) + _to_blocks(right)
|
||||
|
||||
@classmethod
|
||||
def _append_injected_messages(
|
||||
cls,
|
||||
messages: list[dict[str, Any]],
|
||||
injections: list[dict[str, Any]],
|
||||
) -> None:
|
||||
"""Append injected user messages while preserving role alternation."""
|
||||
for injection in injections:
|
||||
if (
|
||||
messages
|
||||
and injection.get("role") == "user"
|
||||
and messages[-1].get("role") == "user"
|
||||
):
|
||||
merged = dict(messages[-1])
|
||||
merged["content"] = cls._merge_message_content(
|
||||
merged.get("content"),
|
||||
injection.get("content"),
|
||||
)
|
||||
messages[-1] = merged
|
||||
continue
|
||||
messages.append(injection)
|
||||
|
||||
async def _drain_injections(self, spec: AgentRunSpec) -> list[dict[str, Any]]:
|
||||
"""Drain pending user messages via the injection callback.
|
||||
|
||||
Returns normalized user messages (capped by
|
||||
``_MAX_INJECTIONS_PER_TURN``), or an empty list when there is
|
||||
nothing to inject. Messages beyond the cap are logged so they
|
||||
are not silently lost.
|
||||
"""
|
||||
if spec.injection_callback is None:
|
||||
return []
|
||||
try:
|
||||
signature = inspect.signature(spec.injection_callback)
|
||||
accepts_limit = (
|
||||
"limit" in signature.parameters
|
||||
or any(
|
||||
parameter.kind is inspect.Parameter.VAR_KEYWORD
|
||||
for parameter in signature.parameters.values()
|
||||
)
|
||||
)
|
||||
if accepts_limit:
|
||||
items = await spec.injection_callback(limit=_MAX_INJECTIONS_PER_TURN)
|
||||
else:
|
||||
items = await spec.injection_callback()
|
||||
except Exception:
|
||||
logger.exception("injection_callback failed")
|
||||
return []
|
||||
if not items:
|
||||
return []
|
||||
injected_messages: list[dict[str, Any]] = []
|
||||
for item in items:
|
||||
if isinstance(item, dict) and item.get("role") == "user" and "content" in item:
|
||||
injected_messages.append(item)
|
||||
continue
|
||||
text = getattr(item, "content", str(item))
|
||||
if text.strip():
|
||||
injected_messages.append({"role": "user", "content": text})
|
||||
if len(injected_messages) > _MAX_INJECTIONS_PER_TURN:
|
||||
dropped = len(injected_messages) - _MAX_INJECTIONS_PER_TURN
|
||||
logger.warning(
|
||||
"Injection callback returned {} messages, capping to {} ({} dropped)",
|
||||
len(injected_messages), _MAX_INJECTIONS_PER_TURN, dropped,
|
||||
)
|
||||
injected_messages = injected_messages[:_MAX_INJECTIONS_PER_TURN]
|
||||
return injected_messages
|
||||
|
||||
async def run(self, spec: AgentRunSpec) -> AgentRunResult:
|
||||
hook = spec.hook or AgentHook()
|
||||
messages = list(spec.initial_messages)
|
||||
@ -99,6 +191,8 @@ class AgentRunner:
|
||||
external_lookup_counts: dict[str, int] = {}
|
||||
empty_content_retries = 0
|
||||
length_recovery_count = 0
|
||||
had_injections = False
|
||||
injection_cycles = 0
|
||||
|
||||
for iteration in range(spec.max_iterations):
|
||||
try:
|
||||
@ -207,6 +301,17 @@ class AgentRunner:
|
||||
)
|
||||
empty_content_retries = 0
|
||||
length_recovery_count = 0
|
||||
# Checkpoint 1: drain injections after tools, before next LLM call
|
||||
if injection_cycles < _MAX_INJECTION_CYCLES:
|
||||
injections = await self._drain_injections(spec)
|
||||
if injections:
|
||||
had_injections = True
|
||||
injection_cycles += 1
|
||||
self._append_injected_messages(messages, injections)
|
||||
logger.info(
|
||||
"Injected {} follow-up message(s) after tool execution ({}/{})",
|
||||
len(injections), injection_cycles, _MAX_INJECTION_CYCLES,
|
||||
)
|
||||
await hook.after_iteration(context)
|
||||
continue
|
||||
|
||||
@ -263,8 +368,49 @@ class AgentRunner:
|
||||
await hook.after_iteration(context)
|
||||
continue
|
||||
|
||||
assistant_message: dict[str, Any] | None = None
|
||||
if response.finish_reason != "error" and not is_blank_text(clean):
|
||||
assistant_message = build_assistant_message(
|
||||
clean,
|
||||
reasoning_content=response.reasoning_content,
|
||||
thinking_blocks=response.thinking_blocks,
|
||||
)
|
||||
|
||||
# Check for mid-turn injections BEFORE signaling stream end.
|
||||
# If injections are found we keep the stream alive (resuming=True)
|
||||
# so streaming channels don't prematurely finalize the card.
|
||||
_injected_after_final = False
|
||||
if injection_cycles < _MAX_INJECTION_CYCLES:
|
||||
injections = await self._drain_injections(spec)
|
||||
if injections:
|
||||
had_injections = True
|
||||
injection_cycles += 1
|
||||
_injected_after_final = True
|
||||
if assistant_message is not None:
|
||||
messages.append(assistant_message)
|
||||
await self._emit_checkpoint(
|
||||
spec,
|
||||
{
|
||||
"phase": "final_response",
|
||||
"iteration": iteration,
|
||||
"model": spec.model,
|
||||
"assistant_message": assistant_message,
|
||||
"completed_tool_results": [],
|
||||
"pending_tool_calls": [],
|
||||
},
|
||||
)
|
||||
self._append_injected_messages(messages, injections)
|
||||
logger.info(
|
||||
"Injected {} follow-up message(s) after final response ({}/{})",
|
||||
len(injections), injection_cycles, _MAX_INJECTION_CYCLES,
|
||||
)
|
||||
|
||||
if hook.wants_streaming():
|
||||
await hook.on_stream_end(context, resuming=False)
|
||||
await hook.on_stream_end(context, resuming=_injected_after_final)
|
||||
|
||||
if _injected_after_final:
|
||||
await hook.after_iteration(context)
|
||||
continue
|
||||
|
||||
if response.finish_reason == "error":
|
||||
final_content = clean or spec.error_message or _DEFAULT_ERROR_MESSAGE
|
||||
@ -287,7 +433,7 @@ class AgentRunner:
|
||||
await hook.after_iteration(context)
|
||||
break
|
||||
|
||||
messages.append(build_assistant_message(
|
||||
messages.append(assistant_message or build_assistant_message(
|
||||
clean,
|
||||
reasoning_content=response.reasoning_content,
|
||||
thinking_blocks=response.thinking_blocks,
|
||||
@ -330,6 +476,7 @@ class AgentRunner:
|
||||
stop_reason=stop_reason,
|
||||
error=error,
|
||||
tool_events=tool_events,
|
||||
had_injections=had_injections,
|
||||
)
|
||||
|
||||
def _build_request_kwargs(
|
||||
|
||||
@ -242,43 +242,46 @@ class QQChannel(BaseChannel):
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""Send attachments first, then text."""
|
||||
if not self._client:
|
||||
logger.warning("QQ client not initialized")
|
||||
return
|
||||
try:
|
||||
if not self._client:
|
||||
logger.warning("QQ client not initialized")
|
||||
return
|
||||
|
||||
msg_id = msg.metadata.get("message_id")
|
||||
chat_type = self._chat_type_cache.get(msg.chat_id, "c2c")
|
||||
is_group = chat_type == "group"
|
||||
msg_id = msg.metadata.get("message_id")
|
||||
chat_type = self._chat_type_cache.get(msg.chat_id, "c2c")
|
||||
is_group = chat_type == "group"
|
||||
|
||||
# 1) Send media
|
||||
for media_ref in msg.media or []:
|
||||
ok = await self._send_media(
|
||||
chat_id=msg.chat_id,
|
||||
media_ref=media_ref,
|
||||
msg_id=msg_id,
|
||||
is_group=is_group,
|
||||
)
|
||||
if not ok:
|
||||
filename = (
|
||||
os.path.basename(urlparse(media_ref).path)
|
||||
or os.path.basename(media_ref)
|
||||
or "file"
|
||||
# 1) Send media
|
||||
for media_ref in msg.media or []:
|
||||
ok = await self._send_media(
|
||||
chat_id=msg.chat_id,
|
||||
media_ref=media_ref,
|
||||
msg_id=msg_id,
|
||||
is_group=is_group,
|
||||
)
|
||||
if not ok:
|
||||
filename = (
|
||||
os.path.basename(urlparse(media_ref).path)
|
||||
or os.path.basename(media_ref)
|
||||
or "file"
|
||||
)
|
||||
await self._send_text_only(
|
||||
chat_id=msg.chat_id,
|
||||
is_group=is_group,
|
||||
msg_id=msg_id,
|
||||
content=f"[Attachment send failed: {filename}]",
|
||||
)
|
||||
|
||||
# 2) Send text
|
||||
if msg.content and msg.content.strip():
|
||||
await self._send_text_only(
|
||||
chat_id=msg.chat_id,
|
||||
is_group=is_group,
|
||||
msg_id=msg_id,
|
||||
content=f"[Attachment send failed: {filename}]",
|
||||
content=msg.content.strip(),
|
||||
)
|
||||
|
||||
# 2) Send text
|
||||
if msg.content and msg.content.strip():
|
||||
await self._send_text_only(
|
||||
chat_id=msg.chat_id,
|
||||
is_group=is_group,
|
||||
msg_id=msg_id,
|
||||
content=msg.content.strip(),
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error sending QQ message to chat_id={}", msg.chat_id)
|
||||
|
||||
async def _send_text_only(
|
||||
self,
|
||||
@ -438,15 +441,26 @@ class QQChannel(BaseChannel):
|
||||
endpoint = "/v2/users/{openid}/files"
|
||||
id_key = "openid"
|
||||
|
||||
payload = {
|
||||
payload: dict[str, Any] = {
|
||||
id_key: chat_id,
|
||||
"file_type": file_type,
|
||||
"file_data": file_data,
|
||||
"file_name": file_name,
|
||||
"srv_send_msg": srv_send_msg,
|
||||
}
|
||||
# Only pass file_name for non-image types (file_type=4).
|
||||
# Passing file_name for images causes QQ client to render them as
|
||||
# file attachments instead of inline images.
|
||||
if file_type != QQ_FILE_TYPE_IMAGE and file_name:
|
||||
payload["file_name"] = file_name
|
||||
|
||||
route = Route("POST", endpoint, **{id_key: chat_id})
|
||||
return await self._client.api._http.request(route, json=payload)
|
||||
result = await self._client.api._http.request(route, json=payload)
|
||||
|
||||
# Extract only the file_info field to avoid extra fields (file_uuid, ttl, etc.)
|
||||
# that may confuse QQ client when sending the media object.
|
||||
if isinstance(result, dict) and "file_info" in result:
|
||||
return {"file_info": result["file_info"]}
|
||||
return result
|
||||
|
||||
# ---------------------------
|
||||
# Inbound (receive)
|
||||
@ -454,58 +468,68 @@ class QQChannel(BaseChannel):
|
||||
|
||||
async def _on_message(self, data: C2CMessage | GroupMessage, is_group: bool = False) -> None:
|
||||
"""Parse inbound message, download attachments, and publish to the bus."""
|
||||
if data.id in self._processed_ids:
|
||||
return
|
||||
self._processed_ids.append(data.id)
|
||||
try:
|
||||
if data.id in self._processed_ids:
|
||||
return
|
||||
self._processed_ids.append(data.id)
|
||||
|
||||
if is_group:
|
||||
chat_id = data.group_openid
|
||||
user_id = data.author.member_openid
|
||||
self._chat_type_cache[chat_id] = "group"
|
||||
else:
|
||||
chat_id = str(
|
||||
getattr(data.author, "id", None) or getattr(data.author, "user_openid", "unknown")
|
||||
)
|
||||
user_id = chat_id
|
||||
self._chat_type_cache[chat_id] = "c2c"
|
||||
|
||||
content = (data.content or "").strip()
|
||||
|
||||
# the data used by tests don't contain attachments property
|
||||
# so we use getattr with a default of [] to avoid AttributeError in tests
|
||||
attachments = getattr(data, "attachments", None) or []
|
||||
media_paths, recv_lines, att_meta = await self._handle_attachments(attachments)
|
||||
|
||||
# Compose content that always contains actionable saved paths
|
||||
if recv_lines:
|
||||
tag = "[Image]" if any(_is_image_name(Path(p).name) for p in media_paths) else "[File]"
|
||||
file_block = "Received files:\n" + "\n".join(recv_lines)
|
||||
content = f"{content}\n\n{file_block}".strip() if content else f"{tag}\n{file_block}"
|
||||
|
||||
if not content and not media_paths:
|
||||
return
|
||||
|
||||
if self.config.ack_message:
|
||||
try:
|
||||
await self._send_text_only(
|
||||
chat_id=chat_id,
|
||||
is_group=is_group,
|
||||
msg_id=data.id,
|
||||
content=self.config.ack_message,
|
||||
if is_group:
|
||||
chat_id = data.group_openid
|
||||
user_id = data.author.member_openid
|
||||
self._chat_type_cache[chat_id] = "group"
|
||||
else:
|
||||
chat_id = str(
|
||||
getattr(data.author, "id", None)
|
||||
or getattr(data.author, "user_openid", "unknown")
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("QQ ack message failed for chat_id={}", chat_id)
|
||||
user_id = chat_id
|
||||
self._chat_type_cache[chat_id] = "c2c"
|
||||
|
||||
await self._handle_message(
|
||||
sender_id=user_id,
|
||||
chat_id=chat_id,
|
||||
content=content,
|
||||
media=media_paths if media_paths else None,
|
||||
metadata={
|
||||
"message_id": data.id,
|
||||
"attachments": att_meta,
|
||||
},
|
||||
)
|
||||
content = (data.content or "").strip()
|
||||
|
||||
# the data used by tests don't contain attachments property
|
||||
# so we use getattr with a default of [] to avoid AttributeError in tests
|
||||
attachments = getattr(data, "attachments", None) or []
|
||||
media_paths, recv_lines, att_meta = await self._handle_attachments(attachments)
|
||||
|
||||
# Compose content that always contains actionable saved paths
|
||||
if recv_lines:
|
||||
tag = (
|
||||
"[Image]"
|
||||
if any(_is_image_name(Path(p).name) for p in media_paths)
|
||||
else "[File]"
|
||||
)
|
||||
file_block = "Received files:\n" + "\n".join(recv_lines)
|
||||
content = (
|
||||
f"{content}\n\n{file_block}".strip() if content else f"{tag}\n{file_block}"
|
||||
)
|
||||
|
||||
if not content and not media_paths:
|
||||
return
|
||||
|
||||
if self.config.ack_message:
|
||||
try:
|
||||
await self._send_text_only(
|
||||
chat_id=chat_id,
|
||||
is_group=is_group,
|
||||
msg_id=data.id,
|
||||
content=self.config.ack_message,
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("QQ ack message failed for chat_id={}", chat_id)
|
||||
|
||||
await self._handle_message(
|
||||
sender_id=user_id,
|
||||
chat_id=chat_id,
|
||||
content=content,
|
||||
media=media_paths if media_paths else None,
|
||||
metadata={
|
||||
"message_id": data.id,
|
||||
"attachments": att_meta,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error handling QQ inbound message id={}", getattr(data, "id", "?"))
|
||||
|
||||
async def _handle_attachments(
|
||||
self,
|
||||
@ -520,7 +544,9 @@ class QQChannel(BaseChannel):
|
||||
return media_paths, recv_lines, att_meta
|
||||
|
||||
for att in attachments:
|
||||
url, filename, ctype = att.url, att.filename, att.content_type
|
||||
url = getattr(att, "url", None) or ""
|
||||
filename = getattr(att, "filename", None) or ""
|
||||
ctype = getattr(att, "content_type", None) or ""
|
||||
|
||||
logger.info("Downloading file from QQ: {}", filename or url)
|
||||
local_path = await self._download_to_media_dir_chunked(url, filename_hint=filename)
|
||||
@ -555,6 +581,10 @@ class QQChannel(BaseChannel):
|
||||
Enforces a max download size and writes to a .part temp file
|
||||
that is atomically renamed on success.
|
||||
"""
|
||||
# Handle protocol-relative URLs (e.g. "//multimedia.nt.qq.com/...")
|
||||
if url.startswith("//"):
|
||||
url = f"https:{url}"
|
||||
|
||||
if not self._http:
|
||||
self._http = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=120))
|
||||
|
||||
|
||||
@ -1,9 +1,13 @@
|
||||
"""WeCom (Enterprise WeChat) channel implementation using wecom_aibot_sdk."""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import importlib.util
|
||||
import os
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
@ -17,6 +21,37 @@ from pydantic import Field
|
||||
|
||||
WECOM_AVAILABLE = importlib.util.find_spec("wecom_aibot_sdk") is not None
|
||||
|
||||
# Upload safety limits (matching QQ channel defaults)
|
||||
WECOM_UPLOAD_MAX_BYTES = 1024 * 1024 * 200 # 200MB
|
||||
|
||||
# Replace unsafe characters with "_", keep Chinese and common safe punctuation.
|
||||
_SAFE_NAME_RE = re.compile(r"[^\w.\-()\[\]()【】\u4e00-\u9fff]+", re.UNICODE)
|
||||
|
||||
|
||||
def _sanitize_filename(name: str) -> str:
|
||||
"""Sanitize filename to avoid traversal and problematic chars."""
|
||||
name = (name or "").strip()
|
||||
name = Path(name).name
|
||||
name = _SAFE_NAME_RE.sub("_", name).strip("._ ")
|
||||
return name
|
||||
|
||||
|
||||
_IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp"}
|
||||
_VIDEO_EXTS = {".mp4", ".avi", ".mov"}
|
||||
_AUDIO_EXTS = {".amr", ".mp3", ".wav", ".ogg"}
|
||||
|
||||
|
||||
def _guess_wecom_media_type(filename: str) -> str:
|
||||
"""Classify file extension as WeCom media_type string."""
|
||||
ext = Path(filename).suffix.lower()
|
||||
if ext in _IMAGE_EXTS:
|
||||
return "image"
|
||||
if ext in _VIDEO_EXTS:
|
||||
return "video"
|
||||
if ext in _AUDIO_EXTS:
|
||||
return "voice"
|
||||
return "file"
|
||||
|
||||
class WecomConfig(Base):
|
||||
"""WeCom (Enterprise WeChat) AI Bot channel configuration."""
|
||||
|
||||
@ -217,6 +252,7 @@ class WecomChannel(BaseChannel):
|
||||
chat_id = body.get("chatid", sender_id)
|
||||
|
||||
content_parts = []
|
||||
media_paths: list[str] = []
|
||||
|
||||
if msg_type == "text":
|
||||
text = body.get("text", {}).get("content", "")
|
||||
@ -232,7 +268,8 @@ class WecomChannel(BaseChannel):
|
||||
file_path = await self._download_and_save_media(file_url, aes_key, "image")
|
||||
if file_path:
|
||||
filename = os.path.basename(file_path)
|
||||
content_parts.append(f"[image: {filename}]\n[Image: source: {file_path}]")
|
||||
content_parts.append(f"[image: {filename}]")
|
||||
media_paths.append(file_path)
|
||||
else:
|
||||
content_parts.append("[image: download failed]")
|
||||
else:
|
||||
@ -256,7 +293,8 @@ class WecomChannel(BaseChannel):
|
||||
if file_url and aes_key:
|
||||
file_path = await self._download_and_save_media(file_url, aes_key, "file", file_name)
|
||||
if file_path:
|
||||
content_parts.append(f"[file: {file_name}]\n[File: source: {file_path}]")
|
||||
content_parts.append(f"[file: {file_name}]")
|
||||
media_paths.append(file_path)
|
||||
else:
|
||||
content_parts.append(f"[file: {file_name}: download failed]")
|
||||
else:
|
||||
@ -286,12 +324,11 @@ class WecomChannel(BaseChannel):
|
||||
self._chat_frames[chat_id] = frame
|
||||
|
||||
# Forward to message bus
|
||||
# Note: media paths are included in content for broader model compatibility
|
||||
await self._handle_message(
|
||||
sender_id=sender_id,
|
||||
chat_id=chat_id,
|
||||
content=content,
|
||||
media=None,
|
||||
media=media_paths or None,
|
||||
metadata={
|
||||
"message_id": msg_id,
|
||||
"msg_type": msg_type,
|
||||
@ -322,13 +359,21 @@ class WecomChannel(BaseChannel):
|
||||
logger.warning("Failed to download media from WeCom")
|
||||
return None
|
||||
|
||||
if len(data) > WECOM_UPLOAD_MAX_BYTES:
|
||||
logger.warning(
|
||||
"WeCom inbound media too large: {} bytes (max {})",
|
||||
len(data),
|
||||
WECOM_UPLOAD_MAX_BYTES,
|
||||
)
|
||||
return None
|
||||
|
||||
media_dir = get_media_dir("wecom")
|
||||
if not filename:
|
||||
filename = fname or f"{media_type}_{hash(file_url) % 100000}"
|
||||
filename = os.path.basename(filename)
|
||||
filename = _sanitize_filename(filename)
|
||||
|
||||
file_path = media_dir / filename
|
||||
file_path.write_bytes(data)
|
||||
await asyncio.to_thread(file_path.write_bytes, data)
|
||||
logger.debug("Downloaded {} to {}", media_type, file_path)
|
||||
return str(file_path)
|
||||
|
||||
@ -336,6 +381,100 @@ class WecomChannel(BaseChannel):
|
||||
logger.error("Error downloading media: {}", e)
|
||||
return None
|
||||
|
||||
async def _upload_media_ws(
|
||||
self, client: Any, file_path: str,
|
||||
) -> "tuple[str, str] | tuple[None, None]":
|
||||
"""Upload a local file to WeCom via WebSocket 3-step protocol (base64).
|
||||
|
||||
Uses the WeCom WebSocket upload commands directly via
|
||||
``client._ws_manager.send_reply()``:
|
||||
|
||||
``aibot_upload_media_init`` → upload_id
|
||||
``aibot_upload_media_chunk`` × N (≤512 KB raw per chunk, base64)
|
||||
``aibot_upload_media_finish`` → media_id
|
||||
|
||||
Returns (media_id, media_type) on success, (None, None) on failure.
|
||||
"""
|
||||
from wecom_aibot_sdk.utils import generate_req_id as _gen_req_id
|
||||
|
||||
try:
|
||||
fname = os.path.basename(file_path)
|
||||
media_type = _guess_wecom_media_type(fname)
|
||||
|
||||
# Read file size and data in a thread to avoid blocking the event loop
|
||||
def _read_file():
|
||||
file_size = os.path.getsize(file_path)
|
||||
if file_size > WECOM_UPLOAD_MAX_BYTES:
|
||||
raise ValueError(
|
||||
f"File too large: {file_size} bytes (max {WECOM_UPLOAD_MAX_BYTES})"
|
||||
)
|
||||
with open(file_path, "rb") as f:
|
||||
return file_size, f.read()
|
||||
|
||||
file_size, data = await asyncio.to_thread(_read_file)
|
||||
# MD5 is used for file integrity only, not cryptographic security
|
||||
md5_hash = hashlib.md5(data).hexdigest()
|
||||
|
||||
CHUNK_SIZE = 512 * 1024 # 512 KB raw (before base64)
|
||||
mv = memoryview(data)
|
||||
chunk_list = [bytes(mv[i : i + CHUNK_SIZE]) for i in range(0, file_size, CHUNK_SIZE)]
|
||||
n_chunks = len(chunk_list)
|
||||
del mv, data
|
||||
|
||||
# Step 1: init
|
||||
req_id = _gen_req_id("upload_init")
|
||||
resp = await client._ws_manager.send_reply(req_id, {
|
||||
"type": media_type,
|
||||
"filename": fname,
|
||||
"total_size": file_size,
|
||||
"total_chunks": n_chunks,
|
||||
"md5": md5_hash,
|
||||
}, "aibot_upload_media_init")
|
||||
if resp.errcode != 0:
|
||||
logger.warning("WeCom upload init failed ({}): {}", resp.errcode, resp.errmsg)
|
||||
return None, None
|
||||
upload_id = resp.body.get("upload_id") if resp.body else None
|
||||
if not upload_id:
|
||||
logger.warning("WeCom upload init: no upload_id in response")
|
||||
return None, None
|
||||
|
||||
# Step 2: send chunks
|
||||
for i, chunk in enumerate(chunk_list):
|
||||
req_id = _gen_req_id("upload_chunk")
|
||||
resp = await client._ws_manager.send_reply(req_id, {
|
||||
"upload_id": upload_id,
|
||||
"chunk_index": i,
|
||||
"base64_data": base64.b64encode(chunk).decode(),
|
||||
}, "aibot_upload_media_chunk")
|
||||
if resp.errcode != 0:
|
||||
logger.warning("WeCom upload chunk {} failed ({}): {}", i, resp.errcode, resp.errmsg)
|
||||
return None, None
|
||||
|
||||
# Step 3: finish
|
||||
req_id = _gen_req_id("upload_finish")
|
||||
resp = await client._ws_manager.send_reply(req_id, {
|
||||
"upload_id": upload_id,
|
||||
}, "aibot_upload_media_finish")
|
||||
if resp.errcode != 0:
|
||||
logger.warning("WeCom upload finish failed ({}): {}", resp.errcode, resp.errmsg)
|
||||
return None, None
|
||||
|
||||
media_id = resp.body.get("media_id") if resp.body else None
|
||||
if not media_id:
|
||||
logger.warning("WeCom upload finish: no media_id in response body={}", resp.body)
|
||||
return None, None
|
||||
|
||||
suffix = "..." if len(media_id) > 16 else ""
|
||||
logger.debug("WeCom uploaded {} ({}) → media_id={}", fname, media_type, media_id[:16] + suffix)
|
||||
return media_id, media_type
|
||||
|
||||
except ValueError as e:
|
||||
logger.warning("WeCom upload skipped for {}: {}", file_path, e)
|
||||
return None, None
|
||||
except Exception as e:
|
||||
logger.error("WeCom _upload_media_ws error for {}: {}", file_path, e)
|
||||
return None, None
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
"""Send a message through WeCom."""
|
||||
if not self._client:
|
||||
@ -343,29 +482,59 @@ class WecomChannel(BaseChannel):
|
||||
return
|
||||
|
||||
try:
|
||||
content = msg.content.strip()
|
||||
if not content:
|
||||
return
|
||||
content = (msg.content or "").strip()
|
||||
is_progress = bool(msg.metadata.get("_progress"))
|
||||
|
||||
# Get the stored frame for this chat
|
||||
frame = self._chat_frames.get(msg.chat_id)
|
||||
if not frame:
|
||||
logger.warning("No frame found for chat {}, cannot reply", msg.chat_id)
|
||||
|
||||
# Send media files via WebSocket upload
|
||||
for file_path in msg.media or []:
|
||||
if not os.path.isfile(file_path):
|
||||
logger.warning("WeCom media file not found: {}", file_path)
|
||||
continue
|
||||
media_id, media_type = await self._upload_media_ws(self._client, file_path)
|
||||
if media_id:
|
||||
if frame:
|
||||
await self._client.reply(frame, {
|
||||
"msgtype": media_type,
|
||||
media_type: {"media_id": media_id},
|
||||
})
|
||||
else:
|
||||
await self._client.send_message(msg.chat_id, {
|
||||
"msgtype": media_type,
|
||||
media_type: {"media_id": media_id},
|
||||
})
|
||||
logger.debug("WeCom sent {} → {}", media_type, msg.chat_id)
|
||||
else:
|
||||
content += f"\n[file upload failed: {os.path.basename(file_path)}]"
|
||||
|
||||
if not content:
|
||||
return
|
||||
|
||||
# Use streaming reply for better UX
|
||||
stream_id = self._generate_req_id("stream")
|
||||
if frame:
|
||||
# Both progress and final messages must use reply_stream (cmd="aibot_respond_msg").
|
||||
# The plain reply() uses cmd="reply" which does not support "text" msgtype
|
||||
# and causes errcode=40008 from WeCom API.
|
||||
stream_id = self._generate_req_id("stream")
|
||||
await self._client.reply_stream(
|
||||
frame,
|
||||
stream_id,
|
||||
content,
|
||||
finish=not is_progress,
|
||||
)
|
||||
logger.debug(
|
||||
"WeCom {} sent to {}",
|
||||
"progress" if is_progress else "message",
|
||||
msg.chat_id,
|
||||
)
|
||||
else:
|
||||
# No frame (e.g. cron push): proactive send only supports markdown
|
||||
await self._client.send_message(msg.chat_id, {
|
||||
"msgtype": "markdown",
|
||||
"markdown": {"content": content},
|
||||
})
|
||||
logger.info("WeCom proactive send to {}", msg.chat_id)
|
||||
|
||||
# Send as streaming message with finish=True
|
||||
await self._client.reply_stream(
|
||||
frame,
|
||||
stream_id,
|
||||
content,
|
||||
finish=True,
|
||||
)
|
||||
|
||||
logger.debug("WeCom message sent to {}", msg.chat_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error sending WeCom message: {}", e)
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Error sending WeCom message to chat_id={}", msg.chat_id)
|
||||
|
||||
@ -307,7 +307,7 @@ async def test_agent_loop_extra_hook_receives_calls(tmp_path):
|
||||
)
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
content, tools_used, messages, _ = await loop._run_agent_loop(
|
||||
content, tools_used, messages, _, _ = await loop._run_agent_loop(
|
||||
[{"role": "user", "content": "hi"}]
|
||||
)
|
||||
|
||||
@ -331,7 +331,7 @@ async def test_agent_loop_extra_hook_error_isolation(tmp_path):
|
||||
)
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
content, _, _, _ = await loop._run_agent_loop(
|
||||
content, _, _, _, _ = await loop._run_agent_loop(
|
||||
[{"role": "user", "content": "hi"}]
|
||||
)
|
||||
|
||||
@ -373,7 +373,7 @@ async def test_agent_loop_no_hooks_backward_compat(tmp_path):
|
||||
loop.tools.execute = AsyncMock(return_value="ok")
|
||||
loop.max_iterations = 2
|
||||
|
||||
content, tools_used, _, _ = await loop._run_agent_loop([])
|
||||
content, tools_used, _, _, _ = await loop._run_agent_loop([])
|
||||
assert content == (
|
||||
"I reached the maximum number of tool call iterations (2) "
|
||||
"without completing the task. You can try breaking the task into smaller steps."
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import os
|
||||
import time
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
@ -798,7 +799,7 @@ async def test_loop_max_iterations_message_stays_stable(tmp_path):
|
||||
loop.tools.execute = AsyncMock(return_value="ok")
|
||||
loop.max_iterations = 2
|
||||
|
||||
final_content, _, _, _ = await loop._run_agent_loop([])
|
||||
final_content, _, _, _, _ = await loop._run_agent_loop([])
|
||||
|
||||
assert final_content == (
|
||||
"I reached the maximum number of tool call iterations (2) "
|
||||
@ -825,7 +826,7 @@ async def test_loop_stream_filter_handles_think_only_prefix_without_crashing(tmp
|
||||
async def on_stream_end(*, resuming: bool = False) -> None:
|
||||
endings.append(resuming)
|
||||
|
||||
final_content, _, _, _ = await loop._run_agent_loop(
|
||||
final_content, _, _, _, _ = await loop._run_agent_loop(
|
||||
[],
|
||||
on_stream=on_stream,
|
||||
on_stream_end=on_stream_end,
|
||||
@ -849,7 +850,7 @@ async def test_loop_retries_think_only_final_response(tmp_path):
|
||||
|
||||
loop.provider.chat_with_retry = chat_with_retry
|
||||
|
||||
final_content, _, _, _ = await loop._run_agent_loop([])
|
||||
final_content, _, _, _, _ = await loop._run_agent_loop([])
|
||||
|
||||
assert final_content == "Recovered answer"
|
||||
assert call_count["n"] == 2
|
||||
@ -1722,3 +1723,690 @@ def test_governance_fallback_still_repairs_orphans():
|
||||
repaired = AgentRunner._backfill_missing_tool_results(repaired)
|
||||
# Orphan tool result should be gone.
|
||||
assert not any(m.get("tool_call_id") == "orphan_tc" for m in repaired)
|
||||
# ── Mid-turn injection tests ──────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drain_injections_returns_empty_when_no_callback():
|
||||
"""No injection_callback → empty list."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
runner = AgentRunner(provider)
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
spec = AgentRunSpec(
|
||||
initial_messages=[], tools=tools, model="m",
|
||||
max_iterations=1, max_tool_result_chars=1000,
|
||||
injection_callback=None,
|
||||
)
|
||||
result = await runner._drain_injections(spec)
|
||||
assert result == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drain_injections_extracts_content_from_inbound_messages():
|
||||
"""Should extract .content from InboundMessage objects."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
from nanobot.bus.events import InboundMessage
|
||||
|
||||
provider = MagicMock()
|
||||
runner = AgentRunner(provider)
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
msgs = [
|
||||
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="hello"),
|
||||
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="world"),
|
||||
]
|
||||
|
||||
async def cb():
|
||||
return msgs
|
||||
|
||||
spec = AgentRunSpec(
|
||||
initial_messages=[], tools=tools, model="m",
|
||||
max_iterations=1, max_tool_result_chars=1000,
|
||||
injection_callback=cb,
|
||||
)
|
||||
result = await runner._drain_injections(spec)
|
||||
assert result == [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "user", "content": "world"},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drain_injections_passes_limit_to_callback_when_supported():
|
||||
"""Limit-aware callbacks can preserve overflow in their own queue."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner, _MAX_INJECTIONS_PER_TURN
|
||||
from nanobot.bus.events import InboundMessage
|
||||
|
||||
provider = MagicMock()
|
||||
runner = AgentRunner(provider)
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
seen_limits: list[int] = []
|
||||
|
||||
msgs = [
|
||||
InboundMessage(channel="cli", sender_id="u", chat_id="c", content=f"msg{i}")
|
||||
for i in range(_MAX_INJECTIONS_PER_TURN + 3)
|
||||
]
|
||||
|
||||
async def cb(*, limit: int):
|
||||
seen_limits.append(limit)
|
||||
return msgs[:limit]
|
||||
|
||||
spec = AgentRunSpec(
|
||||
initial_messages=[], tools=tools, model="m",
|
||||
max_iterations=1, max_tool_result_chars=1000,
|
||||
injection_callback=cb,
|
||||
)
|
||||
result = await runner._drain_injections(spec)
|
||||
assert seen_limits == [_MAX_INJECTIONS_PER_TURN]
|
||||
assert result == [
|
||||
{"role": "user", "content": "msg0"},
|
||||
{"role": "user", "content": "msg1"},
|
||||
{"role": "user", "content": "msg2"},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drain_injections_skips_empty_content():
|
||||
"""Messages with blank content should be filtered out."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
from nanobot.bus.events import InboundMessage
|
||||
|
||||
provider = MagicMock()
|
||||
runner = AgentRunner(provider)
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
msgs = [
|
||||
InboundMessage(channel="cli", sender_id="u", chat_id="c", content=""),
|
||||
InboundMessage(channel="cli", sender_id="u", chat_id="c", content=" "),
|
||||
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="valid"),
|
||||
]
|
||||
|
||||
async def cb():
|
||||
return msgs
|
||||
|
||||
spec = AgentRunSpec(
|
||||
initial_messages=[], tools=tools, model="m",
|
||||
max_iterations=1, max_tool_result_chars=1000,
|
||||
injection_callback=cb,
|
||||
)
|
||||
result = await runner._drain_injections(spec)
|
||||
assert result == [{"role": "user", "content": "valid"}]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drain_injections_handles_callback_exception():
|
||||
"""If the callback raises, return empty list (error is logged)."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
runner = AgentRunner(provider)
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
async def cb():
|
||||
raise RuntimeError("boom")
|
||||
|
||||
spec = AgentRunSpec(
|
||||
initial_messages=[], tools=tools, model="m",
|
||||
max_iterations=1, max_tool_result_chars=1000,
|
||||
injection_callback=cb,
|
||||
)
|
||||
result = await runner._drain_injections(spec)
|
||||
assert result == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_checkpoint1_injects_after_tool_execution():
|
||||
"""Follow-up messages are injected after tool execution, before next LLM call."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
from nanobot.bus.events import InboundMessage
|
||||
|
||||
provider = MagicMock()
|
||||
call_count = {"n": 0}
|
||||
captured_messages = []
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
call_count["n"] += 1
|
||||
captured_messages.append(list(messages))
|
||||
if call_count["n"] == 1:
|
||||
return LLMResponse(
|
||||
content="using tool",
|
||||
tool_calls=[ToolCallRequest(id="c1", name="read_file", arguments={"path": "x"})],
|
||||
usage={},
|
||||
)
|
||||
return LLMResponse(content="final answer", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
tools.execute = AsyncMock(return_value="file content")
|
||||
|
||||
injection_queue = asyncio.Queue()
|
||||
|
||||
async def inject_cb():
|
||||
items = []
|
||||
while not injection_queue.empty():
|
||||
items.append(await injection_queue.get())
|
||||
return items
|
||||
|
||||
# Put a follow-up message in the queue before the run starts
|
||||
await injection_queue.put(
|
||||
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up question")
|
||||
)
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "hello"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=5,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
injection_callback=inject_cb,
|
||||
))
|
||||
|
||||
assert result.had_injections is True
|
||||
assert result.final_content == "final answer"
|
||||
# The second call should have the injected user message
|
||||
assert call_count["n"] == 2
|
||||
last_messages = captured_messages[-1]
|
||||
injected = [m for m in last_messages if m.get("role") == "user" and m.get("content") == "follow-up question"]
|
||||
assert len(injected) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_checkpoint2_injects_after_final_response_with_resuming_stream():
|
||||
"""After final response, if injections exist, stream_end should get resuming=True."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
from nanobot.agent.hook import AgentHook, AgentHookContext
|
||||
from nanobot.bus.events import InboundMessage
|
||||
|
||||
provider = MagicMock()
|
||||
call_count = {"n": 0}
|
||||
stream_end_calls = []
|
||||
|
||||
class TrackingHook(AgentHook):
|
||||
def wants_streaming(self) -> bool:
|
||||
return True
|
||||
|
||||
async def on_stream_end(self, context: AgentHookContext, *, resuming: bool) -> None:
|
||||
stream_end_calls.append(resuming)
|
||||
|
||||
def finalize_content(self, context: AgentHookContext, content: str | None) -> str | None:
|
||||
return content
|
||||
|
||||
async def chat_stream_with_retry(*, messages, on_content_delta=None, **kwargs):
|
||||
call_count["n"] += 1
|
||||
if call_count["n"] == 1:
|
||||
return LLMResponse(content="first answer", tool_calls=[], usage={})
|
||||
return LLMResponse(content="second answer", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_stream_with_retry = chat_stream_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
injection_queue = asyncio.Queue()
|
||||
|
||||
async def inject_cb():
|
||||
items = []
|
||||
while not injection_queue.empty():
|
||||
items.append(await injection_queue.get())
|
||||
return items
|
||||
|
||||
# Inject a follow-up that arrives during the first response
|
||||
await injection_queue.put(
|
||||
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="quick follow-up")
|
||||
)
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "hello"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=5,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
hook=TrackingHook(),
|
||||
injection_callback=inject_cb,
|
||||
))
|
||||
|
||||
assert result.had_injections is True
|
||||
assert result.final_content == "second answer"
|
||||
assert call_count["n"] == 2
|
||||
# First stream_end should have resuming=True (because injections found)
|
||||
assert stream_end_calls[0] is True
|
||||
# Second (final) stream_end should have resuming=False
|
||||
assert stream_end_calls[-1] is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_checkpoint2_preserves_final_response_in_history_before_followup():
|
||||
"""A follow-up injected after a final answer must still see that answer in history."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
from nanobot.bus.events import InboundMessage
|
||||
|
||||
provider = MagicMock()
|
||||
call_count = {"n": 0}
|
||||
captured_messages = []
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
call_count["n"] += 1
|
||||
captured_messages.append([dict(message) for message in messages])
|
||||
if call_count["n"] == 1:
|
||||
return LLMResponse(content="first answer", tool_calls=[], usage={})
|
||||
return LLMResponse(content="second answer", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
injection_queue = asyncio.Queue()
|
||||
|
||||
async def inject_cb():
|
||||
items = []
|
||||
while not injection_queue.empty():
|
||||
items.append(await injection_queue.get())
|
||||
return items
|
||||
|
||||
await injection_queue.put(
|
||||
InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up question")
|
||||
)
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "hello"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=5,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
injection_callback=inject_cb,
|
||||
))
|
||||
|
||||
assert result.final_content == "second answer"
|
||||
assert call_count["n"] == 2
|
||||
assert captured_messages[-1] == [
|
||||
{"role": "user", "content": "hello"},
|
||||
{"role": "assistant", "content": "first answer"},
|
||||
{"role": "user", "content": "follow-up question"},
|
||||
]
|
||||
assert [
|
||||
{"role": message["role"], "content": message["content"]}
|
||||
for message in result.messages
|
||||
if message.get("role") == "assistant"
|
||||
] == [
|
||||
{"role": "assistant", "content": "first answer"},
|
||||
{"role": "assistant", "content": "second answer"},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_loop_injected_followup_preserves_image_media(tmp_path):
|
||||
"""Mid-turn follow-ups with images should keep multimodal content."""
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
|
||||
image_path = tmp_path / "followup.png"
|
||||
image_path.write_bytes(base64.b64decode(
|
||||
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO+yF9kAAAAASUVORK5CYII="
|
||||
))
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
captured_messages: list[list[dict]] = []
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
call_count["n"] += 1
|
||||
captured_messages.append(list(messages))
|
||||
if call_count["n"] == 1:
|
||||
return LLMResponse(content="first answer", tool_calls=[], usage={})
|
||||
return LLMResponse(content="second answer", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model")
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
pending_queue = asyncio.Queue()
|
||||
await pending_queue.put(InboundMessage(
|
||||
channel="cli",
|
||||
sender_id="u",
|
||||
chat_id="c",
|
||||
content="",
|
||||
media=[str(image_path)],
|
||||
))
|
||||
|
||||
final_content, _, _, _, had_injections = await loop._run_agent_loop(
|
||||
[{"role": "user", "content": "hello"}],
|
||||
channel="cli",
|
||||
chat_id="c",
|
||||
pending_queue=pending_queue,
|
||||
)
|
||||
|
||||
assert final_content == "second answer"
|
||||
assert had_injections is True
|
||||
assert call_count["n"] == 2
|
||||
injected_user_messages = [
|
||||
message for message in captured_messages[-1]
|
||||
if message.get("role") == "user" and isinstance(message.get("content"), list)
|
||||
]
|
||||
assert injected_user_messages
|
||||
assert any(
|
||||
block.get("type") == "image_url"
|
||||
for block in injected_user_messages[-1]["content"]
|
||||
if isinstance(block, dict)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_merges_multiple_injected_user_messages_without_losing_media():
|
||||
"""Multiple injected follow-ups should not create lossy consecutive user messages."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
call_count = {"n": 0}
|
||||
captured_messages = []
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
call_count["n"] += 1
|
||||
captured_messages.append([dict(message) for message in messages])
|
||||
if call_count["n"] == 1:
|
||||
return LLMResponse(content="first answer", tool_calls=[], usage={})
|
||||
return LLMResponse(content="second answer", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
async def inject_cb():
|
||||
if call_count["n"] == 1:
|
||||
return [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}},
|
||||
{"type": "text", "text": "look at this"},
|
||||
],
|
||||
},
|
||||
{"role": "user", "content": "and answer briefly"},
|
||||
]
|
||||
return []
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "hello"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=5,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
injection_callback=inject_cb,
|
||||
))
|
||||
|
||||
assert result.final_content == "second answer"
|
||||
assert call_count["n"] == 2
|
||||
second_call = captured_messages[-1]
|
||||
user_messages = [message for message in second_call if message.get("role") == "user"]
|
||||
assert len(user_messages) == 2
|
||||
injected = user_messages[-1]
|
||||
assert isinstance(injected["content"], list)
|
||||
assert any(
|
||||
block.get("type") == "image_url"
|
||||
for block in injected["content"]
|
||||
if isinstance(block, dict)
|
||||
)
|
||||
assert any(
|
||||
block.get("type") == "text" and block.get("text") == "and answer briefly"
|
||||
for block in injected["content"]
|
||||
if isinstance(block, dict)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_injection_cycles_capped_at_max():
|
||||
"""Injection cycles should be capped at _MAX_INJECTION_CYCLES."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner, _MAX_INJECTION_CYCLES
|
||||
from nanobot.bus.events import InboundMessage
|
||||
|
||||
provider = MagicMock()
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
call_count["n"] += 1
|
||||
return LLMResponse(content=f"answer-{call_count['n']}", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
drain_count = {"n": 0}
|
||||
|
||||
async def inject_cb():
|
||||
drain_count["n"] += 1
|
||||
# Only inject for the first _MAX_INJECTION_CYCLES drains
|
||||
if drain_count["n"] <= _MAX_INJECTION_CYCLES:
|
||||
return [InboundMessage(channel="cli", sender_id="u", chat_id="c", content=f"msg-{drain_count['n']}")]
|
||||
return []
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "start"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=20,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
injection_callback=inject_cb,
|
||||
))
|
||||
|
||||
assert result.had_injections is True
|
||||
# Should be capped: _MAX_INJECTION_CYCLES injection rounds + 1 final round
|
||||
assert call_count["n"] == _MAX_INJECTION_CYCLES + 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_injections_flag_is_false_by_default():
|
||||
"""had_injections should be False when no injection callback or no messages."""
|
||||
from nanobot.agent.runner import AgentRunSpec, AgentRunner
|
||||
|
||||
provider = MagicMock()
|
||||
|
||||
async def chat_with_retry(**kwargs):
|
||||
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
tools = MagicMock()
|
||||
tools.get_definitions.return_value = []
|
||||
|
||||
runner = AgentRunner(provider)
|
||||
result = await runner.run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "hi"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=1,
|
||||
max_tool_result_chars=_MAX_TOOL_RESULT_CHARS,
|
||||
))
|
||||
|
||||
assert result.had_injections is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pending_queue_cleanup_on_dispatch(tmp_path):
|
||||
"""_pending_queues should be cleaned up after _dispatch completes."""
|
||||
loop = _make_loop(tmp_path)
|
||||
|
||||
async def chat_with_retry(**kwargs):
|
||||
return LLMResponse(content="done", tool_calls=[], usage={})
|
||||
|
||||
loop.provider.chat_with_retry = chat_with_retry
|
||||
|
||||
from nanobot.bus.events import InboundMessage
|
||||
|
||||
msg = InboundMessage(channel="cli", sender_id="u", chat_id="c", content="hello")
|
||||
# The queue should not exist before dispatch
|
||||
assert msg.session_key not in loop._pending_queues
|
||||
|
||||
await loop._dispatch(msg)
|
||||
|
||||
# The queue should be cleaned up after dispatch
|
||||
assert msg.session_key not in loop._pending_queues
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_followup_routed_to_pending_queue(tmp_path):
|
||||
"""Unified-session follow-ups should route into the active pending queue."""
|
||||
from nanobot.agent.loop import UNIFIED_SESSION_KEY
|
||||
from nanobot.bus.events import InboundMessage
|
||||
|
||||
loop = _make_loop(tmp_path)
|
||||
loop._unified_session = True
|
||||
loop._dispatch = AsyncMock() # type: ignore[method-assign]
|
||||
|
||||
pending = asyncio.Queue(maxsize=20)
|
||||
loop._pending_queues[UNIFIED_SESSION_KEY] = pending
|
||||
|
||||
run_task = asyncio.create_task(loop.run())
|
||||
msg = InboundMessage(channel="discord", sender_id="u", chat_id="c", content="follow-up")
|
||||
await loop.bus.publish_inbound(msg)
|
||||
|
||||
deadline = time.time() + 2
|
||||
while pending.empty() and time.time() < deadline:
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
loop.stop()
|
||||
await asyncio.wait_for(run_task, timeout=2)
|
||||
|
||||
assert loop._dispatch.await_count == 0
|
||||
assert not pending.empty()
|
||||
queued_msg = pending.get_nowait()
|
||||
assert queued_msg.content == "follow-up"
|
||||
assert queued_msg.session_key == UNIFIED_SESSION_KEY
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pending_queue_preserves_overflow_for_next_injection_cycle(tmp_path):
|
||||
"""Pending queue should leave overflow messages queued for later drains."""
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.agent.runner import _MAX_INJECTIONS_PER_TURN
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
captured_messages: list[list[dict]] = []
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def chat_with_retry(*, messages, **kwargs):
|
||||
call_count["n"] += 1
|
||||
captured_messages.append([dict(message) for message in messages])
|
||||
return LLMResponse(content=f"answer-{call_count['n']}", tool_calls=[], usage={})
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model")
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
pending_queue = asyncio.Queue()
|
||||
total_followups = _MAX_INJECTIONS_PER_TURN + 2
|
||||
for idx in range(total_followups):
|
||||
await pending_queue.put(InboundMessage(
|
||||
channel="cli",
|
||||
sender_id="u",
|
||||
chat_id="c",
|
||||
content=f"follow-up-{idx}",
|
||||
))
|
||||
|
||||
final_content, _, _, _, had_injections = await loop._run_agent_loop(
|
||||
[{"role": "user", "content": "hello"}],
|
||||
channel="cli",
|
||||
chat_id="c",
|
||||
pending_queue=pending_queue,
|
||||
)
|
||||
|
||||
assert final_content == "answer-3"
|
||||
assert had_injections is True
|
||||
assert call_count["n"] == 3
|
||||
flattened_user_content = "\n".join(
|
||||
message["content"]
|
||||
for message in captured_messages[-1]
|
||||
if message.get("role") == "user" and isinstance(message.get("content"), str)
|
||||
)
|
||||
for idx in range(total_followups):
|
||||
assert f"follow-up-{idx}" in flattened_user_content
|
||||
assert pending_queue.empty()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pending_queue_full_falls_back_to_queued_task(tmp_path):
|
||||
"""QueueFull should preserve the message by dispatching a queued task."""
|
||||
from nanobot.bus.events import InboundMessage
|
||||
|
||||
loop = _make_loop(tmp_path)
|
||||
loop._dispatch = AsyncMock() # type: ignore[method-assign]
|
||||
|
||||
pending = asyncio.Queue(maxsize=1)
|
||||
pending.put_nowait(InboundMessage(channel="cli", sender_id="u", chat_id="c", content="already queued"))
|
||||
loop._pending_queues["cli:c"] = pending
|
||||
|
||||
run_task = asyncio.create_task(loop.run())
|
||||
msg = InboundMessage(channel="cli", sender_id="u", chat_id="c", content="follow-up")
|
||||
await loop.bus.publish_inbound(msg)
|
||||
|
||||
deadline = time.time() + 2
|
||||
while loop._dispatch.await_count == 0 and time.time() < deadline:
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
loop.stop()
|
||||
await asyncio.wait_for(run_task, timeout=2)
|
||||
|
||||
assert loop._dispatch.await_count == 1
|
||||
dispatched_msg = loop._dispatch.await_args.args[0]
|
||||
assert dispatched_msg.content == "follow-up"
|
||||
assert pending.qsize() == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_republishes_leftover_queue_messages(tmp_path):
|
||||
"""Messages left in the pending queue after _dispatch are re-published to the bus.
|
||||
|
||||
This tests the finally-block cleanup that prevents message loss when
|
||||
the runner exits early (e.g., max_iterations, tool_error) with messages
|
||||
still in the queue.
|
||||
"""
|
||||
from nanobot.bus.events import InboundMessage
|
||||
|
||||
loop = _make_loop(tmp_path)
|
||||
bus = loop.bus
|
||||
|
||||
# Simulate a completed dispatch by manually registering a queue
|
||||
# with leftover messages, then running the cleanup logic directly.
|
||||
pending = asyncio.Queue(maxsize=20)
|
||||
session_key = "cli:c"
|
||||
loop._pending_queues[session_key] = pending
|
||||
pending.put_nowait(InboundMessage(channel="cli", sender_id="u", chat_id="c", content="leftover-1"))
|
||||
pending.put_nowait(InboundMessage(channel="cli", sender_id="u", chat_id="c", content="leftover-2"))
|
||||
|
||||
# Execute the cleanup logic from the finally block
|
||||
queue = loop._pending_queues.pop(session_key, None)
|
||||
assert queue is not None
|
||||
leftover = 0
|
||||
while True:
|
||||
try:
|
||||
item = queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
await bus.publish_inbound(item)
|
||||
leftover += 1
|
||||
|
||||
assert leftover == 2
|
||||
|
||||
# Verify the messages are now on the bus
|
||||
msgs = []
|
||||
while not bus.inbound.empty():
|
||||
msgs.append(await asyncio.wait_for(bus.consume_inbound(), timeout=0.5))
|
||||
contents = [m.content for m in msgs]
|
||||
assert "leftover-1" in contents
|
||||
assert "leftover-2" in contents
|
||||
|
||||
304
tests/channels/test_qq_media.py
Normal file
304
tests/channels/test_qq_media.py
Normal file
@ -0,0 +1,304 @@
|
||||
"""Tests for QQ channel media support: helpers, send, inbound, and upload."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
try:
|
||||
from nanobot.channels import qq
|
||||
|
||||
QQ_AVAILABLE = getattr(qq, "QQ_AVAILABLE", False)
|
||||
except ImportError:
|
||||
QQ_AVAILABLE = False
|
||||
|
||||
if not QQ_AVAILABLE:
|
||||
pytest.skip("QQ dependencies not installed (qq-botpy)", allow_module_level=True)
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.qq import (
|
||||
QQ_FILE_TYPE_FILE,
|
||||
QQ_FILE_TYPE_IMAGE,
|
||||
QQChannel,
|
||||
QQConfig,
|
||||
_guess_send_file_type,
|
||||
_is_image_name,
|
||||
_sanitize_filename,
|
||||
)
|
||||
|
||||
|
||||
class _FakeApi:
|
||||
def __init__(self) -> None:
|
||||
self.c2c_calls: list[dict] = []
|
||||
self.group_calls: list[dict] = []
|
||||
|
||||
async def post_c2c_message(self, **kwargs) -> None:
|
||||
self.c2c_calls.append(kwargs)
|
||||
|
||||
async def post_group_message(self, **kwargs) -> None:
|
||||
self.group_calls.append(kwargs)
|
||||
|
||||
|
||||
class _FakeHttp:
|
||||
"""Fake _http for _post_base64file tests."""
|
||||
|
||||
def __init__(self, return_value: dict | None = None) -> None:
|
||||
self.return_value = return_value or {}
|
||||
self.calls: list[tuple] = []
|
||||
|
||||
async def request(self, route, **kwargs):
|
||||
self.calls.append((route, kwargs))
|
||||
return self.return_value
|
||||
|
||||
|
||||
class _FakeClient:
|
||||
def __init__(self, http_return: dict | None = None) -> None:
|
||||
self.api = _FakeApi()
|
||||
self.api._http = _FakeHttp(http_return)
|
||||
|
||||
|
||||
# ── Helper function tests (pure, no async) ──────────────────────────
|
||||
|
||||
|
||||
def test_sanitize_filename_strips_path_traversal() -> None:
|
||||
assert _sanitize_filename("../../etc/passwd") == "passwd"
|
||||
|
||||
|
||||
def test_sanitize_filename_keeps_chinese_chars() -> None:
|
||||
assert _sanitize_filename("文件(1).jpg") == "文件(1).jpg"
|
||||
|
||||
|
||||
def test_sanitize_filename_strips_unsafe_chars() -> None:
|
||||
result = _sanitize_filename('file<>:"|?*.txt')
|
||||
# All unsafe chars replaced with "_", but * is replaced too
|
||||
assert result.startswith("file")
|
||||
assert result.endswith(".txt")
|
||||
assert "<" not in result
|
||||
assert ">" not in result
|
||||
assert '"' not in result
|
||||
assert "|" not in result
|
||||
assert "?" not in result
|
||||
|
||||
|
||||
def test_sanitize_filename_empty_input() -> None:
|
||||
assert _sanitize_filename("") == ""
|
||||
|
||||
|
||||
def test_is_image_name_with_known_extensions() -> None:
|
||||
for ext in (".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp", ".tif", ".tiff", ".ico", ".svg"):
|
||||
assert _is_image_name(f"photo{ext}") is True
|
||||
|
||||
|
||||
def test_is_image_name_with_unknown_extension() -> None:
|
||||
for ext in (".pdf", ".txt", ".mp3", ".mp4"):
|
||||
assert _is_image_name(f"doc{ext}") is False
|
||||
|
||||
|
||||
def test_guess_send_file_type_image() -> None:
|
||||
assert _guess_send_file_type("photo.png") == QQ_FILE_TYPE_IMAGE
|
||||
assert _guess_send_file_type("pic.jpg") == QQ_FILE_TYPE_IMAGE
|
||||
|
||||
|
||||
def test_guess_send_file_type_file() -> None:
|
||||
assert _guess_send_file_type("doc.pdf") == QQ_FILE_TYPE_FILE
|
||||
|
||||
|
||||
def test_guess_send_file_type_by_mime() -> None:
|
||||
# A filename with no known extension but whose mime type is image/*
|
||||
assert _guess_send_file_type("photo.xyz_image_test") == QQ_FILE_TYPE_FILE
|
||||
|
||||
|
||||
# ── send() exception handling ───────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_exception_caught_not_raised() -> None:
|
||||
"""Exceptions inside send() must not propagate."""
|
||||
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
|
||||
channel._client = _FakeClient()
|
||||
|
||||
with patch.object(channel, "_send_text_only", new_callable=AsyncMock, side_effect=RuntimeError("boom")):
|
||||
await channel.send(
|
||||
OutboundMessage(channel="qq", chat_id="user1", content="hello")
|
||||
)
|
||||
# No exception raised — test passes if we get here.
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_media_then_text() -> None:
|
||||
"""Media is sent before text when both are present."""
|
||||
import tempfile
|
||||
|
||||
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
|
||||
channel._client = _FakeClient()
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
|
||||
f.write(b"\x89PNG\r\n")
|
||||
tmp = f.name
|
||||
|
||||
try:
|
||||
with patch.object(channel, "_post_base64file", new_callable=AsyncMock, return_value={"file_info": "1"}) as mock_upload:
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="qq",
|
||||
chat_id="user1",
|
||||
content="text after image",
|
||||
media=[tmp],
|
||||
metadata={"message_id": "m1"},
|
||||
)
|
||||
)
|
||||
assert mock_upload.called
|
||||
|
||||
# Text should have been sent via c2c (default chat type)
|
||||
text_calls = [c for c in channel._client.api.c2c_calls if c.get("msg_type") == 0]
|
||||
assert len(text_calls) >= 1
|
||||
assert text_calls[-1]["content"] == "text after image"
|
||||
finally:
|
||||
import os
|
||||
os.unlink(tmp)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_media_failure_falls_back_to_text() -> None:
|
||||
"""When _send_media returns False, a failure notice is appended."""
|
||||
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
|
||||
channel._client = _FakeClient()
|
||||
|
||||
with patch.object(channel, "_send_media", new_callable=AsyncMock, return_value=False):
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="qq",
|
||||
chat_id="user1",
|
||||
content="hello",
|
||||
media=["https://example.com/bad.png"],
|
||||
metadata={"message_id": "m1"},
|
||||
)
|
||||
)
|
||||
|
||||
# Should have the failure text among the c2c calls
|
||||
failure_calls = [c for c in channel._client.api.c2c_calls if "Attachment send failed" in c.get("content", "")]
|
||||
assert len(failure_calls) == 1
|
||||
assert "bad.png" in failure_calls[0]["content"]
|
||||
|
||||
|
||||
# ── _on_message() exception handling ────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_exception_caught_not_raised() -> None:
|
||||
"""Missing required attributes should not crash _on_message."""
|
||||
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
|
||||
channel._client = _FakeClient()
|
||||
|
||||
# Construct a message-like object that lacks 'author' — triggers AttributeError
|
||||
bad_data = SimpleNamespace(id="x1", content="hi")
|
||||
# Should not raise
|
||||
await channel._on_message(bad_data, is_group=False)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_with_attachments() -> None:
|
||||
"""Messages with attachments produce media_paths and formatted content."""
|
||||
import tempfile
|
||||
|
||||
channel = QQChannel(QQConfig(app_id="app", secret="secret", allow_from=["*"]), MessageBus())
|
||||
channel._client = _FakeClient()
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
|
||||
f.write(b"\x89PNG\r\n")
|
||||
saved_path = f.name
|
||||
|
||||
att = SimpleNamespace(url="", filename="screenshot.png", content_type="image/png")
|
||||
|
||||
# Patch _download_to_media_dir_chunked to return the temp file path
|
||||
async def fake_download(url, filename_hint=""):
|
||||
return saved_path
|
||||
|
||||
try:
|
||||
with patch.object(channel, "_download_to_media_dir_chunked", side_effect=fake_download):
|
||||
data = SimpleNamespace(
|
||||
id="att1",
|
||||
content="look at this",
|
||||
author=SimpleNamespace(user_openid="u1"),
|
||||
attachments=[att],
|
||||
)
|
||||
await channel._on_message(data, is_group=False)
|
||||
|
||||
msg = await channel.bus.consume_inbound()
|
||||
assert "look at this" in msg.content
|
||||
assert "screenshot.png" in msg.content
|
||||
assert "Received files:" in msg.content
|
||||
assert len(msg.media) == 1
|
||||
assert msg.media[0] == saved_path
|
||||
finally:
|
||||
import os
|
||||
os.unlink(saved_path)
|
||||
|
||||
|
||||
# ── _post_base64file() ─────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_post_base64file_omits_file_name_for_images() -> None:
|
||||
"""file_type=1 (image) → payload must not contain file_name."""
|
||||
channel = QQChannel(QQConfig(app_id="app", secret="secret"), MessageBus())
|
||||
channel._client = _FakeClient(http_return={"file_info": "img_abc"})
|
||||
|
||||
await channel._post_base64file(
|
||||
chat_id="user1",
|
||||
is_group=False,
|
||||
file_type=QQ_FILE_TYPE_IMAGE,
|
||||
file_data="ZmFrZQ==",
|
||||
file_name="photo.png",
|
||||
)
|
||||
|
||||
http = channel._client.api._http
|
||||
assert len(http.calls) == 1
|
||||
payload = http.calls[0][1]["json"]
|
||||
assert "file_name" not in payload
|
||||
assert payload["file_type"] == QQ_FILE_TYPE_IMAGE
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_post_base64file_includes_file_name_for_files() -> None:
|
||||
"""file_type=4 (file) → payload must contain file_name."""
|
||||
channel = QQChannel(QQConfig(app_id="app", secret="secret"), MessageBus())
|
||||
channel._client = _FakeClient(http_return={"file_info": "file_abc"})
|
||||
|
||||
await channel._post_base64file(
|
||||
chat_id="user1",
|
||||
is_group=False,
|
||||
file_type=QQ_FILE_TYPE_FILE,
|
||||
file_data="ZmFrZQ==",
|
||||
file_name="report.pdf",
|
||||
)
|
||||
|
||||
http = channel._client.api._http
|
||||
assert len(http.calls) == 1
|
||||
payload = http.calls[0][1]["json"]
|
||||
assert payload["file_name"] == "report.pdf"
|
||||
assert payload["file_type"] == QQ_FILE_TYPE_FILE
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_post_base64file_filters_response_to_file_info() -> None:
|
||||
"""Response with file_info + extra fields must be filtered to only file_info."""
|
||||
channel = QQChannel(QQConfig(app_id="app", secret="secret"), MessageBus())
|
||||
channel._client = _FakeClient(http_return={
|
||||
"file_info": "fi_123",
|
||||
"file_uuid": "uuid_xxx",
|
||||
"ttl": 3600,
|
||||
})
|
||||
|
||||
result = await channel._post_base64file(
|
||||
chat_id="user1",
|
||||
is_group=False,
|
||||
file_type=QQ_FILE_TYPE_FILE,
|
||||
file_data="ZmFrZQ==",
|
||||
file_name="doc.pdf",
|
||||
)
|
||||
|
||||
assert result == {"file_info": "fi_123"}
|
||||
assert "file_uuid" not in result
|
||||
assert "ttl" not in result
|
||||
584
tests/channels/test_wecom_channel.py
Normal file
584
tests/channels/test_wecom_channel.py
Normal file
@ -0,0 +1,584 @@
|
||||
"""Tests for WeCom channel: helpers, download, upload, send, and message processing."""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
try:
|
||||
import importlib.util
|
||||
|
||||
WECOM_AVAILABLE = importlib.util.find_spec("wecom_aibot_sdk") is not None
|
||||
except ImportError:
|
||||
WECOM_AVAILABLE = False
|
||||
|
||||
if not WECOM_AVAILABLE:
|
||||
pytest.skip("WeCom dependencies not installed (wecom_aibot_sdk)", allow_module_level=True)
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.wecom import (
|
||||
WecomChannel,
|
||||
WecomConfig,
|
||||
_guess_wecom_media_type,
|
||||
_sanitize_filename,
|
||||
)
|
||||
|
||||
# Try to import the real response class; fall back to a stub if unavailable.
|
||||
try:
|
||||
from wecom_aibot_sdk.utils import WsResponse
|
||||
|
||||
_RealWsResponse = WsResponse
|
||||
except ImportError:
|
||||
_RealWsResponse = None
|
||||
|
||||
|
||||
class _FakeResponse:
|
||||
"""Minimal stand-in for wecom_aibot_sdk WsResponse."""
|
||||
|
||||
def __init__(self, errcode: int = 0, body: dict | None = None, errmsg: str = "ok"):
|
||||
self.errcode = errcode
|
||||
self.errmsg = errmsg
|
||||
self.body = body or {}
|
||||
|
||||
|
||||
class _FakeWsManager:
|
||||
"""Tracks send_reply calls and returns configurable responses."""
|
||||
|
||||
def __init__(self, responses: list[_FakeResponse] | None = None):
|
||||
self.responses = responses or []
|
||||
self.calls: list[tuple[str, dict, str]] = []
|
||||
self._idx = 0
|
||||
|
||||
async def send_reply(self, req_id: str, data: dict, cmd: str) -> _FakeResponse:
|
||||
self.calls.append((req_id, data, cmd))
|
||||
if self._idx < len(self.responses):
|
||||
resp = self.responses[self._idx]
|
||||
self._idx += 1
|
||||
return resp
|
||||
return _FakeResponse()
|
||||
|
||||
|
||||
class _FakeFrame:
|
||||
"""Minimal frame object with a body dict."""
|
||||
|
||||
def __init__(self, body: dict | None = None):
|
||||
self.body = body or {}
|
||||
|
||||
|
||||
class _FakeWeComClient:
|
||||
"""Fake WeCom client with mock methods."""
|
||||
|
||||
def __init__(self, ws_responses: list[_FakeResponse] | None = None):
|
||||
self._ws_manager = _FakeWsManager(ws_responses)
|
||||
self.download_file = AsyncMock(return_value=(None, None))
|
||||
self.reply = AsyncMock()
|
||||
self.reply_stream = AsyncMock()
|
||||
self.send_message = AsyncMock()
|
||||
self.reply_welcome = AsyncMock()
|
||||
|
||||
|
||||
# ── Helper function tests (pure, no async) ──────────────────────────
|
||||
|
||||
|
||||
def test_sanitize_filename_strips_path_traversal() -> None:
|
||||
assert _sanitize_filename("../../etc/passwd") == "passwd"
|
||||
|
||||
|
||||
def test_sanitize_filename_keeps_chinese_chars() -> None:
|
||||
assert _sanitize_filename("文件(1).jpg") == "文件(1).jpg"
|
||||
|
||||
|
||||
def test_sanitize_filename_empty_input() -> None:
|
||||
assert _sanitize_filename("") == ""
|
||||
|
||||
|
||||
def test_guess_wecom_media_type_image() -> None:
|
||||
for ext in (".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp"):
|
||||
assert _guess_wecom_media_type(f"photo{ext}") == "image"
|
||||
|
||||
|
||||
def test_guess_wecom_media_type_video() -> None:
|
||||
for ext in (".mp4", ".avi", ".mov"):
|
||||
assert _guess_wecom_media_type(f"video{ext}") == "video"
|
||||
|
||||
|
||||
def test_guess_wecom_media_type_voice() -> None:
|
||||
for ext in (".amr", ".mp3", ".wav", ".ogg"):
|
||||
assert _guess_wecom_media_type(f"audio{ext}") == "voice"
|
||||
|
||||
|
||||
def test_guess_wecom_media_type_file_fallback() -> None:
|
||||
for ext in (".pdf", ".doc", ".xlsx", ".zip"):
|
||||
assert _guess_wecom_media_type(f"doc{ext}") == "file"
|
||||
|
||||
|
||||
def test_guess_wecom_media_type_case_insensitive() -> None:
|
||||
assert _guess_wecom_media_type("photo.PNG") == "image"
|
||||
assert _guess_wecom_media_type("photo.Jpg") == "image"
|
||||
|
||||
|
||||
# ── _download_and_save_media() ──────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_and_save_success() -> None:
|
||||
"""Successful download writes file and returns sanitized path."""
|
||||
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
|
||||
client = _FakeWeComClient()
|
||||
channel._client = client
|
||||
|
||||
fake_data = b"\x89PNG\r\nfake image"
|
||||
client.download_file.return_value = (fake_data, "raw_photo.png")
|
||||
|
||||
with patch("nanobot.channels.wecom.get_media_dir", return_value=Path(tempfile.gettempdir())):
|
||||
path = await channel._download_and_save_media("https://example.com/img.png", "aes_key", "image", "photo.png")
|
||||
|
||||
assert path is not None
|
||||
assert os.path.isfile(path)
|
||||
assert os.path.basename(path) == "photo.png"
|
||||
# Cleanup
|
||||
os.unlink(path)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_and_save_oversized_rejected() -> None:
|
||||
"""Data exceeding 200MB is rejected → returns None."""
|
||||
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
|
||||
client = _FakeWeComClient()
|
||||
channel._client = client
|
||||
|
||||
big_data = b"\x00" * (200 * 1024 * 1024 + 1) # 200MB + 1 byte
|
||||
client.download_file.return_value = (big_data, "big.bin")
|
||||
|
||||
with patch("nanobot.channels.wecom.get_media_dir", return_value=Path(tempfile.gettempdir())):
|
||||
result = await channel._download_and_save_media("https://example.com/big.bin", "key", "file", "big.bin")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_and_save_failure() -> None:
|
||||
"""SDK returns None data → returns None."""
|
||||
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
|
||||
client = _FakeWeComClient()
|
||||
channel._client = client
|
||||
|
||||
client.download_file.return_value = (None, None)
|
||||
|
||||
with patch("nanobot.channels.wecom.get_media_dir", return_value=Path(tempfile.gettempdir())):
|
||||
result = await channel._download_and_save_media("https://example.com/fail.png", "key", "image")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
# ── _upload_media_ws() ──────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_media_ws_success() -> None:
|
||||
"""Happy path: init → chunk → finish → returns (media_id, media_type)."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
|
||||
f.write(b"\x89PNG\r\n")
|
||||
tmp = f.name
|
||||
|
||||
try:
|
||||
responses = [
|
||||
_FakeResponse(errcode=0, body={"upload_id": "up_1"}),
|
||||
_FakeResponse(errcode=0, body={}),
|
||||
_FakeResponse(errcode=0, body={"media_id": "media_abc"}),
|
||||
]
|
||||
|
||||
client = _FakeWeComClient(responses)
|
||||
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
|
||||
channel._client = client
|
||||
|
||||
with patch("wecom_aibot_sdk.utils.generate_req_id", side_effect=lambda x: f"req_{x}"):
|
||||
media_id, media_type = await channel._upload_media_ws(client, tmp)
|
||||
|
||||
assert media_id == "media_abc"
|
||||
assert media_type == "image"
|
||||
finally:
|
||||
os.unlink(tmp)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_media_ws_oversized_file() -> None:
|
||||
"""File >200MB triggers ValueError → returns (None, None)."""
|
||||
# Instead of creating a real 200MB+ file, mock os.path.getsize and open
|
||||
with patch("os.path.getsize", return_value=200 * 1024 * 1024 + 1), \
|
||||
patch("builtins.open", MagicMock()):
|
||||
client = _FakeWeComClient()
|
||||
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
|
||||
channel._client = client
|
||||
|
||||
result = await channel._upload_media_ws(client, "/fake/large.bin")
|
||||
assert result == (None, None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_media_ws_init_failure() -> None:
|
||||
"""Init step returns errcode != 0 → returns (None, None)."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as f:
|
||||
f.write(b"hello")
|
||||
tmp = f.name
|
||||
|
||||
try:
|
||||
responses = [
|
||||
_FakeResponse(errcode=50001, errmsg="invalid"),
|
||||
]
|
||||
|
||||
client = _FakeWeComClient(responses)
|
||||
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
|
||||
channel._client = client
|
||||
|
||||
with patch("wecom_aibot_sdk.utils.generate_req_id", side_effect=lambda x: f"req_{x}"):
|
||||
result = await channel._upload_media_ws(client, tmp)
|
||||
|
||||
assert result == (None, None)
|
||||
finally:
|
||||
os.unlink(tmp)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_media_ws_chunk_failure() -> None:
|
||||
"""Chunk step returns errcode != 0 → returns (None, None)."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
|
||||
f.write(b"\x89PNG\r\n")
|
||||
tmp = f.name
|
||||
|
||||
try:
|
||||
responses = [
|
||||
_FakeResponse(errcode=0, body={"upload_id": "up_1"}),
|
||||
_FakeResponse(errcode=50002, errmsg="chunk fail"),
|
||||
]
|
||||
|
||||
client = _FakeWeComClient(responses)
|
||||
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
|
||||
channel._client = client
|
||||
|
||||
with patch("wecom_aibot_sdk.utils.generate_req_id", side_effect=lambda x: f"req_{x}"):
|
||||
result = await channel._upload_media_ws(client, tmp)
|
||||
|
||||
assert result == (None, None)
|
||||
finally:
|
||||
os.unlink(tmp)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_media_ws_finish_no_media_id() -> None:
|
||||
"""Finish step returns empty media_id → returns (None, None)."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
|
||||
f.write(b"\x89PNG\r\n")
|
||||
tmp = f.name
|
||||
|
||||
try:
|
||||
responses = [
|
||||
_FakeResponse(errcode=0, body={"upload_id": "up_1"}),
|
||||
_FakeResponse(errcode=0, body={}),
|
||||
_FakeResponse(errcode=0, body={}), # no media_id
|
||||
]
|
||||
|
||||
client = _FakeWeComClient(responses)
|
||||
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
|
||||
channel._client = client
|
||||
|
||||
with patch("wecom_aibot_sdk.utils.generate_req_id", side_effect=lambda x: f"req_{x}"):
|
||||
result = await channel._upload_media_ws(client, tmp)
|
||||
|
||||
assert result == (None, None)
|
||||
finally:
|
||||
os.unlink(tmp)
|
||||
|
||||
|
||||
# ── send() ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_text_with_frame() -> None:
|
||||
"""When frame is stored, send uses reply_stream for final text."""
|
||||
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
|
||||
client = _FakeWeComClient()
|
||||
channel._client = client
|
||||
channel._generate_req_id = lambda x: f"req_{x}"
|
||||
channel._chat_frames["chat1"] = _FakeFrame()
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(channel="wecom", chat_id="chat1", content="hello")
|
||||
)
|
||||
|
||||
client.reply_stream.assert_called_once()
|
||||
call_args = client.reply_stream.call_args
|
||||
assert call_args[0][2] == "hello" # content arg
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_progress_with_frame() -> None:
|
||||
"""When metadata has _progress, send uses reply_stream with finish=False."""
|
||||
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
|
||||
client = _FakeWeComClient()
|
||||
channel._client = client
|
||||
channel._generate_req_id = lambda x: f"req_{x}"
|
||||
channel._chat_frames["chat1"] = _FakeFrame()
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(channel="wecom", chat_id="chat1", content="thinking...", metadata={"_progress": True})
|
||||
)
|
||||
|
||||
client.reply_stream.assert_called_once()
|
||||
call_args = client.reply_stream.call_args
|
||||
assert call_args[0][2] == "thinking..." # content arg
|
||||
assert call_args[1]["finish"] is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_proactive_without_frame() -> None:
|
||||
"""Without stored frame, send uses send_message with markdown."""
|
||||
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
|
||||
client = _FakeWeComClient()
|
||||
channel._client = client
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(channel="wecom", chat_id="chat1", content="proactive msg")
|
||||
)
|
||||
|
||||
client.send_message.assert_called_once()
|
||||
call_args = client.send_message.call_args
|
||||
assert call_args[0][0] == "chat1"
|
||||
assert call_args[0][1]["msgtype"] == "markdown"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_media_then_text() -> None:
|
||||
"""Media files are uploaded and sent before text content."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
|
||||
f.write(b"\x89PNG\r\n")
|
||||
tmp = f.name
|
||||
|
||||
try:
|
||||
responses = [
|
||||
_FakeResponse(errcode=0, body={"upload_id": "up_1"}),
|
||||
_FakeResponse(errcode=0, body={}),
|
||||
_FakeResponse(errcode=0, body={"media_id": "media_123"}),
|
||||
]
|
||||
|
||||
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
|
||||
client = _FakeWeComClient(responses)
|
||||
channel._client = client
|
||||
channel._generate_req_id = lambda x: f"req_{x}"
|
||||
channel._chat_frames["chat1"] = _FakeFrame()
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(channel="wecom", chat_id="chat1", content="see image", media=[tmp])
|
||||
)
|
||||
|
||||
# Media should have been sent via reply
|
||||
media_calls = [c for c in client.reply.call_args_list if c[0][1].get("msgtype") == "image"]
|
||||
assert len(media_calls) == 1
|
||||
assert media_calls[0][0][1]["image"]["media_id"] == "media_123"
|
||||
|
||||
# Text should have been sent via reply_stream
|
||||
client.reply_stream.assert_called_once()
|
||||
finally:
|
||||
os.unlink(tmp)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_media_file_not_found() -> None:
|
||||
"""Non-existent media path is skipped with a warning."""
|
||||
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
|
||||
client = _FakeWeComClient()
|
||||
channel._client = client
|
||||
channel._generate_req_id = lambda x: f"req_{x}"
|
||||
channel._chat_frames["chat1"] = _FakeFrame()
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(channel="wecom", chat_id="chat1", content="hello", media=["/nonexistent/file.png"])
|
||||
)
|
||||
|
||||
# reply_stream should still be called for the text part
|
||||
client.reply_stream.assert_called_once()
|
||||
# No media reply should happen
|
||||
media_calls = [c for c in client.reply.call_args_list if c[0][1].get("msgtype") in ("image", "file", "video")]
|
||||
assert len(media_calls) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_exception_caught_not_raised() -> None:
|
||||
"""Exceptions inside send() must not propagate."""
|
||||
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["*"]), MessageBus())
|
||||
client = _FakeWeComClient()
|
||||
channel._client = client
|
||||
channel._generate_req_id = lambda x: f"req_{x}"
|
||||
channel._chat_frames["chat1"] = _FakeFrame()
|
||||
|
||||
# Make reply_stream raise
|
||||
client.reply_stream.side_effect = RuntimeError("boom")
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(channel="wecom", chat_id="chat1", content="fail test")
|
||||
)
|
||||
# No exception — test passes if we reach here.
|
||||
|
||||
|
||||
# ── _process_message() ──────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_text_message() -> None:
|
||||
"""Text message is routed to bus with correct fields."""
|
||||
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["user1"]), MessageBus())
|
||||
client = _FakeWeComClient()
|
||||
channel._client = client
|
||||
|
||||
frame = _FakeFrame(body={
|
||||
"msgid": "msg_text_1",
|
||||
"chatid": "chat1",
|
||||
"chattype": "single",
|
||||
"from": {"userid": "user1"},
|
||||
"text": {"content": "hello wecom"},
|
||||
})
|
||||
|
||||
await channel._process_message(frame, "text")
|
||||
|
||||
msg = await channel.bus.consume_inbound()
|
||||
assert msg.sender_id == "user1"
|
||||
assert msg.chat_id == "chat1"
|
||||
assert msg.content == "hello wecom"
|
||||
assert msg.metadata["msg_type"] == "text"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_image_message() -> None:
|
||||
"""Image message: download success → media_paths non-empty."""
|
||||
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["user1"]), MessageBus())
|
||||
client = _FakeWeComClient()
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
|
||||
f.write(b"\x89PNG\r\n")
|
||||
saved = f.name
|
||||
|
||||
client.download_file.return_value = (b"\x89PNG\r\n", "photo.png")
|
||||
channel._client = client
|
||||
|
||||
try:
|
||||
with patch("nanobot.channels.wecom.get_media_dir", return_value=Path(os.path.dirname(saved))):
|
||||
frame = _FakeFrame(body={
|
||||
"msgid": "msg_img_1",
|
||||
"chatid": "chat1",
|
||||
"from": {"userid": "user1"},
|
||||
"image": {"url": "https://example.com/img.png", "aeskey": "key123"},
|
||||
})
|
||||
await channel._process_message(frame, "image")
|
||||
|
||||
msg = await channel.bus.consume_inbound()
|
||||
assert len(msg.media) == 1
|
||||
assert msg.media[0].endswith("photo.png")
|
||||
assert "[image:" in msg.content
|
||||
finally:
|
||||
if os.path.exists(saved):
|
||||
pass # may have been overwritten; clean up if exists
|
||||
# Clean up any photo.png in tempdir
|
||||
p = os.path.join(os.path.dirname(saved), "photo.png")
|
||||
if os.path.exists(p):
|
||||
os.unlink(p)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_file_message() -> None:
|
||||
"""File message: download success → media_paths non-empty (critical fix verification)."""
|
||||
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["user1"]), MessageBus())
|
||||
client = _FakeWeComClient()
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as f:
|
||||
f.write(b"%PDF-1.4 fake")
|
||||
saved = f.name
|
||||
|
||||
client.download_file.return_value = (b"%PDF-1.4 fake", "report.pdf")
|
||||
channel._client = client
|
||||
|
||||
try:
|
||||
with patch("nanobot.channels.wecom.get_media_dir", return_value=Path(os.path.dirname(saved))):
|
||||
frame = _FakeFrame(body={
|
||||
"msgid": "msg_file_1",
|
||||
"chatid": "chat1",
|
||||
"from": {"userid": "user1"},
|
||||
"file": {"url": "https://example.com/report.pdf", "aeskey": "key456", "name": "report.pdf"},
|
||||
})
|
||||
await channel._process_message(frame, "file")
|
||||
|
||||
msg = await channel.bus.consume_inbound()
|
||||
assert len(msg.media) == 1
|
||||
assert msg.media[0].endswith("report.pdf")
|
||||
assert "[file: report.pdf]" in msg.content
|
||||
finally:
|
||||
p = os.path.join(os.path.dirname(saved), "report.pdf")
|
||||
if os.path.exists(p):
|
||||
os.unlink(p)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_voice_message() -> None:
|
||||
"""Voice message: transcribed text is included in content."""
|
||||
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["user1"]), MessageBus())
|
||||
client = _FakeWeComClient()
|
||||
channel._client = client
|
||||
|
||||
frame = _FakeFrame(body={
|
||||
"msgid": "msg_voice_1",
|
||||
"chatid": "chat1",
|
||||
"from": {"userid": "user1"},
|
||||
"voice": {"content": "transcribed text here"},
|
||||
})
|
||||
|
||||
await channel._process_message(frame, "voice")
|
||||
|
||||
msg = await channel.bus.consume_inbound()
|
||||
assert "transcribed text here" in msg.content
|
||||
assert "[voice]" in msg.content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_message_deduplication() -> None:
|
||||
"""Same msg_id is not processed twice."""
|
||||
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["user1"]), MessageBus())
|
||||
client = _FakeWeComClient()
|
||||
channel._client = client
|
||||
|
||||
frame = _FakeFrame(body={
|
||||
"msgid": "msg_dup_1",
|
||||
"chatid": "chat1",
|
||||
"from": {"userid": "user1"},
|
||||
"text": {"content": "once"},
|
||||
})
|
||||
|
||||
await channel._process_message(frame, "text")
|
||||
await channel._process_message(frame, "text")
|
||||
|
||||
msg = await channel.bus.consume_inbound()
|
||||
assert msg.content == "once"
|
||||
|
||||
# Second message should not appear on the bus
|
||||
assert channel.bus.inbound.empty()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_message_empty_content_skipped() -> None:
|
||||
"""Message with empty content produces no bus message."""
|
||||
channel = WecomChannel(WecomConfig(bot_id="b", secret="s", allow_from=["user1"]), MessageBus())
|
||||
client = _FakeWeComClient()
|
||||
channel._client = client
|
||||
|
||||
frame = _FakeFrame(body={
|
||||
"msgid": "msg_empty_1",
|
||||
"chatid": "chat1",
|
||||
"from": {"userid": "user1"},
|
||||
"text": {"content": ""},
|
||||
})
|
||||
|
||||
await channel._process_message(frame, "text")
|
||||
|
||||
assert channel.bus.inbound.empty()
|
||||
@ -1,5 +1,6 @@
|
||||
"""Test message tool suppress logic for final replies."""
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
@ -86,6 +87,42 @@ class TestMessageToolSuppressLogic:
|
||||
assert result is not None
|
||||
assert "Hello" in result.content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_injected_followup_with_message_tool_does_not_emit_empty_fallback(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
loop = _make_loop(tmp_path)
|
||||
tool_call = ToolCallRequest(
|
||||
id="call1", name="message",
|
||||
arguments={"content": "Tool reply", "channel": "feishu", "chat_id": "chat123"},
|
||||
)
|
||||
calls = iter([
|
||||
LLMResponse(content="First answer", tool_calls=[]),
|
||||
LLMResponse(content="", tool_calls=[tool_call]),
|
||||
LLMResponse(content="", tool_calls=[]),
|
||||
LLMResponse(content="", tool_calls=[]),
|
||||
LLMResponse(content="", tool_calls=[]),
|
||||
])
|
||||
loop.provider.chat_with_retry = AsyncMock(side_effect=lambda *a, **kw: next(calls))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
sent: list[OutboundMessage] = []
|
||||
mt = loop.tools.get("message")
|
||||
if isinstance(mt, MessageTool):
|
||||
mt.set_send_callback(AsyncMock(side_effect=lambda m: sent.append(m)))
|
||||
|
||||
pending_queue = asyncio.Queue()
|
||||
await pending_queue.put(
|
||||
InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="follow-up")
|
||||
)
|
||||
|
||||
msg = InboundMessage(channel="feishu", sender_id="user1", chat_id="chat123", content="Start")
|
||||
result = await loop._process_message(msg, pending_queue=pending_queue)
|
||||
|
||||
assert len(sent) == 1
|
||||
assert sent[0].content == "Tool reply"
|
||||
assert result is None
|
||||
|
||||
async def test_progress_hides_internal_reasoning(self, tmp_path: Path) -> None:
|
||||
loop = _make_loop(tmp_path)
|
||||
tool_call = ToolCallRequest(id="call1", name="read_file", arguments={"path": "foo.txt"})
|
||||
@ -107,7 +144,7 @@ class TestMessageToolSuppressLogic:
|
||||
async def on_progress(content: str, *, tool_hint: bool = False) -> None:
|
||||
progress.append((content, tool_hint))
|
||||
|
||||
final_content, _, _, _ = await loop._run_agent_loop([], on_progress=on_progress)
|
||||
final_content, _, _, _, _ = await loop._run_agent_loop([], on_progress=on_progress)
|
||||
|
||||
assert final_content == "Done"
|
||||
assert progress == [
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user