From 112f40ad67e14afc553520b4e9e5aece58d59549 Mon Sep 17 00:00:00 2001 From: Xubin Ren <52506698+Re-bin@users.noreply.github.com> Date: Mon, 18 May 2026 00:35:12 +0800 Subject: [PATCH] fix(agent): refresh llm runtime for background tasks --- nanobot/agent/loop.py | 9 +++- nanobot/cli/commands.py | 3 +- nanobot/heartbeat/service.py | 27 +++++++---- nanobot/utils/llm_runtime.py | 22 +++++++++ nanobot/utils/webui_turn_helpers.py | 18 +++---- tests/agent/test_heartbeat_service.py | 49 ++++++++++++++++++- tests/agent/test_loop_save_turn.py | 69 ++++++++++++++++++++++++++- tests/agent/test_runtime_refresh.py | 25 ++++++++++ tests/cli/test_commands.py | 10 +++- 9 files changed, 203 insertions(+), 29 deletions(-) create mode 100644 nanobot/utils/llm_runtime.py diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index 81cc393b8..c1f521170 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -41,6 +41,7 @@ from nanobot.utils.document import extract_documents from nanobot.utils.helpers import image_placeholder_text from nanobot.utils.helpers import truncate_text as truncate_text_fn from nanobot.utils.image_generation_intent import image_generation_prompt +from nanobot.utils.llm_runtime import LLMRuntime from nanobot.utils.runtime import EMPTY_FINAL_RESPONSE_MESSAGE from nanobot.utils.session_attachments import merge_turn_media_into_last_assistant from nanobot.utils.webui_turn_helpers import ( @@ -138,6 +139,11 @@ class AgentLoop: def tool_names(self) -> list[str]: return self.tools.tool_names + def llm_runtime(self) -> LLMRuntime: + """Return the current provider/model pair owned by this loop.""" + self._refresh_provider_snapshot() + return LLMRuntime(self.provider, self.model) + _RUNTIME_CHECKPOINT_KEY = "runtime_checkpoint" _PENDING_USER_TURN_KEY = "pending_user_turn" @@ -1296,8 +1302,7 @@ class AgentLoop: self._webui_turns.capture_title_context( ctx.session_key, ctx.msg, - self.provider, - self.model, + self.llm_runtime(), ) ctx.initial_messages = self._build_initial_messages( diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py index cc14f52c1..f5d8ddc3d 100644 --- a/nanobot/cli/commands.py +++ b/nanobot/cli/commands.py @@ -914,8 +914,7 @@ def _run_gateway( hb_cfg = config.gateway.heartbeat heartbeat = HeartbeatService( workspace=config.workspace_path, - provider=agent.provider, - model=agent.model, + llm_runtime=agent.llm_runtime, on_execute=on_heartbeat_execute, on_notify=on_heartbeat_notify, interval_s=hb_cfg.interval_s, diff --git a/nanobot/heartbeat/service.py b/nanobot/heartbeat/service.py index b41ee7a1e..4506b5806 100644 --- a/nanobot/heartbeat/service.py +++ b/nanobot/heartbeat/service.py @@ -4,12 +4,12 @@ from __future__ import annotations import asyncio from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Coroutine +from typing import Any, Callable, Coroutine from loguru import logger -if TYPE_CHECKING: - from nanobot.providers.base import LLMProvider +from nanobot.providers.base import LLMProvider +from nanobot.utils.llm_runtime import LLMRuntime, LLMRuntimeResolver, static_llm_runtime _HEARTBEAT_TOOL = [ { @@ -53,17 +53,21 @@ class HeartbeatService: def __init__( self, workspace: Path, - provider: LLMProvider, - model: str, + provider: LLMProvider | None = None, + model: str | None = None, on_execute: Callable[[str], Coroutine[Any, Any, str]] | None = None, on_notify: Callable[[str], Coroutine[Any, Any, None]] | None = None, interval_s: int = 30 * 60, enabled: bool = True, timezone: str | None = None, + llm_runtime: LLMRuntimeResolver | None = None, ): self.workspace = workspace - self.provider = provider - self.model = model + if llm_runtime is None: + if provider is None or model is None: + raise ValueError("HeartbeatService requires either llm_runtime or provider/model") + llm_runtime = static_llm_runtime(provider, model) + self._llm_runtime = llm_runtime self.on_execute = on_execute self.on_notify = on_notify self.interval_s = interval_s @@ -91,7 +95,9 @@ class HeartbeatService: """ from nanobot.utils.helpers import current_time_str - response = await self.provider.chat_with_retry( + llm = self._llm_runtime() + + response = await llm.provider.chat_with_retry( messages=[ {"role": "system", "content": "You are a heartbeat agent. Call the heartbeat tool to report your decision."}, {"role": "user", "content": ( @@ -101,7 +107,7 @@ class HeartbeatService: )}, ], tools=_HEARTBEAT_TOOL, - model=self.model, + model=llm.model, ) if not response.should_execute_tools: @@ -214,8 +220,9 @@ class HeartbeatService: ) return + llm = self._llm_runtime() should_notify = await evaluate_response( - response, tasks, self.provider, self.model, + response, tasks, llm.provider, llm.model, ) if should_notify and self.on_notify: logger.info("Heartbeat: completed, delivering response") diff --git a/nanobot/utils/llm_runtime.py b/nanobot/utils/llm_runtime.py new file mode 100644 index 000000000..a74f0d8c0 --- /dev/null +++ b/nanobot/utils/llm_runtime.py @@ -0,0 +1,22 @@ +"""Small helpers for passing the active LLM provider/model together.""" + +from __future__ import annotations + +from collections.abc import Callable +from dataclasses import dataclass + +from nanobot.providers.base import LLMProvider + + +@dataclass(frozen=True) +class LLMRuntime: + provider: LLMProvider + model: str + + +LLMRuntimeResolver = Callable[[], LLMRuntime] + + +def static_llm_runtime(provider: LLMProvider, model: str) -> LLMRuntimeResolver: + runtime = LLMRuntime(provider=provider, model=model) + return lambda: runtime diff --git a/nanobot/utils/webui_turn_helpers.py b/nanobot/utils/webui_turn_helpers.py index 10403852f..9ef4612f9 100644 --- a/nanobot/utils/webui_turn_helpers.py +++ b/nanobot/utils/webui_turn_helpers.py @@ -20,6 +20,7 @@ from nanobot.providers.base import LLMProvider from nanobot.session.goal_state import goal_state_ws_blob from nanobot.session.manager import Session, SessionManager from nanobot.utils.helpers import truncate_text +from nanobot.utils.llm_runtime import LLMRuntime WEBUI_SESSION_METADATA_KEY = "webui" WEBUI_TITLE_METADATA_KEY = "title" @@ -31,7 +32,6 @@ TITLE_GENERATION_REASONING_EFFORT = "none" # Wall-clock turn start per ``chat_id`` (websocket only). Survives browser refresh while the # gateway process stays up; cleared on idle/stop and implicitly dropped on restart. _WEBSOCKET_TURN_WALL_STARTED_AT: dict[str, float] = {} -TitleContext = tuple[LLMProvider, str] def mark_webui_session(session: Session, metadata: dict[str, Any]) -> bool: @@ -241,17 +241,16 @@ class WebuiTurnCoordinator: bus: MessageBus sessions: SessionManager schedule_background: Callable[[Awaitable[None]], None] - _title_contexts: dict[str, TitleContext] = field(default_factory=dict) + _title_contexts: dict[str, LLMRuntime] = field(default_factory=dict) def capture_title_context( self, session_key: str, msg: InboundMessage, - provider: LLMProvider, - model: str, + llm: LLMRuntime, ) -> None: if msg.channel == "websocket" and msg.metadata.get("webui") is True: - self._title_contexts[session_key] = (provider, model) + self._title_contexts[session_key] = llm def discard(self, session_key: str) -> None: self._title_contexts.pop(session_key, None) @@ -287,19 +286,16 @@ class WebuiTurnCoordinator: if msg.metadata.get("webui") is not True or title_context is None: return - title_provider, title_model = title_context - async def _generate_title_and_notify( - provider: LLMProvider = title_provider, - model: str = title_model, + title_llm: LLMRuntime = title_context, ) -> None: generated = await maybe_generate_webui_title_after_turn( channel=msg.channel, metadata=msg.metadata, sessions=self.sessions, session_key=session_key, - provider=provider, - model=model, + provider=title_llm.provider, + model=title_llm.model, ) if generated: await self.bus.publish_outbound(OutboundMessage( diff --git a/tests/agent/test_heartbeat_service.py b/tests/agent/test_heartbeat_service.py index 8f563cff4..fe7b54256 100644 --- a/tests/agent/test_heartbeat_service.py +++ b/tests/agent/test_heartbeat_service.py @@ -4,6 +4,7 @@ import pytest from nanobot.heartbeat.service import HeartbeatService from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest +from nanobot.utils.llm_runtime import LLMRuntime class DummyProvider(LLMProvider): @@ -11,9 +12,11 @@ class DummyProvider(LLMProvider): super().__init__() self._responses = list(responses) self.calls = 0 + self.models: list[str | None] = [] async def chat(self, *args, **kwargs) -> LLMResponse: self.calls += 1 + self.models.append(kwargs.get("model")) if self._responses: return self._responses.pop(0) return LLMResponse(content="", tool_calls=[]) @@ -215,6 +218,51 @@ async def test_tick_suppresses_when_evaluator_says_no(tmp_path, monkeypatch) -> assert notified == [] +def test_tick_uses_runtime_provider_and_model(tmp_path, monkeypatch) -> None: + """Preset changes must apply to heartbeat decision and post-run evaluation.""" + (tmp_path / "HEARTBEAT.md").write_text("- [ ] check runtime model", encoding="utf-8") + + runtime_provider = DummyProvider([ + LLMResponse( + content="", + tool_calls=[ + ToolCallRequest( + id="hb_1", + name="heartbeat", + arguments={"action": "run", "tasks": "check runtime model"}, + ) + ], + ), + ]) + runtime_model = "openai/gpt-4.1" + + executed: list[str] = [] + evaluated: list[tuple[LLMProvider, str]] = [] + + async def _on_execute(tasks: str) -> str: + executed.append(tasks) + return "runtime model produced a user-facing update" + + async def _eval_capture(response, tasks, provider, model): + evaluated.append((provider, model)) + return False + + service = HeartbeatService( + workspace=tmp_path, + llm_runtime=lambda: LLMRuntime(runtime_provider, runtime_model), + on_execute=_on_execute, + ) + + monkeypatch.setattr("nanobot.utils.evaluator.evaluate_response", _eval_capture) + + asyncio.run(service._tick()) + + assert runtime_provider.calls == 1 + assert runtime_provider.models == [runtime_model] + assert executed == ["check runtime model"] + assert evaluated == [(runtime_provider, runtime_model)] + + @pytest.mark.asyncio async def test_decide_retries_transient_error_then_succeeds(tmp_path, monkeypatch) -> None: provider = DummyProvider([ @@ -286,4 +334,3 @@ async def test_decide_prompt_includes_current_time(tmp_path) -> None: user_msg = captured_messages[1] assert user_msg["role"] == "user" assert "Current Time:" in user_msg["content"] - diff --git a/tests/agent/test_loop_save_turn.py b/tests/agent/test_loop_save_turn.py index 105291347..9814c386d 100644 --- a/tests/agent/test_loop_save_turn.py +++ b/tests/agent/test_loop_save_turn.py @@ -10,14 +10,16 @@ from nanobot.bus.events import InboundMessage from nanobot.bus.queue import MessageBus from nanobot.providers.base import LLMResponse from nanobot.session.goal_state import GOAL_STATE_KEY -from nanobot.session.manager import Session +from nanobot.session.manager import Session, SessionManager from nanobot.utils.webui_turn_helpers import ( TITLE_GENERATION_MAX_TOKENS, TITLE_GENERATION_REASONING_EFFORT, WEBUI_SESSION_METADATA_KEY, WEBUI_TITLE_METADATA_KEY, + WebuiTurnCoordinator, maybe_generate_webui_title, ) +from nanobot.utils.llm_runtime import LLMRuntime def _mk_loop() -> AgentLoop: @@ -35,6 +37,22 @@ def _make_full_loop(tmp_path: Path) -> AgentLoop: return AgentLoop(bus=MessageBus(), provider=provider, workspace=tmp_path, model="test-model") +def test_agent_loop_llm_runtime_reflects_current_provider_and_model(tmp_path: Path) -> None: + loop = _make_full_loop(tmp_path) + runtime = loop.llm_runtime() + + assert runtime.provider is loop.provider + assert runtime.model == "test-model" + + next_provider = MagicMock() + loop.provider = next_provider + loop.model = "next-model" + runtime = loop.llm_runtime() + + assert runtime.provider is next_provider + assert runtime.model == "next-model" + + @pytest.mark.asyncio async def test_generate_webui_title_only_for_marked_webui_sessions(tmp_path: Path) -> None: loop = _make_full_loop(tmp_path) @@ -111,6 +129,55 @@ async def test_generate_webui_title_ignores_command_only_sessions(tmp_path: Path loop.provider.chat_with_retry.assert_not_awaited() +def test_webui_title_update_uses_captured_llm_runtime( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + bus = MessageBus() + sessions = SessionManager(tmp_path) + scheduled: list[object] = [] + captured: dict[str, object] = {} + + async def fake_title_after_turn(**kwargs: object) -> bool: + captured.update(kwargs) + return False + + monkeypatch.setattr( + "nanobot.utils.webui_turn_helpers.maybe_generate_webui_title_after_turn", + fake_title_after_turn, + ) + coordinator = WebuiTurnCoordinator( + bus=bus, + sessions=sessions, + schedule_background=lambda coro: scheduled.append(coro), + ) + provider = MagicMock() + msg = InboundMessage( + channel="websocket", + sender_id="u1", + chat_id="chat1", + content="say hello", + metadata={"webui": True}, + ) + + coordinator.capture_title_context( + "websocket:chat1", + msg, + LLMRuntime(provider, "turn-model"), + ) + asyncio.run(coordinator.handle_turn_end( + msg, + session_key="websocket:chat1", + latency_ms=None, + )) + + assert len(scheduled) == 1 + asyncio.run(scheduled[0]) # type: ignore[arg-type] + + assert captured["provider"] is provider + assert captured["model"] == "turn-model" + + def test_save_turn_skips_multimodal_user_when_only_runtime_context() -> None: loop = _mk_loop() session = Session(key="test:runtime-only") diff --git a/tests/agent/test_runtime_refresh.py b/tests/agent/test_runtime_refresh.py index a6b19a9d8..b36b1899b 100644 --- a/tests/agent/test_runtime_refresh.py +++ b/tests/agent/test_runtime_refresh.py @@ -47,3 +47,28 @@ def test_provider_refresh_updates_all_model_dependents(tmp_path: Path) -> None: assert loop.dream.provider is new_provider assert loop.dream.model == "new-model" assert loop.dream._runner.provider is new_provider + + +def test_llm_runtime_refreshes_provider_snapshot(tmp_path: Path) -> None: + old_provider = _provider("old-model") + new_provider = _provider("new-model", max_tokens=456) + loop = AgentLoop( + bus=MessageBus(), + provider=old_provider, + workspace=tmp_path, + model="old-model", + context_window_tokens=1000, + provider_snapshot_loader=lambda: ProviderSnapshot( + provider=new_provider, + model="new-model", + context_window_tokens=2000, + signature=("new-model",), + ), + ) + + runtime = loop.llm_runtime() + + assert runtime.provider is new_provider + assert runtime.model == "new-model" + assert loop.provider is new_provider + assert loop.runner.provider is new_provider diff --git a/tests/cli/test_commands.py b/tests/cli/test_commands.py index 90c2ce877..8baa5d2f8 100644 --- a/tests/cli/test_commands.py +++ b/tests/cli/test_commands.py @@ -1170,6 +1170,7 @@ def test_gateway_cron_evaluator_receives_scheduled_reminder_context( self.model = "test-model" self.provider = kwargs.get("provider", object()) self.tools = {} + seen["agent"] = self async def process_direct(self, *_args, **_kwargs): return OutboundMessage( @@ -1218,6 +1219,11 @@ def test_gateway_cron_evaluator_receives_scheduled_reminder_context( assert isinstance(cron, _FakeCron) assert cron.on_job is not None + runtime_provider = object() + agent = seen["agent"] + agent.provider = runtime_provider + agent.model = "runtime-model" + job = CronJob( id="cron-1", name="stretch", @@ -1233,8 +1239,8 @@ def test_gateway_cron_evaluator_receives_scheduled_reminder_context( assert response == "Time to stretch." assert seen["response"] == "Time to stretch." - assert seen["provider"] is provider - assert seen["model"] == "test-model" + assert seen["provider"] is runtime_provider + assert seen["model"] == "runtime-model" assert seen["task_context"] == ( "The scheduled time has arrived. Deliver this reminder to the user now, " "as a brief and natural message in their language. Speak directly to them — "