fix(session): correct last_consolidated tracking in non-contiguous retention

The previous fix made retain_recent_legal_suffix return the actual dropped
message list, but already_consolidated was still computed with
min(before_last_consolidated, len(dropped)), which assumes dropped messages
are always a prefix. In the else branch (tail has no user messages), dropped
may include messages from after the consolidated prefix, causing
already_consolidated to skip too many and leaving tail messages neither
retained nor raw-archived.

Fix by having retain_recent_legal_suffix return (dropped,
already_consolidated_count) where already_consolidated_count is computed
against original message indices. Also fix last_consolidated update to count
how many retained messages were inside the old consolidated prefix.
This commit is contained in:
yorkhellen 2026-06-01 11:47:37 +08:00 committed by Xubin Ren
parent 72fb642ef7
commit baffd6ef92
2 changed files with 86 additions and 12 deletions

View File

@ -269,19 +269,25 @@ class Session:
self.updated_at = datetime.now()
self.metadata.pop("_last_summary", None)
def retain_recent_legal_suffix(self, max_messages: int) -> list[dict]:
def retain_recent_legal_suffix(self, max_messages: int) -> tuple[list[dict], int]:
"""Keep a legal recent suffix constrained by a hard message cap.
Returns the list of messages that were removed.
Returns ``(dropped, already_consolidated_count)`` where *dropped* is
the list of removed messages (in original order) and
*already_consolidated_count* is how many of those were inside the
pre-existing ``last_consolidated`` prefix and therefore do not need
raw archiving.
"""
if max_messages <= 0:
dropped = list(self.messages)
lc = self.last_consolidated
self.clear()
return dropped
return dropped, min(lc, len(dropped))
if len(self.messages) <= max_messages:
return []
return [], 0
original = list(self.messages)
before_lc = self.last_consolidated
retained = list(self.messages[-max_messages:])
@ -318,10 +324,26 @@ class Session:
retained_ids = set(id(m) for m in retained)
dropped = [m for m in original if id(m) not in retained_ids]
# Count how many dropped messages were in the already-consolidated
# prefix of the original list. This cannot be a simple min() because
# dropped may include messages from *after* the consolidated prefix
# (e.g. in the else branch).
already_consolidated = sum(
1 for i, m in enumerate(original)
if i < before_lc and id(m) not in retained_ids
)
# New last_consolidated = count of retained messages that were inside
# the old consolidated prefix.
new_lc = sum(
1 for i, m in enumerate(original)
if i < before_lc and id(m) in retained_ids
)
self.messages = retained
self.last_consolidated = max(0, self.last_consolidated - len(dropped))
self.last_consolidated = new_lc
self.updated_at = datetime.now()
return dropped
return dropped, already_consolidated
def enforce_file_cap(
self,
@ -332,12 +354,10 @@ class Session:
if limit <= 0 or len(self.messages) <= limit:
return
before_last_consolidated = self.last_consolidated
dropped = self.retain_recent_legal_suffix(limit)
dropped, already_consolidated = self.retain_recent_legal_suffix(limit)
if not dropped:
return
already_consolidated = min(before_last_consolidated, len(dropped))
archive_chunk = dropped[already_consolidated:]
if archive_chunk and on_archive:
on_archive(archive_chunk)

View File

@ -549,11 +549,12 @@ def test_retain_recent_legal_suffix_returns_dropped_messages():
for i in range(10):
session.messages.append({"role": "user", "content": f"msg{i}"})
dropped = session.retain_recent_legal_suffix(4)
dropped, already_cons = session.retain_recent_legal_suffix(4)
assert len(dropped) == 6
assert [m["content"] for m in dropped] == [f"msg{i}" for i in range(6)]
assert len(session.messages) == 4
assert already_cons == 0
def test_retain_recent_legal_suffix_returns_empty_when_no_drop():
@ -562,9 +563,10 @@ def test_retain_recent_legal_suffix_returns_empty_when_no_drop():
for i in range(3):
session.messages.append({"role": "user", "content": f"msg{i}"})
dropped = session.retain_recent_legal_suffix(4)
dropped, already_cons = session.retain_recent_legal_suffix(4)
assert dropped == []
assert already_cons == 0
assert len(session.messages) == 3
@ -573,10 +575,12 @@ def test_retain_recent_legal_suffix_returns_all_on_zero():
session = Session(key="test:zero-return")
for i in range(5):
session.messages.append({"role": "user", "content": f"msg{i}"})
session.last_consolidated = 3
dropped = session.retain_recent_legal_suffix(0)
dropped, already_cons = session.retain_recent_legal_suffix(0)
assert len(dropped) == 5
assert already_cons == 3
assert session.messages == []
@ -640,3 +644,53 @@ def test_enforce_file_cap_no_message_loss_in_else_branch():
assert not missing, (
f"Lost {len(missing)} message(s) — neither retained nor archived"
)
def test_enforce_file_cap_correct_archive_with_last_consolidated_in_else_branch():
"""When last_consolidated > 0 and the else branch fires, only the
unconsolidated dropped messages should be raw-archived. Messages in the
consolidated prefix that are dropped do NOT need raw archiving."""
from unittest.mock import MagicMock
session = Session(key="test:else-lc-archive")
# 20 messages total: u0..u9 (user), a0..a9 (assistant)
for i in range(10):
session.messages.append({"role": "user", "content": f"u{i}"})
for i in range(10):
session.messages.append({"role": "assistant", "content": f"a{i}"})
# First 8 messages already consolidated
session.last_consolidated = 8
archive_fn = MagicMock()
session.enforce_file_cap(on_archive=archive_fn, limit=4)
if archive_fn.called:
archived = archive_fn.call_args.args[0]
# Archived messages should NOT include any from the consolidated prefix
# (u0..u7). They should only be unconsolidated dropped messages.
archived_contents = [m["content"] for m in archived]
for c in archived_contents:
assert c not in [f"u{i}" for i in range(8)], (
f"Consolidated message {c!r} should not be raw-archived"
)
def test_retain_recent_legal_suffix_last_consolidated_correct_in_else_branch():
"""last_consolidated after retain_recent_legal_suffix should reflect how
many retained messages were inside the old consolidated prefix."""
session = Session(key="test:else-lc-correct")
# 20 messages: u0..u9, a0..a9
for i in range(10):
session.messages.append({"role": "user", "content": f"u{i}"})
for i in range(10):
session.messages.append({"role": "assistant", "content": f"a{i}"})
session.last_consolidated = 12 # u0..u9, a0, a1 consolidated
dropped, already_cons = session.retain_recent_legal_suffix(4)
# Retained messages start from latest user (u9) + max_messages forward
# so retained = [u9, a0..a9][:4] → but these are from original indices 9..12
# Of those, indices 9,10,11 are < 12 (before_lc), so new_lc = 3
assert session.last_consolidated <= 12
# already_cons should count dropped messages with original index < 12
assert already_cons >= 0