fix(memory): preserve consolidation turn boundaries under chunk cap

Made-with: Cursor
This commit is contained in:
Xubin Ren 2026-04-10 04:57:02 +00:00 committed by Xubin Ren
parent bfe53ebb10
commit c579d67887
2 changed files with 74 additions and 4 deletions

View File

@ -400,6 +400,22 @@ class Consolidator:
return last_boundary
def _cap_consolidation_boundary(
self,
session: Session,
end_idx: int,
) -> int | None:
"""Clamp the chunk size without breaking the user-turn boundary."""
start = session.last_consolidated
if end_idx - start <= self._MAX_CHUNK_MESSAGES:
return end_idx
capped_end = start + self._MAX_CHUNK_MESSAGES
for idx in range(capped_end, start, -1):
if session.messages[idx].get("role") == "user":
return idx
return None
def estimate_session_prompt_tokens(self, session: Session) -> tuple[int, str]:
"""Estimate current prompt size for the normal session history view."""
history = session.get_history(max_messages=0)
@ -495,14 +511,19 @@ class Consolidator:
return
end_idx = boundary[0]
end_idx = self._cap_consolidation_boundary(session, end_idx)
if end_idx is None:
logger.debug(
"Token consolidation: no capped boundary for {} (round {})",
session.key,
round_num,
)
return
chunk = session.messages[session.last_consolidated:end_idx]
if not chunk:
return
if len(chunk) > self._MAX_CHUNK_MESSAGES:
chunk = chunk[:self._MAX_CHUNK_MESSAGES]
end_idx = session.last_consolidated + len(chunk)
logger.info(
"Token consolidation round {} for {}: {}/{} via {}, chunk={} msgs",
round_num,

View File

@ -76,3 +76,52 @@ class TestConsolidatorTokenBudget:
consolidator.archive = AsyncMock(return_value=True)
await consolidator.maybe_consolidate_by_tokens(session)
consolidator.archive.assert_not_called()
async def test_chunk_cap_preserves_user_turn_boundary(self, consolidator):
"""Chunk cap should rewind to the last user boundary within the cap."""
consolidator._SAFETY_BUFFER = 0
session = MagicMock()
session.last_consolidated = 0
session.key = "test:key"
session.messages = [
{
"role": "user" if i in {0, 50, 61} else "assistant",
"content": f"m{i}",
}
for i in range(70)
]
consolidator.estimate_session_prompt_tokens = MagicMock(
side_effect=[(1200, "tiktoken"), (400, "tiktoken")]
)
consolidator.pick_consolidation_boundary = MagicMock(return_value=(61, 999))
consolidator.archive = AsyncMock(return_value=True)
await consolidator.maybe_consolidate_by_tokens(session)
archived_chunk = consolidator.archive.await_args.args[0]
assert len(archived_chunk) == 50
assert archived_chunk[0]["content"] == "m0"
assert archived_chunk[-1]["content"] == "m49"
assert session.last_consolidated == 50
async def test_chunk_cap_skips_when_no_user_boundary_within_cap(self, consolidator):
"""If the cap would cut mid-turn, consolidation should skip that round."""
consolidator._SAFETY_BUFFER = 0
session = MagicMock()
session.last_consolidated = 0
session.key = "test:key"
session.messages = [
{
"role": "user" if i in {0, 61} else "assistant",
"content": f"m{i}",
}
for i in range(70)
]
consolidator.estimate_session_prompt_tokens = MagicMock(return_value=(1200, "tiktoken"))
consolidator.pick_consolidation_boundary = MagicMock(return_value=(61, 999))
consolidator.archive = AsyncMock(return_value=True)
await consolidator.maybe_consolidate_by_tokens(session)
consolidator.archive.assert_not_awaited()
assert session.last_consolidated == 0