mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-13 14:23:58 +00:00
fix(heartbeat): skip when HEARTBEAT.md has no tasks and fail closed on delivery (#4111)
This commit is contained in:
parent
2b4c984e9a
commit
e3df310309
@ -4,6 +4,8 @@ from contextvars import ContextVar
|
||||
from pathlib import Path
|
||||
from typing import Any, Awaitable, Callable
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.agent.tools.base import Tool, tool_parameters
|
||||
from nanobot.agent.tools.context import ContextAware, RequestContext
|
||||
from nanobot.agent.tools.path_utils import resolve_workspace_path
|
||||
@ -83,6 +85,10 @@ class MessageTool(Tool, ContextAware):
|
||||
"message_record_channel_delivery",
|
||||
default=False,
|
||||
)
|
||||
self._suppress_delivery_var: ContextVar[bool] = ContextVar(
|
||||
"message_suppress_delivery",
|
||||
default=False,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create(cls, ctx: Any) -> Tool:
|
||||
@ -121,6 +127,14 @@ class MessageTool(Tool, ContextAware):
|
||||
"""Restore previous proactive delivery recording state."""
|
||||
self._record_channel_delivery_var.reset(token)
|
||||
|
||||
def set_suppress_delivery(self, active: bool):
|
||||
"""Acknowledge but don't deliver tool sends (heartbeat internal check)."""
|
||||
return self._suppress_delivery_var.set(active)
|
||||
|
||||
def reset_suppress_delivery(self, token) -> None:
|
||||
"""Restore previous delivery-suppression state."""
|
||||
self._suppress_delivery_var.reset(token)
|
||||
|
||||
@property
|
||||
def _sent_in_turn(self) -> bool:
|
||||
return self._sent_in_turn_var.get()
|
||||
@ -241,6 +255,10 @@ class MessageTool(Tool, ContextAware):
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
if self._suppress_delivery_var.get():
|
||||
logger.debug("MessageTool: delivery suppressed during internal check")
|
||||
return f"Message acknowledged for {channel}:{chat_id} (not delivered)"
|
||||
|
||||
try:
|
||||
await self._send_callback(msg)
|
||||
if channel == default_channel and chat_id == default_chat_id:
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
"""CLI commands for nanobot."""
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import os
|
||||
import select
|
||||
import signal
|
||||
@ -105,10 +104,23 @@ _HEARTBEAT_PREAMBLE = (
|
||||
)
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _heartbeat_template() -> str | None:
|
||||
from nanobot.utils.helpers import load_bundled_template
|
||||
return load_bundled_template("HEARTBEAT.md")
|
||||
def _heartbeat_has_active_tasks(content: str) -> bool:
|
||||
"""True if HEARTBEAT.md has task lines, ignoring headers, blanks and comments."""
|
||||
in_comment = False
|
||||
for line in content.splitlines():
|
||||
stripped = line.strip()
|
||||
if in_comment:
|
||||
if "-->" in stripped:
|
||||
in_comment = False
|
||||
continue
|
||||
if not stripped or stripped.startswith("#"):
|
||||
continue
|
||||
if stripped.startswith("<!--"):
|
||||
if "-->" not in stripped[4:]:
|
||||
in_comment = True
|
||||
continue
|
||||
return True
|
||||
return False
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI input: prompt_toolkit for editing, paste, history, and display
|
||||
@ -978,8 +990,8 @@ def _run_gateway(
|
||||
except OSError:
|
||||
logger.debug("Heartbeat: HEARTBEAT.md missing")
|
||||
return None
|
||||
if not content or content == _heartbeat_template():
|
||||
logger.debug("Heartbeat: HEARTBEAT.md empty or identical to template")
|
||||
if not _heartbeat_has_active_tasks(content):
|
||||
logger.debug("Heartbeat: HEARTBEAT.md has no active tasks")
|
||||
return None
|
||||
|
||||
channel, chat_id = _pick_heartbeat_target()
|
||||
@ -991,13 +1003,22 @@ def _run_gateway(
|
||||
+ f"Review the following HEARTBEAT.md and report any active tasks:\n\n{content}"
|
||||
)
|
||||
|
||||
resp = await agent.process_direct(
|
||||
prompt,
|
||||
session_key="heartbeat",
|
||||
channel=channel,
|
||||
chat_id=chat_id,
|
||||
on_progress=_silent,
|
||||
)
|
||||
# Internal check: funnel all output through the post-run gate so the
|
||||
# turn can't deliver directly via the message tool and skip it.
|
||||
suppress_token = None
|
||||
if isinstance(message_tool, MessageTool):
|
||||
suppress_token = message_tool.set_suppress_delivery(True)
|
||||
try:
|
||||
resp = await agent.process_direct(
|
||||
prompt,
|
||||
session_key="heartbeat",
|
||||
channel=channel,
|
||||
chat_id=chat_id,
|
||||
on_progress=_silent,
|
||||
)
|
||||
finally:
|
||||
if isinstance(message_tool, MessageTool) and suppress_token is not None:
|
||||
message_tool.reset_suppress_delivery(suppress_token)
|
||||
response = resp.content if resp else ""
|
||||
|
||||
# Keep a small tail of heartbeat history so the loop stays bounded.
|
||||
@ -1008,8 +1029,10 @@ def _run_gateway(
|
||||
if not response:
|
||||
return None
|
||||
|
||||
# Fail closed: stay silent on evaluator failure instead of notifying.
|
||||
should_notify = await evaluate_response(
|
||||
response, prompt, agent.provider, agent.model,
|
||||
default_notify=False,
|
||||
)
|
||||
if should_notify:
|
||||
logger.info("Heartbeat: completed, delivering response")
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
# Heartbeat Tasks
|
||||
|
||||
<!--
|
||||
This file is checked periodically by your nanobot agent.
|
||||
Register it as a cron job (e.g. `cron add --name heartbeat --schedule "every 30m" --message "Check HEARTBEAT.md"`) to get the same behavior as the legacy heartbeat service.
|
||||
|
||||
If this file has no tasks (only headers and comments), the agent will skip it.
|
||||
-->
|
||||
|
||||
## Active Tasks
|
||||
|
||||
|
||||
@ -44,12 +44,12 @@ async def evaluate_response(
|
||||
task_context: str,
|
||||
provider: LLMProvider,
|
||||
model: str,
|
||||
default_notify: bool = True,
|
||||
) -> bool:
|
||||
"""Decide whether a background-task result should be delivered to the user.
|
||||
|
||||
Uses a lightweight tool-call LLM request (same pattern as heartbeat
|
||||
``_decide()``). Falls back to ``True`` (notify) on any failure so
|
||||
that important messages are never silently dropped.
|
||||
On any failure, falls back to ``default_notify`` (cron reminders fail open;
|
||||
heartbeat passes ``False`` to fail closed).
|
||||
"""
|
||||
try:
|
||||
llm_response = await provider.chat_with_retry(
|
||||
@ -71,19 +71,24 @@ async def evaluate_response(
|
||||
if not llm_response.should_execute_tools:
|
||||
if llm_response.has_tool_calls:
|
||||
logger.warning(
|
||||
"evaluate_response: ignoring tool calls under finish_reason='{}', defaulting to notify",
|
||||
"evaluate_response: ignoring tool calls under finish_reason='{}', "
|
||||
"defaulting to notify={}",
|
||||
llm_response.finish_reason,
|
||||
default_notify,
|
||||
)
|
||||
else:
|
||||
logger.warning("evaluate_response: no tool call returned, defaulting to notify")
|
||||
return True
|
||||
logger.warning(
|
||||
"evaluate_response: no tool call returned, defaulting to notify={}",
|
||||
default_notify,
|
||||
)
|
||||
return default_notify
|
||||
|
||||
args = llm_response.tool_calls[0].arguments
|
||||
should_notify = args.get("should_notify", True)
|
||||
should_notify = args.get("should_notify", default_notify)
|
||||
reason = args.get("reason", "")
|
||||
logger.info("evaluate_response: should_notify={}, reason={}", should_notify, reason)
|
||||
return bool(should_notify)
|
||||
|
||||
except Exception:
|
||||
logger.exception("evaluate_response failed, defaulting to notify")
|
||||
return True
|
||||
logger.exception("evaluate_response failed, defaulting to notify={}", default_notify)
|
||||
return default_notify
|
||||
|
||||
@ -61,3 +61,21 @@ async def test_no_tool_call_fallback() -> None:
|
||||
provider = DummyProvider([LLMResponse(content="I think you should notify", tool_calls=[])])
|
||||
result = await evaluate_response("some response", "some task", provider, "m")
|
||||
assert result is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fail_closed_on_error() -> None:
|
||||
class FailingProvider(DummyProvider):
|
||||
async def chat(self, *args, **kwargs) -> LLMResponse:
|
||||
raise RuntimeError("provider down")
|
||||
|
||||
provider = FailingProvider([])
|
||||
result = await evaluate_response("some", "task", provider, "m", default_notify=False)
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fail_closed_on_no_tool_call() -> None:
|
||||
provider = DummyProvider([LLMResponse(content="text only", tool_calls=[])])
|
||||
result = await evaluate_response("some", "task", provider, "m", default_notify=False)
|
||||
assert result is False
|
||||
|
||||
@ -952,6 +952,29 @@ def test_heartbeat_retains_recent_messages_by_default():
|
||||
assert config.gateway.heartbeat.keep_recent_messages == 8
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"content, expected",
|
||||
[
|
||||
("", False),
|
||||
("# Title\n\n## Active Tasks\n", False),
|
||||
("<!--\nmulti-line\ncomment\n-->\n", False), # block comment, not tasks
|
||||
("<!-- single line -->\n", False),
|
||||
("## Active Tasks\n\n- water the plants\n", True),
|
||||
],
|
||||
)
|
||||
def test_heartbeat_has_active_tasks(content, expected):
|
||||
from nanobot.cli.commands import _heartbeat_has_active_tasks
|
||||
|
||||
assert _heartbeat_has_active_tasks(content) is expected
|
||||
|
||||
|
||||
def test_heartbeat_skips_bundled_template():
|
||||
from nanobot.cli.commands import _heartbeat_has_active_tasks
|
||||
from nanobot.utils.helpers import load_bundled_template
|
||||
|
||||
assert _heartbeat_has_active_tasks(load_bundled_template("HEARTBEAT.md")) is False
|
||||
|
||||
|
||||
def _write_instance_config(tmp_path: Path) -> Path:
|
||||
config_file = tmp_path / "instance" / "config.json"
|
||||
config_file.parent.mkdir(parents=True)
|
||||
|
||||
@ -38,6 +38,28 @@ async def test_message_tool_rejects_malformed_buttons(bad) -> None:
|
||||
assert result == "Error: buttons must be a list of list of strings"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_tool_suppresses_delivery_when_active() -> None:
|
||||
sent: list[OutboundMessage] = []
|
||||
|
||||
async def _send(msg: OutboundMessage) -> None:
|
||||
sent.append(msg)
|
||||
|
||||
tool = MessageTool(send_callback=_send)
|
||||
|
||||
token = tool.set_suppress_delivery(True)
|
||||
try:
|
||||
result = await tool.execute(content="all clear", channel="telegram", chat_id="1")
|
||||
finally:
|
||||
tool.reset_suppress_delivery(token)
|
||||
assert sent == []
|
||||
assert "not delivered" in result
|
||||
|
||||
await tool.execute(content="real", channel="telegram", chat_id="1")
|
||||
assert len(sent) == 1
|
||||
assert sent[0].content == "real"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_message_tool_marks_channel_delivery_only_when_enabled() -> None:
|
||||
sent: list[OutboundMessage] = []
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user