mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-13 14:23:58 +00:00
fix(session): prevent duplicate archive and message loss in enforce_file_cap
When retain_recent_legal_suffix hits the else branch (tail has no user messages), it takes a non-contiguous slice from the middle of the session. enforce_file_cap incorrectly assumed dropped messages were always a prefix (before[:dropped_count]), causing user messages to be both archived and retained, and some messages to silently disappear. Fix by having retain_recent_legal_suffix return the actual dropped message list using identity-based diff, so enforce_file_cap no longer needs to guess which messages were removed.
This commit is contained in:
parent
b886b4a566
commit
72fb642ef7
@ -269,13 +269,19 @@ class Session:
|
|||||||
self.updated_at = datetime.now()
|
self.updated_at = datetime.now()
|
||||||
self.metadata.pop("_last_summary", None)
|
self.metadata.pop("_last_summary", None)
|
||||||
|
|
||||||
def retain_recent_legal_suffix(self, max_messages: int) -> None:
|
def retain_recent_legal_suffix(self, max_messages: int) -> list[dict]:
|
||||||
"""Keep a legal recent suffix constrained by a hard message cap."""
|
"""Keep a legal recent suffix constrained by a hard message cap.
|
||||||
|
|
||||||
|
Returns the list of messages that were removed.
|
||||||
|
"""
|
||||||
if max_messages <= 0:
|
if max_messages <= 0:
|
||||||
|
dropped = list(self.messages)
|
||||||
self.clear()
|
self.clear()
|
||||||
return
|
return dropped
|
||||||
if len(self.messages) <= max_messages:
|
if len(self.messages) <= max_messages:
|
||||||
return
|
return []
|
||||||
|
|
||||||
|
original = list(self.messages)
|
||||||
|
|
||||||
retained = list(self.messages[-max_messages:])
|
retained = list(self.messages[-max_messages:])
|
||||||
|
|
||||||
@ -306,10 +312,16 @@ class Session:
|
|||||||
if start:
|
if start:
|
||||||
retained = retained[start:]
|
retained = retained[start:]
|
||||||
|
|
||||||
dropped = len(self.messages) - len(retained)
|
# Compute actually-dropped messages using identity comparison so that
|
||||||
|
# even when retained is a non-contiguous slice of original (the else
|
||||||
|
# branch above), we never duplicate or lose messages.
|
||||||
|
retained_ids = set(id(m) for m in retained)
|
||||||
|
dropped = [m for m in original if id(m) not in retained_ids]
|
||||||
|
|
||||||
self.messages = retained
|
self.messages = retained
|
||||||
self.last_consolidated = max(0, self.last_consolidated - dropped)
|
self.last_consolidated = max(0, self.last_consolidated - len(dropped))
|
||||||
self.updated_at = datetime.now()
|
self.updated_at = datetime.now()
|
||||||
|
return dropped
|
||||||
|
|
||||||
def enforce_file_cap(
|
def enforce_file_cap(
|
||||||
self,
|
self,
|
||||||
@ -320,23 +332,19 @@ class Session:
|
|||||||
if limit <= 0 or len(self.messages) <= limit:
|
if limit <= 0 or len(self.messages) <= limit:
|
||||||
return
|
return
|
||||||
|
|
||||||
before = list(self.messages)
|
|
||||||
before_last_consolidated = self.last_consolidated
|
before_last_consolidated = self.last_consolidated
|
||||||
before_count = len(before)
|
dropped = self.retain_recent_legal_suffix(limit)
|
||||||
self.retain_recent_legal_suffix(limit)
|
if not dropped:
|
||||||
dropped_count = before_count - len(self.messages)
|
|
||||||
if dropped_count <= 0:
|
|
||||||
return
|
return
|
||||||
|
|
||||||
dropped = before[:dropped_count]
|
already_consolidated = min(before_last_consolidated, len(dropped))
|
||||||
already_consolidated = min(before_last_consolidated, dropped_count)
|
|
||||||
archive_chunk = dropped[already_consolidated:]
|
archive_chunk = dropped[already_consolidated:]
|
||||||
if archive_chunk and on_archive:
|
if archive_chunk and on_archive:
|
||||||
on_archive(archive_chunk)
|
on_archive(archive_chunk)
|
||||||
logger.info(
|
logger.info(
|
||||||
"Session file cap hit for {}: dropped {}, raw-archived {}, kept {}",
|
"Session file cap hit for {}: dropped {}, raw-archived {}, kept {}",
|
||||||
self.key,
|
self.key,
|
||||||
dropped_count,
|
len(dropped),
|
||||||
len(archive_chunk),
|
len(archive_chunk),
|
||||||
len(self.messages),
|
len(self.messages),
|
||||||
)
|
)
|
||||||
|
|||||||
@ -538,3 +538,105 @@ def test_retain_recent_legal_suffix_hard_cap_with_long_non_user_chain():
|
|||||||
session.retain_recent_legal_suffix(6)
|
session.retain_recent_legal_suffix(6)
|
||||||
|
|
||||||
assert len(session.messages) <= 6
|
assert len(session.messages) <= 6
|
||||||
|
|
||||||
|
|
||||||
|
# --- enforce_file_cap archive correctness (issue #4128) ---
|
||||||
|
|
||||||
|
|
||||||
|
def test_retain_recent_legal_suffix_returns_dropped_messages():
|
||||||
|
"""retain_recent_legal_suffix returns the actually-dropped messages."""
|
||||||
|
session = Session(key="test:return-dropped")
|
||||||
|
for i in range(10):
|
||||||
|
session.messages.append({"role": "user", "content": f"msg{i}"})
|
||||||
|
|
||||||
|
dropped = session.retain_recent_legal_suffix(4)
|
||||||
|
|
||||||
|
assert len(dropped) == 6
|
||||||
|
assert [m["content"] for m in dropped] == [f"msg{i}" for i in range(6)]
|
||||||
|
assert len(session.messages) == 4
|
||||||
|
|
||||||
|
|
||||||
|
def test_retain_recent_legal_suffix_returns_empty_when_no_drop():
|
||||||
|
"""No messages dropped → empty list returned."""
|
||||||
|
session = Session(key="test:no-drop")
|
||||||
|
for i in range(3):
|
||||||
|
session.messages.append({"role": "user", "content": f"msg{i}"})
|
||||||
|
|
||||||
|
dropped = session.retain_recent_legal_suffix(4)
|
||||||
|
|
||||||
|
assert dropped == []
|
||||||
|
assert len(session.messages) == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_retain_recent_legal_suffix_returns_all_on_zero():
|
||||||
|
"""max_messages=0 clears session and returns all messages."""
|
||||||
|
session = Session(key="test:zero-return")
|
||||||
|
for i in range(5):
|
||||||
|
session.messages.append({"role": "user", "content": f"msg{i}"})
|
||||||
|
|
||||||
|
dropped = session.retain_recent_legal_suffix(0)
|
||||||
|
|
||||||
|
assert len(dropped) == 5
|
||||||
|
assert session.messages == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_enforce_file_cap_no_duplicate_archive_in_else_branch():
|
||||||
|
"""When the tail is assistant-only, enforce_file_cap must not archive
|
||||||
|
messages that are also retained (the bug from issue #4128)."""
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
session = Session(key="test:else-archive")
|
||||||
|
# Build: 15 user messages, then 10 assistant messages (no user in tail)
|
||||||
|
for i in range(15):
|
||||||
|
session.messages.append({"role": "user", "content": f"u{i}"})
|
||||||
|
for i in range(10):
|
||||||
|
session.messages.append({"role": "assistant", "content": f"a{i}"})
|
||||||
|
|
||||||
|
archive_fn = MagicMock()
|
||||||
|
session.enforce_file_cap(on_archive=archive_fn, limit=6)
|
||||||
|
|
||||||
|
# Verify retained messages
|
||||||
|
retained_contents = [m["content"] for m in session.messages]
|
||||||
|
assert len(session.messages) <= 6
|
||||||
|
|
||||||
|
# Verify archived messages have NO overlap with retained
|
||||||
|
if archive_fn.called:
|
||||||
|
archived = archive_fn.call_args.args[0]
|
||||||
|
archived_ids = set(id(m) for m in archived)
|
||||||
|
retained_ids = set(id(m) for m in session.messages)
|
||||||
|
assert not archived_ids & retained_ids, (
|
||||||
|
f"Duplicate messages in archive and retained: "
|
||||||
|
f"overlap contents = {[m['content'] for m in archived if id(m) in retained_ids]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_enforce_file_cap_no_message_loss_in_else_branch():
|
||||||
|
"""In the else branch, no messages should silently disappear — every
|
||||||
|
message must be either retained or archived."""
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
session = Session(key="test:else-no-loss")
|
||||||
|
all_messages = []
|
||||||
|
for i in range(15):
|
||||||
|
msg = {"role": "user", "content": f"u{i}"}
|
||||||
|
session.messages.append(msg)
|
||||||
|
all_messages.append(msg)
|
||||||
|
for i in range(10):
|
||||||
|
msg = {"role": "assistant", "content": f"a{i}"}
|
||||||
|
session.messages.append(msg)
|
||||||
|
all_messages.append(msg)
|
||||||
|
|
||||||
|
archive_fn = MagicMock()
|
||||||
|
session.enforce_file_cap(on_archive=archive_fn, limit=6)
|
||||||
|
|
||||||
|
# Collect all messages accounted for (retained + archived)
|
||||||
|
accounted = set(id(m) for m in session.messages)
|
||||||
|
if archive_fn.called:
|
||||||
|
for m in archive_fn.call_args.args[0]:
|
||||||
|
accounted.add(id(m))
|
||||||
|
|
||||||
|
all_ids = set(id(m) for m in all_messages)
|
||||||
|
missing = all_ids - accounted
|
||||||
|
assert not missing, (
|
||||||
|
f"Lost {len(missing)} message(s) — neither retained nor archived"
|
||||||
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user