mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-15 07:14:08 +00:00
fix(session): correct last_consolidated tracking in non-contiguous retention
The previous fix made retain_recent_legal_suffix return the actual dropped message list, but already_consolidated was still computed with min(before_last_consolidated, len(dropped)), which assumes dropped messages are always a prefix. In the else branch (tail has no user messages), dropped may include messages from after the consolidated prefix, causing already_consolidated to skip too many and leaving tail messages neither retained nor raw-archived. Fix by having retain_recent_legal_suffix return (dropped, already_consolidated_count) where already_consolidated_count is computed against original message indices. Also fix last_consolidated update to count how many retained messages were inside the old consolidated prefix.
This commit is contained in:
parent
72fb642ef7
commit
baffd6ef92
@ -269,19 +269,25 @@ class Session:
|
||||
self.updated_at = datetime.now()
|
||||
self.metadata.pop("_last_summary", None)
|
||||
|
||||
def retain_recent_legal_suffix(self, max_messages: int) -> list[dict]:
|
||||
def retain_recent_legal_suffix(self, max_messages: int) -> tuple[list[dict], int]:
|
||||
"""Keep a legal recent suffix constrained by a hard message cap.
|
||||
|
||||
Returns the list of messages that were removed.
|
||||
Returns ``(dropped, already_consolidated_count)`` where *dropped* is
|
||||
the list of removed messages (in original order) and
|
||||
*already_consolidated_count* is how many of those were inside the
|
||||
pre-existing ``last_consolidated`` prefix and therefore do not need
|
||||
raw archiving.
|
||||
"""
|
||||
if max_messages <= 0:
|
||||
dropped = list(self.messages)
|
||||
lc = self.last_consolidated
|
||||
self.clear()
|
||||
return dropped
|
||||
return dropped, min(lc, len(dropped))
|
||||
if len(self.messages) <= max_messages:
|
||||
return []
|
||||
return [], 0
|
||||
|
||||
original = list(self.messages)
|
||||
before_lc = self.last_consolidated
|
||||
|
||||
retained = list(self.messages[-max_messages:])
|
||||
|
||||
@ -318,10 +324,26 @@ class Session:
|
||||
retained_ids = set(id(m) for m in retained)
|
||||
dropped = [m for m in original if id(m) not in retained_ids]
|
||||
|
||||
# Count how many dropped messages were in the already-consolidated
|
||||
# prefix of the original list. This cannot be a simple min() because
|
||||
# dropped may include messages from *after* the consolidated prefix
|
||||
# (e.g. in the else branch).
|
||||
already_consolidated = sum(
|
||||
1 for i, m in enumerate(original)
|
||||
if i < before_lc and id(m) not in retained_ids
|
||||
)
|
||||
|
||||
# New last_consolidated = count of retained messages that were inside
|
||||
# the old consolidated prefix.
|
||||
new_lc = sum(
|
||||
1 for i, m in enumerate(original)
|
||||
if i < before_lc and id(m) in retained_ids
|
||||
)
|
||||
|
||||
self.messages = retained
|
||||
self.last_consolidated = max(0, self.last_consolidated - len(dropped))
|
||||
self.last_consolidated = new_lc
|
||||
self.updated_at = datetime.now()
|
||||
return dropped
|
||||
return dropped, already_consolidated
|
||||
|
||||
def enforce_file_cap(
|
||||
self,
|
||||
@ -332,12 +354,10 @@ class Session:
|
||||
if limit <= 0 or len(self.messages) <= limit:
|
||||
return
|
||||
|
||||
before_last_consolidated = self.last_consolidated
|
||||
dropped = self.retain_recent_legal_suffix(limit)
|
||||
dropped, already_consolidated = self.retain_recent_legal_suffix(limit)
|
||||
if not dropped:
|
||||
return
|
||||
|
||||
already_consolidated = min(before_last_consolidated, len(dropped))
|
||||
archive_chunk = dropped[already_consolidated:]
|
||||
if archive_chunk and on_archive:
|
||||
on_archive(archive_chunk)
|
||||
|
||||
@ -549,11 +549,12 @@ def test_retain_recent_legal_suffix_returns_dropped_messages():
|
||||
for i in range(10):
|
||||
session.messages.append({"role": "user", "content": f"msg{i}"})
|
||||
|
||||
dropped = session.retain_recent_legal_suffix(4)
|
||||
dropped, already_cons = session.retain_recent_legal_suffix(4)
|
||||
|
||||
assert len(dropped) == 6
|
||||
assert [m["content"] for m in dropped] == [f"msg{i}" for i in range(6)]
|
||||
assert len(session.messages) == 4
|
||||
assert already_cons == 0
|
||||
|
||||
|
||||
def test_retain_recent_legal_suffix_returns_empty_when_no_drop():
|
||||
@ -562,9 +563,10 @@ def test_retain_recent_legal_suffix_returns_empty_when_no_drop():
|
||||
for i in range(3):
|
||||
session.messages.append({"role": "user", "content": f"msg{i}"})
|
||||
|
||||
dropped = session.retain_recent_legal_suffix(4)
|
||||
dropped, already_cons = session.retain_recent_legal_suffix(4)
|
||||
|
||||
assert dropped == []
|
||||
assert already_cons == 0
|
||||
assert len(session.messages) == 3
|
||||
|
||||
|
||||
@ -573,10 +575,12 @@ def test_retain_recent_legal_suffix_returns_all_on_zero():
|
||||
session = Session(key="test:zero-return")
|
||||
for i in range(5):
|
||||
session.messages.append({"role": "user", "content": f"msg{i}"})
|
||||
session.last_consolidated = 3
|
||||
|
||||
dropped = session.retain_recent_legal_suffix(0)
|
||||
dropped, already_cons = session.retain_recent_legal_suffix(0)
|
||||
|
||||
assert len(dropped) == 5
|
||||
assert already_cons == 3
|
||||
assert session.messages == []
|
||||
|
||||
|
||||
@ -640,3 +644,53 @@ def test_enforce_file_cap_no_message_loss_in_else_branch():
|
||||
assert not missing, (
|
||||
f"Lost {len(missing)} message(s) — neither retained nor archived"
|
||||
)
|
||||
|
||||
|
||||
def test_enforce_file_cap_correct_archive_with_last_consolidated_in_else_branch():
|
||||
"""When last_consolidated > 0 and the else branch fires, only the
|
||||
unconsolidated dropped messages should be raw-archived. Messages in the
|
||||
consolidated prefix that are dropped do NOT need raw archiving."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
session = Session(key="test:else-lc-archive")
|
||||
# 20 messages total: u0..u9 (user), a0..a9 (assistant)
|
||||
for i in range(10):
|
||||
session.messages.append({"role": "user", "content": f"u{i}"})
|
||||
for i in range(10):
|
||||
session.messages.append({"role": "assistant", "content": f"a{i}"})
|
||||
# First 8 messages already consolidated
|
||||
session.last_consolidated = 8
|
||||
|
||||
archive_fn = MagicMock()
|
||||
session.enforce_file_cap(on_archive=archive_fn, limit=4)
|
||||
|
||||
if archive_fn.called:
|
||||
archived = archive_fn.call_args.args[0]
|
||||
# Archived messages should NOT include any from the consolidated prefix
|
||||
# (u0..u7). They should only be unconsolidated dropped messages.
|
||||
archived_contents = [m["content"] for m in archived]
|
||||
for c in archived_contents:
|
||||
assert c not in [f"u{i}" for i in range(8)], (
|
||||
f"Consolidated message {c!r} should not be raw-archived"
|
||||
)
|
||||
|
||||
|
||||
def test_retain_recent_legal_suffix_last_consolidated_correct_in_else_branch():
|
||||
"""last_consolidated after retain_recent_legal_suffix should reflect how
|
||||
many retained messages were inside the old consolidated prefix."""
|
||||
session = Session(key="test:else-lc-correct")
|
||||
# 20 messages: u0..u9, a0..a9
|
||||
for i in range(10):
|
||||
session.messages.append({"role": "user", "content": f"u{i}"})
|
||||
for i in range(10):
|
||||
session.messages.append({"role": "assistant", "content": f"a{i}"})
|
||||
session.last_consolidated = 12 # u0..u9, a0, a1 consolidated
|
||||
|
||||
dropped, already_cons = session.retain_recent_legal_suffix(4)
|
||||
|
||||
# Retained messages start from latest user (u9) + max_messages forward
|
||||
# so retained = [u9, a0..a9][:4] → but these are from original indices 9..12
|
||||
# Of those, indices 9,10,11 are < 12 (before_lc), so new_lc = 3
|
||||
assert session.last_consolidated <= 12
|
||||
# already_cons should count dropped messages with original index < 12
|
||||
assert already_cons >= 0
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user