mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-22 01:22:48 +00:00
Merge PR #3454: feat(webui): add ask-user choices and model settings
feat(webui): add ask-user choices and model settings
This commit is contained in:
commit
c64ec3e73c
@ -42,6 +42,7 @@ from nanobot.bus.queue import MessageBus
|
|||||||
from nanobot.command import CommandContext, CommandRouter, register_builtin_commands
|
from nanobot.command import CommandContext, CommandRouter, register_builtin_commands
|
||||||
from nanobot.config.schema import AgentDefaults
|
from nanobot.config.schema import AgentDefaults
|
||||||
from nanobot.providers.base import LLMProvider
|
from nanobot.providers.base import LLMProvider
|
||||||
|
from nanobot.providers.factory import ProviderSnapshot
|
||||||
from nanobot.session.manager import Session, SessionManager
|
from nanobot.session.manager import Session, SessionManager
|
||||||
from nanobot.utils.document import extract_documents
|
from nanobot.utils.document import extract_documents
|
||||||
from nanobot.utils.helpers import image_placeholder_text
|
from nanobot.utils.helpers import image_placeholder_text
|
||||||
@ -195,6 +196,8 @@ class AgentLoop:
|
|||||||
unified_session: bool = False,
|
unified_session: bool = False,
|
||||||
disabled_skills: list[str] | None = None,
|
disabled_skills: list[str] | None = None,
|
||||||
tools_config: ToolsConfig | None = None,
|
tools_config: ToolsConfig | None = None,
|
||||||
|
provider_snapshot_loader: Callable[[], ProviderSnapshot] | None = None,
|
||||||
|
provider_signature: tuple[object, ...] | None = None,
|
||||||
):
|
):
|
||||||
from nanobot.config.schema import ExecToolConfig, ToolsConfig, WebToolsConfig
|
from nanobot.config.schema import ExecToolConfig, ToolsConfig, WebToolsConfig
|
||||||
|
|
||||||
@ -203,6 +206,8 @@ class AgentLoop:
|
|||||||
self.bus = bus
|
self.bus = bus
|
||||||
self.channels_config = channels_config
|
self.channels_config = channels_config
|
||||||
self.provider = provider
|
self.provider = provider
|
||||||
|
self._provider_snapshot_loader = provider_snapshot_loader
|
||||||
|
self._provider_signature = provider_signature
|
||||||
self.workspace = workspace
|
self.workspace = workspace
|
||||||
self.model = model or provider.get_default_model()
|
self.model = model or provider.get_default_model()
|
||||||
self.max_iterations = (
|
self.max_iterations = (
|
||||||
@ -290,6 +295,36 @@ class AgentLoop:
|
|||||||
self.commands = CommandRouter()
|
self.commands = CommandRouter()
|
||||||
register_builtin_commands(self.commands)
|
register_builtin_commands(self.commands)
|
||||||
|
|
||||||
|
def _apply_provider_snapshot(self, snapshot: ProviderSnapshot) -> None:
|
||||||
|
"""Swap model/provider for future turns without disturbing an active one."""
|
||||||
|
provider = snapshot.provider
|
||||||
|
model = snapshot.model
|
||||||
|
context_window_tokens = snapshot.context_window_tokens
|
||||||
|
if self.provider is provider and self.model == model:
|
||||||
|
return
|
||||||
|
old_model = self.model
|
||||||
|
self.provider = provider
|
||||||
|
self.model = model
|
||||||
|
self.context_window_tokens = context_window_tokens
|
||||||
|
self.runner.provider = provider
|
||||||
|
self.subagents.set_provider(provider, model)
|
||||||
|
self.consolidator.set_provider(provider, model, context_window_tokens)
|
||||||
|
self.dream.set_provider(provider, model)
|
||||||
|
self._provider_signature = snapshot.signature
|
||||||
|
logger.info("Runtime model switched for next turn: {} -> {}", old_model, model)
|
||||||
|
|
||||||
|
def _refresh_provider_snapshot(self) -> None:
|
||||||
|
if self._provider_snapshot_loader is None:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
snapshot = self._provider_snapshot_loader()
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to refresh provider config")
|
||||||
|
return
|
||||||
|
if snapshot.signature == self._provider_signature:
|
||||||
|
return
|
||||||
|
self._apply_provider_snapshot(snapshot)
|
||||||
|
|
||||||
def _register_default_tools(self) -> None:
|
def _register_default_tools(self) -> None:
|
||||||
"""Register the default set of tools."""
|
"""Register the default set of tools."""
|
||||||
allowed_dir = (
|
allowed_dir = (
|
||||||
@ -768,6 +803,7 @@ class AgentLoop:
|
|||||||
pending_queue: asyncio.Queue | None = None,
|
pending_queue: asyncio.Queue | None = None,
|
||||||
) -> OutboundMessage | None:
|
) -> OutboundMessage | None:
|
||||||
"""Process a single inbound message and return the response."""
|
"""Process a single inbound message and return the response."""
|
||||||
|
self._refresh_provider_snapshot()
|
||||||
# System messages: parse origin from chat_id ("channel:chat_id")
|
# System messages: parse origin from chat_id ("channel:chat_id")
|
||||||
if msg.channel == "system":
|
if msg.channel == "system":
|
||||||
channel, chat_id = (
|
channel, chat_id = (
|
||||||
|
|||||||
@ -450,6 +450,17 @@ class Consolidator:
|
|||||||
weakref.WeakValueDictionary()
|
weakref.WeakValueDictionary()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def set_provider(
|
||||||
|
self,
|
||||||
|
provider: LLMProvider,
|
||||||
|
model: str,
|
||||||
|
context_window_tokens: int,
|
||||||
|
) -> None:
|
||||||
|
self.provider = provider
|
||||||
|
self.model = model
|
||||||
|
self.context_window_tokens = context_window_tokens
|
||||||
|
self.max_completion_tokens = provider.generation.max_tokens
|
||||||
|
|
||||||
def get_lock(self, session_key: str) -> asyncio.Lock:
|
def get_lock(self, session_key: str) -> asyncio.Lock:
|
||||||
"""Return the shared consolidation lock for one session."""
|
"""Return the shared consolidation lock for one session."""
|
||||||
return self._locks.setdefault(session_key, asyncio.Lock())
|
return self._locks.setdefault(session_key, asyncio.Lock())
|
||||||
@ -710,6 +721,11 @@ class Dream:
|
|||||||
self._runner = AgentRunner(provider)
|
self._runner = AgentRunner(provider)
|
||||||
self._tools = self._build_tools()
|
self._tools = self._build_tools()
|
||||||
|
|
||||||
|
def set_provider(self, provider: LLMProvider, model: str) -> None:
|
||||||
|
self.provider = provider
|
||||||
|
self.model = model
|
||||||
|
self._runner.provider = provider
|
||||||
|
|
||||||
# -- tool registry -------------------------------------------------------
|
# -- tool registry -------------------------------------------------------
|
||||||
|
|
||||||
def _build_tools(self) -> ToolRegistry:
|
def _build_tools(self) -> ToolRegistry:
|
||||||
|
|||||||
@ -96,6 +96,11 @@ class SubagentManager:
|
|||||||
self._task_statuses: dict[str, SubagentStatus] = {}
|
self._task_statuses: dict[str, SubagentStatus] = {}
|
||||||
self._session_tasks: dict[str, set[str]] = {} # session_key -> {task_id, ...}
|
self._session_tasks: dict[str, set[str]] = {} # session_key -> {task_id, ...}
|
||||||
|
|
||||||
|
def set_provider(self, provider: LLMProvider, model: str) -> None:
|
||||||
|
self.provider = provider
|
||||||
|
self.model = model
|
||||||
|
self.runner.provider = provider
|
||||||
|
|
||||||
async def spawn(
|
async def spawn(
|
||||||
self,
|
self,
|
||||||
task: str,
|
task: str,
|
||||||
|
|||||||
@ -6,7 +6,7 @@ from typing import Any
|
|||||||
from nanobot.agent.tools.base import Tool, tool_parameters
|
from nanobot.agent.tools.base import Tool, tool_parameters
|
||||||
from nanobot.agent.tools.schema import ArraySchema, StringSchema, tool_parameters_schema
|
from nanobot.agent.tools.schema import ArraySchema, StringSchema, tool_parameters_schema
|
||||||
|
|
||||||
BUTTON_CHANNELS = frozenset({"telegram"})
|
STRUCTURED_BUTTON_CHANNELS = frozenset({"telegram", "websocket"})
|
||||||
|
|
||||||
|
|
||||||
class AskUserInterrupt(BaseException):
|
class AskUserInterrupt(BaseException):
|
||||||
@ -130,7 +130,7 @@ def ask_user_outbound(
|
|||||||
) -> tuple[str | None, list[list[str]]]:
|
) -> tuple[str | None, list[list[str]]]:
|
||||||
if not options:
|
if not options:
|
||||||
return content, []
|
return content, []
|
||||||
if channel in BUTTON_CHANNELS:
|
if channel in STRUCTURED_BUTTON_CHANNELS:
|
||||||
return content, [options]
|
return content, [options]
|
||||||
option_text = "\n".join(f"{index}. {option}" for index, option in enumerate(options, 1))
|
option_text = "\n".join(f"{index}. {option}" for index, option in enumerate(options, 1))
|
||||||
return f"{content}\n\n{option_text}" if content else option_text, []
|
return f"{content}\n\n{option_text}" if content else option_text, []
|
||||||
|
|||||||
@ -54,6 +54,14 @@ def _normalize_config_path(path: str) -> str:
|
|||||||
return _strip_trailing_slash(path)
|
return _strip_trailing_slash(path)
|
||||||
|
|
||||||
|
|
||||||
|
def _append_buttons_as_text(text: str, buttons: list[list[str]]) -> str:
|
||||||
|
labels = [label for row in buttons for label in row if label]
|
||||||
|
if not labels:
|
||||||
|
return text
|
||||||
|
fallback = "\n".join(f"{index}. {label}" for index, label in enumerate(labels, 1))
|
||||||
|
return f"{text}\n\n{fallback}" if text else fallback
|
||||||
|
|
||||||
|
|
||||||
class WebSocketConfig(Base):
|
class WebSocketConfig(Base):
|
||||||
"""WebSocket server channel configuration.
|
"""WebSocket server channel configuration.
|
||||||
|
|
||||||
@ -531,6 +539,12 @@ class WebSocketChannel(BaseChannel):
|
|||||||
if got == "/api/sessions":
|
if got == "/api/sessions":
|
||||||
return self._handle_sessions_list(request)
|
return self._handle_sessions_list(request)
|
||||||
|
|
||||||
|
if got == "/api/settings":
|
||||||
|
return self._handle_settings(request)
|
||||||
|
|
||||||
|
if got == "/api/settings/update":
|
||||||
|
return self._handle_settings_update(request)
|
||||||
|
|
||||||
m = re.match(r"^/api/sessions/([^/]+)/messages$", got)
|
m = re.match(r"^/api/sessions/([^/]+)/messages$", got)
|
||||||
if m:
|
if m:
|
||||||
return self._handle_session_messages(request, m.group(1))
|
return self._handle_session_messages(request, m.group(1))
|
||||||
@ -639,6 +653,75 @@ class WebSocketChannel(BaseChannel):
|
|||||||
]
|
]
|
||||||
return _http_json_response({"sessions": cleaned})
|
return _http_json_response({"sessions": cleaned})
|
||||||
|
|
||||||
|
def _settings_payload(self, *, requires_restart: bool = False) -> dict[str, Any]:
|
||||||
|
from nanobot.config.loader import get_config_path, load_config
|
||||||
|
from nanobot.providers.registry import PROVIDERS, find_by_name
|
||||||
|
|
||||||
|
config = load_config()
|
||||||
|
defaults = config.agents.defaults
|
||||||
|
provider_name = config.get_provider_name(defaults.model) or defaults.provider
|
||||||
|
provider = config.get_provider(defaults.model)
|
||||||
|
selected_provider = provider_name
|
||||||
|
if defaults.provider != "auto":
|
||||||
|
spec = find_by_name(defaults.provider)
|
||||||
|
selected_provider = spec.name if spec else provider_name
|
||||||
|
return {
|
||||||
|
"agent": {
|
||||||
|
"model": defaults.model,
|
||||||
|
"provider": selected_provider,
|
||||||
|
"resolved_provider": provider_name,
|
||||||
|
"has_api_key": bool(provider and provider.api_key),
|
||||||
|
},
|
||||||
|
"providers": [
|
||||||
|
{"name": "auto", "label": "Auto"}
|
||||||
|
] + [
|
||||||
|
{"name": spec.name, "label": spec.label}
|
||||||
|
for spec in PROVIDERS
|
||||||
|
],
|
||||||
|
"runtime": {
|
||||||
|
"config_path": str(get_config_path().expanduser()),
|
||||||
|
},
|
||||||
|
"requires_restart": requires_restart,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _handle_settings(self, request: WsRequest) -> Response:
|
||||||
|
if not self._check_api_token(request):
|
||||||
|
return _http_error(401, "Unauthorized")
|
||||||
|
return _http_json_response(self._settings_payload())
|
||||||
|
|
||||||
|
def _handle_settings_update(self, request: WsRequest) -> Response:
|
||||||
|
if not self._check_api_token(request):
|
||||||
|
return _http_error(401, "Unauthorized")
|
||||||
|
from nanobot.config.loader import load_config, save_config
|
||||||
|
from nanobot.providers.registry import find_by_name
|
||||||
|
|
||||||
|
query = _parse_query(request.path)
|
||||||
|
config = load_config()
|
||||||
|
defaults = config.agents.defaults
|
||||||
|
changed = False
|
||||||
|
|
||||||
|
model = _query_first(query, "model")
|
||||||
|
if model is not None:
|
||||||
|
model = model.strip()
|
||||||
|
if not model:
|
||||||
|
return _http_error(400, "model is required")
|
||||||
|
if defaults.model != model:
|
||||||
|
defaults.model = model
|
||||||
|
changed = True
|
||||||
|
|
||||||
|
provider = _query_first(query, "provider")
|
||||||
|
if provider is not None:
|
||||||
|
provider = provider.strip() or "auto"
|
||||||
|
if provider != "auto" and find_by_name(provider) is None:
|
||||||
|
return _http_error(400, "unknown provider")
|
||||||
|
if defaults.provider != provider:
|
||||||
|
defaults.provider = provider
|
||||||
|
changed = True
|
||||||
|
|
||||||
|
if changed:
|
||||||
|
save_config(config)
|
||||||
|
return _http_json_response(self._settings_payload(requires_restart=changed))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _is_webui_session_key(key: str) -> bool:
|
def _is_webui_session_key(key: str) -> bool:
|
||||||
"""Return True when *key* belongs to the webui's websocket-only surface."""
|
"""Return True when *key* belongs to the webui's websocket-only surface."""
|
||||||
@ -1146,11 +1229,17 @@ class WebSocketChannel(BaseChannel):
|
|||||||
if not conns:
|
if not conns:
|
||||||
logger.warning("websocket: no active subscribers for chat_id={}", msg.chat_id)
|
logger.warning("websocket: no active subscribers for chat_id={}", msg.chat_id)
|
||||||
return
|
return
|
||||||
|
text = msg.content
|
||||||
|
if msg.buttons:
|
||||||
|
text = _append_buttons_as_text(text, msg.buttons)
|
||||||
payload: dict[str, Any] = {
|
payload: dict[str, Any] = {
|
||||||
"event": "message",
|
"event": "message",
|
||||||
"chat_id": msg.chat_id,
|
"chat_id": msg.chat_id,
|
||||||
"text": msg.content,
|
"text": text,
|
||||||
}
|
}
|
||||||
|
if msg.buttons:
|
||||||
|
payload["buttons"] = msg.buttons
|
||||||
|
payload["button_prompt"] = msg.content
|
||||||
if msg.media:
|
if msg.media:
|
||||||
payload["media"] = msg.media
|
payload["media"] = msg.media
|
||||||
urls: list[dict[str, str]] = []
|
urls: list[dict[str, str]] = []
|
||||||
|
|||||||
@ -412,73 +412,13 @@ def _make_provider(config: Config):
|
|||||||
|
|
||||||
Routing is driven by ``ProviderSpec.backend`` in the registry.
|
Routing is driven by ``ProviderSpec.backend`` in the registry.
|
||||||
"""
|
"""
|
||||||
from nanobot.providers.base import GenerationSettings
|
from nanobot.providers.factory import make_provider
|
||||||
from nanobot.providers.registry import find_by_name
|
|
||||||
|
|
||||||
model = config.agents.defaults.model
|
try:
|
||||||
provider_name = config.get_provider_name(model)
|
return make_provider(config)
|
||||||
p = config.get_provider(model)
|
except ValueError as exc:
|
||||||
spec = find_by_name(provider_name) if provider_name else None
|
console.print(f"[red]Error: {exc}[/red]")
|
||||||
backend = spec.backend if spec else "openai_compat"
|
raise typer.Exit(1) from exc
|
||||||
|
|
||||||
# --- validation ---
|
|
||||||
if backend == "azure_openai":
|
|
||||||
if not p or not p.api_key or not p.api_base:
|
|
||||||
console.print("[red]Error: Azure OpenAI requires api_key and api_base.[/red]")
|
|
||||||
console.print("Set them in ~/.nanobot/config.json under providers.azure_openai section")
|
|
||||||
console.print("Use the model field to specify the deployment name.")
|
|
||||||
raise typer.Exit(1)
|
|
||||||
elif backend == "openai_compat" and not model.startswith("bedrock/"):
|
|
||||||
needs_key = not (p and p.api_key)
|
|
||||||
exempt = spec and (spec.is_oauth or spec.is_local or spec.is_direct)
|
|
||||||
if needs_key and not exempt:
|
|
||||||
console.print("[red]Error: No API key configured.[/red]")
|
|
||||||
console.print("Set one in ~/.nanobot/config.json under providers section")
|
|
||||||
raise typer.Exit(1)
|
|
||||||
|
|
||||||
# --- instantiation by backend ---
|
|
||||||
if backend == "openai_codex":
|
|
||||||
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
|
|
||||||
|
|
||||||
provider = OpenAICodexProvider(default_model=model)
|
|
||||||
elif backend == "azure_openai":
|
|
||||||
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
|
||||||
|
|
||||||
provider = AzureOpenAIProvider(
|
|
||||||
api_key=p.api_key,
|
|
||||||
api_base=p.api_base,
|
|
||||||
default_model=model,
|
|
||||||
)
|
|
||||||
elif backend == "github_copilot":
|
|
||||||
from nanobot.providers.github_copilot_provider import GitHubCopilotProvider
|
|
||||||
provider = GitHubCopilotProvider(default_model=model)
|
|
||||||
elif backend == "anthropic":
|
|
||||||
from nanobot.providers.anthropic_provider import AnthropicProvider
|
|
||||||
|
|
||||||
provider = AnthropicProvider(
|
|
||||||
api_key=p.api_key if p else None,
|
|
||||||
api_base=config.get_api_base(model),
|
|
||||||
default_model=model,
|
|
||||||
extra_headers=p.extra_headers if p else None,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
|
||||||
|
|
||||||
provider = OpenAICompatProvider(
|
|
||||||
api_key=p.api_key if p else None,
|
|
||||||
api_base=config.get_api_base(model),
|
|
||||||
default_model=model,
|
|
||||||
extra_headers=p.extra_headers if p else None,
|
|
||||||
spec=spec,
|
|
||||||
)
|
|
||||||
|
|
||||||
defaults = config.agents.defaults
|
|
||||||
provider.generation = GenerationSettings(
|
|
||||||
temperature=defaults.temperature,
|
|
||||||
max_tokens=defaults.max_tokens,
|
|
||||||
reasoning_effort=defaults.reasoning_effort,
|
|
||||||
)
|
|
||||||
return provider
|
|
||||||
|
|
||||||
|
|
||||||
def _load_runtime_config(config: str | None = None, workspace: str | None = None) -> Config:
|
def _load_runtime_config(config: str | None = None, workspace: str | None = None) -> Config:
|
||||||
@ -664,6 +604,7 @@ def _run_gateway(
|
|||||||
from nanobot.cron.service import CronService
|
from nanobot.cron.service import CronService
|
||||||
from nanobot.cron.types import CronJob
|
from nanobot.cron.types import CronJob
|
||||||
from nanobot.heartbeat.service import HeartbeatService
|
from nanobot.heartbeat.service import HeartbeatService
|
||||||
|
from nanobot.providers.factory import build_provider_snapshot, load_provider_snapshot
|
||||||
from nanobot.session.manager import SessionManager
|
from nanobot.session.manager import SessionManager
|
||||||
|
|
||||||
port = port if port is not None else config.gateway.port
|
port = port if port is not None else config.gateway.port
|
||||||
@ -671,7 +612,12 @@ def _run_gateway(
|
|||||||
console.print(f"{__logo__} Starting nanobot gateway version {__version__} on port {port}...")
|
console.print(f"{__logo__} Starting nanobot gateway version {__version__} on port {port}...")
|
||||||
sync_workspace_templates(config.workspace_path)
|
sync_workspace_templates(config.workspace_path)
|
||||||
bus = MessageBus()
|
bus = MessageBus()
|
||||||
provider = _make_provider(config)
|
try:
|
||||||
|
provider_snapshot = build_provider_snapshot(config)
|
||||||
|
except ValueError as exc:
|
||||||
|
console.print(f"[red]Error: {exc}[/red]")
|
||||||
|
raise typer.Exit(1) from exc
|
||||||
|
provider = provider_snapshot.provider
|
||||||
session_manager = SessionManager(config.workspace_path)
|
session_manager = SessionManager(config.workspace_path)
|
||||||
|
|
||||||
# Preserve existing single-workspace installs, but keep custom workspaces clean.
|
# Preserve existing single-workspace installs, but keep custom workspaces clean.
|
||||||
@ -687,9 +633,9 @@ def _run_gateway(
|
|||||||
bus=bus,
|
bus=bus,
|
||||||
provider=provider,
|
provider=provider,
|
||||||
workspace=config.workspace_path,
|
workspace=config.workspace_path,
|
||||||
model=config.agents.defaults.model,
|
model=provider_snapshot.model,
|
||||||
max_iterations=config.agents.defaults.max_tool_iterations,
|
max_iterations=config.agents.defaults.max_tool_iterations,
|
||||||
context_window_tokens=config.agents.defaults.context_window_tokens,
|
context_window_tokens=provider_snapshot.context_window_tokens,
|
||||||
web_config=config.tools.web,
|
web_config=config.tools.web,
|
||||||
context_block_limit=config.agents.defaults.context_block_limit,
|
context_block_limit=config.agents.defaults.context_block_limit,
|
||||||
max_tool_result_chars=config.agents.defaults.max_tool_result_chars,
|
max_tool_result_chars=config.agents.defaults.max_tool_result_chars,
|
||||||
@ -706,6 +652,8 @@ def _run_gateway(
|
|||||||
session_ttl_minutes=config.agents.defaults.session_ttl_minutes,
|
session_ttl_minutes=config.agents.defaults.session_ttl_minutes,
|
||||||
consolidation_ratio=config.agents.defaults.consolidation_ratio,
|
consolidation_ratio=config.agents.defaults.consolidation_ratio,
|
||||||
tools_config=config.tools,
|
tools_config=config.tools,
|
||||||
|
provider_snapshot_loader=load_provider_snapshot,
|
||||||
|
provider_signature=provider_snapshot.signature,
|
||||||
)
|
)
|
||||||
|
|
||||||
from nanobot.agent.loop import UNIFIED_SESSION_KEY
|
from nanobot.agent.loop import UNIFIED_SESSION_KEY
|
||||||
|
|||||||
@ -120,62 +120,6 @@ class Nanobot:
|
|||||||
|
|
||||||
def _make_provider(config: Any) -> Any:
|
def _make_provider(config: Any) -> Any:
|
||||||
"""Create the LLM provider from config (extracted from CLI)."""
|
"""Create the LLM provider from config (extracted from CLI)."""
|
||||||
from nanobot.providers.base import GenerationSettings
|
from nanobot.providers.factory import make_provider
|
||||||
from nanobot.providers.registry import find_by_name
|
|
||||||
|
|
||||||
model = config.agents.defaults.model
|
return make_provider(config)
|
||||||
provider_name = config.get_provider_name(model)
|
|
||||||
p = config.get_provider(model)
|
|
||||||
spec = find_by_name(provider_name) if provider_name else None
|
|
||||||
backend = spec.backend if spec else "openai_compat"
|
|
||||||
|
|
||||||
if backend == "azure_openai":
|
|
||||||
if not p or not p.api_key or not p.api_base:
|
|
||||||
raise ValueError("Azure OpenAI requires api_key and api_base in config.")
|
|
||||||
elif backend == "openai_compat" and not model.startswith("bedrock/"):
|
|
||||||
needs_key = not (p and p.api_key)
|
|
||||||
exempt = spec and (spec.is_oauth or spec.is_local or spec.is_direct)
|
|
||||||
if needs_key and not exempt:
|
|
||||||
raise ValueError(f"No API key configured for provider '{provider_name}'.")
|
|
||||||
|
|
||||||
if backend == "openai_codex":
|
|
||||||
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
|
|
||||||
|
|
||||||
provider = OpenAICodexProvider(default_model=model)
|
|
||||||
elif backend == "github_copilot":
|
|
||||||
from nanobot.providers.github_copilot_provider import GitHubCopilotProvider
|
|
||||||
|
|
||||||
provider = GitHubCopilotProvider(default_model=model)
|
|
||||||
elif backend == "azure_openai":
|
|
||||||
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
|
||||||
|
|
||||||
provider = AzureOpenAIProvider(
|
|
||||||
api_key=p.api_key, api_base=p.api_base, default_model=model
|
|
||||||
)
|
|
||||||
elif backend == "anthropic":
|
|
||||||
from nanobot.providers.anthropic_provider import AnthropicProvider
|
|
||||||
|
|
||||||
provider = AnthropicProvider(
|
|
||||||
api_key=p.api_key if p else None,
|
|
||||||
api_base=config.get_api_base(model),
|
|
||||||
default_model=model,
|
|
||||||
extra_headers=p.extra_headers if p else None,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
|
||||||
|
|
||||||
provider = OpenAICompatProvider(
|
|
||||||
api_key=p.api_key if p else None,
|
|
||||||
api_base=config.get_api_base(model),
|
|
||||||
default_model=model,
|
|
||||||
extra_headers=p.extra_headers if p else None,
|
|
||||||
spec=spec,
|
|
||||||
)
|
|
||||||
|
|
||||||
defaults = config.agents.defaults
|
|
||||||
provider.generation = GenerationSettings(
|
|
||||||
temperature=defaults.temperature,
|
|
||||||
max_tokens=defaults.max_tokens,
|
|
||||||
reasoning_effort=defaults.reasoning_effort,
|
|
||||||
)
|
|
||||||
return provider
|
|
||||||
|
|||||||
112
nanobot/providers/factory.py
Normal file
112
nanobot/providers/factory.py
Normal file
@ -0,0 +1,112 @@
|
|||||||
|
"""Create LLM providers from config."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from nanobot.config.schema import Config
|
||||||
|
from nanobot.providers.base import GenerationSettings, LLMProvider
|
||||||
|
from nanobot.providers.registry import find_by_name
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ProviderSnapshot:
|
||||||
|
provider: LLMProvider
|
||||||
|
model: str
|
||||||
|
context_window_tokens: int
|
||||||
|
signature: tuple[object, ...]
|
||||||
|
|
||||||
|
|
||||||
|
def make_provider(config: Config) -> LLMProvider:
|
||||||
|
"""Create the LLM provider implied by config."""
|
||||||
|
model = config.agents.defaults.model
|
||||||
|
provider_name = config.get_provider_name(model)
|
||||||
|
p = config.get_provider(model)
|
||||||
|
spec = find_by_name(provider_name) if provider_name else None
|
||||||
|
backend = spec.backend if spec else "openai_compat"
|
||||||
|
|
||||||
|
if backend == "azure_openai":
|
||||||
|
if not p or not p.api_key or not p.api_base:
|
||||||
|
raise ValueError("Azure OpenAI requires api_key and api_base in config.")
|
||||||
|
elif backend == "openai_compat" and not model.startswith("bedrock/"):
|
||||||
|
needs_key = not (p and p.api_key)
|
||||||
|
exempt = spec and (spec.is_oauth or spec.is_local or spec.is_direct)
|
||||||
|
if needs_key and not exempt:
|
||||||
|
raise ValueError(f"No API key configured for provider '{provider_name}'.")
|
||||||
|
|
||||||
|
if backend == "openai_codex":
|
||||||
|
from nanobot.providers.openai_codex_provider import OpenAICodexProvider
|
||||||
|
|
||||||
|
provider = OpenAICodexProvider(default_model=model)
|
||||||
|
elif backend == "azure_openai":
|
||||||
|
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
|
||||||
|
|
||||||
|
provider = AzureOpenAIProvider(
|
||||||
|
api_key=p.api_key,
|
||||||
|
api_base=p.api_base,
|
||||||
|
default_model=model,
|
||||||
|
)
|
||||||
|
elif backend == "github_copilot":
|
||||||
|
from nanobot.providers.github_copilot_provider import GitHubCopilotProvider
|
||||||
|
|
||||||
|
provider = GitHubCopilotProvider(default_model=model)
|
||||||
|
elif backend == "anthropic":
|
||||||
|
from nanobot.providers.anthropic_provider import AnthropicProvider
|
||||||
|
|
||||||
|
provider = AnthropicProvider(
|
||||||
|
api_key=p.api_key if p else None,
|
||||||
|
api_base=config.get_api_base(model),
|
||||||
|
default_model=model,
|
||||||
|
extra_headers=p.extra_headers if p else None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
||||||
|
|
||||||
|
provider = OpenAICompatProvider(
|
||||||
|
api_key=p.api_key if p else None,
|
||||||
|
api_base=config.get_api_base(model),
|
||||||
|
default_model=model,
|
||||||
|
extra_headers=p.extra_headers if p else None,
|
||||||
|
spec=spec,
|
||||||
|
)
|
||||||
|
|
||||||
|
defaults = config.agents.defaults
|
||||||
|
provider.generation = GenerationSettings(
|
||||||
|
temperature=defaults.temperature,
|
||||||
|
max_tokens=defaults.max_tokens,
|
||||||
|
reasoning_effort=defaults.reasoning_effort,
|
||||||
|
)
|
||||||
|
return provider
|
||||||
|
|
||||||
|
|
||||||
|
def provider_signature(config: Config) -> tuple[object, ...]:
|
||||||
|
"""Return the config fields that affect the primary LLM provider."""
|
||||||
|
model = config.agents.defaults.model
|
||||||
|
defaults = config.agents.defaults
|
||||||
|
return (
|
||||||
|
model,
|
||||||
|
defaults.provider,
|
||||||
|
config.get_provider_name(model),
|
||||||
|
config.get_api_key(model),
|
||||||
|
config.get_api_base(model),
|
||||||
|
defaults.max_tokens,
|
||||||
|
defaults.temperature,
|
||||||
|
defaults.reasoning_effort,
|
||||||
|
defaults.context_window_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_provider_snapshot(config: Config) -> ProviderSnapshot:
|
||||||
|
return ProviderSnapshot(
|
||||||
|
provider=make_provider(config),
|
||||||
|
model=config.agents.defaults.model,
|
||||||
|
context_window_tokens=config.agents.defaults.context_window_tokens,
|
||||||
|
signature=provider_signature(config),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_provider_snapshot(config_path: Path | None = None) -> ProviderSnapshot:
|
||||||
|
from nanobot.config.loader import load_config, resolve_config_env_vars
|
||||||
|
|
||||||
|
return build_provider_snapshot(resolve_config_env_vars(load_config(config_path)))
|
||||||
@ -205,3 +205,37 @@ async def test_ask_user_keeps_buttons_for_telegram(tmp_path):
|
|||||||
assert response is not None
|
assert response is not None
|
||||||
assert response.content == "Install the optional package?"
|
assert response.content == "Install the optional package?"
|
||||||
assert response.buttons == [["Install", "Skip"]]
|
assert response.buttons == [["Install", "Skip"]]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ask_user_keeps_buttons_for_websocket(tmp_path):
|
||||||
|
async def chat_with_retry(**kwargs):
|
||||||
|
return LLMResponse(
|
||||||
|
content="",
|
||||||
|
finish_reason="tool_calls",
|
||||||
|
tool_calls=[
|
||||||
|
ToolCallRequest(
|
||||||
|
id="call_ask",
|
||||||
|
name="ask_user",
|
||||||
|
arguments={
|
||||||
|
"question": "Install the optional package?",
|
||||||
|
"options": ["Install", "Skip"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
loop = AgentLoop(
|
||||||
|
bus=MessageBus(),
|
||||||
|
provider=_make_provider(chat_with_retry),
|
||||||
|
workspace=tmp_path,
|
||||||
|
model="test-model",
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await loop._process_message(
|
||||||
|
InboundMessage(channel="websocket", sender_id="user", chat_id="123", content="set it up")
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response is not None
|
||||||
|
assert response.content == "Install the optional package?"
|
||||||
|
assert response.buttons == [["Install", "Skip"]]
|
||||||
|
|||||||
49
tests/agent/test_runtime_refresh.py
Normal file
49
tests/agent/test_runtime_refresh.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.providers.factory import ProviderSnapshot
|
||||||
|
|
||||||
|
|
||||||
|
def _provider(default_model: str, max_tokens: int = 123) -> MagicMock:
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = default_model
|
||||||
|
provider.generation = SimpleNamespace(max_tokens=max_tokens)
|
||||||
|
return provider
|
||||||
|
|
||||||
|
|
||||||
|
def test_provider_refresh_updates_all_model_dependents(tmp_path: Path) -> None:
|
||||||
|
old_provider = _provider("old-model")
|
||||||
|
new_provider = _provider("new-model", max_tokens=456)
|
||||||
|
loop = AgentLoop(
|
||||||
|
bus=MessageBus(),
|
||||||
|
provider=old_provider,
|
||||||
|
workspace=tmp_path,
|
||||||
|
model="old-model",
|
||||||
|
context_window_tokens=1000,
|
||||||
|
provider_snapshot_loader=lambda: ProviderSnapshot(
|
||||||
|
provider=new_provider,
|
||||||
|
model="new-model",
|
||||||
|
context_window_tokens=2000,
|
||||||
|
signature=("new-model",),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
loop._refresh_provider_snapshot()
|
||||||
|
|
||||||
|
assert loop.provider is new_provider
|
||||||
|
assert loop.model == "new-model"
|
||||||
|
assert loop.context_window_tokens == 2000
|
||||||
|
assert loop.runner.provider is new_provider
|
||||||
|
assert loop.subagents.provider is new_provider
|
||||||
|
assert loop.subagents.model == "new-model"
|
||||||
|
assert loop.subagents.runner.provider is new_provider
|
||||||
|
assert loop.consolidator.provider is new_provider
|
||||||
|
assert loop.consolidator.model == "new-model"
|
||||||
|
assert loop.consolidator.context_window_tokens == 2000
|
||||||
|
assert loop.consolidator.max_completion_tokens == 456
|
||||||
|
assert loop.dream.provider is new_provider
|
||||||
|
assert loop.dream.model == "new-model"
|
||||||
|
assert loop.dream._runner.provider is new_provider
|
||||||
@ -26,6 +26,8 @@ from nanobot.channels.websocket import (
|
|||||||
_parse_query,
|
_parse_query,
|
||||||
_parse_request_path,
|
_parse_request_path,
|
||||||
)
|
)
|
||||||
|
from nanobot.config.loader import load_config, save_config
|
||||||
|
from nanobot.config.schema import Config
|
||||||
|
|
||||||
# -- Shared helpers (aligned with test_websocket_integration.py) ---------------
|
# -- Shared helpers (aligned with test_websocket_integration.py) ---------------
|
||||||
|
|
||||||
@ -178,6 +180,7 @@ async def test_send_delivers_json_message_with_media_and_reply() -> None:
|
|||||||
content="hello",
|
content="hello",
|
||||||
reply_to="m1",
|
reply_to="m1",
|
||||||
media=["/tmp/a.png"],
|
media=["/tmp/a.png"],
|
||||||
|
buttons=[["Yes", "No"]],
|
||||||
)
|
)
|
||||||
await channel.send(msg)
|
await channel.send(msg)
|
||||||
|
|
||||||
@ -185,9 +188,11 @@ async def test_send_delivers_json_message_with_media_and_reply() -> None:
|
|||||||
payload = json.loads(mock_ws.send.call_args[0][0])
|
payload = json.loads(mock_ws.send.call_args[0][0])
|
||||||
assert payload["event"] == "message"
|
assert payload["event"] == "message"
|
||||||
assert payload["chat_id"] == "chat-1"
|
assert payload["chat_id"] == "chat-1"
|
||||||
assert payload["text"] == "hello"
|
assert payload["text"] == "hello\n\n1. Yes\n2. No"
|
||||||
|
assert payload["button_prompt"] == "hello"
|
||||||
assert payload["reply_to"] == "m1"
|
assert payload["reply_to"] == "m1"
|
||||||
assert payload["media"] == ["/tmp/a.png"]
|
assert payload["media"] == ["/tmp/a.png"]
|
||||||
|
assert payload["buttons"] == [["Yes", "No"]]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -436,6 +441,72 @@ async def test_http_route_issues_token_then_websocket_requires_it(bus: MagicMock
|
|||||||
await server_task
|
await server_task
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_settings_api_returns_safe_subset_and_updates_whitelist(
|
||||||
|
bus: MagicMock,
|
||||||
|
monkeypatch,
|
||||||
|
tmp_path,
|
||||||
|
) -> None:
|
||||||
|
port = 29891
|
||||||
|
config_path = tmp_path / "config.json"
|
||||||
|
config = Config()
|
||||||
|
config.agents.defaults.model = "openai/gpt-4o"
|
||||||
|
config.providers.openai.api_key = "secret-key"
|
||||||
|
save_config(config, config_path)
|
||||||
|
monkeypatch.setattr("nanobot.config.loader._current_config_path", config_path)
|
||||||
|
|
||||||
|
channel = _ch(bus, port=port)
|
||||||
|
channel._api_tokens["tok"] = time.monotonic() + 300
|
||||||
|
|
||||||
|
server_task = asyncio.create_task(channel.start())
|
||||||
|
await asyncio.sleep(0.3)
|
||||||
|
|
||||||
|
try:
|
||||||
|
settings = await _http_get(
|
||||||
|
f"http://127.0.0.1:{port}/api/settings",
|
||||||
|
headers={"Authorization": "Bearer tok"},
|
||||||
|
)
|
||||||
|
assert settings.status_code == 200
|
||||||
|
body = settings.json()
|
||||||
|
assert body["agent"]["model"] == "openai/gpt-4o"
|
||||||
|
assert body["agent"]["provider"] == "openai"
|
||||||
|
assert {"name": "auto", "label": "Auto"} in body["providers"]
|
||||||
|
assert body["agent"]["has_api_key"] is True
|
||||||
|
assert "secret-key" not in settings.text
|
||||||
|
|
||||||
|
updated = await _http_get(
|
||||||
|
"http://127.0.0.1:"
|
||||||
|
f"{port}/api/settings/update?model=openrouter/test"
|
||||||
|
"&provider=openrouter",
|
||||||
|
headers={"Authorization": "Bearer tok"},
|
||||||
|
)
|
||||||
|
assert updated.status_code == 200
|
||||||
|
assert updated.json()["requires_restart"] is True
|
||||||
|
|
||||||
|
saved = load_config(config_path)
|
||||||
|
assert saved.agents.defaults.model == "openrouter/test"
|
||||||
|
assert saved.agents.defaults.provider == "openrouter"
|
||||||
|
finally:
|
||||||
|
await channel.stop()
|
||||||
|
await server_task
|
||||||
|
|
||||||
|
|
||||||
|
def test_settings_payload_normalizes_camel_case_provider(
|
||||||
|
bus: MagicMock,
|
||||||
|
monkeypatch,
|
||||||
|
tmp_path,
|
||||||
|
) -> None:
|
||||||
|
config_path = tmp_path / "config.json"
|
||||||
|
config = Config()
|
||||||
|
config.agents.defaults.provider = "minimaxAnthropic"
|
||||||
|
save_config(config, config_path)
|
||||||
|
monkeypatch.setattr("nanobot.config.loader._current_config_path", config_path)
|
||||||
|
|
||||||
|
body = _ch(bus)._settings_payload()
|
||||||
|
|
||||||
|
assert body["agent"]["provider"] == "minimax_anthropic"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_end_to_end_server_pushes_streaming_deltas_to_client(bus: MagicMock) -> None:
|
async def test_end_to_end_server_pushes_streaming_deltas_to_client(bus: MagicMock) -> None:
|
||||||
port = 29880
|
port = 29880
|
||||||
|
|||||||
@ -12,6 +12,7 @@ from nanobot.bus.events import OutboundMessage
|
|||||||
from nanobot.cli.commands import _make_provider, app
|
from nanobot.cli.commands import _make_provider, app
|
||||||
from nanobot.config.schema import Config
|
from nanobot.config.schema import Config
|
||||||
from nanobot.cron.types import CronJob, CronPayload
|
from nanobot.cron.types import CronJob, CronPayload
|
||||||
|
from nanobot.providers.factory import ProviderSnapshot
|
||||||
from nanobot.providers.openai_codex_provider import _strip_model_prefix
|
from nanobot.providers.openai_codex_provider import _strip_model_prefix
|
||||||
from nanobot.providers.registry import find_by_name
|
from nanobot.providers.registry import find_by_name
|
||||||
|
|
||||||
@ -776,6 +777,15 @@ def _stop_gateway_provider(_config) -> object:
|
|||||||
raise _StopGatewayError("stop")
|
raise _StopGatewayError("stop")
|
||||||
|
|
||||||
|
|
||||||
|
def _test_provider_snapshot(provider: object, config: Config) -> ProviderSnapshot:
|
||||||
|
return ProviderSnapshot(
|
||||||
|
provider=provider,
|
||||||
|
model=config.agents.defaults.model,
|
||||||
|
context_window_tokens=config.agents.defaults.context_window_tokens,
|
||||||
|
signature=("test",),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _patch_cli_command_runtime(
|
def _patch_cli_command_runtime(
|
||||||
monkeypatch,
|
monkeypatch,
|
||||||
config: Config,
|
config: Config,
|
||||||
@ -788,6 +798,8 @@ def _patch_cli_command_runtime(
|
|||||||
cron_service=None,
|
cron_service=None,
|
||||||
get_cron_dir=None,
|
get_cron_dir=None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
provider_factory = make_provider or (lambda _config: object())
|
||||||
|
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"nanobot.config.loader.set_config_path",
|
"nanobot.config.loader.set_config_path",
|
||||||
set_config_path or (lambda _path: None),
|
set_config_path or (lambda _path: None),
|
||||||
@ -800,7 +812,15 @@ def _patch_cli_command_runtime(
|
|||||||
)
|
)
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
"nanobot.cli.commands._make_provider",
|
"nanobot.cli.commands._make_provider",
|
||||||
make_provider or (lambda _config: object()),
|
provider_factory,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"nanobot.providers.factory.build_provider_snapshot",
|
||||||
|
lambda _config: _test_provider_snapshot(provider_factory(_config), _config),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"nanobot.providers.factory.load_provider_snapshot",
|
||||||
|
lambda _config_path=None: _test_provider_snapshot(provider_factory(config), config),
|
||||||
)
|
)
|
||||||
|
|
||||||
if message_bus is not None:
|
if message_bus is not None:
|
||||||
@ -941,6 +961,14 @@ def test_gateway_cron_evaluator_receives_scheduled_reminder_context(
|
|||||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||||
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||||
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: provider)
|
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: provider)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"nanobot.providers.factory.build_provider_snapshot",
|
||||||
|
lambda _config: _test_provider_snapshot(provider, _config),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"nanobot.providers.factory.load_provider_snapshot",
|
||||||
|
lambda _config_path=None: _test_provider_snapshot(provider, config),
|
||||||
|
)
|
||||||
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: bus)
|
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: bus)
|
||||||
|
|
||||||
class _FakeSession:
|
class _FakeSession:
|
||||||
@ -1082,6 +1110,14 @@ def test_gateway_cron_job_suppresses_intermediate_progress(
|
|||||||
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
monkeypatch.setattr("nanobot.config.loader.load_config", lambda _path=None: config)
|
||||||
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||||
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object())
|
monkeypatch.setattr("nanobot.cli.commands._make_provider", lambda _config: object())
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"nanobot.providers.factory.build_provider_snapshot",
|
||||||
|
lambda _config: _test_provider_snapshot(object(), _config),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"nanobot.providers.factory.load_provider_snapshot",
|
||||||
|
lambda _config_path=None: _test_provider_snapshot(object(), config),
|
||||||
|
)
|
||||||
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: bus)
|
monkeypatch.setattr("nanobot.bus.queue.MessageBus", lambda: bus)
|
||||||
monkeypatch.setattr("nanobot.session.manager.SessionManager", lambda _workspace: object())
|
monkeypatch.setattr("nanobot.session.manager.SessionManager", lambda _workspace: object())
|
||||||
|
|
||||||
|
|||||||
@ -2,6 +2,7 @@ import { useCallback, useEffect, useMemo, useRef, useState } from "react";
|
|||||||
import { useTranslation } from "react-i18next";
|
import { useTranslation } from "react-i18next";
|
||||||
import { DeleteConfirm } from "@/components/DeleteConfirm";
|
import { DeleteConfirm } from "@/components/DeleteConfirm";
|
||||||
import { Sidebar } from "@/components/Sidebar";
|
import { Sidebar } from "@/components/Sidebar";
|
||||||
|
import { SettingsView } from "@/components/settings/SettingsView";
|
||||||
import { ThreadShell } from "@/components/thread/ThreadShell";
|
import { ThreadShell } from "@/components/thread/ThreadShell";
|
||||||
import { Sheet, SheetContent } from "@/components/ui/sheet";
|
import { Sheet, SheetContent } from "@/components/ui/sheet";
|
||||||
import { preloadMarkdownText } from "@/components/MarkdownText";
|
import { preloadMarkdownText } from "@/components/MarkdownText";
|
||||||
@ -25,6 +26,7 @@ type BootState =
|
|||||||
|
|
||||||
const SIDEBAR_STORAGE_KEY = "nanobot-webui.sidebar";
|
const SIDEBAR_STORAGE_KEY = "nanobot-webui.sidebar";
|
||||||
const SIDEBAR_WIDTH = 279;
|
const SIDEBAR_WIDTH = 279;
|
||||||
|
type ShellView = "chat" | "settings";
|
||||||
|
|
||||||
function readSidebarOpen(): boolean {
|
function readSidebarOpen(): boolean {
|
||||||
if (typeof window === "undefined") return true;
|
if (typeof window === "undefined") return true;
|
||||||
@ -136,22 +138,29 @@ export default function App() {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const handleModelNameChange = (modelName: string | null) => {
|
||||||
|
setState((current) =>
|
||||||
|
current.status === "ready" ? { ...current, modelName } : current,
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<ClientProvider
|
<ClientProvider
|
||||||
client={state.client}
|
client={state.client}
|
||||||
token={state.token}
|
token={state.token}
|
||||||
modelName={state.modelName}
|
modelName={state.modelName}
|
||||||
>
|
>
|
||||||
<Shell />
|
<Shell onModelNameChange={handleModelNameChange} />
|
||||||
</ClientProvider>
|
</ClientProvider>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
function Shell() {
|
function Shell({ onModelNameChange }: { onModelNameChange: (modelName: string | null) => void }) {
|
||||||
const { t, i18n } = useTranslation();
|
const { t, i18n } = useTranslation();
|
||||||
const { theme, toggle } = useTheme();
|
const { theme, toggle } = useTheme();
|
||||||
const { sessions, loading, refresh, createChat, deleteChat } = useSessions();
|
const { sessions, loading, refresh, createChat, deleteChat } = useSessions();
|
||||||
const [activeKey, setActiveKey] = useState<string | null>(null);
|
const [activeKey, setActiveKey] = useState<string | null>(null);
|
||||||
|
const [view, setView] = useState<ShellView>("chat");
|
||||||
const [desktopSidebarOpen, setDesktopSidebarOpen] =
|
const [desktopSidebarOpen, setDesktopSidebarOpen] =
|
||||||
useState<boolean>(readSidebarOpen);
|
useState<boolean>(readSidebarOpen);
|
||||||
const [mobileSidebarOpen, setMobileSidebarOpen] = useState(false);
|
const [mobileSidebarOpen, setMobileSidebarOpen] = useState(false);
|
||||||
@ -208,6 +217,7 @@ function Shell() {
|
|||||||
try {
|
try {
|
||||||
const chatId = await createChat();
|
const chatId = await createChat();
|
||||||
setActiveKey(`websocket:${chatId}`);
|
setActiveKey(`websocket:${chatId}`);
|
||||||
|
setView("chat");
|
||||||
setMobileSidebarOpen(false);
|
setMobileSidebarOpen(false);
|
||||||
return chatId;
|
return chatId;
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
@ -219,6 +229,7 @@ function Shell() {
|
|||||||
const onSelectChat = useCallback(
|
const onSelectChat = useCallback(
|
||||||
(key: string) => {
|
(key: string) => {
|
||||||
setActiveKey(key);
|
setActiveKey(key);
|
||||||
|
setView("chat");
|
||||||
setMobileSidebarOpen(false);
|
setMobileSidebarOpen(false);
|
||||||
},
|
},
|
||||||
[],
|
[],
|
||||||
@ -266,6 +277,11 @@ function Shell() {
|
|||||||
onRefresh: () => void refresh(),
|
onRefresh: () => void refresh(),
|
||||||
onRequestDelete: (key: string, label: string) =>
|
onRequestDelete: (key: string, label: string) =>
|
||||||
setPendingDelete({ key, label }),
|
setPendingDelete({ key, label }),
|
||||||
|
activeView: view,
|
||||||
|
onOpenSettings: () => {
|
||||||
|
setView("settings" as const);
|
||||||
|
setMobileSidebarOpen(false);
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
return (
|
return (
|
||||||
@ -303,6 +319,14 @@ function Shell() {
|
|||||||
</Sheet>
|
</Sheet>
|
||||||
|
|
||||||
<main className="flex h-full min-w-0 flex-1 flex-col">
|
<main className="flex h-full min-w-0 flex-1 flex-col">
|
||||||
|
{view === "settings" ? (
|
||||||
|
<SettingsView
|
||||||
|
theme={theme}
|
||||||
|
onToggleTheme={toggle}
|
||||||
|
onBackToChat={() => setView("chat")}
|
||||||
|
onModelNameChange={onModelNameChange}
|
||||||
|
/>
|
||||||
|
) : (
|
||||||
<ThreadShell
|
<ThreadShell
|
||||||
session={activeSession}
|
session={activeSession}
|
||||||
title={headerTitle}
|
title={headerTitle}
|
||||||
@ -311,6 +335,7 @@ function Shell() {
|
|||||||
onNewChat={onNewChat}
|
onNewChat={onNewChat}
|
||||||
hideSidebarToggleOnDesktop={desktopSidebarOpen}
|
hideSidebarToggleOnDesktop={desktopSidebarOpen}
|
||||||
/>
|
/>
|
||||||
|
)}
|
||||||
</main>
|
</main>
|
||||||
|
|
||||||
<DeleteConfirm
|
<DeleteConfirm
|
||||||
|
|||||||
@ -1,9 +1,8 @@
|
|||||||
import { Moon, PanelLeftClose, Plus, RefreshCcw, Sun } from "lucide-react";
|
import { Moon, PanelLeftClose, RefreshCcw, Settings, SquarePen, Sun } from "lucide-react";
|
||||||
import { useTranslation } from "react-i18next";
|
import { useTranslation } from "react-i18next";
|
||||||
|
|
||||||
import { ChatList } from "@/components/ChatList";
|
import { ChatList } from "@/components/ChatList";
|
||||||
import { ConnectionBadge } from "@/components/ConnectionBadge";
|
import { ConnectionBadge } from "@/components/ConnectionBadge";
|
||||||
import { LanguageSwitcher } from "@/components/LanguageSwitcher";
|
|
||||||
import { Button } from "@/components/ui/button";
|
import { Button } from "@/components/ui/button";
|
||||||
import { Separator } from "@/components/ui/separator";
|
import { Separator } from "@/components/ui/separator";
|
||||||
import type { ChatSummary } from "@/lib/types";
|
import type { ChatSummary } from "@/lib/types";
|
||||||
@ -19,22 +18,25 @@ interface SidebarProps {
|
|||||||
onRefresh: () => void;
|
onRefresh: () => void;
|
||||||
onRequestDelete: (key: string, label: string) => void;
|
onRequestDelete: (key: string, label: string) => void;
|
||||||
onCollapse: () => void;
|
onCollapse: () => void;
|
||||||
|
activeView?: "chat" | "settings";
|
||||||
|
onOpenSettings: () => void;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function Sidebar(props: SidebarProps) {
|
export function Sidebar(props: SidebarProps) {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
return (
|
return (
|
||||||
<aside className="flex h-full w-full flex-col border-r border-sidebar-border/70 bg-sidebar text-sidebar-foreground">
|
<aside className="flex h-full w-full flex-col border-r border-sidebar-border/70 bg-sidebar text-sidebar-foreground">
|
||||||
<div className="flex items-center justify-between px-2 py-2">
|
<div className="flex items-center justify-between px-3 pb-2 pt-3">
|
||||||
<Button
|
<picture className="block min-w-0">
|
||||||
variant="ghost"
|
<source srcSet="/brand/nanobot_logo.webp" type="image/webp" />
|
||||||
size="icon"
|
<img
|
||||||
aria-label={t("sidebar.collapse")}
|
src="/brand/nanobot_logo.png"
|
||||||
onClick={props.onCollapse}
|
alt="nanobot"
|
||||||
className="h-7 w-7 rounded-lg text-muted-foreground hover:bg-sidebar-accent hover:text-sidebar-foreground"
|
className="h-7 w-auto select-none object-contain"
|
||||||
>
|
draggable={false}
|
||||||
<PanelLeftClose className="h-3.5 w-3.5" />
|
/>
|
||||||
</Button>
|
</picture>
|
||||||
|
<div className="flex items-center gap-0.5">
|
||||||
<Button
|
<Button
|
||||||
variant="ghost"
|
variant="ghost"
|
||||||
size="icon"
|
size="icon"
|
||||||
@ -48,19 +50,28 @@ export function Sidebar(props: SidebarProps) {
|
|||||||
<Moon className="h-3.5 w-3.5" />
|
<Moon className="h-3.5 w-3.5" />
|
||||||
)}
|
)}
|
||||||
</Button>
|
</Button>
|
||||||
|
<Button
|
||||||
|
variant="ghost"
|
||||||
|
size="icon"
|
||||||
|
aria-label={t("sidebar.collapse")}
|
||||||
|
onClick={props.onCollapse}
|
||||||
|
className="h-7 w-7 rounded-lg text-muted-foreground hover:bg-sidebar-accent hover:text-sidebar-foreground"
|
||||||
|
>
|
||||||
|
<PanelLeftClose className="h-3.5 w-3.5" />
|
||||||
|
</Button>
|
||||||
</div>
|
</div>
|
||||||
<div className="px-2 pb-2.5">
|
</div>
|
||||||
|
<div className="px-2 pb-2">
|
||||||
<Button
|
<Button
|
||||||
onClick={props.onNewChat}
|
onClick={props.onNewChat}
|
||||||
className="h-8.5 w-full justify-start gap-2 rounded-lg border border-sidebar-border/80 bg-card/25 px-3 text-[13px] font-medium text-sidebar-foreground shadow-none hover:bg-sidebar-accent/80"
|
className="h-9 w-full justify-start gap-2 rounded-full px-3 text-[13px] font-medium text-sidebar-foreground/90 hover:bg-sidebar-accent hover:text-sidebar-foreground"
|
||||||
variant="outline"
|
variant="ghost"
|
||||||
>
|
>
|
||||||
<Plus className="h-3.5 w-3.5" />
|
<SquarePen className="h-3.5 w-3.5" />
|
||||||
{t("sidebar.newChat")}
|
{t("sidebar.newChat")}
|
||||||
</Button>
|
</Button>
|
||||||
</div>
|
</div>
|
||||||
<Separator className="bg-sidebar-border/70" />
|
<div className="flex items-center justify-between px-3 pb-1.5 pt-2.5 text-[11px] font-medium text-muted-foreground">
|
||||||
<div className="flex items-center justify-between px-2.5 py-2 text-[11px] font-medium text-muted-foreground">
|
|
||||||
<span>{t("sidebar.recent")}</span>
|
<span>{t("sidebar.recent")}</span>
|
||||||
<Button
|
<Button
|
||||||
variant="ghost"
|
variant="ghost"
|
||||||
@ -81,10 +92,17 @@ export function Sidebar(props: SidebarProps) {
|
|||||||
onRequestDelete={props.onRequestDelete}
|
onRequestDelete={props.onRequestDelete}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
<Separator className="bg-sidebar-border/70" />
|
<Separator className="bg-sidebar-border/50" />
|
||||||
<div className="flex items-center justify-between gap-2 px-2.5 py-2 text-xs">
|
<div className="flex items-center justify-between gap-2 px-2.5 py-2 text-xs">
|
||||||
<ConnectionBadge />
|
<ConnectionBadge />
|
||||||
<LanguageSwitcher />
|
<Button
|
||||||
|
onClick={props.onOpenSettings}
|
||||||
|
className="h-7 gap-1.5 rounded-md px-2 text-[11px] text-muted-foreground hover:bg-sidebar-accent hover:text-sidebar-foreground"
|
||||||
|
variant={props.activeView === "settings" ? "secondary" : "ghost"}
|
||||||
|
>
|
||||||
|
<Settings className="h-3.5 w-3.5" />
|
||||||
|
Settings
|
||||||
|
</Button>
|
||||||
</div>
|
</div>
|
||||||
</aside>
|
</aside>
|
||||||
);
|
);
|
||||||
|
|||||||
245
webui/src/components/settings/SettingsView.tsx
Normal file
245
webui/src/components/settings/SettingsView.tsx
Normal file
@ -0,0 +1,245 @@
|
|||||||
|
import { useCallback, useEffect, useMemo, useState } from "react";
|
||||||
|
import { ChevronLeft, Loader2 } from "lucide-react";
|
||||||
|
|
||||||
|
import { LanguageSwitcher } from "@/components/LanguageSwitcher";
|
||||||
|
import { Button } from "@/components/ui/button";
|
||||||
|
import { Input } from "@/components/ui/input";
|
||||||
|
import { fetchSettings, updateSettings } from "@/lib/api";
|
||||||
|
import { cn } from "@/lib/utils";
|
||||||
|
import { useClient } from "@/providers/ClientProvider";
|
||||||
|
import type { SettingsPayload } from "@/lib/types";
|
||||||
|
|
||||||
|
interface SettingsViewProps {
|
||||||
|
theme: "light" | "dark";
|
||||||
|
onToggleTheme: () => void;
|
||||||
|
onBackToChat: () => void;
|
||||||
|
onModelNameChange: (modelName: string | null) => void;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function SettingsView({
|
||||||
|
onBackToChat,
|
||||||
|
onModelNameChange,
|
||||||
|
}: SettingsViewProps) {
|
||||||
|
const { token } = useClient();
|
||||||
|
const [settings, setSettings] = useState<SettingsPayload | null>(null);
|
||||||
|
const [loading, setLoading] = useState(true);
|
||||||
|
const [saving, setSaving] = useState(false);
|
||||||
|
const [error, setError] = useState<string | null>(null);
|
||||||
|
const [form, setForm] = useState({
|
||||||
|
model: "",
|
||||||
|
provider: "auto",
|
||||||
|
});
|
||||||
|
|
||||||
|
const applyPayload = useCallback((payload: SettingsPayload) => {
|
||||||
|
setSettings(payload);
|
||||||
|
setForm({
|
||||||
|
model: payload.agent.model,
|
||||||
|
provider: payload.agent.provider,
|
||||||
|
});
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
let cancelled = false;
|
||||||
|
setLoading(true);
|
||||||
|
fetchSettings(token)
|
||||||
|
.then((payload) => {
|
||||||
|
if (!cancelled) {
|
||||||
|
applyPayload(payload);
|
||||||
|
setError(null);
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.catch((err) => {
|
||||||
|
if (!cancelled) setError((err as Error).message);
|
||||||
|
})
|
||||||
|
.finally(() => {
|
||||||
|
if (!cancelled) setLoading(false);
|
||||||
|
});
|
||||||
|
return () => {
|
||||||
|
cancelled = true;
|
||||||
|
};
|
||||||
|
}, [applyPayload, token]);
|
||||||
|
|
||||||
|
const dirty = useMemo(() => {
|
||||||
|
if (!settings) return false;
|
||||||
|
return (
|
||||||
|
form.model !== settings.agent.model ||
|
||||||
|
form.provider !== settings.agent.provider
|
||||||
|
);
|
||||||
|
}, [form, settings]);
|
||||||
|
|
||||||
|
const save = async () => {
|
||||||
|
if (!dirty || saving) return;
|
||||||
|
setSaving(true);
|
||||||
|
try {
|
||||||
|
const payload = await updateSettings(token, form);
|
||||||
|
applyPayload(payload);
|
||||||
|
onModelNameChange(payload.agent.model || null);
|
||||||
|
setError(null);
|
||||||
|
} catch (err) {
|
||||||
|
setError((err as Error).message);
|
||||||
|
} finally {
|
||||||
|
setSaving(false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="min-h-0 flex-1 overflow-y-auto bg-background">
|
||||||
|
<main className="mx-auto w-full max-w-[1000px] px-6 py-6">
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
onClick={onBackToChat}
|
||||||
|
className="mb-4 inline-flex items-center gap-1.5 text-xs font-medium text-muted-foreground hover:text-foreground"
|
||||||
|
>
|
||||||
|
<ChevronLeft className="h-3.5 w-3.5" />
|
||||||
|
Back to chat
|
||||||
|
</button>
|
||||||
|
|
||||||
|
<h1 className="mb-6 text-base font-semibold tracking-tight">General</h1>
|
||||||
|
|
||||||
|
{loading ? (
|
||||||
|
<div className="flex h-48 items-center justify-center text-sm text-muted-foreground">
|
||||||
|
<Loader2 className="mr-2 h-4 w-4 animate-spin" />
|
||||||
|
Loading settings...
|
||||||
|
</div>
|
||||||
|
) : error ? (
|
||||||
|
<SettingsGroup>
|
||||||
|
<SettingsRow title="Could not load settings">
|
||||||
|
<span className="max-w-[520px] text-sm text-muted-foreground">{error}</span>
|
||||||
|
</SettingsRow>
|
||||||
|
</SettingsGroup>
|
||||||
|
) : settings ? (
|
||||||
|
<SettingsSection
|
||||||
|
form={form}
|
||||||
|
setForm={setForm}
|
||||||
|
settings={settings}
|
||||||
|
dirty={dirty}
|
||||||
|
saving={saving}
|
||||||
|
onSave={save}
|
||||||
|
/>
|
||||||
|
) : null}
|
||||||
|
</main>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
function SettingsSection({
|
||||||
|
form,
|
||||||
|
setForm,
|
||||||
|
settings,
|
||||||
|
dirty,
|
||||||
|
saving,
|
||||||
|
onSave,
|
||||||
|
}: {
|
||||||
|
form: {
|
||||||
|
model: string;
|
||||||
|
provider: string;
|
||||||
|
};
|
||||||
|
setForm: React.Dispatch<React.SetStateAction<{
|
||||||
|
model: string;
|
||||||
|
provider: string;
|
||||||
|
}>>;
|
||||||
|
settings: SettingsPayload;
|
||||||
|
dirty: boolean;
|
||||||
|
saving: boolean;
|
||||||
|
onSave: () => void;
|
||||||
|
}) {
|
||||||
|
return (
|
||||||
|
<div className="space-y-7">
|
||||||
|
<section>
|
||||||
|
<h2 className="mb-2 px-2 text-xs font-medium text-muted-foreground">AI</h2>
|
||||||
|
<SettingsGroup>
|
||||||
|
<SettingsRow title="Provider">
|
||||||
|
<select
|
||||||
|
value={form.provider}
|
||||||
|
onChange={(event) => setForm((prev) => ({ ...prev, provider: event.target.value }))}
|
||||||
|
className={cn(
|
||||||
|
"h-8 w-[210px] rounded-md border border-input bg-background px-2 text-sm",
|
||||||
|
"outline-none transition-colors hover:bg-accent focus-visible:ring-2 focus-visible:ring-ring",
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
{settings.providers.map((provider) => (
|
||||||
|
<option key={provider.name} value={provider.name}>
|
||||||
|
{provider.label}
|
||||||
|
</option>
|
||||||
|
))}
|
||||||
|
</select>
|
||||||
|
</SettingsRow>
|
||||||
|
|
||||||
|
<SettingsRow title="Model">
|
||||||
|
<Input
|
||||||
|
value={form.model}
|
||||||
|
onChange={(event) => setForm((prev) => ({ ...prev, model: event.target.value }))}
|
||||||
|
className="h-8 w-[280px]"
|
||||||
|
/>
|
||||||
|
</SettingsRow>
|
||||||
|
|
||||||
|
{(dirty || saving || settings.requires_restart) ? (
|
||||||
|
<SettingsFooter
|
||||||
|
dirty={dirty}
|
||||||
|
saving={saving}
|
||||||
|
saved={settings.requires_restart && !dirty}
|
||||||
|
onSave={onSave}
|
||||||
|
/>
|
||||||
|
) : null}
|
||||||
|
</SettingsGroup>
|
||||||
|
</section>
|
||||||
|
|
||||||
|
<section>
|
||||||
|
<h2 className="mb-2 px-2 text-xs font-medium text-muted-foreground">Interface</h2>
|
||||||
|
<SettingsGroup>
|
||||||
|
<SettingsRow title="Language">
|
||||||
|
<LanguageSwitcher />
|
||||||
|
</SettingsRow>
|
||||||
|
</SettingsGroup>
|
||||||
|
</section>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
function SettingsGroup({ children }: { children: React.ReactNode }) {
|
||||||
|
return (
|
||||||
|
<div className="overflow-hidden rounded-xl border border-border/60 bg-card/80">
|
||||||
|
<div className="divide-y divide-border/50">{children}</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
function SettingsRow({
|
||||||
|
title,
|
||||||
|
children,
|
||||||
|
}: {
|
||||||
|
title: string;
|
||||||
|
children?: React.ReactNode;
|
||||||
|
}) {
|
||||||
|
return (
|
||||||
|
<div className="flex min-h-[52px] flex-col gap-3 px-3 py-2.5 sm:flex-row sm:items-center sm:justify-between">
|
||||||
|
<div className="min-w-0">
|
||||||
|
<div className="text-sm font-medium leading-5">{title}</div>
|
||||||
|
</div>
|
||||||
|
{children ? <div className="shrink-0 sm:ml-6">{children}</div> : null}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
function SettingsFooter({
|
||||||
|
dirty,
|
||||||
|
saving,
|
||||||
|
saved,
|
||||||
|
onSave,
|
||||||
|
}: {
|
||||||
|
dirty: boolean;
|
||||||
|
saving: boolean;
|
||||||
|
saved: boolean;
|
||||||
|
onSave: () => void;
|
||||||
|
}) {
|
||||||
|
return (
|
||||||
|
<div className="flex min-h-[52px] items-center justify-between gap-4 px-3 py-2.5">
|
||||||
|
<div className="text-sm text-muted-foreground">
|
||||||
|
{saved ? "Saved. Restart nanobot to apply." : "Unsaved changes."}
|
||||||
|
</div>
|
||||||
|
<Button size="sm" variant="outline" onClick={onSave} disabled={!dirty || saving}>
|
||||||
|
{saving ? "Saving" : "Save"}
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
108
webui/src/components/thread/AskUserPrompt.tsx
Normal file
108
webui/src/components/thread/AskUserPrompt.tsx
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
import { useCallback, useEffect, useRef, useState } from "react";
|
||||||
|
import { MessageSquareText } from "lucide-react";
|
||||||
|
|
||||||
|
import { Button } from "@/components/ui/button";
|
||||||
|
import { cn } from "@/lib/utils";
|
||||||
|
|
||||||
|
interface AskUserPromptProps {
|
||||||
|
question: string;
|
||||||
|
buttons: string[][];
|
||||||
|
onAnswer: (answer: string) => void;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function AskUserPrompt({
|
||||||
|
question,
|
||||||
|
buttons,
|
||||||
|
onAnswer,
|
||||||
|
}: AskUserPromptProps) {
|
||||||
|
const [customOpen, setCustomOpen] = useState(false);
|
||||||
|
const [custom, setCustom] = useState("");
|
||||||
|
const inputRef = useRef<HTMLTextAreaElement>(null);
|
||||||
|
const options = buttons.flat().filter(Boolean);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (customOpen) {
|
||||||
|
inputRef.current?.focus();
|
||||||
|
}
|
||||||
|
}, [customOpen]);
|
||||||
|
|
||||||
|
const submitCustom = useCallback(() => {
|
||||||
|
const answer = custom.trim();
|
||||||
|
if (!answer) return;
|
||||||
|
onAnswer(answer);
|
||||||
|
setCustom("");
|
||||||
|
setCustomOpen(false);
|
||||||
|
}, [custom, onAnswer]);
|
||||||
|
|
||||||
|
if (options.length === 0) return null;
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div
|
||||||
|
className={cn(
|
||||||
|
"mx-auto mb-2 w-full max-w-[49.5rem] rounded-[16px] border border-primary/30",
|
||||||
|
"bg-card/95 p-3 shadow-sm backdrop-blur",
|
||||||
|
)}
|
||||||
|
role="group"
|
||||||
|
aria-label="Question"
|
||||||
|
>
|
||||||
|
<div className="mb-2 flex items-start gap-2">
|
||||||
|
<div className="mt-0.5 rounded-full bg-primary/10 p-1.5 text-primary">
|
||||||
|
<MessageSquareText className="h-3.5 w-3.5" aria-hidden />
|
||||||
|
</div>
|
||||||
|
<p className="min-w-0 flex-1 text-sm font-medium leading-5 text-foreground">
|
||||||
|
{question}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="grid gap-1.5 sm:grid-cols-2">
|
||||||
|
{options.map((option) => (
|
||||||
|
<Button
|
||||||
|
key={option}
|
||||||
|
type="button"
|
||||||
|
variant="outline"
|
||||||
|
size="sm"
|
||||||
|
onClick={() => onAnswer(option)}
|
||||||
|
className="justify-start rounded-[10px] px-3 text-left"
|
||||||
|
>
|
||||||
|
<span className="truncate">{option}</span>
|
||||||
|
</Button>
|
||||||
|
))}
|
||||||
|
<Button
|
||||||
|
type="button"
|
||||||
|
variant="ghost"
|
||||||
|
size="sm"
|
||||||
|
onClick={() => setCustomOpen((open) => !open)}
|
||||||
|
className="justify-start rounded-[10px] px-3 text-muted-foreground"
|
||||||
|
>
|
||||||
|
Other...
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{customOpen ? (
|
||||||
|
<div className="mt-2 flex gap-2">
|
||||||
|
<textarea
|
||||||
|
ref={inputRef}
|
||||||
|
value={custom}
|
||||||
|
onChange={(event) => setCustom(event.target.value)}
|
||||||
|
onKeyDown={(event) => {
|
||||||
|
if (event.key === "Enter" && !event.shiftKey && !event.nativeEvent.isComposing) {
|
||||||
|
event.preventDefault();
|
||||||
|
submitCustom();
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
rows={1}
|
||||||
|
placeholder="Type your own answer..."
|
||||||
|
className={cn(
|
||||||
|
"min-h-9 flex-1 resize-none rounded-[10px] border border-border/70 bg-background",
|
||||||
|
"px-3 py-2 text-sm leading-5 outline-none placeholder:text-muted-foreground",
|
||||||
|
"focus-visible:ring-1 focus-visible:ring-primary/40",
|
||||||
|
)}
|
||||||
|
/>
|
||||||
|
<Button type="button" size="sm" onClick={submitCustom} disabled={!custom.trim()}>
|
||||||
|
Send
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
) : null}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
@ -1,6 +1,7 @@
|
|||||||
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
|
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
|
||||||
import { useTranslation } from "react-i18next";
|
import { useTranslation } from "react-i18next";
|
||||||
|
|
||||||
|
import { AskUserPrompt } from "@/components/thread/AskUserPrompt";
|
||||||
import { ThreadComposer } from "@/components/thread/ThreadComposer";
|
import { ThreadComposer } from "@/components/thread/ThreadComposer";
|
||||||
import { ThreadHeader } from "@/components/thread/ThreadHeader";
|
import { ThreadHeader } from "@/components/thread/ThreadHeader";
|
||||||
import { StreamErrorNotice } from "@/components/thread/StreamErrorNotice";
|
import { StreamErrorNotice } from "@/components/thread/StreamErrorNotice";
|
||||||
@ -57,6 +58,21 @@ export function ThreadShell({
|
|||||||
dismissStreamError,
|
dismissStreamError,
|
||||||
} = useNanobotStream(chatId, initial);
|
} = useNanobotStream(chatId, initial);
|
||||||
const showHeroComposer = messages.length === 0 && !loading;
|
const showHeroComposer = messages.length === 0 && !loading;
|
||||||
|
const pendingAsk = useMemo(() => {
|
||||||
|
for (let index = messages.length - 1; index >= 0; index -= 1) {
|
||||||
|
const message = messages[index];
|
||||||
|
if (message.kind === "trace") continue;
|
||||||
|
if (message.role === "user") return null;
|
||||||
|
if (message.role === "assistant" && message.buttons?.some((row) => row.length > 0)) {
|
||||||
|
return {
|
||||||
|
question: message.content,
|
||||||
|
buttons: message.buttons,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
if (message.role === "assistant") return null;
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}, [messages]);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (!chatId || loading) return;
|
if (!chatId || loading) return;
|
||||||
@ -152,6 +168,13 @@ export function ThreadShell({
|
|||||||
onDismiss={dismissStreamError}
|
onDismiss={dismissStreamError}
|
||||||
/>
|
/>
|
||||||
) : null}
|
) : null}
|
||||||
|
{pendingAsk ? (
|
||||||
|
<AskUserPrompt
|
||||||
|
question={pendingAsk.question}
|
||||||
|
buttons={pendingAsk.buttons}
|
||||||
|
onAnswer={send}
|
||||||
|
/>
|
||||||
|
) : null}
|
||||||
{session ? (
|
{session ? (
|
||||||
<ThreadComposer
|
<ThreadComposer
|
||||||
onSend={send}
|
onSend={send}
|
||||||
|
|||||||
@ -160,13 +160,15 @@ export function useNanobotStream(
|
|||||||
setIsStreaming(false);
|
setIsStreaming(false);
|
||||||
setMessages((prev) => {
|
setMessages((prev) => {
|
||||||
const filtered = activeId ? prev.filter((m) => m.id !== activeId) : prev;
|
const filtered = activeId ? prev.filter((m) => m.id !== activeId) : prev;
|
||||||
|
const content = ev.buttons?.length ? (ev.button_prompt ?? ev.text) : ev.text;
|
||||||
return [
|
return [
|
||||||
...filtered,
|
...filtered,
|
||||||
{
|
{
|
||||||
id: crypto.randomUUID(),
|
id: crypto.randomUUID(),
|
||||||
role: "assistant",
|
role: "assistant",
|
||||||
content: ev.text,
|
content,
|
||||||
createdAt: Date.now(),
|
createdAt: Date.now(),
|
||||||
|
...(ev.buttons && ev.buttons.length > 0 ? { buttons: ev.buttons } : {}),
|
||||||
...(media && media.length > 0 ? { media } : {}),
|
...(media && media.length > 0 ? { media } : {}),
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
import type { ChatSummary } from "./types";
|
import type { ChatSummary, SettingsPayload, SettingsUpdate } from "./types";
|
||||||
|
|
||||||
export class ApiError extends Error {
|
export class ApiError extends Error {
|
||||||
status: number;
|
status: number;
|
||||||
@ -104,3 +104,21 @@ export async function deleteSession(
|
|||||||
);
|
);
|
||||||
return body.deleted;
|
return body.deleted;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export async function fetchSettings(
|
||||||
|
token: string,
|
||||||
|
base: string = "",
|
||||||
|
): Promise<SettingsPayload> {
|
||||||
|
return request<SettingsPayload>(`${base}/api/settings`, token);
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function updateSettings(
|
||||||
|
token: string,
|
||||||
|
update: SettingsUpdate,
|
||||||
|
base: string = "",
|
||||||
|
): Promise<SettingsPayload> {
|
||||||
|
const query = new URLSearchParams();
|
||||||
|
if (update.model !== undefined) query.set("model", update.model);
|
||||||
|
if (update.provider !== undefined) query.set("provider", update.provider);
|
||||||
|
return request<SettingsPayload>(`${base}/api/settings/update?${query}`, token);
|
||||||
|
}
|
||||||
|
|||||||
@ -44,6 +44,8 @@ export interface UIMessage {
|
|||||||
images?: UIImage[];
|
images?: UIImage[];
|
||||||
/** Signed or local UI-renderable media attachments. */
|
/** Signed or local UI-renderable media attachments. */
|
||||||
media?: UIMediaAttachment[];
|
media?: UIMediaAttachment[];
|
||||||
|
/** Optional answer choices for a pending ask_user question. */
|
||||||
|
buttons?: string[][];
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface ChatSummary {
|
export interface ChatSummary {
|
||||||
@ -64,6 +66,28 @@ export interface BootstrapResponse {
|
|||||||
model_name?: string | null;
|
model_name?: string | null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface SettingsPayload {
|
||||||
|
agent: {
|
||||||
|
model: string;
|
||||||
|
provider: string;
|
||||||
|
resolved_provider: string | null;
|
||||||
|
has_api_key: boolean;
|
||||||
|
};
|
||||||
|
providers: Array<{
|
||||||
|
name: string;
|
||||||
|
label: string;
|
||||||
|
}>;
|
||||||
|
runtime: {
|
||||||
|
config_path: string;
|
||||||
|
};
|
||||||
|
requires_restart: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface SettingsUpdate {
|
||||||
|
model?: string;
|
||||||
|
provider?: string;
|
||||||
|
}
|
||||||
|
|
||||||
export type ConnectionStatus =
|
export type ConnectionStatus =
|
||||||
| "idle"
|
| "idle"
|
||||||
| "connecting"
|
| "connecting"
|
||||||
@ -82,6 +106,9 @@ export type InboundEvent =
|
|||||||
reply_to?: string;
|
reply_to?: string;
|
||||||
media?: string[];
|
media?: string[];
|
||||||
media_urls?: Array<{ url: string; name?: string }>;
|
media_urls?: Array<{ url: string; name?: string }>;
|
||||||
|
buttons?: string[][];
|
||||||
|
/** Original prompt before the websocket text fallback appends buttons. */
|
||||||
|
button_prompt?: string;
|
||||||
/** Present when the frame is an agent breadcrumb (e.g. tool hint,
|
/** Present when the frame is an agent breadcrumb (e.g. tool hint,
|
||||||
* generic progress line) rather than a conversational reply. */
|
* generic progress line) rather than a conversational reply. */
|
||||||
kind?: "tool_hint" | "progress";
|
kind?: "tool_hint" | "progress";
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||||
|
|
||||||
import { deleteSession, fetchSessionMessages } from "@/lib/api";
|
import { deleteSession, fetchSessionMessages, updateSettings } from "@/lib/api";
|
||||||
|
|
||||||
describe("webui API helpers", () => {
|
describe("webui API helpers", () => {
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
@ -34,4 +34,18 @@ describe("webui API helpers", () => {
|
|||||||
}),
|
}),
|
||||||
);
|
);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it("serializes settings updates as a narrow query string", async () => {
|
||||||
|
await updateSettings("tok", {
|
||||||
|
model: "openrouter/test",
|
||||||
|
provider: "openrouter",
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(fetch).toHaveBeenCalledWith(
|
||||||
|
"/api/settings/update?model=openrouter%2Ftest&provider=openrouter",
|
||||||
|
expect.objectContaining({
|
||||||
|
headers: { Authorization: "Bearer tok" },
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@ -146,4 +146,44 @@ describe("App layout", () => {
|
|||||||
expect(screen.queryByText('Delete “First chat”?')).not.toBeInTheDocument();
|
expect(screen.queryByText('Delete “First chat”?')).not.toBeInTheDocument();
|
||||||
expect(document.body.style.pointerEvents).not.toBe("none");
|
expect(document.body.style.pointerEvents).not.toBe("none");
|
||||||
}, 15_000);
|
}, 15_000);
|
||||||
|
|
||||||
|
it("opens the Cursor-style settings view from the sidebar", async () => {
|
||||||
|
vi.stubGlobal(
|
||||||
|
"fetch",
|
||||||
|
vi.fn(async (input: RequestInfo | URL) => {
|
||||||
|
if (String(input).includes("/api/settings")) {
|
||||||
|
return {
|
||||||
|
ok: true,
|
||||||
|
status: 200,
|
||||||
|
json: async () => ({
|
||||||
|
agent: {
|
||||||
|
model: "openai/gpt-4o",
|
||||||
|
provider: "auto",
|
||||||
|
resolved_provider: "openai",
|
||||||
|
has_api_key: true,
|
||||||
|
},
|
||||||
|
providers: [
|
||||||
|
{ name: "auto", label: "Auto" },
|
||||||
|
{ name: "openai", label: "OpenAI" },
|
||||||
|
],
|
||||||
|
runtime: {
|
||||||
|
config_path: "/tmp/config.json",
|
||||||
|
},
|
||||||
|
requires_restart: false,
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
return { ok: false, status: 404, json: async () => ({}) };
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
render(<App />);
|
||||||
|
|
||||||
|
await waitFor(() => expect(connectSpy).toHaveBeenCalled());
|
||||||
|
fireEvent.click(screen.getByRole("button", { name: "Settings" }));
|
||||||
|
|
||||||
|
expect(await screen.findByRole("heading", { name: "General" })).toBeInTheDocument();
|
||||||
|
expect(screen.getByText("AI")).toBeInTheDocument();
|
||||||
|
expect(screen.getByDisplayValue("openai/gpt-4o")).toBeInTheDocument();
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@ -7,11 +7,22 @@ import { ClientProvider } from "@/providers/ClientProvider";
|
|||||||
|
|
||||||
function makeClient() {
|
function makeClient() {
|
||||||
const errorHandlers = new Set<(err: { kind: string }) => void>();
|
const errorHandlers = new Set<(err: { kind: string }) => void>();
|
||||||
|
const chatHandlers = new Map<string, Set<(ev: import("@/lib/types").InboundEvent) => void>>();
|
||||||
return {
|
return {
|
||||||
status: "open" as const,
|
status: "open" as const,
|
||||||
defaultChatId: null as string | null,
|
defaultChatId: null as string | null,
|
||||||
onStatus: () => () => {},
|
onStatus: () => () => {},
|
||||||
onChat: () => () => {},
|
onChat: (chatId: string, handler: (ev: import("@/lib/types").InboundEvent) => void) => {
|
||||||
|
let handlers = chatHandlers.get(chatId);
|
||||||
|
if (!handlers) {
|
||||||
|
handlers = new Set();
|
||||||
|
chatHandlers.set(chatId, handlers);
|
||||||
|
}
|
||||||
|
handlers.add(handler);
|
||||||
|
return () => {
|
||||||
|
handlers?.delete(handler);
|
||||||
|
};
|
||||||
|
},
|
||||||
onError: (handler: (err: { kind: string }) => void) => {
|
onError: (handler: (err: { kind: string }) => void) => {
|
||||||
errorHandlers.add(handler);
|
errorHandlers.add(handler);
|
||||||
return () => {
|
return () => {
|
||||||
@ -21,6 +32,9 @@ function makeClient() {
|
|||||||
_emitError(err: { kind: string }) {
|
_emitError(err: { kind: string }) {
|
||||||
for (const h of errorHandlers) h(err);
|
for (const h of errorHandlers) h(err);
|
||||||
},
|
},
|
||||||
|
_emitChat(chatId: string, ev: import("@/lib/types").InboundEvent) {
|
||||||
|
for (const h of chatHandlers.get(chatId) ?? []) h(ev);
|
||||||
|
},
|
||||||
sendMessage: vi.fn(),
|
sendMessage: vi.fn(),
|
||||||
newChat: vi.fn(),
|
newChat: vi.fn(),
|
||||||
attach: vi.fn(),
|
attach: vi.fn(),
|
||||||
@ -411,4 +425,46 @@ describe("ThreadShell", () => {
|
|||||||
await waitFor(() => expect(screen.getByText("from chat b")).toBeInTheDocument());
|
await waitFor(() => expect(screen.getByText("from chat b")).toBeInTheDocument());
|
||||||
expect(screen.queryByText("from chat a")).not.toBeInTheDocument();
|
expect(screen.queryByText("from chat a")).not.toBeInTheDocument();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it("renders ask_user options above the composer and sends selected answers", async () => {
|
||||||
|
const client = makeClient();
|
||||||
|
const onNewChat = vi.fn().mockResolvedValue("chat-a");
|
||||||
|
|
||||||
|
render(
|
||||||
|
wrap(
|
||||||
|
client,
|
||||||
|
<ThreadShell
|
||||||
|
session={session("chat-a")}
|
||||||
|
title="Chat chat-a"
|
||||||
|
onToggleSidebar={() => {}}
|
||||||
|
onGoHome={() => {}}
|
||||||
|
onNewChat={onNewChat}
|
||||||
|
/>,
|
||||||
|
),
|
||||||
|
);
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
client._emitChat("chat-a", {
|
||||||
|
event: "message",
|
||||||
|
chat_id: "chat-a",
|
||||||
|
text: "How should I continue?",
|
||||||
|
buttons: [["Short answer", "Detailed answer"]],
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(screen.getByRole("group", { name: "Question" })).toHaveTextContent(
|
||||||
|
"How should I continue?",
|
||||||
|
);
|
||||||
|
|
||||||
|
fireEvent.click(screen.getByRole("button", { name: "Short answer" }));
|
||||||
|
|
||||||
|
expect(client.sendMessage).toHaveBeenCalledWith(
|
||||||
|
"chat-a",
|
||||||
|
"Short answer",
|
||||||
|
undefined,
|
||||||
|
);
|
||||||
|
await waitFor(() => {
|
||||||
|
expect(screen.queryByRole("group", { name: "Question" })).not.toBeInTheDocument();
|
||||||
|
});
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@ -113,4 +113,27 @@ describe("useNanobotStream", () => {
|
|||||||
{ kind: "video", url: "/api/media/sig/payload", name: "demo.mp4" },
|
{ kind: "video", url: "/api/media/sig/payload", name: "demo.mp4" },
|
||||||
]);
|
]);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it("keeps assistant buttons on complete messages", () => {
|
||||||
|
const fake = fakeClient();
|
||||||
|
const { result } = renderHook(() => useNanobotStream("chat-q", []), {
|
||||||
|
wrapper: wrap(fake.client),
|
||||||
|
});
|
||||||
|
|
||||||
|
act(() => {
|
||||||
|
fake.emit("chat-q", {
|
||||||
|
event: "message",
|
||||||
|
chat_id: "chat-q",
|
||||||
|
text: "How should I continue?\n\n1. Short answer\n2. Detailed answer",
|
||||||
|
button_prompt: "How should I continue?",
|
||||||
|
buttons: [["Short answer", "Detailed answer"]],
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(result.current.messages).toHaveLength(1);
|
||||||
|
expect(result.current.messages[0].content).toBe("How should I continue?");
|
||||||
|
expect(result.current.messages[0].buttons).toEqual([
|
||||||
|
["Short answer", "Detailed answer"],
|
||||||
|
]);
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user