mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-19 01:04:04 +00:00
fix(heartbeat): skip when HEARTBEAT.md has no tasks and fail closed on delivery (#4111)
This commit is contained in:
parent
2b4c984e9a
commit
e3df310309
@ -4,6 +4,8 @@ from contextvars import ContextVar
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Awaitable, Callable
|
from typing import Any, Awaitable, Callable
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
from nanobot.agent.tools.base import Tool, tool_parameters
|
from nanobot.agent.tools.base import Tool, tool_parameters
|
||||||
from nanobot.agent.tools.context import ContextAware, RequestContext
|
from nanobot.agent.tools.context import ContextAware, RequestContext
|
||||||
from nanobot.agent.tools.path_utils import resolve_workspace_path
|
from nanobot.agent.tools.path_utils import resolve_workspace_path
|
||||||
@ -83,6 +85,10 @@ class MessageTool(Tool, ContextAware):
|
|||||||
"message_record_channel_delivery",
|
"message_record_channel_delivery",
|
||||||
default=False,
|
default=False,
|
||||||
)
|
)
|
||||||
|
self._suppress_delivery_var: ContextVar[bool] = ContextVar(
|
||||||
|
"message_suppress_delivery",
|
||||||
|
default=False,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(cls, ctx: Any) -> Tool:
|
def create(cls, ctx: Any) -> Tool:
|
||||||
@ -121,6 +127,14 @@ class MessageTool(Tool, ContextAware):
|
|||||||
"""Restore previous proactive delivery recording state."""
|
"""Restore previous proactive delivery recording state."""
|
||||||
self._record_channel_delivery_var.reset(token)
|
self._record_channel_delivery_var.reset(token)
|
||||||
|
|
||||||
|
def set_suppress_delivery(self, active: bool):
|
||||||
|
"""Acknowledge but don't deliver tool sends (heartbeat internal check)."""
|
||||||
|
return self._suppress_delivery_var.set(active)
|
||||||
|
|
||||||
|
def reset_suppress_delivery(self, token) -> None:
|
||||||
|
"""Restore previous delivery-suppression state."""
|
||||||
|
self._suppress_delivery_var.reset(token)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _sent_in_turn(self) -> bool:
|
def _sent_in_turn(self) -> bool:
|
||||||
return self._sent_in_turn_var.get()
|
return self._sent_in_turn_var.get()
|
||||||
@ -241,6 +255,10 @@ class MessageTool(Tool, ContextAware):
|
|||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self._suppress_delivery_var.get():
|
||||||
|
logger.debug("MessageTool: delivery suppressed during internal check")
|
||||||
|
return f"Message acknowledged for {channel}:{chat_id} (not delivered)"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self._send_callback(msg)
|
await self._send_callback(msg)
|
||||||
if channel == default_channel and chat_id == default_chat_id:
|
if channel == default_channel and chat_id == default_chat_id:
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
"""CLI commands for nanobot."""
|
"""CLI commands for nanobot."""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import functools
|
|
||||||
import os
|
import os
|
||||||
import select
|
import select
|
||||||
import signal
|
import signal
|
||||||
@ -105,10 +104,23 @@ _HEARTBEAT_PREAMBLE = (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@functools.lru_cache(maxsize=None)
|
def _heartbeat_has_active_tasks(content: str) -> bool:
|
||||||
def _heartbeat_template() -> str | None:
|
"""True if HEARTBEAT.md has task lines, ignoring headers, blanks and comments."""
|
||||||
from nanobot.utils.helpers import load_bundled_template
|
in_comment = False
|
||||||
return load_bundled_template("HEARTBEAT.md")
|
for line in content.splitlines():
|
||||||
|
stripped = line.strip()
|
||||||
|
if in_comment:
|
||||||
|
if "-->" in stripped:
|
||||||
|
in_comment = False
|
||||||
|
continue
|
||||||
|
if not stripped or stripped.startswith("#"):
|
||||||
|
continue
|
||||||
|
if stripped.startswith("<!--"):
|
||||||
|
if "-->" not in stripped[4:]:
|
||||||
|
in_comment = True
|
||||||
|
continue
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# CLI input: prompt_toolkit for editing, paste, history, and display
|
# CLI input: prompt_toolkit for editing, paste, history, and display
|
||||||
@ -978,8 +990,8 @@ def _run_gateway(
|
|||||||
except OSError:
|
except OSError:
|
||||||
logger.debug("Heartbeat: HEARTBEAT.md missing")
|
logger.debug("Heartbeat: HEARTBEAT.md missing")
|
||||||
return None
|
return None
|
||||||
if not content or content == _heartbeat_template():
|
if not _heartbeat_has_active_tasks(content):
|
||||||
logger.debug("Heartbeat: HEARTBEAT.md empty or identical to template")
|
logger.debug("Heartbeat: HEARTBEAT.md has no active tasks")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
channel, chat_id = _pick_heartbeat_target()
|
channel, chat_id = _pick_heartbeat_target()
|
||||||
@ -991,13 +1003,22 @@ def _run_gateway(
|
|||||||
+ f"Review the following HEARTBEAT.md and report any active tasks:\n\n{content}"
|
+ f"Review the following HEARTBEAT.md and report any active tasks:\n\n{content}"
|
||||||
)
|
)
|
||||||
|
|
||||||
resp = await agent.process_direct(
|
# Internal check: funnel all output through the post-run gate so the
|
||||||
prompt,
|
# turn can't deliver directly via the message tool and skip it.
|
||||||
session_key="heartbeat",
|
suppress_token = None
|
||||||
channel=channel,
|
if isinstance(message_tool, MessageTool):
|
||||||
chat_id=chat_id,
|
suppress_token = message_tool.set_suppress_delivery(True)
|
||||||
on_progress=_silent,
|
try:
|
||||||
)
|
resp = await agent.process_direct(
|
||||||
|
prompt,
|
||||||
|
session_key="heartbeat",
|
||||||
|
channel=channel,
|
||||||
|
chat_id=chat_id,
|
||||||
|
on_progress=_silent,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
if isinstance(message_tool, MessageTool) and suppress_token is not None:
|
||||||
|
message_tool.reset_suppress_delivery(suppress_token)
|
||||||
response = resp.content if resp else ""
|
response = resp.content if resp else ""
|
||||||
|
|
||||||
# Keep a small tail of heartbeat history so the loop stays bounded.
|
# Keep a small tail of heartbeat history so the loop stays bounded.
|
||||||
@ -1008,8 +1029,10 @@ def _run_gateway(
|
|||||||
if not response:
|
if not response:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
# Fail closed: stay silent on evaluator failure instead of notifying.
|
||||||
should_notify = await evaluate_response(
|
should_notify = await evaluate_response(
|
||||||
response, prompt, agent.provider, agent.model,
|
response, prompt, agent.provider, agent.model,
|
||||||
|
default_notify=False,
|
||||||
)
|
)
|
||||||
if should_notify:
|
if should_notify:
|
||||||
logger.info("Heartbeat: completed, delivering response")
|
logger.info("Heartbeat: completed, delivering response")
|
||||||
|
|||||||
@ -1,9 +1,11 @@
|
|||||||
# Heartbeat Tasks
|
# Heartbeat Tasks
|
||||||
|
|
||||||
|
<!--
|
||||||
This file is checked periodically by your nanobot agent.
|
This file is checked periodically by your nanobot agent.
|
||||||
Register it as a cron job (e.g. `cron add --name heartbeat --schedule "every 30m" --message "Check HEARTBEAT.md"`) to get the same behavior as the legacy heartbeat service.
|
Register it as a cron job (e.g. `cron add --name heartbeat --schedule "every 30m" --message "Check HEARTBEAT.md"`) to get the same behavior as the legacy heartbeat service.
|
||||||
|
|
||||||
If this file has no tasks (only headers and comments), the agent will skip it.
|
If this file has no tasks (only headers and comments), the agent will skip it.
|
||||||
|
-->
|
||||||
|
|
||||||
## Active Tasks
|
## Active Tasks
|
||||||
|
|
||||||
|
|||||||
@ -44,12 +44,12 @@ async def evaluate_response(
|
|||||||
task_context: str,
|
task_context: str,
|
||||||
provider: LLMProvider,
|
provider: LLMProvider,
|
||||||
model: str,
|
model: str,
|
||||||
|
default_notify: bool = True,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Decide whether a background-task result should be delivered to the user.
|
"""Decide whether a background-task result should be delivered to the user.
|
||||||
|
|
||||||
Uses a lightweight tool-call LLM request (same pattern as heartbeat
|
On any failure, falls back to ``default_notify`` (cron reminders fail open;
|
||||||
``_decide()``). Falls back to ``True`` (notify) on any failure so
|
heartbeat passes ``False`` to fail closed).
|
||||||
that important messages are never silently dropped.
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
llm_response = await provider.chat_with_retry(
|
llm_response = await provider.chat_with_retry(
|
||||||
@ -71,19 +71,24 @@ async def evaluate_response(
|
|||||||
if not llm_response.should_execute_tools:
|
if not llm_response.should_execute_tools:
|
||||||
if llm_response.has_tool_calls:
|
if llm_response.has_tool_calls:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"evaluate_response: ignoring tool calls under finish_reason='{}', defaulting to notify",
|
"evaluate_response: ignoring tool calls under finish_reason='{}', "
|
||||||
|
"defaulting to notify={}",
|
||||||
llm_response.finish_reason,
|
llm_response.finish_reason,
|
||||||
|
default_notify,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning("evaluate_response: no tool call returned, defaulting to notify")
|
logger.warning(
|
||||||
return True
|
"evaluate_response: no tool call returned, defaulting to notify={}",
|
||||||
|
default_notify,
|
||||||
|
)
|
||||||
|
return default_notify
|
||||||
|
|
||||||
args = llm_response.tool_calls[0].arguments
|
args = llm_response.tool_calls[0].arguments
|
||||||
should_notify = args.get("should_notify", True)
|
should_notify = args.get("should_notify", default_notify)
|
||||||
reason = args.get("reason", "")
|
reason = args.get("reason", "")
|
||||||
logger.info("evaluate_response: should_notify={}, reason={}", should_notify, reason)
|
logger.info("evaluate_response: should_notify={}, reason={}", should_notify, reason)
|
||||||
return bool(should_notify)
|
return bool(should_notify)
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("evaluate_response failed, defaulting to notify")
|
logger.exception("evaluate_response failed, defaulting to notify={}", default_notify)
|
||||||
return True
|
return default_notify
|
||||||
|
|||||||
@ -61,3 +61,21 @@ async def test_no_tool_call_fallback() -> None:
|
|||||||
provider = DummyProvider([LLMResponse(content="I think you should notify", tool_calls=[])])
|
provider = DummyProvider([LLMResponse(content="I think you should notify", tool_calls=[])])
|
||||||
result = await evaluate_response("some response", "some task", provider, "m")
|
result = await evaluate_response("some response", "some task", provider, "m")
|
||||||
assert result is True
|
assert result is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fail_closed_on_error() -> None:
|
||||||
|
class FailingProvider(DummyProvider):
|
||||||
|
async def chat(self, *args, **kwargs) -> LLMResponse:
|
||||||
|
raise RuntimeError("provider down")
|
||||||
|
|
||||||
|
provider = FailingProvider([])
|
||||||
|
result = await evaluate_response("some", "task", provider, "m", default_notify=False)
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fail_closed_on_no_tool_call() -> None:
|
||||||
|
provider = DummyProvider([LLMResponse(content="text only", tool_calls=[])])
|
||||||
|
result = await evaluate_response("some", "task", provider, "m", default_notify=False)
|
||||||
|
assert result is False
|
||||||
|
|||||||
@ -952,6 +952,29 @@ def test_heartbeat_retains_recent_messages_by_default():
|
|||||||
assert config.gateway.heartbeat.keep_recent_messages == 8
|
assert config.gateway.heartbeat.keep_recent_messages == 8
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"content, expected",
|
||||||
|
[
|
||||||
|
("", False),
|
||||||
|
("# Title\n\n## Active Tasks\n", False),
|
||||||
|
("<!--\nmulti-line\ncomment\n-->\n", False), # block comment, not tasks
|
||||||
|
("<!-- single line -->\n", False),
|
||||||
|
("## Active Tasks\n\n- water the plants\n", True),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_heartbeat_has_active_tasks(content, expected):
|
||||||
|
from nanobot.cli.commands import _heartbeat_has_active_tasks
|
||||||
|
|
||||||
|
assert _heartbeat_has_active_tasks(content) is expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_heartbeat_skips_bundled_template():
|
||||||
|
from nanobot.cli.commands import _heartbeat_has_active_tasks
|
||||||
|
from nanobot.utils.helpers import load_bundled_template
|
||||||
|
|
||||||
|
assert _heartbeat_has_active_tasks(load_bundled_template("HEARTBEAT.md")) is False
|
||||||
|
|
||||||
|
|
||||||
def _write_instance_config(tmp_path: Path) -> Path:
|
def _write_instance_config(tmp_path: Path) -> Path:
|
||||||
config_file = tmp_path / "instance" / "config.json"
|
config_file = tmp_path / "instance" / "config.json"
|
||||||
config_file.parent.mkdir(parents=True)
|
config_file.parent.mkdir(parents=True)
|
||||||
|
|||||||
@ -38,6 +38,28 @@ async def test_message_tool_rejects_malformed_buttons(bad) -> None:
|
|||||||
assert result == "Error: buttons must be a list of list of strings"
|
assert result == "Error: buttons must be a list of list of strings"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_message_tool_suppresses_delivery_when_active() -> None:
|
||||||
|
sent: list[OutboundMessage] = []
|
||||||
|
|
||||||
|
async def _send(msg: OutboundMessage) -> None:
|
||||||
|
sent.append(msg)
|
||||||
|
|
||||||
|
tool = MessageTool(send_callback=_send)
|
||||||
|
|
||||||
|
token = tool.set_suppress_delivery(True)
|
||||||
|
try:
|
||||||
|
result = await tool.execute(content="all clear", channel="telegram", chat_id="1")
|
||||||
|
finally:
|
||||||
|
tool.reset_suppress_delivery(token)
|
||||||
|
assert sent == []
|
||||||
|
assert "not delivered" in result
|
||||||
|
|
||||||
|
await tool.execute(content="real", channel="telegram", chat_id="1")
|
||||||
|
assert len(sent) == 1
|
||||||
|
assert sent[0].content == "real"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_message_tool_marks_channel_delivery_only_when_enabled() -> None:
|
async def test_message_tool_marks_channel_delivery_only_when_enabled() -> None:
|
||||||
sent: list[OutboundMessage] = []
|
sent: list[OutboundMessage] = []
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user