mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-19 08:02:30 +00:00
fix(agent): refresh llm runtime for background tasks
This commit is contained in:
parent
2f323e24c1
commit
112f40ad67
@ -41,6 +41,7 @@ from nanobot.utils.document import extract_documents
|
||||
from nanobot.utils.helpers import image_placeholder_text
|
||||
from nanobot.utils.helpers import truncate_text as truncate_text_fn
|
||||
from nanobot.utils.image_generation_intent import image_generation_prompt
|
||||
from nanobot.utils.llm_runtime import LLMRuntime
|
||||
from nanobot.utils.runtime import EMPTY_FINAL_RESPONSE_MESSAGE
|
||||
from nanobot.utils.session_attachments import merge_turn_media_into_last_assistant
|
||||
from nanobot.utils.webui_turn_helpers import (
|
||||
@ -138,6 +139,11 @@ class AgentLoop:
|
||||
def tool_names(self) -> list[str]:
|
||||
return self.tools.tool_names
|
||||
|
||||
def llm_runtime(self) -> LLMRuntime:
|
||||
"""Return the current provider/model pair owned by this loop."""
|
||||
self._refresh_provider_snapshot()
|
||||
return LLMRuntime(self.provider, self.model)
|
||||
|
||||
_RUNTIME_CHECKPOINT_KEY = "runtime_checkpoint"
|
||||
_PENDING_USER_TURN_KEY = "pending_user_turn"
|
||||
|
||||
@ -1296,8 +1302,7 @@ class AgentLoop:
|
||||
self._webui_turns.capture_title_context(
|
||||
ctx.session_key,
|
||||
ctx.msg,
|
||||
self.provider,
|
||||
self.model,
|
||||
self.llm_runtime(),
|
||||
)
|
||||
|
||||
ctx.initial_messages = self._build_initial_messages(
|
||||
|
||||
@ -914,8 +914,7 @@ def _run_gateway(
|
||||
hb_cfg = config.gateway.heartbeat
|
||||
heartbeat = HeartbeatService(
|
||||
workspace=config.workspace_path,
|
||||
provider=agent.provider,
|
||||
model=agent.model,
|
||||
llm_runtime=agent.llm_runtime,
|
||||
on_execute=on_heartbeat_execute,
|
||||
on_notify=on_heartbeat_notify,
|
||||
interval_s=hb_cfg.interval_s,
|
||||
|
||||
@ -4,12 +4,12 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable, Coroutine
|
||||
from typing import Any, Callable, Coroutine
|
||||
|
||||
from loguru import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nanobot.providers.base import LLMProvider
|
||||
from nanobot.providers.base import LLMProvider
|
||||
from nanobot.utils.llm_runtime import LLMRuntime, LLMRuntimeResolver, static_llm_runtime
|
||||
|
||||
_HEARTBEAT_TOOL = [
|
||||
{
|
||||
@ -53,17 +53,21 @@ class HeartbeatService:
|
||||
def __init__(
|
||||
self,
|
||||
workspace: Path,
|
||||
provider: LLMProvider,
|
||||
model: str,
|
||||
provider: LLMProvider | None = None,
|
||||
model: str | None = None,
|
||||
on_execute: Callable[[str], Coroutine[Any, Any, str]] | None = None,
|
||||
on_notify: Callable[[str], Coroutine[Any, Any, None]] | None = None,
|
||||
interval_s: int = 30 * 60,
|
||||
enabled: bool = True,
|
||||
timezone: str | None = None,
|
||||
llm_runtime: LLMRuntimeResolver | None = None,
|
||||
):
|
||||
self.workspace = workspace
|
||||
self.provider = provider
|
||||
self.model = model
|
||||
if llm_runtime is None:
|
||||
if provider is None or model is None:
|
||||
raise ValueError("HeartbeatService requires either llm_runtime or provider/model")
|
||||
llm_runtime = static_llm_runtime(provider, model)
|
||||
self._llm_runtime = llm_runtime
|
||||
self.on_execute = on_execute
|
||||
self.on_notify = on_notify
|
||||
self.interval_s = interval_s
|
||||
@ -91,7 +95,9 @@ class HeartbeatService:
|
||||
"""
|
||||
from nanobot.utils.helpers import current_time_str
|
||||
|
||||
response = await self.provider.chat_with_retry(
|
||||
llm = self._llm_runtime()
|
||||
|
||||
response = await llm.provider.chat_with_retry(
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a heartbeat agent. Call the heartbeat tool to report your decision."},
|
||||
{"role": "user", "content": (
|
||||
@ -101,7 +107,7 @@ class HeartbeatService:
|
||||
)},
|
||||
],
|
||||
tools=_HEARTBEAT_TOOL,
|
||||
model=self.model,
|
||||
model=llm.model,
|
||||
)
|
||||
|
||||
if not response.should_execute_tools:
|
||||
@ -214,8 +220,9 @@ class HeartbeatService:
|
||||
)
|
||||
return
|
||||
|
||||
llm = self._llm_runtime()
|
||||
should_notify = await evaluate_response(
|
||||
response, tasks, self.provider, self.model,
|
||||
response, tasks, llm.provider, llm.model,
|
||||
)
|
||||
if should_notify and self.on_notify:
|
||||
logger.info("Heartbeat: completed, delivering response")
|
||||
|
||||
22
nanobot/utils/llm_runtime.py
Normal file
22
nanobot/utils/llm_runtime.py
Normal file
@ -0,0 +1,22 @@
|
||||
"""Small helpers for passing the active LLM provider/model together."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
|
||||
from nanobot.providers.base import LLMProvider
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LLMRuntime:
|
||||
provider: LLMProvider
|
||||
model: str
|
||||
|
||||
|
||||
LLMRuntimeResolver = Callable[[], LLMRuntime]
|
||||
|
||||
|
||||
def static_llm_runtime(provider: LLMProvider, model: str) -> LLMRuntimeResolver:
|
||||
runtime = LLMRuntime(provider=provider, model=model)
|
||||
return lambda: runtime
|
||||
@ -20,6 +20,7 @@ from nanobot.providers.base import LLMProvider
|
||||
from nanobot.session.goal_state import goal_state_ws_blob
|
||||
from nanobot.session.manager import Session, SessionManager
|
||||
from nanobot.utils.helpers import truncate_text
|
||||
from nanobot.utils.llm_runtime import LLMRuntime
|
||||
|
||||
WEBUI_SESSION_METADATA_KEY = "webui"
|
||||
WEBUI_TITLE_METADATA_KEY = "title"
|
||||
@ -31,7 +32,6 @@ TITLE_GENERATION_REASONING_EFFORT = "none"
|
||||
# Wall-clock turn start per ``chat_id`` (websocket only). Survives browser refresh while the
|
||||
# gateway process stays up; cleared on idle/stop and implicitly dropped on restart.
|
||||
_WEBSOCKET_TURN_WALL_STARTED_AT: dict[str, float] = {}
|
||||
TitleContext = tuple[LLMProvider, str]
|
||||
|
||||
|
||||
def mark_webui_session(session: Session, metadata: dict[str, Any]) -> bool:
|
||||
@ -241,17 +241,16 @@ class WebuiTurnCoordinator:
|
||||
bus: MessageBus
|
||||
sessions: SessionManager
|
||||
schedule_background: Callable[[Awaitable[None]], None]
|
||||
_title_contexts: dict[str, TitleContext] = field(default_factory=dict)
|
||||
_title_contexts: dict[str, LLMRuntime] = field(default_factory=dict)
|
||||
|
||||
def capture_title_context(
|
||||
self,
|
||||
session_key: str,
|
||||
msg: InboundMessage,
|
||||
provider: LLMProvider,
|
||||
model: str,
|
||||
llm: LLMRuntime,
|
||||
) -> None:
|
||||
if msg.channel == "websocket" and msg.metadata.get("webui") is True:
|
||||
self._title_contexts[session_key] = (provider, model)
|
||||
self._title_contexts[session_key] = llm
|
||||
|
||||
def discard(self, session_key: str) -> None:
|
||||
self._title_contexts.pop(session_key, None)
|
||||
@ -287,19 +286,16 @@ class WebuiTurnCoordinator:
|
||||
if msg.metadata.get("webui") is not True or title_context is None:
|
||||
return
|
||||
|
||||
title_provider, title_model = title_context
|
||||
|
||||
async def _generate_title_and_notify(
|
||||
provider: LLMProvider = title_provider,
|
||||
model: str = title_model,
|
||||
title_llm: LLMRuntime = title_context,
|
||||
) -> None:
|
||||
generated = await maybe_generate_webui_title_after_turn(
|
||||
channel=msg.channel,
|
||||
metadata=msg.metadata,
|
||||
sessions=self.sessions,
|
||||
session_key=session_key,
|
||||
provider=provider,
|
||||
model=model,
|
||||
provider=title_llm.provider,
|
||||
model=title_llm.model,
|
||||
)
|
||||
if generated:
|
||||
await self.bus.publish_outbound(OutboundMessage(
|
||||
|
||||
@ -4,6 +4,7 @@ import pytest
|
||||
|
||||
from nanobot.heartbeat.service import HeartbeatService
|
||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||
from nanobot.utils.llm_runtime import LLMRuntime
|
||||
|
||||
|
||||
class DummyProvider(LLMProvider):
|
||||
@ -11,9 +12,11 @@ class DummyProvider(LLMProvider):
|
||||
super().__init__()
|
||||
self._responses = list(responses)
|
||||
self.calls = 0
|
||||
self.models: list[str | None] = []
|
||||
|
||||
async def chat(self, *args, **kwargs) -> LLMResponse:
|
||||
self.calls += 1
|
||||
self.models.append(kwargs.get("model"))
|
||||
if self._responses:
|
||||
return self._responses.pop(0)
|
||||
return LLMResponse(content="", tool_calls=[])
|
||||
@ -215,6 +218,51 @@ async def test_tick_suppresses_when_evaluator_says_no(tmp_path, monkeypatch) ->
|
||||
assert notified == []
|
||||
|
||||
|
||||
def test_tick_uses_runtime_provider_and_model(tmp_path, monkeypatch) -> None:
|
||||
"""Preset changes must apply to heartbeat decision and post-run evaluation."""
|
||||
(tmp_path / "HEARTBEAT.md").write_text("- [ ] check runtime model", encoding="utf-8")
|
||||
|
||||
runtime_provider = DummyProvider([
|
||||
LLMResponse(
|
||||
content="",
|
||||
tool_calls=[
|
||||
ToolCallRequest(
|
||||
id="hb_1",
|
||||
name="heartbeat",
|
||||
arguments={"action": "run", "tasks": "check runtime model"},
|
||||
)
|
||||
],
|
||||
),
|
||||
])
|
||||
runtime_model = "openai/gpt-4.1"
|
||||
|
||||
executed: list[str] = []
|
||||
evaluated: list[tuple[LLMProvider, str]] = []
|
||||
|
||||
async def _on_execute(tasks: str) -> str:
|
||||
executed.append(tasks)
|
||||
return "runtime model produced a user-facing update"
|
||||
|
||||
async def _eval_capture(response, tasks, provider, model):
|
||||
evaluated.append((provider, model))
|
||||
return False
|
||||
|
||||
service = HeartbeatService(
|
||||
workspace=tmp_path,
|
||||
llm_runtime=lambda: LLMRuntime(runtime_provider, runtime_model),
|
||||
on_execute=_on_execute,
|
||||
)
|
||||
|
||||
monkeypatch.setattr("nanobot.utils.evaluator.evaluate_response", _eval_capture)
|
||||
|
||||
asyncio.run(service._tick())
|
||||
|
||||
assert runtime_provider.calls == 1
|
||||
assert runtime_provider.models == [runtime_model]
|
||||
assert executed == ["check runtime model"]
|
||||
assert evaluated == [(runtime_provider, runtime_model)]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decide_retries_transient_error_then_succeeds(tmp_path, monkeypatch) -> None:
|
||||
provider = DummyProvider([
|
||||
@ -286,4 +334,3 @@ async def test_decide_prompt_includes_current_time(tmp_path) -> None:
|
||||
user_msg = captured_messages[1]
|
||||
assert user_msg["role"] == "user"
|
||||
assert "Current Time:" in user_msg["content"]
|
||||
|
||||
|
||||
@ -10,14 +10,16 @@ from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMResponse
|
||||
from nanobot.session.goal_state import GOAL_STATE_KEY
|
||||
from nanobot.session.manager import Session
|
||||
from nanobot.session.manager import Session, SessionManager
|
||||
from nanobot.utils.webui_turn_helpers import (
|
||||
TITLE_GENERATION_MAX_TOKENS,
|
||||
TITLE_GENERATION_REASONING_EFFORT,
|
||||
WEBUI_SESSION_METADATA_KEY,
|
||||
WEBUI_TITLE_METADATA_KEY,
|
||||
WebuiTurnCoordinator,
|
||||
maybe_generate_webui_title,
|
||||
)
|
||||
from nanobot.utils.llm_runtime import LLMRuntime
|
||||
|
||||
|
||||
def _mk_loop() -> AgentLoop:
|
||||
@ -35,6 +37,22 @@ def _make_full_loop(tmp_path: Path) -> AgentLoop:
|
||||
return AgentLoop(bus=MessageBus(), provider=provider, workspace=tmp_path, model="test-model")
|
||||
|
||||
|
||||
def test_agent_loop_llm_runtime_reflects_current_provider_and_model(tmp_path: Path) -> None:
|
||||
loop = _make_full_loop(tmp_path)
|
||||
runtime = loop.llm_runtime()
|
||||
|
||||
assert runtime.provider is loop.provider
|
||||
assert runtime.model == "test-model"
|
||||
|
||||
next_provider = MagicMock()
|
||||
loop.provider = next_provider
|
||||
loop.model = "next-model"
|
||||
runtime = loop.llm_runtime()
|
||||
|
||||
assert runtime.provider is next_provider
|
||||
assert runtime.model == "next-model"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_webui_title_only_for_marked_webui_sessions(tmp_path: Path) -> None:
|
||||
loop = _make_full_loop(tmp_path)
|
||||
@ -111,6 +129,55 @@ async def test_generate_webui_title_ignores_command_only_sessions(tmp_path: Path
|
||||
loop.provider.chat_with_retry.assert_not_awaited()
|
||||
|
||||
|
||||
def test_webui_title_update_uses_captured_llm_runtime(
|
||||
tmp_path: Path,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
bus = MessageBus()
|
||||
sessions = SessionManager(tmp_path)
|
||||
scheduled: list[object] = []
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
async def fake_title_after_turn(**kwargs: object) -> bool:
|
||||
captured.update(kwargs)
|
||||
return False
|
||||
|
||||
monkeypatch.setattr(
|
||||
"nanobot.utils.webui_turn_helpers.maybe_generate_webui_title_after_turn",
|
||||
fake_title_after_turn,
|
||||
)
|
||||
coordinator = WebuiTurnCoordinator(
|
||||
bus=bus,
|
||||
sessions=sessions,
|
||||
schedule_background=lambda coro: scheduled.append(coro),
|
||||
)
|
||||
provider = MagicMock()
|
||||
msg = InboundMessage(
|
||||
channel="websocket",
|
||||
sender_id="u1",
|
||||
chat_id="chat1",
|
||||
content="say hello",
|
||||
metadata={"webui": True},
|
||||
)
|
||||
|
||||
coordinator.capture_title_context(
|
||||
"websocket:chat1",
|
||||
msg,
|
||||
LLMRuntime(provider, "turn-model"),
|
||||
)
|
||||
asyncio.run(coordinator.handle_turn_end(
|
||||
msg,
|
||||
session_key="websocket:chat1",
|
||||
latency_ms=None,
|
||||
))
|
||||
|
||||
assert len(scheduled) == 1
|
||||
asyncio.run(scheduled[0]) # type: ignore[arg-type]
|
||||
|
||||
assert captured["provider"] is provider
|
||||
assert captured["model"] == "turn-model"
|
||||
|
||||
|
||||
def test_save_turn_skips_multimodal_user_when_only_runtime_context() -> None:
|
||||
loop = _mk_loop()
|
||||
session = Session(key="test:runtime-only")
|
||||
|
||||
@ -47,3 +47,28 @@ def test_provider_refresh_updates_all_model_dependents(tmp_path: Path) -> None:
|
||||
assert loop.dream.provider is new_provider
|
||||
assert loop.dream.model == "new-model"
|
||||
assert loop.dream._runner.provider is new_provider
|
||||
|
||||
|
||||
def test_llm_runtime_refreshes_provider_snapshot(tmp_path: Path) -> None:
|
||||
old_provider = _provider("old-model")
|
||||
new_provider = _provider("new-model", max_tokens=456)
|
||||
loop = AgentLoop(
|
||||
bus=MessageBus(),
|
||||
provider=old_provider,
|
||||
workspace=tmp_path,
|
||||
model="old-model",
|
||||
context_window_tokens=1000,
|
||||
provider_snapshot_loader=lambda: ProviderSnapshot(
|
||||
provider=new_provider,
|
||||
model="new-model",
|
||||
context_window_tokens=2000,
|
||||
signature=("new-model",),
|
||||
),
|
||||
)
|
||||
|
||||
runtime = loop.llm_runtime()
|
||||
|
||||
assert runtime.provider is new_provider
|
||||
assert runtime.model == "new-model"
|
||||
assert loop.provider is new_provider
|
||||
assert loop.runner.provider is new_provider
|
||||
|
||||
@ -1170,6 +1170,7 @@ def test_gateway_cron_evaluator_receives_scheduled_reminder_context(
|
||||
self.model = "test-model"
|
||||
self.provider = kwargs.get("provider", object())
|
||||
self.tools = {}
|
||||
seen["agent"] = self
|
||||
|
||||
async def process_direct(self, *_args, **_kwargs):
|
||||
return OutboundMessage(
|
||||
@ -1218,6 +1219,11 @@ def test_gateway_cron_evaluator_receives_scheduled_reminder_context(
|
||||
assert isinstance(cron, _FakeCron)
|
||||
assert cron.on_job is not None
|
||||
|
||||
runtime_provider = object()
|
||||
agent = seen["agent"]
|
||||
agent.provider = runtime_provider
|
||||
agent.model = "runtime-model"
|
||||
|
||||
job = CronJob(
|
||||
id="cron-1",
|
||||
name="stretch",
|
||||
@ -1233,8 +1239,8 @@ def test_gateway_cron_evaluator_receives_scheduled_reminder_context(
|
||||
|
||||
assert response == "Time to stretch."
|
||||
assert seen["response"] == "Time to stretch."
|
||||
assert seen["provider"] is provider
|
||||
assert seen["model"] == "test-model"
|
||||
assert seen["provider"] is runtime_provider
|
||||
assert seen["model"] == "runtime-model"
|
||||
assert seen["task_context"] == (
|
||||
"The scheduled time has arrived. Deliver this reminder to the user now, "
|
||||
"as a brief and natural message in their language. Speak directly to them — "
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user