mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-15 15:24:06 +00:00
241 lines
8.3 KiB
Python
241 lines
8.3 KiB
Python
"""Internal turn continuation helpers.
|
|
|
|
This module keeps budget-boundary continuation policy out of ``AgentLoop``.
|
|
The loop calls a small set of helpers; those helpers decide whether an internal
|
|
continuation is allowed and, when it is, queue the next turn directly.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import dataclasses
|
|
from typing import Any, Mapping, MutableMapping
|
|
|
|
from loguru import logger
|
|
|
|
from nanobot.session.goal_state import (
|
|
goal_state_runtime_lines,
|
|
sustained_goal_active,
|
|
sustained_goal_turn,
|
|
)
|
|
|
|
INTERNAL_CONTINUATION_META = "_internal_continuation"
|
|
INTERNAL_CONTINUATION_KIND_META = "_internal_continuation_kind"
|
|
INTERNAL_CONTINUATION_PENDING_META = "_internal_continuation_pending"
|
|
INTERNAL_CONTINUATION_RUN_STARTED_AT_META = "_internal_continuation_run_started_at"
|
|
|
|
_GOAL_CONTINUATION_KIND = "sustained_goal"
|
|
_GOAL_CONTINUATION_SENDER = "system:continuation"
|
|
_GOAL_CONTINUATION_ROUNDS_KEY = "_sustained_goal_continuation_rounds"
|
|
_MAX_GOAL_CONTINUATION_ROUNDS = 12
|
|
_STRIPPED_INBOUND_META_KEYS = {
|
|
"_stream_id",
|
|
"_stream_delta",
|
|
"_stream_end",
|
|
"_resuming",
|
|
INTERNAL_CONTINUATION_PENDING_META,
|
|
}
|
|
|
|
|
|
def internal_continuation_inbound(metadata: Mapping[str, Any] | None) -> bool:
|
|
"""True for an inbound message created by an internal continuation policy."""
|
|
return bool(metadata and metadata.get(INTERNAL_CONTINUATION_META) is True)
|
|
|
|
|
|
def internal_continuation_pending(metadata: Mapping[str, Any] | None) -> bool:
|
|
"""True when the current turn scheduled an invisible continuation slice."""
|
|
return bool(metadata and metadata.get(INTERNAL_CONTINUATION_PENDING_META) is True)
|
|
|
|
|
|
def internal_continuation_run_started_at(metadata: Mapping[str, Any] | None) -> float | None:
|
|
"""Return the user-visible run start propagated across continuation slices."""
|
|
if not metadata:
|
|
return None
|
|
value = metadata.get(INTERNAL_CONTINUATION_RUN_STARTED_AT_META)
|
|
if not isinstance(value, int | float):
|
|
return None
|
|
started_at = float(value)
|
|
return started_at if started_at > 0 else None
|
|
|
|
|
|
def should_persist_user_message(metadata: Mapping[str, Any] | None) -> bool:
|
|
"""Return whether this inbound message should be persisted as user input."""
|
|
return not internal_continuation_inbound(metadata)
|
|
|
|
|
|
def should_stream_budget_response(
|
|
*,
|
|
stop_reason: str,
|
|
pending_queue_available: bool,
|
|
session_metadata: Mapping[str, Any] | None,
|
|
message_metadata: Mapping[str, Any] | None = None,
|
|
) -> bool:
|
|
"""Return whether the budget-boundary response should be sent to the user."""
|
|
return not _continuation_available(
|
|
stop_reason=stop_reason,
|
|
pending_queue_available=pending_queue_available,
|
|
session_metadata=session_metadata,
|
|
message_metadata=message_metadata,
|
|
)
|
|
|
|
|
|
async def maybe_continue_turn(ctx: Any) -> bool:
|
|
"""Queue an internal continuation for *ctx* when policy allows it."""
|
|
if ctx.session is None or ctx.pending_queue is None:
|
|
return False
|
|
if not _continuation_available(
|
|
stop_reason=ctx.stop_reason,
|
|
pending_queue_available=True,
|
|
session_metadata=ctx.session.metadata,
|
|
message_metadata=ctx.msg.metadata,
|
|
):
|
|
return False
|
|
|
|
metadata = _internal_continuation_metadata(
|
|
ctx.msg.metadata,
|
|
run_started_at=getattr(ctx, "visible_run_started_at", None),
|
|
)
|
|
content = _goal_continuation_prompt(ctx.session.metadata)
|
|
messages = _strip_terminal_assistant(ctx.all_messages, ctx.final_content)
|
|
_increment_goal_continuation_round(ctx.session.metadata)
|
|
|
|
logger.info("Turn budget reached; scheduling internal continuation")
|
|
ctx.msg.metadata[INTERNAL_CONTINUATION_PENDING_META] = True
|
|
ctx.final_content = ""
|
|
ctx.all_messages = messages
|
|
ctx.suppress_response = True
|
|
await ctx.pending_queue.put(
|
|
dataclasses.replace(
|
|
ctx.msg,
|
|
sender_id=_GOAL_CONTINUATION_SENDER,
|
|
content=content,
|
|
media=[],
|
|
metadata=metadata,
|
|
session_key_override=ctx.session_key,
|
|
)
|
|
)
|
|
return True
|
|
|
|
|
|
def prepare_save_boundary(ctx: Any) -> None:
|
|
"""Prepare continuation bookkeeping and the history append boundary."""
|
|
if ctx.session is not None:
|
|
clear_internal_continuation_state(ctx.session.metadata)
|
|
|
|
ctx.save_skip = _save_skip_for_turn(
|
|
message_metadata=ctx.msg.metadata,
|
|
initial_message_count=len(ctx.initial_messages),
|
|
history_count=len(ctx.history),
|
|
user_persisted_early=ctx.user_persisted_early,
|
|
)
|
|
|
|
|
|
def _continuation_available(
|
|
*,
|
|
stop_reason: str,
|
|
pending_queue_available: bool,
|
|
session_metadata: Mapping[str, Any] | None,
|
|
message_metadata: Mapping[str, Any] | None = None,
|
|
) -> bool:
|
|
if stop_reason != "max_iterations" or not pending_queue_available:
|
|
return False
|
|
return _goal_continuation_available(
|
|
session_metadata,
|
|
message_metadata=message_metadata,
|
|
)
|
|
|
|
|
|
def clear_internal_continuation_state(metadata: MutableMapping[str, Any]) -> None:
|
|
"""Reset policy bookkeeping once its owning runtime mode is inactive."""
|
|
if not sustained_goal_active(metadata):
|
|
metadata.pop(_GOAL_CONTINUATION_ROUNDS_KEY, None)
|
|
|
|
|
|
def _save_skip_for_turn(
|
|
*,
|
|
message_metadata: Mapping[str, Any] | None,
|
|
initial_message_count: int,
|
|
history_count: int,
|
|
user_persisted_early: bool,
|
|
) -> int:
|
|
"""Return the persisted-message append boundary for this turn."""
|
|
if internal_continuation_inbound(message_metadata):
|
|
return initial_message_count
|
|
return 1 + history_count + (1 if user_persisted_early else 0)
|
|
|
|
|
|
def _goal_continuation_available(
|
|
session_metadata: Mapping[str, Any] | None,
|
|
*,
|
|
message_metadata: Mapping[str, Any] | None = None,
|
|
max_rounds: int = _MAX_GOAL_CONTINUATION_ROUNDS,
|
|
) -> bool:
|
|
if not sustained_goal_turn(session_metadata, message_metadata=message_metadata):
|
|
return False
|
|
if not sustained_goal_active(session_metadata):
|
|
return False
|
|
try:
|
|
rounds = int((session_metadata or {}).get(_GOAL_CONTINUATION_ROUNDS_KEY) or 0)
|
|
except (TypeError, ValueError):
|
|
rounds = 0
|
|
return rounds < max(0, max_rounds)
|
|
|
|
|
|
def _increment_goal_continuation_round(session_metadata: MutableMapping[str, Any]) -> None:
|
|
try:
|
|
rounds = int(session_metadata.get(_GOAL_CONTINUATION_ROUNDS_KEY) or 0)
|
|
except (TypeError, ValueError):
|
|
rounds = 0
|
|
session_metadata[_GOAL_CONTINUATION_ROUNDS_KEY] = rounds + 1
|
|
|
|
|
|
def _internal_continuation_metadata(
|
|
message_metadata: Mapping[str, Any] | None,
|
|
*,
|
|
run_started_at: float | None = None,
|
|
) -> dict[str, Any]:
|
|
metadata = dict(message_metadata or {})
|
|
metadata[INTERNAL_CONTINUATION_META] = True
|
|
metadata[INTERNAL_CONTINUATION_KIND_META] = _GOAL_CONTINUATION_KIND
|
|
if run_started_at is not None:
|
|
metadata[INTERNAL_CONTINUATION_RUN_STARTED_AT_META] = float(run_started_at)
|
|
for key in _STRIPPED_INBOUND_META_KEYS:
|
|
metadata.pop(key, None)
|
|
return metadata
|
|
|
|
|
|
def _goal_continuation_prompt(metadata: Mapping[str, Any] | None) -> str:
|
|
lines = goal_state_runtime_lines(metadata)
|
|
if lines:
|
|
goal = "\n".join(lines)
|
|
return (
|
|
"Continue the active sustained goal after the previous turn reached "
|
|
"its tool-call budget.\n\n"
|
|
f"{goal}\n\n"
|
|
"Continue from the saved context. Do not mention the continuation "
|
|
"boundary to the user. Use tools as needed, and call complete_goal "
|
|
"when the objective is truly finished."
|
|
)
|
|
return (
|
|
"Continue the active sustained goal after the previous turn reached "
|
|
"its tool-call budget. Continue from the saved context. Do not mention "
|
|
"the continuation boundary to the user. Use tools as needed, and call "
|
|
"complete_goal when the objective is truly finished."
|
|
)
|
|
|
|
|
|
def _strip_terminal_assistant(
|
|
messages: list[dict[str, Any]],
|
|
final_content: str | None,
|
|
) -> list[dict[str, Any]]:
|
|
"""Drop the synthetic max-iteration assistant message before saving history."""
|
|
if not messages:
|
|
return messages
|
|
last = messages[-1]
|
|
if last.get("role") != "assistant":
|
|
return messages
|
|
if final_content is None or last.get("content") != final_content:
|
|
return messages
|
|
if last.get("tool_calls"):
|
|
return messages
|
|
return messages[:-1]
|