diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index 3350c447b..53cb49d71 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -648,7 +648,10 @@ class AgentLoop: session, pending = self.auto_compact.prepare_session(session, key) - await self.consolidator.maybe_consolidate_by_tokens(session) + await self.consolidator.maybe_consolidate_by_tokens( + session, + session_summary=pending, + ) # Persist subagent follow-ups into durable history BEFORE prompt # assembly. ContextBuilder merges adjacent same-role messages for # provider compatibility, which previously caused the follow-up to @@ -709,7 +712,10 @@ class AgentLoop: if result := await self.commands.dispatch(ctx): return result - await self.consolidator.maybe_consolidate_by_tokens(session) + await self.consolidator.maybe_consolidate_by_tokens( + session, + session_summary=pending, + ) self._set_tool_context(msg.channel, msg.chat_id, msg.metadata.get("message_id")) if message_tool := self.tools.get("message"): diff --git a/nanobot/agent/memory.py b/nanobot/agent/memory.py index e99590538..fb630ce13 100644 --- a/nanobot/agent/memory.py +++ b/nanobot/agent/memory.py @@ -416,7 +416,12 @@ class Consolidator: return idx return None - def estimate_session_prompt_tokens(self, session: Session) -> tuple[int, str]: + def estimate_session_prompt_tokens( + self, + session: Session, + *, + session_summary: str | None = None, + ) -> tuple[int, str]: """Estimate current prompt size for the normal session history view.""" history = session.get_history(max_messages=0) channel, chat_id = (session.key.split(":", 1) if ":" in session.key else (None, None)) @@ -425,6 +430,7 @@ class Consolidator: current_message="[token-probe]", channel=channel, chat_id=chat_id, + session_summary=session_summary, ) return estimate_prompt_tokens_chain( self.provider, @@ -467,7 +473,12 @@ class Consolidator: self.store.raw_archive(messages) return None - async def maybe_consolidate_by_tokens(self, session: Session) -> None: + async def maybe_consolidate_by_tokens( + self, + session: Session, + *, + session_summary: str | None = None, + ) -> None: """Loop: archive old messages until prompt fits within safe budget. The budget reserves space for completion tokens and a safety buffer @@ -481,7 +492,10 @@ class Consolidator: budget = self.context_window_tokens - self.max_completion_tokens - self._SAFETY_BUFFER target = budget // 2 try: - estimated, source = self.estimate_session_prompt_tokens(session) + estimated, source = self.estimate_session_prompt_tokens( + session, + session_summary=session_summary, + ) except Exception: logger.exception("Token estimation failed for {}", session.key) estimated, source = 0, "error" @@ -545,7 +559,10 @@ class Consolidator: self.sessions.save(session) try: - estimated, source = self.estimate_session_prompt_tokens(session) + estimated, source = self.estimate_session_prompt_tokens( + session, + session_summary=session_summary, + ) except Exception: logger.exception("Token estimation failed for {}", session.key) estimated, source = 0, "error" diff --git a/tests/agent/test_loop_consolidation_tokens.py b/tests/agent/test_loop_consolidation_tokens.py index a5d0ce9b0..347cab1e9 100644 --- a/tests/agent/test_loop_consolidation_tokens.py +++ b/tests/agent/test_loop_consolidation_tokens.py @@ -102,7 +102,7 @@ async def test_consolidation_loops_until_target_met(tmp_path, monkeypatch) -> No loop.sessions.save(session) call_count = [0] - def mock_estimate(_session): + def mock_estimate(_session, *, session_summary=None): call_count[0] += 1 if call_count[0] == 1: return (500, "test") @@ -139,7 +139,7 @@ async def test_consolidation_continues_below_trigger_until_half_target(tmp_path, call_count = [0] - def mock_estimate(_session): + def mock_estimate(_session, *, session_summary=None): call_count[0] += 1 if call_count[0] == 1: return (500, "test") @@ -171,7 +171,7 @@ async def test_consolidation_persists_summary_for_next_prepare_session(tmp_path, call_count = [0] - def mock_estimate(_session): + def mock_estimate(_session, *, session_summary=None): call_count[0] += 1 if call_count[0] == 1: return (500, "test") @@ -193,6 +193,24 @@ async def test_consolidation_persists_summary_for_next_prepare_session(tmp_path, assert "_last_summary" not in reloaded.metadata +@pytest.mark.asyncio +async def test_preflight_consolidation_receives_pending_summary(tmp_path) -> None: + loop = _make_loop(tmp_path, estimated_tokens=100, context_window_tokens=200) + session = loop.sessions.get_or_create("cli:test") + loop.auto_compact.prepare_session = MagicMock( + return_value=(session, "Previous conversation summary: earlier context") + ) # type: ignore[method-assign] + loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=None) # type: ignore[method-assign] + loop._schedule_background = lambda coro: coro.close() # type: ignore[method-assign] + + await loop.process_direct("hello", session_key="cli:test") + + loop.consolidator.maybe_consolidate_by_tokens.assert_awaited_once_with( + session, + session_summary="Previous conversation summary: earlier context", + ) + + @pytest.mark.asyncio async def test_preflight_consolidation_before_llm_call(tmp_path, monkeypatch) -> None: """Verify preflight consolidation runs before the LLM call in process_direct.""" @@ -210,6 +228,7 @@ async def test_preflight_consolidation_before_llm_call(tmp_path, monkeypatch) -> return LLMResponse(content="ok", tool_calls=[]) loop.provider.chat_with_retry = track_llm loop.provider.chat_stream_with_retry = track_llm + loop._schedule_background = lambda coro: coro.close() # type: ignore[method-assign] session = loop.sessions.get_or_create("cli:test") session.messages = [ @@ -221,7 +240,7 @@ async def test_preflight_consolidation_before_llm_call(tmp_path, monkeypatch) -> monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _m: 500) call_count = [0] - def mock_estimate(_session): + def mock_estimate(_session, *, session_summary=None): call_count[0] += 1 return (1000 if call_count[0] <= 1 else 80, "test") loop.consolidator.estimate_session_prompt_tokens = mock_estimate # type: ignore[method-assign] diff --git a/tests/agent/test_unified_session.py b/tests/agent/test_unified_session.py index 557beaca7..acf0b9d6d 100644 --- a/tests/agent/test_unified_session.py +++ b/tests/agent/test_unified_session.py @@ -395,7 +395,10 @@ class TestConsolidationUnaffectedByUnifiedSession: await consolidator.maybe_consolidate_by_tokens(session) # estimate was called (consolidation was attempted) - consolidator.estimate_session_prompt_tokens.assert_called_once_with(session) + consolidator.estimate_session_prompt_tokens.assert_called_once_with( + session, + session_summary=None, + ) # but archive was not called (no valid boundary) consolidator.archive.assert_not_called()