fix(config): harden env ref round-trip behavior

This commit is contained in:
Evan Luo 2026-03-19 15:45:21 +00:00 committed by chengyongru
parent dfba8e3248
commit 11b60c84df
4 changed files with 217 additions and 38 deletions

View File

@ -1,13 +1,11 @@
"""Configuration loading utilities.""" """Configuration loading utilities."""
import copy
import json import json
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from nanobot.config.schema import Config from nanobot.config.schema import Config
from nanobot.config.secret_resolver import resolve_config, _REF_PATTERN from nanobot.config.secret_resolver import has_env_ref, resolve_config, resolve_env_vars
# Global variable to store current config path (for multi-instance support) # Global variable to store current config path (for multi-instance support)
_current_config_path: Path | None = None _current_config_path: Path | None = None
@ -45,10 +43,10 @@ def load_config(config_path: Path | None = None) -> Config:
raw_data = _migrate_config(raw_data) raw_data = _migrate_config(raw_data)
# Record original values of fields containing {env:VAR} references # Record original values of fields containing {env:VAR} references
env_refs: dict[str, Any] = {} env_refs: dict[str, dict[str, str]] = {}
_collect_env_refs(raw_data, "", env_refs) _collect_env_refs(raw_data, "", env_refs)
resolved_data = resolve_config(copy.deepcopy(raw_data)) # Resolve {env:VAR} references resolved_data = resolve_config(raw_data)
config = Config.model_validate(resolved_data) config = Config.model_validate(resolved_data)
config._env_refs = env_refs # Preserve original {env:VAR} values for save_config config._env_refs = env_refs # Preserve original {env:VAR} values for save_config
return config return config
@ -70,8 +68,7 @@ def save_config(config: Config, config_path: Path | None = None) -> None:
path = config_path or get_config_path() path = config_path or get_config_path()
path.parent.mkdir(parents=True, exist_ok=True) path.parent.mkdir(parents=True, exist_ok=True)
# Use raw unresolved data if available to preserve {env:VAR} placeholders # Preserve original {env:VAR} placeholders only for values unchanged since load.
# Use model_dump as base, but restore {env:VAR} references from original values
data = config.model_dump(by_alias=True) data = config.model_dump(by_alias=True)
if config._env_refs: if config._env_refs:
_restore_env_refs(data, config._env_refs) _restore_env_refs(data, config._env_refs)
@ -90,8 +87,8 @@ def _migrate_config(data: dict) -> dict:
return data return data
def _collect_env_refs(obj: Any, path: str, refs: dict[str, Any]) -> None: def _collect_env_refs(obj: Any, path: str, refs: dict[str, dict[str, str]]) -> None:
"""Collect field paths and original values for fields containing {env:VAR}.""" """Collect field paths with original and resolved values for {env:VAR} strings."""
if isinstance(obj, dict): if isinstance(obj, dict):
for key, value in obj.items(): for key, value in obj.items():
child_path = f"{path}.{key}" if path else key child_path = f"{path}.{key}" if path else key
@ -99,24 +96,97 @@ def _collect_env_refs(obj: Any, path: str, refs: dict[str, Any]) -> None:
elif isinstance(obj, list): elif isinstance(obj, list):
for idx, item in enumerate(obj): for idx, item in enumerate(obj):
_collect_env_refs(item, f"{path}[{idx}]", refs) _collect_env_refs(item, f"{path}[{idx}]", refs)
elif isinstance(obj, str) and _REF_PATTERN.search(obj): elif isinstance(obj, str) and has_env_ref(obj):
refs[path] = obj refs[path] = {
"original": obj,
"resolved": resolve_env_vars(obj),
}
def _restore_env_refs(data: dict, refs: dict[str, Any]) -> None: def _restore_env_refs(data: dict, refs: dict[str, dict[str, str]]) -> None:
"""Restore original {env:VAR} values into data dict.""" """Restore original {env:VAR} values into unchanged fields."""
for path, original_value in refs.items(): for path, record in refs.items():
_set_by_path(data, path, original_value) found, current_value = _get_by_path(data, path)
if not found:
continue
if current_value == record["resolved"]:
_set_by_path(data, path, record["original"])
def _set_by_path(data: dict, path: str, value: Any) -> None: def _parse_path(path: str) -> list[str | int]:
"""Set a value in nested dict by dot-notation path like 'providers.zhipu.apiKey'.""" """Parse dotted/list path like providers.openai.apiKey or args[0]."""
parts = path.split(".") tokens: list[str | int] = []
buf = ""
i = 0
while i < len(path):
ch = path[i]
if ch == ".":
if buf:
tokens.append(buf)
buf = ""
i += 1
continue
if ch == "[":
if buf:
tokens.append(buf)
buf = ""
close = path.find("]", i + 1)
if close == -1:
return []
idx = path[i + 1 : close]
if not idx.isdigit():
return []
tokens.append(int(idx))
i = close + 1
continue
buf += ch
i += 1
if buf:
tokens.append(buf)
return tokens
def _get_by_path(data: Any, path: str) -> tuple[bool, Any]:
"""Get value from nested dict/list path."""
tokens = _parse_path(path)
if not tokens:
return False, None
current = data current = data
for part in parts[:-1]: for token in tokens:
if part not in current: if isinstance(token, int):
return if not isinstance(current, list) or token >= len(current):
current = current[part] return False, None
last_key = parts[-1] current = current[token]
if isinstance(current, dict) and last_key in current: else:
current[last_key] = value if not isinstance(current, dict) or token not in current:
return False, None
current = current[token]
return True, current
def _set_by_path(data: Any, path: str, value: Any) -> None:
"""Set value in nested dict/list path."""
tokens = _parse_path(path)
if not tokens:
return
current = data
for token in tokens[:-1]:
if isinstance(token, int):
if not isinstance(current, list) or token >= len(current):
return
current = current[token]
else:
if not isinstance(current, dict) or token not in current:
return
current = current[token]
last = tokens[-1]
if isinstance(last, int):
if isinstance(current, list) and last < len(current):
current[last] = value
return
if isinstance(current, dict) and last in current:
current[last] = value

View File

@ -5,6 +5,7 @@ Supports {env:VARIABLE_NAME} syntax to reference environment variables.
import os import os
import re import re
from typing import Any
# Pattern matches {env:VAR_NAME} where VAR_NAME follows env var naming conventions # Pattern matches {env:VAR_NAME} where VAR_NAME follows env var naming conventions
_REF_PATTERN = re.compile(r"\{env:([A-Z_][A-Z0-9_]*)\}") _REF_PATTERN = re.compile(r"\{env:([A-Z_][A-Z0-9_]*)\}")
@ -18,19 +19,22 @@ def resolve_env_vars(value: str) -> str:
Returns: Returns:
String with all {env:VAR} references replaced by their values. String with all {env:VAR} references replaced by their values.
Unresolved references are left unchanged. Missing env vars resolve to empty string.
""" """
def replacer(match: re.Match[str]) -> str: def replacer(match: re.Match[str]) -> str:
var_name = match.group(1) var_name = match.group(1)
env_value = os.environ.get(var_name) return os.environ.get(var_name, "")
if env_value is None:
return match.group(0) # Keep original if env var doesn't exist
return env_value
return _REF_PATTERN.sub(replacer, value) return _REF_PATTERN.sub(replacer, value)
def resolve_config(obj): def has_env_ref(value: str) -> bool:
"""Return True if string contains at least one {env:VAR} reference."""
return bool(_REF_PATTERN.search(value))
def resolve_config(obj: Any) -> Any:
"""Recursively resolve {env:VAR} references in a configuration object. """Recursively resolve {env:VAR} references in a configuration object.
Args: Args:

View File

@ -76,7 +76,9 @@ def test_onboard_refresh_rewrites_legacy_config_template(tmp_path, monkeypatch)
) )
monkeypatch.setattr("nanobot.config.loader.get_config_path", lambda: config_path) monkeypatch.setattr("nanobot.config.loader.get_config_path", lambda: config_path)
monkeypatch.setattr("nanobot.cli.commands.get_workspace_path", lambda _workspace=None: workspace) monkeypatch.setattr(
"nanobot.cli.commands.get_workspace_path", lambda _workspace=None: workspace
)
result = runner.invoke(app, ["onboard"], input="n\n") result = runner.invoke(app, ["onboard"], input="n\n")
@ -109,7 +111,9 @@ def test_onboard_refresh_backfills_missing_channel_fields(tmp_path, monkeypatch)
) )
monkeypatch.setattr("nanobot.config.loader.get_config_path", lambda: config_path) monkeypatch.setattr("nanobot.config.loader.get_config_path", lambda: config_path)
monkeypatch.setattr("nanobot.cli.commands.get_workspace_path", lambda _workspace=None: workspace) monkeypatch.setattr(
"nanobot.cli.commands.get_workspace_path", lambda _workspace=None: workspace
)
monkeypatch.setattr( monkeypatch.setattr(
"nanobot.channels.registry.discover_all", "nanobot.channels.registry.discover_all",
lambda: { lambda: {
@ -130,3 +134,107 @@ def test_onboard_refresh_backfills_missing_channel_fields(tmp_path, monkeypatch)
assert result.exit_code == 0 assert result.exit_code == 0
saved = json.loads(config_path.read_text(encoding="utf-8")) saved = json.loads(config_path.read_text(encoding="utf-8"))
assert saved["channels"]["qq"]["msgFormat"] == "plain" assert saved["channels"]["qq"]["msgFormat"] == "plain"
def test_env_ref_round_trip_preserves_placeholder_after_save(tmp_path, monkeypatch) -> None:
config_path = tmp_path / "config.json"
config_path.write_text(
json.dumps(
{
"providers": {
"openai": {
"apiKey": "{env:OPENAI_API_KEY}",
}
}
}
),
encoding="utf-8",
)
monkeypatch.setenv("OPENAI_API_KEY", "sk-runtime")
config = load_config(config_path)
assert config.providers.openai.api_key == "sk-runtime"
save_config(config, config_path)
saved = json.loads(config_path.read_text(encoding="utf-8"))
assert saved["providers"]["openai"]["apiKey"] == "{env:OPENAI_API_KEY}"
def test_env_ref_in_list_round_trip_preserves_placeholder(tmp_path, monkeypatch) -> None:
config_path = tmp_path / "config.json"
config_path.write_text(
json.dumps(
{
"tools": {
"mcpServers": {
"demo": {
"command": "npx",
"args": [
"-y",
"run-tool",
"--token",
"{env:MCP_TOKEN}",
],
}
}
}
}
),
encoding="utf-8",
)
monkeypatch.setenv("MCP_TOKEN", "runtime-token")
config = load_config(config_path)
assert config.tools.mcp_servers["demo"].args[3] == "runtime-token"
save_config(config, config_path)
saved = json.loads(config_path.read_text(encoding="utf-8"))
assert saved["tools"]["mcpServers"]["demo"]["args"][3] == "{env:MCP_TOKEN}"
def test_save_keeps_intentional_in_memory_override_of_env_ref(tmp_path, monkeypatch) -> None:
config_path = tmp_path / "config.json"
config_path.write_text(
json.dumps(
{
"providers": {
"openai": {
"apiKey": "{env:OPENAI_API_KEY}",
}
}
}
),
encoding="utf-8",
)
monkeypatch.setenv("OPENAI_API_KEY", "sk-from-env")
config = load_config(config_path)
config.providers.openai.api_key = "sk-manual-override"
save_config(config, config_path)
saved = json.loads(config_path.read_text(encoding="utf-8"))
assert saved["providers"]["openai"]["apiKey"] == "sk-manual-override"
def test_missing_env_ref_resolves_empty_at_runtime_but_persists_placeholder(tmp_path) -> None:
config_path = tmp_path / "config.json"
config_path.write_text(
json.dumps(
{
"providers": {
"openai": {
"apiKey": "{env:MISSING_OPENAI_KEY}",
}
}
}
),
encoding="utf-8",
)
config = load_config(config_path)
assert config.providers.openai.api_key == ""
assert config.get_provider_name("openai/gpt-4.1") is None
save_config(config, config_path)
saved = json.loads(config_path.read_text(encoding="utf-8"))
assert saved["providers"]["openai"]["apiKey"] == "{env:MISSING_OPENAI_KEY}"

View File

@ -1,7 +1,5 @@
"""Tests for secret_resolver module.""" """Tests for secret_resolver module."""
import os
import pytest import pytest
from nanobot.config.secret_resolver import resolve_config, resolve_env_vars from nanobot.config.secret_resolver import resolve_config, resolve_env_vars
@ -23,9 +21,8 @@ class TestResolveEnvVars:
monkeypatch.setenv("HOST", "example.com") monkeypatch.setenv("HOST", "example.com")
assert resolve_env_vars("{env:USER}@{env:HOST}") == "alice@example.com" assert resolve_env_vars("{env:USER}@{env:HOST}") == "alice@example.com"
def test_unresolved_var_kept_unchanged(self) -> None: def test_unresolved_var_becomes_empty(self) -> None:
# Environment variable that doesn't exist should remain as-is assert resolve_env_vars("{env:NONEXISTENT_VAR_XYZ}") == ""
assert resolve_env_vars("{env:NONEXISTENT_VAR_XYZ}") == "{env:NONEXISTENT_VAR_XYZ}"
def test_empty_env_var(self, monkeypatch: pytest.MonkeyPatch) -> None: def test_empty_env_var(self, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("EMPTY_VAR", "") monkeypatch.setenv("EMPTY_VAR", "")