mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-02 17:32:39 +00:00
fix(onboard): require explicit save in interactive wizard
Keep onboarding edits in draft state until users choose Done or Save and Exit, so backing out or discarding the wizard no longer persists partial changes.
This commit is contained in:
parent
f45329aee4
commit
d6acf1abcb
@ -1,11 +1,11 @@
|
||||
"""CLI commands for nanobot."""
|
||||
|
||||
import asyncio
|
||||
from contextlib import contextmanager, nullcontext
|
||||
import os
|
||||
import select
|
||||
import signal
|
||||
import sys
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
@ -64,6 +64,7 @@ def _flush_pending_tty_input() -> None:
|
||||
|
||||
try:
|
||||
import termios
|
||||
|
||||
termios.tcflush(fd, termios.TCIFLUSH)
|
||||
return
|
||||
except Exception:
|
||||
@ -86,6 +87,7 @@ def _restore_terminal() -> None:
|
||||
return
|
||||
try:
|
||||
import termios
|
||||
|
||||
termios.tcsetattr(sys.stdin.fileno(), termios.TCSADRAIN, _SAVED_TERM_ATTRS)
|
||||
except Exception:
|
||||
pass
|
||||
@ -98,6 +100,7 @@ def _init_prompt_session() -> None:
|
||||
# Save terminal state so we can restore it on exit
|
||||
try:
|
||||
import termios
|
||||
|
||||
_SAVED_TERM_ATTRS = termios.tcgetattr(sys.stdin.fileno())
|
||||
except Exception:
|
||||
pass
|
||||
@ -110,7 +113,7 @@ def _init_prompt_session() -> None:
|
||||
_PROMPT_SESSION = PromptSession(
|
||||
history=FileHistory(str(history_file)),
|
||||
enable_open_in_editor=False,
|
||||
multiline=False, # Enter submits (single line mode)
|
||||
multiline=False, # Enter submits (single line mode)
|
||||
)
|
||||
|
||||
|
||||
@ -143,10 +146,9 @@ def _print_agent_response(response: str, render_markdown: bool) -> None:
|
||||
|
||||
async def _print_interactive_line(text: str) -> None:
|
||||
"""Print async interactive updates with prompt_toolkit-safe Rich styling."""
|
||||
|
||||
def _write() -> None:
|
||||
ansi = _render_interactive_ansi(
|
||||
lambda c: c.print(f" [dim]↳ {text}[/dim]")
|
||||
)
|
||||
ansi = _render_interactive_ansi(lambda c: c.print(f" [dim]↳ {text}[/dim]"))
|
||||
print_formatted_text(ANSI(ansi), end="")
|
||||
|
||||
await run_in_terminal(_write)
|
||||
@ -154,6 +156,7 @@ async def _print_interactive_line(text: str) -> None:
|
||||
|
||||
async def _print_interactive_response(response: str, render_markdown: bool) -> None:
|
||||
"""Print async interactive replies with prompt_toolkit-safe Rich styling."""
|
||||
|
||||
def _write() -> None:
|
||||
content = response or ""
|
||||
ansi = _render_interactive_ansi(
|
||||
@ -173,9 +176,9 @@ class _ThinkingSpinner:
|
||||
"""Spinner wrapper with pause support for clean progress output."""
|
||||
|
||||
def __init__(self, enabled: bool):
|
||||
self._spinner = console.status(
|
||||
"[dim]nanobot is thinking...[/dim]", spinner="dots"
|
||||
) if enabled else None
|
||||
self._spinner = (
|
||||
console.status("[dim]nanobot is thinking...[/dim]", spinner="dots") if enabled else None
|
||||
)
|
||||
self._active = False
|
||||
|
||||
def __enter__(self):
|
||||
@ -238,7 +241,6 @@ async def _read_interactive_input_async() -> str:
|
||||
raise KeyboardInterrupt from exc
|
||||
|
||||
|
||||
|
||||
def version_callback(value: bool):
|
||||
if value:
|
||||
console.print(f"{__logo__} nanobot v{__version__}")
|
||||
@ -247,9 +249,7 @@ def version_callback(value: bool):
|
||||
|
||||
@app.callback()
|
||||
def main(
|
||||
version: bool = typer.Option(
|
||||
None, "--version", "-v", callback=version_callback, is_eager=True
|
||||
),
|
||||
version: bool = typer.Option(None, "--version", "-v", callback=version_callback, is_eager=True),
|
||||
):
|
||||
"""nanobot - Personal AI Assistant."""
|
||||
pass
|
||||
@ -264,7 +264,9 @@ def main(
|
||||
def onboard(
|
||||
workspace: str | None = typer.Option(None, "--workspace", "-w", help="Workspace directory"),
|
||||
config: str | None = typer.Option(None, "--config", "-c", help="Path to config file"),
|
||||
non_interactive: bool = typer.Option(False, "--non-interactive", help="Skip interactive wizard"),
|
||||
non_interactive: bool = typer.Option(
|
||||
False, "--non-interactive", help="Skip interactive wizard"
|
||||
),
|
||||
):
|
||||
"""Initialize nanobot configuration and workspace."""
|
||||
from nanobot.config.loader import get_config_path, load_config, save_config, set_config_path
|
||||
@ -282,41 +284,53 @@ def onboard(
|
||||
loaded.agents.defaults.workspace = workspace
|
||||
return loaded
|
||||
|
||||
cfg: Config
|
||||
|
||||
# Non-interactive mode: simple config creation/update
|
||||
if non_interactive:
|
||||
if config_path.exists():
|
||||
console.print(f"[yellow]Config already exists at {config_path}[/yellow]")
|
||||
console.print(" [bold]y[/bold] = overwrite with defaults (existing values will be lost)")
|
||||
console.print(" [bold]N[/bold] = refresh config, keeping existing values and adding new fields")
|
||||
console.print(
|
||||
" [bold]y[/bold] = overwrite with defaults (existing values will be lost)"
|
||||
)
|
||||
console.print(
|
||||
" [bold]N[/bold] = refresh config, keeping existing values and adding new fields"
|
||||
)
|
||||
if typer.confirm("Overwrite?"):
|
||||
config = _apply_workspace_override(Config())
|
||||
save_config(config, config_path)
|
||||
cfg = _apply_workspace_override(Config())
|
||||
save_config(cfg, config_path)
|
||||
console.print(f"[green]✓[/green] Config reset to defaults at {config_path}")
|
||||
else:
|
||||
config = _apply_workspace_override(load_config(config_path))
|
||||
save_config(config, config_path)
|
||||
console.print(f"[green]✓[/green] Config refreshed at {config_path} (existing values preserved)")
|
||||
cfg = _apply_workspace_override(load_config(config_path))
|
||||
save_config(cfg, config_path)
|
||||
console.print(
|
||||
f"[green]✓[/green] Config refreshed at {config_path} (existing values preserved)"
|
||||
)
|
||||
else:
|
||||
config = _apply_workspace_override(Config())
|
||||
save_config(config, config_path)
|
||||
cfg = _apply_workspace_override(Config())
|
||||
save_config(cfg, config_path)
|
||||
console.print(f"[green]✓[/green] Created config at {config_path}")
|
||||
console.print("[dim]Config template now uses `maxTokens` + `contextWindowTokens`; `memoryWindow` is no longer a runtime setting.[/dim]")
|
||||
console.print(
|
||||
"[dim]Config template now uses `maxTokens` + `contextWindowTokens`; `memoryWindow` is no longer a runtime setting.[/dim]"
|
||||
)
|
||||
else:
|
||||
# Interactive mode: use wizard
|
||||
if config_path.exists():
|
||||
config = load_config()
|
||||
cfg = _apply_workspace_override(load_config(config_path))
|
||||
else:
|
||||
config = Config()
|
||||
save_config(config)
|
||||
console.print(f"[green]✓[/green] Created config at {config_path}")
|
||||
cfg = _apply_workspace_override(Config())
|
||||
|
||||
# Run interactive wizard
|
||||
from nanobot.cli.onboard_wizard import run_onboard
|
||||
|
||||
try:
|
||||
# Pass the config with workspace override applied as initial config
|
||||
config = run_onboard(initial_config=config)
|
||||
save_config(config, config_path)
|
||||
result = run_onboard(initial_config=cfg)
|
||||
if not result.should_save:
|
||||
console.print("[yellow]Configuration discarded. No changes were saved.[/yellow]")
|
||||
return
|
||||
|
||||
cfg = result.config
|
||||
save_config(cfg, config_path)
|
||||
console.print(f"[green]✓[/green] Config saved at {config_path}")
|
||||
except Exception as e:
|
||||
console.print(f"[red]✗[/red] Error during configuration: {e}")
|
||||
@ -326,15 +340,15 @@ def onboard(
|
||||
_onboard_plugins(config_path)
|
||||
|
||||
# Create workspace, preferring the configured workspace path.
|
||||
workspace = get_workspace_path(config.workspace_path)
|
||||
if not workspace.exists():
|
||||
workspace.mkdir(parents=True, exist_ok=True)
|
||||
console.print(f"[green]✓[/green] Created workspace at {workspace}")
|
||||
workspace_path = get_workspace_path(cfg.workspace_path)
|
||||
if not workspace_path.exists():
|
||||
workspace_path.mkdir(parents=True, exist_ok=True)
|
||||
console.print(f"[green]✓[/green] Created workspace at {workspace_path}")
|
||||
|
||||
sync_workspace_templates(workspace)
|
||||
sync_workspace_templates(workspace_path)
|
||||
|
||||
agent_cmd = 'nanobot agent -m "Hello!"'
|
||||
if config:
|
||||
if cfg:
|
||||
agent_cmd += f" --config {config_path}"
|
||||
|
||||
console.print(f"\n{__logo__} nanobot is ready!")
|
||||
@ -344,9 +358,11 @@ def onboard(
|
||||
console.print(" Get one at: https://openrouter.ai/keys")
|
||||
console.print(f" 2. Chat: [cyan]{agent_cmd}[/cyan]")
|
||||
else:
|
||||
console.print(" 1. Chat: [cyan]nanobot agent -m \"Hello!\"[/cyan]")
|
||||
console.print(' 1. Chat: [cyan]nanobot agent -m "Hello!"[/cyan]')
|
||||
console.print(" 2. Start gateway: [cyan]nanobot gateway[/cyan]")
|
||||
console.print("\n[dim]Want Telegram/WhatsApp? See: https://github.com/HKUDS/nanobot#-chat-apps[/dim]")
|
||||
console.print(
|
||||
"\n[dim]Want Telegram/WhatsApp? See: https://github.com/HKUDS/nanobot#-chat-apps[/dim]"
|
||||
)
|
||||
|
||||
|
||||
def _merge_missing_defaults(existing: Any, defaults: Any) -> Any:
|
||||
@ -403,6 +419,7 @@ def _make_provider(config: Config):
|
||||
# Custom: direct OpenAI-compatible endpoint, bypasses LiteLLM
|
||||
elif provider_name == "custom":
|
||||
from nanobot.providers.custom_provider import CustomProvider
|
||||
|
||||
provider = CustomProvider(
|
||||
api_key=p.api_key if p else "no-key",
|
||||
api_base=config.get_api_base(model) or "http://localhost:8000/v1",
|
||||
@ -424,6 +441,7 @@ def _make_provider(config: Config):
|
||||
# OpenVINO Model Server: direct OpenAI-compatible endpoint at /v3
|
||||
elif provider_name == "ovms":
|
||||
from nanobot.providers.custom_provider import CustomProvider
|
||||
|
||||
provider = CustomProvider(
|
||||
api_key=p.api_key if p else "no-key",
|
||||
api_base=config.get_api_base(model) or "http://localhost:8000/v3",
|
||||
@ -432,8 +450,13 @@ def _make_provider(config: Config):
|
||||
else:
|
||||
from nanobot.providers.litellm_provider import LiteLLMProvider
|
||||
from nanobot.providers.registry import find_by_name
|
||||
|
||||
spec = find_by_name(provider_name)
|
||||
if not model.startswith("bedrock/") and not (p and p.api_key) and not (spec and (spec.is_oauth or spec.is_local)):
|
||||
if (
|
||||
not model.startswith("bedrock/")
|
||||
and not (p and p.api_key)
|
||||
and not (spec and (spec.is_oauth or spec.is_local))
|
||||
):
|
||||
console.print("[red]Error: No API key configured.[/red]")
|
||||
console.print("Set one in ~/.nanobot/config.json under providers section")
|
||||
raise typer.Exit(1)
|
||||
@ -507,6 +530,7 @@ def gateway(
|
||||
|
||||
if verbose:
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
config = _load_runtime_config(config, workspace)
|
||||
@ -576,16 +600,23 @@ def gateway(
|
||||
|
||||
if job.payload.deliver and job.payload.to and response:
|
||||
should_notify = await evaluate_response(
|
||||
response, job.payload.message, provider, agent.model,
|
||||
response,
|
||||
job.payload.message,
|
||||
provider,
|
||||
agent.model,
|
||||
)
|
||||
if should_notify:
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
await bus.publish_outbound(OutboundMessage(
|
||||
channel=job.payload.channel or "cli",
|
||||
chat_id=job.payload.to,
|
||||
content=response,
|
||||
))
|
||||
|
||||
await bus.publish_outbound(
|
||||
OutboundMessage(
|
||||
channel=job.payload.channel or "cli",
|
||||
chat_id=job.payload.to,
|
||||
content=response,
|
||||
)
|
||||
)
|
||||
return response
|
||||
|
||||
cron.on_job = on_cron_job
|
||||
|
||||
# Create channel manager
|
||||
@ -626,10 +657,13 @@ def gateway(
|
||||
async def on_heartbeat_notify(response: str) -> None:
|
||||
"""Deliver a heartbeat response to the user's channel."""
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
|
||||
channel, chat_id = _pick_heartbeat_target()
|
||||
if channel == "cli":
|
||||
return # No external channel available to deliver to
|
||||
await bus.publish_outbound(OutboundMessage(channel=channel, chat_id=chat_id, content=response))
|
||||
await bus.publish_outbound(
|
||||
OutboundMessage(channel=channel, chat_id=chat_id, content=response)
|
||||
)
|
||||
|
||||
hb_cfg = config.gateway.heartbeat
|
||||
heartbeat = HeartbeatService(
|
||||
@ -665,6 +699,7 @@ def gateway(
|
||||
console.print("\nShutting down...")
|
||||
except Exception:
|
||||
import traceback
|
||||
|
||||
console.print("\n[red]Error: Gateway crashed unexpectedly[/red]")
|
||||
console.print(traceback.format_exc())
|
||||
finally:
|
||||
@ -677,8 +712,6 @@ def gateway(
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Agent Commands
|
||||
# ============================================================================
|
||||
@ -690,8 +723,12 @@ def agent(
|
||||
session_id: str = typer.Option("cli:direct", "--session", "-s", help="Session ID"),
|
||||
workspace: str | None = typer.Option(None, "--workspace", "-w", help="Workspace directory"),
|
||||
config: str | None = typer.Option(None, "--config", "-c", help="Config file path"),
|
||||
markdown: bool = typer.Option(True, "--markdown/--no-markdown", help="Render assistant output as Markdown"),
|
||||
logs: bool = typer.Option(False, "--logs/--no-logs", help="Show nanobot runtime logs during chat"),
|
||||
markdown: bool = typer.Option(
|
||||
True, "--markdown/--no-markdown", help="Render assistant output as Markdown"
|
||||
),
|
||||
logs: bool = typer.Option(
|
||||
False, "--logs/--no-logs", help="Show nanobot runtime logs during chat"
|
||||
),
|
||||
):
|
||||
"""Interact with the agent directly."""
|
||||
from loguru import logger
|
||||
@ -751,7 +788,9 @@ def agent(
|
||||
nonlocal _thinking
|
||||
_thinking = _ThinkingSpinner(enabled=not logs)
|
||||
with _thinking:
|
||||
response = await agent_loop.process_direct(message, session_id, on_progress=_cli_progress)
|
||||
response = await agent_loop.process_direct(
|
||||
message, session_id, on_progress=_cli_progress
|
||||
)
|
||||
_thinking = None
|
||||
_print_agent_response(response, render_markdown=markdown)
|
||||
await agent_loop.close_mcp()
|
||||
@ -760,8 +799,11 @@ def agent(
|
||||
else:
|
||||
# Interactive mode — route through bus like other channels
|
||||
from nanobot.bus.events import InboundMessage
|
||||
|
||||
_init_prompt_session()
|
||||
console.print(f"{__logo__} Interactive mode (type [bold]exit[/bold] or [bold]Ctrl+C[/bold] to quit)\n")
|
||||
console.print(
|
||||
f"{__logo__} Interactive mode (type [bold]exit[/bold] or [bold]Ctrl+C[/bold] to quit)\n"
|
||||
)
|
||||
|
||||
if ":" in session_id:
|
||||
cli_channel, cli_chat_id = session_id.split(":", 1)
|
||||
@ -777,11 +819,11 @@ def agent(
|
||||
signal.signal(signal.SIGINT, _handle_signal)
|
||||
signal.signal(signal.SIGTERM, _handle_signal)
|
||||
# SIGHUP is not available on Windows
|
||||
if hasattr(signal, 'SIGHUP'):
|
||||
if hasattr(signal, "SIGHUP"):
|
||||
signal.signal(signal.SIGHUP, _handle_signal)
|
||||
# Ignore SIGPIPE to prevent silent process termination when writing to closed pipes
|
||||
# SIGPIPE is not available on Windows
|
||||
if hasattr(signal, 'SIGPIPE'):
|
||||
if hasattr(signal, "SIGPIPE"):
|
||||
signal.signal(signal.SIGPIPE, signal.SIG_IGN)
|
||||
|
||||
async def run_interactive():
|
||||
@ -835,12 +877,14 @@ def agent(
|
||||
turn_done.clear()
|
||||
turn_response.clear()
|
||||
|
||||
await bus.publish_inbound(InboundMessage(
|
||||
channel=cli_channel,
|
||||
sender_id="user",
|
||||
chat_id=cli_chat_id,
|
||||
content=user_input,
|
||||
))
|
||||
await bus.publish_inbound(
|
||||
InboundMessage(
|
||||
channel=cli_channel,
|
||||
sender_id="user",
|
||||
chat_id=cli_chat_id,
|
||||
content=user_input,
|
||||
)
|
||||
)
|
||||
|
||||
nonlocal _thinking
|
||||
_thinking = _ThinkingSpinner(enabled=not logs)
|
||||
@ -982,7 +1026,11 @@ def channels_login():
|
||||
|
||||
env = {**os.environ}
|
||||
wa_cfg = getattr(config.channels, "whatsapp", None) or {}
|
||||
bridge_token = wa_cfg.get("bridgeToken", "") if isinstance(wa_cfg, dict) else getattr(wa_cfg, "bridge_token", "")
|
||||
bridge_token = (
|
||||
wa_cfg.get("bridgeToken", "")
|
||||
if isinstance(wa_cfg, dict)
|
||||
else getattr(wa_cfg, "bridge_token", "")
|
||||
)
|
||||
if bridge_token:
|
||||
env["BRIDGE_TOKEN"] = bridge_token
|
||||
env["AUTH_DIR"] = str(get_runtime_subdir("whatsapp-auth"))
|
||||
@ -1056,8 +1104,12 @@ def status():
|
||||
|
||||
console.print(f"{__logo__} nanobot Status\n")
|
||||
|
||||
console.print(f"Config: {config_path} {'[green]✓[/green]' if config_path.exists() else '[red]✗[/red]'}")
|
||||
console.print(f"Workspace: {workspace} {'[green]✓[/green]' if workspace.exists() else '[red]✗[/red]'}")
|
||||
console.print(
|
||||
f"Config: {config_path} {'[green]✓[/green]' if config_path.exists() else '[red]✗[/red]'}"
|
||||
)
|
||||
console.print(
|
||||
f"Workspace: {workspace} {'[green]✓[/green]' if workspace.exists() else '[red]✗[/red]'}"
|
||||
)
|
||||
|
||||
if config_path.exists():
|
||||
from nanobot.providers.registry import PROVIDERS
|
||||
@ -1079,7 +1131,9 @@ def status():
|
||||
console.print(f"{spec.label}: [dim]not set[/dim]")
|
||||
else:
|
||||
has_key = bool(p.api_key)
|
||||
console.print(f"{spec.label}: {'[green]✓[/green]' if has_key else '[dim]not set[/dim]'}")
|
||||
console.print(
|
||||
f"{spec.label}: {'[green]✓[/green]' if has_key else '[dim]not set[/dim]'}"
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
@ -1097,12 +1151,15 @@ def _register_login(name: str):
|
||||
def decorator(fn):
|
||||
_LOGIN_HANDLERS[name] = fn
|
||||
return fn
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@provider_app.command("login")
|
||||
def provider_login(
|
||||
provider: str = typer.Argument(..., help="OAuth provider (e.g. 'openai-codex', 'github-copilot')"),
|
||||
provider: str = typer.Argument(
|
||||
..., help="OAuth provider (e.g. 'openai-codex', 'github-copilot')"
|
||||
),
|
||||
):
|
||||
"""Authenticate with an OAuth provider."""
|
||||
from nanobot.providers.registry import PROVIDERS
|
||||
@ -1127,6 +1184,7 @@ def provider_login(
|
||||
def _login_openai_codex() -> None:
|
||||
try:
|
||||
from oauth_cli_kit import get_token, login_oauth_interactive
|
||||
|
||||
token = None
|
||||
try:
|
||||
token = get_token()
|
||||
@ -1141,7 +1199,9 @@ def _login_openai_codex() -> None:
|
||||
if not (token and token.access):
|
||||
console.print("[red]✗ Authentication failed[/red]")
|
||||
raise typer.Exit(1)
|
||||
console.print(f"[green]✓ Authenticated with OpenAI Codex[/green] [dim]{token.account_id}[/dim]")
|
||||
console.print(
|
||||
f"[green]✓ Authenticated with OpenAI Codex[/green] [dim]{token.account_id}[/dim]"
|
||||
)
|
||||
except ImportError:
|
||||
console.print("[red]oauth_cli_kit not installed. Run: pip install oauth-cli-kit[/red]")
|
||||
raise typer.Exit(1)
|
||||
@ -1155,7 +1215,12 @@ def _login_github_copilot() -> None:
|
||||
|
||||
async def _trigger():
|
||||
from litellm import acompletion
|
||||
await acompletion(model="github_copilot/gpt-4o", messages=[{"role": "user", "content": "hi"}], max_tokens=1)
|
||||
|
||||
await acompletion(
|
||||
model="github_copilot/gpt-4o",
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
max_tokens=1,
|
||||
)
|
||||
|
||||
try:
|
||||
asyncio.run(_trigger())
|
||||
|
||||
@ -2,7 +2,8 @@
|
||||
|
||||
import json
|
||||
import types
|
||||
from typing import Any, Callable, get_args, get_origin
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, get_args, get_origin
|
||||
|
||||
import questionary
|
||||
from loguru import logger
|
||||
@ -21,6 +22,15 @@ from nanobot.config.schema import Config
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
@dataclass
|
||||
class OnboardResult:
|
||||
"""Result of an onboarding session."""
|
||||
|
||||
config: Config
|
||||
should_save: bool
|
||||
|
||||
|
||||
# --- Field Hints for Select Fields ---
|
||||
# Maps field names to (choices, hint_text)
|
||||
# To add a new select field with hints, add an entry:
|
||||
@ -128,10 +138,12 @@ def _select_with_back(
|
||||
event.app.exit()
|
||||
|
||||
# Style
|
||||
style = Style.from_dict({
|
||||
"selected": "fg:green bold",
|
||||
"question": "fg:cyan",
|
||||
})
|
||||
style = Style.from_dict(
|
||||
{
|
||||
"selected": "fg:green bold",
|
||||
"question": "fg:cyan",
|
||||
}
|
||||
)
|
||||
|
||||
app = Application(layout=layout, key_bindings=bindings, style=style)
|
||||
try:
|
||||
@ -142,6 +154,7 @@ def _select_with_back(
|
||||
|
||||
return state["result"]
|
||||
|
||||
|
||||
# --- Type Introspection ---
|
||||
|
||||
|
||||
@ -268,9 +281,7 @@ def _show_main_menu_header() -> None:
|
||||
# Use Align.CENTER for the single line of text
|
||||
from rich.align import Align
|
||||
|
||||
console.print(
|
||||
Align.center(f"{__logo__} [bold cyan]nanobot[{__version__}][/bold cyan]")
|
||||
)
|
||||
console.print(Align.center(f"{__logo__} [bold cyan]nanobot[{__version__}][/bold cyan]"))
|
||||
console.print()
|
||||
|
||||
|
||||
@ -329,9 +340,7 @@ def _input_text(display_name: str, current: Any, field_type: str) -> Any:
|
||||
return value
|
||||
|
||||
|
||||
def _input_with_existing(
|
||||
display_name: str, current: Any, field_type: str
|
||||
) -> Any:
|
||||
def _input_with_existing(display_name: str, current: Any, field_type: str) -> Any:
|
||||
"""Handle input with 'keep existing' option for non-empty values."""
|
||||
has_existing = current is not None and current != "" and current != {} and current != []
|
||||
|
||||
@ -357,12 +366,8 @@ def _get_current_provider(model: BaseModel) -> str:
|
||||
return "auto"
|
||||
|
||||
|
||||
def _input_model_with_autocomplete(
|
||||
display_name: str, current: Any, provider: str
|
||||
) -> str | None:
|
||||
"""Get model input with autocomplete suggestions.
|
||||
|
||||
"""
|
||||
def _input_model_with_autocomplete(display_name: str, current: Any, provider: str) -> str | None:
|
||||
"""Get model input with autocomplete suggestions."""
|
||||
from prompt_toolkit.completion import Completer, Completion
|
||||
|
||||
default = str(current) if current else ""
|
||||
@ -431,7 +436,9 @@ def _input_context_window_with_recommendation(
|
||||
context_limit = get_model_context_limit(model_name, provider)
|
||||
|
||||
if context_limit:
|
||||
console.print(f"[green]✓ Recommended context window: {format_token_count(context_limit)} tokens[/green]")
|
||||
console.print(
|
||||
f"[green]✓ Recommended context window: {format_token_count(context_limit)} tokens[/green]"
|
||||
)
|
||||
return context_limit
|
||||
else:
|
||||
console.print("[yellow]⚠ Could not fetch model info, please enter manually[/yellow]")
|
||||
@ -458,83 +465,88 @@ def _configure_pydantic_model(
|
||||
display_name: str,
|
||||
*,
|
||||
skip_fields: set[str] | None = None,
|
||||
finalize_hook: Callable | None = None,
|
||||
) -> None:
|
||||
"""Configure a Pydantic model interactively."""
|
||||
) -> BaseModel | None:
|
||||
"""Configure a Pydantic model interactively.
|
||||
|
||||
Returns the updated model only when the user explicitly selects "Done".
|
||||
Back and cancel actions discard the section draft.
|
||||
"""
|
||||
skip_fields = skip_fields or set()
|
||||
working_model = model.model_copy(deep=True)
|
||||
|
||||
fields = []
|
||||
for field_name, field_info in type(model).model_fields.items():
|
||||
for field_name, field_info in type(working_model).model_fields.items():
|
||||
if field_name in skip_fields:
|
||||
continue
|
||||
fields.append((field_name, field_info))
|
||||
|
||||
if not fields:
|
||||
console.print(f"[dim]{display_name}: No configurable fields[/dim]")
|
||||
return
|
||||
return working_model
|
||||
|
||||
def get_choices() -> list[str]:
|
||||
choices = []
|
||||
for field_name, field_info in fields:
|
||||
value = getattr(model, field_name, None)
|
||||
value = getattr(working_model, field_name, None)
|
||||
display = _get_field_display_name(field_name, field_info)
|
||||
formatted = _format_value(value, rich=False)
|
||||
choices.append(f"{display}: {formatted}")
|
||||
return choices + ["✓ Done"]
|
||||
|
||||
while True:
|
||||
_show_config_panel(display_name, model, fields)
|
||||
_show_config_panel(display_name, working_model, fields)
|
||||
choices = get_choices()
|
||||
|
||||
answer = _select_with_back("Select field to configure:", choices)
|
||||
|
||||
if answer is _BACK_PRESSED:
|
||||
# User pressed Escape or Left arrow - go back
|
||||
if finalize_hook:
|
||||
finalize_hook(model)
|
||||
break
|
||||
if answer is _BACK_PRESSED or answer is None:
|
||||
return None
|
||||
|
||||
if answer == "✓ Done" or answer is None:
|
||||
if finalize_hook:
|
||||
finalize_hook(model)
|
||||
break
|
||||
if answer == "✓ Done":
|
||||
return working_model
|
||||
|
||||
field_idx = next((i for i, c in enumerate(choices) if c == answer), -1)
|
||||
if field_idx < 0 or field_idx >= len(fields):
|
||||
break
|
||||
return None
|
||||
|
||||
field_name, field_info = fields[field_idx]
|
||||
current_value = getattr(model, field_name, None)
|
||||
current_value = getattr(working_model, field_name, None)
|
||||
field_type, _ = _get_field_type_info(field_info)
|
||||
field_display = _get_field_display_name(field_name, field_info)
|
||||
|
||||
if field_type == "model":
|
||||
nested_model = current_value
|
||||
created_nested_model = nested_model is None
|
||||
if nested_model is None:
|
||||
_, nested_cls = _get_field_type_info(field_info)
|
||||
if nested_cls:
|
||||
nested_model = nested_cls()
|
||||
setattr(model, field_name, nested_model)
|
||||
|
||||
if nested_model and isinstance(nested_model, BaseModel):
|
||||
_configure_pydantic_model(nested_model, field_display)
|
||||
updated_nested_model = _configure_pydantic_model(nested_model, field_display)
|
||||
if updated_nested_model is not None:
|
||||
setattr(working_model, field_name, updated_nested_model)
|
||||
elif created_nested_model:
|
||||
setattr(working_model, field_name, None)
|
||||
continue
|
||||
|
||||
# Special handling for model field (autocomplete)
|
||||
if field_name == "model":
|
||||
provider = _get_current_provider(model)
|
||||
provider = _get_current_provider(working_model)
|
||||
new_value = _input_model_with_autocomplete(field_display, current_value, provider)
|
||||
if new_value is not None and new_value != current_value:
|
||||
setattr(model, field_name, new_value)
|
||||
setattr(working_model, field_name, new_value)
|
||||
# Auto-fill context_window_tokens if it's at default value
|
||||
_try_auto_fill_context_window(model, new_value)
|
||||
_try_auto_fill_context_window(working_model, new_value)
|
||||
continue
|
||||
|
||||
# Special handling for context_window_tokens field
|
||||
if field_name == "context_window_tokens":
|
||||
new_value = _input_context_window_with_recommendation(field_display, current_value, model)
|
||||
new_value = _input_context_window_with_recommendation(
|
||||
field_display, current_value, working_model
|
||||
)
|
||||
if new_value is not None:
|
||||
setattr(model, field_name, new_value)
|
||||
setattr(working_model, field_name, new_value)
|
||||
continue
|
||||
|
||||
# Special handling for select fields with hints (e.g., reasoning_effort)
|
||||
@ -542,23 +554,25 @@ def _configure_pydantic_model(
|
||||
choices_list, hint = _SELECT_FIELD_HINTS[field_name]
|
||||
select_choices = choices_list + ["(clear/unset)"]
|
||||
console.print(f"[dim] Hint: {hint}[/dim]")
|
||||
new_value = _select_with_back(field_display, select_choices, default=current_value or select_choices[0])
|
||||
new_value = _select_with_back(
|
||||
field_display, select_choices, default=current_value or select_choices[0]
|
||||
)
|
||||
if new_value is _BACK_PRESSED:
|
||||
continue
|
||||
if new_value == "(clear/unset)":
|
||||
setattr(model, field_name, None)
|
||||
setattr(working_model, field_name, None)
|
||||
elif new_value is not None:
|
||||
setattr(model, field_name, new_value)
|
||||
setattr(working_model, field_name, new_value)
|
||||
continue
|
||||
|
||||
if field_type == "bool":
|
||||
new_value = _input_bool(field_display, current_value)
|
||||
if new_value is not None:
|
||||
setattr(model, field_name, new_value)
|
||||
setattr(working_model, field_name, new_value)
|
||||
else:
|
||||
new_value = _input_with_existing(field_display, current_value, field_type)
|
||||
if new_value is not None:
|
||||
setattr(model, field_name, new_value)
|
||||
setattr(working_model, field_name, new_value)
|
||||
|
||||
|
||||
def _try_auto_fill_context_window(model: BaseModel, new_model_name: str) -> None:
|
||||
@ -589,7 +603,9 @@ def _try_auto_fill_context_window(model: BaseModel, new_model_name: str) -> None
|
||||
|
||||
if context_limit:
|
||||
setattr(model, "context_window_tokens", context_limit)
|
||||
console.print(f"[green]✓ Auto-filled context window: {format_token_count(context_limit)} tokens[/green]")
|
||||
console.print(
|
||||
f"[green]✓ Auto-filled context window: {format_token_count(context_limit)} tokens[/green]"
|
||||
)
|
||||
else:
|
||||
console.print("[dim]ℹ Could not auto-fill context window (model not in database)[/dim]")
|
||||
|
||||
@ -637,10 +653,12 @@ def _configure_provider(config: Config, provider_name: str) -> None:
|
||||
if default_api_base and not provider_config.api_base:
|
||||
provider_config.api_base = default_api_base
|
||||
|
||||
_configure_pydantic_model(
|
||||
updated_provider = _configure_pydantic_model(
|
||||
provider_config,
|
||||
display_name,
|
||||
)
|
||||
if updated_provider is not None:
|
||||
setattr(config.providers, provider_name, updated_provider)
|
||||
|
||||
|
||||
def _configure_providers(config: Config) -> None:
|
||||
@ -747,15 +765,13 @@ def _configure_channel(config: Config, channel_name: str) -> None:
|
||||
|
||||
model = config_cls.model_validate(channel_dict) if channel_dict else config_cls()
|
||||
|
||||
def finalize(model: BaseModel):
|
||||
new_dict = model.model_dump(by_alias=True, exclude_none=True)
|
||||
setattr(config.channels, channel_name, new_dict)
|
||||
|
||||
_configure_pydantic_model(
|
||||
updated_channel = _configure_pydantic_model(
|
||||
model,
|
||||
display_name,
|
||||
finalize_hook=finalize,
|
||||
)
|
||||
if updated_channel is not None:
|
||||
new_dict = updated_channel.model_dump(by_alias=True, exclude_none=True)
|
||||
setattr(config.channels, channel_name, new_dict)
|
||||
|
||||
|
||||
def _configure_channels(config: Config) -> None:
|
||||
@ -798,13 +814,25 @@ def _configure_general_settings(config: Config, section: str) -> None:
|
||||
model, display_name = section_map[section]
|
||||
|
||||
if section == "Tools":
|
||||
_configure_pydantic_model(
|
||||
updated_model = _configure_pydantic_model(
|
||||
model,
|
||||
display_name,
|
||||
skip_fields={"mcp_servers"},
|
||||
)
|
||||
else:
|
||||
_configure_pydantic_model(model, display_name)
|
||||
updated_model = _configure_pydantic_model(model, display_name)
|
||||
|
||||
if updated_model is None:
|
||||
return
|
||||
|
||||
if section == "Agent Settings":
|
||||
config.agents.defaults = updated_model
|
||||
elif section == "Gateway":
|
||||
config.gateway = updated_model
|
||||
elif section == "Tools":
|
||||
config.tools = updated_model
|
||||
elif section == "Channel Common":
|
||||
config.channels = updated_model
|
||||
|
||||
|
||||
def _configure_agents(config: Config) -> None:
|
||||
@ -938,7 +966,35 @@ def _show_summary(config: Config) -> None:
|
||||
# --- Main Entry Point ---
|
||||
|
||||
|
||||
def run_onboard(initial_config: Config | None = None) -> Config:
|
||||
def _has_unsaved_changes(original: Config, current: Config) -> bool:
|
||||
"""Return True when the onboarding session has committed changes."""
|
||||
return original.model_dump(by_alias=True) != current.model_dump(by_alias=True)
|
||||
|
||||
|
||||
def _prompt_main_menu_exit(has_unsaved_changes: bool) -> str:
|
||||
"""Resolve how to leave the main menu."""
|
||||
if not has_unsaved_changes:
|
||||
return "discard"
|
||||
|
||||
answer = questionary.select(
|
||||
"You have unsaved changes. What would you like to do?",
|
||||
choices=[
|
||||
"💾 Save and Exit",
|
||||
"🗑️ Exit Without Saving",
|
||||
"↩ Resume Editing",
|
||||
],
|
||||
default="↩ Resume Editing",
|
||||
qmark="→",
|
||||
).ask()
|
||||
|
||||
if answer == "💾 Save and Exit":
|
||||
return "save"
|
||||
if answer == "🗑️ Exit Without Saving":
|
||||
return "discard"
|
||||
return "resume"
|
||||
|
||||
|
||||
def run_onboard(initial_config: Config | None = None) -> OnboardResult:
|
||||
"""Run the interactive onboarding questionnaire.
|
||||
|
||||
Args:
|
||||
@ -946,18 +1002,21 @@ def run_onboard(initial_config: Config | None = None) -> Config:
|
||||
If None, loads from config file or creates new default.
|
||||
"""
|
||||
if initial_config is not None:
|
||||
config = initial_config
|
||||
base_config = initial_config.model_copy(deep=True)
|
||||
else:
|
||||
config_path = get_config_path()
|
||||
if config_path.exists():
|
||||
config = load_config()
|
||||
base_config = load_config()
|
||||
else:
|
||||
config = Config()
|
||||
base_config = Config()
|
||||
|
||||
original_config = base_config.model_copy(deep=True)
|
||||
config = base_config.model_copy(deep=True)
|
||||
|
||||
while True:
|
||||
try:
|
||||
_show_main_menu_header()
|
||||
_show_main_menu_header()
|
||||
|
||||
try:
|
||||
answer = questionary.select(
|
||||
"What would you like to configure?",
|
||||
choices=[
|
||||
@ -969,30 +1028,36 @@ def run_onboard(initial_config: Config | None = None) -> Config:
|
||||
"🔧 Configure Tools",
|
||||
"📋 View Configuration Summary",
|
||||
"💾 Save and Exit",
|
||||
"🗑️ Exit Without Saving",
|
||||
],
|
||||
qmark="→",
|
||||
).ask()
|
||||
|
||||
if answer == "🔌 Configure LLM Provider":
|
||||
_configure_providers(config)
|
||||
elif answer == "💬 Configure Chat Channel":
|
||||
_configure_channels(config)
|
||||
elif answer == "⚙️ Configure Channel Common":
|
||||
_configure_general_settings(config, "Channel Common")
|
||||
elif answer == "🤖 Configure Agent Settings":
|
||||
_configure_agents(config)
|
||||
elif answer == "🌐 Configure Gateway":
|
||||
_configure_gateway(config)
|
||||
elif answer == "🔧 Configure Tools":
|
||||
_configure_tools(config)
|
||||
elif answer == "📋 View Configuration Summary":
|
||||
_show_summary(config)
|
||||
elif answer == "💾 Save and Exit":
|
||||
break
|
||||
except KeyboardInterrupt:
|
||||
console.print(
|
||||
"\n\n[yellow]Operation cancelled. Use 'Save and Exit' to save changes.[/yellow]"
|
||||
)
|
||||
break
|
||||
answer = None
|
||||
|
||||
return config
|
||||
if answer is None:
|
||||
action = _prompt_main_menu_exit(_has_unsaved_changes(original_config, config))
|
||||
if action == "save":
|
||||
return OnboardResult(config=config, should_save=True)
|
||||
if action == "discard":
|
||||
return OnboardResult(config=original_config, should_save=False)
|
||||
continue
|
||||
|
||||
if answer == "🔌 Configure LLM Provider":
|
||||
_configure_providers(config)
|
||||
elif answer == "💬 Configure Chat Channel":
|
||||
_configure_channels(config)
|
||||
elif answer == "⚙️ Configure Channel Common":
|
||||
_configure_general_settings(config, "Channel Common")
|
||||
elif answer == "🤖 Configure Agent Settings":
|
||||
_configure_agents(config)
|
||||
elif answer == "🌐 Configure Gateway":
|
||||
_configure_gateway(config)
|
||||
elif answer == "🔧 Configure Tools":
|
||||
_configure_tools(config)
|
||||
elif answer == "📋 View Configuration Summary":
|
||||
_show_summary(config)
|
||||
elif answer == "💾 Save and Exit":
|
||||
return OnboardResult(config=config, should_save=True)
|
||||
elif answer == "🗑️ Exit Without Saving":
|
||||
return OnboardResult(config=original_config, should_save=False)
|
||||
|
||||
@ -16,24 +16,25 @@ from nanobot.providers.registry import find_by_model
|
||||
runner = CliRunner()
|
||||
|
||||
|
||||
class _StopGateway(RuntimeError):
|
||||
class _StopGatewayError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
def _strip_ansi(text):
|
||||
"""Remove ANSI escape codes from text."""
|
||||
ansi_escape = re.compile(r'\x1b\[[0-9;]*m')
|
||||
return ansi_escape.sub('', text)
|
||||
ansi_escape = re.compile(r"\x1b\[[0-9;]*m")
|
||||
return ansi_escape.sub("", text)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_paths():
|
||||
"""Mock config/workspace paths for test isolation."""
|
||||
with patch("nanobot.config.loader.get_config_path") as mock_cp, \
|
||||
patch("nanobot.config.loader.save_config") as mock_sc, \
|
||||
patch("nanobot.config.loader.load_config") as mock_lc, \
|
||||
patch("nanobot.cli.commands.get_workspace_path") as mock_ws:
|
||||
|
||||
with (
|
||||
patch("nanobot.config.loader.get_config_path") as mock_cp,
|
||||
patch("nanobot.config.loader.save_config") as mock_sc,
|
||||
patch("nanobot.config.loader.load_config") as mock_lc,
|
||||
patch("nanobot.cli.commands.get_workspace_path") as mock_ws,
|
||||
):
|
||||
base_dir = Path("./test_onboard_data")
|
||||
if base_dir.exists():
|
||||
shutil.rmtree(base_dir)
|
||||
@ -130,6 +131,24 @@ def test_onboard_help_shows_workspace_and_config_options():
|
||||
assert "--dir" not in stripped_output
|
||||
|
||||
|
||||
def test_onboard_interactive_discard_does_not_save_or_create_workspace(mock_paths, monkeypatch):
|
||||
config_file, workspace_dir, _ = mock_paths
|
||||
|
||||
from nanobot.cli.onboard_wizard import OnboardResult
|
||||
|
||||
monkeypatch.setattr(
|
||||
"nanobot.cli.onboard_wizard.run_onboard",
|
||||
lambda initial_config: OnboardResult(config=initial_config, should_save=False),
|
||||
)
|
||||
|
||||
result = runner.invoke(app, ["onboard"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "No changes were saved" in result.stdout
|
||||
assert not config_file.exists()
|
||||
assert not workspace_dir.exists()
|
||||
|
||||
|
||||
def test_onboard_uses_explicit_config_and_workspace_paths(tmp_path, monkeypatch):
|
||||
config_path = tmp_path / "instance" / "config.json"
|
||||
workspace_path = tmp_path / "workspace"
|
||||
@ -138,7 +157,14 @@ def test_onboard_uses_explicit_config_and_workspace_paths(tmp_path, monkeypatch)
|
||||
|
||||
result = runner.invoke(
|
||||
app,
|
||||
["onboard", "--config", str(config_path), "--workspace", str(workspace_path), "--non-interactive"],
|
||||
[
|
||||
"onboard",
|
||||
"--config",
|
||||
str(config_path),
|
||||
"--workspace",
|
||||
str(workspace_path),
|
||||
"--non-interactive",
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
@ -278,15 +304,16 @@ def mock_agent_runtime(tmp_path):
|
||||
config.agents.defaults.workspace = str(tmp_path / "default-workspace")
|
||||
cron_dir = tmp_path / "data" / "cron"
|
||||
|
||||
with patch("nanobot.config.loader.load_config", return_value=config) as mock_load_config, \
|
||||
patch("nanobot.config.paths.get_cron_dir", return_value=cron_dir), \
|
||||
patch("nanobot.cli.commands.sync_workspace_templates") as mock_sync_templates, \
|
||||
patch("nanobot.cli.commands._make_provider", return_value=object()), \
|
||||
patch("nanobot.cli.commands._print_agent_response") as mock_print_response, \
|
||||
patch("nanobot.bus.queue.MessageBus"), \
|
||||
patch("nanobot.cron.service.CronService"), \
|
||||
patch("nanobot.agent.loop.AgentLoop") as mock_agent_loop_cls:
|
||||
|
||||
with (
|
||||
patch("nanobot.config.loader.load_config", return_value=config) as mock_load_config,
|
||||
patch("nanobot.config.paths.get_cron_dir", return_value=cron_dir),
|
||||
patch("nanobot.cli.commands.sync_workspace_templates") as mock_sync_templates,
|
||||
patch("nanobot.cli.commands._make_provider", return_value=object()),
|
||||
patch("nanobot.cli.commands._print_agent_response") as mock_print_response,
|
||||
patch("nanobot.bus.queue.MessageBus"),
|
||||
patch("nanobot.cron.service.CronService"),
|
||||
patch("nanobot.agent.loop.AgentLoop") as mock_agent_loop_cls,
|
||||
):
|
||||
agent_loop = MagicMock()
|
||||
agent_loop.channels_config = None
|
||||
agent_loop.process_direct = AsyncMock(return_value="mock-response")
|
||||
@ -326,7 +353,9 @@ def test_agent_uses_default_config_when_no_workspace_or_config_flags(mock_agent_
|
||||
mock_agent_runtime["config"].workspace_path
|
||||
)
|
||||
mock_agent_runtime["agent_loop"].process_direct.assert_awaited_once()
|
||||
mock_agent_runtime["print_response"].assert_called_once_with("mock-response", render_markdown=True)
|
||||
mock_agent_runtime["print_response"].assert_called_once_with(
|
||||
"mock-response", render_markdown=True
|
||||
)
|
||||
|
||||
|
||||
def test_agent_uses_explicit_config_path(mock_agent_runtime, tmp_path: Path):
|
||||
@ -369,7 +398,9 @@ def test_agent_config_sets_active_path(monkeypatch, tmp_path: Path) -> None:
|
||||
return None
|
||||
|
||||
monkeypatch.setattr("nanobot.agent.loop.AgentLoop", _FakeAgentLoop)
|
||||
monkeypatch.setattr("nanobot.cli.commands._print_agent_response", lambda *_args, **_kwargs: None)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.cli.commands._print_agent_response", lambda *_args, **_kwargs: None
|
||||
)
|
||||
|
||||
result = runner.invoke(app, ["agent", "-m", "hello", "-c", str(config_file)])
|
||||
|
||||
@ -435,12 +466,12 @@ def test_gateway_uses_workspace_from_config_by_default(monkeypatch, tmp_path: Pa
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.cli.commands._make_provider",
|
||||
lambda _config: (_ for _ in ()).throw(_StopGateway("stop")),
|
||||
lambda _config: (_ for _ in ()).throw(_StopGatewayError("stop")),
|
||||
)
|
||||
|
||||
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
|
||||
|
||||
assert isinstance(result.exception, _StopGateway)
|
||||
assert isinstance(result.exception, _StopGatewayError)
|
||||
assert seen["config_path"] == config_file.resolve()
|
||||
assert seen["workspace"] == Path(config.agents.defaults.workspace)
|
||||
|
||||
@ -463,7 +494,7 @@ def test_gateway_workspace_option_overrides_config(monkeypatch, tmp_path: Path)
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.cli.commands._make_provider",
|
||||
lambda _config: (_ for _ in ()).throw(_StopGateway("stop")),
|
||||
lambda _config: (_ for _ in ()).throw(_StopGatewayError("stop")),
|
||||
)
|
||||
|
||||
result = runner.invoke(
|
||||
@ -471,7 +502,7 @@ def test_gateway_workspace_option_overrides_config(monkeypatch, tmp_path: Path)
|
||||
["gateway", "--config", str(config_file), "--workspace", str(override)],
|
||||
)
|
||||
|
||||
assert isinstance(result.exception, _StopGateway)
|
||||
assert isinstance(result.exception, _StopGatewayError)
|
||||
assert seen["workspace"] == override
|
||||
assert config.workspace_path == override
|
||||
|
||||
@ -489,15 +520,16 @@ def test_gateway_warns_about_deprecated_memory_window(monkeypatch, tmp_path: Pat
|
||||
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.cli.commands._make_provider",
|
||||
lambda _config: (_ for _ in ()).throw(_StopGateway("stop")),
|
||||
lambda _config: (_ for _ in ()).throw(_StopGatewayError("stop")),
|
||||
)
|
||||
|
||||
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
|
||||
|
||||
assert isinstance(result.exception, _StopGateway)
|
||||
assert isinstance(result.exception, _StopGatewayError)
|
||||
assert "memoryWindow" in result.stdout
|
||||
assert "contextWindowTokens" in result.stdout
|
||||
|
||||
|
||||
def test_gateway_uses_config_directory_for_cron_store(monkeypatch, tmp_path: Path) -> None:
|
||||
config_file = tmp_path / "instance" / "config.json"
|
||||
config_file.parent.mkdir(parents=True)
|
||||
@ -518,13 +550,13 @@ def test_gateway_uses_config_directory_for_cron_store(monkeypatch, tmp_path: Pat
|
||||
class _StopCron:
|
||||
def __init__(self, store_path: Path) -> None:
|
||||
seen["cron_store"] = store_path
|
||||
raise _StopGateway("stop")
|
||||
raise _StopGatewayError("stop")
|
||||
|
||||
monkeypatch.setattr("nanobot.cron.service.CronService", _StopCron)
|
||||
|
||||
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
|
||||
|
||||
assert isinstance(result.exception, _StopGateway)
|
||||
assert isinstance(result.exception, _StopGatewayError)
|
||||
assert seen["cron_store"] == config_file.parent / "cron" / "jobs.json"
|
||||
|
||||
|
||||
@ -541,12 +573,12 @@ def test_gateway_uses_configured_port_when_cli_flag_is_missing(monkeypatch, tmp_
|
||||
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.cli.commands._make_provider",
|
||||
lambda _config: (_ for _ in ()).throw(_StopGateway("stop")),
|
||||
lambda _config: (_ for _ in ()).throw(_StopGatewayError("stop")),
|
||||
)
|
||||
|
||||
result = runner.invoke(app, ["gateway", "--config", str(config_file)])
|
||||
|
||||
assert isinstance(result.exception, _StopGateway)
|
||||
assert isinstance(result.exception, _StopGatewayError)
|
||||
assert "port 18791" in result.stdout
|
||||
|
||||
|
||||
@ -563,10 +595,10 @@ def test_gateway_cli_port_overrides_configured_port(monkeypatch, tmp_path: Path)
|
||||
monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.cli.commands._make_provider",
|
||||
lambda _config: (_ for _ in ()).throw(_StopGateway("stop")),
|
||||
lambda _config: (_ for _ in ()).throw(_StopGatewayError("stop")),
|
||||
)
|
||||
|
||||
result = runner.invoke(app, ["gateway", "--config", str(config_file), "--port", "18792"])
|
||||
|
||||
assert isinstance(result.exception, _StopGateway)
|
||||
assert isinstance(result.exception, _StopGatewayError)
|
||||
assert "port 18792" in result.stdout
|
||||
|
||||
@ -4,24 +4,40 @@ These tests focus on the business logic behind the onboard wizard,
|
||||
without testing the interactive UI components.
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from nanobot.cli import onboard_wizard
|
||||
|
||||
# Import functions to test
|
||||
from nanobot.cli.commands import _merge_missing_defaults
|
||||
from nanobot.cli.onboard_wizard import (
|
||||
_BACK_PRESSED,
|
||||
_configure_pydantic_model,
|
||||
_format_value,
|
||||
_get_field_display_name,
|
||||
_get_field_type_info,
|
||||
run_onboard,
|
||||
)
|
||||
from nanobot.config.schema import Config
|
||||
from nanobot.utils.helpers import sync_workspace_templates
|
||||
|
||||
|
||||
class _SimpleDraftModel(BaseModel):
|
||||
api_key: str = ""
|
||||
|
||||
|
||||
class _NestedDraftModel(BaseModel):
|
||||
api_key: str = ""
|
||||
|
||||
|
||||
class _OuterDraftModel(BaseModel):
|
||||
nested: _NestedDraftModel = Field(default_factory=_NestedDraftModel)
|
||||
|
||||
|
||||
class TestMergeMissingDefaults:
|
||||
"""Tests for _merge_missing_defaults recursive config merging."""
|
||||
|
||||
@ -192,6 +208,7 @@ class TestGetFieldTypeInfo:
|
||||
|
||||
def test_handles_none_annotation(self):
|
||||
"""Field with None annotation defaults to str."""
|
||||
|
||||
class Model(BaseModel):
|
||||
field: Any = None
|
||||
|
||||
@ -371,3 +388,104 @@ class TestProviderChannelInfo:
|
||||
for provider_name, value in info.items():
|
||||
assert isinstance(value, tuple)
|
||||
assert len(value) == 4 # (display_name, needs_api_key, needs_api_base, env_var)
|
||||
|
||||
|
||||
class TestConfigurePydanticModelDrafts:
|
||||
@staticmethod
|
||||
def _patch_prompt_helpers(monkeypatch, tokens, text_value="secret"):
|
||||
sequence = iter(tokens)
|
||||
|
||||
def fake_select(_prompt, choices, default=None):
|
||||
token = next(sequence)
|
||||
if token == "first":
|
||||
return choices[0]
|
||||
if token == "done":
|
||||
return "✓ Done"
|
||||
if token == "back":
|
||||
return _BACK_PRESSED
|
||||
return token
|
||||
|
||||
monkeypatch.setattr(onboard_wizard, "_select_with_back", fake_select)
|
||||
monkeypatch.setattr(onboard_wizard, "_show_config_panel", lambda *_args, **_kwargs: None)
|
||||
monkeypatch.setattr(
|
||||
onboard_wizard, "_input_with_existing", lambda *_args, **_kwargs: text_value
|
||||
)
|
||||
|
||||
def test_discarding_section_keeps_original_model_unchanged(self, monkeypatch):
|
||||
model = _SimpleDraftModel()
|
||||
self._patch_prompt_helpers(monkeypatch, ["first", "back"])
|
||||
|
||||
result = _configure_pydantic_model(model, "Simple")
|
||||
|
||||
assert result is None
|
||||
assert model.api_key == ""
|
||||
|
||||
def test_completing_section_returns_updated_draft(self, monkeypatch):
|
||||
model = _SimpleDraftModel()
|
||||
self._patch_prompt_helpers(monkeypatch, ["first", "done"])
|
||||
|
||||
result = _configure_pydantic_model(model, "Simple")
|
||||
|
||||
assert result is not None
|
||||
updated = cast(_SimpleDraftModel, result)
|
||||
assert updated.api_key == "secret"
|
||||
assert model.api_key == ""
|
||||
|
||||
def test_nested_section_back_discards_nested_edits(self, monkeypatch):
|
||||
model = _OuterDraftModel()
|
||||
self._patch_prompt_helpers(monkeypatch, ["first", "first", "back", "done"])
|
||||
|
||||
result = _configure_pydantic_model(model, "Outer")
|
||||
|
||||
assert result is not None
|
||||
updated = cast(_OuterDraftModel, result)
|
||||
assert updated.nested.api_key == ""
|
||||
assert model.nested.api_key == ""
|
||||
|
||||
def test_nested_section_done_commits_nested_edits(self, monkeypatch):
|
||||
model = _OuterDraftModel()
|
||||
self._patch_prompt_helpers(monkeypatch, ["first", "first", "done", "done"])
|
||||
|
||||
result = _configure_pydantic_model(model, "Outer")
|
||||
|
||||
assert result is not None
|
||||
updated = cast(_OuterDraftModel, result)
|
||||
assert updated.nested.api_key == "secret"
|
||||
assert model.nested.api_key == ""
|
||||
|
||||
|
||||
class TestRunOnboardExitBehavior:
|
||||
def test_main_menu_interrupt_can_discard_unsaved_session_changes(self, monkeypatch):
|
||||
initial_config = Config()
|
||||
|
||||
responses = iter(
|
||||
[
|
||||
"🤖 Configure Agent Settings",
|
||||
KeyboardInterrupt(),
|
||||
"🗑️ Exit Without Saving",
|
||||
]
|
||||
)
|
||||
|
||||
class FakePrompt:
|
||||
def __init__(self, response):
|
||||
self.response = response
|
||||
|
||||
def ask(self):
|
||||
if isinstance(self.response, BaseException):
|
||||
raise self.response
|
||||
return self.response
|
||||
|
||||
def fake_select(*_args, **_kwargs):
|
||||
return FakePrompt(next(responses))
|
||||
|
||||
def fake_configure_agents(config):
|
||||
config.agents.defaults.model = "test/provider-model"
|
||||
|
||||
monkeypatch.setattr(onboard_wizard, "_show_main_menu_header", lambda: None)
|
||||
monkeypatch.setattr(onboard_wizard.questionary, "select", fake_select)
|
||||
monkeypatch.setattr(onboard_wizard, "_configure_agents", fake_configure_agents)
|
||||
|
||||
result = run_onboard(initial_config=initial_config)
|
||||
|
||||
assert result.should_save is False
|
||||
assert result.config.model_dump(by_alias=True) == initial_config.model_dump(by_alias=True)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user