fix(config): harden env ref round-trip behavior

This commit is contained in:
Evan Luo 2026-03-19 15:45:21 +00:00 committed by chengyongru
parent dfba8e3248
commit 11b60c84df
4 changed files with 217 additions and 38 deletions

View File

@ -1,13 +1,11 @@
"""Configuration loading utilities."""
import copy
import json
from pathlib import Path
from typing import Any
from nanobot.config.schema import Config
from nanobot.config.secret_resolver import resolve_config, _REF_PATTERN
from nanobot.config.secret_resolver import has_env_ref, resolve_config, resolve_env_vars
# Global variable to store current config path (for multi-instance support)
_current_config_path: Path | None = None
@ -45,10 +43,10 @@ def load_config(config_path: Path | None = None) -> Config:
raw_data = _migrate_config(raw_data)
# Record original values of fields containing {env:VAR} references
env_refs: dict[str, Any] = {}
env_refs: dict[str, dict[str, str]] = {}
_collect_env_refs(raw_data, "", env_refs)
resolved_data = resolve_config(copy.deepcopy(raw_data)) # Resolve {env:VAR} references
resolved_data = resolve_config(raw_data)
config = Config.model_validate(resolved_data)
config._env_refs = env_refs # Preserve original {env:VAR} values for save_config
return config
@ -70,8 +68,7 @@ def save_config(config: Config, config_path: Path | None = None) -> None:
path = config_path or get_config_path()
path.parent.mkdir(parents=True, exist_ok=True)
# Use raw unresolved data if available to preserve {env:VAR} placeholders
# Use model_dump as base, but restore {env:VAR} references from original values
# Preserve original {env:VAR} placeholders only for values unchanged since load.
data = config.model_dump(by_alias=True)
if config._env_refs:
_restore_env_refs(data, config._env_refs)
@ -90,8 +87,8 @@ def _migrate_config(data: dict) -> dict:
return data
def _collect_env_refs(obj: Any, path: str, refs: dict[str, Any]) -> None:
"""Collect field paths and original values for fields containing {env:VAR}."""
def _collect_env_refs(obj: Any, path: str, refs: dict[str, dict[str, str]]) -> None:
"""Collect field paths with original and resolved values for {env:VAR} strings."""
if isinstance(obj, dict):
for key, value in obj.items():
child_path = f"{path}.{key}" if path else key
@ -99,24 +96,97 @@ def _collect_env_refs(obj: Any, path: str, refs: dict[str, Any]) -> None:
elif isinstance(obj, list):
for idx, item in enumerate(obj):
_collect_env_refs(item, f"{path}[{idx}]", refs)
elif isinstance(obj, str) and _REF_PATTERN.search(obj):
refs[path] = obj
elif isinstance(obj, str) and has_env_ref(obj):
refs[path] = {
"original": obj,
"resolved": resolve_env_vars(obj),
}
def _restore_env_refs(data: dict, refs: dict[str, Any]) -> None:
"""Restore original {env:VAR} values into data dict."""
for path, original_value in refs.items():
_set_by_path(data, path, original_value)
def _restore_env_refs(data: dict, refs: dict[str, dict[str, str]]) -> None:
"""Restore original {env:VAR} values into unchanged fields."""
for path, record in refs.items():
found, current_value = _get_by_path(data, path)
if not found:
continue
if current_value == record["resolved"]:
_set_by_path(data, path, record["original"])
def _set_by_path(data: dict, path: str, value: Any) -> None:
"""Set a value in nested dict by dot-notation path like 'providers.zhipu.apiKey'."""
parts = path.split(".")
def _parse_path(path: str) -> list[str | int]:
"""Parse dotted/list path like providers.openai.apiKey or args[0]."""
tokens: list[str | int] = []
buf = ""
i = 0
while i < len(path):
ch = path[i]
if ch == ".":
if buf:
tokens.append(buf)
buf = ""
i += 1
continue
if ch == "[":
if buf:
tokens.append(buf)
buf = ""
close = path.find("]", i + 1)
if close == -1:
return []
idx = path[i + 1 : close]
if not idx.isdigit():
return []
tokens.append(int(idx))
i = close + 1
continue
buf += ch
i += 1
if buf:
tokens.append(buf)
return tokens
def _get_by_path(data: Any, path: str) -> tuple[bool, Any]:
"""Get value from nested dict/list path."""
tokens = _parse_path(path)
if not tokens:
return False, None
current = data
for part in parts[:-1]:
if part not in current:
return
current = current[part]
last_key = parts[-1]
if isinstance(current, dict) and last_key in current:
current[last_key] = value
for token in tokens:
if isinstance(token, int):
if not isinstance(current, list) or token >= len(current):
return False, None
current = current[token]
else:
if not isinstance(current, dict) or token not in current:
return False, None
current = current[token]
return True, current
def _set_by_path(data: Any, path: str, value: Any) -> None:
"""Set value in nested dict/list path."""
tokens = _parse_path(path)
if not tokens:
return
current = data
for token in tokens[:-1]:
if isinstance(token, int):
if not isinstance(current, list) or token >= len(current):
return
current = current[token]
else:
if not isinstance(current, dict) or token not in current:
return
current = current[token]
last = tokens[-1]
if isinstance(last, int):
if isinstance(current, list) and last < len(current):
current[last] = value
return
if isinstance(current, dict) and last in current:
current[last] = value

View File

@ -5,6 +5,7 @@ Supports {env:VARIABLE_NAME} syntax to reference environment variables.
import os
import re
from typing import Any
# Pattern matches {env:VAR_NAME} where VAR_NAME follows env var naming conventions
_REF_PATTERN = re.compile(r"\{env:([A-Z_][A-Z0-9_]*)\}")
@ -18,19 +19,22 @@ def resolve_env_vars(value: str) -> str:
Returns:
String with all {env:VAR} references replaced by their values.
Unresolved references are left unchanged.
Missing env vars resolve to empty string.
"""
def replacer(match: re.Match[str]) -> str:
var_name = match.group(1)
env_value = os.environ.get(var_name)
if env_value is None:
return match.group(0) # Keep original if env var doesn't exist
return env_value
return os.environ.get(var_name, "")
return _REF_PATTERN.sub(replacer, value)
def resolve_config(obj):
def has_env_ref(value: str) -> bool:
"""Return True if string contains at least one {env:VAR} reference."""
return bool(_REF_PATTERN.search(value))
def resolve_config(obj: Any) -> Any:
"""Recursively resolve {env:VAR} references in a configuration object.
Args:

View File

@ -76,7 +76,9 @@ def test_onboard_refresh_rewrites_legacy_config_template(tmp_path, monkeypatch)
)
monkeypatch.setattr("nanobot.config.loader.get_config_path", lambda: config_path)
monkeypatch.setattr("nanobot.cli.commands.get_workspace_path", lambda _workspace=None: workspace)
monkeypatch.setattr(
"nanobot.cli.commands.get_workspace_path", lambda _workspace=None: workspace
)
result = runner.invoke(app, ["onboard"], input="n\n")
@ -109,7 +111,9 @@ def test_onboard_refresh_backfills_missing_channel_fields(tmp_path, monkeypatch)
)
monkeypatch.setattr("nanobot.config.loader.get_config_path", lambda: config_path)
monkeypatch.setattr("nanobot.cli.commands.get_workspace_path", lambda _workspace=None: workspace)
monkeypatch.setattr(
"nanobot.cli.commands.get_workspace_path", lambda _workspace=None: workspace
)
monkeypatch.setattr(
"nanobot.channels.registry.discover_all",
lambda: {
@ -130,3 +134,107 @@ def test_onboard_refresh_backfills_missing_channel_fields(tmp_path, monkeypatch)
assert result.exit_code == 0
saved = json.loads(config_path.read_text(encoding="utf-8"))
assert saved["channels"]["qq"]["msgFormat"] == "plain"
def test_env_ref_round_trip_preserves_placeholder_after_save(tmp_path, monkeypatch) -> None:
config_path = tmp_path / "config.json"
config_path.write_text(
json.dumps(
{
"providers": {
"openai": {
"apiKey": "{env:OPENAI_API_KEY}",
}
}
}
),
encoding="utf-8",
)
monkeypatch.setenv("OPENAI_API_KEY", "sk-runtime")
config = load_config(config_path)
assert config.providers.openai.api_key == "sk-runtime"
save_config(config, config_path)
saved = json.loads(config_path.read_text(encoding="utf-8"))
assert saved["providers"]["openai"]["apiKey"] == "{env:OPENAI_API_KEY}"
def test_env_ref_in_list_round_trip_preserves_placeholder(tmp_path, monkeypatch) -> None:
config_path = tmp_path / "config.json"
config_path.write_text(
json.dumps(
{
"tools": {
"mcpServers": {
"demo": {
"command": "npx",
"args": [
"-y",
"run-tool",
"--token",
"{env:MCP_TOKEN}",
],
}
}
}
}
),
encoding="utf-8",
)
monkeypatch.setenv("MCP_TOKEN", "runtime-token")
config = load_config(config_path)
assert config.tools.mcp_servers["demo"].args[3] == "runtime-token"
save_config(config, config_path)
saved = json.loads(config_path.read_text(encoding="utf-8"))
assert saved["tools"]["mcpServers"]["demo"]["args"][3] == "{env:MCP_TOKEN}"
def test_save_keeps_intentional_in_memory_override_of_env_ref(tmp_path, monkeypatch) -> None:
config_path = tmp_path / "config.json"
config_path.write_text(
json.dumps(
{
"providers": {
"openai": {
"apiKey": "{env:OPENAI_API_KEY}",
}
}
}
),
encoding="utf-8",
)
monkeypatch.setenv("OPENAI_API_KEY", "sk-from-env")
config = load_config(config_path)
config.providers.openai.api_key = "sk-manual-override"
save_config(config, config_path)
saved = json.loads(config_path.read_text(encoding="utf-8"))
assert saved["providers"]["openai"]["apiKey"] == "sk-manual-override"
def test_missing_env_ref_resolves_empty_at_runtime_but_persists_placeholder(tmp_path) -> None:
config_path = tmp_path / "config.json"
config_path.write_text(
json.dumps(
{
"providers": {
"openai": {
"apiKey": "{env:MISSING_OPENAI_KEY}",
}
}
}
),
encoding="utf-8",
)
config = load_config(config_path)
assert config.providers.openai.api_key == ""
assert config.get_provider_name("openai/gpt-4.1") is None
save_config(config, config_path)
saved = json.loads(config_path.read_text(encoding="utf-8"))
assert saved["providers"]["openai"]["apiKey"] == "{env:MISSING_OPENAI_KEY}"

View File

@ -1,7 +1,5 @@
"""Tests for secret_resolver module."""
import os
import pytest
from nanobot.config.secret_resolver import resolve_config, resolve_env_vars
@ -23,9 +21,8 @@ class TestResolveEnvVars:
monkeypatch.setenv("HOST", "example.com")
assert resolve_env_vars("{env:USER}@{env:HOST}") == "alice@example.com"
def test_unresolved_var_kept_unchanged(self) -> None:
# Environment variable that doesn't exist should remain as-is
assert resolve_env_vars("{env:NONEXISTENT_VAR_XYZ}") == "{env:NONEXISTENT_VAR_XYZ}"
def test_unresolved_var_becomes_empty(self) -> None:
assert resolve_env_vars("{env:NONEXISTENT_VAR_XYZ}") == ""
def test_empty_env_var(self, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("EMPTY_VAR", "")