mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-02 17:32:39 +00:00
fix(config): harden env ref round-trip behavior
This commit is contained in:
parent
dfba8e3248
commit
11b60c84df
@ -1,13 +1,11 @@
|
||||
"""Configuration loading utilities."""
|
||||
|
||||
import copy
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from nanobot.config.schema import Config
|
||||
from nanobot.config.secret_resolver import resolve_config, _REF_PATTERN
|
||||
|
||||
from nanobot.config.secret_resolver import has_env_ref, resolve_config, resolve_env_vars
|
||||
|
||||
# Global variable to store current config path (for multi-instance support)
|
||||
_current_config_path: Path | None = None
|
||||
@ -45,10 +43,10 @@ def load_config(config_path: Path | None = None) -> Config:
|
||||
raw_data = _migrate_config(raw_data)
|
||||
|
||||
# Record original values of fields containing {env:VAR} references
|
||||
env_refs: dict[str, Any] = {}
|
||||
env_refs: dict[str, dict[str, str]] = {}
|
||||
_collect_env_refs(raw_data, "", env_refs)
|
||||
|
||||
resolved_data = resolve_config(copy.deepcopy(raw_data)) # Resolve {env:VAR} references
|
||||
resolved_data = resolve_config(raw_data)
|
||||
config = Config.model_validate(resolved_data)
|
||||
config._env_refs = env_refs # Preserve original {env:VAR} values for save_config
|
||||
return config
|
||||
@ -70,8 +68,7 @@ def save_config(config: Config, config_path: Path | None = None) -> None:
|
||||
path = config_path or get_config_path()
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Use raw unresolved data if available to preserve {env:VAR} placeholders
|
||||
# Use model_dump as base, but restore {env:VAR} references from original values
|
||||
# Preserve original {env:VAR} placeholders only for values unchanged since load.
|
||||
data = config.model_dump(by_alias=True)
|
||||
if config._env_refs:
|
||||
_restore_env_refs(data, config._env_refs)
|
||||
@ -90,8 +87,8 @@ def _migrate_config(data: dict) -> dict:
|
||||
return data
|
||||
|
||||
|
||||
def _collect_env_refs(obj: Any, path: str, refs: dict[str, Any]) -> None:
|
||||
"""Collect field paths and original values for fields containing {env:VAR}."""
|
||||
def _collect_env_refs(obj: Any, path: str, refs: dict[str, dict[str, str]]) -> None:
|
||||
"""Collect field paths with original and resolved values for {env:VAR} strings."""
|
||||
if isinstance(obj, dict):
|
||||
for key, value in obj.items():
|
||||
child_path = f"{path}.{key}" if path else key
|
||||
@ -99,24 +96,97 @@ def _collect_env_refs(obj: Any, path: str, refs: dict[str, Any]) -> None:
|
||||
elif isinstance(obj, list):
|
||||
for idx, item in enumerate(obj):
|
||||
_collect_env_refs(item, f"{path}[{idx}]", refs)
|
||||
elif isinstance(obj, str) and _REF_PATTERN.search(obj):
|
||||
refs[path] = obj
|
||||
elif isinstance(obj, str) and has_env_ref(obj):
|
||||
refs[path] = {
|
||||
"original": obj,
|
||||
"resolved": resolve_env_vars(obj),
|
||||
}
|
||||
|
||||
|
||||
def _restore_env_refs(data: dict, refs: dict[str, Any]) -> None:
|
||||
"""Restore original {env:VAR} values into data dict."""
|
||||
for path, original_value in refs.items():
|
||||
_set_by_path(data, path, original_value)
|
||||
def _restore_env_refs(data: dict, refs: dict[str, dict[str, str]]) -> None:
|
||||
"""Restore original {env:VAR} values into unchanged fields."""
|
||||
for path, record in refs.items():
|
||||
found, current_value = _get_by_path(data, path)
|
||||
if not found:
|
||||
continue
|
||||
if current_value == record["resolved"]:
|
||||
_set_by_path(data, path, record["original"])
|
||||
|
||||
|
||||
def _set_by_path(data: dict, path: str, value: Any) -> None:
|
||||
"""Set a value in nested dict by dot-notation path like 'providers.zhipu.apiKey'."""
|
||||
parts = path.split(".")
|
||||
def _parse_path(path: str) -> list[str | int]:
|
||||
"""Parse dotted/list path like providers.openai.apiKey or args[0]."""
|
||||
tokens: list[str | int] = []
|
||||
buf = ""
|
||||
i = 0
|
||||
while i < len(path):
|
||||
ch = path[i]
|
||||
if ch == ".":
|
||||
if buf:
|
||||
tokens.append(buf)
|
||||
buf = ""
|
||||
i += 1
|
||||
continue
|
||||
if ch == "[":
|
||||
if buf:
|
||||
tokens.append(buf)
|
||||
buf = ""
|
||||
close = path.find("]", i + 1)
|
||||
if close == -1:
|
||||
return []
|
||||
idx = path[i + 1 : close]
|
||||
if not idx.isdigit():
|
||||
return []
|
||||
tokens.append(int(idx))
|
||||
i = close + 1
|
||||
continue
|
||||
buf += ch
|
||||
i += 1
|
||||
if buf:
|
||||
tokens.append(buf)
|
||||
return tokens
|
||||
|
||||
|
||||
def _get_by_path(data: Any, path: str) -> tuple[bool, Any]:
|
||||
"""Get value from nested dict/list path."""
|
||||
tokens = _parse_path(path)
|
||||
if not tokens:
|
||||
return False, None
|
||||
|
||||
current = data
|
||||
for part in parts[:-1]:
|
||||
if part not in current:
|
||||
return
|
||||
current = current[part]
|
||||
last_key = parts[-1]
|
||||
if isinstance(current, dict) and last_key in current:
|
||||
current[last_key] = value
|
||||
for token in tokens:
|
||||
if isinstance(token, int):
|
||||
if not isinstance(current, list) or token >= len(current):
|
||||
return False, None
|
||||
current = current[token]
|
||||
else:
|
||||
if not isinstance(current, dict) or token not in current:
|
||||
return False, None
|
||||
current = current[token]
|
||||
return True, current
|
||||
|
||||
|
||||
def _set_by_path(data: Any, path: str, value: Any) -> None:
|
||||
"""Set value in nested dict/list path."""
|
||||
tokens = _parse_path(path)
|
||||
if not tokens:
|
||||
return
|
||||
|
||||
current = data
|
||||
for token in tokens[:-1]:
|
||||
if isinstance(token, int):
|
||||
if not isinstance(current, list) or token >= len(current):
|
||||
return
|
||||
current = current[token]
|
||||
else:
|
||||
if not isinstance(current, dict) or token not in current:
|
||||
return
|
||||
current = current[token]
|
||||
|
||||
last = tokens[-1]
|
||||
if isinstance(last, int):
|
||||
if isinstance(current, list) and last < len(current):
|
||||
current[last] = value
|
||||
return
|
||||
|
||||
if isinstance(current, dict) and last in current:
|
||||
current[last] = value
|
||||
|
||||
@ -5,6 +5,7 @@ Supports {env:VARIABLE_NAME} syntax to reference environment variables.
|
||||
|
||||
import os
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
# Pattern matches {env:VAR_NAME} where VAR_NAME follows env var naming conventions
|
||||
_REF_PATTERN = re.compile(r"\{env:([A-Z_][A-Z0-9_]*)\}")
|
||||
@ -18,19 +19,22 @@ def resolve_env_vars(value: str) -> str:
|
||||
|
||||
Returns:
|
||||
String with all {env:VAR} references replaced by their values.
|
||||
Unresolved references are left unchanged.
|
||||
Missing env vars resolve to empty string.
|
||||
"""
|
||||
|
||||
def replacer(match: re.Match[str]) -> str:
|
||||
var_name = match.group(1)
|
||||
env_value = os.environ.get(var_name)
|
||||
if env_value is None:
|
||||
return match.group(0) # Keep original if env var doesn't exist
|
||||
return env_value
|
||||
return os.environ.get(var_name, "")
|
||||
|
||||
return _REF_PATTERN.sub(replacer, value)
|
||||
|
||||
|
||||
def resolve_config(obj):
|
||||
def has_env_ref(value: str) -> bool:
|
||||
"""Return True if string contains at least one {env:VAR} reference."""
|
||||
return bool(_REF_PATTERN.search(value))
|
||||
|
||||
|
||||
def resolve_config(obj: Any) -> Any:
|
||||
"""Recursively resolve {env:VAR} references in a configuration object.
|
||||
|
||||
Args:
|
||||
|
||||
@ -76,7 +76,9 @@ def test_onboard_refresh_rewrites_legacy_config_template(tmp_path, monkeypatch)
|
||||
)
|
||||
|
||||
monkeypatch.setattr("nanobot.config.loader.get_config_path", lambda: config_path)
|
||||
monkeypatch.setattr("nanobot.cli.commands.get_workspace_path", lambda _workspace=None: workspace)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.cli.commands.get_workspace_path", lambda _workspace=None: workspace
|
||||
)
|
||||
|
||||
result = runner.invoke(app, ["onboard"], input="n\n")
|
||||
|
||||
@ -109,7 +111,9 @@ def test_onboard_refresh_backfills_missing_channel_fields(tmp_path, monkeypatch)
|
||||
)
|
||||
|
||||
monkeypatch.setattr("nanobot.config.loader.get_config_path", lambda: config_path)
|
||||
monkeypatch.setattr("nanobot.cli.commands.get_workspace_path", lambda _workspace=None: workspace)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.cli.commands.get_workspace_path", lambda _workspace=None: workspace
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"nanobot.channels.registry.discover_all",
|
||||
lambda: {
|
||||
@ -130,3 +134,107 @@ def test_onboard_refresh_backfills_missing_channel_fields(tmp_path, monkeypatch)
|
||||
assert result.exit_code == 0
|
||||
saved = json.loads(config_path.read_text(encoding="utf-8"))
|
||||
assert saved["channels"]["qq"]["msgFormat"] == "plain"
|
||||
|
||||
|
||||
def test_env_ref_round_trip_preserves_placeholder_after_save(tmp_path, monkeypatch) -> None:
|
||||
config_path = tmp_path / "config.json"
|
||||
config_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"providers": {
|
||||
"openai": {
|
||||
"apiKey": "{env:OPENAI_API_KEY}",
|
||||
}
|
||||
}
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "sk-runtime")
|
||||
|
||||
config = load_config(config_path)
|
||||
assert config.providers.openai.api_key == "sk-runtime"
|
||||
|
||||
save_config(config, config_path)
|
||||
saved = json.loads(config_path.read_text(encoding="utf-8"))
|
||||
assert saved["providers"]["openai"]["apiKey"] == "{env:OPENAI_API_KEY}"
|
||||
|
||||
|
||||
def test_env_ref_in_list_round_trip_preserves_placeholder(tmp_path, monkeypatch) -> None:
|
||||
config_path = tmp_path / "config.json"
|
||||
config_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"tools": {
|
||||
"mcpServers": {
|
||||
"demo": {
|
||||
"command": "npx",
|
||||
"args": [
|
||||
"-y",
|
||||
"run-tool",
|
||||
"--token",
|
||||
"{env:MCP_TOKEN}",
|
||||
],
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
monkeypatch.setenv("MCP_TOKEN", "runtime-token")
|
||||
|
||||
config = load_config(config_path)
|
||||
assert config.tools.mcp_servers["demo"].args[3] == "runtime-token"
|
||||
|
||||
save_config(config, config_path)
|
||||
saved = json.loads(config_path.read_text(encoding="utf-8"))
|
||||
assert saved["tools"]["mcpServers"]["demo"]["args"][3] == "{env:MCP_TOKEN}"
|
||||
|
||||
|
||||
def test_save_keeps_intentional_in_memory_override_of_env_ref(tmp_path, monkeypatch) -> None:
|
||||
config_path = tmp_path / "config.json"
|
||||
config_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"providers": {
|
||||
"openai": {
|
||||
"apiKey": "{env:OPENAI_API_KEY}",
|
||||
}
|
||||
}
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "sk-from-env")
|
||||
|
||||
config = load_config(config_path)
|
||||
config.providers.openai.api_key = "sk-manual-override"
|
||||
|
||||
save_config(config, config_path)
|
||||
saved = json.loads(config_path.read_text(encoding="utf-8"))
|
||||
assert saved["providers"]["openai"]["apiKey"] == "sk-manual-override"
|
||||
|
||||
|
||||
def test_missing_env_ref_resolves_empty_at_runtime_but_persists_placeholder(tmp_path) -> None:
|
||||
config_path = tmp_path / "config.json"
|
||||
config_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"providers": {
|
||||
"openai": {
|
||||
"apiKey": "{env:MISSING_OPENAI_KEY}",
|
||||
}
|
||||
}
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
config = load_config(config_path)
|
||||
assert config.providers.openai.api_key == ""
|
||||
assert config.get_provider_name("openai/gpt-4.1") is None
|
||||
|
||||
save_config(config, config_path)
|
||||
saved = json.loads(config_path.read_text(encoding="utf-8"))
|
||||
assert saved["providers"]["openai"]["apiKey"] == "{env:MISSING_OPENAI_KEY}"
|
||||
|
||||
@ -1,7 +1,5 @@
|
||||
"""Tests for secret_resolver module."""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.config.secret_resolver import resolve_config, resolve_env_vars
|
||||
@ -23,9 +21,8 @@ class TestResolveEnvVars:
|
||||
monkeypatch.setenv("HOST", "example.com")
|
||||
assert resolve_env_vars("{env:USER}@{env:HOST}") == "alice@example.com"
|
||||
|
||||
def test_unresolved_var_kept_unchanged(self) -> None:
|
||||
# Environment variable that doesn't exist should remain as-is
|
||||
assert resolve_env_vars("{env:NONEXISTENT_VAR_XYZ}") == "{env:NONEXISTENT_VAR_XYZ}"
|
||||
def test_unresolved_var_becomes_empty(self) -> None:
|
||||
assert resolve_env_vars("{env:NONEXISTENT_VAR_XYZ}") == ""
|
||||
|
||||
def test_empty_env_var(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("EMPTY_VAR", "")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user