refactor(memory): centralize cursor validation behind a single gate

Move the non-int cursor guard out of the two consumer sites and into a
shared ``_iter_valid_entries`` iterator so the invariant lives in one
place.  Closes three gaps left by the original fix:

* ``bool`` is now rejected — ``isinstance(True, int)`` is ``True`` in
  Python, so the previous guard silently treated ``{"cursor": true}`` as
  cursor ``1``.
* Recovery now returns ``max(valid cursors) + 1``.  Under adversarial
  corruption "first int scanning in reverse" is not the same thing, and
  only ``max`` keeps the recovered cursor strictly greater than every
  legitimate cursor still on disk.
* Non-int cursors are logged exactly once per ``MemoryStore``.  Silently
  dropping corrupted entries hides the root cause (an external writer
  to ``memory/history.jsonl``); rate-limiting keeps the log clean when
  the same poisoned file is read every turn.

All 7 tests from the original fix pass unchanged; 3 new tests pin the
invariants above.

Made-with: Cursor
This commit is contained in:
Xubin Ren 2026-04-21 05:54:17 +00:00 committed by Xubin Ren
parent c0a11c7cf4
commit c1957e14ff
2 changed files with 118 additions and 22 deletions

View File

@ -8,7 +8,7 @@ import re
import weakref import weakref
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable from typing import TYPE_CHECKING, Any, Callable, Iterator
from loguru import logger from loguru import logger
@ -49,6 +49,7 @@ class MemoryStore:
self.user_file = workspace / "USER.md" self.user_file = workspace / "USER.md"
self._cursor_file = self.memory_dir / ".cursor" self._cursor_file = self.memory_dir / ".cursor"
self._dream_cursor_file = self.memory_dir / ".dream_cursor" self._dream_cursor_file = self.memory_dir / ".dream_cursor"
self._corruption_logged = False # rate-limit non-int cursor warning
self._git = GitStore(workspace, tracked_files=[ self._git = GitStore(workspace, tracked_files=[
"SOUL.md", "USER.md", "memory/MEMORY.md", "SOUL.md", "USER.md", "memory/MEMORY.md",
]) ])
@ -246,35 +247,52 @@ class MemoryStore:
self._cursor_file.write_text(str(cursor), encoding="utf-8") self._cursor_file.write_text(str(cursor), encoding="utf-8")
return cursor return cursor
@staticmethod
def _valid_cursor(value: Any) -> int | None:
"""Int cursors only — reject bool (``isinstance(True, int)`` is True)."""
if isinstance(value, bool) or not isinstance(value, int):
return None
return value
def _iter_valid_entries(self) -> Iterator[tuple[dict[str, Any], int]]:
"""Yield ``(entry, cursor)`` for entries with int cursors; warn once on corruption."""
poisoned: Any = None
for entry in self._read_entries():
raw = entry.get("cursor")
if raw is None:
continue
cursor = self._valid_cursor(raw)
if cursor is None:
poisoned = raw
continue
yield entry, cursor
if poisoned is not None and not self._corruption_logged:
self._corruption_logged = True
logger.warning(
"history.jsonl contains a non-int cursor ({!r}); dropping it. "
"Usually caused by an external writer; further occurrences suppressed.",
poisoned,
)
def _next_cursor(self) -> int: def _next_cursor(self) -> int:
"""Read the current cursor counter and return next value.""" """Read the current cursor counter and return the next value."""
if self._cursor_file.exists(): if self._cursor_file.exists():
try: try:
return int(self._cursor_file.read_text(encoding="utf-8").strip()) + 1 return int(self._cursor_file.read_text(encoding="utf-8").strip()) + 1
except (ValueError, OSError): except (ValueError, OSError):
pass pass
# Fallback: read last line's cursor from the JSONL file. # Fast path: trust the tail when intact. Otherwise scan the whole
last = self._read_last_entry() # file and take ``max`` — that stays correct even if the monotonic
if last and last.get("cursor"): # invariant was broken by external writes.
cursor = last["cursor"] last = self._read_last_entry() or {}
if isinstance(cursor, int): cursor = self._valid_cursor(last.get("cursor"))
return cursor + 1 if cursor is not None:
# Corrupted (non-int) cursor — scan all entries for the highest valid one. return cursor + 1
entries = self._read_entries() return max((c for _, c in self._iter_valid_entries()), default=0) + 1
for entry in reversed(entries):
c = entry.get("cursor")
if isinstance(c, int):
return c + 1
return 1
return 1
def read_unprocessed_history(self, since_cursor: int) -> list[dict[str, Any]]: def read_unprocessed_history(self, since_cursor: int) -> list[dict[str, Any]]:
"""Return history entries with cursor > *since_cursor*.""" """Return history entries with a valid cursor > *since_cursor*."""
return [ return [e for e, c in self._iter_valid_entries() if c > since_cursor]
e
for e in self._read_entries()
if isinstance(e.get("cursor"), int) and e["cursor"] > since_cursor
]
def compact_history(self) -> None: def compact_history(self) -> None:
"""Drop oldest entries if the file exceeds *max_history_entries*.""" """Drop oldest entries if the file exceeds *max_history_entries*."""

View File

@ -108,3 +108,81 @@ class TestReadUnprocessedWithCorruption:
assert len(entries) == 2 assert len(entries) == 2
assert entries[0]["cursor"] == 2 assert entries[0]["cursor"] == 2
assert entries[1]["cursor"] == 3 assert entries[1]["cursor"] == 3
class TestCursorValidationInvariant:
"""First-principles checks: the cursor validity rules and the
observability we layer on top of them."""
def test_bool_cursor_rejected(self, store):
"""``isinstance(True, int) is True`` in Python; the guard must
still treat ``{"cursor": true}`` as corruption, otherwise a
boolean silently becomes cursor ``1`` / ``0`` downstream.
"""
assert MemoryStore._valid_cursor(True) is None
assert MemoryStore._valid_cursor(False) is None
assert MemoryStore._valid_cursor(5) == 5
assert MemoryStore._valid_cursor(0) == 0
store.history_file.write_text(
'{"cursor": 4, "timestamp": "2026-04-01 10:00", "content": "real"}\n'
'{"cursor": true, "timestamp": "2026-04-01 10:01", "content": "bool"}\n',
encoding="utf-8",
)
store._cursor_file.unlink(missing_ok=True)
assert store.append_history("next") == 5
entries = store.read_unprocessed_history(since_cursor=0)
assert [e["cursor"] for e in entries] == [4, 5]
def test_next_cursor_returns_max_not_just_last_int(self, store):
"""Under adversarial corruption, file order ≠ numeric order. The
recovery scan must return ``max(valid cursors) + 1``, not the
first int seen from the tail, so the returned cursor is strictly
greater than every legitimate cursor already on disk.
"""
# Tail is corrupt → recovery scan runs. Valid cursors are 100
# and 5, in that order on disk; a naive "first int from the tail"
# recovery would return 6, which would then silently collide with
# the existing cursor 100. ``max`` is the only safe choice.
store.history_file.write_text(
'{"cursor": 100, "timestamp": "2026-04-01 10:00", "content": "high"}\n'
'{"cursor": 5, "timestamp": "2026-04-01 10:01", "content": "out of order"}\n'
'{"cursor": "poison", "timestamp": "2026-04-01 10:02", "content": "tail corrupt"}\n',
encoding="utf-8",
)
store._cursor_file.unlink(missing_ok=True)
assert store.append_history("safe next") == 101
def test_corruption_is_logged_exactly_once_per_store(self, store, caplog):
"""Observability without spam: the first non-int cursor emits one
warning, subsequent reads on the same store stay quiet. Without
this, a poisoned file produces one warning per agent turn."""
import logging
from loguru import logger as loguru_logger
store.history_file.write_text(
'{"cursor": "bad1", "timestamp": "2026-04-01 10:00", "content": "x"}\n'
'{"cursor": 2, "timestamp": "2026-04-01 10:01", "content": "y"}\n',
encoding="utf-8",
)
store._cursor_file.unlink(missing_ok=True)
handler_id = loguru_logger.add(
caplog.handler, format="{message}", level="WARNING"
)
try:
with caplog.at_level(logging.WARNING):
store.read_unprocessed_history(since_cursor=0)
store.read_unprocessed_history(since_cursor=0)
store.append_history("another")
finally:
loguru_logger.remove(handler_id)
corruption_warnings = [
r for r in caplog.records if "non-int cursor" in r.getMessage()
]
assert len(corruption_warnings) == 1, (
"Expected exactly one corruption warning per store instance; "
f"got {len(corruption_warnings)}: {[r.getMessage() for r in corruption_warnings]}"
)