mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-20 08:32:25 +00:00
fix(onboard): allow empty strings and falsy values in input fields
Fixes two related input-handling bugs in the onboard wizard: 1. _input_text treated "" as None, preventing users from clearing optional string fields or entering empty strings intentionally. 2. _input_model_with_autocomplete used `if value else None`, which discarded falsy values such as empty strings or 0. To support clearing optional string fields, add _is_str_or_none() and normalize empty strings to None inside _configure_pydantic_model only when the field annotation is `str | None`. Required str fields keep "" as a valid value. Also included: - Remember last selected item in provider/channel/model menus for better UX when configuring multiple items. - Rename _SIMPLE_TYPES and _MENU_DISPATCH to lowercase to follow Python naming conventions (they are local variables, not constants). - Remove unused imports in test file. Extracted from PR #3358.
This commit is contained in:
parent
12005c20f0
commit
e34b7fd086
@ -191,13 +191,13 @@ def _get_field_type_info(field_info) -> FieldTypeInfo:
|
|||||||
origin = get_origin(annotation)
|
origin = get_origin(annotation)
|
||||||
args = get_args(annotation)
|
args = get_args(annotation)
|
||||||
|
|
||||||
_SIMPLE_TYPES: dict[type, str] = {bool: "bool", int: "int", float: "float"}
|
_simple_types: dict[type, str] = {bool: "bool", int: "int", float: "float"}
|
||||||
|
|
||||||
if origin is list or (hasattr(origin, "__name__") and origin.__name__ == "List"):
|
if origin is list or (hasattr(origin, "__name__") and origin.__name__ == "List"):
|
||||||
return FieldTypeInfo("list", args[0] if args else str)
|
return FieldTypeInfo("list", args[0] if args else str)
|
||||||
if origin is dict or (hasattr(origin, "__name__") and origin.__name__ == "Dict"):
|
if origin is dict or (hasattr(origin, "__name__") and origin.__name__ == "Dict"):
|
||||||
return FieldTypeInfo("dict", None)
|
return FieldTypeInfo("dict", None)
|
||||||
for py_type, name in _SIMPLE_TYPES.items():
|
for py_type, name in _simple_types.items():
|
||||||
if annotation is py_type:
|
if annotation is py_type:
|
||||||
return FieldTypeInfo(name, None)
|
return FieldTypeInfo(name, None)
|
||||||
if isinstance(annotation, type) and issubclass(annotation, BaseModel):
|
if isinstance(annotation, type) and issubclass(annotation, BaseModel):
|
||||||
@ -403,7 +403,7 @@ def _input_text(display_name: str, current: Any, field_type: str, field_info=Non
|
|||||||
|
|
||||||
value = _get_questionary().text(f"{display_name}:", default=default).ask()
|
value = _get_questionary().text(f"{display_name}:", default=default).ask()
|
||||||
|
|
||||||
if value is None or value == "":
|
if value is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if field_type == "int":
|
if field_type == "int":
|
||||||
@ -507,7 +507,7 @@ def _input_model_with_autocomplete(
|
|||||||
qmark=">",
|
qmark=">",
|
||||||
).ask()
|
).ask()
|
||||||
|
|
||||||
return value if value else None
|
return value if value is not None else None
|
||||||
|
|
||||||
|
|
||||||
def _input_context_window_with_recommendation(
|
def _input_context_window_with_recommendation(
|
||||||
@ -594,6 +594,15 @@ _FIELD_HANDLERS: dict[str, Any] = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _is_str_or_none(annotation: Any) -> bool:
|
||||||
|
"""Check whether a field annotation is ``str | None`` (or ``Optional[str]``)."""
|
||||||
|
origin = get_origin(annotation)
|
||||||
|
if origin is None:
|
||||||
|
return False
|
||||||
|
args = get_args(annotation)
|
||||||
|
return str in args and type(None) in args
|
||||||
|
|
||||||
|
|
||||||
def _configure_pydantic_model(
|
def _configure_pydantic_model(
|
||||||
model: BaseModel,
|
model: BaseModel,
|
||||||
display_name: str,
|
display_name: str,
|
||||||
@ -626,11 +635,20 @@ def _configure_pydantic_model(
|
|||||||
items.append(f"{display}: {formatted}")
|
items.append(f"{display}: {formatted}")
|
||||||
return items + ["[Done]"]
|
return items + ["[Done]"]
|
||||||
|
|
||||||
|
last_field_name: str | None = None
|
||||||
while True:
|
while True:
|
||||||
console.clear()
|
console.clear()
|
||||||
_show_config_panel(display_name, working_model, fields)
|
_show_config_panel(display_name, working_model, fields)
|
||||||
choices = get_choices()
|
choices = get_choices()
|
||||||
answer = _select_with_back("Select field to configure:", choices)
|
default_choice = None
|
||||||
|
if last_field_name:
|
||||||
|
for idx, (fname, _) in enumerate(fields):
|
||||||
|
if fname == last_field_name:
|
||||||
|
default_choice = choices[idx]
|
||||||
|
break
|
||||||
|
answer = _select_with_back(
|
||||||
|
"Select field to configure:", choices, default=default_choice
|
||||||
|
)
|
||||||
|
|
||||||
if answer is _BACK_PRESSED or answer is None:
|
if answer is _BACK_PRESSED or answer is None:
|
||||||
return None
|
return None
|
||||||
@ -641,6 +659,8 @@ def _configure_pydantic_model(
|
|||||||
if field_idx < 0 or field_idx >= len(fields):
|
if field_idx < 0 or field_idx >= len(fields):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
last_field_name = fields[field_idx][0]
|
||||||
|
|
||||||
field_name, field_info = fields[field_idx]
|
field_name, field_info = fields[field_idx]
|
||||||
current_value = getattr(working_model, field_name, None)
|
current_value = getattr(working_model, field_name, None)
|
||||||
ftype = _get_field_type_info(field_info)
|
ftype = _get_field_type_info(field_info)
|
||||||
@ -697,6 +717,10 @@ def _configure_pydantic_model(
|
|||||||
else:
|
else:
|
||||||
new_value = _input_with_existing(field_display, current_value, ftype.type_name, field_info=field_info)
|
new_value = _input_with_existing(field_display, current_value, ftype.type_name, field_info=field_info)
|
||||||
if new_value is not None:
|
if new_value is not None:
|
||||||
|
# Normalize empty string to None for optional string fields so that
|
||||||
|
# clearing an api_key / api_base actually removes the value.
|
||||||
|
if new_value == "" and _is_str_or_none(field_info.annotation):
|
||||||
|
new_value = None
|
||||||
setattr(working_model, field_name, new_value)
|
setattr(working_model, field_name, new_value)
|
||||||
|
|
||||||
|
|
||||||
@ -795,12 +819,23 @@ def _configure_providers(config: Config) -> None:
|
|||||||
choices.append(display)
|
choices.append(display)
|
||||||
return choices + ["<- Back"]
|
return choices + ["<- Back"]
|
||||||
|
|
||||||
|
last_provider_key: str | None = None
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
console.clear()
|
console.clear()
|
||||||
_show_section_header("LLM Providers", "Select a provider to configure API key and endpoint")
|
_show_section_header("LLM Providers", "Select a provider to configure API key and endpoint")
|
||||||
choices = get_provider_choices()
|
choices = get_provider_choices()
|
||||||
answer = _select_with_back("Select provider:", choices)
|
default_choice = None
|
||||||
|
if last_provider_key:
|
||||||
|
display = _get_provider_names().get(last_provider_key)
|
||||||
|
if display:
|
||||||
|
for c in choices:
|
||||||
|
if c.replace(" *", "") == display:
|
||||||
|
default_choice = c
|
||||||
|
break
|
||||||
|
answer = _select_with_back(
|
||||||
|
"Select provider:", choices, default=default_choice
|
||||||
|
)
|
||||||
|
|
||||||
if answer is _BACK_PRESSED or answer is None or answer == "<- Back":
|
if answer is _BACK_PRESSED or answer is None or answer == "<- Back":
|
||||||
break
|
break
|
||||||
@ -812,6 +847,7 @@ def _configure_providers(config: Config) -> None:
|
|||||||
# Find the actual provider key from display names
|
# Find the actual provider key from display names
|
||||||
for name, display in _get_provider_names().items():
|
for name, display in _get_provider_names().items():
|
||||||
if display == provider_name:
|
if display == provider_name:
|
||||||
|
last_provider_key = name
|
||||||
_configure_provider(config, name)
|
_configure_provider(config, name)
|
||||||
break
|
break
|
||||||
|
|
||||||
@ -885,17 +921,21 @@ def _configure_channels(config: Config) -> None:
|
|||||||
channel_names = list(_get_channel_names().keys())
|
channel_names = list(_get_channel_names().keys())
|
||||||
choices = channel_names + ["<- Back"]
|
choices = channel_names + ["<- Back"]
|
||||||
|
|
||||||
|
last_choice: str | None = None
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
console.clear()
|
console.clear()
|
||||||
_show_section_header("Chat Channels", "Select a channel to configure connection settings")
|
_show_section_header("Chat Channels", "Select a channel to configure connection settings")
|
||||||
answer = _select_with_back("Select channel:", choices)
|
answer = _select_with_back(
|
||||||
|
"Select channel:", choices, default=last_choice
|
||||||
|
)
|
||||||
|
|
||||||
if answer is _BACK_PRESSED or answer is None or answer == "<- Back":
|
if answer is _BACK_PRESSED or answer is None or answer == "<- Back":
|
||||||
break
|
break
|
||||||
|
|
||||||
# Type guard: answer is now guaranteed to be a string
|
# Type guard: answer is now guaranteed to be a string
|
||||||
assert isinstance(answer, str)
|
assert isinstance(answer, str)
|
||||||
|
last_choice = answer
|
||||||
_configure_channel(config, answer)
|
_configure_channel(config, answer)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
console.print("\n[dim]Returning to main menu...[/dim]")
|
console.print("\n[dim]Returning to main menu...[/dim]")
|
||||||
@ -1073,6 +1113,7 @@ def run_onboard(initial_config: Config | None = None) -> OnboardResult:
|
|||||||
original_config = base_config.model_copy(deep=True)
|
original_config = base_config.model_copy(deep=True)
|
||||||
config = base_config.model_copy(deep=True)
|
config = base_config.model_copy(deep=True)
|
||||||
|
|
||||||
|
last_main_choice: str | None = None
|
||||||
while True:
|
while True:
|
||||||
console.clear()
|
console.clear()
|
||||||
_show_main_menu_header()
|
_show_main_menu_header()
|
||||||
@ -1092,6 +1133,7 @@ def run_onboard(initial_config: Config | None = None) -> OnboardResult:
|
|||||||
"[S] Save and Exit",
|
"[S] Save and Exit",
|
||||||
"[X] Exit Without Saving",
|
"[X] Exit Without Saving",
|
||||||
],
|
],
|
||||||
|
default=last_main_choice,
|
||||||
qmark=">",
|
qmark=">",
|
||||||
).ask()
|
).ask()
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
@ -1105,7 +1147,7 @@ def run_onboard(initial_config: Config | None = None) -> OnboardResult:
|
|||||||
return OnboardResult(config=original_config, should_save=False)
|
return OnboardResult(config=original_config, should_save=False)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
_MENU_DISPATCH = {
|
_menu_dispatch = {
|
||||||
"[P] LLM Provider": lambda: _configure_providers(config),
|
"[P] LLM Provider": lambda: _configure_providers(config),
|
||||||
"[C] Chat Channel": lambda: _configure_channels(config),
|
"[C] Chat Channel": lambda: _configure_channels(config),
|
||||||
"[H] Channel Common": lambda: _configure_general_settings(config, "Channel Common"),
|
"[H] Channel Common": lambda: _configure_general_settings(config, "Channel Common"),
|
||||||
@ -1121,6 +1163,7 @@ def run_onboard(initial_config: Config | None = None) -> OnboardResult:
|
|||||||
if answer == "[X] Exit Without Saving":
|
if answer == "[X] Exit Without Saving":
|
||||||
return OnboardResult(config=original_config, should_save=False)
|
return OnboardResult(config=original_config, should_save=False)
|
||||||
|
|
||||||
action_fn = _MENU_DISPATCH.get(answer)
|
action_fn = _menu_dispatch.get(answer)
|
||||||
if action_fn:
|
if action_fn:
|
||||||
|
last_main_choice = answer
|
||||||
action_fn()
|
action_fn()
|
||||||
|
|||||||
@ -4,7 +4,6 @@ These tests focus on the business logic behind the onboard wizard,
|
|||||||
without testing the interactive UI components.
|
without testing the interactive UI components.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
@ -13,18 +12,15 @@ import pytest
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from nanobot.cli import onboard as onboard_wizard
|
from nanobot.cli import onboard as onboard_wizard
|
||||||
|
|
||||||
# Import functions to test
|
|
||||||
from nanobot.cli.commands import _merge_missing_defaults
|
from nanobot.cli.commands import _merge_missing_defaults
|
||||||
from nanobot.cli.onboard import (
|
from nanobot.cli.onboard import (
|
||||||
_BACK_PRESSED,
|
_BACK_PRESSED,
|
||||||
_configure_pydantic_model,
|
_configure_pydantic_model,
|
||||||
_format_value,
|
_format_value,
|
||||||
|
_get_constraint_hint,
|
||||||
_get_field_display_name,
|
_get_field_display_name,
|
||||||
_get_field_type_info,
|
_get_field_type_info,
|
||||||
_get_constraint_hint,
|
|
||||||
_input_text,
|
_input_text,
|
||||||
_validate_field_constraint,
|
|
||||||
run_onboard,
|
run_onboard,
|
||||||
)
|
)
|
||||||
from nanobot.config.schema import Config
|
from nanobot.config.schema import Config
|
||||||
@ -960,3 +956,121 @@ class TestMainMenuUpdate:
|
|||||||
|
|
||||||
assert result.should_save is True
|
assert result.should_save is True
|
||||||
assert pause_called["n"] == 1
|
assert pause_called["n"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestInputTextEmptyString:
|
||||||
|
"""Tests for _input_text empty-string handling bug fix."""
|
||||||
|
|
||||||
|
def test_empty_string_returned_not_none(self, monkeypatch):
|
||||||
|
"""_input_text should return empty string, not None, when user enters ''."""
|
||||||
|
monkeypatch.setattr(
|
||||||
|
onboard_wizard,
|
||||||
|
"_get_questionary",
|
||||||
|
lambda: SimpleNamespace(text=lambda *a, **kw: SimpleNamespace(ask=lambda: "")),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = _input_text("Name", "old", "str")
|
||||||
|
assert result == ""
|
||||||
|
|
||||||
|
def test_none_still_returns_none(self, monkeypatch):
|
||||||
|
"""_input_text should return None when questionary returns None."""
|
||||||
|
monkeypatch.setattr(
|
||||||
|
onboard_wizard,
|
||||||
|
"_get_questionary",
|
||||||
|
lambda: SimpleNamespace(text=lambda *a, **kw: SimpleNamespace(ask=lambda: None)),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = _input_text("Name", "old", "str")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestIsStrOrNone:
|
||||||
|
"""Tests for _is_str_or_none helper."""
|
||||||
|
|
||||||
|
def test_str_or_none_true(self):
|
||||||
|
from nanobot.cli.onboard import _is_str_or_none
|
||||||
|
|
||||||
|
assert _is_str_or_none(str | None) is True
|
||||||
|
|
||||||
|
def test_optional_str_true(self):
|
||||||
|
from typing import Optional
|
||||||
|
from nanobot.cli.onboard import _is_str_or_none
|
||||||
|
|
||||||
|
assert _is_str_or_none(Optional[str]) is True
|
||||||
|
|
||||||
|
def test_str_only_false(self):
|
||||||
|
from nanobot.cli.onboard import _is_str_or_none
|
||||||
|
|
||||||
|
assert _is_str_or_none(str) is False
|
||||||
|
|
||||||
|
def test_int_or_none_false(self):
|
||||||
|
from nanobot.cli.onboard import _is_str_or_none
|
||||||
|
|
||||||
|
assert _is_str_or_none(int | None) is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestConfigurePydanticModelEmptyString:
|
||||||
|
"""Tests that optional string fields are cleared when empty string is entered."""
|
||||||
|
|
||||||
|
def test_optional_str_empty_string_becomes_none(self, monkeypatch):
|
||||||
|
"""Entering '' for an optional str field should set it to None."""
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from nanobot.cli.onboard import _is_str_or_none
|
||||||
|
|
||||||
|
class M(BaseModel):
|
||||||
|
api_key: str | None = None
|
||||||
|
|
||||||
|
model = M(api_key="secret")
|
||||||
|
|
||||||
|
call_count = {"select": 0}
|
||||||
|
|
||||||
|
def fake_select(_prompt, choices, default=None):
|
||||||
|
call_count["select"] += 1
|
||||||
|
# First call: select the api_key field, then Done
|
||||||
|
if call_count["select"] == 1:
|
||||||
|
for c in choices:
|
||||||
|
if "Api Key" in c:
|
||||||
|
return c
|
||||||
|
return choices[0]
|
||||||
|
return "[Done]"
|
||||||
|
|
||||||
|
monkeypatch.setattr(onboard_wizard, "_select_with_back", fake_select)
|
||||||
|
monkeypatch.setattr(onboard_wizard, "_show_config_panel", lambda *a, **kw: None)
|
||||||
|
# Simulate user entering empty string
|
||||||
|
monkeypatch.setattr(
|
||||||
|
onboard_wizard, "_input_with_existing", lambda *a, **kw: ""
|
||||||
|
)
|
||||||
|
|
||||||
|
result = _configure_pydantic_model(model, "Test")
|
||||||
|
assert result is not None
|
||||||
|
assert result.api_key is None
|
||||||
|
|
||||||
|
def test_required_str_empty_string_kept(self, monkeypatch):
|
||||||
|
"""Entering '' for a required str field should keep the empty string."""
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
class M(BaseModel):
|
||||||
|
api_key: str = ""
|
||||||
|
|
||||||
|
model = M(api_key="secret")
|
||||||
|
|
||||||
|
call_count = {"select": 0}
|
||||||
|
|
||||||
|
def fake_select(_prompt, choices, default=None):
|
||||||
|
call_count["select"] += 1
|
||||||
|
if call_count["select"] == 1:
|
||||||
|
for c in choices:
|
||||||
|
if "Api Key" in c:
|
||||||
|
return c
|
||||||
|
return choices[0]
|
||||||
|
return "[Done]"
|
||||||
|
|
||||||
|
monkeypatch.setattr(onboard_wizard, "_select_with_back", fake_select)
|
||||||
|
monkeypatch.setattr(onboard_wizard, "_show_config_panel", lambda *a, **kw: None)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
onboard_wizard, "_input_with_existing", lambda *a, **kw: ""
|
||||||
|
)
|
||||||
|
|
||||||
|
result = _configure_pydantic_model(model, "Test")
|
||||||
|
assert result is not None
|
||||||
|
assert result.api_key == ""
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user