fix(agent): skip auto-compact for sessions with active agent tasks

Prevent proactive compaction from archiving sessions that have an
in-flight agent task, avoiding mid-turn context truncation when a
task runs longer than the idle TTL.
This commit is contained in:
chengyongru 2026-04-13 11:27:16 +08:00 committed by chengyongru
parent 217e1fc957
commit 62bd54ac4a
3 changed files with 115 additions and 7 deletions

View File

@ -2,6 +2,7 @@
from __future__ import annotations
from collections.abc import Collection
from datetime import datetime
from typing import TYPE_CHECKING, Any, Callable, Coroutine
@ -23,12 +24,13 @@ class AutoCompact:
self._archiving: set[str] = set()
self._summaries: dict[str, tuple[str, datetime]] = {}
def _is_expired(self, ts: datetime | str | None) -> bool:
def _is_expired(self, ts: datetime | str | None,
now: datetime | None = None) -> bool:
if self._ttl <= 0 or not ts:
return False
if isinstance(ts, str):
ts = datetime.fromisoformat(ts)
return (datetime.now() - ts).total_seconds() >= self._ttl * 60
return ((now or datetime.now()) - ts).total_seconds() >= self._ttl * 60
@staticmethod
def _format_summary(text: str, last_active: datetime) -> str:
@ -56,10 +58,17 @@ class AutoCompact:
cut = len(tail) - len(kept)
return tail[:cut], kept
def check_expired(self, schedule_background: Callable[[Coroutine], None]) -> None:
def check_expired(self, schedule_background: Callable[[Coroutine], None],
active_session_keys: Collection[str] = ()) -> None:
"""Schedule archival for idle sessions, skipping those with in-flight agent tasks."""
now = datetime.now()
for info in self.sessions.list_sessions():
key = info.get("key", "")
if key and key not in self._archiving and self._is_expired(info.get("updated_at")):
if not key or key in self._archiving:
continue
if key in active_session_keys:
continue
if self._is_expired(info.get("updated_at"), now):
self._archiving.add(key)
logger.debug("Auto-compact: scheduling archival for {} (idle > {} min)", key, self._ttl)
schedule_background(self._archive(key))

View File

@ -433,7 +433,10 @@ class AgentLoop:
try:
msg = await asyncio.wait_for(self.bus.consume_inbound(), timeout=1.0)
except asyncio.TimeoutError:
self.auto_compact.check_expired(self._schedule_background)
self.auto_compact.check_expired(
self._schedule_background,
active_session_keys=self._pending_queues.keys(),
)
continue
except asyncio.CancelledError:
# Preserve real task cancellation so shutdown can complete cleanly.

View File

@ -560,9 +560,12 @@ class TestProactiveAutoCompact:
"""Test proactive auto-new on idle ticks (TimeoutError path in run loop)."""
@staticmethod
async def _run_check_expired(loop):
async def _run_check_expired(loop, active_session_keys=()):
"""Helper: run check_expired via callback and wait for background tasks."""
loop.auto_compact.check_expired(loop._schedule_background)
loop.auto_compact.check_expired(
loop._schedule_background,
active_session_keys=active_session_keys,
)
await asyncio.sleep(0.1)
@pytest.mark.asyncio
@ -701,6 +704,99 @@ class TestProactiveAutoCompact:
assert not archive_called
await loop.close_mcp()
@pytest.mark.asyncio
async def test_skip_expired_session_with_active_agent_task(self, tmp_path):
"""Expired session with an active agent task should NOT be archived."""
loop = _make_loop(tmp_path, session_ttl_minutes=15)
session = loop.sessions.get_or_create("cli:test")
_add_turns(session, 6, prefix="old")
session.updated_at = datetime.now() - timedelta(minutes=20)
loop.sessions.save(session)
archive_count = 0
async def _fake_archive(messages):
nonlocal archive_count
archive_count += 1
return "Summary."
loop.consolidator.archive = _fake_archive
# Simulate an active agent task for this session
await self._run_check_expired(loop, active_session_keys={"cli:test"})
assert archive_count == 0
session_after = loop.sessions.get_or_create("cli:test")
assert len(session_after.messages) == 12 # All messages preserved
await loop.close_mcp()
@pytest.mark.asyncio
async def test_archive_after_active_task_completes(self, tmp_path):
"""Session should be archived on next tick after active task completes."""
loop = _make_loop(tmp_path, session_ttl_minutes=15)
session = loop.sessions.get_or_create("cli:test")
_add_turns(session, 6, prefix="old")
session.updated_at = datetime.now() - timedelta(minutes=20)
loop.sessions.save(session)
archive_count = 0
async def _fake_archive(messages):
nonlocal archive_count
archive_count += 1
return "Summary."
loop.consolidator.archive = _fake_archive
# First tick: active task, skip
await self._run_check_expired(loop, active_session_keys={"cli:test"})
assert archive_count == 0
# Second tick: task completed, should archive
await self._run_check_expired(loop)
assert archive_count == 1
await loop.close_mcp()
@pytest.mark.asyncio
async def test_partial_active_set_only_archives_inactive_expired(self, tmp_path):
"""With multiple sessions, only the expired+inactive one should be archived."""
loop = _make_loop(tmp_path, session_ttl_minutes=15)
# Session A: expired, no active task -> should be archived
s1 = loop.sessions.get_or_create("cli:expired_idle")
_add_turns(s1, 6, prefix="old_a")
s1.updated_at = datetime.now() - timedelta(minutes=20)
loop.sessions.save(s1)
# Session B: expired, has active task -> should be skipped
s2 = loop.sessions.get_or_create("cli:expired_active")
_add_turns(s2, 6, prefix="old_b")
s2.updated_at = datetime.now() - timedelta(minutes=20)
loop.sessions.save(s2)
# Session C: recent, no active task -> should be skipped
s3 = loop.sessions.get_or_create("cli:recent")
s3.add_message("user", "recent")
loop.sessions.save(s3)
archive_count = 0
async def _fake_archive(messages):
nonlocal archive_count
archive_count += 1
return "Summary."
loop.consolidator.archive = _fake_archive
await self._run_check_expired(loop, active_session_keys={"cli:expired_active"})
assert archive_count == 1
s1_after = loop.sessions.get_or_create("cli:expired_idle")
assert len(s1_after.messages) == loop.auto_compact._RECENT_SUFFIX_MESSAGES
s2_after = loop.sessions.get_or_create("cli:expired_active")
assert len(s2_after.messages) == 12 # Preserved
s3_after = loop.sessions.get_or_create("cli:recent")
assert len(s3_after.messages) == 1 # Preserved
await loop.close_mcp()
@pytest.mark.asyncio
async def test_no_reschedule_after_successful_archive(self, tmp_path):
"""Already-archived session should NOT be re-scheduled on subsequent ticks."""