fix(agent): move archived summary into system prompt for KV cache stability

- Append [Archived Context Summary] to system prompt instead of injecting
  it into the user message runtime context, improving KV cache reuse across
  turns and avoiding consecutive same-role messages.
- _last_summary persists in metadata (no pop) for restart survival;
  summary is re-injected every turn via the stable system prompt.
- Remove dynamic "Inactive for X minutes" from _format_summary — use
  static last_active timestamp instead to preserve KV cache stability.
- Pass session_summary through build_messages() so both normal and
  ask_user paths receive the archived summary in the system prompt.
- estimate_session_prompt_tokens now reads _last_summary from metadata
  to include the summary in token budget estimation.
- Remove obsolete session_summary parameter from
  maybe_consolidate_by_tokens and estimate_session_prompt_tokens
  call sites in loop.py (summary flows through build_messages instead).
- Ensure /new (session.clear()) clears _last_summary from metadata.
This commit is contained in:
chengyongru 2026-05-07 16:17:27 +08:00 committed by Xubin Ren
parent 73a8d8a875
commit a6e993df25
8 changed files with 119 additions and 41 deletions

View File

@ -7,6 +7,7 @@ from datetime import datetime
from typing import TYPE_CHECKING, Any, Callable, Coroutine from typing import TYPE_CHECKING, Any, Callable, Coroutine
from loguru import logger from loguru import logger
from nanobot.session.manager import Session, SessionManager from nanobot.session.manager import Session, SessionManager
if TYPE_CHECKING: if TYPE_CHECKING:
@ -34,8 +35,7 @@ class AutoCompact:
@staticmethod @staticmethod
def _format_summary(text: str, last_active: datetime) -> str: def _format_summary(text: str, last_active: datetime) -> str:
idle_min = int((datetime.now() - last_active).total_seconds() / 60) return f"Previous conversation summary (last active {last_active.isoformat()}):\n{text}"
return f"Inactive for {idle_min} minutes.\nPrevious conversation summary: {text}"
def _split_unconsolidated( def _split_unconsolidated(
self, session: Session, self, session: Session,
@ -111,13 +111,11 @@ class AutoCompact:
logger.info("Auto-compact: reloading session {} (archiving={})", key, key in self._archiving) logger.info("Auto-compact: reloading session {} (archiving={})", key, key in self._archiving)
session = self.sessions.get_or_create(key) session = self.sessions.get_or_create(key)
# Hot path: summary from in-memory dict (process hasn't restarted). # Hot path: summary from in-memory dict (process hasn't restarted).
# Also clean metadata copy so stale _last_summary never leaks to disk.
entry = self._summaries.pop(key, None) entry = self._summaries.pop(key, None)
if entry: if entry:
session.metadata.pop("_last_summary", None)
return session, self._format_summary(entry[0], entry[1]) return session, self._format_summary(entry[0], entry[1])
if "_last_summary" in session.metadata: # Cold path: summary persisted in session metadata (process restarted).
meta = session.metadata.pop("_last_summary") meta = session.metadata.get("_last_summary")
self.sessions.save(session) if isinstance(meta, dict):
return session, self._format_summary(meta["text"], datetime.fromisoformat(meta["last_active"])) return session, self._format_summary(meta["text"], datetime.fromisoformat(meta["last_active"]))
return session, None return session, None

View File

@ -10,7 +10,12 @@ from typing import Any
from nanobot.agent.memory import MemoryStore from nanobot.agent.memory import MemoryStore
from nanobot.agent.skills import SkillsLoader from nanobot.agent.skills import SkillsLoader
from nanobot.utils.helpers import build_assistant_message, current_time_str, detect_image_mime, truncate_text from nanobot.utils.helpers import (
build_assistant_message,
current_time_str,
detect_image_mime,
truncate_text,
)
from nanobot.utils.prompt_templates import render_template from nanobot.utils.prompt_templates import render_template
@ -33,6 +38,7 @@ class ContextBuilder:
self, self,
skill_names: list[str] | None = None, skill_names: list[str] | None = None,
channel: str | None = None, channel: str | None = None,
session_summary: str | None = None,
) -> str: ) -> str:
"""Build the system prompt from identity, bootstrap files, memory, and skills.""" """Build the system prompt from identity, bootstrap files, memory, and skills."""
parts = [self._get_identity(channel=channel)] parts = [self._get_identity(channel=channel)]
@ -64,6 +70,9 @@ class ContextBuilder:
history_text = truncate_text(history_text, self._MAX_HISTORY_CHARS) history_text = truncate_text(history_text, self._MAX_HISTORY_CHARS)
parts.append("# Recent History\n\n" + history_text) parts.append("# Recent History\n\n" + history_text)
if session_summary:
parts.append(f"[Archived Context Summary]\n\n{session_summary}")
return "\n\n---\n\n".join(parts) return "\n\n---\n\n".join(parts)
def _get_identity(self, channel: str | None = None) -> str: def _get_identity(self, channel: str | None = None) -> str:
@ -83,7 +92,7 @@ class ContextBuilder:
@staticmethod @staticmethod
def _build_runtime_context( def _build_runtime_context(
channel: str | None, chat_id: str | None, timezone: str | None = None, channel: str | None, chat_id: str | None, timezone: str | None = None,
session_summary: str | None = None, sender_id: str | None = None, sender_id: str | None = None,
) -> str: ) -> str:
"""Build untrusted runtime metadata block for injection before the user message.""" """Build untrusted runtime metadata block for injection before the user message."""
lines = [f"Current Time: {current_time_str(timezone)}"] lines = [f"Current Time: {current_time_str(timezone)}"]
@ -91,8 +100,6 @@ class ContextBuilder:
lines += [f"Channel: {channel}", f"Chat ID: {chat_id}"] lines += [f"Channel: {channel}", f"Chat ID: {chat_id}"]
if sender_id: if sender_id:
lines += [f"Sender ID: {sender_id}"] lines += [f"Sender ID: {sender_id}"]
if session_summary:
lines += ["", "[Resumed Session]", session_summary]
return ContextBuilder._RUNTIME_CONTEXT_TAG + "\n" + "\n".join(lines) + "\n" + ContextBuilder._RUNTIME_CONTEXT_END return ContextBuilder._RUNTIME_CONTEXT_TAG + "\n" + "\n".join(lines) + "\n" + ContextBuilder._RUNTIME_CONTEXT_END
@staticmethod @staticmethod
@ -139,11 +146,11 @@ class ContextBuilder:
channel: str | None = None, channel: str | None = None,
chat_id: str | None = None, chat_id: str | None = None,
current_role: str = "user", current_role: str = "user",
session_summary: str | None = None,
sender_id: str | None = None, sender_id: str | None = None,
session_summary: str | None = None,
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
"""Build the complete message list for an LLM call.""" """Build the complete message list for an LLM call."""
runtime_ctx = self._build_runtime_context(channel, chat_id, self.timezone, session_summary=session_summary, sender_id=sender_id) runtime_ctx = self._build_runtime_context(channel, chat_id, self.timezone, sender_id=sender_id)
user_content = self._build_user_content(current_message, media) user_content = self._build_user_content(current_message, media)
# Merge runtime context and user content into a single user message # Merge runtime context and user content into a single user message
@ -153,7 +160,7 @@ class ContextBuilder:
else: else:
merged = [{"type": "text", "text": runtime_ctx}] + user_content merged = [{"type": "text", "text": runtime_ctx}] + user_content
messages = [ messages = [
{"role": "system", "content": self.build_system_prompt(skill_names, channel=channel)}, {"role": "system", "content": self.build_system_prompt(skill_names, channel=channel, session_summary=session_summary)},
*history, *history,
] ]
if messages[-1].get("role") == current_role: if messages[-1].get("role") == current_role:

View File

@ -706,12 +706,16 @@ class AgentLoop:
session: Session, session: Session,
history: list[dict[str, Any]], history: list[dict[str, Any]],
pending_ask_id: str | None, pending_ask_id: str | None,
pending_summary: Any, pending_summary: str | None,
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
"""Build the initial message list for the LLM turn.""" """Build the initial message list for the LLM turn."""
if pending_ask_id: if pending_ask_id:
system_prompt = self.context.build_system_prompt(
channel=msg.channel,
session_summary=pending_summary,
)
return ask_user_tool_result_messages( return ask_user_tool_result_messages(
self.context.build_system_prompt(channel=msg.channel), system_prompt,
history, history,
pending_ask_id, pending_ask_id,
image_generation_prompt(msg.content, msg.metadata), image_generation_prompt(msg.content, msg.metadata),
@ -719,11 +723,11 @@ class AgentLoop:
return self.context.build_messages( return self.context.build_messages(
history=history, history=history,
current_message=image_generation_prompt(msg.content, msg.metadata), current_message=image_generation_prompt(msg.content, msg.metadata),
session_summary=pending_summary,
media=msg.media if msg.media else None, media=msg.media if msg.media else None,
channel=msg.channel, channel=msg.channel,
chat_id=self._runtime_chat_id(msg), chat_id=self._runtime_chat_id(msg),
sender_id=msg.sender_id, sender_id=msg.sender_id,
session_summary=pending_summary,
) )
async def _dispatch_command_inline( async def _dispatch_command_inline(
@ -1179,7 +1183,6 @@ class AgentLoop:
await self.consolidator.maybe_consolidate_by_tokens( await self.consolidator.maybe_consolidate_by_tokens(
session, session,
session_summary=pending,
replay_max_messages=self._max_messages, replay_max_messages=self._max_messages,
) )
is_subagent = msg.sender_id == "subagent" is_subagent = msg.sender_id == "subagent"
@ -1203,9 +1206,9 @@ class AgentLoop:
current_message="" if is_subagent else msg.content, current_message="" if is_subagent else msg.content,
channel=channel, channel=channel,
chat_id=chat_id, chat_id=chat_id,
session_summary=pending,
current_role=current_role, current_role=current_role,
sender_id=msg.sender_id, sender_id=msg.sender_id,
session_summary=pending,
) )
final_content, _, all_msgs, stop_reason, _ = await self._run_agent_loop( final_content, _, all_msgs, stop_reason, _ = await self._run_agent_loop(
messages, session=session, channel=channel, chat_id=chat_id, messages, session=session, channel=channel, chat_id=chat_id,
@ -1413,7 +1416,6 @@ class AgentLoop:
async def _state_build(self, ctx: TurnContext) -> str: async def _state_build(self, ctx: TurnContext) -> str:
await self.consolidator.maybe_consolidate_by_tokens( await self.consolidator.maybe_consolidate_by_tokens(
ctx.session, ctx.session,
session_summary=ctx.pending_summary,
replay_max_messages=self._max_messages, replay_max_messages=self._max_messages,
) )
self._set_tool_context( self._set_tool_context(

View File

@ -590,19 +590,20 @@ class Consolidator:
def estimate_session_prompt_tokens( def estimate_session_prompt_tokens(
self, self,
session: Session, session: Session,
*,
session_summary: str | None = None,
) -> tuple[int, str]: ) -> tuple[int, str]:
"""Estimate prompt size from the full unconsolidated session tail.""" """Estimate prompt size from the full unconsolidated session tail."""
history = self._full_unconsolidated_history(session, include_timestamps=True) history = self._full_unconsolidated_history(session, include_timestamps=True)
channel, chat_id = (session.key.split(":", 1) if ":" in session.key else (None, None)) channel, chat_id = (session.key.split(":", 1) if ":" in session.key else (None, None))
# Include archived summary in estimation so the budget accounts for it.
meta = session.metadata.get("_last_summary")
summary = meta.get("text") if isinstance(meta, dict) else (meta if isinstance(meta, str) else None)
probe_messages = self._build_messages( probe_messages = self._build_messages(
history=history, history=history,
current_message="[token-probe]", current_message="[token-probe]",
channel=channel, channel=channel,
chat_id=chat_id, chat_id=chat_id,
session_summary=session_summary,
sender_id=None, sender_id=None,
session_summary=summary,
) )
return estimate_prompt_tokens_chain( return estimate_prompt_tokens_chain(
self.provider, self.provider,
@ -669,7 +670,6 @@ class Consolidator:
self, self,
session: Session, session: Session,
*, *,
session_summary: str | None = None,
replay_max_messages: int | None = None, replay_max_messages: int | None = None,
) -> None: ) -> None:
"""Loop: archive old messages until prompt fits within safe budget. """Loop: archive old messages until prompt fits within safe budget.
@ -691,7 +691,6 @@ class Consolidator:
try: try:
estimated, source = self.estimate_session_prompt_tokens( estimated, source = self.estimate_session_prompt_tokens(
session, session,
session_summary=session_summary,
) )
except Exception: except Exception:
logger.exception("Token estimation failed for {}", session.key) logger.exception("Token estimation failed for {}", session.key)
@ -757,7 +756,6 @@ class Consolidator:
try: try:
estimated, source = self.estimate_session_prompt_tokens( estimated, source = self.estimate_session_prompt_tokens(
session, session,
session_summary=session_summary,
) )
except Exception: except Exception:
logger.exception("Token estimation failed for {}", session.key) logger.exception("Token estimation failed for {}", session.key)

View File

@ -181,6 +181,7 @@ class Session:
self.messages = [] self.messages = []
self.last_consolidated = 0 self.last_consolidated = 0
self.updated_at = datetime.now() self.updated_at = datetime.now()
self.metadata.pop("_last_summary", None)
def retain_recent_legal_suffix(self, max_messages: int) -> None: def retain_recent_legal_suffix(self, max_messages: int) -> None:
"""Keep a legal recent suffix constrained by a hard message cap.""" """Keep a legal recent suffix constrained by a hard message cap."""

View File

@ -1020,14 +1020,14 @@ class TestSummaryPersistence:
assert summary is not None assert summary is not None
assert "User said hello." in summary assert "User said hello." in summary
assert "Inactive for" in summary assert "Previous conversation summary" in summary
# Metadata should be cleaned up after consumption # _last_summary persists in metadata for restart survival.
assert "_last_summary" not in reloaded.metadata assert "_last_summary" in reloaded.metadata
await loop.close_mcp() await loop.close_mcp()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_metadata_cleanup_no_leak(self, tmp_path): async def test_metadata_persists_for_restart(self, tmp_path):
"""_last_summary should be removed from metadata after being consumed.""" """_last_summary stays in metadata so it survives process restarts."""
loop = _make_loop(tmp_path, session_ttl_minutes=15) loop = _make_loop(tmp_path, session_ttl_minutes=15)
session = loop.sessions.get_or_create("cli:test") session = loop.sessions.get_or_create("cli:test")
_add_turns(session, 6, prefix="hello") _add_turns(session, 6, prefix="hello")
@ -1046,14 +1046,14 @@ class TestSummaryPersistence:
loop.sessions.invalidate("cli:test") loop.sessions.invalidate("cli:test")
reloaded = loop.sessions.get_or_create("cli:test") reloaded = loop.sessions.get_or_create("cli:test")
# First call: consumes from metadata # Every call returns the summary from metadata (no _consumed_keys gate)
_, summary = loop.auto_compact.prepare_session(reloaded, "cli:test") _, summary = loop.auto_compact.prepare_session(reloaded, "cli:test")
assert summary is not None assert summary is not None
# Second call: no summary (already consumed)
_, summary2 = loop.auto_compact.prepare_session(reloaded, "cli:test") _, summary2 = loop.auto_compact.prepare_session(reloaded, "cli:test")
assert summary2 is None assert summary2 is not None
assert "_last_summary" not in reloaded.metadata assert "Summary." in summary2
# _last_summary persists in metadata for restart survival.
assert "_last_summary" in reloaded.metadata
await loop.close_mcp() await loop.close_mcp()
@pytest.mark.asyncio @pytest.mark.asyncio
@ -1081,6 +1081,79 @@ class TestSummaryPersistence:
# In-memory path is taken (no restart) # In-memory path is taken (no restart)
_, summary = loop.auto_compact.prepare_session(reloaded, "cli:test") _, summary = loop.auto_compact.prepare_session(reloaded, "cli:test")
assert summary is not None assert summary is not None
# Metadata should also be cleaned up # _last_summary persists in metadata for restart survival.
assert "_last_summary" not in reloaded.metadata assert "_last_summary" in reloaded.metadata
await loop.close_mcp()
@pytest.mark.asyncio
async def test_new_summary_overrides_old(self, tmp_path):
"""A fresh archive writes a new summary that replaces the old one."""
loop = _make_loop(tmp_path, session_ttl_minutes=15)
session = loop.sessions.get_or_create("cli:test")
_add_turns(session, 6, prefix="hello")
session.updated_at = datetime.now() - timedelta(minutes=20)
loop.sessions.save(session)
async def _fake_archive(messages):
return "First summary."
loop.consolidator.archive = _fake_archive
await loop.auto_compact._archive("cli:test")
# Consume the first summary via hot path
_, summary1 = loop.auto_compact.prepare_session(
loop.sessions.get_or_create("cli:test"), "cli:test"
)
assert summary1 is not None
assert "First summary." in summary1
assert "cli:test" not in loop.auto_compact._summaries # popped by hot path
# Add new messages and archive again (simulating a later turn)
_add_turns(session, 4, prefix="world")
session.updated_at = datetime.now() - timedelta(minutes=20)
loop.sessions.save(session)
async def _fake_archive2(messages):
return "Second summary."
loop.consolidator.archive = _fake_archive2
await loop.auto_compact._archive("cli:test")
# The second archive writes a new summary
assert "cli:test" in loop.auto_compact._summaries
# prepare_session must return the new summary
reloaded = loop.sessions.get_or_create("cli:test")
_, summary2 = loop.auto_compact.prepare_session(reloaded, "cli:test")
assert summary2 is not None
assert "Second summary." in summary2
await loop.close_mcp()
@pytest.mark.asyncio
async def test_new_command_clears_last_summary(self, tmp_path):
"""/new should clear _last_summary so the new session starts fresh."""
loop = _make_loop(tmp_path, session_ttl_minutes=15)
session = loop.sessions.get_or_create("cli:test")
_add_turns(session, 6, prefix="hello")
session.updated_at = datetime.now() - timedelta(minutes=20)
loop.sessions.save(session)
async def _fake_archive(messages):
return "Old summary."
loop.consolidator.archive = _fake_archive
await loop.auto_compact._archive("cli:test")
# Verify summary exists before /new
reloaded = loop.sessions.get_or_create("cli:test")
assert "_last_summary" in reloaded.metadata
# Simulate /new command
session.clear()
loop.sessions.save(session)
loop.sessions.invalidate(session.key)
# After /new, metadata should no longer contain _last_summary
fresh = loop.sessions.get_or_create("cli:test")
assert "_last_summary" not in fresh.metadata
await loop.close_mcp() await loop.close_mcp()

View File

@ -190,7 +190,8 @@ async def test_consolidation_persists_summary_for_next_prepare_session(tmp_path,
reloaded, pending = loop.auto_compact.prepare_session(reloaded, "cli:test") reloaded, pending = loop.auto_compact.prepare_session(reloaded, "cli:test")
assert pending is not None assert pending is not None
assert "User discussed project status." in pending assert "User discussed project status." in pending
assert "_last_summary" not in reloaded.metadata # _last_summary persists for restart survival.
assert "_last_summary" in reloaded.metadata
@pytest.mark.asyncio @pytest.mark.asyncio
@ -207,7 +208,6 @@ async def test_preflight_consolidation_receives_pending_summary(tmp_path) -> Non
loop.consolidator.maybe_consolidate_by_tokens.assert_any_await( loop.consolidator.maybe_consolidate_by_tokens.assert_any_await(
session, session,
session_summary="Previous conversation summary: earlier context",
replay_max_messages=loop._max_messages, replay_max_messages=loop._max_messages,
) )

View File

@ -399,7 +399,6 @@ class TestConsolidationUnaffectedByUnifiedSession:
# estimate was called (consolidation was attempted) # estimate was called (consolidation was attempted)
consolidator.estimate_session_prompt_tokens.assert_called_once_with( consolidator.estimate_session_prompt_tokens.assert_called_once_with(
session, session,
session_summary=None,
) )
# but archive was not called (no valid boundary) # but archive was not called (no valid boundary)
consolidator.archive.assert_not_called() consolidator.archive.assert_not_called()