fix(agent): refresh llm runtime for background tasks

This commit is contained in:
Xubin Ren 2026-05-18 00:35:12 +08:00
parent 2f323e24c1
commit 112f40ad67
9 changed files with 203 additions and 29 deletions

View File

@ -41,6 +41,7 @@ from nanobot.utils.document import extract_documents
from nanobot.utils.helpers import image_placeholder_text from nanobot.utils.helpers import image_placeholder_text
from nanobot.utils.helpers import truncate_text as truncate_text_fn from nanobot.utils.helpers import truncate_text as truncate_text_fn
from nanobot.utils.image_generation_intent import image_generation_prompt from nanobot.utils.image_generation_intent import image_generation_prompt
from nanobot.utils.llm_runtime import LLMRuntime
from nanobot.utils.runtime import EMPTY_FINAL_RESPONSE_MESSAGE from nanobot.utils.runtime import EMPTY_FINAL_RESPONSE_MESSAGE
from nanobot.utils.session_attachments import merge_turn_media_into_last_assistant from nanobot.utils.session_attachments import merge_turn_media_into_last_assistant
from nanobot.utils.webui_turn_helpers import ( from nanobot.utils.webui_turn_helpers import (
@ -138,6 +139,11 @@ class AgentLoop:
def tool_names(self) -> list[str]: def tool_names(self) -> list[str]:
return self.tools.tool_names return self.tools.tool_names
def llm_runtime(self) -> LLMRuntime:
"""Return the current provider/model pair owned by this loop."""
self._refresh_provider_snapshot()
return LLMRuntime(self.provider, self.model)
_RUNTIME_CHECKPOINT_KEY = "runtime_checkpoint" _RUNTIME_CHECKPOINT_KEY = "runtime_checkpoint"
_PENDING_USER_TURN_KEY = "pending_user_turn" _PENDING_USER_TURN_KEY = "pending_user_turn"
@ -1296,8 +1302,7 @@ class AgentLoop:
self._webui_turns.capture_title_context( self._webui_turns.capture_title_context(
ctx.session_key, ctx.session_key,
ctx.msg, ctx.msg,
self.provider, self.llm_runtime(),
self.model,
) )
ctx.initial_messages = self._build_initial_messages( ctx.initial_messages = self._build_initial_messages(

View File

@ -914,8 +914,7 @@ def _run_gateway(
hb_cfg = config.gateway.heartbeat hb_cfg = config.gateway.heartbeat
heartbeat = HeartbeatService( heartbeat = HeartbeatService(
workspace=config.workspace_path, workspace=config.workspace_path,
provider=agent.provider, llm_runtime=agent.llm_runtime,
model=agent.model,
on_execute=on_heartbeat_execute, on_execute=on_heartbeat_execute,
on_notify=on_heartbeat_notify, on_notify=on_heartbeat_notify,
interval_s=hb_cfg.interval_s, interval_s=hb_cfg.interval_s,

View File

@ -4,12 +4,12 @@ from __future__ import annotations
import asyncio import asyncio
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Coroutine from typing import Any, Callable, Coroutine
from loguru import logger from loguru import logger
if TYPE_CHECKING: from nanobot.providers.base import LLMProvider
from nanobot.providers.base import LLMProvider from nanobot.utils.llm_runtime import LLMRuntime, LLMRuntimeResolver, static_llm_runtime
_HEARTBEAT_TOOL = [ _HEARTBEAT_TOOL = [
{ {
@ -53,17 +53,21 @@ class HeartbeatService:
def __init__( def __init__(
self, self,
workspace: Path, workspace: Path,
provider: LLMProvider, provider: LLMProvider | None = None,
model: str, model: str | None = None,
on_execute: Callable[[str], Coroutine[Any, Any, str]] | None = None, on_execute: Callable[[str], Coroutine[Any, Any, str]] | None = None,
on_notify: Callable[[str], Coroutine[Any, Any, None]] | None = None, on_notify: Callable[[str], Coroutine[Any, Any, None]] | None = None,
interval_s: int = 30 * 60, interval_s: int = 30 * 60,
enabled: bool = True, enabled: bool = True,
timezone: str | None = None, timezone: str | None = None,
llm_runtime: LLMRuntimeResolver | None = None,
): ):
self.workspace = workspace self.workspace = workspace
self.provider = provider if llm_runtime is None:
self.model = model if provider is None or model is None:
raise ValueError("HeartbeatService requires either llm_runtime or provider/model")
llm_runtime = static_llm_runtime(provider, model)
self._llm_runtime = llm_runtime
self.on_execute = on_execute self.on_execute = on_execute
self.on_notify = on_notify self.on_notify = on_notify
self.interval_s = interval_s self.interval_s = interval_s
@ -91,7 +95,9 @@ class HeartbeatService:
""" """
from nanobot.utils.helpers import current_time_str from nanobot.utils.helpers import current_time_str
response = await self.provider.chat_with_retry( llm = self._llm_runtime()
response = await llm.provider.chat_with_retry(
messages=[ messages=[
{"role": "system", "content": "You are a heartbeat agent. Call the heartbeat tool to report your decision."}, {"role": "system", "content": "You are a heartbeat agent. Call the heartbeat tool to report your decision."},
{"role": "user", "content": ( {"role": "user", "content": (
@ -101,7 +107,7 @@ class HeartbeatService:
)}, )},
], ],
tools=_HEARTBEAT_TOOL, tools=_HEARTBEAT_TOOL,
model=self.model, model=llm.model,
) )
if not response.should_execute_tools: if not response.should_execute_tools:
@ -214,8 +220,9 @@ class HeartbeatService:
) )
return return
llm = self._llm_runtime()
should_notify = await evaluate_response( should_notify = await evaluate_response(
response, tasks, self.provider, self.model, response, tasks, llm.provider, llm.model,
) )
if should_notify and self.on_notify: if should_notify and self.on_notify:
logger.info("Heartbeat: completed, delivering response") logger.info("Heartbeat: completed, delivering response")

View File

@ -0,0 +1,22 @@
"""Small helpers for passing the active LLM provider/model together."""
from __future__ import annotations
from collections.abc import Callable
from dataclasses import dataclass
from nanobot.providers.base import LLMProvider
@dataclass(frozen=True)
class LLMRuntime:
provider: LLMProvider
model: str
LLMRuntimeResolver = Callable[[], LLMRuntime]
def static_llm_runtime(provider: LLMProvider, model: str) -> LLMRuntimeResolver:
runtime = LLMRuntime(provider=provider, model=model)
return lambda: runtime

View File

@ -20,6 +20,7 @@ from nanobot.providers.base import LLMProvider
from nanobot.session.goal_state import goal_state_ws_blob from nanobot.session.goal_state import goal_state_ws_blob
from nanobot.session.manager import Session, SessionManager from nanobot.session.manager import Session, SessionManager
from nanobot.utils.helpers import truncate_text from nanobot.utils.helpers import truncate_text
from nanobot.utils.llm_runtime import LLMRuntime
WEBUI_SESSION_METADATA_KEY = "webui" WEBUI_SESSION_METADATA_KEY = "webui"
WEBUI_TITLE_METADATA_KEY = "title" WEBUI_TITLE_METADATA_KEY = "title"
@ -31,7 +32,6 @@ TITLE_GENERATION_REASONING_EFFORT = "none"
# Wall-clock turn start per ``chat_id`` (websocket only). Survives browser refresh while the # Wall-clock turn start per ``chat_id`` (websocket only). Survives browser refresh while the
# gateway process stays up; cleared on idle/stop and implicitly dropped on restart. # gateway process stays up; cleared on idle/stop and implicitly dropped on restart.
_WEBSOCKET_TURN_WALL_STARTED_AT: dict[str, float] = {} _WEBSOCKET_TURN_WALL_STARTED_AT: dict[str, float] = {}
TitleContext = tuple[LLMProvider, str]
def mark_webui_session(session: Session, metadata: dict[str, Any]) -> bool: def mark_webui_session(session: Session, metadata: dict[str, Any]) -> bool:
@ -241,17 +241,16 @@ class WebuiTurnCoordinator:
bus: MessageBus bus: MessageBus
sessions: SessionManager sessions: SessionManager
schedule_background: Callable[[Awaitable[None]], None] schedule_background: Callable[[Awaitable[None]], None]
_title_contexts: dict[str, TitleContext] = field(default_factory=dict) _title_contexts: dict[str, LLMRuntime] = field(default_factory=dict)
def capture_title_context( def capture_title_context(
self, self,
session_key: str, session_key: str,
msg: InboundMessage, msg: InboundMessage,
provider: LLMProvider, llm: LLMRuntime,
model: str,
) -> None: ) -> None:
if msg.channel == "websocket" and msg.metadata.get("webui") is True: if msg.channel == "websocket" and msg.metadata.get("webui") is True:
self._title_contexts[session_key] = (provider, model) self._title_contexts[session_key] = llm
def discard(self, session_key: str) -> None: def discard(self, session_key: str) -> None:
self._title_contexts.pop(session_key, None) self._title_contexts.pop(session_key, None)
@ -287,19 +286,16 @@ class WebuiTurnCoordinator:
if msg.metadata.get("webui") is not True or title_context is None: if msg.metadata.get("webui") is not True or title_context is None:
return return
title_provider, title_model = title_context
async def _generate_title_and_notify( async def _generate_title_and_notify(
provider: LLMProvider = title_provider, title_llm: LLMRuntime = title_context,
model: str = title_model,
) -> None: ) -> None:
generated = await maybe_generate_webui_title_after_turn( generated = await maybe_generate_webui_title_after_turn(
channel=msg.channel, channel=msg.channel,
metadata=msg.metadata, metadata=msg.metadata,
sessions=self.sessions, sessions=self.sessions,
session_key=session_key, session_key=session_key,
provider=provider, provider=title_llm.provider,
model=model, model=title_llm.model,
) )
if generated: if generated:
await self.bus.publish_outbound(OutboundMessage( await self.bus.publish_outbound(OutboundMessage(

View File

@ -4,6 +4,7 @@ import pytest
from nanobot.heartbeat.service import HeartbeatService from nanobot.heartbeat.service import HeartbeatService
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
from nanobot.utils.llm_runtime import LLMRuntime
class DummyProvider(LLMProvider): class DummyProvider(LLMProvider):
@ -11,9 +12,11 @@ class DummyProvider(LLMProvider):
super().__init__() super().__init__()
self._responses = list(responses) self._responses = list(responses)
self.calls = 0 self.calls = 0
self.models: list[str | None] = []
async def chat(self, *args, **kwargs) -> LLMResponse: async def chat(self, *args, **kwargs) -> LLMResponse:
self.calls += 1 self.calls += 1
self.models.append(kwargs.get("model"))
if self._responses: if self._responses:
return self._responses.pop(0) return self._responses.pop(0)
return LLMResponse(content="", tool_calls=[]) return LLMResponse(content="", tool_calls=[])
@ -215,6 +218,51 @@ async def test_tick_suppresses_when_evaluator_says_no(tmp_path, monkeypatch) ->
assert notified == [] assert notified == []
def test_tick_uses_runtime_provider_and_model(tmp_path, monkeypatch) -> None:
"""Preset changes must apply to heartbeat decision and post-run evaluation."""
(tmp_path / "HEARTBEAT.md").write_text("- [ ] check runtime model", encoding="utf-8")
runtime_provider = DummyProvider([
LLMResponse(
content="",
tool_calls=[
ToolCallRequest(
id="hb_1",
name="heartbeat",
arguments={"action": "run", "tasks": "check runtime model"},
)
],
),
])
runtime_model = "openai/gpt-4.1"
executed: list[str] = []
evaluated: list[tuple[LLMProvider, str]] = []
async def _on_execute(tasks: str) -> str:
executed.append(tasks)
return "runtime model produced a user-facing update"
async def _eval_capture(response, tasks, provider, model):
evaluated.append((provider, model))
return False
service = HeartbeatService(
workspace=tmp_path,
llm_runtime=lambda: LLMRuntime(runtime_provider, runtime_model),
on_execute=_on_execute,
)
monkeypatch.setattr("nanobot.utils.evaluator.evaluate_response", _eval_capture)
asyncio.run(service._tick())
assert runtime_provider.calls == 1
assert runtime_provider.models == [runtime_model]
assert executed == ["check runtime model"]
assert evaluated == [(runtime_provider, runtime_model)]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_decide_retries_transient_error_then_succeeds(tmp_path, monkeypatch) -> None: async def test_decide_retries_transient_error_then_succeeds(tmp_path, monkeypatch) -> None:
provider = DummyProvider([ provider = DummyProvider([
@ -286,4 +334,3 @@ async def test_decide_prompt_includes_current_time(tmp_path) -> None:
user_msg = captured_messages[1] user_msg = captured_messages[1]
assert user_msg["role"] == "user" assert user_msg["role"] == "user"
assert "Current Time:" in user_msg["content"] assert "Current Time:" in user_msg["content"]

View File

@ -10,14 +10,16 @@ from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse from nanobot.providers.base import LLMResponse
from nanobot.session.goal_state import GOAL_STATE_KEY from nanobot.session.goal_state import GOAL_STATE_KEY
from nanobot.session.manager import Session from nanobot.session.manager import Session, SessionManager
from nanobot.utils.webui_turn_helpers import ( from nanobot.utils.webui_turn_helpers import (
TITLE_GENERATION_MAX_TOKENS, TITLE_GENERATION_MAX_TOKENS,
TITLE_GENERATION_REASONING_EFFORT, TITLE_GENERATION_REASONING_EFFORT,
WEBUI_SESSION_METADATA_KEY, WEBUI_SESSION_METADATA_KEY,
WEBUI_TITLE_METADATA_KEY, WEBUI_TITLE_METADATA_KEY,
WebuiTurnCoordinator,
maybe_generate_webui_title, maybe_generate_webui_title,
) )
from nanobot.utils.llm_runtime import LLMRuntime
def _mk_loop() -> AgentLoop: def _mk_loop() -> AgentLoop:
@ -35,6 +37,22 @@ def _make_full_loop(tmp_path: Path) -> AgentLoop:
return AgentLoop(bus=MessageBus(), provider=provider, workspace=tmp_path, model="test-model") return AgentLoop(bus=MessageBus(), provider=provider, workspace=tmp_path, model="test-model")
def test_agent_loop_llm_runtime_reflects_current_provider_and_model(tmp_path: Path) -> None:
loop = _make_full_loop(tmp_path)
runtime = loop.llm_runtime()
assert runtime.provider is loop.provider
assert runtime.model == "test-model"
next_provider = MagicMock()
loop.provider = next_provider
loop.model = "next-model"
runtime = loop.llm_runtime()
assert runtime.provider is next_provider
assert runtime.model == "next-model"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_generate_webui_title_only_for_marked_webui_sessions(tmp_path: Path) -> None: async def test_generate_webui_title_only_for_marked_webui_sessions(tmp_path: Path) -> None:
loop = _make_full_loop(tmp_path) loop = _make_full_loop(tmp_path)
@ -111,6 +129,55 @@ async def test_generate_webui_title_ignores_command_only_sessions(tmp_path: Path
loop.provider.chat_with_retry.assert_not_awaited() loop.provider.chat_with_retry.assert_not_awaited()
def test_webui_title_update_uses_captured_llm_runtime(
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
bus = MessageBus()
sessions = SessionManager(tmp_path)
scheduled: list[object] = []
captured: dict[str, object] = {}
async def fake_title_after_turn(**kwargs: object) -> bool:
captured.update(kwargs)
return False
monkeypatch.setattr(
"nanobot.utils.webui_turn_helpers.maybe_generate_webui_title_after_turn",
fake_title_after_turn,
)
coordinator = WebuiTurnCoordinator(
bus=bus,
sessions=sessions,
schedule_background=lambda coro: scheduled.append(coro),
)
provider = MagicMock()
msg = InboundMessage(
channel="websocket",
sender_id="u1",
chat_id="chat1",
content="say hello",
metadata={"webui": True},
)
coordinator.capture_title_context(
"websocket:chat1",
msg,
LLMRuntime(provider, "turn-model"),
)
asyncio.run(coordinator.handle_turn_end(
msg,
session_key="websocket:chat1",
latency_ms=None,
))
assert len(scheduled) == 1
asyncio.run(scheduled[0]) # type: ignore[arg-type]
assert captured["provider"] is provider
assert captured["model"] == "turn-model"
def test_save_turn_skips_multimodal_user_when_only_runtime_context() -> None: def test_save_turn_skips_multimodal_user_when_only_runtime_context() -> None:
loop = _mk_loop() loop = _mk_loop()
session = Session(key="test:runtime-only") session = Session(key="test:runtime-only")

View File

@ -47,3 +47,28 @@ def test_provider_refresh_updates_all_model_dependents(tmp_path: Path) -> None:
assert loop.dream.provider is new_provider assert loop.dream.provider is new_provider
assert loop.dream.model == "new-model" assert loop.dream.model == "new-model"
assert loop.dream._runner.provider is new_provider assert loop.dream._runner.provider is new_provider
def test_llm_runtime_refreshes_provider_snapshot(tmp_path: Path) -> None:
old_provider = _provider("old-model")
new_provider = _provider("new-model", max_tokens=456)
loop = AgentLoop(
bus=MessageBus(),
provider=old_provider,
workspace=tmp_path,
model="old-model",
context_window_tokens=1000,
provider_snapshot_loader=lambda: ProviderSnapshot(
provider=new_provider,
model="new-model",
context_window_tokens=2000,
signature=("new-model",),
),
)
runtime = loop.llm_runtime()
assert runtime.provider is new_provider
assert runtime.model == "new-model"
assert loop.provider is new_provider
assert loop.runner.provider is new_provider

View File

@ -1170,6 +1170,7 @@ def test_gateway_cron_evaluator_receives_scheduled_reminder_context(
self.model = "test-model" self.model = "test-model"
self.provider = kwargs.get("provider", object()) self.provider = kwargs.get("provider", object())
self.tools = {} self.tools = {}
seen["agent"] = self
async def process_direct(self, *_args, **_kwargs): async def process_direct(self, *_args, **_kwargs):
return OutboundMessage( return OutboundMessage(
@ -1218,6 +1219,11 @@ def test_gateway_cron_evaluator_receives_scheduled_reminder_context(
assert isinstance(cron, _FakeCron) assert isinstance(cron, _FakeCron)
assert cron.on_job is not None assert cron.on_job is not None
runtime_provider = object()
agent = seen["agent"]
agent.provider = runtime_provider
agent.model = "runtime-model"
job = CronJob( job = CronJob(
id="cron-1", id="cron-1",
name="stretch", name="stretch",
@ -1233,8 +1239,8 @@ def test_gateway_cron_evaluator_receives_scheduled_reminder_context(
assert response == "Time to stretch." assert response == "Time to stretch."
assert seen["response"] == "Time to stretch." assert seen["response"] == "Time to stretch."
assert seen["provider"] is provider assert seen["provider"] is runtime_provider
assert seen["model"] == "test-model" assert seen["model"] == "runtime-model"
assert seen["task_context"] == ( assert seen["task_context"] == (
"The scheduled time has arrived. Deliver this reminder to the user now, " "The scheduled time has arrived. Deliver this reminder to the user now, "
"as a brief and natural message in their language. Speak directly to them — " "as a brief and natural message in their language. Speak directly to them — "