mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-10 21:23:39 +00:00
fix
This commit is contained in:
parent
d084d10dc2
commit
0f1e3aa151
@ -4,10 +4,12 @@ import asyncio
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import asdict
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Coroutine, Literal
|
||||
|
||||
from filelock import FileLock
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.cron.types import CronJob, CronJobState, CronPayload, CronRunRecord, CronSchedule, CronStore
|
||||
@ -69,28 +71,25 @@ class CronService:
|
||||
self,
|
||||
store_path: Path,
|
||||
on_job: Callable[[CronJob], Coroutine[Any, Any, str | None]] | None = None,
|
||||
max_sleep_ms: int = 300_000 # 5 minutes
|
||||
):
|
||||
self.store_path = store_path
|
||||
self._action_path = store_path.parent / "action.jsonl"
|
||||
self._lock = FileLock(str(self._action_path.parent) + ".lock")
|
||||
self.on_job = on_job
|
||||
self._store: CronStore | None = None
|
||||
self._last_mtime: float = 0.0
|
||||
self._timer_task: asyncio.Task | None = None
|
||||
self._running = False
|
||||
self.max_sleep_ms = max_sleep_ms
|
||||
|
||||
def _load_store(self) -> CronStore:
|
||||
"""Load jobs from disk. Reloads automatically if file was modified externally."""
|
||||
if self._store and self.store_path.exists():
|
||||
mtime = self.store_path.stat().st_mtime
|
||||
if mtime != self._last_mtime:
|
||||
logger.info("Cron: jobs.json modified externally, reloading")
|
||||
self._store = None
|
||||
if self._store:
|
||||
return self._store
|
||||
|
||||
def _load_jobs(self) -> tuple[list[CronJob], int]:
|
||||
jobs = []
|
||||
version = 1
|
||||
if self.store_path.exists():
|
||||
try:
|
||||
data = json.loads(self.store_path.read_text(encoding="utf-8"))
|
||||
jobs = []
|
||||
version = data.get("version", 1)
|
||||
for j in data.get("jobs", []):
|
||||
jobs.append(CronJob(
|
||||
id=j["id"],
|
||||
@ -129,13 +128,53 @@ class CronService:
|
||||
updated_at_ms=j.get("updatedAtMs", 0),
|
||||
delete_after_run=j.get("deleteAfterRun", False),
|
||||
))
|
||||
self._store = CronStore(jobs=jobs)
|
||||
self._last_mtime = self.store_path.stat().st_mtime
|
||||
except Exception as e:
|
||||
logger.warning("Failed to load cron store: {}", e)
|
||||
self._store = CronStore()
|
||||
else:
|
||||
self._store = CronStore()
|
||||
return jobs, version
|
||||
|
||||
def _merge_action(self):
|
||||
if not self._action_path.exists():
|
||||
return
|
||||
|
||||
jobs_map = {j.id: j for j in self._store.jobs}
|
||||
def _update(params: dict):
|
||||
j = CronJob.from_dict(params)
|
||||
jobs_map[j.id] = j
|
||||
|
||||
def _del(params: dict):
|
||||
if job_id := params.get("job_id"):
|
||||
jobs_map.pop(job_id)
|
||||
|
||||
with self._lock:
|
||||
with open(self._action_path, "r", encoding="utf-8") as f:
|
||||
changed = False
|
||||
for line in f:
|
||||
try:
|
||||
line = line.strip()
|
||||
action = json.loads(line)
|
||||
if "action" not in action:
|
||||
continue
|
||||
if action["action"] == "del":
|
||||
_del(action.get("params", {}))
|
||||
else:
|
||||
_update(action.get("params", {}))
|
||||
changed = True
|
||||
except Exception as exp:
|
||||
logger.debug(f"load action line error: {exp}")
|
||||
continue
|
||||
self._store.jobs = list(jobs_map.values())
|
||||
if self._running and changed:
|
||||
self._action_path.write_text("", encoding="utf-8")
|
||||
self._save_store()
|
||||
return
|
||||
|
||||
def _load_store(self) -> CronStore:
|
||||
"""Load jobs from disk. Reloads automatically if file was modified externally.
|
||||
- Reload every time because it needs to merge operations on the jobs object from other instances.
|
||||
"""
|
||||
jobs, version = self._load_jobs()
|
||||
self._store = CronStore(version=version, jobs=jobs)
|
||||
self._merge_action()
|
||||
|
||||
return self._store
|
||||
|
||||
@ -230,11 +269,11 @@ class CronService:
|
||||
if self._timer_task:
|
||||
self._timer_task.cancel()
|
||||
|
||||
next_wake = self._get_next_wake_ms()
|
||||
if not next_wake or not self._running:
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
delay_ms = max(0, next_wake - _now_ms())
|
||||
next_wake = self._get_next_wake_ms() or 0
|
||||
delay_ms = min(self.max_sleep_ms ,max(1000, next_wake - _now_ms()))
|
||||
delay_s = delay_ms / 1000
|
||||
|
||||
async def tick():
|
||||
@ -303,6 +342,13 @@ class CronService:
|
||||
# Compute next run
|
||||
job.state.next_run_at_ms = _compute_next_run(job.schedule, _now_ms())
|
||||
|
||||
def _append_action(self, action: Literal["add", "del", "update"], params: dict):
|
||||
self.store_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with self._lock:
|
||||
with open(self._action_path, "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps({"action": action, "params": params}, ensure_ascii=False) + "\n")
|
||||
|
||||
|
||||
# ========== Public API ==========
|
||||
|
||||
def list_jobs(self, include_disabled: bool = False) -> list[CronJob]:
|
||||
@ -322,7 +368,6 @@ class CronService:
|
||||
delete_after_run: bool = False,
|
||||
) -> CronJob:
|
||||
"""Add a new job."""
|
||||
store = self._load_store()
|
||||
_validate_schedule_for_add(schedule)
|
||||
now = _now_ms()
|
||||
|
||||
@ -343,10 +388,13 @@ class CronService:
|
||||
updated_at_ms=now,
|
||||
delete_after_run=delete_after_run,
|
||||
)
|
||||
|
||||
store.jobs.append(job)
|
||||
self._save_store()
|
||||
self._arm_timer()
|
||||
if self._running:
|
||||
store = self._load_store()
|
||||
store.jobs.append(job)
|
||||
self._save_store()
|
||||
self._arm_timer()
|
||||
else:
|
||||
self._append_action("add", asdict(job))
|
||||
|
||||
logger.info("Cron: added job '{}' ({})", name, job.id)
|
||||
return job
|
||||
@ -380,8 +428,11 @@ class CronService:
|
||||
removed = len(store.jobs) < before
|
||||
|
||||
if removed:
|
||||
self._save_store()
|
||||
self._arm_timer()
|
||||
if self._running:
|
||||
self._save_store()
|
||||
self._arm_timer()
|
||||
else:
|
||||
self._append_action("del", {"job_id": job_id})
|
||||
logger.info("Cron: removed job {}", job_id)
|
||||
return "removed"
|
||||
|
||||
@ -398,13 +449,20 @@ class CronService:
|
||||
job.state.next_run_at_ms = _compute_next_run(job.schedule, _now_ms())
|
||||
else:
|
||||
job.state.next_run_at_ms = None
|
||||
self._save_store()
|
||||
self._arm_timer()
|
||||
if self._running:
|
||||
self._save_store()
|
||||
self._arm_timer()
|
||||
else:
|
||||
self._append_action("update", asdict(job))
|
||||
return job
|
||||
return None
|
||||
|
||||
async def run_job(self, job_id: str, force: bool = False) -> bool:
|
||||
"""Manually run a job."""
|
||||
"""Manually run a job. For testing purposes
|
||||
- It's not that the gateway instance cannot run because it doesn't have the on_job method.
|
||||
- There may be concurrency issues.
|
||||
"""
|
||||
self._running = True
|
||||
store = self._load_store()
|
||||
for job in store.jobs:
|
||||
if job.id == job_id:
|
||||
@ -412,8 +470,10 @@ class CronService:
|
||||
return False
|
||||
await self._execute_job(job)
|
||||
self._save_store()
|
||||
self._running = False
|
||||
self._arm_timer()
|
||||
return True
|
||||
self._running = False
|
||||
return False
|
||||
|
||||
def get_job(self, job_id: str) -> CronJob | None:
|
||||
|
||||
@ -61,6 +61,13 @@ class CronJob:
|
||||
updated_at_ms: int = 0
|
||||
delete_after_run: bool = False
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, kwargs: dict):
|
||||
kwargs["schedule"] = CronSchedule(**kwargs.get("schedule", {"kind": "every"}))
|
||||
kwargs["payload"] = CronPayload(**kwargs.get("payload", {}))
|
||||
kwargs["state"] = CronJobState(**kwargs.get("state", {}))
|
||||
return cls(**kwargs)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CronStore:
|
||||
|
||||
@ -50,6 +50,7 @@ dependencies = [
|
||||
"tiktoken>=0.12.0,<1.0.0",
|
||||
"jinja2>=3.1.0,<4.0.0",
|
||||
"dulwich>=0.22.0,<1.0.0",
|
||||
"filelock>=3.25.2",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
@ -158,24 +159,27 @@ def test_remove_job_refuses_system_jobs(tmp_path) -> None:
|
||||
assert service.get_job("dream") is not None
|
||||
|
||||
|
||||
def test_reload_jobs(tmp_path):
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_server_not_jobs(tmp_path):
|
||||
store_path = tmp_path / "cron" / "jobs.json"
|
||||
service = CronService(store_path, on_job=lambda _: asyncio.sleep(0))
|
||||
service.add_job(
|
||||
name="hist",
|
||||
schedule=CronSchedule(kind="every", every_ms=60_000),
|
||||
message="hello",
|
||||
)
|
||||
called = []
|
||||
async def on_job(job):
|
||||
called.append(job.name)
|
||||
|
||||
assert len(service.list_jobs()) == 1
|
||||
service = CronService(store_path, on_job=on_job, max_sleep_ms=1000)
|
||||
await service.start()
|
||||
assert len(service.list_jobs()) == 0
|
||||
|
||||
service2 = CronService(tmp_path / "cron" / "jobs.json")
|
||||
service2.add_job(
|
||||
name="hist2",
|
||||
schedule=CronSchedule(kind="every", every_ms=60_000),
|
||||
message="hello2",
|
||||
name="hist",
|
||||
schedule=CronSchedule(kind="every", every_ms=500),
|
||||
message="hello",
|
||||
)
|
||||
assert len(service.list_jobs()) == 2
|
||||
assert len(service.list_jobs()) == 1
|
||||
await asyncio.sleep(2)
|
||||
assert len(called) != 0
|
||||
service.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -204,7 +208,40 @@ async def test_running_service_picks_up_external_add(tmp_path):
|
||||
message="ping",
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.6)
|
||||
await asyncio.sleep(2)
|
||||
assert "external" in called
|
||||
finally:
|
||||
service.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_job_during_jobs_exec(tmp_path):
|
||||
store_path = tmp_path / "cron" / "jobs.json"
|
||||
run_once = True
|
||||
|
||||
async def on_job(job):
|
||||
nonlocal run_once
|
||||
if run_once:
|
||||
service2 = CronService(store_path, on_job=lambda x: asyncio.sleep(0))
|
||||
service2.add_job(
|
||||
name="test",
|
||||
schedule=CronSchedule(kind="every", every_ms=150),
|
||||
message="tick",
|
||||
)
|
||||
run_once = False
|
||||
|
||||
service = CronService(store_path, on_job=on_job)
|
||||
service.add_job(
|
||||
name="heartbeat",
|
||||
schedule=CronSchedule(kind="every", every_ms=150),
|
||||
message="tick",
|
||||
)
|
||||
assert len(service.list_jobs()) == 1
|
||||
await service.start()
|
||||
try:
|
||||
await asyncio.sleep(3)
|
||||
jobs = service.list_jobs()
|
||||
assert len(jobs) == 2
|
||||
assert "test" in [j.name for j in jobs]
|
||||
finally:
|
||||
service.stop()
|
||||
|
||||
@ -2,9 +2,12 @@
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.tools.cron import CronTool
|
||||
from nanobot.cron.service import CronService
|
||||
from nanobot.cron.types import CronJob, CronJobState, CronPayload, CronSchedule
|
||||
from tests.test_openai_api import pytest_plugins
|
||||
|
||||
|
||||
def _make_tool(tmp_path) -> CronTool:
|
||||
@ -215,8 +218,10 @@ def test_list_at_job_shows_iso_timestamp(tmp_path) -> None:
|
||||
assert "Asia/Shanghai" in result
|
||||
|
||||
|
||||
def test_list_shows_last_run_state(tmp_path) -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_shows_last_run_state(tmp_path) -> None:
|
||||
tool = _make_tool(tmp_path)
|
||||
tool._cron._running = True
|
||||
job = tool._cron.add_job(
|
||||
name="Stateful job",
|
||||
schedule=CronSchedule(kind="cron", expr="0 9 * * *", tz="UTC"),
|
||||
@ -232,9 +237,10 @@ def test_list_shows_last_run_state(tmp_path) -> None:
|
||||
assert "ok" in result
|
||||
assert "(UTC)" in result
|
||||
|
||||
|
||||
def test_list_shows_error_message(tmp_path) -> None:
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_shows_error_message(tmp_path) -> None:
|
||||
tool = _make_tool(tmp_path)
|
||||
tool._cron._running = True
|
||||
job = tool._cron.add_job(
|
||||
name="Failed job",
|
||||
schedule=CronSchedule(kind="cron", expr="0 9 * * *", tz="UTC"),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user