From 11b60c84df42ac67d58bcfbd82abe60be933b789 Mon Sep 17 00:00:00 2001 From: Evan Luo Date: Thu, 19 Mar 2026 15:45:21 +0000 Subject: [PATCH] fix(config): harden env ref round-trip behavior --- nanobot/config/loader.py | 120 +++++++++++++++++++++++------- nanobot/config/secret_resolver.py | 16 ++-- tests/test_config_migration.py | 112 +++++++++++++++++++++++++++- tests/test_secret_resolver.py | 7 +- 4 files changed, 217 insertions(+), 38 deletions(-) diff --git a/nanobot/config/loader.py b/nanobot/config/loader.py index 6fb6919c6..aa8efcfc7 100644 --- a/nanobot/config/loader.py +++ b/nanobot/config/loader.py @@ -1,13 +1,11 @@ """Configuration loading utilities.""" -import copy import json from pathlib import Path from typing import Any from nanobot.config.schema import Config -from nanobot.config.secret_resolver import resolve_config, _REF_PATTERN - +from nanobot.config.secret_resolver import has_env_ref, resolve_config, resolve_env_vars # Global variable to store current config path (for multi-instance support) _current_config_path: Path | None = None @@ -45,10 +43,10 @@ def load_config(config_path: Path | None = None) -> Config: raw_data = _migrate_config(raw_data) # Record original values of fields containing {env:VAR} references - env_refs: dict[str, Any] = {} + env_refs: dict[str, dict[str, str]] = {} _collect_env_refs(raw_data, "", env_refs) - resolved_data = resolve_config(copy.deepcopy(raw_data)) # Resolve {env:VAR} references + resolved_data = resolve_config(raw_data) config = Config.model_validate(resolved_data) config._env_refs = env_refs # Preserve original {env:VAR} values for save_config return config @@ -70,8 +68,7 @@ def save_config(config: Config, config_path: Path | None = None) -> None: path = config_path or get_config_path() path.parent.mkdir(parents=True, exist_ok=True) - # Use raw unresolved data if available to preserve {env:VAR} placeholders - # Use model_dump as base, but restore {env:VAR} references from original values + # Preserve original {env:VAR} placeholders only for values unchanged since load. data = config.model_dump(by_alias=True) if config._env_refs: _restore_env_refs(data, config._env_refs) @@ -90,8 +87,8 @@ def _migrate_config(data: dict) -> dict: return data -def _collect_env_refs(obj: Any, path: str, refs: dict[str, Any]) -> None: - """Collect field paths and original values for fields containing {env:VAR}.""" +def _collect_env_refs(obj: Any, path: str, refs: dict[str, dict[str, str]]) -> None: + """Collect field paths with original and resolved values for {env:VAR} strings.""" if isinstance(obj, dict): for key, value in obj.items(): child_path = f"{path}.{key}" if path else key @@ -99,24 +96,97 @@ def _collect_env_refs(obj: Any, path: str, refs: dict[str, Any]) -> None: elif isinstance(obj, list): for idx, item in enumerate(obj): _collect_env_refs(item, f"{path}[{idx}]", refs) - elif isinstance(obj, str) and _REF_PATTERN.search(obj): - refs[path] = obj + elif isinstance(obj, str) and has_env_ref(obj): + refs[path] = { + "original": obj, + "resolved": resolve_env_vars(obj), + } -def _restore_env_refs(data: dict, refs: dict[str, Any]) -> None: - """Restore original {env:VAR} values into data dict.""" - for path, original_value in refs.items(): - _set_by_path(data, path, original_value) +def _restore_env_refs(data: dict, refs: dict[str, dict[str, str]]) -> None: + """Restore original {env:VAR} values into unchanged fields.""" + for path, record in refs.items(): + found, current_value = _get_by_path(data, path) + if not found: + continue + if current_value == record["resolved"]: + _set_by_path(data, path, record["original"]) -def _set_by_path(data: dict, path: str, value: Any) -> None: - """Set a value in nested dict by dot-notation path like 'providers.zhipu.apiKey'.""" - parts = path.split(".") +def _parse_path(path: str) -> list[str | int]: + """Parse dotted/list path like providers.openai.apiKey or args[0].""" + tokens: list[str | int] = [] + buf = "" + i = 0 + while i < len(path): + ch = path[i] + if ch == ".": + if buf: + tokens.append(buf) + buf = "" + i += 1 + continue + if ch == "[": + if buf: + tokens.append(buf) + buf = "" + close = path.find("]", i + 1) + if close == -1: + return [] + idx = path[i + 1 : close] + if not idx.isdigit(): + return [] + tokens.append(int(idx)) + i = close + 1 + continue + buf += ch + i += 1 + if buf: + tokens.append(buf) + return tokens + + +def _get_by_path(data: Any, path: str) -> tuple[bool, Any]: + """Get value from nested dict/list path.""" + tokens = _parse_path(path) + if not tokens: + return False, None + current = data - for part in parts[:-1]: - if part not in current: - return - current = current[part] - last_key = parts[-1] - if isinstance(current, dict) and last_key in current: - current[last_key] = value + for token in tokens: + if isinstance(token, int): + if not isinstance(current, list) or token >= len(current): + return False, None + current = current[token] + else: + if not isinstance(current, dict) or token not in current: + return False, None + current = current[token] + return True, current + + +def _set_by_path(data: Any, path: str, value: Any) -> None: + """Set value in nested dict/list path.""" + tokens = _parse_path(path) + if not tokens: + return + + current = data + for token in tokens[:-1]: + if isinstance(token, int): + if not isinstance(current, list) or token >= len(current): + return + current = current[token] + else: + if not isinstance(current, dict) or token not in current: + return + current = current[token] + + last = tokens[-1] + if isinstance(last, int): + if isinstance(current, list) and last < len(current): + current[last] = value + return + + if isinstance(current, dict) and last in current: + current[last] = value diff --git a/nanobot/config/secret_resolver.py b/nanobot/config/secret_resolver.py index dac19c1b5..c81a69923 100644 --- a/nanobot/config/secret_resolver.py +++ b/nanobot/config/secret_resolver.py @@ -5,6 +5,7 @@ Supports {env:VARIABLE_NAME} syntax to reference environment variables. import os import re +from typing import Any # Pattern matches {env:VAR_NAME} where VAR_NAME follows env var naming conventions _REF_PATTERN = re.compile(r"\{env:([A-Z_][A-Z0-9_]*)\}") @@ -18,19 +19,22 @@ def resolve_env_vars(value: str) -> str: Returns: String with all {env:VAR} references replaced by their values. - Unresolved references are left unchanged. + Missing env vars resolve to empty string. """ + def replacer(match: re.Match[str]) -> str: var_name = match.group(1) - env_value = os.environ.get(var_name) - if env_value is None: - return match.group(0) # Keep original if env var doesn't exist - return env_value + return os.environ.get(var_name, "") return _REF_PATTERN.sub(replacer, value) -def resolve_config(obj): +def has_env_ref(value: str) -> bool: + """Return True if string contains at least one {env:VAR} reference.""" + return bool(_REF_PATTERN.search(value)) + + +def resolve_config(obj: Any) -> Any: """Recursively resolve {env:VAR} references in a configuration object. Args: diff --git a/tests/test_config_migration.py b/tests/test_config_migration.py index 2a446b774..42b19806d 100644 --- a/tests/test_config_migration.py +++ b/tests/test_config_migration.py @@ -76,7 +76,9 @@ def test_onboard_refresh_rewrites_legacy_config_template(tmp_path, monkeypatch) ) monkeypatch.setattr("nanobot.config.loader.get_config_path", lambda: config_path) - monkeypatch.setattr("nanobot.cli.commands.get_workspace_path", lambda _workspace=None: workspace) + monkeypatch.setattr( + "nanobot.cli.commands.get_workspace_path", lambda _workspace=None: workspace + ) result = runner.invoke(app, ["onboard"], input="n\n") @@ -109,7 +111,9 @@ def test_onboard_refresh_backfills_missing_channel_fields(tmp_path, monkeypatch) ) monkeypatch.setattr("nanobot.config.loader.get_config_path", lambda: config_path) - monkeypatch.setattr("nanobot.cli.commands.get_workspace_path", lambda _workspace=None: workspace) + monkeypatch.setattr( + "nanobot.cli.commands.get_workspace_path", lambda _workspace=None: workspace + ) monkeypatch.setattr( "nanobot.channels.registry.discover_all", lambda: { @@ -130,3 +134,107 @@ def test_onboard_refresh_backfills_missing_channel_fields(tmp_path, monkeypatch) assert result.exit_code == 0 saved = json.loads(config_path.read_text(encoding="utf-8")) assert saved["channels"]["qq"]["msgFormat"] == "plain" + + +def test_env_ref_round_trip_preserves_placeholder_after_save(tmp_path, monkeypatch) -> None: + config_path = tmp_path / "config.json" + config_path.write_text( + json.dumps( + { + "providers": { + "openai": { + "apiKey": "{env:OPENAI_API_KEY}", + } + } + } + ), + encoding="utf-8", + ) + monkeypatch.setenv("OPENAI_API_KEY", "sk-runtime") + + config = load_config(config_path) + assert config.providers.openai.api_key == "sk-runtime" + + save_config(config, config_path) + saved = json.loads(config_path.read_text(encoding="utf-8")) + assert saved["providers"]["openai"]["apiKey"] == "{env:OPENAI_API_KEY}" + + +def test_env_ref_in_list_round_trip_preserves_placeholder(tmp_path, monkeypatch) -> None: + config_path = tmp_path / "config.json" + config_path.write_text( + json.dumps( + { + "tools": { + "mcpServers": { + "demo": { + "command": "npx", + "args": [ + "-y", + "run-tool", + "--token", + "{env:MCP_TOKEN}", + ], + } + } + } + } + ), + encoding="utf-8", + ) + monkeypatch.setenv("MCP_TOKEN", "runtime-token") + + config = load_config(config_path) + assert config.tools.mcp_servers["demo"].args[3] == "runtime-token" + + save_config(config, config_path) + saved = json.loads(config_path.read_text(encoding="utf-8")) + assert saved["tools"]["mcpServers"]["demo"]["args"][3] == "{env:MCP_TOKEN}" + + +def test_save_keeps_intentional_in_memory_override_of_env_ref(tmp_path, monkeypatch) -> None: + config_path = tmp_path / "config.json" + config_path.write_text( + json.dumps( + { + "providers": { + "openai": { + "apiKey": "{env:OPENAI_API_KEY}", + } + } + } + ), + encoding="utf-8", + ) + monkeypatch.setenv("OPENAI_API_KEY", "sk-from-env") + + config = load_config(config_path) + config.providers.openai.api_key = "sk-manual-override" + + save_config(config, config_path) + saved = json.loads(config_path.read_text(encoding="utf-8")) + assert saved["providers"]["openai"]["apiKey"] == "sk-manual-override" + + +def test_missing_env_ref_resolves_empty_at_runtime_but_persists_placeholder(tmp_path) -> None: + config_path = tmp_path / "config.json" + config_path.write_text( + json.dumps( + { + "providers": { + "openai": { + "apiKey": "{env:MISSING_OPENAI_KEY}", + } + } + } + ), + encoding="utf-8", + ) + + config = load_config(config_path) + assert config.providers.openai.api_key == "" + assert config.get_provider_name("openai/gpt-4.1") is None + + save_config(config, config_path) + saved = json.loads(config_path.read_text(encoding="utf-8")) + assert saved["providers"]["openai"]["apiKey"] == "{env:MISSING_OPENAI_KEY}" diff --git a/tests/test_secret_resolver.py b/tests/test_secret_resolver.py index ef6282dda..bc079ed18 100644 --- a/tests/test_secret_resolver.py +++ b/tests/test_secret_resolver.py @@ -1,7 +1,5 @@ """Tests for secret_resolver module.""" -import os - import pytest from nanobot.config.secret_resolver import resolve_config, resolve_env_vars @@ -23,9 +21,8 @@ class TestResolveEnvVars: monkeypatch.setenv("HOST", "example.com") assert resolve_env_vars("{env:USER}@{env:HOST}") == "alice@example.com" - def test_unresolved_var_kept_unchanged(self) -> None: - # Environment variable that doesn't exist should remain as-is - assert resolve_env_vars("{env:NONEXISTENT_VAR_XYZ}") == "{env:NONEXISTENT_VAR_XYZ}" + def test_unresolved_var_becomes_empty(self) -> None: + assert resolve_env_vars("{env:NONEXISTENT_VAR_XYZ}") == "" def test_empty_env_var(self, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setenv("EMPTY_VAR", "")