mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-02 17:32:39 +00:00
fix(config): preserve {env:VAR} format when saving config
Before this fix, load_config would resolve {env:VAR} references but
save_config would write back the resolved (plaintext) values, losing
the original placeholder format.
Now we record the original {env:VAR} values during load and restore
them during save, so the config file always keeps the placeholder
syntax. This works without requiring changes to any consuming code.
This commit is contained in:
parent
7162549eb3
commit
dfba8e3248
@ -1,10 +1,12 @@
|
||||
"""Configuration loading utilities."""
|
||||
|
||||
import copy
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from nanobot.config.schema import Config
|
||||
from nanobot.config.secret_resolver import resolve_config
|
||||
from nanobot.config.secret_resolver import resolve_config, _REF_PATTERN
|
||||
|
||||
|
||||
# Global variable to store current config path (for multi-instance support)
|
||||
@ -39,10 +41,17 @@ def load_config(config_path: Path | None = None) -> Config:
|
||||
if path.exists():
|
||||
try:
|
||||
with open(path, encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
data = _migrate_config(data)
|
||||
data = resolve_config(data) # Resolve {env:VAR} references
|
||||
return Config.model_validate(data)
|
||||
raw_data = json.load(f)
|
||||
raw_data = _migrate_config(raw_data)
|
||||
|
||||
# Record original values of fields containing {env:VAR} references
|
||||
env_refs: dict[str, Any] = {}
|
||||
_collect_env_refs(raw_data, "", env_refs)
|
||||
|
||||
resolved_data = resolve_config(copy.deepcopy(raw_data)) # Resolve {env:VAR} references
|
||||
config = Config.model_validate(resolved_data)
|
||||
config._env_refs = env_refs # Preserve original {env:VAR} values for save_config
|
||||
return config
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
print(f"Warning: Failed to load config from {path}: {e}")
|
||||
print("Using default configuration.")
|
||||
@ -61,7 +70,11 @@ def save_config(config: Config, config_path: Path | None = None) -> None:
|
||||
path = config_path or get_config_path()
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Use raw unresolved data if available to preserve {env:VAR} placeholders
|
||||
# Use model_dump as base, but restore {env:VAR} references from original values
|
||||
data = config.model_dump(by_alias=True)
|
||||
if config._env_refs:
|
||||
_restore_env_refs(data, config._env_refs)
|
||||
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||
@ -75,3 +88,35 @@ def _migrate_config(data: dict) -> dict:
|
||||
if "restrictToWorkspace" in exec_cfg and "restrictToWorkspace" not in tools:
|
||||
tools["restrictToWorkspace"] = exec_cfg.pop("restrictToWorkspace")
|
||||
return data
|
||||
|
||||
|
||||
def _collect_env_refs(obj: Any, path: str, refs: dict[str, Any]) -> None:
|
||||
"""Collect field paths and original values for fields containing {env:VAR}."""
|
||||
if isinstance(obj, dict):
|
||||
for key, value in obj.items():
|
||||
child_path = f"{path}.{key}" if path else key
|
||||
_collect_env_refs(value, child_path, refs)
|
||||
elif isinstance(obj, list):
|
||||
for idx, item in enumerate(obj):
|
||||
_collect_env_refs(item, f"{path}[{idx}]", refs)
|
||||
elif isinstance(obj, str) and _REF_PATTERN.search(obj):
|
||||
refs[path] = obj
|
||||
|
||||
|
||||
def _restore_env_refs(data: dict, refs: dict[str, Any]) -> None:
|
||||
"""Restore original {env:VAR} values into data dict."""
|
||||
for path, original_value in refs.items():
|
||||
_set_by_path(data, path, original_value)
|
||||
|
||||
|
||||
def _set_by_path(data: dict, path: str, value: Any) -> None:
|
||||
"""Set a value in nested dict by dot-notation path like 'providers.zhipu.apiKey'."""
|
||||
parts = path.split(".")
|
||||
current = data
|
||||
for part in parts[:-1]:
|
||||
if part not in current:
|
||||
return
|
||||
current = current[part]
|
||||
last_key = parts[-1]
|
||||
if isinstance(current, dict) and last_key in current:
|
||||
current[last_key] = value
|
||||
|
||||
@ -159,6 +159,9 @@ class Config(BaseSettings):
|
||||
gateway: GatewayConfig = Field(default_factory=GatewayConfig)
|
||||
tools: ToolsConfig = Field(default_factory=ToolsConfig)
|
||||
|
||||
# Reserved field to store original {env:VAR} values for save_config
|
||||
_env_refs: dict | None = None
|
||||
|
||||
@property
|
||||
def workspace_path(self) -> Path:
|
||||
"""Get expanded workspace path."""
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user