fix(session): prevent duplicate archive and message loss in enforce_file_cap

When retain_recent_legal_suffix hits the else branch (tail has no user
messages), it takes a non-contiguous slice from the middle of the session.
enforce_file_cap incorrectly assumed dropped messages were always a prefix
(before[:dropped_count]), causing user messages to be both archived and
retained, and some messages to silently disappear.

Fix by having retain_recent_legal_suffix return the actual dropped message
list using identity-based diff, so enforce_file_cap no longer needs to
guess which messages were removed.
This commit is contained in:
yorkhellen 2026-06-01 09:47:26 +08:00 committed by Xubin Ren
parent b886b4a566
commit 72fb642ef7
2 changed files with 124 additions and 14 deletions

View File

@ -269,13 +269,19 @@ class Session:
self.updated_at = datetime.now() self.updated_at = datetime.now()
self.metadata.pop("_last_summary", None) self.metadata.pop("_last_summary", None)
def retain_recent_legal_suffix(self, max_messages: int) -> None: def retain_recent_legal_suffix(self, max_messages: int) -> list[dict]:
"""Keep a legal recent suffix constrained by a hard message cap.""" """Keep a legal recent suffix constrained by a hard message cap.
Returns the list of messages that were removed.
"""
if max_messages <= 0: if max_messages <= 0:
dropped = list(self.messages)
self.clear() self.clear()
return return dropped
if len(self.messages) <= max_messages: if len(self.messages) <= max_messages:
return return []
original = list(self.messages)
retained = list(self.messages[-max_messages:]) retained = list(self.messages[-max_messages:])
@ -306,10 +312,16 @@ class Session:
if start: if start:
retained = retained[start:] retained = retained[start:]
dropped = len(self.messages) - len(retained) # Compute actually-dropped messages using identity comparison so that
# even when retained is a non-contiguous slice of original (the else
# branch above), we never duplicate or lose messages.
retained_ids = set(id(m) for m in retained)
dropped = [m for m in original if id(m) not in retained_ids]
self.messages = retained self.messages = retained
self.last_consolidated = max(0, self.last_consolidated - dropped) self.last_consolidated = max(0, self.last_consolidated - len(dropped))
self.updated_at = datetime.now() self.updated_at = datetime.now()
return dropped
def enforce_file_cap( def enforce_file_cap(
self, self,
@ -320,23 +332,19 @@ class Session:
if limit <= 0 or len(self.messages) <= limit: if limit <= 0 or len(self.messages) <= limit:
return return
before = list(self.messages)
before_last_consolidated = self.last_consolidated before_last_consolidated = self.last_consolidated
before_count = len(before) dropped = self.retain_recent_legal_suffix(limit)
self.retain_recent_legal_suffix(limit) if not dropped:
dropped_count = before_count - len(self.messages)
if dropped_count <= 0:
return return
dropped = before[:dropped_count] already_consolidated = min(before_last_consolidated, len(dropped))
already_consolidated = min(before_last_consolidated, dropped_count)
archive_chunk = dropped[already_consolidated:] archive_chunk = dropped[already_consolidated:]
if archive_chunk and on_archive: if archive_chunk and on_archive:
on_archive(archive_chunk) on_archive(archive_chunk)
logger.info( logger.info(
"Session file cap hit for {}: dropped {}, raw-archived {}, kept {}", "Session file cap hit for {}: dropped {}, raw-archived {}, kept {}",
self.key, self.key,
dropped_count, len(dropped),
len(archive_chunk), len(archive_chunk),
len(self.messages), len(self.messages),
) )

View File

@ -538,3 +538,105 @@ def test_retain_recent_legal_suffix_hard_cap_with_long_non_user_chain():
session.retain_recent_legal_suffix(6) session.retain_recent_legal_suffix(6)
assert len(session.messages) <= 6 assert len(session.messages) <= 6
# --- enforce_file_cap archive correctness (issue #4128) ---
def test_retain_recent_legal_suffix_returns_dropped_messages():
"""retain_recent_legal_suffix returns the actually-dropped messages."""
session = Session(key="test:return-dropped")
for i in range(10):
session.messages.append({"role": "user", "content": f"msg{i}"})
dropped = session.retain_recent_legal_suffix(4)
assert len(dropped) == 6
assert [m["content"] for m in dropped] == [f"msg{i}" for i in range(6)]
assert len(session.messages) == 4
def test_retain_recent_legal_suffix_returns_empty_when_no_drop():
"""No messages dropped → empty list returned."""
session = Session(key="test:no-drop")
for i in range(3):
session.messages.append({"role": "user", "content": f"msg{i}"})
dropped = session.retain_recent_legal_suffix(4)
assert dropped == []
assert len(session.messages) == 3
def test_retain_recent_legal_suffix_returns_all_on_zero():
"""max_messages=0 clears session and returns all messages."""
session = Session(key="test:zero-return")
for i in range(5):
session.messages.append({"role": "user", "content": f"msg{i}"})
dropped = session.retain_recent_legal_suffix(0)
assert len(dropped) == 5
assert session.messages == []
def test_enforce_file_cap_no_duplicate_archive_in_else_branch():
"""When the tail is assistant-only, enforce_file_cap must not archive
messages that are also retained (the bug from issue #4128)."""
from unittest.mock import MagicMock
session = Session(key="test:else-archive")
# Build: 15 user messages, then 10 assistant messages (no user in tail)
for i in range(15):
session.messages.append({"role": "user", "content": f"u{i}"})
for i in range(10):
session.messages.append({"role": "assistant", "content": f"a{i}"})
archive_fn = MagicMock()
session.enforce_file_cap(on_archive=archive_fn, limit=6)
# Verify retained messages
retained_contents = [m["content"] for m in session.messages]
assert len(session.messages) <= 6
# Verify archived messages have NO overlap with retained
if archive_fn.called:
archived = archive_fn.call_args.args[0]
archived_ids = set(id(m) for m in archived)
retained_ids = set(id(m) for m in session.messages)
assert not archived_ids & retained_ids, (
f"Duplicate messages in archive and retained: "
f"overlap contents = {[m['content'] for m in archived if id(m) in retained_ids]}"
)
def test_enforce_file_cap_no_message_loss_in_else_branch():
"""In the else branch, no messages should silently disappear — every
message must be either retained or archived."""
from unittest.mock import MagicMock
session = Session(key="test:else-no-loss")
all_messages = []
for i in range(15):
msg = {"role": "user", "content": f"u{i}"}
session.messages.append(msg)
all_messages.append(msg)
for i in range(10):
msg = {"role": "assistant", "content": f"a{i}"}
session.messages.append(msg)
all_messages.append(msg)
archive_fn = MagicMock()
session.enforce_file_cap(on_archive=archive_fn, limit=6)
# Collect all messages accounted for (retained + archived)
accounted = set(id(m) for m in session.messages)
if archive_fn.called:
for m in archive_fn.call_args.args[0]:
accounted.add(id(m))
all_ids = set(id(m) for m in all_messages)
missing = all_ids - accounted
assert not missing, (
f"Lost {len(missing)} message(s) — neither retained nor archived"
)