mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-20 16:42:25 +00:00
fix(memory): add session-refresh guard to maybe_consolidate_by_tokens
When background consolidation runs with a stale session reference (captured before AutoCompact replaced the session via compact_idle_session), it could operate on outdated data. Now, after acquiring the per-session lock, the method refreshes its session reference from SessionManager.get_or_create(). If the session was replaced, it swaps in the fresh reference before doing any consolidation work. This prevents a race where AutoCompact truncates an idle session while a background maybe_consolidate_by_tokens call is in flight with the old session object.
This commit is contained in:
parent
48d35bd2d9
commit
888d54790d
@ -683,6 +683,11 @@ class Consolidator:
|
|||||||
|
|
||||||
lock = self.get_lock(session.key)
|
lock = self.get_lock(session.key)
|
||||||
async with lock:
|
async with lock:
|
||||||
|
# Refresh session reference: AutoCompact may have replaced it.
|
||||||
|
fresh = self.sessions.get_or_create(session.key)
|
||||||
|
if fresh is not session:
|
||||||
|
session = fresh
|
||||||
|
|
||||||
budget = self._input_token_budget
|
budget = self._input_token_budget
|
||||||
target = int(budget * self.consolidation_ratio)
|
target = int(budget * self.consolidation_ratio)
|
||||||
last_summary = await self._consolidate_replay_overflow(
|
last_summary = await self._consolidate_replay_overflow(
|
||||||
|
|||||||
@ -28,6 +28,12 @@ def mock_provider():
|
|||||||
def consolidator(store, mock_provider):
|
def consolidator(store, mock_provider):
|
||||||
sessions = MagicMock()
|
sessions = MagicMock()
|
||||||
sessions.save = MagicMock()
|
sessions.save = MagicMock()
|
||||||
|
# When maybe_consolidate_by_tokens refreshes the session reference via
|
||||||
|
# get_or_create(session.key), it should get back the same object the test
|
||||||
|
# passed in. Store sessions by key so the lookup is transparent.
|
||||||
|
_session_cache: dict[str, MagicMock] = {}
|
||||||
|
sessions.get_or_create = MagicMock(side_effect=lambda key: _session_cache.get(key, MagicMock()))
|
||||||
|
sessions._session_cache = _session_cache
|
||||||
return Consolidator(
|
return Consolidator(
|
||||||
store=store,
|
store=store,
|
||||||
provider=mock_provider,
|
provider=mock_provider,
|
||||||
@ -117,6 +123,7 @@ class TestConsolidatorTokenBudget:
|
|||||||
session.last_consolidated = 0
|
session.last_consolidated = 0
|
||||||
session.messages = [{"role": "user", "content": "hi"}]
|
session.messages = [{"role": "user", "content": "hi"}]
|
||||||
session.key = "test:key"
|
session.key = "test:key"
|
||||||
|
consolidator.sessions._session_cache[session.key] = session
|
||||||
consolidator.estimate_session_prompt_tokens = MagicMock(return_value=(100, "tiktoken"))
|
consolidator.estimate_session_prompt_tokens = MagicMock(return_value=(100, "tiktoken"))
|
||||||
consolidator.archive = AsyncMock(return_value=True)
|
consolidator.archive = AsyncMock(return_value=True)
|
||||||
await consolidator.maybe_consolidate_by_tokens(session)
|
await consolidator.maybe_consolidate_by_tokens(session)
|
||||||
@ -152,6 +159,7 @@ class TestConsolidatorTokenBudget:
|
|||||||
session.add_message("user", f"u{i}")
|
session.add_message("user", f"u{i}")
|
||||||
session.add_message("assistant", f"a{i}")
|
session.add_message("assistant", f"a{i}")
|
||||||
|
|
||||||
|
consolidator.sessions._session_cache[session.key] = session
|
||||||
consolidator.estimate_session_prompt_tokens = MagicMock(return_value=(100, "tiktoken"))
|
consolidator.estimate_session_prompt_tokens = MagicMock(return_value=(100, "tiktoken"))
|
||||||
consolidator.archive = AsyncMock(return_value="old conversation summary")
|
consolidator.archive = AsyncMock(return_value="old conversation summary")
|
||||||
|
|
||||||
@ -184,6 +192,7 @@ class TestConsolidatorTokenBudget:
|
|||||||
session.add_message("tool", "tool result", tool_call_id="call-1", name="x")
|
session.add_message("tool", "tool result", tool_call_id="call-1", name="x")
|
||||||
session.add_message("assistant", "final answer")
|
session.add_message("assistant", "final answer")
|
||||||
|
|
||||||
|
consolidator.sessions._session_cache[session.key] = session
|
||||||
consolidator.estimate_session_prompt_tokens = MagicMock(return_value=(100, "tiktoken"))
|
consolidator.estimate_session_prompt_tokens = MagicMock(return_value=(100, "tiktoken"))
|
||||||
consolidator.archive = AsyncMock(return_value="tool turn summary")
|
consolidator.archive = AsyncMock(return_value="tool turn summary")
|
||||||
|
|
||||||
@ -210,6 +219,7 @@ class TestConsolidatorTokenBudget:
|
|||||||
}
|
}
|
||||||
for i in range(70)
|
for i in range(70)
|
||||||
]
|
]
|
||||||
|
consolidator.sessions._session_cache[session.key] = session
|
||||||
consolidator.estimate_session_prompt_tokens = MagicMock(
|
consolidator.estimate_session_prompt_tokens = MagicMock(
|
||||||
side_effect=[(1200, "tiktoken"), (400, "tiktoken")]
|
side_effect=[(1200, "tiktoken"), (400, "tiktoken")]
|
||||||
)
|
)
|
||||||
@ -238,6 +248,7 @@ class TestConsolidatorTokenBudget:
|
|||||||
for i in range(70)
|
for i in range(70)
|
||||||
]
|
]
|
||||||
session.metadata = {}
|
session.metadata = {}
|
||||||
|
consolidator.sessions._session_cache[session.key] = session
|
||||||
consolidator.estimate_session_prompt_tokens = MagicMock(
|
consolidator.estimate_session_prompt_tokens = MagicMock(
|
||||||
side_effect=[(1200, "tiktoken"), (400, "tiktoken")]
|
side_effect=[(1200, "tiktoken"), (400, "tiktoken")]
|
||||||
)
|
)
|
||||||
@ -263,6 +274,7 @@ class TestConsolidatorTokenBudget:
|
|||||||
for i in range(70)
|
for i in range(70)
|
||||||
]
|
]
|
||||||
session.metadata = {}
|
session.metadata = {}
|
||||||
|
consolidator.sessions._session_cache[session.key] = session
|
||||||
# Keep estimates high so the loop would otherwise run multiple rounds.
|
# Keep estimates high so the loop would otherwise run multiple rounds.
|
||||||
consolidator.estimate_session_prompt_tokens = MagicMock(
|
consolidator.estimate_session_prompt_tokens = MagicMock(
|
||||||
return_value=(1200, "tiktoken")
|
return_value=(1200, "tiktoken")
|
||||||
@ -287,6 +299,7 @@ class TestConsolidatorTokenBudget:
|
|||||||
}
|
}
|
||||||
for i in range(70)
|
for i in range(70)
|
||||||
]
|
]
|
||||||
|
consolidator.sessions._session_cache[session.key] = session
|
||||||
consolidator.estimate_session_prompt_tokens = MagicMock(
|
consolidator.estimate_session_prompt_tokens = MagicMock(
|
||||||
side_effect=[(1200, "tiktoken"), (400, "tiktoken")]
|
side_effect=[(1200, "tiktoken"), (400, "tiktoken")]
|
||||||
)
|
)
|
||||||
@ -461,6 +474,57 @@ class TestCompactIdleSession:
|
|||||||
assert not lock.locked()
|
assert not lock.locked()
|
||||||
|
|
||||||
|
|
||||||
|
class TestConsolidatorSessionRefresh:
|
||||||
|
"""Background consolidation must detect stale session references."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reloads_stale_session_after_compact(self, tmp_path):
|
||||||
|
"""After compact_idle_session replaces the session, a concurrent
|
||||||
|
maybe_consolidate_by_tokens with the old reference should use the
|
||||||
|
fresh session from cache instead of overwriting."""
|
||||||
|
from nanobot.agent.memory import Consolidator, MemoryStore
|
||||||
|
from nanobot.session.manager import SessionManager
|
||||||
|
|
||||||
|
store = MemoryStore(tmp_path)
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.chat_with_retry = AsyncMock(
|
||||||
|
return_value=MagicMock(content="summary", finish_reason="stop")
|
||||||
|
)
|
||||||
|
provider.generation.max_tokens = 4096
|
||||||
|
provider.estimate_prompt_tokens = MagicMock(return_value=(10, "test"))
|
||||||
|
sessions = SessionManager(tmp_path)
|
||||||
|
consolidator = Consolidator(
|
||||||
|
store=store,
|
||||||
|
provider=provider,
|
||||||
|
model="test-model",
|
||||||
|
sessions=sessions,
|
||||||
|
context_window_tokens=128_000,
|
||||||
|
build_messages=MagicMock(return_value=[]),
|
||||||
|
get_tool_definitions=MagicMock(return_value=[]),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Populate session with many messages
|
||||||
|
session = sessions.get_or_create("cli:test")
|
||||||
|
for i in range(20):
|
||||||
|
session.add_message("user", f"u{i}")
|
||||||
|
session.add_message("assistant", f"a{i}")
|
||||||
|
sessions.save(session)
|
||||||
|
|
||||||
|
# Simulate: background consolidation captures old reference
|
||||||
|
old_ref = session
|
||||||
|
|
||||||
|
# AutoCompact runs first and truncates to 8
|
||||||
|
await consolidator.compact_idle_session("cli:test", max_suffix=8)
|
||||||
|
|
||||||
|
# Background consolidation runs with stale reference —
|
||||||
|
# should detect the session was replaced and not undo the compact.
|
||||||
|
await consolidator.maybe_consolidate_by_tokens(old_ref)
|
||||||
|
|
||||||
|
session_after = sessions.get_or_create("cli:test")
|
||||||
|
# Messages should still be truncated (not restored to 40)
|
||||||
|
assert len(session_after.messages) <= 8
|
||||||
|
|
||||||
|
|
||||||
class TestRawArchiveTruncation:
|
class TestRawArchiveTruncation:
|
||||||
"""raw_archive() must cap entry size to avoid bloating history.jsonl."""
|
"""raw_archive() must cap entry size to avoid bloating history.jsonl."""
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user