diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py index 2e106ba90..92a086d6e 100644 --- a/nanobot/cli/commands.py +++ b/nanobot/cli/commands.py @@ -1,11 +1,11 @@ """CLI commands for nanobot.""" import asyncio -from contextlib import contextmanager, nullcontext import os import select import signal import sys +from contextlib import contextmanager, nullcontext from pathlib import Path from typing import Any @@ -64,6 +64,7 @@ def _flush_pending_tty_input() -> None: try: import termios + termios.tcflush(fd, termios.TCIFLUSH) return except Exception: @@ -86,6 +87,7 @@ def _restore_terminal() -> None: return try: import termios + termios.tcsetattr(sys.stdin.fileno(), termios.TCSADRAIN, _SAVED_TERM_ATTRS) except Exception: pass @@ -98,6 +100,7 @@ def _init_prompt_session() -> None: # Save terminal state so we can restore it on exit try: import termios + _SAVED_TERM_ATTRS = termios.tcgetattr(sys.stdin.fileno()) except Exception: pass @@ -110,7 +113,7 @@ def _init_prompt_session() -> None: _PROMPT_SESSION = PromptSession( history=FileHistory(str(history_file)), enable_open_in_editor=False, - multiline=False, # Enter submits (single line mode) + multiline=False, # Enter submits (single line mode) ) @@ -143,10 +146,9 @@ def _print_agent_response(response: str, render_markdown: bool) -> None: async def _print_interactive_line(text: str) -> None: """Print async interactive updates with prompt_toolkit-safe Rich styling.""" + def _write() -> None: - ansi = _render_interactive_ansi( - lambda c: c.print(f" [dim]↳ {text}[/dim]") - ) + ansi = _render_interactive_ansi(lambda c: c.print(f" [dim]↳ {text}[/dim]")) print_formatted_text(ANSI(ansi), end="") await run_in_terminal(_write) @@ -154,6 +156,7 @@ async def _print_interactive_line(text: str) -> None: async def _print_interactive_response(response: str, render_markdown: bool) -> None: """Print async interactive replies with prompt_toolkit-safe Rich styling.""" + def _write() -> None: content = response or "" ansi = _render_interactive_ansi( @@ -173,9 +176,9 @@ class _ThinkingSpinner: """Spinner wrapper with pause support for clean progress output.""" def __init__(self, enabled: bool): - self._spinner = console.status( - "[dim]nanobot is thinking...[/dim]", spinner="dots" - ) if enabled else None + self._spinner = ( + console.status("[dim]nanobot is thinking...[/dim]", spinner="dots") if enabled else None + ) self._active = False def __enter__(self): @@ -238,7 +241,6 @@ async def _read_interactive_input_async() -> str: raise KeyboardInterrupt from exc - def version_callback(value: bool): if value: console.print(f"{__logo__} nanobot v{__version__}") @@ -247,9 +249,7 @@ def version_callback(value: bool): @app.callback() def main( - version: bool = typer.Option( - None, "--version", "-v", callback=version_callback, is_eager=True - ), + version: bool = typer.Option(None, "--version", "-v", callback=version_callback, is_eager=True), ): """nanobot - Personal AI Assistant.""" pass @@ -264,7 +264,9 @@ def main( def onboard( workspace: str | None = typer.Option(None, "--workspace", "-w", help="Workspace directory"), config: str | None = typer.Option(None, "--config", "-c", help="Path to config file"), - non_interactive: bool = typer.Option(False, "--non-interactive", help="Skip interactive wizard"), + non_interactive: bool = typer.Option( + False, "--non-interactive", help="Skip interactive wizard" + ), ): """Initialize nanobot configuration and workspace.""" from nanobot.config.loader import get_config_path, load_config, save_config, set_config_path @@ -282,41 +284,53 @@ def onboard( loaded.agents.defaults.workspace = workspace return loaded + cfg: Config + # Non-interactive mode: simple config creation/update if non_interactive: if config_path.exists(): console.print(f"[yellow]Config already exists at {config_path}[/yellow]") - console.print(" [bold]y[/bold] = overwrite with defaults (existing values will be lost)") - console.print(" [bold]N[/bold] = refresh config, keeping existing values and adding new fields") + console.print( + " [bold]y[/bold] = overwrite with defaults (existing values will be lost)" + ) + console.print( + " [bold]N[/bold] = refresh config, keeping existing values and adding new fields" + ) if typer.confirm("Overwrite?"): - config = _apply_workspace_override(Config()) - save_config(config, config_path) + cfg = _apply_workspace_override(Config()) + save_config(cfg, config_path) console.print(f"[green]✓[/green] Config reset to defaults at {config_path}") else: - config = _apply_workspace_override(load_config(config_path)) - save_config(config, config_path) - console.print(f"[green]✓[/green] Config refreshed at {config_path} (existing values preserved)") + cfg = _apply_workspace_override(load_config(config_path)) + save_config(cfg, config_path) + console.print( + f"[green]✓[/green] Config refreshed at {config_path} (existing values preserved)" + ) else: - config = _apply_workspace_override(Config()) - save_config(config, config_path) + cfg = _apply_workspace_override(Config()) + save_config(cfg, config_path) console.print(f"[green]✓[/green] Created config at {config_path}") - console.print("[dim]Config template now uses `maxTokens` + `contextWindowTokens`; `memoryWindow` is no longer a runtime setting.[/dim]") + console.print( + "[dim]Config template now uses `maxTokens` + `contextWindowTokens`; `memoryWindow` is no longer a runtime setting.[/dim]" + ) else: # Interactive mode: use wizard if config_path.exists(): - config = load_config() + cfg = _apply_workspace_override(load_config(config_path)) else: - config = Config() - save_config(config) - console.print(f"[green]✓[/green] Created config at {config_path}") + cfg = _apply_workspace_override(Config()) # Run interactive wizard from nanobot.cli.onboard_wizard import run_onboard try: - # Pass the config with workspace override applied as initial config - config = run_onboard(initial_config=config) - save_config(config, config_path) + result = run_onboard(initial_config=cfg) + if not result.should_save: + console.print("[yellow]Configuration discarded. No changes were saved.[/yellow]") + return + + cfg = result.config + save_config(cfg, config_path) console.print(f"[green]✓[/green] Config saved at {config_path}") except Exception as e: console.print(f"[red]✗[/red] Error during configuration: {e}") @@ -326,15 +340,15 @@ def onboard( _onboard_plugins(config_path) # Create workspace, preferring the configured workspace path. - workspace = get_workspace_path(config.workspace_path) - if not workspace.exists(): - workspace.mkdir(parents=True, exist_ok=True) - console.print(f"[green]✓[/green] Created workspace at {workspace}") + workspace_path = get_workspace_path(cfg.workspace_path) + if not workspace_path.exists(): + workspace_path.mkdir(parents=True, exist_ok=True) + console.print(f"[green]✓[/green] Created workspace at {workspace_path}") - sync_workspace_templates(workspace) + sync_workspace_templates(workspace_path) agent_cmd = 'nanobot agent -m "Hello!"' - if config: + if cfg: agent_cmd += f" --config {config_path}" console.print(f"\n{__logo__} nanobot is ready!") @@ -344,9 +358,11 @@ def onboard( console.print(" Get one at: https://openrouter.ai/keys") console.print(f" 2. Chat: [cyan]{agent_cmd}[/cyan]") else: - console.print(" 1. Chat: [cyan]nanobot agent -m \"Hello!\"[/cyan]") + console.print(' 1. Chat: [cyan]nanobot agent -m "Hello!"[/cyan]') console.print(" 2. Start gateway: [cyan]nanobot gateway[/cyan]") - console.print("\n[dim]Want Telegram/WhatsApp? See: https://github.com/HKUDS/nanobot#-chat-apps[/dim]") + console.print( + "\n[dim]Want Telegram/WhatsApp? See: https://github.com/HKUDS/nanobot#-chat-apps[/dim]" + ) def _merge_missing_defaults(existing: Any, defaults: Any) -> Any: @@ -403,6 +419,7 @@ def _make_provider(config: Config): # Custom: direct OpenAI-compatible endpoint, bypasses LiteLLM elif provider_name == "custom": from nanobot.providers.custom_provider import CustomProvider + provider = CustomProvider( api_key=p.api_key if p else "no-key", api_base=config.get_api_base(model) or "http://localhost:8000/v1", @@ -424,6 +441,7 @@ def _make_provider(config: Config): # OpenVINO Model Server: direct OpenAI-compatible endpoint at /v3 elif provider_name == "ovms": from nanobot.providers.custom_provider import CustomProvider + provider = CustomProvider( api_key=p.api_key if p else "no-key", api_base=config.get_api_base(model) or "http://localhost:8000/v3", @@ -432,8 +450,13 @@ def _make_provider(config: Config): else: from nanobot.providers.litellm_provider import LiteLLMProvider from nanobot.providers.registry import find_by_name + spec = find_by_name(provider_name) - if not model.startswith("bedrock/") and not (p and p.api_key) and not (spec and (spec.is_oauth or spec.is_local)): + if ( + not model.startswith("bedrock/") + and not (p and p.api_key) + and not (spec and (spec.is_oauth or spec.is_local)) + ): console.print("[red]Error: No API key configured.[/red]") console.print("Set one in ~/.nanobot/config.json under providers section") raise typer.Exit(1) @@ -507,6 +530,7 @@ def gateway( if verbose: import logging + logging.basicConfig(level=logging.DEBUG) config = _load_runtime_config(config, workspace) @@ -576,16 +600,23 @@ def gateway( if job.payload.deliver and job.payload.to and response: should_notify = await evaluate_response( - response, job.payload.message, provider, agent.model, + response, + job.payload.message, + provider, + agent.model, ) if should_notify: from nanobot.bus.events import OutboundMessage - await bus.publish_outbound(OutboundMessage( - channel=job.payload.channel or "cli", - chat_id=job.payload.to, - content=response, - )) + + await bus.publish_outbound( + OutboundMessage( + channel=job.payload.channel or "cli", + chat_id=job.payload.to, + content=response, + ) + ) return response + cron.on_job = on_cron_job # Create channel manager @@ -626,10 +657,13 @@ def gateway( async def on_heartbeat_notify(response: str) -> None: """Deliver a heartbeat response to the user's channel.""" from nanobot.bus.events import OutboundMessage + channel, chat_id = _pick_heartbeat_target() if channel == "cli": return # No external channel available to deliver to - await bus.publish_outbound(OutboundMessage(channel=channel, chat_id=chat_id, content=response)) + await bus.publish_outbound( + OutboundMessage(channel=channel, chat_id=chat_id, content=response) + ) hb_cfg = config.gateway.heartbeat heartbeat = HeartbeatService( @@ -665,6 +699,7 @@ def gateway( console.print("\nShutting down...") except Exception: import traceback + console.print("\n[red]Error: Gateway crashed unexpectedly[/red]") console.print(traceback.format_exc()) finally: @@ -677,8 +712,6 @@ def gateway( asyncio.run(run()) - - # ============================================================================ # Agent Commands # ============================================================================ @@ -690,8 +723,12 @@ def agent( session_id: str = typer.Option("cli:direct", "--session", "-s", help="Session ID"), workspace: str | None = typer.Option(None, "--workspace", "-w", help="Workspace directory"), config: str | None = typer.Option(None, "--config", "-c", help="Config file path"), - markdown: bool = typer.Option(True, "--markdown/--no-markdown", help="Render assistant output as Markdown"), - logs: bool = typer.Option(False, "--logs/--no-logs", help="Show nanobot runtime logs during chat"), + markdown: bool = typer.Option( + True, "--markdown/--no-markdown", help="Render assistant output as Markdown" + ), + logs: bool = typer.Option( + False, "--logs/--no-logs", help="Show nanobot runtime logs during chat" + ), ): """Interact with the agent directly.""" from loguru import logger @@ -751,7 +788,9 @@ def agent( nonlocal _thinking _thinking = _ThinkingSpinner(enabled=not logs) with _thinking: - response = await agent_loop.process_direct(message, session_id, on_progress=_cli_progress) + response = await agent_loop.process_direct( + message, session_id, on_progress=_cli_progress + ) _thinking = None _print_agent_response(response, render_markdown=markdown) await agent_loop.close_mcp() @@ -760,8 +799,11 @@ def agent( else: # Interactive mode — route through bus like other channels from nanobot.bus.events import InboundMessage + _init_prompt_session() - console.print(f"{__logo__} Interactive mode (type [bold]exit[/bold] or [bold]Ctrl+C[/bold] to quit)\n") + console.print( + f"{__logo__} Interactive mode (type [bold]exit[/bold] or [bold]Ctrl+C[/bold] to quit)\n" + ) if ":" in session_id: cli_channel, cli_chat_id = session_id.split(":", 1) @@ -777,11 +819,11 @@ def agent( signal.signal(signal.SIGINT, _handle_signal) signal.signal(signal.SIGTERM, _handle_signal) # SIGHUP is not available on Windows - if hasattr(signal, 'SIGHUP'): + if hasattr(signal, "SIGHUP"): signal.signal(signal.SIGHUP, _handle_signal) # Ignore SIGPIPE to prevent silent process termination when writing to closed pipes # SIGPIPE is not available on Windows - if hasattr(signal, 'SIGPIPE'): + if hasattr(signal, "SIGPIPE"): signal.signal(signal.SIGPIPE, signal.SIG_IGN) async def run_interactive(): @@ -835,12 +877,14 @@ def agent( turn_done.clear() turn_response.clear() - await bus.publish_inbound(InboundMessage( - channel=cli_channel, - sender_id="user", - chat_id=cli_chat_id, - content=user_input, - )) + await bus.publish_inbound( + InboundMessage( + channel=cli_channel, + sender_id="user", + chat_id=cli_chat_id, + content=user_input, + ) + ) nonlocal _thinking _thinking = _ThinkingSpinner(enabled=not logs) @@ -982,7 +1026,11 @@ def channels_login(): env = {**os.environ} wa_cfg = getattr(config.channels, "whatsapp", None) or {} - bridge_token = wa_cfg.get("bridgeToken", "") if isinstance(wa_cfg, dict) else getattr(wa_cfg, "bridge_token", "") + bridge_token = ( + wa_cfg.get("bridgeToken", "") + if isinstance(wa_cfg, dict) + else getattr(wa_cfg, "bridge_token", "") + ) if bridge_token: env["BRIDGE_TOKEN"] = bridge_token env["AUTH_DIR"] = str(get_runtime_subdir("whatsapp-auth")) @@ -1056,8 +1104,12 @@ def status(): console.print(f"{__logo__} nanobot Status\n") - console.print(f"Config: {config_path} {'[green]✓[/green]' if config_path.exists() else '[red]✗[/red]'}") - console.print(f"Workspace: {workspace} {'[green]✓[/green]' if workspace.exists() else '[red]✗[/red]'}") + console.print( + f"Config: {config_path} {'[green]✓[/green]' if config_path.exists() else '[red]✗[/red]'}" + ) + console.print( + f"Workspace: {workspace} {'[green]✓[/green]' if workspace.exists() else '[red]✗[/red]'}" + ) if config_path.exists(): from nanobot.providers.registry import PROVIDERS @@ -1079,7 +1131,9 @@ def status(): console.print(f"{spec.label}: [dim]not set[/dim]") else: has_key = bool(p.api_key) - console.print(f"{spec.label}: {'[green]✓[/green]' if has_key else '[dim]not set[/dim]'}") + console.print( + f"{spec.label}: {'[green]✓[/green]' if has_key else '[dim]not set[/dim]'}" + ) # ============================================================================ @@ -1097,12 +1151,15 @@ def _register_login(name: str): def decorator(fn): _LOGIN_HANDLERS[name] = fn return fn + return decorator @provider_app.command("login") def provider_login( - provider: str = typer.Argument(..., help="OAuth provider (e.g. 'openai-codex', 'github-copilot')"), + provider: str = typer.Argument( + ..., help="OAuth provider (e.g. 'openai-codex', 'github-copilot')" + ), ): """Authenticate with an OAuth provider.""" from nanobot.providers.registry import PROVIDERS @@ -1127,6 +1184,7 @@ def provider_login( def _login_openai_codex() -> None: try: from oauth_cli_kit import get_token, login_oauth_interactive + token = None try: token = get_token() @@ -1141,7 +1199,9 @@ def _login_openai_codex() -> None: if not (token and token.access): console.print("[red]✗ Authentication failed[/red]") raise typer.Exit(1) - console.print(f"[green]✓ Authenticated with OpenAI Codex[/green] [dim]{token.account_id}[/dim]") + console.print( + f"[green]✓ Authenticated with OpenAI Codex[/green] [dim]{token.account_id}[/dim]" + ) except ImportError: console.print("[red]oauth_cli_kit not installed. Run: pip install oauth-cli-kit[/red]") raise typer.Exit(1) @@ -1155,7 +1215,12 @@ def _login_github_copilot() -> None: async def _trigger(): from litellm import acompletion - await acompletion(model="github_copilot/gpt-4o", messages=[{"role": "user", "content": "hi"}], max_tokens=1) + + await acompletion( + model="github_copilot/gpt-4o", + messages=[{"role": "user", "content": "hi"}], + max_tokens=1, + ) try: asyncio.run(_trigger()) diff --git a/nanobot/cli/onboard_wizard.py b/nanobot/cli/onboard_wizard.py index 811434827..a3344d52c 100644 --- a/nanobot/cli/onboard_wizard.py +++ b/nanobot/cli/onboard_wizard.py @@ -2,7 +2,8 @@ import json import types -from typing import Any, Callable, get_args, get_origin +from dataclasses import dataclass +from typing import Any, get_args, get_origin import questionary from loguru import logger @@ -21,6 +22,15 @@ from nanobot.config.schema import Config console = Console() + +@dataclass +class OnboardResult: + """Result of an onboarding session.""" + + config: Config + should_save: bool + + # --- Field Hints for Select Fields --- # Maps field names to (choices, hint_text) # To add a new select field with hints, add an entry: @@ -128,10 +138,12 @@ def _select_with_back( event.app.exit() # Style - style = Style.from_dict({ - "selected": "fg:green bold", - "question": "fg:cyan", - }) + style = Style.from_dict( + { + "selected": "fg:green bold", + "question": "fg:cyan", + } + ) app = Application(layout=layout, key_bindings=bindings, style=style) try: @@ -142,6 +154,7 @@ def _select_with_back( return state["result"] + # --- Type Introspection --- @@ -268,9 +281,7 @@ def _show_main_menu_header() -> None: # Use Align.CENTER for the single line of text from rich.align import Align - console.print( - Align.center(f"{__logo__} [bold cyan]nanobot[{__version__}][/bold cyan]") - ) + console.print(Align.center(f"{__logo__} [bold cyan]nanobot[{__version__}][/bold cyan]")) console.print() @@ -329,9 +340,7 @@ def _input_text(display_name: str, current: Any, field_type: str) -> Any: return value -def _input_with_existing( - display_name: str, current: Any, field_type: str -) -> Any: +def _input_with_existing(display_name: str, current: Any, field_type: str) -> Any: """Handle input with 'keep existing' option for non-empty values.""" has_existing = current is not None and current != "" and current != {} and current != [] @@ -357,12 +366,8 @@ def _get_current_provider(model: BaseModel) -> str: return "auto" -def _input_model_with_autocomplete( - display_name: str, current: Any, provider: str -) -> str | None: - """Get model input with autocomplete suggestions. - - """ +def _input_model_with_autocomplete(display_name: str, current: Any, provider: str) -> str | None: + """Get model input with autocomplete suggestions.""" from prompt_toolkit.completion import Completer, Completion default = str(current) if current else "" @@ -431,7 +436,9 @@ def _input_context_window_with_recommendation( context_limit = get_model_context_limit(model_name, provider) if context_limit: - console.print(f"[green]✓ Recommended context window: {format_token_count(context_limit)} tokens[/green]") + console.print( + f"[green]✓ Recommended context window: {format_token_count(context_limit)} tokens[/green]" + ) return context_limit else: console.print("[yellow]⚠ Could not fetch model info, please enter manually[/yellow]") @@ -458,83 +465,88 @@ def _configure_pydantic_model( display_name: str, *, skip_fields: set[str] | None = None, - finalize_hook: Callable | None = None, -) -> None: - """Configure a Pydantic model interactively.""" +) -> BaseModel | None: + """Configure a Pydantic model interactively. + + Returns the updated model only when the user explicitly selects "Done". + Back and cancel actions discard the section draft. + """ skip_fields = skip_fields or set() + working_model = model.model_copy(deep=True) fields = [] - for field_name, field_info in type(model).model_fields.items(): + for field_name, field_info in type(working_model).model_fields.items(): if field_name in skip_fields: continue fields.append((field_name, field_info)) if not fields: console.print(f"[dim]{display_name}: No configurable fields[/dim]") - return + return working_model def get_choices() -> list[str]: choices = [] for field_name, field_info in fields: - value = getattr(model, field_name, None) + value = getattr(working_model, field_name, None) display = _get_field_display_name(field_name, field_info) formatted = _format_value(value, rich=False) choices.append(f"{display}: {formatted}") return choices + ["✓ Done"] while True: - _show_config_panel(display_name, model, fields) + _show_config_panel(display_name, working_model, fields) choices = get_choices() answer = _select_with_back("Select field to configure:", choices) - if answer is _BACK_PRESSED: - # User pressed Escape or Left arrow - go back - if finalize_hook: - finalize_hook(model) - break + if answer is _BACK_PRESSED or answer is None: + return None - if answer == "✓ Done" or answer is None: - if finalize_hook: - finalize_hook(model) - break + if answer == "✓ Done": + return working_model field_idx = next((i for i, c in enumerate(choices) if c == answer), -1) if field_idx < 0 or field_idx >= len(fields): - break + return None field_name, field_info = fields[field_idx] - current_value = getattr(model, field_name, None) + current_value = getattr(working_model, field_name, None) field_type, _ = _get_field_type_info(field_info) field_display = _get_field_display_name(field_name, field_info) if field_type == "model": nested_model = current_value + created_nested_model = nested_model is None if nested_model is None: _, nested_cls = _get_field_type_info(field_info) if nested_cls: nested_model = nested_cls() - setattr(model, field_name, nested_model) if nested_model and isinstance(nested_model, BaseModel): - _configure_pydantic_model(nested_model, field_display) + updated_nested_model = _configure_pydantic_model(nested_model, field_display) + if updated_nested_model is not None: + setattr(working_model, field_name, updated_nested_model) + elif created_nested_model: + setattr(working_model, field_name, None) continue # Special handling for model field (autocomplete) if field_name == "model": - provider = _get_current_provider(model) + provider = _get_current_provider(working_model) new_value = _input_model_with_autocomplete(field_display, current_value, provider) if new_value is not None and new_value != current_value: - setattr(model, field_name, new_value) + setattr(working_model, field_name, new_value) # Auto-fill context_window_tokens if it's at default value - _try_auto_fill_context_window(model, new_value) + _try_auto_fill_context_window(working_model, new_value) continue # Special handling for context_window_tokens field if field_name == "context_window_tokens": - new_value = _input_context_window_with_recommendation(field_display, current_value, model) + new_value = _input_context_window_with_recommendation( + field_display, current_value, working_model + ) if new_value is not None: - setattr(model, field_name, new_value) + setattr(working_model, field_name, new_value) continue # Special handling for select fields with hints (e.g., reasoning_effort) @@ -542,23 +554,25 @@ def _configure_pydantic_model( choices_list, hint = _SELECT_FIELD_HINTS[field_name] select_choices = choices_list + ["(clear/unset)"] console.print(f"[dim] Hint: {hint}[/dim]") - new_value = _select_with_back(field_display, select_choices, default=current_value or select_choices[0]) + new_value = _select_with_back( + field_display, select_choices, default=current_value or select_choices[0] + ) if new_value is _BACK_PRESSED: continue if new_value == "(clear/unset)": - setattr(model, field_name, None) + setattr(working_model, field_name, None) elif new_value is not None: - setattr(model, field_name, new_value) + setattr(working_model, field_name, new_value) continue if field_type == "bool": new_value = _input_bool(field_display, current_value) if new_value is not None: - setattr(model, field_name, new_value) + setattr(working_model, field_name, new_value) else: new_value = _input_with_existing(field_display, current_value, field_type) if new_value is not None: - setattr(model, field_name, new_value) + setattr(working_model, field_name, new_value) def _try_auto_fill_context_window(model: BaseModel, new_model_name: str) -> None: @@ -589,7 +603,9 @@ def _try_auto_fill_context_window(model: BaseModel, new_model_name: str) -> None if context_limit: setattr(model, "context_window_tokens", context_limit) - console.print(f"[green]✓ Auto-filled context window: {format_token_count(context_limit)} tokens[/green]") + console.print( + f"[green]✓ Auto-filled context window: {format_token_count(context_limit)} tokens[/green]" + ) else: console.print("[dim]ℹ Could not auto-fill context window (model not in database)[/dim]") @@ -637,10 +653,12 @@ def _configure_provider(config: Config, provider_name: str) -> None: if default_api_base and not provider_config.api_base: provider_config.api_base = default_api_base - _configure_pydantic_model( + updated_provider = _configure_pydantic_model( provider_config, display_name, ) + if updated_provider is not None: + setattr(config.providers, provider_name, updated_provider) def _configure_providers(config: Config) -> None: @@ -747,15 +765,13 @@ def _configure_channel(config: Config, channel_name: str) -> None: model = config_cls.model_validate(channel_dict) if channel_dict else config_cls() - def finalize(model: BaseModel): - new_dict = model.model_dump(by_alias=True, exclude_none=True) - setattr(config.channels, channel_name, new_dict) - - _configure_pydantic_model( + updated_channel = _configure_pydantic_model( model, display_name, - finalize_hook=finalize, ) + if updated_channel is not None: + new_dict = updated_channel.model_dump(by_alias=True, exclude_none=True) + setattr(config.channels, channel_name, new_dict) def _configure_channels(config: Config) -> None: @@ -798,13 +814,25 @@ def _configure_general_settings(config: Config, section: str) -> None: model, display_name = section_map[section] if section == "Tools": - _configure_pydantic_model( + updated_model = _configure_pydantic_model( model, display_name, skip_fields={"mcp_servers"}, ) else: - _configure_pydantic_model(model, display_name) + updated_model = _configure_pydantic_model(model, display_name) + + if updated_model is None: + return + + if section == "Agent Settings": + config.agents.defaults = updated_model + elif section == "Gateway": + config.gateway = updated_model + elif section == "Tools": + config.tools = updated_model + elif section == "Channel Common": + config.channels = updated_model def _configure_agents(config: Config) -> None: @@ -938,7 +966,35 @@ def _show_summary(config: Config) -> None: # --- Main Entry Point --- -def run_onboard(initial_config: Config | None = None) -> Config: +def _has_unsaved_changes(original: Config, current: Config) -> bool: + """Return True when the onboarding session has committed changes.""" + return original.model_dump(by_alias=True) != current.model_dump(by_alias=True) + + +def _prompt_main_menu_exit(has_unsaved_changes: bool) -> str: + """Resolve how to leave the main menu.""" + if not has_unsaved_changes: + return "discard" + + answer = questionary.select( + "You have unsaved changes. What would you like to do?", + choices=[ + "💾 Save and Exit", + "🗑️ Exit Without Saving", + "↩ Resume Editing", + ], + default="↩ Resume Editing", + qmark="→", + ).ask() + + if answer == "💾 Save and Exit": + return "save" + if answer == "🗑️ Exit Without Saving": + return "discard" + return "resume" + + +def run_onboard(initial_config: Config | None = None) -> OnboardResult: """Run the interactive onboarding questionnaire. Args: @@ -946,18 +1002,21 @@ def run_onboard(initial_config: Config | None = None) -> Config: If None, loads from config file or creates new default. """ if initial_config is not None: - config = initial_config + base_config = initial_config.model_copy(deep=True) else: config_path = get_config_path() if config_path.exists(): - config = load_config() + base_config = load_config() else: - config = Config() + base_config = Config() + + original_config = base_config.model_copy(deep=True) + config = base_config.model_copy(deep=True) while True: - try: - _show_main_menu_header() + _show_main_menu_header() + try: answer = questionary.select( "What would you like to configure?", choices=[ @@ -969,30 +1028,36 @@ def run_onboard(initial_config: Config | None = None) -> Config: "🔧 Configure Tools", "📋 View Configuration Summary", "💾 Save and Exit", + "🗑️ Exit Without Saving", ], qmark="→", ).ask() - - if answer == "🔌 Configure LLM Provider": - _configure_providers(config) - elif answer == "💬 Configure Chat Channel": - _configure_channels(config) - elif answer == "⚙️ Configure Channel Common": - _configure_general_settings(config, "Channel Common") - elif answer == "🤖 Configure Agent Settings": - _configure_agents(config) - elif answer == "🌐 Configure Gateway": - _configure_gateway(config) - elif answer == "🔧 Configure Tools": - _configure_tools(config) - elif answer == "📋 View Configuration Summary": - _show_summary(config) - elif answer == "💾 Save and Exit": - break except KeyboardInterrupt: - console.print( - "\n\n[yellow]Operation cancelled. Use 'Save and Exit' to save changes.[/yellow]" - ) - break + answer = None - return config + if answer is None: + action = _prompt_main_menu_exit(_has_unsaved_changes(original_config, config)) + if action == "save": + return OnboardResult(config=config, should_save=True) + if action == "discard": + return OnboardResult(config=original_config, should_save=False) + continue + + if answer == "🔌 Configure LLM Provider": + _configure_providers(config) + elif answer == "💬 Configure Chat Channel": + _configure_channels(config) + elif answer == "⚙️ Configure Channel Common": + _configure_general_settings(config, "Channel Common") + elif answer == "🤖 Configure Agent Settings": + _configure_agents(config) + elif answer == "🌐 Configure Gateway": + _configure_gateway(config) + elif answer == "🔧 Configure Tools": + _configure_tools(config) + elif answer == "📋 View Configuration Summary": + _show_summary(config) + elif answer == "💾 Save and Exit": + return OnboardResult(config=config, should_save=True) + elif answer == "🗑️ Exit Without Saving": + return OnboardResult(config=original_config, should_save=False) diff --git a/tests/test_commands.py b/tests/test_commands.py index 08a6397fd..e08ef8788 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -16,24 +16,25 @@ from nanobot.providers.registry import find_by_model runner = CliRunner() -class _StopGateway(RuntimeError): +class _StopGatewayError(RuntimeError): pass def _strip_ansi(text): """Remove ANSI escape codes from text.""" - ansi_escape = re.compile(r'\x1b\[[0-9;]*m') - return ansi_escape.sub('', text) + ansi_escape = re.compile(r"\x1b\[[0-9;]*m") + return ansi_escape.sub("", text) @pytest.fixture def mock_paths(): """Mock config/workspace paths for test isolation.""" - with patch("nanobot.config.loader.get_config_path") as mock_cp, \ - patch("nanobot.config.loader.save_config") as mock_sc, \ - patch("nanobot.config.loader.load_config") as mock_lc, \ - patch("nanobot.cli.commands.get_workspace_path") as mock_ws: - + with ( + patch("nanobot.config.loader.get_config_path") as mock_cp, + patch("nanobot.config.loader.save_config") as mock_sc, + patch("nanobot.config.loader.load_config") as mock_lc, + patch("nanobot.cli.commands.get_workspace_path") as mock_ws, + ): base_dir = Path("./test_onboard_data") if base_dir.exists(): shutil.rmtree(base_dir) @@ -130,6 +131,24 @@ def test_onboard_help_shows_workspace_and_config_options(): assert "--dir" not in stripped_output +def test_onboard_interactive_discard_does_not_save_or_create_workspace(mock_paths, monkeypatch): + config_file, workspace_dir, _ = mock_paths + + from nanobot.cli.onboard_wizard import OnboardResult + + monkeypatch.setattr( + "nanobot.cli.onboard_wizard.run_onboard", + lambda initial_config: OnboardResult(config=initial_config, should_save=False), + ) + + result = runner.invoke(app, ["onboard"]) + + assert result.exit_code == 0 + assert "No changes were saved" in result.stdout + assert not config_file.exists() + assert not workspace_dir.exists() + + def test_onboard_uses_explicit_config_and_workspace_paths(tmp_path, monkeypatch): config_path = tmp_path / "instance" / "config.json" workspace_path = tmp_path / "workspace" @@ -138,7 +157,14 @@ def test_onboard_uses_explicit_config_and_workspace_paths(tmp_path, monkeypatch) result = runner.invoke( app, - ["onboard", "--config", str(config_path), "--workspace", str(workspace_path), "--non-interactive"], + [ + "onboard", + "--config", + str(config_path), + "--workspace", + str(workspace_path), + "--non-interactive", + ], ) assert result.exit_code == 0 @@ -278,15 +304,16 @@ def mock_agent_runtime(tmp_path): config.agents.defaults.workspace = str(tmp_path / "default-workspace") cron_dir = tmp_path / "data" / "cron" - with patch("nanobot.config.loader.load_config", return_value=config) as mock_load_config, \ - patch("nanobot.config.paths.get_cron_dir", return_value=cron_dir), \ - patch("nanobot.cli.commands.sync_workspace_templates") as mock_sync_templates, \ - patch("nanobot.cli.commands._make_provider", return_value=object()), \ - patch("nanobot.cli.commands._print_agent_response") as mock_print_response, \ - patch("nanobot.bus.queue.MessageBus"), \ - patch("nanobot.cron.service.CronService"), \ - patch("nanobot.agent.loop.AgentLoop") as mock_agent_loop_cls: - + with ( + patch("nanobot.config.loader.load_config", return_value=config) as mock_load_config, + patch("nanobot.config.paths.get_cron_dir", return_value=cron_dir), + patch("nanobot.cli.commands.sync_workspace_templates") as mock_sync_templates, + patch("nanobot.cli.commands._make_provider", return_value=object()), + patch("nanobot.cli.commands._print_agent_response") as mock_print_response, + patch("nanobot.bus.queue.MessageBus"), + patch("nanobot.cron.service.CronService"), + patch("nanobot.agent.loop.AgentLoop") as mock_agent_loop_cls, + ): agent_loop = MagicMock() agent_loop.channels_config = None agent_loop.process_direct = AsyncMock(return_value="mock-response") @@ -326,7 +353,9 @@ def test_agent_uses_default_config_when_no_workspace_or_config_flags(mock_agent_ mock_agent_runtime["config"].workspace_path ) mock_agent_runtime["agent_loop"].process_direct.assert_awaited_once() - mock_agent_runtime["print_response"].assert_called_once_with("mock-response", render_markdown=True) + mock_agent_runtime["print_response"].assert_called_once_with( + "mock-response", render_markdown=True + ) def test_agent_uses_explicit_config_path(mock_agent_runtime, tmp_path: Path): @@ -369,7 +398,9 @@ def test_agent_config_sets_active_path(monkeypatch, tmp_path: Path) -> None: return None monkeypatch.setattr("nanobot.agent.loop.AgentLoop", _FakeAgentLoop) - monkeypatch.setattr("nanobot.cli.commands._print_agent_response", lambda *_args, **_kwargs: None) + monkeypatch.setattr( + "nanobot.cli.commands._print_agent_response", lambda *_args, **_kwargs: None + ) result = runner.invoke(app, ["agent", "-m", "hello", "-c", str(config_file)]) @@ -435,12 +466,12 @@ def test_gateway_uses_workspace_from_config_by_default(monkeypatch, tmp_path: Pa ) monkeypatch.setattr( "nanobot.cli.commands._make_provider", - lambda _config: (_ for _ in ()).throw(_StopGateway("stop")), + lambda _config: (_ for _ in ()).throw(_StopGatewayError("stop")), ) result = runner.invoke(app, ["gateway", "--config", str(config_file)]) - assert isinstance(result.exception, _StopGateway) + assert isinstance(result.exception, _StopGatewayError) assert seen["config_path"] == config_file.resolve() assert seen["workspace"] == Path(config.agents.defaults.workspace) @@ -463,7 +494,7 @@ def test_gateway_workspace_option_overrides_config(monkeypatch, tmp_path: Path) ) monkeypatch.setattr( "nanobot.cli.commands._make_provider", - lambda _config: (_ for _ in ()).throw(_StopGateway("stop")), + lambda _config: (_ for _ in ()).throw(_StopGatewayError("stop")), ) result = runner.invoke( @@ -471,7 +502,7 @@ def test_gateway_workspace_option_overrides_config(monkeypatch, tmp_path: Path) ["gateway", "--config", str(config_file), "--workspace", str(override)], ) - assert isinstance(result.exception, _StopGateway) + assert isinstance(result.exception, _StopGatewayError) assert seen["workspace"] == override assert config.workspace_path == override @@ -489,15 +520,16 @@ def test_gateway_warns_about_deprecated_memory_window(monkeypatch, tmp_path: Pat monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None) monkeypatch.setattr( "nanobot.cli.commands._make_provider", - lambda _config: (_ for _ in ()).throw(_StopGateway("stop")), + lambda _config: (_ for _ in ()).throw(_StopGatewayError("stop")), ) result = runner.invoke(app, ["gateway", "--config", str(config_file)]) - assert isinstance(result.exception, _StopGateway) + assert isinstance(result.exception, _StopGatewayError) assert "memoryWindow" in result.stdout assert "contextWindowTokens" in result.stdout + def test_gateway_uses_config_directory_for_cron_store(monkeypatch, tmp_path: Path) -> None: config_file = tmp_path / "instance" / "config.json" config_file.parent.mkdir(parents=True) @@ -518,13 +550,13 @@ def test_gateway_uses_config_directory_for_cron_store(monkeypatch, tmp_path: Pat class _StopCron: def __init__(self, store_path: Path) -> None: seen["cron_store"] = store_path - raise _StopGateway("stop") + raise _StopGatewayError("stop") monkeypatch.setattr("nanobot.cron.service.CronService", _StopCron) result = runner.invoke(app, ["gateway", "--config", str(config_file)]) - assert isinstance(result.exception, _StopGateway) + assert isinstance(result.exception, _StopGatewayError) assert seen["cron_store"] == config_file.parent / "cron" / "jobs.json" @@ -541,12 +573,12 @@ def test_gateway_uses_configured_port_when_cli_flag_is_missing(monkeypatch, tmp_ monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None) monkeypatch.setattr( "nanobot.cli.commands._make_provider", - lambda _config: (_ for _ in ()).throw(_StopGateway("stop")), + lambda _config: (_ for _ in ()).throw(_StopGatewayError("stop")), ) result = runner.invoke(app, ["gateway", "--config", str(config_file)]) - assert isinstance(result.exception, _StopGateway) + assert isinstance(result.exception, _StopGatewayError) assert "port 18791" in result.stdout @@ -563,10 +595,10 @@ def test_gateway_cli_port_overrides_configured_port(monkeypatch, tmp_path: Path) monkeypatch.setattr("nanobot.cli.commands.sync_workspace_templates", lambda _path: None) monkeypatch.setattr( "nanobot.cli.commands._make_provider", - lambda _config: (_ for _ in ()).throw(_StopGateway("stop")), + lambda _config: (_ for _ in ()).throw(_StopGatewayError("stop")), ) result = runner.invoke(app, ["gateway", "--config", str(config_file), "--port", "18792"]) - assert isinstance(result.exception, _StopGateway) + assert isinstance(result.exception, _StopGatewayError) assert "port 18792" in result.stdout diff --git a/tests/test_onboard_logic.py b/tests/test_onboard_logic.py index a7c8d9603..64284f495 100644 --- a/tests/test_onboard_logic.py +++ b/tests/test_onboard_logic.py @@ -4,24 +4,40 @@ These tests focus on the business logic behind the onboard wizard, without testing the interactive UI components. """ -import json from pathlib import Path from types import SimpleNamespace -from typing import Any +from typing import Any, cast -import pytest from pydantic import BaseModel, Field +from nanobot.cli import onboard_wizard + # Import functions to test from nanobot.cli.commands import _merge_missing_defaults from nanobot.cli.onboard_wizard import ( + _BACK_PRESSED, + _configure_pydantic_model, _format_value, _get_field_display_name, _get_field_type_info, + run_onboard, ) +from nanobot.config.schema import Config from nanobot.utils.helpers import sync_workspace_templates +class _SimpleDraftModel(BaseModel): + api_key: str = "" + + +class _NestedDraftModel(BaseModel): + api_key: str = "" + + +class _OuterDraftModel(BaseModel): + nested: _NestedDraftModel = Field(default_factory=_NestedDraftModel) + + class TestMergeMissingDefaults: """Tests for _merge_missing_defaults recursive config merging.""" @@ -192,6 +208,7 @@ class TestGetFieldTypeInfo: def test_handles_none_annotation(self): """Field with None annotation defaults to str.""" + class Model(BaseModel): field: Any = None @@ -371,3 +388,104 @@ class TestProviderChannelInfo: for provider_name, value in info.items(): assert isinstance(value, tuple) assert len(value) == 4 # (display_name, needs_api_key, needs_api_base, env_var) + + +class TestConfigurePydanticModelDrafts: + @staticmethod + def _patch_prompt_helpers(monkeypatch, tokens, text_value="secret"): + sequence = iter(tokens) + + def fake_select(_prompt, choices, default=None): + token = next(sequence) + if token == "first": + return choices[0] + if token == "done": + return "✓ Done" + if token == "back": + return _BACK_PRESSED + return token + + monkeypatch.setattr(onboard_wizard, "_select_with_back", fake_select) + monkeypatch.setattr(onboard_wizard, "_show_config_panel", lambda *_args, **_kwargs: None) + monkeypatch.setattr( + onboard_wizard, "_input_with_existing", lambda *_args, **_kwargs: text_value + ) + + def test_discarding_section_keeps_original_model_unchanged(self, monkeypatch): + model = _SimpleDraftModel() + self._patch_prompt_helpers(monkeypatch, ["first", "back"]) + + result = _configure_pydantic_model(model, "Simple") + + assert result is None + assert model.api_key == "" + + def test_completing_section_returns_updated_draft(self, monkeypatch): + model = _SimpleDraftModel() + self._patch_prompt_helpers(monkeypatch, ["first", "done"]) + + result = _configure_pydantic_model(model, "Simple") + + assert result is not None + updated = cast(_SimpleDraftModel, result) + assert updated.api_key == "secret" + assert model.api_key == "" + + def test_nested_section_back_discards_nested_edits(self, monkeypatch): + model = _OuterDraftModel() + self._patch_prompt_helpers(monkeypatch, ["first", "first", "back", "done"]) + + result = _configure_pydantic_model(model, "Outer") + + assert result is not None + updated = cast(_OuterDraftModel, result) + assert updated.nested.api_key == "" + assert model.nested.api_key == "" + + def test_nested_section_done_commits_nested_edits(self, monkeypatch): + model = _OuterDraftModel() + self._patch_prompt_helpers(monkeypatch, ["first", "first", "done", "done"]) + + result = _configure_pydantic_model(model, "Outer") + + assert result is not None + updated = cast(_OuterDraftModel, result) + assert updated.nested.api_key == "secret" + assert model.nested.api_key == "" + + +class TestRunOnboardExitBehavior: + def test_main_menu_interrupt_can_discard_unsaved_session_changes(self, monkeypatch): + initial_config = Config() + + responses = iter( + [ + "🤖 Configure Agent Settings", + KeyboardInterrupt(), + "🗑️ Exit Without Saving", + ] + ) + + class FakePrompt: + def __init__(self, response): + self.response = response + + def ask(self): + if isinstance(self.response, BaseException): + raise self.response + return self.response + + def fake_select(*_args, **_kwargs): + return FakePrompt(next(responses)) + + def fake_configure_agents(config): + config.agents.defaults.model = "test/provider-model" + + monkeypatch.setattr(onboard_wizard, "_show_main_menu_header", lambda: None) + monkeypatch.setattr(onboard_wizard.questionary, "select", fake_select) + monkeypatch.setattr(onboard_wizard, "_configure_agents", fake_configure_agents) + + result = run_onboard(initial_config=initial_config) + + assert result.should_save is False + assert result.config.model_dump(by_alias=True) == initial_config.model_dump(by_alias=True)