mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-15 07:14:08 +00:00
fix(agent): extend sustained goal iteration budget
This commit is contained in:
parent
cba9ff1f57
commit
be2e0172d1
@ -45,6 +45,7 @@ from nanobot.session.goal_state import (
|
||||
sustained_goal_active,
|
||||
)
|
||||
from nanobot.session.manager import Session, SessionManager
|
||||
from nanobot.session import turn_continuation
|
||||
from nanobot.session.webui_turns import (
|
||||
WebuiTurnCoordinator,
|
||||
build_bus_progress_callback,
|
||||
@ -112,6 +113,7 @@ class TurnContext:
|
||||
save_skip: int = 0
|
||||
|
||||
outbound: OutboundMessage | None = None
|
||||
suppress_response: bool = False
|
||||
|
||||
on_progress: Callable[..., Awaitable[None]] | None = None
|
||||
on_stream: Callable[[str], Awaitable[None]] | None = None
|
||||
@ -121,6 +123,7 @@ class TurnContext:
|
||||
pending_queue: asyncio.Queue | None = None
|
||||
pending_summary: str | None = None
|
||||
turn_wall_started_at: float = field(default_factory=time.time)
|
||||
visible_run_started_at: float | None = None
|
||||
turn_latency_ms: int | None = None
|
||||
|
||||
trace: list[StateTraceEntry] = field(default_factory=list)
|
||||
@ -565,6 +568,8 @@ class AgentLoop:
|
||||
|
||||
Returns True if the message was persisted.
|
||||
"""
|
||||
if not turn_continuation.should_persist_user_message(msg.metadata):
|
||||
return False
|
||||
media_paths = [p for p in (msg.media or []) if isinstance(p, str) and p]
|
||||
has_text = isinstance(msg.content, str) and msg.content.strip()
|
||||
if has_text or media_paths:
|
||||
@ -771,6 +776,7 @@ class AgentLoop:
|
||||
+ "\n\nPlease continue working toward the objective using your tools, "
|
||||
"or call complete_goal if the work is truly finished."
|
||||
) if _goal_lines else SUSTAINED_GOAL_CONTINUE_PROMPT
|
||||
session_metadata = session.metadata if session is not None else None
|
||||
try:
|
||||
result = await self.runner.run(AgentRunSpec(
|
||||
initial_messages=initial_messages,
|
||||
@ -796,7 +802,8 @@ class AgentLoop:
|
||||
llm_timeout_s=runner_wall_llm_timeout_s(
|
||||
self.sessions,
|
||||
session.key if session is not None else session_key,
|
||||
metadata=(session.metadata if session is not None else None),
|
||||
metadata=session_metadata,
|
||||
message_metadata=metadata,
|
||||
),
|
||||
goal_active_predicate=lambda: sustained_goal_active(session.metadata) if session is not None else False,
|
||||
goal_continue_message=_goal_continue,
|
||||
@ -808,9 +815,15 @@ class AgentLoop:
|
||||
self._last_usage = result.usage
|
||||
if result.stop_reason == "max_iterations":
|
||||
logger.warning("Max iterations ({}) reached", self.max_iterations)
|
||||
should_stream = turn_continuation.should_stream_budget_response(
|
||||
stop_reason=result.stop_reason,
|
||||
pending_queue_available=pending_queue is not None and session is not None,
|
||||
session_metadata=session_metadata,
|
||||
message_metadata=metadata,
|
||||
)
|
||||
# Push final content through stream so streaming channels (e.g. Feishu)
|
||||
# update the card instead of leaving it empty.
|
||||
if on_stream and on_stream_end:
|
||||
if on_stream and on_stream_end and should_stream:
|
||||
await on_stream(result.final_content or "")
|
||||
await on_stream_end(resuming=False)
|
||||
elif result.stop_reason == "error":
|
||||
@ -953,7 +966,8 @@ class AgentLoop:
|
||||
channel=msg.channel, chat_id=msg.chat_id,
|
||||
content="", metadata=msg.metadata or {},
|
||||
))
|
||||
if msg.channel == "websocket":
|
||||
continuing = turn_continuation.internal_continuation_pending(msg.metadata)
|
||||
if msg.channel == "websocket" and not continuing:
|
||||
turn_lat = self._pending_turn_latency_ms.pop(session_key, None)
|
||||
await self._webui_turns.handle_turn_end(
|
||||
msg,
|
||||
@ -1017,9 +1031,10 @@ class AgentLoop:
|
||||
"Re-published {} leftover message(s) to bus for session {}",
|
||||
leftover, session_key,
|
||||
)
|
||||
await self._webui_turns.publish_run_status(msg, "idle")
|
||||
self._pending_turn_latency_ms.pop(session_key, None)
|
||||
self._webui_turns.discard(session_key)
|
||||
if not turn_continuation.internal_continuation_pending(msg.metadata):
|
||||
await self._webui_turns.publish_run_status(msg, "idle")
|
||||
self._pending_turn_latency_ms.pop(session_key, None)
|
||||
self._webui_turns.discard(session_key)
|
||||
finally:
|
||||
if pending is None:
|
||||
await self._webui_turns.publish_run_status(msg, "idle")
|
||||
@ -1167,12 +1182,17 @@ class AgentLoop:
|
||||
)
|
||||
|
||||
key = session_key or msg.session_key
|
||||
t0 = time.time()
|
||||
ctx = TurnContext(
|
||||
msg=msg,
|
||||
session=None,
|
||||
session_key=key,
|
||||
state=TurnState.RESTORE,
|
||||
turn_id=f"{key}:{time.time_ns()}",
|
||||
turn_wall_started_at=t0,
|
||||
visible_run_started_at=turn_continuation.internal_continuation_run_started_at(
|
||||
msg.metadata,
|
||||
),
|
||||
on_progress=on_progress,
|
||||
on_stream=on_stream,
|
||||
on_stream_end=on_stream_end,
|
||||
@ -1378,7 +1398,13 @@ class AgentLoop:
|
||||
return "ok"
|
||||
|
||||
async def _state_run(self, ctx: TurnContext) -> str:
|
||||
await self._webui_turns.publish_run_status(ctx.msg, "running")
|
||||
if ctx.visible_run_started_at is None:
|
||||
ctx.visible_run_started_at = time.time()
|
||||
await self._webui_turns.publish_run_status(
|
||||
ctx.msg,
|
||||
"running",
|
||||
started_at=ctx.visible_run_started_at,
|
||||
)
|
||||
result = await self._run_agent_loop(
|
||||
ctx.initial_messages,
|
||||
on_progress=ctx.on_progress,
|
||||
@ -1399,15 +1425,25 @@ class AgentLoop:
|
||||
ctx.all_messages = all_msgs
|
||||
ctx.stop_reason = stop_reason
|
||||
ctx.had_injections = had_injections
|
||||
await turn_continuation.maybe_continue_turn(ctx)
|
||||
return "ok"
|
||||
|
||||
async def _state_save(self, ctx: TurnContext) -> str:
|
||||
if ctx.final_content is None or not ctx.final_content.strip():
|
||||
turn_continuation.prepare_save_boundary(ctx)
|
||||
|
||||
if (
|
||||
(ctx.final_content is None or not ctx.final_content.strip())
|
||||
and not ctx.suppress_response
|
||||
):
|
||||
ctx.final_content = EMPTY_FINAL_RESPONSE_MESSAGE
|
||||
|
||||
ctx.save_skip = 1 + len(ctx.history) + (1 if ctx.user_persisted_early else 0)
|
||||
|
||||
ctx.turn_latency_ms = max(0, int((time.time() - ctx.turn_wall_started_at) * 1000))
|
||||
latency_started_at = (
|
||||
ctx.visible_run_started_at
|
||||
if turn_continuation.internal_continuation_inbound(ctx.msg.metadata)
|
||||
and ctx.visible_run_started_at is not None
|
||||
else ctx.turn_wall_started_at
|
||||
)
|
||||
ctx.turn_latency_ms = max(0, int((time.time() - latency_started_at) * 1000))
|
||||
self._save_turn(
|
||||
ctx.session, ctx.all_messages, ctx.save_skip,
|
||||
turn_latency_ms=ctx.turn_latency_ms,
|
||||
@ -1427,6 +1463,9 @@ class AgentLoop:
|
||||
return "ok"
|
||||
|
||||
async def _state_respond(self, ctx: TurnContext) -> str:
|
||||
if ctx.suppress_response:
|
||||
ctx.outbound = None
|
||||
return "ok"
|
||||
ctx.outbound = self._assemble_outbound(
|
||||
ctx.msg,
|
||||
ctx.final_content,
|
||||
|
||||
@ -43,6 +43,19 @@ def sustained_goal_active(metadata: Mapping[str, Any] | None) -> bool:
|
||||
return isinstance(goal, dict) and goal.get("status") == "active"
|
||||
|
||||
|
||||
def sustained_goal_turn(
|
||||
metadata: Mapping[str, Any] | None,
|
||||
*,
|
||||
message_metadata: Mapping[str, Any] | None = None,
|
||||
) -> bool:
|
||||
"""True when this turn should use sustained-goal runtime limits."""
|
||||
if sustained_goal_active(metadata):
|
||||
return True
|
||||
if not message_metadata:
|
||||
return False
|
||||
return str(message_metadata.get("original_command") or "").strip() == "/goal"
|
||||
|
||||
|
||||
def parse_goal_state(blob: Any) -> dict[str, Any] | None:
|
||||
if blob is None:
|
||||
return None
|
||||
@ -98,14 +111,16 @@ def runner_wall_llm_timeout_s(
|
||||
session_key: str | None,
|
||||
*,
|
||||
metadata: Mapping[str, Any] | None = None,
|
||||
message_metadata: Mapping[str, Any] | None = None,
|
||||
) -> float | None:
|
||||
"""Wall-clock cap for :class:`~nanobot.agent.runner.AgentRunner` when streaming an LLM.
|
||||
|
||||
Returns ``0.0`` to disable ``asyncio.wait_for`` around the request when a sustained goal is
|
||||
active; ``None`` means use ``NANOBOT_LLM_TIMEOUT_S``. Pass in-memory ``metadata`` when the
|
||||
caller already holds :attr:`~nanobot.session.manager.Session.metadata` for this turn.
|
||||
Returns ``0.0`` to disable ``asyncio.wait_for`` around the request when this is a
|
||||
sustained-goal turn; ``None`` means use ``NANOBOT_LLM_TIMEOUT_S``. Pass in-memory
|
||||
``metadata`` when the caller already holds :attr:`~nanobot.session.manager.Session.metadata`
|
||||
for this turn.
|
||||
"""
|
||||
meta: Mapping[str, Any] | None = metadata
|
||||
if meta is None and session_key:
|
||||
meta = sessions.get_or_create(session_key).metadata
|
||||
return 0.0 if sustained_goal_active(meta) else None
|
||||
return 0.0 if sustained_goal_turn(meta, message_metadata=message_metadata) else None
|
||||
|
||||
240
nanobot/session/turn_continuation.py
Normal file
240
nanobot/session/turn_continuation.py
Normal file
@ -0,0 +1,240 @@
|
||||
"""Internal turn continuation helpers.
|
||||
|
||||
This module keeps budget-boundary continuation policy out of ``AgentLoop``.
|
||||
The loop calls a small set of helpers; those helpers decide whether an internal
|
||||
continuation is allowed and, when it is, queue the next turn directly.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from typing import Any, Mapping, MutableMapping
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.session.goal_state import (
|
||||
goal_state_runtime_lines,
|
||||
sustained_goal_active,
|
||||
sustained_goal_turn,
|
||||
)
|
||||
|
||||
INTERNAL_CONTINUATION_META = "_internal_continuation"
|
||||
INTERNAL_CONTINUATION_KIND_META = "_internal_continuation_kind"
|
||||
INTERNAL_CONTINUATION_PENDING_META = "_internal_continuation_pending"
|
||||
INTERNAL_CONTINUATION_RUN_STARTED_AT_META = "_internal_continuation_run_started_at"
|
||||
|
||||
_GOAL_CONTINUATION_KIND = "sustained_goal"
|
||||
_GOAL_CONTINUATION_SENDER = "system:continuation"
|
||||
_GOAL_CONTINUATION_ROUNDS_KEY = "_sustained_goal_continuation_rounds"
|
||||
_MAX_GOAL_CONTINUATION_ROUNDS = 12
|
||||
_STRIPPED_INBOUND_META_KEYS = {
|
||||
"_stream_id",
|
||||
"_stream_delta",
|
||||
"_stream_end",
|
||||
"_resuming",
|
||||
INTERNAL_CONTINUATION_PENDING_META,
|
||||
}
|
||||
|
||||
|
||||
def internal_continuation_inbound(metadata: Mapping[str, Any] | None) -> bool:
|
||||
"""True for an inbound message created by an internal continuation policy."""
|
||||
return bool(metadata and metadata.get(INTERNAL_CONTINUATION_META) is True)
|
||||
|
||||
|
||||
def internal_continuation_pending(metadata: Mapping[str, Any] | None) -> bool:
|
||||
"""True when the current turn scheduled an invisible continuation slice."""
|
||||
return bool(metadata and metadata.get(INTERNAL_CONTINUATION_PENDING_META) is True)
|
||||
|
||||
|
||||
def internal_continuation_run_started_at(metadata: Mapping[str, Any] | None) -> float | None:
|
||||
"""Return the user-visible run start propagated across continuation slices."""
|
||||
if not metadata:
|
||||
return None
|
||||
value = metadata.get(INTERNAL_CONTINUATION_RUN_STARTED_AT_META)
|
||||
if not isinstance(value, int | float):
|
||||
return None
|
||||
started_at = float(value)
|
||||
return started_at if started_at > 0 else None
|
||||
|
||||
|
||||
def should_persist_user_message(metadata: Mapping[str, Any] | None) -> bool:
|
||||
"""Return whether this inbound message should be persisted as user input."""
|
||||
return not internal_continuation_inbound(metadata)
|
||||
|
||||
|
||||
def should_stream_budget_response(
|
||||
*,
|
||||
stop_reason: str,
|
||||
pending_queue_available: bool,
|
||||
session_metadata: Mapping[str, Any] | None,
|
||||
message_metadata: Mapping[str, Any] | None = None,
|
||||
) -> bool:
|
||||
"""Return whether the budget-boundary response should be sent to the user."""
|
||||
return not _continuation_available(
|
||||
stop_reason=stop_reason,
|
||||
pending_queue_available=pending_queue_available,
|
||||
session_metadata=session_metadata,
|
||||
message_metadata=message_metadata,
|
||||
)
|
||||
|
||||
|
||||
async def maybe_continue_turn(ctx: Any) -> bool:
|
||||
"""Queue an internal continuation for *ctx* when policy allows it."""
|
||||
if ctx.session is None or ctx.pending_queue is None:
|
||||
return False
|
||||
if not _continuation_available(
|
||||
stop_reason=ctx.stop_reason,
|
||||
pending_queue_available=True,
|
||||
session_metadata=ctx.session.metadata,
|
||||
message_metadata=ctx.msg.metadata,
|
||||
):
|
||||
return False
|
||||
|
||||
metadata = _internal_continuation_metadata(
|
||||
ctx.msg.metadata,
|
||||
run_started_at=getattr(ctx, "visible_run_started_at", None),
|
||||
)
|
||||
content = _goal_continuation_prompt(ctx.session.metadata)
|
||||
messages = _strip_terminal_assistant(ctx.all_messages, ctx.final_content)
|
||||
_increment_goal_continuation_round(ctx.session.metadata)
|
||||
|
||||
logger.info("Turn budget reached; scheduling internal continuation")
|
||||
ctx.msg.metadata[INTERNAL_CONTINUATION_PENDING_META] = True
|
||||
ctx.final_content = ""
|
||||
ctx.all_messages = messages
|
||||
ctx.suppress_response = True
|
||||
await ctx.pending_queue.put(
|
||||
dataclasses.replace(
|
||||
ctx.msg,
|
||||
sender_id=_GOAL_CONTINUATION_SENDER,
|
||||
content=content,
|
||||
media=[],
|
||||
metadata=metadata,
|
||||
session_key_override=ctx.session_key,
|
||||
)
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def prepare_save_boundary(ctx: Any) -> None:
|
||||
"""Prepare continuation bookkeeping and the history append boundary."""
|
||||
if ctx.session is not None:
|
||||
clear_internal_continuation_state(ctx.session.metadata)
|
||||
|
||||
ctx.save_skip = _save_skip_for_turn(
|
||||
message_metadata=ctx.msg.metadata,
|
||||
initial_message_count=len(ctx.initial_messages),
|
||||
history_count=len(ctx.history),
|
||||
user_persisted_early=ctx.user_persisted_early,
|
||||
)
|
||||
|
||||
|
||||
def _continuation_available(
|
||||
*,
|
||||
stop_reason: str,
|
||||
pending_queue_available: bool,
|
||||
session_metadata: Mapping[str, Any] | None,
|
||||
message_metadata: Mapping[str, Any] | None = None,
|
||||
) -> bool:
|
||||
if stop_reason != "max_iterations" or not pending_queue_available:
|
||||
return False
|
||||
return _goal_continuation_available(
|
||||
session_metadata,
|
||||
message_metadata=message_metadata,
|
||||
)
|
||||
|
||||
|
||||
def clear_internal_continuation_state(metadata: MutableMapping[str, Any]) -> None:
|
||||
"""Reset policy bookkeeping once its owning runtime mode is inactive."""
|
||||
if not sustained_goal_active(metadata):
|
||||
metadata.pop(_GOAL_CONTINUATION_ROUNDS_KEY, None)
|
||||
|
||||
|
||||
def _save_skip_for_turn(
|
||||
*,
|
||||
message_metadata: Mapping[str, Any] | None,
|
||||
initial_message_count: int,
|
||||
history_count: int,
|
||||
user_persisted_early: bool,
|
||||
) -> int:
|
||||
"""Return the persisted-message append boundary for this turn."""
|
||||
if internal_continuation_inbound(message_metadata):
|
||||
return initial_message_count
|
||||
return 1 + history_count + (1 if user_persisted_early else 0)
|
||||
|
||||
|
||||
def _goal_continuation_available(
|
||||
session_metadata: Mapping[str, Any] | None,
|
||||
*,
|
||||
message_metadata: Mapping[str, Any] | None = None,
|
||||
max_rounds: int = _MAX_GOAL_CONTINUATION_ROUNDS,
|
||||
) -> bool:
|
||||
if not sustained_goal_turn(session_metadata, message_metadata=message_metadata):
|
||||
return False
|
||||
if not sustained_goal_active(session_metadata):
|
||||
return False
|
||||
try:
|
||||
rounds = int((session_metadata or {}).get(_GOAL_CONTINUATION_ROUNDS_KEY) or 0)
|
||||
except (TypeError, ValueError):
|
||||
rounds = 0
|
||||
return rounds < max(0, max_rounds)
|
||||
|
||||
|
||||
def _increment_goal_continuation_round(session_metadata: MutableMapping[str, Any]) -> None:
|
||||
try:
|
||||
rounds = int(session_metadata.get(_GOAL_CONTINUATION_ROUNDS_KEY) or 0)
|
||||
except (TypeError, ValueError):
|
||||
rounds = 0
|
||||
session_metadata[_GOAL_CONTINUATION_ROUNDS_KEY] = rounds + 1
|
||||
|
||||
|
||||
def _internal_continuation_metadata(
|
||||
message_metadata: Mapping[str, Any] | None,
|
||||
*,
|
||||
run_started_at: float | None = None,
|
||||
) -> dict[str, Any]:
|
||||
metadata = dict(message_metadata or {})
|
||||
metadata[INTERNAL_CONTINUATION_META] = True
|
||||
metadata[INTERNAL_CONTINUATION_KIND_META] = _GOAL_CONTINUATION_KIND
|
||||
if run_started_at is not None:
|
||||
metadata[INTERNAL_CONTINUATION_RUN_STARTED_AT_META] = float(run_started_at)
|
||||
for key in _STRIPPED_INBOUND_META_KEYS:
|
||||
metadata.pop(key, None)
|
||||
return metadata
|
||||
|
||||
|
||||
def _goal_continuation_prompt(metadata: Mapping[str, Any] | None) -> str:
|
||||
lines = goal_state_runtime_lines(metadata)
|
||||
if lines:
|
||||
goal = "\n".join(lines)
|
||||
return (
|
||||
"Continue the active sustained goal after the previous turn reached "
|
||||
"its tool-call budget.\n\n"
|
||||
f"{goal}\n\n"
|
||||
"Continue from the saved context. Do not mention the continuation "
|
||||
"boundary to the user. Use tools as needed, and call complete_goal "
|
||||
"when the objective is truly finished."
|
||||
)
|
||||
return (
|
||||
"Continue the active sustained goal after the previous turn reached "
|
||||
"its tool-call budget. Continue from the saved context. Do not mention "
|
||||
"the continuation boundary to the user. Use tools as needed, and call "
|
||||
"complete_goal when the objective is truly finished."
|
||||
)
|
||||
|
||||
|
||||
def _strip_terminal_assistant(
|
||||
messages: list[dict[str, Any]],
|
||||
final_content: str | None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Drop the synthetic max-iteration assistant message before saving history."""
|
||||
if not messages:
|
||||
return messages
|
||||
last = messages[-1]
|
||||
if last.get("role") != "assistant":
|
||||
return messages
|
||||
if final_content is None or last.get("content") != final_content:
|
||||
return messages
|
||||
if last.get("tool_calls"):
|
||||
return messages
|
||||
return messages[:-1]
|
||||
@ -178,7 +178,13 @@ def websocket_turn_wall_started_at(chat_id: str) -> float | None:
|
||||
return _WEBSOCKET_TURN_WALL_STARTED_AT.get(chat_id)
|
||||
|
||||
|
||||
async def publish_turn_run_status(bus: MessageBus, msg: InboundMessage, status: str) -> None:
|
||||
async def publish_turn_run_status(
|
||||
bus: MessageBus,
|
||||
msg: InboundMessage,
|
||||
status: str,
|
||||
*,
|
||||
started_at: float | None = None,
|
||||
) -> None:
|
||||
"""Notify WebSocket clients while a user turn is executing (timing strip)."""
|
||||
if msg.channel != "websocket":
|
||||
return
|
||||
@ -189,7 +195,10 @@ async def publish_turn_run_status(bus: MessageBus, msg: InboundMessage, status:
|
||||
"goal_status": status,
|
||||
}
|
||||
if status == "running":
|
||||
t0 = time.time()
|
||||
if isinstance(started_at, int | float) and started_at > 0:
|
||||
t0 = float(started_at)
|
||||
else:
|
||||
t0 = time.time()
|
||||
meta["started_at"] = t0
|
||||
_WEBSOCKET_TURN_WALL_STARTED_AT[cid] = t0
|
||||
else:
|
||||
@ -300,8 +309,14 @@ class WebuiTurnCoordinator:
|
||||
def discard(self, session_key: str) -> None:
|
||||
self._title_contexts.pop(session_key, None)
|
||||
|
||||
async def publish_run_status(self, msg: InboundMessage, status: str) -> None:
|
||||
await publish_turn_run_status(self.bus, msg, status)
|
||||
async def publish_run_status(
|
||||
self,
|
||||
msg: InboundMessage,
|
||||
status: str,
|
||||
*,
|
||||
started_at: float | None = None,
|
||||
) -> None:
|
||||
await publish_turn_run_status(self.bus, msg, status, started_at=started_at)
|
||||
|
||||
async def handle_turn_end(
|
||||
self,
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
@ -48,6 +47,30 @@ async def test_loop_max_iterations_message_stays_stable(tmp_path):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_loop_goal_turn_uses_standard_iteration_budget(tmp_path):
|
||||
loop = _make_loop(tmp_path)
|
||||
loop.provider.chat_with_retry = AsyncMock(return_value=LLMResponse(
|
||||
content="working",
|
||||
tool_calls=[ToolCallRequest(id="call_1", name="list_dir", arguments={})],
|
||||
))
|
||||
loop.tools.get_definitions = MagicMock(return_value=[])
|
||||
loop.tools.execute = AsyncMock(return_value="ok")
|
||||
loop.max_iterations = 2
|
||||
|
||||
final_content, _, _, stop_reason, _ = await loop._run_agent_loop(
|
||||
[],
|
||||
metadata={"original_command": "/goal"},
|
||||
)
|
||||
|
||||
assert stop_reason == "max_iterations"
|
||||
assert loop.provider.chat_with_retry.await_count == 2
|
||||
assert final_content == (
|
||||
"I reached the maximum number of tool call iterations (2) "
|
||||
"without completing the task. You can try breaking the task into smaller steps."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_loop_stream_filter_handles_think_only_prefix_without_crashing(tmp_path):
|
||||
loop = _make_loop(tmp_path)
|
||||
|
||||
@ -11,6 +11,10 @@ from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMResponse
|
||||
from nanobot.session.goal_state import GOAL_STATE_KEY
|
||||
from nanobot.session.manager import Session, SessionManager
|
||||
from nanobot.session.turn_continuation import (
|
||||
INTERNAL_CONTINUATION_META,
|
||||
INTERNAL_CONTINUATION_RUN_STARTED_AT_META,
|
||||
)
|
||||
from nanobot.session.webui_turns import (
|
||||
TITLE_GENERATION_MAX_TOKENS,
|
||||
TITLE_GENERATION_REASONING_EFFORT,
|
||||
@ -560,6 +564,226 @@ async def test_process_message_does_not_duplicate_early_persisted_user_message(t
|
||||
assert AgentLoop._PENDING_USER_TURN_KEY not in session.metadata
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_internal_continuation_queues_turn_without_fake_user_history(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
loop = _make_full_loop(tmp_path)
|
||||
loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign]
|
||||
session = loop.sessions.get_or_create("feishu:c-auto")
|
||||
session.metadata[GOAL_STATE_KEY] = {
|
||||
"status": "active",
|
||||
"objective": "Finish the long goal.",
|
||||
}
|
||||
loop.sessions.save(session)
|
||||
|
||||
calls: list[dict] = []
|
||||
|
||||
async def fake_run_agent_loop(initial_messages, *, metadata=None, **_kwargs):
|
||||
calls.append({"initial_messages": initial_messages, "metadata": metadata})
|
||||
if len(calls) == 1:
|
||||
return (
|
||||
"paused",
|
||||
[],
|
||||
[*initial_messages, {"role": "assistant", "content": "paused"}],
|
||||
"max_iterations",
|
||||
False,
|
||||
)
|
||||
return (
|
||||
"done",
|
||||
[],
|
||||
[*initial_messages, {"role": "assistant", "content": "done"}],
|
||||
"completed",
|
||||
False,
|
||||
)
|
||||
|
||||
loop._run_agent_loop = fake_run_agent_loop # type: ignore[method-assign]
|
||||
pending: asyncio.Queue[InboundMessage] = asyncio.Queue()
|
||||
|
||||
first = await loop._process_message(
|
||||
InboundMessage(
|
||||
channel="feishu",
|
||||
sender_id="u1",
|
||||
chat_id="c-auto",
|
||||
content="start the goal",
|
||||
),
|
||||
pending_queue=pending,
|
||||
)
|
||||
|
||||
assert first is None
|
||||
queued = pending.get_nowait()
|
||||
assert queued.sender_id == "system:continuation"
|
||||
assert queued.metadata[INTERNAL_CONTINUATION_META] is True
|
||||
assert "Finish the long goal." in queued.content
|
||||
|
||||
session = loop.sessions.get_or_create("feishu:c-auto")
|
||||
assert [
|
||||
{k: v for k, v in m.items() if k in {"role", "content"}}
|
||||
for m in session.messages
|
||||
] == [{"role": "user", "content": "start the goal"}]
|
||||
|
||||
second = await loop._process_message(queued, pending_queue=asyncio.Queue())
|
||||
|
||||
assert second is not None
|
||||
assert second.content == "done"
|
||||
session = loop.sessions.get_or_create("feishu:c-auto")
|
||||
assert [
|
||||
{k: v for k, v in m.items() if k in {"role", "content"}}
|
||||
for m in session.messages
|
||||
] == [
|
||||
{"role": "user", "content": "start the goal"},
|
||||
{"role": "assistant", "content": "done"},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_internal_continuation_preserves_streaming_route_metadata(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
loop = _make_full_loop(tmp_path)
|
||||
loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign]
|
||||
session = loop.sessions.get_or_create("feishu:c-stream")
|
||||
session.metadata[GOAL_STATE_KEY] = {
|
||||
"status": "active",
|
||||
"objective": "Finish the streamed long goal.",
|
||||
}
|
||||
loop.sessions.save(session)
|
||||
|
||||
calls = 0
|
||||
|
||||
async def fake_run_agent_loop(initial_messages, *, on_stream=None, on_stream_end=None, **_kwargs):
|
||||
nonlocal calls
|
||||
calls += 1
|
||||
if calls == 1:
|
||||
return (
|
||||
"paused",
|
||||
[],
|
||||
[*initial_messages, {"role": "assistant", "content": "paused"}],
|
||||
"max_iterations",
|
||||
False,
|
||||
)
|
||||
assert on_stream is not None
|
||||
assert on_stream_end is not None
|
||||
await on_stream("done")
|
||||
await on_stream_end(resuming=False)
|
||||
return (
|
||||
"done",
|
||||
[],
|
||||
[*initial_messages, {"role": "assistant", "content": "done"}],
|
||||
"completed",
|
||||
False,
|
||||
)
|
||||
|
||||
loop._run_agent_loop = fake_run_agent_loop # type: ignore[method-assign]
|
||||
|
||||
await loop._dispatch(InboundMessage(
|
||||
channel="feishu",
|
||||
sender_id="u1",
|
||||
chat_id="c-stream",
|
||||
content="start the goal",
|
||||
metadata={
|
||||
"_wants_stream": True,
|
||||
"message_id": "om_001",
|
||||
"origin_message_id": "root_001",
|
||||
"_stream_id": "old-stream",
|
||||
},
|
||||
))
|
||||
|
||||
assert loop.bus.outbound_size == 0
|
||||
queued = await asyncio.wait_for(loop.bus.consume_inbound(), timeout=0.5)
|
||||
assert queued.metadata[INTERNAL_CONTINUATION_META] is True
|
||||
assert queued.metadata["_wants_stream"] is True
|
||||
assert queued.metadata["message_id"] == "om_001"
|
||||
assert queued.metadata["origin_message_id"] == "root_001"
|
||||
assert "_stream_id" not in queued.metadata
|
||||
|
||||
await loop._dispatch(queued)
|
||||
|
||||
outbound = []
|
||||
while loop.bus.outbound_size:
|
||||
outbound.append(await loop.bus.consume_outbound())
|
||||
deltas = [m for m in outbound if m.metadata.get("_stream_delta")]
|
||||
ends = [m for m in outbound if m.metadata.get("_stream_end")]
|
||||
streamed_markers = [m for m in outbound if m.metadata.get("_streamed")]
|
||||
|
||||
assert [m.content for m in deltas] == ["done"]
|
||||
assert len(ends) == 1
|
||||
assert ends[0].metadata["_resuming"] is False
|
||||
assert ends[0].metadata["message_id"] == "om_001"
|
||||
assert ends[0].metadata["origin_message_id"] == "root_001"
|
||||
assert isinstance(ends[0].metadata.get("_stream_id"), str)
|
||||
assert streamed_markers and streamed_markers[-1].content == "done"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_websocket_internal_continuation_keeps_single_visible_run(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
loop = _make_full_loop(tmp_path)
|
||||
loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign]
|
||||
session = loop.sessions.get_or_create("websocket:c-auto")
|
||||
session.metadata[GOAL_STATE_KEY] = {
|
||||
"status": "active",
|
||||
"objective": "Finish the long goal.",
|
||||
}
|
||||
loop.sessions.save(session)
|
||||
|
||||
calls = 0
|
||||
|
||||
async def fake_run_agent_loop(initial_messages, **_kwargs):
|
||||
nonlocal calls
|
||||
calls += 1
|
||||
if calls == 1:
|
||||
return (
|
||||
"paused",
|
||||
[],
|
||||
[*initial_messages, {"role": "assistant", "content": "paused"}],
|
||||
"max_iterations",
|
||||
False,
|
||||
)
|
||||
return (
|
||||
"done",
|
||||
[],
|
||||
[*initial_messages, {"role": "assistant", "content": "done"}],
|
||||
"completed",
|
||||
False,
|
||||
)
|
||||
|
||||
loop._run_agent_loop = fake_run_agent_loop # type: ignore[method-assign]
|
||||
|
||||
await loop._dispatch(InboundMessage(
|
||||
channel="websocket",
|
||||
sender_id="u1",
|
||||
chat_id="c-auto",
|
||||
content="start the goal",
|
||||
metadata={"webui": True},
|
||||
))
|
||||
|
||||
first_outbound = []
|
||||
while loop.bus.outbound_size:
|
||||
first_outbound.append(await loop.bus.consume_outbound())
|
||||
first_statuses = [m.metadata for m in first_outbound if m.metadata.get("_goal_status")]
|
||||
assert [m["goal_status"] for m in first_statuses] == ["running"]
|
||||
assert not [m for m in first_outbound if m.metadata.get("_turn_end")]
|
||||
started_at = first_statuses[0]["started_at"]
|
||||
|
||||
queued = await asyncio.wait_for(loop.bus.consume_inbound(), timeout=0.5)
|
||||
assert queued.metadata[INTERNAL_CONTINUATION_META] is True
|
||||
assert queued.metadata[INTERNAL_CONTINUATION_RUN_STARTED_AT_META] == started_at
|
||||
|
||||
await loop._dispatch(queued)
|
||||
|
||||
second_outbound = []
|
||||
while loop.bus.outbound_size:
|
||||
second_outbound.append(await loop.bus.consume_outbound())
|
||||
second_statuses = [m.metadata for m in second_outbound if m.metadata.get("_goal_status")]
|
||||
assert [m["goal_status"] for m in second_statuses] == ["running", "idle"]
|
||||
assert second_statuses[0]["started_at"] == started_at
|
||||
turn_end = [m for m in second_outbound if m.metadata.get("_turn_end")]
|
||||
assert len(turn_end) == 1
|
||||
assert isinstance(turn_end[0].metadata.get("latency_ms"), int)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_message_uses_context_chat_id_for_runtime_prompt(tmp_path: Path) -> None:
|
||||
loop = _make_full_loop(tmp_path)
|
||||
|
||||
127
tests/session/test_turn_continuation.py
Normal file
127
tests/session/test_turn_continuation.py
Normal file
@ -0,0 +1,127 @@
|
||||
"""Tests for internal turn continuation policy."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.session.goal_state import GOAL_STATE_KEY
|
||||
from nanobot.session.turn_continuation import (
|
||||
INTERNAL_CONTINUATION_KIND_META,
|
||||
INTERNAL_CONTINUATION_META,
|
||||
INTERNAL_CONTINUATION_PENDING_META,
|
||||
INTERNAL_CONTINUATION_RUN_STARTED_AT_META,
|
||||
internal_continuation_pending,
|
||||
internal_continuation_run_started_at,
|
||||
maybe_continue_turn,
|
||||
should_stream_budget_response,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_maybe_continue_turn_queues_internal_message():
|
||||
meta = {
|
||||
GOAL_STATE_KEY: {
|
||||
"status": "active",
|
||||
"objective": "Finish the migration.",
|
||||
"ui_summary": "migration",
|
||||
},
|
||||
}
|
||||
messages = [
|
||||
{"role": "system", "content": "system"},
|
||||
{"role": "user", "content": "start"},
|
||||
{"role": "assistant", "content": "paused"},
|
||||
]
|
||||
pending: asyncio.Queue[InboundMessage] = asyncio.Queue()
|
||||
ctx = SimpleNamespace(
|
||||
session=SimpleNamespace(metadata=meta),
|
||||
msg=InboundMessage(
|
||||
channel="feishu",
|
||||
sender_id="u1",
|
||||
chat_id="c1",
|
||||
content="start",
|
||||
metadata={
|
||||
"message_id": "msg-1",
|
||||
"origin_message_id": "msg-0",
|
||||
"_wants_stream": True,
|
||||
"_stream_id": "stream-1",
|
||||
"_stream_delta": True,
|
||||
"_stream_end": True,
|
||||
"_resuming": True,
|
||||
"webui": True,
|
||||
},
|
||||
),
|
||||
session_key="feishu:c1",
|
||||
pending_queue=pending,
|
||||
stop_reason="max_iterations",
|
||||
final_content="paused",
|
||||
all_messages=messages,
|
||||
suppress_response=False,
|
||||
visible_run_started_at=1234.5,
|
||||
)
|
||||
|
||||
assert await maybe_continue_turn(ctx) is True
|
||||
|
||||
queued = pending.get_nowait()
|
||||
assert queued.sender_id == "system:continuation"
|
||||
assert queued.metadata[INTERNAL_CONTINUATION_META] is True
|
||||
assert queued.metadata[INTERNAL_CONTINUATION_KIND_META] == "sustained_goal"
|
||||
assert queued.metadata[INTERNAL_CONTINUATION_RUN_STARTED_AT_META] == 1234.5
|
||||
assert internal_continuation_run_started_at(queued.metadata) == 1234.5
|
||||
assert internal_continuation_pending(ctx.msg.metadata)
|
||||
assert queued.metadata["webui"] is True
|
||||
assert queued.metadata["message_id"] == "msg-1"
|
||||
assert queued.metadata["origin_message_id"] == "msg-0"
|
||||
assert queued.metadata["_wants_stream"] is True
|
||||
assert "_stream_id" not in queued.metadata
|
||||
assert "_stream_delta" not in queued.metadata
|
||||
assert "_stream_end" not in queued.metadata
|
||||
assert "_resuming" not in queued.metadata
|
||||
assert "Finish the migration." in queued.content
|
||||
assert ctx.all_messages == messages[:-1]
|
||||
assert ctx.final_content == ""
|
||||
assert ctx.suppress_response is True
|
||||
assert ctx.msg.metadata[INTERNAL_CONTINUATION_PENDING_META] is True
|
||||
assert meta["_sustained_goal_continuation_rounds"] == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_internal_continuation_respects_round_limit():
|
||||
meta = {
|
||||
GOAL_STATE_KEY: {"status": "active", "objective": "x"},
|
||||
"_sustained_goal_continuation_rounds": 12,
|
||||
}
|
||||
ctx = SimpleNamespace(
|
||||
session=SimpleNamespace(metadata=meta),
|
||||
msg=InboundMessage(channel="feishu", sender_id="u1", chat_id="c1", content="start"),
|
||||
session_key="feishu:c1",
|
||||
pending_queue=asyncio.Queue(),
|
||||
stop_reason="max_iterations",
|
||||
final_content="paused",
|
||||
all_messages=[],
|
||||
)
|
||||
|
||||
assert should_stream_budget_response(
|
||||
stop_reason="max_iterations",
|
||||
pending_queue_available=True,
|
||||
session_metadata=meta,
|
||||
)
|
||||
assert await maybe_continue_turn(ctx) is False
|
||||
|
||||
|
||||
def test_internal_continuation_requires_budget_boundary_and_queue():
|
||||
meta = {GOAL_STATE_KEY: {"status": "active", "objective": "x"}}
|
||||
|
||||
assert should_stream_budget_response(
|
||||
stop_reason="completed",
|
||||
pending_queue_available=True,
|
||||
session_metadata=meta,
|
||||
)
|
||||
assert should_stream_budget_response(
|
||||
stop_reason="max_iterations",
|
||||
pending_queue_available=False,
|
||||
session_metadata=meta,
|
||||
)
|
||||
@ -31,6 +31,19 @@ async def test_publish_turn_run_status_running_records_wall_clock() -> None:
|
||||
assert call.metadata.get("started_at") == t0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_turn_run_status_reuses_explicit_wall_clock() -> None:
|
||||
bus = MagicMock()
|
||||
bus.publish_outbound = AsyncMock()
|
||||
msg = InboundMessage(channel="websocket", sender_id="u", chat_id="chat-a", content="hi")
|
||||
|
||||
await wth.publish_turn_run_status(bus, msg, "running", started_at=1234.5)
|
||||
|
||||
assert wth.websocket_turn_wall_started_at("chat-a") == 1234.5
|
||||
call = bus.publish_outbound.await_args[0][0]
|
||||
assert call.metadata.get("started_at") == 1234.5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_turn_run_status_idle_clears_wall_clock() -> None:
|
||||
bus = MagicMock()
|
||||
|
||||
@ -393,6 +393,15 @@ function mcpPresetMentionPayload(preset: McpPresetInfo): OutboundMcpPresetMentio
|
||||
};
|
||||
}
|
||||
|
||||
function RunPulseIcon() {
|
||||
return (
|
||||
<span className="run-pulse-icon relative flex h-4 w-4 shrink-0 items-center justify-center" aria-hidden>
|
||||
<span className="run-pulse-icon__ring" />
|
||||
<span className="run-pulse-icon__dot" />
|
||||
</span>
|
||||
);
|
||||
}
|
||||
|
||||
function RunElapsedStrip({
|
||||
startedAt,
|
||||
goalState,
|
||||
@ -586,7 +595,7 @@ function RunElapsedStrip({
|
||||
aria-label={ariaLabel}
|
||||
>
|
||||
{displayShowTimer ? (
|
||||
<Activity className="h-4 w-4 shrink-0 text-primary/80" aria-hidden />
|
||||
<RunPulseIcon />
|
||||
) : (
|
||||
<Target className="h-4 w-4 shrink-0 text-primary/75" aria-hidden />
|
||||
)}
|
||||
|
||||
@ -301,6 +301,50 @@
|
||||
.composer-status-strip[data-state="exit"] {
|
||||
animation: composer-status-strip-exit 180ms ease-in both;
|
||||
}
|
||||
@keyframes run-pulse-dot {
|
||||
0%,
|
||||
100% {
|
||||
transform: scale(0.9);
|
||||
opacity: 0.76;
|
||||
}
|
||||
50% {
|
||||
transform: scale(1.08);
|
||||
opacity: 1;
|
||||
}
|
||||
}
|
||||
@keyframes run-pulse-ring {
|
||||
0% {
|
||||
transform: scale(0.42);
|
||||
opacity: 0.34;
|
||||
}
|
||||
100% {
|
||||
transform: scale(1.28);
|
||||
opacity: 0;
|
||||
}
|
||||
}
|
||||
.run-pulse-icon {
|
||||
color: hsl(204 82% 46%);
|
||||
}
|
||||
.run-pulse-icon__ring,
|
||||
.run-pulse-icon__dot {
|
||||
display: block;
|
||||
border-radius: 999px;
|
||||
pointer-events: none;
|
||||
}
|
||||
.run-pulse-icon__ring {
|
||||
position: absolute;
|
||||
height: 12px;
|
||||
width: 12px;
|
||||
background: hsl(204 82% 46% / 0.22);
|
||||
animation: run-pulse-ring 1.55s ease-out infinite;
|
||||
}
|
||||
.run-pulse-icon__dot {
|
||||
height: 6px;
|
||||
width: 6px;
|
||||
background: currentColor;
|
||||
box-shadow: 0 0 0 1px hsl(204 82% 46% / 0.14);
|
||||
animation: run-pulse-dot 1.55s ease-in-out infinite;
|
||||
}
|
||||
@keyframes queued-prompt-row-enter {
|
||||
0% {
|
||||
opacity: 0;
|
||||
@ -321,6 +365,18 @@
|
||||
.composer-status-strip[data-state] {
|
||||
animation: none;
|
||||
}
|
||||
.run-pulse-icon,
|
||||
.run-pulse-icon * {
|
||||
animation: none;
|
||||
}
|
||||
.run-pulse-icon__ring {
|
||||
opacity: 0.18;
|
||||
transform: scale(1);
|
||||
}
|
||||
.run-pulse-icon__dot {
|
||||
opacity: 1;
|
||||
transform: scale(1);
|
||||
}
|
||||
.queued-prompt-row {
|
||||
animation: none;
|
||||
}
|
||||
|
||||
@ -386,6 +386,7 @@ describe("ThreadComposer", () => {
|
||||
expect(status).toHaveTextContent(/2:05/);
|
||||
expect(status.parentElement).toHaveClass("composer-status-strip");
|
||||
expect(status.parentElement).toHaveAttribute("data-state", "enter");
|
||||
expect(status.querySelector(".run-pulse-icon")).not.toBeNull();
|
||||
|
||||
vi.useRealTimers();
|
||||
});
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user