mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-20 16:42:25 +00:00
fix(config): reconcile presets with settings reload
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
parent
c450d6fd3f
commit
b61c6304c3
@ -293,7 +293,7 @@ class AgentLoop:
|
|||||||
provider_signature: tuple[object, ...] | None = None,
|
provider_signature: tuple[object, ...] | None = None,
|
||||||
model_presets: dict[str, ModelPresetConfig] | None = None,
|
model_presets: dict[str, ModelPresetConfig] | None = None,
|
||||||
model_preset: str | None = None,
|
model_preset: str | None = None,
|
||||||
model_preset_snapshot_builder: Callable[[ModelPresetConfig], ProviderSnapshot] | None = None,
|
model_preset_snapshot_builder: Callable[[str], ProviderSnapshot] | None = None,
|
||||||
):
|
):
|
||||||
from nanobot.config.schema import ToolsConfig
|
from nanobot.config.schema import ToolsConfig
|
||||||
|
|
||||||
@ -304,6 +304,10 @@ class AgentLoop:
|
|||||||
self.provider = provider
|
self.provider = provider
|
||||||
self._provider_snapshot_loader = provider_snapshot_loader
|
self._provider_snapshot_loader = provider_snapshot_loader
|
||||||
self._provider_signature = provider_signature
|
self._provider_signature = provider_signature
|
||||||
|
self._config_provider_signature = provider_signature
|
||||||
|
self._config_default_selection_signature = (
|
||||||
|
provider_signature[:2] if provider_signature is not None else None
|
||||||
|
)
|
||||||
self.workspace = workspace
|
self.workspace = workspace
|
||||||
self.model = model or provider.get_default_model()
|
self.model = model or provider.get_default_model()
|
||||||
self.max_iterations = (
|
self.max_iterations = (
|
||||||
@ -431,6 +435,7 @@ class AgentLoop:
|
|||||||
resolved = config.resolve_preset()
|
resolved = config.resolve_preset()
|
||||||
model = extra.pop("model", None) or resolved.model
|
model = extra.pop("model", None) or resolved.model
|
||||||
context_window_tokens = extra.pop("context_window_tokens", None) or resolved.context_window_tokens
|
context_window_tokens = extra.pop("context_window_tokens", None) or resolved.context_window_tokens
|
||||||
|
model_preset_snapshot_builder = extra.pop("model_preset_snapshot_builder", None)
|
||||||
model_presets = dict(config.model_presets)
|
model_presets = dict(config.model_presets)
|
||||||
if "default" not in model_presets:
|
if "default" not in model_presets:
|
||||||
model_presets["default"] = resolved
|
model_presets["default"] = resolved
|
||||||
@ -457,7 +462,10 @@ class AgentLoop:
|
|||||||
tools_config=config.tools,
|
tools_config=config.tools,
|
||||||
model_presets=model_presets,
|
model_presets=model_presets,
|
||||||
model_preset=defaults.model_preset,
|
model_preset=defaults.model_preset,
|
||||||
model_preset_snapshot_builder=lambda preset: build_provider_snapshot(config, preset=preset),
|
model_preset_snapshot_builder=(
|
||||||
|
model_preset_snapshot_builder
|
||||||
|
or (lambda name: build_provider_snapshot(config, preset_name=name))
|
||||||
|
),
|
||||||
**extra,
|
**extra,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -489,8 +497,32 @@ class AgentLoop:
|
|||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to refresh provider config")
|
logger.exception("Failed to refresh provider config")
|
||||||
return
|
return
|
||||||
|
if self._active_preset:
|
||||||
|
default_selection = snapshot.signature[:2]
|
||||||
|
if (
|
||||||
|
self._config_default_selection_signature is not None
|
||||||
|
and default_selection != self._config_default_selection_signature
|
||||||
|
):
|
||||||
|
self._active_preset = None
|
||||||
|
self._config_provider_signature = snapshot.signature
|
||||||
|
self._config_default_selection_signature = default_selection
|
||||||
|
self._apply_provider_snapshot(snapshot)
|
||||||
|
return
|
||||||
|
self._config_provider_signature = snapshot.signature
|
||||||
|
self._config_default_selection_signature = default_selection
|
||||||
|
try:
|
||||||
|
snapshot = self._build_model_preset_snapshot(self._active_preset)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to refresh active model preset")
|
||||||
|
return
|
||||||
|
if snapshot.signature == self._provider_signature:
|
||||||
|
return
|
||||||
|
self._apply_provider_snapshot(snapshot)
|
||||||
|
return
|
||||||
if snapshot.signature == self._provider_signature:
|
if snapshot.signature == self._provider_signature:
|
||||||
return
|
return
|
||||||
|
self._config_provider_signature = snapshot.signature
|
||||||
|
self._config_default_selection_signature = snapshot.signature[:2]
|
||||||
self._apply_provider_snapshot(snapshot)
|
self._apply_provider_snapshot(snapshot)
|
||||||
|
|
||||||
# -- model_preset property --
|
# -- model_preset property --
|
||||||
@ -506,7 +538,7 @@ class AgentLoop:
|
|||||||
def _build_model_preset_snapshot(self, name: str) -> ProviderSnapshot:
|
def _build_model_preset_snapshot(self, name: str) -> ProviderSnapshot:
|
||||||
preset = self.model_presets[name]
|
preset = self.model_presets[name]
|
||||||
if self._model_preset_snapshot_builder is not None:
|
if self._model_preset_snapshot_builder is not None:
|
||||||
return self._model_preset_snapshot_builder(preset)
|
return self._model_preset_snapshot_builder(name)
|
||||||
self.provider.generation = preset.to_generation_settings()
|
self.provider.generation = preset.to_generation_settings()
|
||||||
return ProviderSnapshot(
|
return ProviderSnapshot(
|
||||||
provider=self.provider,
|
provider=self.provider,
|
||||||
|
|||||||
@ -672,6 +672,7 @@ def _run_gateway(
|
|||||||
"aihubmix": config.providers.aihubmix,
|
"aihubmix": config.providers.aihubmix,
|
||||||
},
|
},
|
||||||
provider_snapshot_loader=load_provider_snapshot,
|
provider_snapshot_loader=load_provider_snapshot,
|
||||||
|
model_preset_snapshot_builder=lambda name: load_provider_snapshot(preset_name=name),
|
||||||
provider_signature=provider_snapshot.signature,
|
provider_signature=provider_snapshot.signature,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -143,7 +143,14 @@ def build_provider_snapshot(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_provider_snapshot(config_path: Path | None = None) -> ProviderSnapshot:
|
def load_provider_snapshot(
|
||||||
|
config_path: Path | None = None,
|
||||||
|
*,
|
||||||
|
preset_name: str | None = None,
|
||||||
|
) -> ProviderSnapshot:
|
||||||
from nanobot.config.loader import load_config, resolve_config_env_vars
|
from nanobot.config.loader import load_config, resolve_config_env_vars
|
||||||
|
|
||||||
return build_provider_snapshot(resolve_config_env_vars(load_config(config_path)))
|
return build_provider_snapshot(
|
||||||
|
resolve_config_env_vars(load_config(config_path)),
|
||||||
|
preset_name=preset_name,
|
||||||
|
)
|
||||||
|
|||||||
@ -80,11 +80,11 @@ def test_model_preset_setter_replaces_provider_from_snapshot(tmp_path) -> None:
|
|||||||
model="base-model",
|
model="base-model",
|
||||||
context_window_tokens=1000,
|
context_window_tokens=1000,
|
||||||
model_presets={"deep": preset},
|
model_presets={"deep": preset},
|
||||||
model_preset_snapshot_builder=lambda _preset: ProviderSnapshot(
|
model_preset_snapshot_builder=lambda _name: ProviderSnapshot(
|
||||||
provider=new_provider,
|
provider=new_provider,
|
||||||
model=_preset.model,
|
model=preset.model,
|
||||||
context_window_tokens=_preset.context_window_tokens,
|
context_window_tokens=preset.context_window_tokens,
|
||||||
signature=("deep", _preset.model),
|
signature=("deep", preset.model),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -111,7 +111,7 @@ def test_model_preset_setter_failure_leaves_old_state(tmp_path) -> None:
|
|||||||
model="base-model",
|
model="base-model",
|
||||||
context_window_tokens=1000,
|
context_window_tokens=1000,
|
||||||
model_presets={"fast": preset},
|
model_presets={"fast": preset},
|
||||||
model_preset_snapshot_builder=lambda _preset: (_ for _ in ()).throw(
|
model_preset_snapshot_builder=lambda _name: (_ for _ in ()).throw(
|
||||||
RuntimeError("provider unavailable")
|
RuntimeError("provider unavailable")
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@ -128,6 +128,78 @@ def test_model_preset_setter_failure_leaves_old_state(tmp_path) -> None:
|
|||||||
assert loop.consolidator.max_completion_tokens == 123
|
assert loop.consolidator.max_completion_tokens == 123
|
||||||
|
|
||||||
|
|
||||||
|
def test_active_model_preset_survives_unchanged_config_refresh(tmp_path) -> None:
|
||||||
|
base_provider = _provider("base-model", max_tokens=123)
|
||||||
|
fast_provider = _provider("openai/gpt-4.1", max_tokens=4096)
|
||||||
|
default_snapshot = ProviderSnapshot(
|
||||||
|
provider=base_provider,
|
||||||
|
model="base-model",
|
||||||
|
context_window_tokens=1000,
|
||||||
|
signature=("base-model", "auto", "openai", "sk-old"),
|
||||||
|
)
|
||||||
|
fast_snapshot = ProviderSnapshot(
|
||||||
|
provider=fast_provider,
|
||||||
|
model="openai/gpt-4.1",
|
||||||
|
context_window_tokens=32_768,
|
||||||
|
signature=("openai/gpt-4.1", "auto", "openai", "sk-old"),
|
||||||
|
)
|
||||||
|
loop = AgentLoop(
|
||||||
|
bus=MessageBus(),
|
||||||
|
provider=base_provider,
|
||||||
|
workspace=tmp_path,
|
||||||
|
model="base-model",
|
||||||
|
context_window_tokens=1000,
|
||||||
|
provider_snapshot_loader=lambda: default_snapshot,
|
||||||
|
provider_signature=default_snapshot.signature,
|
||||||
|
model_presets={"fast": ModelPresetConfig(model="openai/gpt-4.1")},
|
||||||
|
model_preset_snapshot_builder=lambda _name: fast_snapshot,
|
||||||
|
)
|
||||||
|
|
||||||
|
loop.set_model_preset("fast")
|
||||||
|
loop._refresh_provider_snapshot()
|
||||||
|
|
||||||
|
assert loop.model_preset == "fast"
|
||||||
|
assert loop.provider is fast_provider
|
||||||
|
assert loop.model == "openai/gpt-4.1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_config_model_refresh_clears_active_model_preset(tmp_path) -> None:
|
||||||
|
base_provider = _provider("base-model", max_tokens=123)
|
||||||
|
fast_provider = _provider("openai/gpt-4.1", max_tokens=4096)
|
||||||
|
webui_provider = _provider("anthropic/claude-opus-4-5", max_tokens=2048)
|
||||||
|
webui_snapshot = ProviderSnapshot(
|
||||||
|
provider=webui_provider,
|
||||||
|
model="anthropic/claude-opus-4-5",
|
||||||
|
context_window_tokens=200_000,
|
||||||
|
signature=("anthropic/claude-opus-4-5", "anthropic", "anthropic", "sk-old"),
|
||||||
|
)
|
||||||
|
fast_snapshot = ProviderSnapshot(
|
||||||
|
provider=fast_provider,
|
||||||
|
model="openai/gpt-4.1",
|
||||||
|
context_window_tokens=32_768,
|
||||||
|
signature=("openai/gpt-4.1", "auto", "openai", "sk-old"),
|
||||||
|
)
|
||||||
|
loop = AgentLoop(
|
||||||
|
bus=MessageBus(),
|
||||||
|
provider=base_provider,
|
||||||
|
workspace=tmp_path,
|
||||||
|
model="base-model",
|
||||||
|
context_window_tokens=1000,
|
||||||
|
provider_snapshot_loader=lambda: webui_snapshot,
|
||||||
|
provider_signature=("base-model", "auto", "openai", "sk-old"),
|
||||||
|
model_presets={"fast": ModelPresetConfig(model="openai/gpt-4.1")},
|
||||||
|
model_preset_snapshot_builder=lambda _name: fast_snapshot,
|
||||||
|
)
|
||||||
|
|
||||||
|
loop.set_model_preset("fast")
|
||||||
|
loop._refresh_provider_snapshot()
|
||||||
|
|
||||||
|
assert loop.model_preset is None
|
||||||
|
assert loop.provider is webui_provider
|
||||||
|
assert loop.model == "anthropic/claude-opus-4-5"
|
||||||
|
assert loop.context_window_tokens == 200_000
|
||||||
|
|
||||||
|
|
||||||
def test_model_preset_setter_raises_on_unknown(tmp_path) -> None:
|
def test_model_preset_setter_raises_on_unknown(tmp_path) -> None:
|
||||||
loop = _make_loop(tmp_path)
|
loop = _make_loop(tmp_path)
|
||||||
with pytest.raises(KeyError, match="model_preset 'missing' not found"):
|
with pytest.raises(KeyError, match="model_preset 'missing' not found"):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user