nanobot/nanobot/agent/memory.py
2026-06-10 18:09:45 +08:00

1016 lines
39 KiB
Python

"""Memory system: pure file I/O store and lightweight Consolidator."""
from __future__ import annotations
import asyncio
import json
import os
import re
import threading
import weakref
from contextlib import suppress
from datetime import datetime
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Iterator
import tiktoken
from loguru import logger
from nanobot.session.manager import Session
from nanobot.utils.gitstore import GitStore
from nanobot.utils.helpers import (
ensure_dir,
estimate_message_tokens,
estimate_prompt_tokens_chain,
find_legal_message_start,
strip_think,
truncate_text,
)
from nanobot.utils.prompt_templates import render_template
if TYPE_CHECKING:
from nanobot.providers.base import LLMProvider
from nanobot.session.manager import SessionManager
# ---------------------------------------------------------------------------
# MemoryStore — pure file I/O layer
# ---------------------------------------------------------------------------
class MemoryStore:
"""Pure file I/O for memory files: MEMORY.md, history.jsonl, SOUL.md, USER.md."""
_DEFAULT_MAX_HISTORY = 1000
_INTERNAL_HISTORY_SESSION_PREFIXES = ("cron:", "dream:")
_INTERNAL_HISTORY_SESSION_KEYS = {"heartbeat"}
_LEGACY_ENTRY_START_RE = re.compile(r"^\[(\d{4}-\d{2}-\d{2}[^\]]*)\]\s*")
_LEGACY_TIMESTAMP_RE = re.compile(r"^\[(\d{4}-\d{2}-\d{2} \d{2}:\d{2})\]\s*")
_LEGACY_RAW_MESSAGE_RE = re.compile(
r"^\[\d{4}-\d{2}-\d{2}[^\]]*\]\s+[A-Z][A-Z0-9_]*(?:\s+\[tools:\s*[^\]]+\])?:"
)
def __init__(self, workspace: Path, max_history_entries: int = _DEFAULT_MAX_HISTORY):
self.workspace = workspace
self.max_history_entries = max_history_entries
self.memory_dir = ensure_dir(workspace / "memory")
self.memory_file = self.memory_dir / "MEMORY.md"
self.history_file = self.memory_dir / "history.jsonl"
self.legacy_history_file = self.memory_dir / "HISTORY.md"
self.soul_file = workspace / "SOUL.md"
self.user_file = workspace / "USER.md"
self._cursor_file = self.memory_dir / ".cursor"
self._dream_cursor_file = self.memory_dir / ".dream_cursor"
self._corruption_logged = False # rate-limit non-int cursor warning
self._oversize_logged = False # rate-limit oversized-entry warning
self._append_lock = threading.Lock() # serialize cursor allocation + append
self._git = GitStore(workspace, tracked_files=[
"SOUL.md", "USER.md", "memory/MEMORY.md", "memory/.dream_cursor",
])
self._maybe_migrate_legacy_history()
@property
def git(self) -> GitStore:
return self._git
# -- generic helpers -----------------------------------------------------
@staticmethod
def read_file(path: Path) -> str:
try:
return path.read_text(encoding="utf-8")
except FileNotFoundError:
return ""
def _maybe_migrate_legacy_history(self) -> None:
"""One-time upgrade from legacy HISTORY.md to history.jsonl.
The migration is best-effort and prioritizes preserving as much content
as possible over perfect parsing.
"""
if not self.legacy_history_file.exists():
return
if self.history_file.exists() and self.history_file.stat().st_size > 0:
return
try:
legacy_text = self.legacy_history_file.read_text(
encoding="utf-8",
errors="replace",
)
except OSError:
logger.exception("Failed to read legacy HISTORY.md for migration")
return
entries = self._parse_legacy_history(legacy_text)
try:
if entries:
self._write_entries(entries)
last_cursor = entries[-1]["cursor"]
self._cursor_file.write_text(str(last_cursor), encoding="utf-8")
# Default to "already processed" so upgrades do not replay the
# user's entire historical archive into Dream on first start.
self._dream_cursor_file.write_text(str(last_cursor), encoding="utf-8")
backup_path = self._next_legacy_backup_path()
self.legacy_history_file.replace(backup_path)
logger.info(
"Migrated legacy HISTORY.md to history.jsonl ({} entries)",
len(entries),
)
except Exception:
logger.exception("Failed to migrate legacy HISTORY.md")
def _parse_legacy_history(self, text: str) -> list[dict[str, Any]]:
normalized = text.replace("\r\n", "\n").replace("\r", "\n").strip()
if not normalized:
return []
fallback_timestamp = self._legacy_fallback_timestamp()
entries: list[dict[str, Any]] = []
chunks = self._split_legacy_history_chunks(normalized)
for cursor, chunk in enumerate(chunks, start=1):
timestamp = fallback_timestamp
content = chunk
match = self._LEGACY_TIMESTAMP_RE.match(chunk)
if match:
timestamp = match.group(1)
remainder = chunk[match.end():].lstrip()
if remainder:
content = remainder
entries.append({
"cursor": cursor,
"timestamp": timestamp,
"content": content,
})
return entries
def _split_legacy_history_chunks(self, text: str) -> list[str]:
lines = text.split("\n")
chunks: list[str] = []
current: list[str] = []
saw_blank_separator = False
for line in lines:
if saw_blank_separator and line.strip() and current:
chunks.append("\n".join(current).strip())
current = [line]
saw_blank_separator = False
continue
if self._should_start_new_legacy_chunk(line, current):
chunks.append("\n".join(current).strip())
current = [line]
saw_blank_separator = False
continue
current.append(line)
saw_blank_separator = not line.strip()
if current:
chunks.append("\n".join(current).strip())
return [chunk for chunk in chunks if chunk]
def _should_start_new_legacy_chunk(self, line: str, current: list[str]) -> bool:
if not current:
return False
if not self._LEGACY_ENTRY_START_RE.match(line):
return False
if self._is_raw_legacy_chunk(current) and self._LEGACY_RAW_MESSAGE_RE.match(line):
return False
return True
def _is_raw_legacy_chunk(self, lines: list[str]) -> bool:
first_nonempty = next((line for line in lines if line.strip()), "")
match = self._LEGACY_TIMESTAMP_RE.match(first_nonempty)
if not match:
return False
return first_nonempty[match.end():].lstrip().startswith("[RAW]")
def _legacy_fallback_timestamp(self) -> str:
try:
return datetime.fromtimestamp(
self.legacy_history_file.stat().st_mtime,
).strftime("%Y-%m-%d %H:%M")
except OSError:
return datetime.now().strftime("%Y-%m-%d %H:%M")
def _next_legacy_backup_path(self) -> Path:
candidate = self.memory_dir / "HISTORY.md.bak"
suffix = 2
while candidate.exists():
candidate = self.memory_dir / f"HISTORY.md.bak.{suffix}"
suffix += 1
return candidate
# -- MEMORY.md (long-term facts) -----------------------------------------
def read_memory(self) -> str:
return self.read_file(self.memory_file)
def write_memory(self, content: str) -> None:
self.memory_file.write_text(content, encoding="utf-8")
# -- SOUL.md -------------------------------------------------------------
def read_soul(self) -> str:
return self.read_file(self.soul_file)
def write_soul(self, content: str) -> None:
self.soul_file.write_text(content, encoding="utf-8")
# -- USER.md -------------------------------------------------------------
def read_user(self) -> str:
return self.read_file(self.user_file)
def write_user(self, content: str) -> None:
self.user_file.write_text(content, encoding="utf-8")
# -- context injection (used by context.py) ------------------------------
def get_memory_context(self) -> str:
long_term = self.read_memory()
return f"## Long-term Memory\n{long_term}" if long_term else ""
# -- history.jsonl — append-only, JSONL format ---------------------------
def append_history(
self,
entry: str,
*,
max_chars: int | None = None,
session_key: str | None = None,
) -> int:
"""Append *entry* to history.jsonl and return its auto-incrementing cursor.
Entries are passed through `strip_think` to drop template-level leaks
(e.g. unclosed `<think` prefixes, `<channel|>` markers) before being
persisted. If the cleaned content is empty but the raw entry wasn't,
the record is persisted with an empty string rather than falling back
to the raw leak — otherwise `strip_think`'s guarantees would be
undone by history replay / consolidation downstream.
A defensive cap (*max_chars*, default ``_HISTORY_ENTRY_HARD_CAP``) is
applied as a final safety net: individual callers should cap their own
content more tightly; this default only exists to catch unintentional
large writes (e.g. an LLM echoing its input back as a "summary").
"""
limit = max_chars if max_chars is not None else _HISTORY_ENTRY_HARD_CAP
ts = datetime.now().strftime("%Y-%m-%d %H:%M")
raw = entry.rstrip()
if len(raw) > limit:
if not self._oversize_logged:
self._oversize_logged = True
logger.warning(
"history entry exceeds {} chars ({}); truncating. "
"Usually means a caller forgot its own cap; "
"further occurrences suppressed.",
limit, len(raw),
)
raw = truncate_text(raw, limit)
content = strip_think(raw)
# Cursor allocation and the append must be atomic: concurrent writers
# could otherwise read the same current cursor and emit duplicates.
with self._append_lock:
cursor = self._next_cursor()
if raw and not content:
logger.debug(
"history entry {} stripped to empty (likely template leak); "
"persisting empty content to avoid re-polluting context",
cursor,
)
record = {"cursor": cursor, "timestamp": ts, "content": content}
if session_key:
record["session_key"] = session_key
with open(self.history_file, "a", encoding="utf-8") as f:
f.write(json.dumps(record, ensure_ascii=False) + "\n")
self._cursor_file.write_text(str(cursor), encoding="utf-8")
return cursor
@staticmethod
def _valid_cursor(value: Any) -> int | None:
"""Int cursors only — reject bool (``isinstance(True, int)`` is True)."""
if isinstance(value, bool) or not isinstance(value, int):
return None
return value
def _iter_valid_entries(self) -> Iterator[tuple[dict[str, Any], int]]:
"""Yield ``(entry, cursor)`` for entries with int cursors; warn once on corruption."""
poisoned: Any = None
for entry in self._read_entries():
raw = entry.get("cursor")
if raw is None:
continue
cursor = self._valid_cursor(raw)
if cursor is None:
poisoned = raw
continue
yield entry, cursor
if poisoned is not None and not self._corruption_logged:
self._corruption_logged = True
logger.warning(
"history.jsonl contains a non-int cursor ({!r}); dropping it. "
"Usually caused by an external writer; further occurrences suppressed.",
poisoned,
)
def _next_cursor(self) -> int:
"""Read the current cursor counter and return the next value."""
if self._cursor_file.exists():
with suppress(ValueError, OSError):
return int(self._cursor_file.read_text(encoding="utf-8").strip()) + 1
# Fast path: trust the tail when intact. Otherwise scan the whole
# file and take ``max`` — that stays correct even if the monotonic
# invariant was broken by external writes.
last = self._read_last_entry() or {}
cursor = self._valid_cursor(last.get("cursor"))
if cursor is not None:
return cursor + 1
return max((c for _, c in self._iter_valid_entries()), default=0) + 1
def read_unprocessed_history(self, since_cursor: int) -> list[dict[str, Any]]:
"""Return history entries with a valid cursor > *since_cursor*."""
return [e for e, c in self._iter_valid_entries() if c > since_cursor]
@classmethod
def _is_internal_history_session(cls, session_key: str | None) -> bool:
if not session_key:
return False
return (
session_key in cls._INTERNAL_HISTORY_SESSION_KEYS
or session_key.startswith(cls._INTERNAL_HISTORY_SESSION_PREFIXES)
)
def read_recent_history_for_prompt(
self,
since_cursor: int,
*,
session_key: str | None,
unified_session: bool = False,
) -> list[dict[str, Any]]:
"""Return unprocessed history entries safe to inject into a turn prompt."""
entries = self.read_unprocessed_history(since_cursor=since_cursor)
if session_key is None:
return entries
if not unified_session:
return [e for e in entries if e.get("session_key") == session_key]
return [
entry
for entry in entries
if (entry_session := entry.get("session_key")) == session_key
or not self._is_internal_history_session(entry_session)
]
def compact_history(self) -> None:
"""Drop oldest entries if the file exceeds *max_history_entries*."""
if self.max_history_entries <= 0:
return
entries = self._read_entries()
if len(entries) <= self.max_history_entries:
return
kept = entries[-self.max_history_entries:]
self._write_entries(kept)
# -- JSONL helpers -------------------------------------------------------
def _read_entries(self) -> list[dict[str, Any]]:
"""Read all entries from history.jsonl."""
entries: list[dict[str, Any]] = []
with suppress(FileNotFoundError):
with open(self.history_file, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
try:
entries.append(json.loads(line))
except json.JSONDecodeError:
continue
return entries
def _read_last_entry(self) -> dict[str, Any] | None:
"""Read the last entry from the JSONL file efficiently."""
try:
with open(self.history_file, "rb") as f:
f.seek(0, 2)
size = f.tell()
if size == 0:
return None
read_size = min(size, 4096)
f.seek(size - read_size)
data = f.read().decode("utf-8")
lines = [line for line in data.split("\n") if line.strip()]
if not lines:
return None
return json.loads(lines[-1])
except (FileNotFoundError, json.JSONDecodeError, UnicodeDecodeError):
return None
def _write_entries(self, entries: list[dict[str, Any]]) -> None:
"""Overwrite history.jsonl with the given entries (atomic write)."""
tmp_path = self.history_file.with_suffix(self.history_file.suffix + ".tmp")
try:
with open(tmp_path, "w", encoding="utf-8") as f:
for entry in entries:
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
f.flush()
os.fsync(f.fileno())
os.replace(tmp_path, self.history_file)
# fsync the directory so the rename is durable.
# On Windows, opening a directory with O_RDONLY raises
# PermissionError — skip the dir sync there (NTFS
# journals metadata synchronously).
with suppress(PermissionError):
fd = os.open(str(self.history_file.parent), os.O_RDONLY)
try:
os.fsync(fd)
finally:
os.close(fd)
except BaseException:
tmp_path.unlink(missing_ok=True)
raise
# -- dream cursor --------------------------------------------------------
def get_last_dream_cursor(self) -> int:
if self._dream_cursor_file.exists():
with suppress(ValueError, OSError):
return int(self._dream_cursor_file.read_text(encoding="utf-8").strip())
return 0
def set_last_dream_cursor(self, cursor: int) -> None:
self._dream_cursor_file.write_text(str(cursor), encoding="utf-8")
def build_dream_prompt(self, *, max_entries: int = 20) -> tuple[str, int] | None:
"""Build the Dream prompt with unprocessed history context.
Returns ``(prompt, last_cursor)`` or ``None`` if nothing to process.
"""
from nanobot.agent.skills import BUILTIN_SKILLS_DIR
last_cursor = self.get_last_dream_cursor()
entries = self.read_unprocessed_history(since_cursor=last_cursor)
if not entries:
return None
batch = entries[:max_entries]
history_text = "\n".join(
f"[{e['timestamp']}] {truncate_text(e['content'], 500)}"
for e in batch
)
skill_creator_path = str(BUILTIN_SKILLS_DIR / "skill-creator" / "SKILL.md")
template = render_template(
"agent/dream.md", strip=True, skill_creator_path=skill_creator_path,
)
prompt = f"{template}\n\n## Conversation History\n{history_text}"
return (prompt, batch[-1]["cursor"])
def build_dream_tools(self):
"""Build the restricted tool registry used by Dream runs."""
from nanobot.agent.skills import BUILTIN_SKILLS_DIR
from nanobot.agent.tools.apply_patch import ApplyPatchTool
from nanobot.agent.tools.file_state import FileStates
from nanobot.agent.tools.filesystem import EditFileTool, ReadFileTool, WriteFileTool
from nanobot.agent.tools.registry import ToolRegistry
tools = ToolRegistry()
file_states = FileStates()
workspace = self.workspace
skills_dir = workspace / "skills"
skills_dir.mkdir(parents=True, exist_ok=True)
extra_read = [BUILTIN_SKILLS_DIR] if BUILTIN_SKILLS_DIR.exists() else None
editable_roots = [self.soul_file, self.user_file, skills_dir]
tools.register(ReadFileTool(
workspace=workspace,
allowed_dir=workspace,
extra_allowed_dirs=extra_read,
file_states=file_states,
))
tools.register(EditFileTool(
workspace=workspace,
allowed_dir=self.memory_dir,
extra_allowed_dirs=editable_roots,
file_states=file_states,
))
tools.register(ApplyPatchTool(
workspace=workspace,
allowed_dir=self.memory_dir,
extra_allowed_dirs=editable_roots,
file_states=file_states,
))
tools.register(WriteFileTool(
workspace=workspace,
allowed_dir=skills_dir,
file_states=file_states,
))
return tools
@staticmethod
def dream_run_completed(resp: object | None) -> bool:
"""Return True only when an ephemeral Dream agent turn completed cleanly."""
metadata = getattr(resp, "metadata", None)
return isinstance(metadata, dict) and metadata.get("_stop_reason") == "completed"
# -- message formatting utility ------------------------------------------
@staticmethod
def _format_messages(messages: list[dict]) -> str:
lines = []
for message in messages:
if not message.get("content"):
continue
tools = f" [tools: {', '.join(message['tools_used'])}]" if message.get("tools_used") else ""
lines.append(
f"[{message.get('timestamp', '?')[:16]}] {message['role'].upper()}{tools}: {message['content']}"
)
return "\n".join(lines)
def raw_archive(
self,
messages: list[dict],
*,
max_chars: int | None = None,
session_key: str | None = None,
) -> None:
"""Fallback: dump raw messages to history.jsonl without LLM summarization."""
limit = max_chars if max_chars is not None else _RAW_ARCHIVE_MAX_CHARS
formatted = truncate_text(self._format_messages(messages), limit)
self.append_history(
f"[RAW] {len(messages)} messages\n"
f"{formatted}",
session_key=session_key,
)
logger.warning(
"Memory consolidation degraded: raw-archived {} messages", len(messages)
)
# ------------------------------------------------------------------
# Dream helpers
# ------------------------------------------------------------------
@staticmethod
def dream_session_key() -> str:
"""Return a unique session key for a Dream run, e.g. ``dream:20260528-100000``."""
return f"dream:{datetime.now():%Y%m%d-%H%M%S}"
@staticmethod
def build_dream_commit_message(prefix: str, resp: object | None) -> str:
"""Build a Dream auto-commit message, appending the LLM summary if present."""
msg = prefix
if resp is not None and getattr(resp, "content", None):
msg = f"{msg}\n\n{resp.content.strip()}"
return msg
@staticmethod
def prune_dream_sessions(sessions_dir: Path, *, keep: int = 10) -> None:
"""Remove the oldest Dream session files, keeping only the N most recent.
Only files matching ``dream_*.jsonl`` are considered. Non-dream session
files are never touched.
"""
dream_files = sorted(
sessions_dir.glob("dream_*.jsonl"), key=lambda p: p.stat().st_mtime,
)
if len(dream_files) <= keep:
return
to_remove = dream_files[: len(dream_files) - keep]
for path in to_remove:
try:
path.unlink()
logger.debug("Pruned old dream session: {}", path.stem)
except OSError:
logger.warning("Failed to prune dream session {}", path)
# ---------------------------------------------------------------------------
# Consolidator — lightweight token-budget triggered consolidation
# ---------------------------------------------------------------------------
# Individual history.jsonl writers cap their own payloads tightly; the
# _HISTORY_ENTRY_HARD_CAP at append_history() is a belt-and-suspenders default
# that catches any new caller that forgot to set its own cap.
_RAW_ARCHIVE_MAX_CHARS = 16_000 # fallback dump (LLM failed)
_ARCHIVE_SUMMARY_MAX_CHARS = 8_000 # LLM-produced consolidation summary
_HISTORY_ENTRY_HARD_CAP = 64_000 # emergency cap in append_history
class Consolidator:
"""Lightweight consolidation: summarizes evicted messages into history.jsonl."""
_MAX_CONSOLIDATION_ROUNDS = 5
_SAFETY_BUFFER = 1024 # extra headroom for tokenizer estimation drift
def __init__(
self,
store: MemoryStore,
provider: LLMProvider,
model: str,
sessions: SessionManager,
context_window_tokens: int,
build_messages: Callable[..., list[dict[str, Any]]],
get_tool_definitions: Callable[[], list[dict[str, Any]]],
max_completion_tokens: int = 4096,
consolidation_ratio: float = 0.5,
unified_session: bool = False,
):
self.store = store
self.provider = provider
self.model = model
self.sessions = sessions
self.context_window_tokens = context_window_tokens
self.max_completion_tokens = max_completion_tokens
self.consolidation_ratio = consolidation_ratio
self.unified_session = unified_session
self._build_messages = build_messages
self._get_tool_definitions = get_tool_definitions
self._locks: weakref.WeakValueDictionary[str, asyncio.Lock] = (
weakref.WeakValueDictionary()
)
def set_provider(
self,
provider: LLMProvider,
model: str,
context_window_tokens: int,
) -> None:
self.provider = provider
self.model = model
self.context_window_tokens = context_window_tokens
self.max_completion_tokens = provider.generation.max_tokens
def get_lock(self, session_key: str) -> asyncio.Lock:
"""Return the shared consolidation lock for one session."""
return self._locks.setdefault(session_key, asyncio.Lock())
def pick_consolidation_boundary(
self,
session: Session,
tokens_to_remove: int,
) -> tuple[int, int] | None:
"""Pick a user-turn boundary that removes enough old prompt tokens."""
start = session.last_consolidated
if start >= len(session.messages) or tokens_to_remove <= 0:
return None
removed_tokens = 0
last_boundary: tuple[int, int] | None = None
for idx in range(start, len(session.messages)):
message = session.messages[idx]
if idx > start and message.get("role") == "user":
last_boundary = (idx, removed_tokens)
if removed_tokens >= tokens_to_remove:
return last_boundary
removed_tokens += estimate_message_tokens(message)
return last_boundary
@staticmethod
def _full_unconsolidated_history(
session: Session,
*,
include_timestamps: bool = False,
) -> list[dict[str, Any]]:
"""Return the whole unconsolidated tail for consolidation decisions."""
unconsolidated_count = len(session.messages) - session.last_consolidated
if unconsolidated_count <= 0:
return []
return session.get_history(
max_messages=unconsolidated_count,
include_timestamps=include_timestamps,
)
@staticmethod
def _replay_overflow_boundary(
session: Session,
replay_max_messages: int | None,
) -> int | None:
if not replay_max_messages or replay_max_messages <= 0:
return None
tail = list(enumerate(session.messages[session.last_consolidated:], session.last_consolidated))
if len(tail) <= replay_max_messages:
return None
sliced = tail[-replay_max_messages:]
for i, (_idx, message) in enumerate(sliced):
if message.get("role") == "user":
start = i
if i > 0 and sliced[i - 1][1].get("_channel_delivery"):
start = i - 1
sliced = sliced[start:]
break
legal_start = find_legal_message_start([message for _idx, message in sliced])
if legal_start:
sliced = sliced[legal_start:]
if not sliced:
return len(session.messages)
first_visible_idx = sliced[0][0]
if first_visible_idx <= session.last_consolidated:
return None
return first_visible_idx
async def _consolidate_replay_overflow(
self,
session: Session,
replay_max_messages: int | None,
) -> str | None:
"""Archive messages that would be hidden by the replay message window."""
end_idx = self._replay_overflow_boundary(session, replay_max_messages)
if end_idx is None:
return None
chunk = session.messages[session.last_consolidated:end_idx]
if not chunk:
return None
logger.info(
"Replay-window consolidation for {}: chunk={} msgs, replay_max={}",
session.key,
len(chunk),
replay_max_messages,
)
summary = await self.archive(chunk, session_key=session.key)
session.last_consolidated = end_idx
self.sessions.save(session)
return summary
def _persist_last_summary(self, session: Session, summary: str | None) -> None:
if summary and summary != "(nothing)":
session.metadata["_last_summary"] = {
"text": summary,
"last_active": session.updated_at.isoformat(),
}
self.sessions.save(session)
def estimate_session_prompt_tokens(
self,
session: Session,
) -> tuple[int, str]:
"""Estimate prompt size from the full unconsolidated session tail."""
history = self._full_unconsolidated_history(session, include_timestamps=True)
channel, chat_id = (session.key.split(":", 1) if ":" in session.key else (None, None))
# Include archived summary in estimation so the budget accounts for it.
meta = session.metadata.get("_last_summary")
summary = meta.get("text") if isinstance(meta, dict) else (meta if isinstance(meta, str) else None)
probe_messages = self._build_messages(
history=history,
current_message="[token-probe]",
channel=channel,
chat_id=chat_id,
sender_id=None,
session_summary=summary,
session_metadata=session.metadata,
session_key=session.key,
unified_session=self.unified_session,
)
return estimate_prompt_tokens_chain(
self.provider,
self.model,
probe_messages,
self._get_tool_definitions(),
)
@property
def _input_token_budget(self) -> int:
"""Available input token budget for consolidation LLM."""
return self.context_window_tokens - self.max_completion_tokens - self._SAFETY_BUFFER
def _truncate_to_token_budget(self, text: str) -> str:
"""Truncate text so it fits within the consolidation LLM's token budget."""
budget = self._input_token_budget
if budget <= 0:
return truncate_text(text, _RAW_ARCHIVE_MAX_CHARS)
try:
enc = tiktoken.get_encoding("cl100k_base")
tokens = enc.encode(text)
if len(tokens) <= budget:
return text
return enc.decode(tokens[:budget]) + "\n... (truncated)"
except Exception:
return truncate_text(text, budget * 4)
async def archive(
self,
messages: list[dict],
*,
session_key: str | None = None,
) -> str | None:
"""Summarize messages via LLM and append to history.jsonl.
Returns the summary text on success, None if nothing to archive.
"""
if not messages:
return None
try:
formatted = MemoryStore._format_messages(messages)
formatted = self._truncate_to_token_budget(formatted)
response = await self.provider.chat_with_retry(
model=self.model,
messages=[
{
"role": "system",
"content": render_template(
"agent/consolidator_archive.md",
strip=True,
),
},
{"role": "user", "content": formatted},
],
tools=None,
tool_choice=None,
)
if response.finish_reason == "error":
raise RuntimeError(f"LLM returned error: {response.content}")
summary = response.content or "[no summary]"
self.store.append_history(
summary,
max_chars=_ARCHIVE_SUMMARY_MAX_CHARS,
session_key=session_key,
)
return summary
except Exception:
logger.warning("Consolidation LLM call failed, raw-dumping to history")
self.store.raw_archive(messages, session_key=session_key)
return None
async def maybe_consolidate_by_tokens(
self,
session: Session,
*,
replay_max_messages: int | None = None,
) -> None:
"""Loop: archive old messages until prompt fits within safe budget.
The budget reserves space for completion tokens and a safety buffer
so the LLM request never exceeds the context window.
"""
if self.context_window_tokens <= 0:
return
lock = self.get_lock(session.key)
async with lock:
# Refresh session reference: AutoCompact may have replaced it.
fresh = self.sessions.get_or_create(session.key)
if fresh is not session:
session = fresh
if not session.messages:
return
budget = self._input_token_budget
target = int(budget * self.consolidation_ratio)
last_summary = await self._consolidate_replay_overflow(
session,
replay_max_messages,
)
try:
estimated, source = self.estimate_session_prompt_tokens(
session,
)
except Exception:
logger.exception("Token estimation failed for {}", session.key)
estimated, source = 0, "error"
if estimated <= 0:
self._persist_last_summary(session, last_summary)
return
if estimated < budget:
unconsolidated_count = len(session.messages) - session.last_consolidated
logger.debug(
"Token consolidation idle {}: {}/{} via {}, msgs={}",
session.key,
estimated,
self.context_window_tokens,
source,
unconsolidated_count,
)
self._persist_last_summary(session, last_summary)
return
for round_num in range(self._MAX_CONSOLIDATION_ROUNDS):
if estimated <= target:
break
boundary = self.pick_consolidation_boundary(session, max(1, estimated - target))
if boundary is None:
logger.debug(
"Token consolidation: no safe boundary for {} (round {})",
session.key,
round_num,
)
break
end_idx = boundary[0]
chunk = session.messages[session.last_consolidated:end_idx]
if not chunk:
break
logger.info(
"Token consolidation round {} for {}: {}/{} via {}, chunk={} msgs",
round_num,
session.key,
estimated,
self.context_window_tokens,
source,
len(chunk),
)
summary = await self.archive(chunk, session_key=session.key)
# Advance the cursor either way: on success the chunk was
# summarized; on failure archive() already raw-archived it as
# a breadcrumb. Re-archiving the same chunk on the next call
# would just emit duplicate [RAW] entries.
if summary:
last_summary = summary
session.last_consolidated = end_idx
self.sessions.save(session)
if not summary:
# LLM is degraded — stop hammering it this call;
# the next invocation can retry a fresh chunk.
break
try:
estimated, source = self.estimate_session_prompt_tokens(
session,
)
except Exception:
logger.exception("Token estimation failed for {}", session.key)
estimated, source = 0, "error"
if estimated <= 0:
break
# Persist the last summary to session metadata so it can be injected
# into the runtime context on the next prepare_session() call, aligning
# the summary injection strategy with AutoCompact._archive().
self._persist_last_summary(session, last_summary)
async def compact_idle_session(
self,
session_key: str,
max_suffix: int = 8,
) -> str | None:
"""Hard-truncate an idle session under the consolidation lock.
Used by AutoCompact so all session mutation goes through a single
lock-protected path. Returns the summary text on success, ``None``
if the LLM failed (raw_archive fallback), or ``""`` if there was
nothing to archive.
"""
lock = self.get_lock(session_key)
async with lock:
self.sessions.invalidate(session_key)
session = self.sessions.get_or_create(session_key)
tail = list(session.messages[session.last_consolidated:])
if not tail:
session.updated_at = datetime.now()
self.sessions.save(session)
return ""
probe = Session(
key=session.key,
messages=tail.copy(),
created_at=session.created_at,
updated_at=session.updated_at,
metadata={},
last_consolidated=0,
)
dropped, already_consolidated = probe.retain_recent_legal_suffix(max_suffix)
kept = probe.messages
archive_msgs = dropped[already_consolidated:]
if not archive_msgs and not kept:
session.updated_at = datetime.now()
self.sessions.save(session)
return ""
last_active = session.updated_at
summary: str | None = ""
if archive_msgs:
summary = await self.archive(archive_msgs, session_key=session_key)
if summary and summary != "(nothing)":
session.metadata["_last_summary"] = {
"text": summary,
"last_active": last_active.isoformat(),
}
session.messages = kept
session.last_consolidated = 0
session.updated_at = datetime.now()
self.sessions.save(session)
if archive_msgs:
logger.info(
"Idle-session compact for {}: archived={}, kept={}, summary={}",
session_key,
len(archive_msgs),
len(kept),
bool(summary),
)
return summary