fix(agent): refresh llm runtime for background tasks

This commit is contained in:
Xubin Ren 2026-05-18 00:35:12 +08:00
parent 2f323e24c1
commit 112f40ad67
9 changed files with 203 additions and 29 deletions

View File

@ -41,6 +41,7 @@ from nanobot.utils.document import extract_documents
from nanobot.utils.helpers import image_placeholder_text
from nanobot.utils.helpers import truncate_text as truncate_text_fn
from nanobot.utils.image_generation_intent import image_generation_prompt
from nanobot.utils.llm_runtime import LLMRuntime
from nanobot.utils.runtime import EMPTY_FINAL_RESPONSE_MESSAGE
from nanobot.utils.session_attachments import merge_turn_media_into_last_assistant
from nanobot.utils.webui_turn_helpers import (
@ -138,6 +139,11 @@ class AgentLoop:
def tool_names(self) -> list[str]:
return self.tools.tool_names
def llm_runtime(self) -> LLMRuntime:
"""Return the current provider/model pair owned by this loop."""
self._refresh_provider_snapshot()
return LLMRuntime(self.provider, self.model)
_RUNTIME_CHECKPOINT_KEY = "runtime_checkpoint"
_PENDING_USER_TURN_KEY = "pending_user_turn"
@ -1296,8 +1302,7 @@ class AgentLoop:
self._webui_turns.capture_title_context(
ctx.session_key,
ctx.msg,
self.provider,
self.model,
self.llm_runtime(),
)
ctx.initial_messages = self._build_initial_messages(

View File

@ -914,8 +914,7 @@ def _run_gateway(
hb_cfg = config.gateway.heartbeat
heartbeat = HeartbeatService(
workspace=config.workspace_path,
provider=agent.provider,
model=agent.model,
llm_runtime=agent.llm_runtime,
on_execute=on_heartbeat_execute,
on_notify=on_heartbeat_notify,
interval_s=hb_cfg.interval_s,

View File

@ -4,12 +4,12 @@ from __future__ import annotations
import asyncio
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Coroutine
from typing import Any, Callable, Coroutine
from loguru import logger
if TYPE_CHECKING:
from nanobot.providers.base import LLMProvider
from nanobot.providers.base import LLMProvider
from nanobot.utils.llm_runtime import LLMRuntime, LLMRuntimeResolver, static_llm_runtime
_HEARTBEAT_TOOL = [
{
@ -53,17 +53,21 @@ class HeartbeatService:
def __init__(
self,
workspace: Path,
provider: LLMProvider,
model: str,
provider: LLMProvider | None = None,
model: str | None = None,
on_execute: Callable[[str], Coroutine[Any, Any, str]] | None = None,
on_notify: Callable[[str], Coroutine[Any, Any, None]] | None = None,
interval_s: int = 30 * 60,
enabled: bool = True,
timezone: str | None = None,
llm_runtime: LLMRuntimeResolver | None = None,
):
self.workspace = workspace
self.provider = provider
self.model = model
if llm_runtime is None:
if provider is None or model is None:
raise ValueError("HeartbeatService requires either llm_runtime or provider/model")
llm_runtime = static_llm_runtime(provider, model)
self._llm_runtime = llm_runtime
self.on_execute = on_execute
self.on_notify = on_notify
self.interval_s = interval_s
@ -91,7 +95,9 @@ class HeartbeatService:
"""
from nanobot.utils.helpers import current_time_str
response = await self.provider.chat_with_retry(
llm = self._llm_runtime()
response = await llm.provider.chat_with_retry(
messages=[
{"role": "system", "content": "You are a heartbeat agent. Call the heartbeat tool to report your decision."},
{"role": "user", "content": (
@ -101,7 +107,7 @@ class HeartbeatService:
)},
],
tools=_HEARTBEAT_TOOL,
model=self.model,
model=llm.model,
)
if not response.should_execute_tools:
@ -214,8 +220,9 @@ class HeartbeatService:
)
return
llm = self._llm_runtime()
should_notify = await evaluate_response(
response, tasks, self.provider, self.model,
response, tasks, llm.provider, llm.model,
)
if should_notify and self.on_notify:
logger.info("Heartbeat: completed, delivering response")

View File

@ -0,0 +1,22 @@
"""Small helpers for passing the active LLM provider/model together."""
from __future__ import annotations
from collections.abc import Callable
from dataclasses import dataclass
from nanobot.providers.base import LLMProvider
@dataclass(frozen=True)
class LLMRuntime:
provider: LLMProvider
model: str
LLMRuntimeResolver = Callable[[], LLMRuntime]
def static_llm_runtime(provider: LLMProvider, model: str) -> LLMRuntimeResolver:
runtime = LLMRuntime(provider=provider, model=model)
return lambda: runtime

View File

@ -20,6 +20,7 @@ from nanobot.providers.base import LLMProvider
from nanobot.session.goal_state import goal_state_ws_blob
from nanobot.session.manager import Session, SessionManager
from nanobot.utils.helpers import truncate_text
from nanobot.utils.llm_runtime import LLMRuntime
WEBUI_SESSION_METADATA_KEY = "webui"
WEBUI_TITLE_METADATA_KEY = "title"
@ -31,7 +32,6 @@ TITLE_GENERATION_REASONING_EFFORT = "none"
# Wall-clock turn start per ``chat_id`` (websocket only). Survives browser refresh while the
# gateway process stays up; cleared on idle/stop and implicitly dropped on restart.
_WEBSOCKET_TURN_WALL_STARTED_AT: dict[str, float] = {}
TitleContext = tuple[LLMProvider, str]
def mark_webui_session(session: Session, metadata: dict[str, Any]) -> bool:
@ -241,17 +241,16 @@ class WebuiTurnCoordinator:
bus: MessageBus
sessions: SessionManager
schedule_background: Callable[[Awaitable[None]], None]
_title_contexts: dict[str, TitleContext] = field(default_factory=dict)
_title_contexts: dict[str, LLMRuntime] = field(default_factory=dict)
def capture_title_context(
self,
session_key: str,
msg: InboundMessage,
provider: LLMProvider,
model: str,
llm: LLMRuntime,
) -> None:
if msg.channel == "websocket" and msg.metadata.get("webui") is True:
self._title_contexts[session_key] = (provider, model)
self._title_contexts[session_key] = llm
def discard(self, session_key: str) -> None:
self._title_contexts.pop(session_key, None)
@ -287,19 +286,16 @@ class WebuiTurnCoordinator:
if msg.metadata.get("webui") is not True or title_context is None:
return
title_provider, title_model = title_context
async def _generate_title_and_notify(
provider: LLMProvider = title_provider,
model: str = title_model,
title_llm: LLMRuntime = title_context,
) -> None:
generated = await maybe_generate_webui_title_after_turn(
channel=msg.channel,
metadata=msg.metadata,
sessions=self.sessions,
session_key=session_key,
provider=provider,
model=model,
provider=title_llm.provider,
model=title_llm.model,
)
if generated:
await self.bus.publish_outbound(OutboundMessage(

View File

@ -4,6 +4,7 @@ import pytest
from nanobot.heartbeat.service import HeartbeatService
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
from nanobot.utils.llm_runtime import LLMRuntime
class DummyProvider(LLMProvider):
@ -11,9 +12,11 @@ class DummyProvider(LLMProvider):
super().__init__()
self._responses = list(responses)
self.calls = 0
self.models: list[str | None] = []
async def chat(self, *args, **kwargs) -> LLMResponse:
self.calls += 1
self.models.append(kwargs.get("model"))
if self._responses:
return self._responses.pop(0)
return LLMResponse(content="", tool_calls=[])
@ -215,6 +218,51 @@ async def test_tick_suppresses_when_evaluator_says_no(tmp_path, monkeypatch) ->
assert notified == []
def test_tick_uses_runtime_provider_and_model(tmp_path, monkeypatch) -> None:
"""Preset changes must apply to heartbeat decision and post-run evaluation."""
(tmp_path / "HEARTBEAT.md").write_text("- [ ] check runtime model", encoding="utf-8")
runtime_provider = DummyProvider([
LLMResponse(
content="",
tool_calls=[
ToolCallRequest(
id="hb_1",
name="heartbeat",
arguments={"action": "run", "tasks": "check runtime model"},
)
],
),
])
runtime_model = "openai/gpt-4.1"
executed: list[str] = []
evaluated: list[tuple[LLMProvider, str]] = []
async def _on_execute(tasks: str) -> str:
executed.append(tasks)
return "runtime model produced a user-facing update"
async def _eval_capture(response, tasks, provider, model):
evaluated.append((provider, model))
return False
service = HeartbeatService(
workspace=tmp_path,
llm_runtime=lambda: LLMRuntime(runtime_provider, runtime_model),
on_execute=_on_execute,
)
monkeypatch.setattr("nanobot.utils.evaluator.evaluate_response", _eval_capture)
asyncio.run(service._tick())
assert runtime_provider.calls == 1
assert runtime_provider.models == [runtime_model]
assert executed == ["check runtime model"]
assert evaluated == [(runtime_provider, runtime_model)]
@pytest.mark.asyncio
async def test_decide_retries_transient_error_then_succeeds(tmp_path, monkeypatch) -> None:
provider = DummyProvider([
@ -286,4 +334,3 @@ async def test_decide_prompt_includes_current_time(tmp_path) -> None:
user_msg = captured_messages[1]
assert user_msg["role"] == "user"
assert "Current Time:" in user_msg["content"]

View File

@ -10,14 +10,16 @@ from nanobot.bus.events import InboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
from nanobot.session.goal_state import GOAL_STATE_KEY
from nanobot.session.manager import Session
from nanobot.session.manager import Session, SessionManager
from nanobot.utils.webui_turn_helpers import (
TITLE_GENERATION_MAX_TOKENS,
TITLE_GENERATION_REASONING_EFFORT,
WEBUI_SESSION_METADATA_KEY,
WEBUI_TITLE_METADATA_KEY,
WebuiTurnCoordinator,
maybe_generate_webui_title,
)
from nanobot.utils.llm_runtime import LLMRuntime
def _mk_loop() -> AgentLoop:
@ -35,6 +37,22 @@ def _make_full_loop(tmp_path: Path) -> AgentLoop:
return AgentLoop(bus=MessageBus(), provider=provider, workspace=tmp_path, model="test-model")
def test_agent_loop_llm_runtime_reflects_current_provider_and_model(tmp_path: Path) -> None:
loop = _make_full_loop(tmp_path)
runtime = loop.llm_runtime()
assert runtime.provider is loop.provider
assert runtime.model == "test-model"
next_provider = MagicMock()
loop.provider = next_provider
loop.model = "next-model"
runtime = loop.llm_runtime()
assert runtime.provider is next_provider
assert runtime.model == "next-model"
@pytest.mark.asyncio
async def test_generate_webui_title_only_for_marked_webui_sessions(tmp_path: Path) -> None:
loop = _make_full_loop(tmp_path)
@ -111,6 +129,55 @@ async def test_generate_webui_title_ignores_command_only_sessions(tmp_path: Path
loop.provider.chat_with_retry.assert_not_awaited()
def test_webui_title_update_uses_captured_llm_runtime(
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
bus = MessageBus()
sessions = SessionManager(tmp_path)
scheduled: list[object] = []
captured: dict[str, object] = {}
async def fake_title_after_turn(**kwargs: object) -> bool:
captured.update(kwargs)
return False
monkeypatch.setattr(
"nanobot.utils.webui_turn_helpers.maybe_generate_webui_title_after_turn",
fake_title_after_turn,
)
coordinator = WebuiTurnCoordinator(
bus=bus,
sessions=sessions,
schedule_background=lambda coro: scheduled.append(coro),
)
provider = MagicMock()
msg = InboundMessage(
channel="websocket",
sender_id="u1",
chat_id="chat1",
content="say hello",
metadata={"webui": True},
)
coordinator.capture_title_context(
"websocket:chat1",
msg,
LLMRuntime(provider, "turn-model"),
)
asyncio.run(coordinator.handle_turn_end(
msg,
session_key="websocket:chat1",
latency_ms=None,
))
assert len(scheduled) == 1
asyncio.run(scheduled[0]) # type: ignore[arg-type]
assert captured["provider"] is provider
assert captured["model"] == "turn-model"
def test_save_turn_skips_multimodal_user_when_only_runtime_context() -> None:
loop = _mk_loop()
session = Session(key="test:runtime-only")

View File

@ -47,3 +47,28 @@ def test_provider_refresh_updates_all_model_dependents(tmp_path: Path) -> None:
assert loop.dream.provider is new_provider
assert loop.dream.model == "new-model"
assert loop.dream._runner.provider is new_provider
def test_llm_runtime_refreshes_provider_snapshot(tmp_path: Path) -> None:
old_provider = _provider("old-model")
new_provider = _provider("new-model", max_tokens=456)
loop = AgentLoop(
bus=MessageBus(),
provider=old_provider,
workspace=tmp_path,
model="old-model",
context_window_tokens=1000,
provider_snapshot_loader=lambda: ProviderSnapshot(
provider=new_provider,
model="new-model",
context_window_tokens=2000,
signature=("new-model",),
),
)
runtime = loop.llm_runtime()
assert runtime.provider is new_provider
assert runtime.model == "new-model"
assert loop.provider is new_provider
assert loop.runner.provider is new_provider

View File

@ -1170,6 +1170,7 @@ def test_gateway_cron_evaluator_receives_scheduled_reminder_context(
self.model = "test-model"
self.provider = kwargs.get("provider", object())
self.tools = {}
seen["agent"] = self
async def process_direct(self, *_args, **_kwargs):
return OutboundMessage(
@ -1218,6 +1219,11 @@ def test_gateway_cron_evaluator_receives_scheduled_reminder_context(
assert isinstance(cron, _FakeCron)
assert cron.on_job is not None
runtime_provider = object()
agent = seen["agent"]
agent.provider = runtime_provider
agent.model = "runtime-model"
job = CronJob(
id="cron-1",
name="stretch",
@ -1233,8 +1239,8 @@ def test_gateway_cron_evaluator_receives_scheduled_reminder_context(
assert response == "Time to stretch."
assert seen["response"] == "Time to stretch."
assert seen["provider"] is provider
assert seen["model"] == "test-model"
assert seen["provider"] is runtime_provider
assert seen["model"] == "runtime-model"
assert seen["task_context"] == (
"The scheduled time has arrived. Deliver this reminder to the user now, "
"as a brief and natural message in their language. Speak directly to them — "