mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-20 08:32:25 +00:00
fix(agent): refresh llm runtime for background tasks
This commit is contained in:
parent
2f323e24c1
commit
112f40ad67
@ -41,6 +41,7 @@ from nanobot.utils.document import extract_documents
|
|||||||
from nanobot.utils.helpers import image_placeholder_text
|
from nanobot.utils.helpers import image_placeholder_text
|
||||||
from nanobot.utils.helpers import truncate_text as truncate_text_fn
|
from nanobot.utils.helpers import truncate_text as truncate_text_fn
|
||||||
from nanobot.utils.image_generation_intent import image_generation_prompt
|
from nanobot.utils.image_generation_intent import image_generation_prompt
|
||||||
|
from nanobot.utils.llm_runtime import LLMRuntime
|
||||||
from nanobot.utils.runtime import EMPTY_FINAL_RESPONSE_MESSAGE
|
from nanobot.utils.runtime import EMPTY_FINAL_RESPONSE_MESSAGE
|
||||||
from nanobot.utils.session_attachments import merge_turn_media_into_last_assistant
|
from nanobot.utils.session_attachments import merge_turn_media_into_last_assistant
|
||||||
from nanobot.utils.webui_turn_helpers import (
|
from nanobot.utils.webui_turn_helpers import (
|
||||||
@ -138,6 +139,11 @@ class AgentLoop:
|
|||||||
def tool_names(self) -> list[str]:
|
def tool_names(self) -> list[str]:
|
||||||
return self.tools.tool_names
|
return self.tools.tool_names
|
||||||
|
|
||||||
|
def llm_runtime(self) -> LLMRuntime:
|
||||||
|
"""Return the current provider/model pair owned by this loop."""
|
||||||
|
self._refresh_provider_snapshot()
|
||||||
|
return LLMRuntime(self.provider, self.model)
|
||||||
|
|
||||||
_RUNTIME_CHECKPOINT_KEY = "runtime_checkpoint"
|
_RUNTIME_CHECKPOINT_KEY = "runtime_checkpoint"
|
||||||
_PENDING_USER_TURN_KEY = "pending_user_turn"
|
_PENDING_USER_TURN_KEY = "pending_user_turn"
|
||||||
|
|
||||||
@ -1296,8 +1302,7 @@ class AgentLoop:
|
|||||||
self._webui_turns.capture_title_context(
|
self._webui_turns.capture_title_context(
|
||||||
ctx.session_key,
|
ctx.session_key,
|
||||||
ctx.msg,
|
ctx.msg,
|
||||||
self.provider,
|
self.llm_runtime(),
|
||||||
self.model,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
ctx.initial_messages = self._build_initial_messages(
|
ctx.initial_messages = self._build_initial_messages(
|
||||||
|
|||||||
@ -914,8 +914,7 @@ def _run_gateway(
|
|||||||
hb_cfg = config.gateway.heartbeat
|
hb_cfg = config.gateway.heartbeat
|
||||||
heartbeat = HeartbeatService(
|
heartbeat = HeartbeatService(
|
||||||
workspace=config.workspace_path,
|
workspace=config.workspace_path,
|
||||||
provider=agent.provider,
|
llm_runtime=agent.llm_runtime,
|
||||||
model=agent.model,
|
|
||||||
on_execute=on_heartbeat_execute,
|
on_execute=on_heartbeat_execute,
|
||||||
on_notify=on_heartbeat_notify,
|
on_notify=on_heartbeat_notify,
|
||||||
interval_s=hb_cfg.interval_s,
|
interval_s=hb_cfg.interval_s,
|
||||||
|
|||||||
@ -4,12 +4,12 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Coroutine
|
from typing import Any, Callable, Coroutine
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from nanobot.providers.base import LLMProvider
|
from nanobot.providers.base import LLMProvider
|
||||||
|
from nanobot.utils.llm_runtime import LLMRuntime, LLMRuntimeResolver, static_llm_runtime
|
||||||
|
|
||||||
_HEARTBEAT_TOOL = [
|
_HEARTBEAT_TOOL = [
|
||||||
{
|
{
|
||||||
@ -53,17 +53,21 @@ class HeartbeatService:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
workspace: Path,
|
workspace: Path,
|
||||||
provider: LLMProvider,
|
provider: LLMProvider | None = None,
|
||||||
model: str,
|
model: str | None = None,
|
||||||
on_execute: Callable[[str], Coroutine[Any, Any, str]] | None = None,
|
on_execute: Callable[[str], Coroutine[Any, Any, str]] | None = None,
|
||||||
on_notify: Callable[[str], Coroutine[Any, Any, None]] | None = None,
|
on_notify: Callable[[str], Coroutine[Any, Any, None]] | None = None,
|
||||||
interval_s: int = 30 * 60,
|
interval_s: int = 30 * 60,
|
||||||
enabled: bool = True,
|
enabled: bool = True,
|
||||||
timezone: str | None = None,
|
timezone: str | None = None,
|
||||||
|
llm_runtime: LLMRuntimeResolver | None = None,
|
||||||
):
|
):
|
||||||
self.workspace = workspace
|
self.workspace = workspace
|
||||||
self.provider = provider
|
if llm_runtime is None:
|
||||||
self.model = model
|
if provider is None or model is None:
|
||||||
|
raise ValueError("HeartbeatService requires either llm_runtime or provider/model")
|
||||||
|
llm_runtime = static_llm_runtime(provider, model)
|
||||||
|
self._llm_runtime = llm_runtime
|
||||||
self.on_execute = on_execute
|
self.on_execute = on_execute
|
||||||
self.on_notify = on_notify
|
self.on_notify = on_notify
|
||||||
self.interval_s = interval_s
|
self.interval_s = interval_s
|
||||||
@ -91,7 +95,9 @@ class HeartbeatService:
|
|||||||
"""
|
"""
|
||||||
from nanobot.utils.helpers import current_time_str
|
from nanobot.utils.helpers import current_time_str
|
||||||
|
|
||||||
response = await self.provider.chat_with_retry(
|
llm = self._llm_runtime()
|
||||||
|
|
||||||
|
response = await llm.provider.chat_with_retry(
|
||||||
messages=[
|
messages=[
|
||||||
{"role": "system", "content": "You are a heartbeat agent. Call the heartbeat tool to report your decision."},
|
{"role": "system", "content": "You are a heartbeat agent. Call the heartbeat tool to report your decision."},
|
||||||
{"role": "user", "content": (
|
{"role": "user", "content": (
|
||||||
@ -101,7 +107,7 @@ class HeartbeatService:
|
|||||||
)},
|
)},
|
||||||
],
|
],
|
||||||
tools=_HEARTBEAT_TOOL,
|
tools=_HEARTBEAT_TOOL,
|
||||||
model=self.model,
|
model=llm.model,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not response.should_execute_tools:
|
if not response.should_execute_tools:
|
||||||
@ -214,8 +220,9 @@ class HeartbeatService:
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
llm = self._llm_runtime()
|
||||||
should_notify = await evaluate_response(
|
should_notify = await evaluate_response(
|
||||||
response, tasks, self.provider, self.model,
|
response, tasks, llm.provider, llm.model,
|
||||||
)
|
)
|
||||||
if should_notify and self.on_notify:
|
if should_notify and self.on_notify:
|
||||||
logger.info("Heartbeat: completed, delivering response")
|
logger.info("Heartbeat: completed, delivering response")
|
||||||
|
|||||||
22
nanobot/utils/llm_runtime.py
Normal file
22
nanobot/utils/llm_runtime.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
"""Small helpers for passing the active LLM provider/model together."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from nanobot.providers.base import LLMProvider
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class LLMRuntime:
|
||||||
|
provider: LLMProvider
|
||||||
|
model: str
|
||||||
|
|
||||||
|
|
||||||
|
LLMRuntimeResolver = Callable[[], LLMRuntime]
|
||||||
|
|
||||||
|
|
||||||
|
def static_llm_runtime(provider: LLMProvider, model: str) -> LLMRuntimeResolver:
|
||||||
|
runtime = LLMRuntime(provider=provider, model=model)
|
||||||
|
return lambda: runtime
|
||||||
@ -20,6 +20,7 @@ from nanobot.providers.base import LLMProvider
|
|||||||
from nanobot.session.goal_state import goal_state_ws_blob
|
from nanobot.session.goal_state import goal_state_ws_blob
|
||||||
from nanobot.session.manager import Session, SessionManager
|
from nanobot.session.manager import Session, SessionManager
|
||||||
from nanobot.utils.helpers import truncate_text
|
from nanobot.utils.helpers import truncate_text
|
||||||
|
from nanobot.utils.llm_runtime import LLMRuntime
|
||||||
|
|
||||||
WEBUI_SESSION_METADATA_KEY = "webui"
|
WEBUI_SESSION_METADATA_KEY = "webui"
|
||||||
WEBUI_TITLE_METADATA_KEY = "title"
|
WEBUI_TITLE_METADATA_KEY = "title"
|
||||||
@ -31,7 +32,6 @@ TITLE_GENERATION_REASONING_EFFORT = "none"
|
|||||||
# Wall-clock turn start per ``chat_id`` (websocket only). Survives browser refresh while the
|
# Wall-clock turn start per ``chat_id`` (websocket only). Survives browser refresh while the
|
||||||
# gateway process stays up; cleared on idle/stop and implicitly dropped on restart.
|
# gateway process stays up; cleared on idle/stop and implicitly dropped on restart.
|
||||||
_WEBSOCKET_TURN_WALL_STARTED_AT: dict[str, float] = {}
|
_WEBSOCKET_TURN_WALL_STARTED_AT: dict[str, float] = {}
|
||||||
TitleContext = tuple[LLMProvider, str]
|
|
||||||
|
|
||||||
|
|
||||||
def mark_webui_session(session: Session, metadata: dict[str, Any]) -> bool:
|
def mark_webui_session(session: Session, metadata: dict[str, Any]) -> bool:
|
||||||
@ -241,17 +241,16 @@ class WebuiTurnCoordinator:
|
|||||||
bus: MessageBus
|
bus: MessageBus
|
||||||
sessions: SessionManager
|
sessions: SessionManager
|
||||||
schedule_background: Callable[[Awaitable[None]], None]
|
schedule_background: Callable[[Awaitable[None]], None]
|
||||||
_title_contexts: dict[str, TitleContext] = field(default_factory=dict)
|
_title_contexts: dict[str, LLMRuntime] = field(default_factory=dict)
|
||||||
|
|
||||||
def capture_title_context(
|
def capture_title_context(
|
||||||
self,
|
self,
|
||||||
session_key: str,
|
session_key: str,
|
||||||
msg: InboundMessage,
|
msg: InboundMessage,
|
||||||
provider: LLMProvider,
|
llm: LLMRuntime,
|
||||||
model: str,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
if msg.channel == "websocket" and msg.metadata.get("webui") is True:
|
if msg.channel == "websocket" and msg.metadata.get("webui") is True:
|
||||||
self._title_contexts[session_key] = (provider, model)
|
self._title_contexts[session_key] = llm
|
||||||
|
|
||||||
def discard(self, session_key: str) -> None:
|
def discard(self, session_key: str) -> None:
|
||||||
self._title_contexts.pop(session_key, None)
|
self._title_contexts.pop(session_key, None)
|
||||||
@ -287,19 +286,16 @@ class WebuiTurnCoordinator:
|
|||||||
if msg.metadata.get("webui") is not True or title_context is None:
|
if msg.metadata.get("webui") is not True or title_context is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
title_provider, title_model = title_context
|
|
||||||
|
|
||||||
async def _generate_title_and_notify(
|
async def _generate_title_and_notify(
|
||||||
provider: LLMProvider = title_provider,
|
title_llm: LLMRuntime = title_context,
|
||||||
model: str = title_model,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
generated = await maybe_generate_webui_title_after_turn(
|
generated = await maybe_generate_webui_title_after_turn(
|
||||||
channel=msg.channel,
|
channel=msg.channel,
|
||||||
metadata=msg.metadata,
|
metadata=msg.metadata,
|
||||||
sessions=self.sessions,
|
sessions=self.sessions,
|
||||||
session_key=session_key,
|
session_key=session_key,
|
||||||
provider=provider,
|
provider=title_llm.provider,
|
||||||
model=model,
|
model=title_llm.model,
|
||||||
)
|
)
|
||||||
if generated:
|
if generated:
|
||||||
await self.bus.publish_outbound(OutboundMessage(
|
await self.bus.publish_outbound(OutboundMessage(
|
||||||
|
|||||||
@ -4,6 +4,7 @@ import pytest
|
|||||||
|
|
||||||
from nanobot.heartbeat.service import HeartbeatService
|
from nanobot.heartbeat.service import HeartbeatService
|
||||||
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest
|
||||||
|
from nanobot.utils.llm_runtime import LLMRuntime
|
||||||
|
|
||||||
|
|
||||||
class DummyProvider(LLMProvider):
|
class DummyProvider(LLMProvider):
|
||||||
@ -11,9 +12,11 @@ class DummyProvider(LLMProvider):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self._responses = list(responses)
|
self._responses = list(responses)
|
||||||
self.calls = 0
|
self.calls = 0
|
||||||
|
self.models: list[str | None] = []
|
||||||
|
|
||||||
async def chat(self, *args, **kwargs) -> LLMResponse:
|
async def chat(self, *args, **kwargs) -> LLMResponse:
|
||||||
self.calls += 1
|
self.calls += 1
|
||||||
|
self.models.append(kwargs.get("model"))
|
||||||
if self._responses:
|
if self._responses:
|
||||||
return self._responses.pop(0)
|
return self._responses.pop(0)
|
||||||
return LLMResponse(content="", tool_calls=[])
|
return LLMResponse(content="", tool_calls=[])
|
||||||
@ -215,6 +218,51 @@ async def test_tick_suppresses_when_evaluator_says_no(tmp_path, monkeypatch) ->
|
|||||||
assert notified == []
|
assert notified == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_tick_uses_runtime_provider_and_model(tmp_path, monkeypatch) -> None:
|
||||||
|
"""Preset changes must apply to heartbeat decision and post-run evaluation."""
|
||||||
|
(tmp_path / "HEARTBEAT.md").write_text("- [ ] check runtime model", encoding="utf-8")
|
||||||
|
|
||||||
|
runtime_provider = DummyProvider([
|
||||||
|
LLMResponse(
|
||||||
|
content="",
|
||||||
|
tool_calls=[
|
||||||
|
ToolCallRequest(
|
||||||
|
id="hb_1",
|
||||||
|
name="heartbeat",
|
||||||
|
arguments={"action": "run", "tasks": "check runtime model"},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
])
|
||||||
|
runtime_model = "openai/gpt-4.1"
|
||||||
|
|
||||||
|
executed: list[str] = []
|
||||||
|
evaluated: list[tuple[LLMProvider, str]] = []
|
||||||
|
|
||||||
|
async def _on_execute(tasks: str) -> str:
|
||||||
|
executed.append(tasks)
|
||||||
|
return "runtime model produced a user-facing update"
|
||||||
|
|
||||||
|
async def _eval_capture(response, tasks, provider, model):
|
||||||
|
evaluated.append((provider, model))
|
||||||
|
return False
|
||||||
|
|
||||||
|
service = HeartbeatService(
|
||||||
|
workspace=tmp_path,
|
||||||
|
llm_runtime=lambda: LLMRuntime(runtime_provider, runtime_model),
|
||||||
|
on_execute=_on_execute,
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr("nanobot.utils.evaluator.evaluate_response", _eval_capture)
|
||||||
|
|
||||||
|
asyncio.run(service._tick())
|
||||||
|
|
||||||
|
assert runtime_provider.calls == 1
|
||||||
|
assert runtime_provider.models == [runtime_model]
|
||||||
|
assert executed == ["check runtime model"]
|
||||||
|
assert evaluated == [(runtime_provider, runtime_model)]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_decide_retries_transient_error_then_succeeds(tmp_path, monkeypatch) -> None:
|
async def test_decide_retries_transient_error_then_succeeds(tmp_path, monkeypatch) -> None:
|
||||||
provider = DummyProvider([
|
provider = DummyProvider([
|
||||||
@ -286,4 +334,3 @@ async def test_decide_prompt_includes_current_time(tmp_path) -> None:
|
|||||||
user_msg = captured_messages[1]
|
user_msg = captured_messages[1]
|
||||||
assert user_msg["role"] == "user"
|
assert user_msg["role"] == "user"
|
||||||
assert "Current Time:" in user_msg["content"]
|
assert "Current Time:" in user_msg["content"]
|
||||||
|
|
||||||
|
|||||||
@ -10,14 +10,16 @@ from nanobot.bus.events import InboundMessage
|
|||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
from nanobot.providers.base import LLMResponse
|
from nanobot.providers.base import LLMResponse
|
||||||
from nanobot.session.goal_state import GOAL_STATE_KEY
|
from nanobot.session.goal_state import GOAL_STATE_KEY
|
||||||
from nanobot.session.manager import Session
|
from nanobot.session.manager import Session, SessionManager
|
||||||
from nanobot.utils.webui_turn_helpers import (
|
from nanobot.utils.webui_turn_helpers import (
|
||||||
TITLE_GENERATION_MAX_TOKENS,
|
TITLE_GENERATION_MAX_TOKENS,
|
||||||
TITLE_GENERATION_REASONING_EFFORT,
|
TITLE_GENERATION_REASONING_EFFORT,
|
||||||
WEBUI_SESSION_METADATA_KEY,
|
WEBUI_SESSION_METADATA_KEY,
|
||||||
WEBUI_TITLE_METADATA_KEY,
|
WEBUI_TITLE_METADATA_KEY,
|
||||||
|
WebuiTurnCoordinator,
|
||||||
maybe_generate_webui_title,
|
maybe_generate_webui_title,
|
||||||
)
|
)
|
||||||
|
from nanobot.utils.llm_runtime import LLMRuntime
|
||||||
|
|
||||||
|
|
||||||
def _mk_loop() -> AgentLoop:
|
def _mk_loop() -> AgentLoop:
|
||||||
@ -35,6 +37,22 @@ def _make_full_loop(tmp_path: Path) -> AgentLoop:
|
|||||||
return AgentLoop(bus=MessageBus(), provider=provider, workspace=tmp_path, model="test-model")
|
return AgentLoop(bus=MessageBus(), provider=provider, workspace=tmp_path, model="test-model")
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_loop_llm_runtime_reflects_current_provider_and_model(tmp_path: Path) -> None:
|
||||||
|
loop = _make_full_loop(tmp_path)
|
||||||
|
runtime = loop.llm_runtime()
|
||||||
|
|
||||||
|
assert runtime.provider is loop.provider
|
||||||
|
assert runtime.model == "test-model"
|
||||||
|
|
||||||
|
next_provider = MagicMock()
|
||||||
|
loop.provider = next_provider
|
||||||
|
loop.model = "next-model"
|
||||||
|
runtime = loop.llm_runtime()
|
||||||
|
|
||||||
|
assert runtime.provider is next_provider
|
||||||
|
assert runtime.model == "next-model"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_generate_webui_title_only_for_marked_webui_sessions(tmp_path: Path) -> None:
|
async def test_generate_webui_title_only_for_marked_webui_sessions(tmp_path: Path) -> None:
|
||||||
loop = _make_full_loop(tmp_path)
|
loop = _make_full_loop(tmp_path)
|
||||||
@ -111,6 +129,55 @@ async def test_generate_webui_title_ignores_command_only_sessions(tmp_path: Path
|
|||||||
loop.provider.chat_with_retry.assert_not_awaited()
|
loop.provider.chat_with_retry.assert_not_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
def test_webui_title_update_uses_captured_llm_runtime(
|
||||||
|
tmp_path: Path,
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
bus = MessageBus()
|
||||||
|
sessions = SessionManager(tmp_path)
|
||||||
|
scheduled: list[object] = []
|
||||||
|
captured: dict[str, object] = {}
|
||||||
|
|
||||||
|
async def fake_title_after_turn(**kwargs: object) -> bool:
|
||||||
|
captured.update(kwargs)
|
||||||
|
return False
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"nanobot.utils.webui_turn_helpers.maybe_generate_webui_title_after_turn",
|
||||||
|
fake_title_after_turn,
|
||||||
|
)
|
||||||
|
coordinator = WebuiTurnCoordinator(
|
||||||
|
bus=bus,
|
||||||
|
sessions=sessions,
|
||||||
|
schedule_background=lambda coro: scheduled.append(coro),
|
||||||
|
)
|
||||||
|
provider = MagicMock()
|
||||||
|
msg = InboundMessage(
|
||||||
|
channel="websocket",
|
||||||
|
sender_id="u1",
|
||||||
|
chat_id="chat1",
|
||||||
|
content="say hello",
|
||||||
|
metadata={"webui": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
coordinator.capture_title_context(
|
||||||
|
"websocket:chat1",
|
||||||
|
msg,
|
||||||
|
LLMRuntime(provider, "turn-model"),
|
||||||
|
)
|
||||||
|
asyncio.run(coordinator.handle_turn_end(
|
||||||
|
msg,
|
||||||
|
session_key="websocket:chat1",
|
||||||
|
latency_ms=None,
|
||||||
|
))
|
||||||
|
|
||||||
|
assert len(scheduled) == 1
|
||||||
|
asyncio.run(scheduled[0]) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
assert captured["provider"] is provider
|
||||||
|
assert captured["model"] == "turn-model"
|
||||||
|
|
||||||
|
|
||||||
def test_save_turn_skips_multimodal_user_when_only_runtime_context() -> None:
|
def test_save_turn_skips_multimodal_user_when_only_runtime_context() -> None:
|
||||||
loop = _mk_loop()
|
loop = _mk_loop()
|
||||||
session = Session(key="test:runtime-only")
|
session = Session(key="test:runtime-only")
|
||||||
|
|||||||
@ -47,3 +47,28 @@ def test_provider_refresh_updates_all_model_dependents(tmp_path: Path) -> None:
|
|||||||
assert loop.dream.provider is new_provider
|
assert loop.dream.provider is new_provider
|
||||||
assert loop.dream.model == "new-model"
|
assert loop.dream.model == "new-model"
|
||||||
assert loop.dream._runner.provider is new_provider
|
assert loop.dream._runner.provider is new_provider
|
||||||
|
|
||||||
|
|
||||||
|
def test_llm_runtime_refreshes_provider_snapshot(tmp_path: Path) -> None:
|
||||||
|
old_provider = _provider("old-model")
|
||||||
|
new_provider = _provider("new-model", max_tokens=456)
|
||||||
|
loop = AgentLoop(
|
||||||
|
bus=MessageBus(),
|
||||||
|
provider=old_provider,
|
||||||
|
workspace=tmp_path,
|
||||||
|
model="old-model",
|
||||||
|
context_window_tokens=1000,
|
||||||
|
provider_snapshot_loader=lambda: ProviderSnapshot(
|
||||||
|
provider=new_provider,
|
||||||
|
model="new-model",
|
||||||
|
context_window_tokens=2000,
|
||||||
|
signature=("new-model",),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
runtime = loop.llm_runtime()
|
||||||
|
|
||||||
|
assert runtime.provider is new_provider
|
||||||
|
assert runtime.model == "new-model"
|
||||||
|
assert loop.provider is new_provider
|
||||||
|
assert loop.runner.provider is new_provider
|
||||||
|
|||||||
@ -1170,6 +1170,7 @@ def test_gateway_cron_evaluator_receives_scheduled_reminder_context(
|
|||||||
self.model = "test-model"
|
self.model = "test-model"
|
||||||
self.provider = kwargs.get("provider", object())
|
self.provider = kwargs.get("provider", object())
|
||||||
self.tools = {}
|
self.tools = {}
|
||||||
|
seen["agent"] = self
|
||||||
|
|
||||||
async def process_direct(self, *_args, **_kwargs):
|
async def process_direct(self, *_args, **_kwargs):
|
||||||
return OutboundMessage(
|
return OutboundMessage(
|
||||||
@ -1218,6 +1219,11 @@ def test_gateway_cron_evaluator_receives_scheduled_reminder_context(
|
|||||||
assert isinstance(cron, _FakeCron)
|
assert isinstance(cron, _FakeCron)
|
||||||
assert cron.on_job is not None
|
assert cron.on_job is not None
|
||||||
|
|
||||||
|
runtime_provider = object()
|
||||||
|
agent = seen["agent"]
|
||||||
|
agent.provider = runtime_provider
|
||||||
|
agent.model = "runtime-model"
|
||||||
|
|
||||||
job = CronJob(
|
job = CronJob(
|
||||||
id="cron-1",
|
id="cron-1",
|
||||||
name="stretch",
|
name="stretch",
|
||||||
@ -1233,8 +1239,8 @@ def test_gateway_cron_evaluator_receives_scheduled_reminder_context(
|
|||||||
|
|
||||||
assert response == "Time to stretch."
|
assert response == "Time to stretch."
|
||||||
assert seen["response"] == "Time to stretch."
|
assert seen["response"] == "Time to stretch."
|
||||||
assert seen["provider"] is provider
|
assert seen["provider"] is runtime_provider
|
||||||
assert seen["model"] == "test-model"
|
assert seen["model"] == "runtime-model"
|
||||||
assert seen["task_context"] == (
|
assert seen["task_context"] == (
|
||||||
"The scheduled time has arrived. Deliver this reminder to the user now, "
|
"The scheduled time has arrived. Deliver this reminder to the user now, "
|
||||||
"as a brief and natural message in their language. Speak directly to them — "
|
"as a brief and natural message in their language. Speak directly to them — "
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user