fix(memory): consolidate history hidden by replay window

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
Xubin Ren 2026-05-07 14:19:22 +00:00 committed by Xubin Ren
parent 2c830ca817
commit 91ade9eaac
3 changed files with 146 additions and 16 deletions

View File

@ -969,6 +969,7 @@ class AgentLoop:
await self.consolidator.maybe_consolidate_by_tokens( await self.consolidator.maybe_consolidate_by_tokens(
session, session,
session_summary=pending, session_summary=pending,
replay_max_messages=self._max_messages,
) )
# Persist subagent follow-ups into durable history BEFORE prompt # Persist subagent follow-ups into durable history BEFORE prompt
# assembly. ContextBuilder merges adjacent same-role messages for # assembly. ContextBuilder merges adjacent same-role messages for
@ -1013,7 +1014,12 @@ class AgentLoop:
session.enforce_file_cap(on_archive=self.context.memory.raw_archive) session.enforce_file_cap(on_archive=self.context.memory.raw_archive)
self._clear_runtime_checkpoint(session) self._clear_runtime_checkpoint(session)
self.sessions.save(session) self.sessions.save(session)
self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session)) self._schedule_background(
self.consolidator.maybe_consolidate_by_tokens(
session,
replay_max_messages=self._max_messages,
)
)
options = ask_user_options_from_messages(all_msgs) if stop_reason == "ask_user" else [] options = ask_user_options_from_messages(all_msgs) if stop_reason == "ask_user" else []
content, buttons = ask_user_outbound( content, buttons = ask_user_outbound(
final_content or "Background task completed.", final_content or "Background task completed.",
@ -1066,6 +1072,7 @@ class AgentLoop:
await self.consolidator.maybe_consolidate_by_tokens( await self.consolidator.maybe_consolidate_by_tokens(
session, session,
session_summary=pending, session_summary=pending,
replay_max_messages=self._max_messages,
) )
self._set_tool_context( self._set_tool_context(
@ -1179,7 +1186,12 @@ class AgentLoop:
self._clear_pending_user_turn(session) self._clear_pending_user_turn(session)
self._clear_runtime_checkpoint(session) self._clear_runtime_checkpoint(session)
self.sessions.save(session) self.sessions.save(session)
self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session)) self._schedule_background(
self.consolidator.maybe_consolidate_by_tokens(
session,
replay_max_messages=self._max_messages,
)
)
# When follow-up messages were injected mid-turn, a later natural # When follow-up messages were injected mid-turn, a later natural
# language reply may address those follow-ups and should not be # language reply may address those follow-ups and should not be

View File

@ -17,6 +17,7 @@ from loguru import logger
from nanobot.agent.runner import AgentRunner, AgentRunSpec from nanobot.agent.runner import AgentRunner, AgentRunSpec
from nanobot.agent.tools.registry import ToolRegistry from nanobot.agent.tools.registry import ToolRegistry
from nanobot.session.manager import Session
from nanobot.utils.gitstore import GitStore from nanobot.utils.gitstore import GitStore
from nanobot.utils.helpers import ( from nanobot.utils.helpers import (
ensure_dir, ensure_dir,
@ -29,7 +30,7 @@ from nanobot.utils.prompt_templates import render_template
if TYPE_CHECKING: if TYPE_CHECKING:
from nanobot.providers.base import LLMProvider from nanobot.providers.base import LLMProvider
from nanobot.session.manager import Session, SessionManager from nanobot.session.manager import SessionManager
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -508,14 +509,85 @@ class Consolidator:
return last_boundary return last_boundary
@staticmethod
def _full_unconsolidated_history(
session: Session,
*,
include_timestamps: bool = False,
) -> list[dict[str, Any]]:
"""Return the whole unconsolidated tail for consolidation decisions."""
unconsolidated_count = len(session.messages) - session.last_consolidated
if unconsolidated_count <= 0:
return []
return session.get_history(
max_messages=unconsolidated_count,
include_timestamps=include_timestamps,
)
@staticmethod
def _replay_overflow_boundary(
session: Session,
replay_max_messages: int | None,
) -> int | None:
if not replay_max_messages or replay_max_messages <= 0:
return None
tail = list(session.messages[session.last_consolidated:])
if len(tail) <= replay_max_messages:
return None
probe = Session(
key=session.key,
messages=[dict(message) for message in tail],
created_at=session.created_at,
updated_at=session.updated_at,
metadata={},
last_consolidated=0,
)
probe.retain_recent_legal_suffix(replay_max_messages)
cut = len(tail) - len(probe.messages)
if cut <= 0:
return None
return session.last_consolidated + cut
async def _consolidate_replay_overflow(
self,
session: Session,
replay_max_messages: int | None,
) -> str | None:
"""Archive messages that would be hidden by the replay message window."""
end_idx = self._replay_overflow_boundary(session, replay_max_messages)
if end_idx is None:
return None
chunk = session.messages[session.last_consolidated:end_idx]
if not chunk:
return None
logger.info(
"Replay-window consolidation for {}: chunk={} msgs, replay_max={}",
session.key,
len(chunk),
replay_max_messages,
)
summary = await self.archive(chunk)
session.last_consolidated = end_idx
self.sessions.save(session)
return summary
def _persist_last_summary(self, session: Session, summary: str | None) -> None:
if summary and summary != "(nothing)":
session.metadata["_last_summary"] = {
"text": summary,
"last_active": session.updated_at.isoformat(),
}
self.sessions.save(session)
def estimate_session_prompt_tokens( def estimate_session_prompt_tokens(
self, self,
session: Session, session: Session,
*, *,
session_summary: str | None = None, session_summary: str | None = None,
) -> tuple[int, str]: ) -> tuple[int, str]:
"""Estimate current prompt size for the normal session history view.""" """Estimate prompt size from the full unconsolidated session tail."""
history = session.get_history(max_messages=0, include_timestamps=True) history = self._full_unconsolidated_history(session, include_timestamps=True)
channel, chat_id = (session.key.split(":", 1) if ":" in session.key else (None, None)) channel, chat_id = (session.key.split(":", 1) if ":" in session.key else (None, None))
probe_messages = self._build_messages( probe_messages = self._build_messages(
history=history, history=history,
@ -591,6 +663,7 @@ class Consolidator:
session: Session, session: Session,
*, *,
session_summary: str | None = None, session_summary: str | None = None,
replay_max_messages: int | None = None,
) -> None: ) -> None:
"""Loop: archive old messages until prompt fits within safe budget. """Loop: archive old messages until prompt fits within safe budget.
@ -604,6 +677,10 @@ class Consolidator:
async with lock: async with lock:
budget = self._input_token_budget budget = self._input_token_budget
target = int(budget * self.consolidation_ratio) target = int(budget * self.consolidation_ratio)
last_summary = await self._consolidate_replay_overflow(
session,
replay_max_messages,
)
try: try:
estimated, source = self.estimate_session_prompt_tokens( estimated, source = self.estimate_session_prompt_tokens(
session, session,
@ -613,6 +690,7 @@ class Consolidator:
logger.exception("Token estimation failed for {}", session.key) logger.exception("Token estimation failed for {}", session.key)
estimated, source = 0, "error" estimated, source = 0, "error"
if estimated <= 0: if estimated <= 0:
self._persist_last_summary(session, last_summary)
return return
if estimated < budget: if estimated < budget:
unconsolidated_count = len(session.messages) - session.last_consolidated unconsolidated_count = len(session.messages) - session.last_consolidated
@ -624,9 +702,9 @@ class Consolidator:
source, source,
unconsolidated_count, unconsolidated_count,
) )
self._persist_last_summary(session, last_summary)
return return
last_summary = None
for round_num in range(self._MAX_CONSOLIDATION_ROUNDS): for round_num in range(self._MAX_CONSOLIDATION_ROUNDS):
if estimated <= target: if estimated <= target:
break break
@ -683,12 +761,7 @@ class Consolidator:
# Persist the last summary to session metadata so it can be injected # Persist the last summary to session metadata so it can be injected
# into the runtime context on the next prepare_session() call, aligning # into the runtime context on the next prepare_session() call, aligning
# the summary injection strategy with AutoCompact._archive(). # the summary injection strategy with AutoCompact._archive().
if last_summary and last_summary != "(nothing)": self._persist_last_summary(session, last_summary)
session.metadata["_last_summary"] = {
"text": last_summary,
"last_active": session.updated_at.isoformat(),
}
self.sessions.save(session)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------

View File

@ -1,15 +1,15 @@
"""Tests for the lightweight Consolidator — append-only to HISTORY.md.""" """Tests for the lightweight Consolidator — append-only to HISTORY.md."""
from unittest.mock import AsyncMock, MagicMock
import pytest import pytest
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
from nanobot.agent.memory import ( from nanobot.agent.memory import (
_ARCHIVE_SUMMARY_MAX_CHARS,
Consolidator, Consolidator,
MemoryStore, MemoryStore,
_ARCHIVE_SUMMARY_MAX_CHARS,
_RAW_ARCHIVE_MAX_CHARS,
) )
from nanobot.session.manager import Session
@pytest.fixture @pytest.fixture
@ -122,6 +122,51 @@ class TestConsolidatorTokenBudget:
await consolidator.maybe_consolidate_by_tokens(session) await consolidator.maybe_consolidate_by_tokens(session)
consolidator.archive.assert_not_called() consolidator.archive.assert_not_called()
async def test_estimate_uses_full_unconsolidated_tail(self, consolidator):
"""Consolidation pressure must see messages hidden by the replay window."""
session = Session(key="test:full-tail")
for i in range(160):
session.add_message("user", f"msg-{i}")
captured: dict[str, list[dict]] = {}
def build_messages(**kwargs):
captured["history"] = kwargs["history"]
return kwargs["history"]
consolidator._build_messages = build_messages
consolidator.estimate_session_prompt_tokens(session)
assert len(captured["history"]) == 160
assert captured["history"][0]["content"].endswith("msg-0")
async def test_replay_window_overflow_is_archived_even_under_token_budget(
self,
consolidator,
):
"""Old messages that cannot be replayed should be materialized first."""
consolidator._SAFETY_BUFFER = 0
session = Session(key="test:replay-overflow")
for i in range(10):
session.add_message("user", f"u{i}")
session.add_message("assistant", f"a{i}")
consolidator.estimate_session_prompt_tokens = MagicMock(return_value=(100, "tiktoken"))
consolidator.archive = AsyncMock(return_value="old conversation summary")
await consolidator.maybe_consolidate_by_tokens(
session,
replay_max_messages=6,
)
archived_chunk = consolidator.archive.await_args.args[0]
assert archived_chunk[0]["content"] == "u0"
assert archived_chunk[-1]["content"] == "a6"
assert session.last_consolidated == 14
assert session.metadata["_last_summary"]["text"] == "old conversation summary"
consolidator.sessions.save.assert_called()
async def test_large_chunk_archived_without_cap(self, consolidator): async def test_large_chunk_archived_without_cap(self, consolidator):
"""Without chunk cap, the full range from pick_consolidation_boundary is archived.""" """Without chunk cap, the full range from pick_consolidation_boundary is archived."""
consolidator._SAFETY_BUFFER = 0 consolidator._SAFETY_BUFFER = 0