mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-15 07:14:08 +00:00
refactor: move bound cron execution out of gateway
This commit is contained in:
parent
5ae907bc2f
commit
af8192dc38
@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
from collections.abc import Awaitable, Callable, Iterable
|
||||
|
||||
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||
@ -57,6 +58,29 @@ class CronTurnCoordinator:
|
||||
and session_key in active_session_keys
|
||||
)
|
||||
|
||||
def defer_if_active(
|
||||
self,
|
||||
msg: InboundMessage,
|
||||
*,
|
||||
session_key: str,
|
||||
active_session_keys: Iterable[str],
|
||||
) -> bool:
|
||||
"""Defer a cron turn when its target session is already active."""
|
||||
if not self.should_defer(
|
||||
msg,
|
||||
session_key=session_key,
|
||||
active_session_keys=active_session_keys,
|
||||
):
|
||||
return False
|
||||
pending_msg = msg
|
||||
if session_key != msg.session_key:
|
||||
pending_msg = dataclasses.replace(
|
||||
msg,
|
||||
session_key_override=session_key,
|
||||
)
|
||||
self.defer(session_key, pending_msg)
|
||||
return True
|
||||
|
||||
def complete(
|
||||
self,
|
||||
msg: InboundMessage,
|
||||
|
||||
@ -41,8 +41,7 @@ from nanobot.bus.runtime_events import (
|
||||
from nanobot.command import CommandContext, CommandRouter, register_builtin_commands
|
||||
from nanobot.config.schema import AgentDefaults, ModelPresetConfig
|
||||
from nanobot.cron.session_turns import (
|
||||
CRON_HISTORY_META,
|
||||
cron_trigger,
|
||||
cron_history_overrides,
|
||||
)
|
||||
from nanobot.providers.base import LLMProvider
|
||||
from nanobot.providers.factory import ProviderSnapshot
|
||||
@ -593,17 +592,10 @@ class AgentLoop:
|
||||
extra: dict[str, Any] = ({"media": list(media_paths)} if media_paths else {}) | agent_context.session_extra(msg.metadata)
|
||||
extra.update(kwargs)
|
||||
text = msg.content if isinstance(msg.content, str) else ""
|
||||
if trigger := cron_trigger(msg.metadata):
|
||||
persist_content = trigger.get("persist_content")
|
||||
if isinstance(persist_content, str) and persist_content.strip():
|
||||
text = persist_content
|
||||
extra.update({
|
||||
CRON_HISTORY_META: True,
|
||||
"cron_job_id": trigger.get("job_id"),
|
||||
"cron_job_name": trigger.get("job_name"),
|
||||
"cron_run_id": trigger.get("run_id"),
|
||||
"cron_prompt_ref": trigger.get("prompt_ref"),
|
||||
})
|
||||
text_override, cron_extra = cron_history_overrides(msg.metadata)
|
||||
if text_override is not None:
|
||||
text = text_override
|
||||
extra.update(cron_extra)
|
||||
session.add_message("user", text, **extra)
|
||||
self._mark_pending_user_turn(session)
|
||||
self.sessions.save(session)
|
||||
@ -904,18 +896,11 @@ class AgentLoop:
|
||||
self.commands.dispatch_priority,
|
||||
)
|
||||
continue
|
||||
if self._cron_turns.should_defer(
|
||||
if self._cron_turns.defer_if_active(
|
||||
msg,
|
||||
session_key=effective_key,
|
||||
active_session_keys=self._pending_queues.keys(),
|
||||
):
|
||||
pending_msg = msg
|
||||
if effective_key != msg.session_key:
|
||||
pending_msg = dataclasses.replace(
|
||||
msg,
|
||||
session_key_override=effective_key,
|
||||
)
|
||||
self._cron_turns.defer(effective_key, pending_msg)
|
||||
logger.info(
|
||||
"Deferred cron turn for active session {}",
|
||||
effective_key,
|
||||
|
||||
@ -1,13 +1,10 @@
|
||||
"""CLI commands for nanobot."""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import os
|
||||
import select
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from contextlib import nullcontext, suppress
|
||||
from contextvars import ContextVar
|
||||
@ -57,6 +54,7 @@ from nanobot.agent.loop import AgentLoop # noqa: E402
|
||||
from nanobot.cli.stream import StreamRenderer, ThinkingSpinner # noqa: E402
|
||||
from nanobot.config.paths import get_workspace_path, is_default_workspace # noqa: E402
|
||||
from nanobot.config.schema import Config # noqa: E402
|
||||
from nanobot.cron.webui_metadata import cron_proactive_delivery_metadata # noqa: E402
|
||||
from nanobot.utils.evaluator import evaluate_response # noqa: E402
|
||||
from nanobot.utils.helpers import sync_workspace_templates # noqa: E402
|
||||
from nanobot.utils.restart import ( # noqa: E402
|
||||
@ -64,10 +62,6 @@ from nanobot.utils.restart import ( # noqa: E402
|
||||
format_restart_completed_message,
|
||||
should_show_cli_restart_notice,
|
||||
)
|
||||
from nanobot.webui.metadata import ( # noqa: E402
|
||||
WEBUI_MESSAGE_SOURCE_METADATA_KEY,
|
||||
WEBUI_TURN_METADATA_KEY,
|
||||
)
|
||||
|
||||
|
||||
def _sanitize_surrogates(text: str) -> str:
|
||||
@ -99,24 +93,6 @@ _PROACTIVE_WEBUI_METADATA: ContextVar[dict[str, Any] | None] = ContextVar(
|
||||
)
|
||||
|
||||
|
||||
def _proactive_delivery_metadata(
|
||||
channel: str,
|
||||
metadata: dict[str, Any] | None,
|
||||
*,
|
||||
turn_seed: str,
|
||||
source_label: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Return channel metadata for a fresh proactive delivery turn."""
|
||||
out = dict(metadata or {})
|
||||
out.pop(WEBUI_TURN_METADATA_KEY, None)
|
||||
if channel == "websocket":
|
||||
out[WEBUI_TURN_METADATA_KEY] = f"{turn_seed}:{uuid.uuid4().hex}"
|
||||
source: dict[str, str] = {"kind": "cron"}
|
||||
if source_label:
|
||||
source["label"] = source_label
|
||||
out[WEBUI_MESSAGE_SOURCE_METADATA_KEY] = source
|
||||
return out
|
||||
|
||||
app = typer.Typer(
|
||||
name="nanobot",
|
||||
context_settings={"help_option_names": ["-h", "--help"]},
|
||||
@ -979,19 +955,14 @@ def _run_gateway(
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.bus.runtime_events import RuntimeEventBus
|
||||
from nanobot.channels.manager import ChannelManager
|
||||
from nanobot.cron.bound_runner import run_bound_cron_job
|
||||
from nanobot.cron.service import CronService
|
||||
from nanobot.cron.session_delivery import origin_delivery_context
|
||||
from nanobot.cron.session_turns import (
|
||||
CRON_DEFER_UNTIL_IDLE_META,
|
||||
CRON_TRIGGER_META,
|
||||
is_bound_cron_job,
|
||||
)
|
||||
from nanobot.cron.session_turns import is_bound_cron_job
|
||||
from nanobot.cron.types import CronJob
|
||||
from nanobot.providers.factory import build_provider_snapshot, load_provider_snapshot
|
||||
from nanobot.providers.image_generation import image_gen_provider_configs
|
||||
from nanobot.session.manager import SessionManager
|
||||
from nanobot.session.webui_turns import WebuiTurnCoordinator
|
||||
from nanobot.utils.prompt_templates import render_template
|
||||
from nanobot.webui.token_usage import TokenUsageHook
|
||||
|
||||
port = port if port is not None else config.gateway.port
|
||||
@ -1035,7 +1006,7 @@ def _run_gateway(
|
||||
schedule_background=lambda coro: agent._schedule_background(coro),
|
||||
).subscribe(runtime_events)
|
||||
|
||||
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||
from nanobot.bus.events import OutboundMessage
|
||||
from nanobot.session.keys import session_key_for_channel
|
||||
|
||||
def _channel_session_key(channel: str, chat_id: str) -> str:
|
||||
@ -1045,119 +1016,6 @@ def _run_gateway(
|
||||
unified_session=config.agents.defaults.unified_session,
|
||||
)
|
||||
|
||||
def _bound_session_delivery_context(
|
||||
job: CronJob,
|
||||
*,
|
||||
turn_seed: str,
|
||||
source_label: str | None,
|
||||
) -> tuple[str, str, dict[str, Any]]:
|
||||
channel, chat_id, metadata = origin_delivery_context(job)
|
||||
|
||||
if channel == "websocket":
|
||||
metadata["webui"] = True
|
||||
metadata.update(
|
||||
_proactive_delivery_metadata(
|
||||
"websocket",
|
||||
metadata,
|
||||
turn_seed=turn_seed,
|
||||
source_label=source_label,
|
||||
)
|
||||
)
|
||||
|
||||
return channel, chat_id, metadata
|
||||
|
||||
def _cron_prompt_ref(prompt: str) -> dict[str, Any]:
|
||||
return {
|
||||
"id": "cron.agent_turn.reminder",
|
||||
"version": 1,
|
||||
"sha256": hashlib.sha256(prompt.encode("utf-8")).hexdigest(),
|
||||
}
|
||||
|
||||
async def _run_bound_cron_job(job: CronJob) -> str | None:
|
||||
session_key = job.payload.session_key
|
||||
if not session_key:
|
||||
raise ValueError(f"cron job {job.id} is missing payload.session_key")
|
||||
|
||||
prompt = render_template(
|
||||
"agent/cron_reminder.md",
|
||||
strip=True,
|
||||
message=job.payload.message,
|
||||
)
|
||||
prompt_ref = _cron_prompt_ref(prompt)
|
||||
run_id = f"{job.id}:{int(time.time() * 1000)}:{uuid.uuid4().hex[:8]}"
|
||||
channel, chat_id, metadata = _bound_session_delivery_context(
|
||||
job,
|
||||
turn_seed=f"cron:{job.id}",
|
||||
source_label=job.name,
|
||||
)
|
||||
metadata[CRON_TRIGGER_META] = {
|
||||
"job_id": job.id,
|
||||
"job_name": job.name,
|
||||
"run_id": run_id,
|
||||
"prompt_ref": prompt_ref,
|
||||
"persist_content": (
|
||||
f"Scheduled cron job triggered: {job.name}\n\n{job.payload.message}"
|
||||
),
|
||||
}
|
||||
metadata[CRON_DEFER_UNTIL_IDLE_META] = True
|
||||
run_record_base: dict[str, Any] = {
|
||||
"job_id": job.id,
|
||||
"job_name": job.name,
|
||||
"session_key": session_key,
|
||||
"prompt_ref": prompt_ref,
|
||||
"prompt_vars": {"message": job.payload.message},
|
||||
"rendered_prompt": prompt,
|
||||
}
|
||||
|
||||
cron.write_run_record(
|
||||
run_id,
|
||||
{
|
||||
**run_record_base,
|
||||
"status": "queued",
|
||||
},
|
||||
)
|
||||
|
||||
cron_tool = agent.tools.get("cron")
|
||||
cron_token = None
|
||||
if isinstance(cron_tool, CronTool):
|
||||
cron_token = cron_tool.set_cron_context(True)
|
||||
try:
|
||||
resp = await agent.submit_cron_turn(
|
||||
InboundMessage(
|
||||
channel=channel,
|
||||
sender_id="cron",
|
||||
chat_id=chat_id,
|
||||
content=prompt,
|
||||
metadata=metadata,
|
||||
session_key_override=session_key,
|
||||
)
|
||||
)
|
||||
except (Exception, asyncio.CancelledError) as exc:
|
||||
error_text = str(exc) or exc.__class__.__name__
|
||||
cron.write_run_record(
|
||||
run_id,
|
||||
{
|
||||
**run_record_base,
|
||||
"status": "error",
|
||||
"error": error_text,
|
||||
},
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
if isinstance(cron_tool, CronTool) and cron_token is not None:
|
||||
cron_tool.reset_cron_context(cron_token)
|
||||
|
||||
response = resp.content if resp else ""
|
||||
cron.write_run_record(
|
||||
run_id,
|
||||
{
|
||||
**run_record_base,
|
||||
"status": "ok",
|
||||
"response": response,
|
||||
},
|
||||
)
|
||||
return response
|
||||
|
||||
async def _deliver_to_channel(
|
||||
msg: OutboundMessage, *, record: bool = False, session_key: str | None = None,
|
||||
) -> None:
|
||||
@ -1319,7 +1177,7 @@ def _run_gateway(
|
||||
return response
|
||||
|
||||
if is_bound_cron_job(job):
|
||||
return await _run_bound_cron_job(job)
|
||||
return await run_bound_cron_job(job, agent=agent, cron=cron)
|
||||
|
||||
reminder_note = (
|
||||
"The scheduled time has arrived. Deliver this reminder to the user now, "
|
||||
@ -1338,7 +1196,7 @@ def _run_gateway(
|
||||
if isinstance(message_tool, MessageTool):
|
||||
message_record_token = message_tool.set_record_channel_delivery(True)
|
||||
|
||||
proactive_webui_metadata = _proactive_delivery_metadata(
|
||||
proactive_webui_metadata = cron_proactive_delivery_metadata(
|
||||
"websocket",
|
||||
None,
|
||||
turn_seed=f"cron:{job.id}",
|
||||
@ -1371,7 +1229,7 @@ def _run_gateway(
|
||||
response, reminder_note, agent.provider, agent.model,
|
||||
)
|
||||
if should_notify:
|
||||
proactive_metadata = _proactive_delivery_metadata(
|
||||
proactive_metadata = cron_proactive_delivery_metadata(
|
||||
job.payload.channel or "cli",
|
||||
job.payload.channel_meta,
|
||||
turn_seed=f"cron:{job.id}",
|
||||
|
||||
151
nanobot/cron/bound_runner.py
Normal file
151
nanobot/cron/bound_runner.py
Normal file
@ -0,0 +1,151 @@
|
||||
"""Execution helpers for session-bound cron jobs."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Protocol
|
||||
|
||||
from nanobot.agent.tools.cron import CronTool
|
||||
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||
from nanobot.cron.session_delivery import origin_delivery_context
|
||||
from nanobot.cron.session_turns import CRON_DEFER_UNTIL_IDLE_META, CRON_TRIGGER_META
|
||||
from nanobot.cron.types import CronJob
|
||||
from nanobot.cron.webui_metadata import cron_proactive_delivery_metadata
|
||||
from nanobot.utils.prompt_templates import render_template
|
||||
|
||||
|
||||
class BoundCronAgent(Protocol):
|
||||
tools: Any
|
||||
|
||||
async def submit_cron_turn(self, msg: InboundMessage) -> OutboundMessage | None:
|
||||
...
|
||||
|
||||
|
||||
class CronRunRecorder(Protocol):
|
||||
def write_run_record(self, run_id: str, record: dict[str, Any]) -> None:
|
||||
...
|
||||
|
||||
|
||||
def _cron_prompt_ref(prompt: str) -> dict[str, Any]:
|
||||
return {
|
||||
"id": "cron.agent_turn.reminder",
|
||||
"version": 1,
|
||||
"sha256": hashlib.sha256(prompt.encode("utf-8")).hexdigest(),
|
||||
}
|
||||
|
||||
|
||||
def _bound_session_delivery_context(
|
||||
job: CronJob,
|
||||
*,
|
||||
turn_seed: str,
|
||||
source_label: str | None,
|
||||
) -> tuple[str, str, dict[str, Any]]:
|
||||
channel, chat_id, metadata = origin_delivery_context(job)
|
||||
|
||||
if channel == "websocket":
|
||||
metadata["webui"] = True
|
||||
metadata.update(
|
||||
cron_proactive_delivery_metadata(
|
||||
"websocket",
|
||||
metadata,
|
||||
turn_seed=turn_seed,
|
||||
source_label=source_label,
|
||||
)
|
||||
)
|
||||
|
||||
return channel, chat_id, metadata
|
||||
|
||||
|
||||
async def run_bound_cron_job(
|
||||
job: CronJob,
|
||||
*,
|
||||
agent: BoundCronAgent,
|
||||
cron: CronRunRecorder,
|
||||
) -> str | None:
|
||||
"""Execute a session-bound cron job as a normal agent session turn."""
|
||||
session_key = job.payload.session_key
|
||||
if not session_key:
|
||||
raise ValueError(f"cron job {job.id} is missing payload.session_key")
|
||||
|
||||
prompt = render_template(
|
||||
"agent/cron_reminder.md",
|
||||
strip=True,
|
||||
message=job.payload.message,
|
||||
)
|
||||
prompt_ref = _cron_prompt_ref(prompt)
|
||||
run_id = f"{job.id}:{int(time.time() * 1000)}:{uuid.uuid4().hex[:8]}"
|
||||
channel, chat_id, metadata = _bound_session_delivery_context(
|
||||
job,
|
||||
turn_seed=f"cron:{job.id}",
|
||||
source_label=job.name,
|
||||
)
|
||||
metadata[CRON_TRIGGER_META] = {
|
||||
"job_id": job.id,
|
||||
"job_name": job.name,
|
||||
"run_id": run_id,
|
||||
"prompt_ref": prompt_ref,
|
||||
"persist_content": (
|
||||
f"Scheduled cron job triggered: {job.name}\n\n{job.payload.message}"
|
||||
),
|
||||
}
|
||||
metadata[CRON_DEFER_UNTIL_IDLE_META] = True
|
||||
run_record_base: dict[str, Any] = {
|
||||
"job_id": job.id,
|
||||
"job_name": job.name,
|
||||
"session_key": session_key,
|
||||
"prompt_ref": prompt_ref,
|
||||
"prompt_vars": {"message": job.payload.message},
|
||||
"rendered_prompt": prompt,
|
||||
}
|
||||
|
||||
cron.write_run_record(
|
||||
run_id,
|
||||
{
|
||||
**run_record_base,
|
||||
"status": "queued",
|
||||
},
|
||||
)
|
||||
|
||||
cron_tool = agent.tools.get("cron")
|
||||
cron_token = None
|
||||
if isinstance(cron_tool, CronTool):
|
||||
cron_token = cron_tool.set_cron_context(True)
|
||||
try:
|
||||
resp = await agent.submit_cron_turn(
|
||||
InboundMessage(
|
||||
channel=channel,
|
||||
sender_id="cron",
|
||||
chat_id=chat_id,
|
||||
content=prompt,
|
||||
metadata=metadata,
|
||||
session_key_override=session_key,
|
||||
)
|
||||
)
|
||||
except (Exception, asyncio.CancelledError) as exc:
|
||||
error_text = str(exc) or exc.__class__.__name__
|
||||
cron.write_run_record(
|
||||
run_id,
|
||||
{
|
||||
**run_record_base,
|
||||
"status": "error",
|
||||
"error": error_text,
|
||||
},
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
if isinstance(cron_tool, CronTool) and cron_token is not None:
|
||||
cron_tool.reset_cron_context(cron_token)
|
||||
|
||||
response = resp.content if resp else ""
|
||||
cron.write_run_record(
|
||||
run_id,
|
||||
{
|
||||
**run_record_base,
|
||||
"status": "ok",
|
||||
"response": response,
|
||||
},
|
||||
)
|
||||
return response
|
||||
@ -36,6 +36,26 @@ def cron_run_id(metadata: Mapping[str, Any] | None) -> str | None:
|
||||
return value if isinstance(value, str) and value else None
|
||||
|
||||
|
||||
def cron_history_overrides(metadata: Mapping[str, Any] | None) -> tuple[str | None, dict[str, Any]]:
|
||||
"""Return session-history text/metadata overrides for a cron turn."""
|
||||
trigger = cron_trigger(metadata)
|
||||
if not trigger:
|
||||
return None, {}
|
||||
persist_content = trigger.get("persist_content")
|
||||
text = (
|
||||
persist_content
|
||||
if isinstance(persist_content, str) and persist_content.strip()
|
||||
else None
|
||||
)
|
||||
return text, {
|
||||
CRON_HISTORY_META: True,
|
||||
"cron_job_id": trigger.get("job_id"),
|
||||
"cron_job_name": trigger.get("job_name"),
|
||||
"cron_run_id": trigger.get("run_id"),
|
||||
"cron_prompt_ref": trigger.get("prompt_ref"),
|
||||
}
|
||||
|
||||
|
||||
def is_bound_cron_job(job: CronJob) -> bool:
|
||||
"""True for new session-bound cron jobs, excluding legacy delivery payloads."""
|
||||
payload = job.payload
|
||||
|
||||
27
nanobot/cron/webui_metadata.py
Normal file
27
nanobot/cron/webui_metadata.py
Normal file
@ -0,0 +1,27 @@
|
||||
"""WebUI metadata helpers for cron deliveries."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from nanobot.webui.metadata import WEBUI_MESSAGE_SOURCE_METADATA_KEY, WEBUI_TURN_METADATA_KEY
|
||||
|
||||
|
||||
def cron_proactive_delivery_metadata(
|
||||
channel: str,
|
||||
metadata: dict[str, Any] | None,
|
||||
*,
|
||||
turn_seed: str,
|
||||
source_label: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Return channel metadata for a fresh proactive cron delivery turn."""
|
||||
out = dict(metadata or {})
|
||||
out.pop(WEBUI_TURN_METADATA_KEY, None)
|
||||
if channel == "websocket":
|
||||
out[WEBUI_TURN_METADATA_KEY] = f"{turn_seed}:{uuid.uuid4().hex}"
|
||||
source: dict[str, str] = {"kind": "cron"}
|
||||
if source_label:
|
||||
source["label"] = source_label
|
||||
out[WEBUI_MESSAGE_SOURCE_METADATA_KEY] = source
|
||||
return out
|
||||
@ -9,14 +9,18 @@ import pytest
|
||||
from typer.testing import CliRunner
|
||||
|
||||
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||
from nanobot.cli.commands import _proactive_delivery_metadata, app
|
||||
from nanobot.cli.commands import app
|
||||
from nanobot.config.schema import Config
|
||||
from nanobot.cron.session_turns import CRON_DEFER_UNTIL_IDLE_META, CRON_TRIGGER_META
|
||||
from nanobot.cron.types import CronJob, CronPayload
|
||||
from nanobot.cron.webui_metadata import cron_proactive_delivery_metadata
|
||||
from nanobot.providers.factory import ProviderSnapshot, make_provider
|
||||
from nanobot.providers.openai_codex_provider import _strip_model_prefix
|
||||
from nanobot.providers.registry import find_by_name
|
||||
from nanobot.webui.metadata import WEBUI_MESSAGE_SOURCE_METADATA_KEY, WEBUI_TURN_METADATA_KEY
|
||||
from nanobot.webui.metadata import (
|
||||
WEBUI_MESSAGE_SOURCE_METADATA_KEY,
|
||||
WEBUI_TURN_METADATA_KEY,
|
||||
)
|
||||
|
||||
runner = CliRunner()
|
||||
|
||||
@ -28,7 +32,7 @@ def test_proactive_websocket_delivery_gets_fresh_turn_id() -> None:
|
||||
"workspace_scope": {"mode": "default"},
|
||||
}
|
||||
|
||||
out = _proactive_delivery_metadata(
|
||||
out = cron_proactive_delivery_metadata(
|
||||
"websocket",
|
||||
metadata,
|
||||
turn_seed="cron:drink-water",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user