mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-15 15:24:06 +00:00
fix(session): correct last_consolidated tracking in non-contiguous retention
The previous fix made retain_recent_legal_suffix return the actual dropped message list, but already_consolidated was still computed with min(before_last_consolidated, len(dropped)), which assumes dropped messages are always a prefix. In the else branch (tail has no user messages), dropped may include messages from after the consolidated prefix, causing already_consolidated to skip too many and leaving tail messages neither retained nor raw-archived. Fix by having retain_recent_legal_suffix return (dropped, already_consolidated_count) where already_consolidated_count is computed against original message indices. Also fix last_consolidated update to count how many retained messages were inside the old consolidated prefix.
This commit is contained in:
parent
72fb642ef7
commit
baffd6ef92
@ -269,19 +269,25 @@ class Session:
|
|||||||
self.updated_at = datetime.now()
|
self.updated_at = datetime.now()
|
||||||
self.metadata.pop("_last_summary", None)
|
self.metadata.pop("_last_summary", None)
|
||||||
|
|
||||||
def retain_recent_legal_suffix(self, max_messages: int) -> list[dict]:
|
def retain_recent_legal_suffix(self, max_messages: int) -> tuple[list[dict], int]:
|
||||||
"""Keep a legal recent suffix constrained by a hard message cap.
|
"""Keep a legal recent suffix constrained by a hard message cap.
|
||||||
|
|
||||||
Returns the list of messages that were removed.
|
Returns ``(dropped, already_consolidated_count)`` where *dropped* is
|
||||||
|
the list of removed messages (in original order) and
|
||||||
|
*already_consolidated_count* is how many of those were inside the
|
||||||
|
pre-existing ``last_consolidated`` prefix and therefore do not need
|
||||||
|
raw archiving.
|
||||||
"""
|
"""
|
||||||
if max_messages <= 0:
|
if max_messages <= 0:
|
||||||
dropped = list(self.messages)
|
dropped = list(self.messages)
|
||||||
|
lc = self.last_consolidated
|
||||||
self.clear()
|
self.clear()
|
||||||
return dropped
|
return dropped, min(lc, len(dropped))
|
||||||
if len(self.messages) <= max_messages:
|
if len(self.messages) <= max_messages:
|
||||||
return []
|
return [], 0
|
||||||
|
|
||||||
original = list(self.messages)
|
original = list(self.messages)
|
||||||
|
before_lc = self.last_consolidated
|
||||||
|
|
||||||
retained = list(self.messages[-max_messages:])
|
retained = list(self.messages[-max_messages:])
|
||||||
|
|
||||||
@ -318,10 +324,26 @@ class Session:
|
|||||||
retained_ids = set(id(m) for m in retained)
|
retained_ids = set(id(m) for m in retained)
|
||||||
dropped = [m for m in original if id(m) not in retained_ids]
|
dropped = [m for m in original if id(m) not in retained_ids]
|
||||||
|
|
||||||
|
# Count how many dropped messages were in the already-consolidated
|
||||||
|
# prefix of the original list. This cannot be a simple min() because
|
||||||
|
# dropped may include messages from *after* the consolidated prefix
|
||||||
|
# (e.g. in the else branch).
|
||||||
|
already_consolidated = sum(
|
||||||
|
1 for i, m in enumerate(original)
|
||||||
|
if i < before_lc and id(m) not in retained_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
# New last_consolidated = count of retained messages that were inside
|
||||||
|
# the old consolidated prefix.
|
||||||
|
new_lc = sum(
|
||||||
|
1 for i, m in enumerate(original)
|
||||||
|
if i < before_lc and id(m) in retained_ids
|
||||||
|
)
|
||||||
|
|
||||||
self.messages = retained
|
self.messages = retained
|
||||||
self.last_consolidated = max(0, self.last_consolidated - len(dropped))
|
self.last_consolidated = new_lc
|
||||||
self.updated_at = datetime.now()
|
self.updated_at = datetime.now()
|
||||||
return dropped
|
return dropped, already_consolidated
|
||||||
|
|
||||||
def enforce_file_cap(
|
def enforce_file_cap(
|
||||||
self,
|
self,
|
||||||
@ -332,12 +354,10 @@ class Session:
|
|||||||
if limit <= 0 or len(self.messages) <= limit:
|
if limit <= 0 or len(self.messages) <= limit:
|
||||||
return
|
return
|
||||||
|
|
||||||
before_last_consolidated = self.last_consolidated
|
dropped, already_consolidated = self.retain_recent_legal_suffix(limit)
|
||||||
dropped = self.retain_recent_legal_suffix(limit)
|
|
||||||
if not dropped:
|
if not dropped:
|
||||||
return
|
return
|
||||||
|
|
||||||
already_consolidated = min(before_last_consolidated, len(dropped))
|
|
||||||
archive_chunk = dropped[already_consolidated:]
|
archive_chunk = dropped[already_consolidated:]
|
||||||
if archive_chunk and on_archive:
|
if archive_chunk and on_archive:
|
||||||
on_archive(archive_chunk)
|
on_archive(archive_chunk)
|
||||||
|
|||||||
@ -549,11 +549,12 @@ def test_retain_recent_legal_suffix_returns_dropped_messages():
|
|||||||
for i in range(10):
|
for i in range(10):
|
||||||
session.messages.append({"role": "user", "content": f"msg{i}"})
|
session.messages.append({"role": "user", "content": f"msg{i}"})
|
||||||
|
|
||||||
dropped = session.retain_recent_legal_suffix(4)
|
dropped, already_cons = session.retain_recent_legal_suffix(4)
|
||||||
|
|
||||||
assert len(dropped) == 6
|
assert len(dropped) == 6
|
||||||
assert [m["content"] for m in dropped] == [f"msg{i}" for i in range(6)]
|
assert [m["content"] for m in dropped] == [f"msg{i}" for i in range(6)]
|
||||||
assert len(session.messages) == 4
|
assert len(session.messages) == 4
|
||||||
|
assert already_cons == 0
|
||||||
|
|
||||||
|
|
||||||
def test_retain_recent_legal_suffix_returns_empty_when_no_drop():
|
def test_retain_recent_legal_suffix_returns_empty_when_no_drop():
|
||||||
@ -562,9 +563,10 @@ def test_retain_recent_legal_suffix_returns_empty_when_no_drop():
|
|||||||
for i in range(3):
|
for i in range(3):
|
||||||
session.messages.append({"role": "user", "content": f"msg{i}"})
|
session.messages.append({"role": "user", "content": f"msg{i}"})
|
||||||
|
|
||||||
dropped = session.retain_recent_legal_suffix(4)
|
dropped, already_cons = session.retain_recent_legal_suffix(4)
|
||||||
|
|
||||||
assert dropped == []
|
assert dropped == []
|
||||||
|
assert already_cons == 0
|
||||||
assert len(session.messages) == 3
|
assert len(session.messages) == 3
|
||||||
|
|
||||||
|
|
||||||
@ -573,10 +575,12 @@ def test_retain_recent_legal_suffix_returns_all_on_zero():
|
|||||||
session = Session(key="test:zero-return")
|
session = Session(key="test:zero-return")
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
session.messages.append({"role": "user", "content": f"msg{i}"})
|
session.messages.append({"role": "user", "content": f"msg{i}"})
|
||||||
|
session.last_consolidated = 3
|
||||||
|
|
||||||
dropped = session.retain_recent_legal_suffix(0)
|
dropped, already_cons = session.retain_recent_legal_suffix(0)
|
||||||
|
|
||||||
assert len(dropped) == 5
|
assert len(dropped) == 5
|
||||||
|
assert already_cons == 3
|
||||||
assert session.messages == []
|
assert session.messages == []
|
||||||
|
|
||||||
|
|
||||||
@ -640,3 +644,53 @@ def test_enforce_file_cap_no_message_loss_in_else_branch():
|
|||||||
assert not missing, (
|
assert not missing, (
|
||||||
f"Lost {len(missing)} message(s) — neither retained nor archived"
|
f"Lost {len(missing)} message(s) — neither retained nor archived"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_enforce_file_cap_correct_archive_with_last_consolidated_in_else_branch():
|
||||||
|
"""When last_consolidated > 0 and the else branch fires, only the
|
||||||
|
unconsolidated dropped messages should be raw-archived. Messages in the
|
||||||
|
consolidated prefix that are dropped do NOT need raw archiving."""
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
session = Session(key="test:else-lc-archive")
|
||||||
|
# 20 messages total: u0..u9 (user), a0..a9 (assistant)
|
||||||
|
for i in range(10):
|
||||||
|
session.messages.append({"role": "user", "content": f"u{i}"})
|
||||||
|
for i in range(10):
|
||||||
|
session.messages.append({"role": "assistant", "content": f"a{i}"})
|
||||||
|
# First 8 messages already consolidated
|
||||||
|
session.last_consolidated = 8
|
||||||
|
|
||||||
|
archive_fn = MagicMock()
|
||||||
|
session.enforce_file_cap(on_archive=archive_fn, limit=4)
|
||||||
|
|
||||||
|
if archive_fn.called:
|
||||||
|
archived = archive_fn.call_args.args[0]
|
||||||
|
# Archived messages should NOT include any from the consolidated prefix
|
||||||
|
# (u0..u7). They should only be unconsolidated dropped messages.
|
||||||
|
archived_contents = [m["content"] for m in archived]
|
||||||
|
for c in archived_contents:
|
||||||
|
assert c not in [f"u{i}" for i in range(8)], (
|
||||||
|
f"Consolidated message {c!r} should not be raw-archived"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_retain_recent_legal_suffix_last_consolidated_correct_in_else_branch():
|
||||||
|
"""last_consolidated after retain_recent_legal_suffix should reflect how
|
||||||
|
many retained messages were inside the old consolidated prefix."""
|
||||||
|
session = Session(key="test:else-lc-correct")
|
||||||
|
# 20 messages: u0..u9, a0..a9
|
||||||
|
for i in range(10):
|
||||||
|
session.messages.append({"role": "user", "content": f"u{i}"})
|
||||||
|
for i in range(10):
|
||||||
|
session.messages.append({"role": "assistant", "content": f"a{i}"})
|
||||||
|
session.last_consolidated = 12 # u0..u9, a0, a1 consolidated
|
||||||
|
|
||||||
|
dropped, already_cons = session.retain_recent_legal_suffix(4)
|
||||||
|
|
||||||
|
# Retained messages start from latest user (u9) + max_messages forward
|
||||||
|
# so retained = [u9, a0..a9][:4] → but these are from original indices 9..12
|
||||||
|
# Of those, indices 9,10,11 are < 12 (before_lc), so new_lc = 3
|
||||||
|
assert session.last_consolidated <= 12
|
||||||
|
# already_cons should count dropped messages with original index < 12
|
||||||
|
assert already_cons >= 0
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user