mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-13 14:23:58 +00:00
fix(session): prevent duplicate archive and message loss in enforce_file_cap
When retain_recent_legal_suffix hits the else branch (tail has no user messages), it takes a non-contiguous slice from the middle of the session. enforce_file_cap incorrectly assumed dropped messages were always a prefix (before[:dropped_count]), causing user messages to be both archived and retained, and some messages to silently disappear. Fix by having retain_recent_legal_suffix return the actual dropped message list using identity-based diff, so enforce_file_cap no longer needs to guess which messages were removed.
This commit is contained in:
parent
b886b4a566
commit
72fb642ef7
@ -269,13 +269,19 @@ class Session:
|
||||
self.updated_at = datetime.now()
|
||||
self.metadata.pop("_last_summary", None)
|
||||
|
||||
def retain_recent_legal_suffix(self, max_messages: int) -> None:
|
||||
"""Keep a legal recent suffix constrained by a hard message cap."""
|
||||
def retain_recent_legal_suffix(self, max_messages: int) -> list[dict]:
|
||||
"""Keep a legal recent suffix constrained by a hard message cap.
|
||||
|
||||
Returns the list of messages that were removed.
|
||||
"""
|
||||
if max_messages <= 0:
|
||||
dropped = list(self.messages)
|
||||
self.clear()
|
||||
return
|
||||
return dropped
|
||||
if len(self.messages) <= max_messages:
|
||||
return
|
||||
return []
|
||||
|
||||
original = list(self.messages)
|
||||
|
||||
retained = list(self.messages[-max_messages:])
|
||||
|
||||
@ -306,10 +312,16 @@ class Session:
|
||||
if start:
|
||||
retained = retained[start:]
|
||||
|
||||
dropped = len(self.messages) - len(retained)
|
||||
# Compute actually-dropped messages using identity comparison so that
|
||||
# even when retained is a non-contiguous slice of original (the else
|
||||
# branch above), we never duplicate or lose messages.
|
||||
retained_ids = set(id(m) for m in retained)
|
||||
dropped = [m for m in original if id(m) not in retained_ids]
|
||||
|
||||
self.messages = retained
|
||||
self.last_consolidated = max(0, self.last_consolidated - dropped)
|
||||
self.last_consolidated = max(0, self.last_consolidated - len(dropped))
|
||||
self.updated_at = datetime.now()
|
||||
return dropped
|
||||
|
||||
def enforce_file_cap(
|
||||
self,
|
||||
@ -320,23 +332,19 @@ class Session:
|
||||
if limit <= 0 or len(self.messages) <= limit:
|
||||
return
|
||||
|
||||
before = list(self.messages)
|
||||
before_last_consolidated = self.last_consolidated
|
||||
before_count = len(before)
|
||||
self.retain_recent_legal_suffix(limit)
|
||||
dropped_count = before_count - len(self.messages)
|
||||
if dropped_count <= 0:
|
||||
dropped = self.retain_recent_legal_suffix(limit)
|
||||
if not dropped:
|
||||
return
|
||||
|
||||
dropped = before[:dropped_count]
|
||||
already_consolidated = min(before_last_consolidated, dropped_count)
|
||||
already_consolidated = min(before_last_consolidated, len(dropped))
|
||||
archive_chunk = dropped[already_consolidated:]
|
||||
if archive_chunk and on_archive:
|
||||
on_archive(archive_chunk)
|
||||
logger.info(
|
||||
"Session file cap hit for {}: dropped {}, raw-archived {}, kept {}",
|
||||
self.key,
|
||||
dropped_count,
|
||||
len(dropped),
|
||||
len(archive_chunk),
|
||||
len(self.messages),
|
||||
)
|
||||
|
||||
@ -538,3 +538,105 @@ def test_retain_recent_legal_suffix_hard_cap_with_long_non_user_chain():
|
||||
session.retain_recent_legal_suffix(6)
|
||||
|
||||
assert len(session.messages) <= 6
|
||||
|
||||
|
||||
# --- enforce_file_cap archive correctness (issue #4128) ---
|
||||
|
||||
|
||||
def test_retain_recent_legal_suffix_returns_dropped_messages():
|
||||
"""retain_recent_legal_suffix returns the actually-dropped messages."""
|
||||
session = Session(key="test:return-dropped")
|
||||
for i in range(10):
|
||||
session.messages.append({"role": "user", "content": f"msg{i}"})
|
||||
|
||||
dropped = session.retain_recent_legal_suffix(4)
|
||||
|
||||
assert len(dropped) == 6
|
||||
assert [m["content"] for m in dropped] == [f"msg{i}" for i in range(6)]
|
||||
assert len(session.messages) == 4
|
||||
|
||||
|
||||
def test_retain_recent_legal_suffix_returns_empty_when_no_drop():
|
||||
"""No messages dropped → empty list returned."""
|
||||
session = Session(key="test:no-drop")
|
||||
for i in range(3):
|
||||
session.messages.append({"role": "user", "content": f"msg{i}"})
|
||||
|
||||
dropped = session.retain_recent_legal_suffix(4)
|
||||
|
||||
assert dropped == []
|
||||
assert len(session.messages) == 3
|
||||
|
||||
|
||||
def test_retain_recent_legal_suffix_returns_all_on_zero():
|
||||
"""max_messages=0 clears session and returns all messages."""
|
||||
session = Session(key="test:zero-return")
|
||||
for i in range(5):
|
||||
session.messages.append({"role": "user", "content": f"msg{i}"})
|
||||
|
||||
dropped = session.retain_recent_legal_suffix(0)
|
||||
|
||||
assert len(dropped) == 5
|
||||
assert session.messages == []
|
||||
|
||||
|
||||
def test_enforce_file_cap_no_duplicate_archive_in_else_branch():
|
||||
"""When the tail is assistant-only, enforce_file_cap must not archive
|
||||
messages that are also retained (the bug from issue #4128)."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
session = Session(key="test:else-archive")
|
||||
# Build: 15 user messages, then 10 assistant messages (no user in tail)
|
||||
for i in range(15):
|
||||
session.messages.append({"role": "user", "content": f"u{i}"})
|
||||
for i in range(10):
|
||||
session.messages.append({"role": "assistant", "content": f"a{i}"})
|
||||
|
||||
archive_fn = MagicMock()
|
||||
session.enforce_file_cap(on_archive=archive_fn, limit=6)
|
||||
|
||||
# Verify retained messages
|
||||
retained_contents = [m["content"] for m in session.messages]
|
||||
assert len(session.messages) <= 6
|
||||
|
||||
# Verify archived messages have NO overlap with retained
|
||||
if archive_fn.called:
|
||||
archived = archive_fn.call_args.args[0]
|
||||
archived_ids = set(id(m) for m in archived)
|
||||
retained_ids = set(id(m) for m in session.messages)
|
||||
assert not archived_ids & retained_ids, (
|
||||
f"Duplicate messages in archive and retained: "
|
||||
f"overlap contents = {[m['content'] for m in archived if id(m) in retained_ids]}"
|
||||
)
|
||||
|
||||
|
||||
def test_enforce_file_cap_no_message_loss_in_else_branch():
|
||||
"""In the else branch, no messages should silently disappear — every
|
||||
message must be either retained or archived."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
session = Session(key="test:else-no-loss")
|
||||
all_messages = []
|
||||
for i in range(15):
|
||||
msg = {"role": "user", "content": f"u{i}"}
|
||||
session.messages.append(msg)
|
||||
all_messages.append(msg)
|
||||
for i in range(10):
|
||||
msg = {"role": "assistant", "content": f"a{i}"}
|
||||
session.messages.append(msg)
|
||||
all_messages.append(msg)
|
||||
|
||||
archive_fn = MagicMock()
|
||||
session.enforce_file_cap(on_archive=archive_fn, limit=6)
|
||||
|
||||
# Collect all messages accounted for (retained + archived)
|
||||
accounted = set(id(m) for m in session.messages)
|
||||
if archive_fn.called:
|
||||
for m in archive_fn.call_args.args[0]:
|
||||
accounted.add(id(m))
|
||||
|
||||
all_ids = set(id(m) for m in all_messages)
|
||||
missing = all_ids - accounted
|
||||
assert not missing, (
|
||||
f"Lost {len(missing)} message(s) — neither retained nor archived"
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user