Merge PR #3459: feat(session): enforce replay/file-cap invariants for history lifecycle

feat(session): enforce replay/file-cap invariants for history lifecycle
This commit is contained in:
Xubin Ren 2026-04-27 16:17:23 +08:00 committed by GitHub
commit 2fe8d21b6e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 262 additions and 13 deletions

View File

@ -477,6 +477,18 @@ class AgentLoop:
return UNIFIED_SESSION_KEY return UNIFIED_SESSION_KEY
return msg.session_key return msg.session_key
def _replay_token_budget(self) -> int:
"""Derive a token budget for session history replay from the context window."""
if self.context_window_tokens <= 0:
return 0
max_output = getattr(getattr(self.provider, "generation", None), "max_tokens", 4096)
try:
reserved_output = int(max_output)
except (TypeError, ValueError):
reserved_output = 4096
budget = self.context_window_tokens - max(1, reserved_output) - 1024
return budget if budget > 0 else max(128, self.context_window_tokens // 2)
async def _run_agent_loop( async def _run_agent_loop(
self, self,
initial_messages: list[dict], initial_messages: list[dict],
@ -867,7 +879,10 @@ class AgentLoop:
channel, chat_id, msg.metadata.get("message_id"), channel, chat_id, msg.metadata.get("message_id"),
msg.metadata, session_key=key, msg.metadata, session_key=key,
) )
history = session.get_history(max_messages=0, include_timestamps=True) history = session.get_history(
max_tokens=self._replay_token_budget(),
include_timestamps=True,
)
current_role = "assistant" if is_subagent else "user" current_role = "assistant" if is_subagent else "user"
# Subagent content is already in `history` above; passing it again # Subagent content is already in `history` above; passing it again
@ -888,6 +903,7 @@ class AgentLoop:
pending_queue=pending_queue, pending_queue=pending_queue,
) )
self._save_turn(session, all_msgs, 1 + len(history)) self._save_turn(session, all_msgs, 1 + len(history))
session.enforce_file_cap(on_archive=self.context.memory.raw_archive)
self._clear_runtime_checkpoint(session) self._clear_runtime_checkpoint(session)
self.sessions.save(session) self.sessions.save(session)
self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session)) self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session))
@ -950,7 +966,10 @@ class AgentLoop:
if isinstance(message_tool, MessageTool): if isinstance(message_tool, MessageTool):
message_tool.start_turn() message_tool.start_turn()
history = session.get_history(max_messages=0, include_timestamps=True) history = session.get_history(
max_tokens=self._replay_token_budget(),
include_timestamps=True,
)
pending_ask_id = pending_ask_user_id(history) pending_ask_id = pending_ask_user_id(history)
if pending_ask_id: if pending_ask_id:
@ -1038,6 +1057,7 @@ class AgentLoop:
# Skip the already-persisted user message when saving the turn # Skip the already-persisted user message when saving the turn
save_skip = 1 + len(history) + (1 if user_persisted_early else 0) save_skip = 1 + len(history) + (1 if user_persisted_early else 0)
self._save_turn(session, all_msgs, save_skip) self._save_turn(session, all_msgs, save_skip)
session.enforce_file_cap(on_archive=self.context.memory.raw_archive)
self._clear_pending_user_turn(session) self._clear_pending_user_turn(session)
self._clear_runtime_checkpoint(session) self._clear_runtime_checkpoint(session)
self.sessions.save(session) self.sessions.save(session)

View File

@ -12,6 +12,7 @@ from loguru import logger
from nanobot.config.paths import get_legacy_sessions_dir from nanobot.config.paths import get_legacy_sessions_dir
from nanobot.utils.helpers import ( from nanobot.utils.helpers import (
estimate_message_tokens,
ensure_dir, ensure_dir,
find_legal_message_start, find_legal_message_start,
image_placeholder_text, image_placeholder_text,
@ -19,6 +20,10 @@ from nanobot.utils.helpers import (
) )
HISTORY_MAX_MESSAGES = 120
FILE_MAX_MESSAGES = 2000
@dataclass @dataclass
class Session: class Session:
"""A conversation session.""" """A conversation session."""
@ -69,11 +74,16 @@ class Session:
def get_history( def get_history(
self, self,
max_messages: int = 500, max_messages: int = HISTORY_MAX_MESSAGES,
*, *,
max_tokens: int = 0,
include_timestamps: bool = False, include_timestamps: bool = False,
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
"""Return unconsolidated messages for LLM input, aligned to a legal tool-call boundary.""" """Return unconsolidated messages for LLM input.
History is sliced by message count first (``max_messages``), then by
token budget from the tail (``max_tokens``) when provided.
"""
unconsolidated = self.messages[self.last_consolidated:] unconsolidated = self.messages[self.last_consolidated:]
sliced = unconsolidated[-max_messages:] sliced = unconsolidated[-max_messages:]
@ -113,6 +123,38 @@ class Session:
if key in message: if key in message:
entry[key] = message[key] entry[key] = message[key]
out.append(entry) out.append(entry)
if max_tokens > 0 and out:
kept: list[dict[str, Any]] = []
used = 0
for message in reversed(out):
tokens = estimate_message_tokens(message)
if kept and used + tokens > max_tokens:
break
kept.append(message)
used += tokens
kept.reverse()
# Keep history aligned to the first visible user turn.
first_user = next((i for i, m in enumerate(kept) if m.get("role") == "user"), None)
if first_user is not None:
kept = kept[first_user:]
else:
# Tight token budgets can otherwise leave assistant-only tails.
# If a user turn exists in the unsliced output, recover the
# nearest one even if it slightly exceeds the token budget.
recovered_user = next(
(i for i in range(len(out) - 1, -1, -1) if out[i].get("role") == "user"),
None,
)
if recovered_user is not None:
kept = out[recovered_user:]
# And keep a legal tool-call boundary at the front.
start = find_legal_message_start(kept)
if start:
kept = kept[start:]
out = kept
return out return out
def clear(self) -> None: def clear(self) -> None:
@ -122,31 +164,77 @@ class Session:
self.updated_at = datetime.now() self.updated_at = datetime.now()
def retain_recent_legal_suffix(self, max_messages: int) -> None: def retain_recent_legal_suffix(self, max_messages: int) -> None:
"""Keep a legal recent suffix, mirroring get_history boundary rules.""" """Keep a legal recent suffix constrained by a hard message cap."""
if max_messages <= 0: if max_messages <= 0:
self.clear() self.clear()
return return
if len(self.messages) <= max_messages: if len(self.messages) <= max_messages:
return return
start_idx = max(0, len(self.messages) - max_messages) retained = list(self.messages[-max_messages:])
# If the cutoff lands mid-turn, extend backward to the nearest user turn. # Prefer starting at a user turn when one exists within the tail.
while start_idx > 0 and self.messages[start_idx].get("role") != "user": first_user = next((i for i, m in enumerate(retained) if m.get("role") == "user"), None)
start_idx -= 1 if first_user is not None:
retained = retained[first_user:]
retained = self.messages[start_idx:] else:
# If the tail is assistant/tool-only, anchor to the latest user in
# the full session and take a capped forward window from there.
latest_user = next(
(i for i in range(len(self.messages) - 1, -1, -1)
if self.messages[i].get("role") == "user"),
None,
)
if latest_user is not None:
retained = list(self.messages[latest_user: latest_user + max_messages])
# Mirror get_history(): avoid persisting orphan tool results at the front. # Mirror get_history(): avoid persisting orphan tool results at the front.
start = find_legal_message_start(retained) start = find_legal_message_start(retained)
if start: if start:
retained = retained[start:] retained = retained[start:]
# Hard-cap guarantee: never keep more than max_messages.
if len(retained) > max_messages:
retained = retained[-max_messages:]
start = find_legal_message_start(retained)
if start:
retained = retained[start:]
dropped = len(self.messages) - len(retained) dropped = len(self.messages) - len(retained)
self.messages = retained self.messages = retained
self.last_consolidated = max(0, self.last_consolidated - dropped) self.last_consolidated = max(0, self.last_consolidated - dropped)
self.updated_at = datetime.now() self.updated_at = datetime.now()
def enforce_file_cap(
self,
on_archive: Any = None,
limit: int = FILE_MAX_MESSAGES,
) -> None:
"""Bound session message growth by archiving and trimming old prefixes."""
if limit <= 0 or len(self.messages) <= limit:
return
before = list(self.messages)
before_last_consolidated = self.last_consolidated
before_count = len(before)
self.retain_recent_legal_suffix(limit)
dropped_count = before_count - len(self.messages)
if dropped_count <= 0:
return
dropped = before[:dropped_count]
already_consolidated = min(before_last_consolidated, dropped_count)
archive_chunk = dropped[already_consolidated:]
if archive_chunk and on_archive:
on_archive(archive_chunk)
logger.info(
"Session file cap hit for {}: dropped {}, raw-archived {}, kept {}",
self.key,
dropped_count,
len(archive_chunk),
len(self.messages),
)
class SessionManager: class SessionManager:
""" """

View File

@ -2,8 +2,8 @@
import asyncio import asyncio
from datetime import datetime, timedelta from datetime import datetime, timedelta
from unittest.mock import AsyncMock, MagicMock
from pathlib import Path from pathlib import Path
from unittest.mock import AsyncMock, MagicMock
import pytest import pytest
@ -15,7 +15,10 @@ from nanobot.command import CommandContext
from nanobot.providers.base import LLMResponse from nanobot.providers.base import LLMResponse
def _make_loop(tmp_path: Path, session_ttl_minutes: int = 15) -> AgentLoop: def _make_loop(
tmp_path: Path,
session_ttl_minutes: int = 15,
) -> AgentLoop:
"""Create a minimal AgentLoop for testing.""" """Create a minimal AgentLoop for testing."""
bus = MessageBus() bus = MessageBus()
provider = MagicMock() provider = MagicMock()
@ -72,6 +75,12 @@ class TestSessionTTLConfig:
assert data["idleCompactAfterMinutes"] == 30 assert data["idleCompactAfterMinutes"] == 30
assert "sessionTtlMinutes" not in data assert "sessionTtlMinutes" not in data
def test_session_history_and_file_cap_are_internal_constants(self):
"""Session history/file cap should be internal constants, not config fields."""
from nanobot.session.manager import HISTORY_MAX_MESSAGES, FILE_MAX_MESSAGES
assert HISTORY_MAX_MESSAGES == 120
assert FILE_MAX_MESSAGES == 2000
class TestAgentLoopTTLParam: class TestAgentLoopTTLParam:
"""Test that AutoCompact receives and stores session_ttl_minutes.""" """Test that AutoCompact receives and stores session_ttl_minutes."""
@ -86,6 +95,75 @@ class TestAgentLoopTTLParam:
loop = _make_loop(tmp_path, session_ttl_minutes=0) loop = _make_loop(tmp_path, session_ttl_minutes=0)
assert loop.auto_compact._ttl == 0 assert loop.auto_compact._ttl == 0
@pytest.mark.asyncio
async def test_process_message_reads_history_with_token_budget(self, tmp_path):
"""_process_message should pass an auto-derived token budget to get_history."""
loop = _make_loop(tmp_path)
session = loop.sessions.get_or_create("cli:direct")
session.get_history = MagicMock(return_value=[])
loop.context.build_messages = MagicMock(return_value=[])
loop._run_agent_loop = AsyncMock(return_value=("ok", [], [], "stop", False))
loop._save_turn = MagicMock()
msg = InboundMessage(
channel="cli",
sender_id="u1",
chat_id="direct",
content="hello",
)
await loop._process_message(msg)
session.get_history.assert_called_once()
kwargs = session.get_history.call_args.kwargs
assert isinstance(kwargs.get("max_tokens"), int)
assert kwargs["max_tokens"] > 0
assert kwargs["include_timestamps"] is True
@pytest.mark.asyncio
async def test_session_file_cap_archives_and_trims_old_messages(self, tmp_path):
loop = _make_loop(tmp_path)
loop.context.memory.raw_archive = MagicMock()
for i in range(4):
msg = InboundMessage(
channel="cli",
sender_id="u1",
chat_id="direct",
content=f"hello {i}",
)
await loop._process_message(msg)
session = loop.sessions.get_or_create("cli:direct")
from nanobot.session.manager import FILE_MAX_MESSAGES
assert len(session.messages) <= FILE_MAX_MESSAGES
def test_session_enforce_file_cap_skips_archive_when_dropped_prefix_already_consolidated(self, tmp_path):
from nanobot.session.manager import Session
archive_fn = MagicMock()
session = Session(key="cli:direct")
for i in range(8):
session.add_message("user", f"u{i}")
session.last_consolidated = 6
session.enforce_file_cap(on_archive=archive_fn, limit=4)
assert len(session.messages) <= 4
archive_fn.assert_not_called()
def test_session_enforce_file_cap_archives_only_unconsolidated_dropped_prefix(self, tmp_path):
from nanobot.session.manager import Session
archive_fn = MagicMock()
session = Session(key="cli:direct")
for i in range(8):
session.add_message("user", f"u{i}")
session.last_consolidated = 2
session.enforce_file_cap(on_archive=archive_fn, limit=4)
assert len(session.messages) <= 4
archive_fn.assert_called_once()
archived = archive_fn.call_args.args[0]
assert [m["content"] for m in archived] == ["u2", "u3"]
class TestAutoCompact: class TestAutoCompact:
"""Test the _archive method.""" """Test the _archive method."""

View File

@ -350,3 +350,66 @@ def test_get_history_ignores_media_kwarg_on_non_user_rows():
# List content is passed through verbatim — the synthesizer only # List content is passed through verbatim — the synthesizer only
# rewrites plain-string content. # rewrites plain-string content.
assert history[0]["content"] == [{"type": "text", "text": "structured"}] assert history[0]["content"] == [{"type": "text", "text": "structured"}]
def test_get_history_respects_max_tokens(monkeypatch):
session = Session(key="test:token-cap")
session.messages.extend(
[
{"role": "user", "content": "u1"},
{"role": "assistant", "content": "a1"},
{"role": "user", "content": "u2"},
{"role": "assistant", "content": "a2"},
{"role": "user", "content": "u3"},
{"role": "assistant", "content": "a3"},
]
)
token_map = {"u1": 50, "a1": 50, "u2": 50, "a2": 50, "u3": 50, "a3": 50}
monkeypatch.setattr(
"nanobot.session.manager.estimate_message_tokens",
lambda message: token_map.get(message.get("content"), 0),
)
history = session.get_history(max_messages=500, max_tokens=120)
assert [m["content"] for m in history] == ["u3", "a3"]
def test_get_history_recovers_user_when_token_slice_would_be_assistant_only(monkeypatch):
session = Session(key="test:assistant-only-slice")
session.messages.extend(
[
{"role": "user", "content": "u1"},
{"role": "assistant", "content": "a1"},
{"role": "user", "content": "u2"},
{"role": "assistant", "content": "a2"},
]
)
token_map = {"u1": 100, "a1": 100, "u2": 100, "a2": 100}
monkeypatch.setattr(
"nanobot.session.manager.estimate_message_tokens",
lambda message: token_map.get(message.get("content"), 0),
)
history = session.get_history(max_messages=500, max_tokens=100)
assert [m["content"] for m in history] == ["u2", "a2"]
def test_retain_recent_legal_suffix_hard_cap_with_long_non_user_chain():
session = Session(key="test:hard-cap-chain")
session.messages.append({"role": "user", "content": "u0"})
session.messages.append(
{
"role": "assistant",
"content": None,
"tool_calls": [
{"id": "c1", "type": "function", "function": {"name": "x", "arguments": "{}"}}
],
}
)
for i in range(12):
session.messages.append({"role": "assistant", "content": f"a{i}"})
session.retain_recent_legal_suffix(6)
assert len(session.messages) <= 6