nanobot/nanobot/agent/tools/file_state.py
2026-05-01 19:15:07 +08:00

206 lines
6.6 KiB
Python

"""Track file-read state for read-before-edit warnings and read deduplication."""
from __future__ import annotations
import hashlib
import os
from contextvars import ContextVar, Token
from dataclasses import dataclass
from pathlib import Path
@dataclass(slots=True)
class ReadState:
mtime: float
offset: int
limit: int | None
content_hash: str | None
can_dedup: bool
def _hash_file(p: str) -> str | None:
try:
return hashlib.sha256(Path(p).read_bytes()).hexdigest()
except OSError:
return None
class FileStates:
"""Per-session read/write tracker.
Owns its own state dict so read-dedup ("File unchanged since last read")
and read-before-edit warnings stay scoped to one agent session and do
not leak across sessions sharing this process.
"""
__slots__ = ("_state",)
def __init__(self) -> None:
self._state: dict[str, ReadState] = {}
def record_read(self, path: str | Path, offset: int = 1, limit: int | None = None) -> None:
"""Record that a file was read (called after successful read)."""
p = str(Path(path).resolve())
try:
mtime = os.path.getmtime(p)
except OSError:
return
self._state[p] = ReadState(
mtime=mtime,
offset=offset,
limit=limit,
content_hash=_hash_file(p),
can_dedup=True,
)
def record_write(self, path: str | Path) -> None:
"""Record that a file was written (updates mtime in state)."""
p = str(Path(path).resolve())
try:
mtime = os.path.getmtime(p)
except OSError:
self._state.pop(p, None)
return
self._state[p] = ReadState(
mtime=mtime,
offset=1,
limit=None,
content_hash=_hash_file(p),
can_dedup=False,
)
def check_read(self, path: str | Path) -> str | None:
"""Check if a file has been read and is fresh.
Returns None if OK, or a warning string.
When mtime changed but file content is identical (e.g. touch, editor save),
the check passes to avoid false-positive staleness warnings.
"""
p = str(Path(path).resolve())
entry = self._state.get(p)
if entry is None:
return "Warning: file has not been read yet. Read it first to verify content before editing."
try:
current_mtime = os.path.getmtime(p)
except OSError:
return None
if current_mtime != entry.mtime:
if entry.content_hash and _hash_file(p) == entry.content_hash:
entry.mtime = current_mtime
return None
return "Warning: file has been modified since last read. Re-read to verify content before editing."
# mtime unchanged - still check content hash to detect quick modifications
if entry.content_hash and _hash_file(p) != entry.content_hash:
return "Warning: file has been modified since last read. Re-read to verify content before editing."
return None
def is_unchanged(self, path: str | Path, offset: int = 1, limit: int | None = None) -> bool:
"""Return True if file was previously read with same params and content is unchanged."""
p = str(Path(path).resolve())
entry = self._state.get(p)
if entry is None:
return False
if not entry.can_dedup:
return False
if entry.offset != offset or entry.limit != limit:
return False
try:
current_mtime = os.path.getmtime(p)
except OSError:
return False
if current_mtime != entry.mtime:
# mtime changed - check if content also changed
current_hash = _hash_file(p)
if current_hash != entry.content_hash:
# Content actually changed - don't dedup
entry.can_dedup = False
return False
# Content identical despite mtime change (e.g. touch) - mark as not dedupable to force full read next time
entry.can_dedup = False
return True
# mtime unchanged - content must be identical
return True
def get(self, path: str | Path) -> ReadState | None:
"""Return the raw ReadState entry for a path, or None."""
return self._state.get(str(Path(path).resolve()))
def clear(self) -> None:
"""Clear all tracked state (useful for testing)."""
self._state.clear()
class FileStateStore:
"""Lookup table for per-session file read/write state."""
__slots__ = ("_states_by_key",)
def __init__(self) -> None:
self._states_by_key: dict[str, FileStates] = {}
def for_session(self, session_key: str | None) -> FileStates:
key = session_key or "__default__"
states = self._states_by_key.get(key)
if states is None:
states = FileStates()
self._states_by_key[key] = states
return states
def clear(self) -> None:
self._states_by_key.clear()
_current_file_states: ContextVar[FileStates | None] = ContextVar(
"nanobot_file_states",
default=None,
)
def current_file_states(default: FileStates) -> FileStates:
"""Return the FileStates bound to the current agent task, or a fallback."""
return _current_file_states.get() or default
def bind_file_states(file_states: FileStates) -> Token[FileStates | None]:
"""Bind file read/write state for the current async task."""
return _current_file_states.set(file_states)
def reset_file_states(token: Token[FileStates | None]) -> None:
_current_file_states.reset(token)
# Module-level default instance, retained for backward compatibility with
# tests and callers that reach in directly. Per-session callers should hold
# their own FileStates instance instead of touching this one.
_default = FileStates()
def record_read(path: str | Path, offset: int = 1, limit: int | None = None) -> None:
_default.record_read(path, offset=offset, limit=limit)
def record_write(path: str | Path) -> None:
_default.record_write(path)
def check_read(path: str | Path) -> str | None:
return _default.check_read(path)
def is_unchanged(path: str | Path, offset: int = 1, limit: int | None = None) -> bool:
return _default.is_unchanged(path, offset=offset, limit=limit)
def clear() -> None:
_default.clear()
# Legacy attribute for callers that reached into the module-level dict
# directly (filesystem.py used to do this). Kept as a property-like accessor
# so existing imports keep working.
def __getattr__(name: str):
if name == "_state":
return _default._state
raise AttributeError(name)