mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-03 16:25:53 +00:00
Merge PR #3459: feat(session): enforce replay/file-cap invariants for history lifecycle
feat(session): enforce replay/file-cap invariants for history lifecycle
This commit is contained in:
commit
2fe8d21b6e
@ -477,6 +477,18 @@ class AgentLoop:
|
||||
return UNIFIED_SESSION_KEY
|
||||
return msg.session_key
|
||||
|
||||
def _replay_token_budget(self) -> int:
|
||||
"""Derive a token budget for session history replay from the context window."""
|
||||
if self.context_window_tokens <= 0:
|
||||
return 0
|
||||
max_output = getattr(getattr(self.provider, "generation", None), "max_tokens", 4096)
|
||||
try:
|
||||
reserved_output = int(max_output)
|
||||
except (TypeError, ValueError):
|
||||
reserved_output = 4096
|
||||
budget = self.context_window_tokens - max(1, reserved_output) - 1024
|
||||
return budget if budget > 0 else max(128, self.context_window_tokens // 2)
|
||||
|
||||
async def _run_agent_loop(
|
||||
self,
|
||||
initial_messages: list[dict],
|
||||
@ -867,7 +879,10 @@ class AgentLoop:
|
||||
channel, chat_id, msg.metadata.get("message_id"),
|
||||
msg.metadata, session_key=key,
|
||||
)
|
||||
history = session.get_history(max_messages=0, include_timestamps=True)
|
||||
history = session.get_history(
|
||||
max_tokens=self._replay_token_budget(),
|
||||
include_timestamps=True,
|
||||
)
|
||||
current_role = "assistant" if is_subagent else "user"
|
||||
|
||||
# Subagent content is already in `history` above; passing it again
|
||||
@ -888,6 +903,7 @@ class AgentLoop:
|
||||
pending_queue=pending_queue,
|
||||
)
|
||||
self._save_turn(session, all_msgs, 1 + len(history))
|
||||
session.enforce_file_cap(on_archive=self.context.memory.raw_archive)
|
||||
self._clear_runtime_checkpoint(session)
|
||||
self.sessions.save(session)
|
||||
self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session))
|
||||
@ -950,7 +966,10 @@ class AgentLoop:
|
||||
if isinstance(message_tool, MessageTool):
|
||||
message_tool.start_turn()
|
||||
|
||||
history = session.get_history(max_messages=0, include_timestamps=True)
|
||||
history = session.get_history(
|
||||
max_tokens=self._replay_token_budget(),
|
||||
include_timestamps=True,
|
||||
)
|
||||
|
||||
pending_ask_id = pending_ask_user_id(history)
|
||||
if pending_ask_id:
|
||||
@ -1038,6 +1057,7 @@ class AgentLoop:
|
||||
# Skip the already-persisted user message when saving the turn
|
||||
save_skip = 1 + len(history) + (1 if user_persisted_early else 0)
|
||||
self._save_turn(session, all_msgs, save_skip)
|
||||
session.enforce_file_cap(on_archive=self.context.memory.raw_archive)
|
||||
self._clear_pending_user_turn(session)
|
||||
self._clear_runtime_checkpoint(session)
|
||||
self.sessions.save(session)
|
||||
|
||||
@ -12,6 +12,7 @@ from loguru import logger
|
||||
|
||||
from nanobot.config.paths import get_legacy_sessions_dir
|
||||
from nanobot.utils.helpers import (
|
||||
estimate_message_tokens,
|
||||
ensure_dir,
|
||||
find_legal_message_start,
|
||||
image_placeholder_text,
|
||||
@ -19,6 +20,10 @@ from nanobot.utils.helpers import (
|
||||
)
|
||||
|
||||
|
||||
HISTORY_MAX_MESSAGES = 120
|
||||
FILE_MAX_MESSAGES = 2000
|
||||
|
||||
|
||||
@dataclass
|
||||
class Session:
|
||||
"""A conversation session."""
|
||||
@ -69,11 +74,16 @@ class Session:
|
||||
|
||||
def get_history(
|
||||
self,
|
||||
max_messages: int = 500,
|
||||
max_messages: int = HISTORY_MAX_MESSAGES,
|
||||
*,
|
||||
max_tokens: int = 0,
|
||||
include_timestamps: bool = False,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Return unconsolidated messages for LLM input, aligned to a legal tool-call boundary."""
|
||||
"""Return unconsolidated messages for LLM input.
|
||||
|
||||
History is sliced by message count first (``max_messages``), then by
|
||||
token budget from the tail (``max_tokens``) when provided.
|
||||
"""
|
||||
unconsolidated = self.messages[self.last_consolidated:]
|
||||
sliced = unconsolidated[-max_messages:]
|
||||
|
||||
@ -113,6 +123,38 @@ class Session:
|
||||
if key in message:
|
||||
entry[key] = message[key]
|
||||
out.append(entry)
|
||||
|
||||
if max_tokens > 0 and out:
|
||||
kept: list[dict[str, Any]] = []
|
||||
used = 0
|
||||
for message in reversed(out):
|
||||
tokens = estimate_message_tokens(message)
|
||||
if kept and used + tokens > max_tokens:
|
||||
break
|
||||
kept.append(message)
|
||||
used += tokens
|
||||
kept.reverse()
|
||||
|
||||
# Keep history aligned to the first visible user turn.
|
||||
first_user = next((i for i, m in enumerate(kept) if m.get("role") == "user"), None)
|
||||
if first_user is not None:
|
||||
kept = kept[first_user:]
|
||||
else:
|
||||
# Tight token budgets can otherwise leave assistant-only tails.
|
||||
# If a user turn exists in the unsliced output, recover the
|
||||
# nearest one even if it slightly exceeds the token budget.
|
||||
recovered_user = next(
|
||||
(i for i in range(len(out) - 1, -1, -1) if out[i].get("role") == "user"),
|
||||
None,
|
||||
)
|
||||
if recovered_user is not None:
|
||||
kept = out[recovered_user:]
|
||||
|
||||
# And keep a legal tool-call boundary at the front.
|
||||
start = find_legal_message_start(kept)
|
||||
if start:
|
||||
kept = kept[start:]
|
||||
out = kept
|
||||
return out
|
||||
|
||||
def clear(self) -> None:
|
||||
@ -122,31 +164,77 @@ class Session:
|
||||
self.updated_at = datetime.now()
|
||||
|
||||
def retain_recent_legal_suffix(self, max_messages: int) -> None:
|
||||
"""Keep a legal recent suffix, mirroring get_history boundary rules."""
|
||||
"""Keep a legal recent suffix constrained by a hard message cap."""
|
||||
if max_messages <= 0:
|
||||
self.clear()
|
||||
return
|
||||
if len(self.messages) <= max_messages:
|
||||
return
|
||||
|
||||
start_idx = max(0, len(self.messages) - max_messages)
|
||||
retained = list(self.messages[-max_messages:])
|
||||
|
||||
# If the cutoff lands mid-turn, extend backward to the nearest user turn.
|
||||
while start_idx > 0 and self.messages[start_idx].get("role") != "user":
|
||||
start_idx -= 1
|
||||
|
||||
retained = self.messages[start_idx:]
|
||||
# Prefer starting at a user turn when one exists within the tail.
|
||||
first_user = next((i for i, m in enumerate(retained) if m.get("role") == "user"), None)
|
||||
if first_user is not None:
|
||||
retained = retained[first_user:]
|
||||
else:
|
||||
# If the tail is assistant/tool-only, anchor to the latest user in
|
||||
# the full session and take a capped forward window from there.
|
||||
latest_user = next(
|
||||
(i for i in range(len(self.messages) - 1, -1, -1)
|
||||
if self.messages[i].get("role") == "user"),
|
||||
None,
|
||||
)
|
||||
if latest_user is not None:
|
||||
retained = list(self.messages[latest_user: latest_user + max_messages])
|
||||
|
||||
# Mirror get_history(): avoid persisting orphan tool results at the front.
|
||||
start = find_legal_message_start(retained)
|
||||
if start:
|
||||
retained = retained[start:]
|
||||
|
||||
# Hard-cap guarantee: never keep more than max_messages.
|
||||
if len(retained) > max_messages:
|
||||
retained = retained[-max_messages:]
|
||||
start = find_legal_message_start(retained)
|
||||
if start:
|
||||
retained = retained[start:]
|
||||
|
||||
dropped = len(self.messages) - len(retained)
|
||||
self.messages = retained
|
||||
self.last_consolidated = max(0, self.last_consolidated - dropped)
|
||||
self.updated_at = datetime.now()
|
||||
|
||||
def enforce_file_cap(
|
||||
self,
|
||||
on_archive: Any = None,
|
||||
limit: int = FILE_MAX_MESSAGES,
|
||||
) -> None:
|
||||
"""Bound session message growth by archiving and trimming old prefixes."""
|
||||
if limit <= 0 or len(self.messages) <= limit:
|
||||
return
|
||||
|
||||
before = list(self.messages)
|
||||
before_last_consolidated = self.last_consolidated
|
||||
before_count = len(before)
|
||||
self.retain_recent_legal_suffix(limit)
|
||||
dropped_count = before_count - len(self.messages)
|
||||
if dropped_count <= 0:
|
||||
return
|
||||
|
||||
dropped = before[:dropped_count]
|
||||
already_consolidated = min(before_last_consolidated, dropped_count)
|
||||
archive_chunk = dropped[already_consolidated:]
|
||||
if archive_chunk and on_archive:
|
||||
on_archive(archive_chunk)
|
||||
logger.info(
|
||||
"Session file cap hit for {}: dropped {}, raw-archived {}, kept {}",
|
||||
self.key,
|
||||
dropped_count,
|
||||
len(archive_chunk),
|
||||
len(self.messages),
|
||||
)
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""
|
||||
|
||||
@ -2,8 +2,8 @@
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
@ -15,7 +15,10 @@ from nanobot.command import CommandContext
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
|
||||
def _make_loop(tmp_path: Path, session_ttl_minutes: int = 15) -> AgentLoop:
|
||||
def _make_loop(
|
||||
tmp_path: Path,
|
||||
session_ttl_minutes: int = 15,
|
||||
) -> AgentLoop:
|
||||
"""Create a minimal AgentLoop for testing."""
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
@ -72,6 +75,12 @@ class TestSessionTTLConfig:
|
||||
assert data["idleCompactAfterMinutes"] == 30
|
||||
assert "sessionTtlMinutes" not in data
|
||||
|
||||
def test_session_history_and_file_cap_are_internal_constants(self):
|
||||
"""Session history/file cap should be internal constants, not config fields."""
|
||||
from nanobot.session.manager import HISTORY_MAX_MESSAGES, FILE_MAX_MESSAGES
|
||||
assert HISTORY_MAX_MESSAGES == 120
|
||||
assert FILE_MAX_MESSAGES == 2000
|
||||
|
||||
|
||||
class TestAgentLoopTTLParam:
|
||||
"""Test that AutoCompact receives and stores session_ttl_minutes."""
|
||||
@ -86,6 +95,75 @@ class TestAgentLoopTTLParam:
|
||||
loop = _make_loop(tmp_path, session_ttl_minutes=0)
|
||||
assert loop.auto_compact._ttl == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_message_reads_history_with_token_budget(self, tmp_path):
|
||||
"""_process_message should pass an auto-derived token budget to get_history."""
|
||||
loop = _make_loop(tmp_path)
|
||||
session = loop.sessions.get_or_create("cli:direct")
|
||||
session.get_history = MagicMock(return_value=[])
|
||||
loop.context.build_messages = MagicMock(return_value=[])
|
||||
loop._run_agent_loop = AsyncMock(return_value=("ok", [], [], "stop", False))
|
||||
loop._save_turn = MagicMock()
|
||||
|
||||
msg = InboundMessage(
|
||||
channel="cli",
|
||||
sender_id="u1",
|
||||
chat_id="direct",
|
||||
content="hello",
|
||||
)
|
||||
await loop._process_message(msg)
|
||||
session.get_history.assert_called_once()
|
||||
kwargs = session.get_history.call_args.kwargs
|
||||
assert isinstance(kwargs.get("max_tokens"), int)
|
||||
assert kwargs["max_tokens"] > 0
|
||||
assert kwargs["include_timestamps"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_file_cap_archives_and_trims_old_messages(self, tmp_path):
|
||||
loop = _make_loop(tmp_path)
|
||||
loop.context.memory.raw_archive = MagicMock()
|
||||
|
||||
for i in range(4):
|
||||
msg = InboundMessage(
|
||||
channel="cli",
|
||||
sender_id="u1",
|
||||
chat_id="direct",
|
||||
content=f"hello {i}",
|
||||
)
|
||||
await loop._process_message(msg)
|
||||
|
||||
session = loop.sessions.get_or_create("cli:direct")
|
||||
from nanobot.session.manager import FILE_MAX_MESSAGES
|
||||
assert len(session.messages) <= FILE_MAX_MESSAGES
|
||||
|
||||
def test_session_enforce_file_cap_skips_archive_when_dropped_prefix_already_consolidated(self, tmp_path):
|
||||
from nanobot.session.manager import Session
|
||||
archive_fn = MagicMock()
|
||||
session = Session(key="cli:direct")
|
||||
for i in range(8):
|
||||
session.add_message("user", f"u{i}")
|
||||
session.last_consolidated = 6
|
||||
|
||||
session.enforce_file_cap(on_archive=archive_fn, limit=4)
|
||||
|
||||
assert len(session.messages) <= 4
|
||||
archive_fn.assert_not_called()
|
||||
|
||||
def test_session_enforce_file_cap_archives_only_unconsolidated_dropped_prefix(self, tmp_path):
|
||||
from nanobot.session.manager import Session
|
||||
archive_fn = MagicMock()
|
||||
session = Session(key="cli:direct")
|
||||
for i in range(8):
|
||||
session.add_message("user", f"u{i}")
|
||||
session.last_consolidated = 2
|
||||
|
||||
session.enforce_file_cap(on_archive=archive_fn, limit=4)
|
||||
|
||||
assert len(session.messages) <= 4
|
||||
archive_fn.assert_called_once()
|
||||
archived = archive_fn.call_args.args[0]
|
||||
assert [m["content"] for m in archived] == ["u2", "u3"]
|
||||
|
||||
|
||||
class TestAutoCompact:
|
||||
"""Test the _archive method."""
|
||||
|
||||
@ -350,3 +350,66 @@ def test_get_history_ignores_media_kwarg_on_non_user_rows():
|
||||
# List content is passed through verbatim — the synthesizer only
|
||||
# rewrites plain-string content.
|
||||
assert history[0]["content"] == [{"type": "text", "text": "structured"}]
|
||||
|
||||
|
||||
def test_get_history_respects_max_tokens(monkeypatch):
|
||||
session = Session(key="test:token-cap")
|
||||
session.messages.extend(
|
||||
[
|
||||
{"role": "user", "content": "u1"},
|
||||
{"role": "assistant", "content": "a1"},
|
||||
{"role": "user", "content": "u2"},
|
||||
{"role": "assistant", "content": "a2"},
|
||||
{"role": "user", "content": "u3"},
|
||||
{"role": "assistant", "content": "a3"},
|
||||
]
|
||||
)
|
||||
|
||||
token_map = {"u1": 50, "a1": 50, "u2": 50, "a2": 50, "u3": 50, "a3": 50}
|
||||
monkeypatch.setattr(
|
||||
"nanobot.session.manager.estimate_message_tokens",
|
||||
lambda message: token_map.get(message.get("content"), 0),
|
||||
)
|
||||
|
||||
history = session.get_history(max_messages=500, max_tokens=120)
|
||||
assert [m["content"] for m in history] == ["u3", "a3"]
|
||||
|
||||
|
||||
def test_get_history_recovers_user_when_token_slice_would_be_assistant_only(monkeypatch):
|
||||
session = Session(key="test:assistant-only-slice")
|
||||
session.messages.extend(
|
||||
[
|
||||
{"role": "user", "content": "u1"},
|
||||
{"role": "assistant", "content": "a1"},
|
||||
{"role": "user", "content": "u2"},
|
||||
{"role": "assistant", "content": "a2"},
|
||||
]
|
||||
)
|
||||
token_map = {"u1": 100, "a1": 100, "u2": 100, "a2": 100}
|
||||
monkeypatch.setattr(
|
||||
"nanobot.session.manager.estimate_message_tokens",
|
||||
lambda message: token_map.get(message.get("content"), 0),
|
||||
)
|
||||
|
||||
history = session.get_history(max_messages=500, max_tokens=100)
|
||||
assert [m["content"] for m in history] == ["u2", "a2"]
|
||||
|
||||
|
||||
def test_retain_recent_legal_suffix_hard_cap_with_long_non_user_chain():
|
||||
session = Session(key="test:hard-cap-chain")
|
||||
session.messages.append({"role": "user", "content": "u0"})
|
||||
session.messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{"id": "c1", "type": "function", "function": {"name": "x", "arguments": "{}"}}
|
||||
],
|
||||
}
|
||||
)
|
||||
for i in range(12):
|
||||
session.messages.append({"role": "assistant", "content": f"a{i}"})
|
||||
|
||||
session.retain_recent_legal_suffix(6)
|
||||
|
||||
assert len(session.messages) <= 6
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user