mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-19 16:12:30 +00:00
fix(onboard): allow empty strings and falsy values in input fields
Fixes two related input-handling bugs in the onboard wizard: 1. _input_text treated "" as None, preventing users from clearing optional string fields or entering empty strings intentionally. 2. _input_model_with_autocomplete used `if value else None`, which discarded falsy values such as empty strings or 0. To support clearing optional string fields, add _is_str_or_none() and normalize empty strings to None inside _configure_pydantic_model only when the field annotation is `str | None`. Required str fields keep "" as a valid value. Also included: - Remember last selected item in provider/channel/model menus for better UX when configuring multiple items. - Rename _SIMPLE_TYPES and _MENU_DISPATCH to lowercase to follow Python naming conventions (they are local variables, not constants). - Remove unused imports in test file. Extracted from PR #3358.
This commit is contained in:
parent
6a3069514c
commit
3a2f47d720
@ -191,13 +191,13 @@ def _get_field_type_info(field_info) -> FieldTypeInfo:
|
||||
origin = get_origin(annotation)
|
||||
args = get_args(annotation)
|
||||
|
||||
_SIMPLE_TYPES: dict[type, str] = {bool: "bool", int: "int", float: "float"}
|
||||
_simple_types: dict[type, str] = {bool: "bool", int: "int", float: "float"}
|
||||
|
||||
if origin is list or (hasattr(origin, "__name__") and origin.__name__ == "List"):
|
||||
return FieldTypeInfo("list", args[0] if args else str)
|
||||
if origin is dict or (hasattr(origin, "__name__") and origin.__name__ == "Dict"):
|
||||
return FieldTypeInfo("dict", None)
|
||||
for py_type, name in _SIMPLE_TYPES.items():
|
||||
for py_type, name in _simple_types.items():
|
||||
if annotation is py_type:
|
||||
return FieldTypeInfo(name, None)
|
||||
if isinstance(annotation, type) and issubclass(annotation, BaseModel):
|
||||
@ -403,7 +403,7 @@ def _input_text(display_name: str, current: Any, field_type: str, field_info=Non
|
||||
|
||||
value = _get_questionary().text(f"{display_name}:", default=default).ask()
|
||||
|
||||
if value is None or value == "":
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
if field_type == "int":
|
||||
@ -507,7 +507,7 @@ def _input_model_with_autocomplete(
|
||||
qmark=">",
|
||||
).ask()
|
||||
|
||||
return value if value else None
|
||||
return value if value is not None else None
|
||||
|
||||
|
||||
def _input_context_window_with_recommendation(
|
||||
@ -594,6 +594,15 @@ _FIELD_HANDLERS: dict[str, Any] = {
|
||||
}
|
||||
|
||||
|
||||
def _is_str_or_none(annotation: Any) -> bool:
|
||||
"""Check whether a field annotation is ``str | None`` (or ``Optional[str]``)."""
|
||||
origin = get_origin(annotation)
|
||||
if origin is None:
|
||||
return False
|
||||
args = get_args(annotation)
|
||||
return str in args and type(None) in args
|
||||
|
||||
|
||||
def _configure_pydantic_model(
|
||||
model: BaseModel,
|
||||
display_name: str,
|
||||
@ -626,11 +635,20 @@ def _configure_pydantic_model(
|
||||
items.append(f"{display}: {formatted}")
|
||||
return items + ["[Done]"]
|
||||
|
||||
last_field_name: str | None = None
|
||||
while True:
|
||||
console.clear()
|
||||
_show_config_panel(display_name, working_model, fields)
|
||||
choices = get_choices()
|
||||
answer = _select_with_back("Select field to configure:", choices)
|
||||
default_choice = None
|
||||
if last_field_name:
|
||||
for idx, (fname, _) in enumerate(fields):
|
||||
if fname == last_field_name:
|
||||
default_choice = choices[idx]
|
||||
break
|
||||
answer = _select_with_back(
|
||||
"Select field to configure:", choices, default=default_choice
|
||||
)
|
||||
|
||||
if answer is _BACK_PRESSED or answer is None:
|
||||
return None
|
||||
@ -641,6 +659,8 @@ def _configure_pydantic_model(
|
||||
if field_idx < 0 or field_idx >= len(fields):
|
||||
return None
|
||||
|
||||
last_field_name = fields[field_idx][0]
|
||||
|
||||
field_name, field_info = fields[field_idx]
|
||||
current_value = getattr(working_model, field_name, None)
|
||||
ftype = _get_field_type_info(field_info)
|
||||
@ -697,6 +717,10 @@ def _configure_pydantic_model(
|
||||
else:
|
||||
new_value = _input_with_existing(field_display, current_value, ftype.type_name, field_info=field_info)
|
||||
if new_value is not None:
|
||||
# Normalize empty string to None for optional string fields so that
|
||||
# clearing an api_key / api_base actually removes the value.
|
||||
if new_value == "" and _is_str_or_none(field_info.annotation):
|
||||
new_value = None
|
||||
setattr(working_model, field_name, new_value)
|
||||
|
||||
|
||||
@ -795,12 +819,23 @@ def _configure_providers(config: Config) -> None:
|
||||
choices.append(display)
|
||||
return choices + ["<- Back"]
|
||||
|
||||
last_provider_key: str | None = None
|
||||
while True:
|
||||
try:
|
||||
console.clear()
|
||||
_show_section_header("LLM Providers", "Select a provider to configure API key and endpoint")
|
||||
choices = get_provider_choices()
|
||||
answer = _select_with_back("Select provider:", choices)
|
||||
default_choice = None
|
||||
if last_provider_key:
|
||||
display = _get_provider_names().get(last_provider_key)
|
||||
if display:
|
||||
for c in choices:
|
||||
if c.replace(" *", "") == display:
|
||||
default_choice = c
|
||||
break
|
||||
answer = _select_with_back(
|
||||
"Select provider:", choices, default=default_choice
|
||||
)
|
||||
|
||||
if answer is _BACK_PRESSED or answer is None or answer == "<- Back":
|
||||
break
|
||||
@ -812,6 +847,7 @@ def _configure_providers(config: Config) -> None:
|
||||
# Find the actual provider key from display names
|
||||
for name, display in _get_provider_names().items():
|
||||
if display == provider_name:
|
||||
last_provider_key = name
|
||||
_configure_provider(config, name)
|
||||
break
|
||||
|
||||
@ -885,17 +921,21 @@ def _configure_channels(config: Config) -> None:
|
||||
channel_names = list(_get_channel_names().keys())
|
||||
choices = channel_names + ["<- Back"]
|
||||
|
||||
last_choice: str | None = None
|
||||
while True:
|
||||
try:
|
||||
console.clear()
|
||||
_show_section_header("Chat Channels", "Select a channel to configure connection settings")
|
||||
answer = _select_with_back("Select channel:", choices)
|
||||
answer = _select_with_back(
|
||||
"Select channel:", choices, default=last_choice
|
||||
)
|
||||
|
||||
if answer is _BACK_PRESSED or answer is None or answer == "<- Back":
|
||||
break
|
||||
|
||||
# Type guard: answer is now guaranteed to be a string
|
||||
assert isinstance(answer, str)
|
||||
last_choice = answer
|
||||
_configure_channel(config, answer)
|
||||
except KeyboardInterrupt:
|
||||
console.print("\n[dim]Returning to main menu...[/dim]")
|
||||
@ -1073,6 +1113,7 @@ def run_onboard(initial_config: Config | None = None) -> OnboardResult:
|
||||
original_config = base_config.model_copy(deep=True)
|
||||
config = base_config.model_copy(deep=True)
|
||||
|
||||
last_main_choice: str | None = None
|
||||
while True:
|
||||
console.clear()
|
||||
_show_main_menu_header()
|
||||
@ -1092,6 +1133,7 @@ def run_onboard(initial_config: Config | None = None) -> OnboardResult:
|
||||
"[S] Save and Exit",
|
||||
"[X] Exit Without Saving",
|
||||
],
|
||||
default=last_main_choice,
|
||||
qmark=">",
|
||||
).ask()
|
||||
except KeyboardInterrupt:
|
||||
@ -1105,7 +1147,7 @@ def run_onboard(initial_config: Config | None = None) -> OnboardResult:
|
||||
return OnboardResult(config=original_config, should_save=False)
|
||||
continue
|
||||
|
||||
_MENU_DISPATCH = {
|
||||
_menu_dispatch = {
|
||||
"[P] LLM Provider": lambda: _configure_providers(config),
|
||||
"[C] Chat Channel": lambda: _configure_channels(config),
|
||||
"[H] Channel Common": lambda: _configure_general_settings(config, "Channel Common"),
|
||||
@ -1121,6 +1163,7 @@ def run_onboard(initial_config: Config | None = None) -> OnboardResult:
|
||||
if answer == "[X] Exit Without Saving":
|
||||
return OnboardResult(config=original_config, should_save=False)
|
||||
|
||||
action_fn = _MENU_DISPATCH.get(answer)
|
||||
action_fn = _menu_dispatch.get(answer)
|
||||
if action_fn:
|
||||
last_main_choice = answer
|
||||
action_fn()
|
||||
|
||||
@ -4,7 +4,6 @@ These tests focus on the business logic behind the onboard wizard,
|
||||
without testing the interactive UI components.
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
@ -13,18 +12,15 @@ import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from nanobot.cli import onboard as onboard_wizard
|
||||
|
||||
# Import functions to test
|
||||
from nanobot.cli.commands import _merge_missing_defaults
|
||||
from nanobot.cli.onboard import (
|
||||
_BACK_PRESSED,
|
||||
_configure_pydantic_model,
|
||||
_format_value,
|
||||
_get_constraint_hint,
|
||||
_get_field_display_name,
|
||||
_get_field_type_info,
|
||||
_get_constraint_hint,
|
||||
_input_text,
|
||||
_validate_field_constraint,
|
||||
run_onboard,
|
||||
)
|
||||
from nanobot.config.schema import Config
|
||||
@ -960,3 +956,121 @@ class TestMainMenuUpdate:
|
||||
|
||||
assert result.should_save is True
|
||||
assert pause_called["n"] == 1
|
||||
|
||||
|
||||
class TestInputTextEmptyString:
|
||||
"""Tests for _input_text empty-string handling bug fix."""
|
||||
|
||||
def test_empty_string_returned_not_none(self, monkeypatch):
|
||||
"""_input_text should return empty string, not None, when user enters ''."""
|
||||
monkeypatch.setattr(
|
||||
onboard_wizard,
|
||||
"_get_questionary",
|
||||
lambda: SimpleNamespace(text=lambda *a, **kw: SimpleNamespace(ask=lambda: "")),
|
||||
)
|
||||
|
||||
result = _input_text("Name", "old", "str")
|
||||
assert result == ""
|
||||
|
||||
def test_none_still_returns_none(self, monkeypatch):
|
||||
"""_input_text should return None when questionary returns None."""
|
||||
monkeypatch.setattr(
|
||||
onboard_wizard,
|
||||
"_get_questionary",
|
||||
lambda: SimpleNamespace(text=lambda *a, **kw: SimpleNamespace(ask=lambda: None)),
|
||||
)
|
||||
|
||||
result = _input_text("Name", "old", "str")
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestIsStrOrNone:
|
||||
"""Tests for _is_str_or_none helper."""
|
||||
|
||||
def test_str_or_none_true(self):
|
||||
from nanobot.cli.onboard import _is_str_or_none
|
||||
|
||||
assert _is_str_or_none(str | None) is True
|
||||
|
||||
def test_optional_str_true(self):
|
||||
from typing import Optional
|
||||
from nanobot.cli.onboard import _is_str_or_none
|
||||
|
||||
assert _is_str_or_none(Optional[str]) is True
|
||||
|
||||
def test_str_only_false(self):
|
||||
from nanobot.cli.onboard import _is_str_or_none
|
||||
|
||||
assert _is_str_or_none(str) is False
|
||||
|
||||
def test_int_or_none_false(self):
|
||||
from nanobot.cli.onboard import _is_str_or_none
|
||||
|
||||
assert _is_str_or_none(int | None) is False
|
||||
|
||||
|
||||
class TestConfigurePydanticModelEmptyString:
|
||||
"""Tests that optional string fields are cleared when empty string is entered."""
|
||||
|
||||
def test_optional_str_empty_string_becomes_none(self, monkeypatch):
|
||||
"""Entering '' for an optional str field should set it to None."""
|
||||
from pydantic import BaseModel
|
||||
from nanobot.cli.onboard import _is_str_or_none
|
||||
|
||||
class M(BaseModel):
|
||||
api_key: str | None = None
|
||||
|
||||
model = M(api_key="secret")
|
||||
|
||||
call_count = {"select": 0}
|
||||
|
||||
def fake_select(_prompt, choices, default=None):
|
||||
call_count["select"] += 1
|
||||
# First call: select the api_key field, then Done
|
||||
if call_count["select"] == 1:
|
||||
for c in choices:
|
||||
if "Api Key" in c:
|
||||
return c
|
||||
return choices[0]
|
||||
return "[Done]"
|
||||
|
||||
monkeypatch.setattr(onboard_wizard, "_select_with_back", fake_select)
|
||||
monkeypatch.setattr(onboard_wizard, "_show_config_panel", lambda *a, **kw: None)
|
||||
# Simulate user entering empty string
|
||||
monkeypatch.setattr(
|
||||
onboard_wizard, "_input_with_existing", lambda *a, **kw: ""
|
||||
)
|
||||
|
||||
result = _configure_pydantic_model(model, "Test")
|
||||
assert result is not None
|
||||
assert result.api_key is None
|
||||
|
||||
def test_required_str_empty_string_kept(self, monkeypatch):
|
||||
"""Entering '' for a required str field should keep the empty string."""
|
||||
from pydantic import BaseModel
|
||||
|
||||
class M(BaseModel):
|
||||
api_key: str = ""
|
||||
|
||||
model = M(api_key="secret")
|
||||
|
||||
call_count = {"select": 0}
|
||||
|
||||
def fake_select(_prompt, choices, default=None):
|
||||
call_count["select"] += 1
|
||||
if call_count["select"] == 1:
|
||||
for c in choices:
|
||||
if "Api Key" in c:
|
||||
return c
|
||||
return choices[0]
|
||||
return "[Done]"
|
||||
|
||||
monkeypatch.setattr(onboard_wizard, "_select_with_back", fake_select)
|
||||
monkeypatch.setattr(onboard_wizard, "_show_config_panel", lambda *a, **kw: None)
|
||||
monkeypatch.setattr(
|
||||
onboard_wizard, "_input_with_existing", lambda *a, **kw: ""
|
||||
)
|
||||
|
||||
result = _configure_pydantic_model(model, "Test")
|
||||
assert result is not None
|
||||
assert result.api_key == ""
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user