fix(memory): use parent tree in dream-restore revert to properly undo commit

The revert method was using the commit's own tree instead of its parent's,
which meant /dream-restore would restore TO that commit rather than UNDO it.
Also add root commit guard to prevent crash when reverting the initial commit.
This commit is contained in:
chengyongru 2026-04-03 14:05:17 +08:00
parent 475ea06294
commit a662ace8dd
2 changed files with 19 additions and 7 deletions

View File

@ -238,10 +238,11 @@ class GitStore:
# -- restore --------------------------------------------------------------- # -- restore ---------------------------------------------------------------
def revert(self, commit: str) -> str | None: def revert(self, commit: str) -> str | None:
"""Restore all tracked memory files to their state at the given commit. """Revert (undo) the changes introduced by the given commit.
Restores all tracked memory files to the state at the commit's parent,
then creates a new commit recording the revert.
This reads the file contents from the target commit, writes them back,
and creates a new commit. Does not require merge3.
Returns the new commit SHA, or None on failure. Returns the new commit SHA, or None on failure.
""" """
if not self.is_initialized(): if not self.is_initialized():
@ -255,13 +256,20 @@ class GitStore:
logger.warning("Git revert: SHA not found: {}", commit) logger.warning("Git revert: SHA not found: {}", commit)
return None return None
restored: list[str] = []
with Repo(str(self._workspace)) as repo: with Repo(str(self._workspace)) as repo:
commit_obj = repo[full_sha] commit_obj = repo[full_sha]
if commit_obj.type_name != b"commit": if commit_obj.type_name != b"commit":
return None return None
tree = repo[commit_obj.tree]
if not commit_obj.parents:
logger.warning("Git revert: cannot revert root commit {}", commit)
return None
# Use the parent's tree — this undoes the commit's changes
parent_obj = repo[commit_obj.parents[0]]
tree = repo[parent_obj.tree]
restored: list[str] = []
for filepath in self._tracked_files: for filepath in self._tracked_files:
content = self._read_blob_from_tree(repo, tree, filepath) content = self._read_blob_from_tree(repo, tree, filepath)
if content is not None: if content is not None:
@ -273,7 +281,7 @@ class GitStore:
return None return None
# Commit the restored state # Commit the restored state
msg = f"revert: restore to {commit}" msg = f"revert: undo {commit}"
return self.auto_commit(msg) return self.auto_commit(msg)
except Exception: except Exception:
logger.warning("Git revert failed for {}", commit) logger.warning("Git revert failed for {}", commit)

View File

@ -205,10 +205,14 @@ class TestRevert:
git_ready.auto_commit("v2") git_ready.auto_commit("v2")
commits = git_ready.log() commits = git_ready.log()
new_sha = git_ready.revert(commits[1].sha) # revert to init new_sha = git_ready.revert(commits[0].sha) # undo v2 → back to init
assert new_sha is not None assert new_sha is not None
assert (ws / "SOUL.md").read_text(encoding="utf-8") == "" assert (ws / "SOUL.md").read_text(encoding="utf-8") == ""
def test_cannot_revert_root_commit(self, git_ready):
commits = git_ready.log()
assert git_ready.revert(commits[-1].sha) is None
def test_invalid_sha_returns_none(self, git_ready): def test_invalid_sha_returns_none(self, git_ready):
assert git_ready.revert("deadbeef") is None assert git_ready.revert("deadbeef") is None