mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-19 16:12:30 +00:00
fix(memory): align replay overflow with history trimming
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
parent
24daf9a51c
commit
cbd5b06075
@ -23,6 +23,7 @@ from nanobot.utils.helpers import (
|
||||
ensure_dir,
|
||||
estimate_message_tokens,
|
||||
estimate_prompt_tokens_chain,
|
||||
find_legal_message_start,
|
||||
strip_think,
|
||||
truncate_text,
|
||||
)
|
||||
@ -531,23 +532,29 @@ class Consolidator:
|
||||
) -> int | None:
|
||||
if not replay_max_messages or replay_max_messages <= 0:
|
||||
return None
|
||||
tail = list(session.messages[session.last_consolidated:])
|
||||
tail = list(enumerate(session.messages[session.last_consolidated:], session.last_consolidated))
|
||||
if len(tail) <= replay_max_messages:
|
||||
return None
|
||||
|
||||
probe = Session(
|
||||
key=session.key,
|
||||
messages=[dict(message) for message in tail],
|
||||
created_at=session.created_at,
|
||||
updated_at=session.updated_at,
|
||||
metadata={},
|
||||
last_consolidated=0,
|
||||
)
|
||||
probe.retain_recent_legal_suffix(replay_max_messages)
|
||||
cut = len(tail) - len(probe.messages)
|
||||
if cut <= 0:
|
||||
sliced = tail[-replay_max_messages:]
|
||||
for i, (_idx, message) in enumerate(sliced):
|
||||
if message.get("role") == "user":
|
||||
start = i
|
||||
if i > 0 and sliced[i - 1][1].get("_channel_delivery"):
|
||||
start = i - 1
|
||||
sliced = sliced[start:]
|
||||
break
|
||||
|
||||
legal_start = find_legal_message_start([message for _idx, message in sliced])
|
||||
if legal_start:
|
||||
sliced = sliced[legal_start:]
|
||||
if not sliced:
|
||||
return len(session.messages)
|
||||
|
||||
first_visible_idx = sliced[0][0]
|
||||
if first_visible_idx <= session.last_consolidated:
|
||||
return None
|
||||
return session.last_consolidated + cut
|
||||
return first_visible_idx
|
||||
|
||||
async def _consolidate_replay_overflow(
|
||||
self,
|
||||
|
||||
@ -167,6 +167,36 @@ class TestConsolidatorTokenBudget:
|
||||
assert session.metadata["_last_summary"]["text"] == "old conversation summary"
|
||||
consolidator.sessions.save.assert_called()
|
||||
|
||||
async def test_replay_window_overflow_matches_history_tool_boundary(
|
||||
self,
|
||||
consolidator,
|
||||
):
|
||||
"""Archive the exact prefix hidden by get_history's legal-start trimming."""
|
||||
session = Session(key="test:replay-tool-boundary")
|
||||
session.add_message("user", "run the tool")
|
||||
session.add_message(
|
||||
"assistant",
|
||||
"",
|
||||
tool_calls=[
|
||||
{"id": "call-1", "type": "function", "function": {"name": "x", "arguments": "{}"}}
|
||||
],
|
||||
)
|
||||
session.add_message("tool", "tool result", tool_call_id="call-1", name="x")
|
||||
session.add_message("assistant", "final answer")
|
||||
|
||||
consolidator.estimate_session_prompt_tokens = MagicMock(return_value=(100, "tiktoken"))
|
||||
consolidator.archive = AsyncMock(return_value="tool turn summary")
|
||||
|
||||
await consolidator.maybe_consolidate_by_tokens(
|
||||
session,
|
||||
replay_max_messages=2,
|
||||
)
|
||||
|
||||
archived_chunk = consolidator.archive.await_args.args[0]
|
||||
assert [m["role"] for m in archived_chunk] == ["user", "assistant", "tool"]
|
||||
assert session.last_consolidated == 3
|
||||
assert session.get_history(max_messages=2) == [{"role": "assistant", "content": "final answer"}]
|
||||
|
||||
async def test_large_chunk_archived_without_cap(self, consolidator):
|
||||
"""Without chunk cap, the full range from pick_consolidation_boundary is archived."""
|
||||
consolidator._SAFETY_BUFFER = 0
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user