fix(apply-patch): tighten edits-only boundaries

This commit is contained in:
Xubin Ren 2026-05-22 17:09:59 +08:00
parent 3d9f50a0cc
commit b0d3069621
2 changed files with 93 additions and 6 deletions

View File

@ -31,7 +31,7 @@ class _PatchError(ValueError):
pass pass
_ABSOLUTE_WINDOWS_RE = re.compile(r"^[A-Za-z]:[\/]") _ABSOLUTE_WINDOWS_RE = re.compile(r"^[A-Za-z]:[\\/]")
def _validate_relative_path(path: str) -> str: def _validate_relative_path(path: str) -> str:
@ -42,7 +42,7 @@ def _validate_relative_path(path: str) -> str:
raise _PatchError(f"patch path contains a null byte: {path!r}") raise _PatchError(f"patch path contains a null byte: {path!r}")
if normalized.startswith(("~", "/", "\\")) or _ABSOLUTE_WINDOWS_RE.match(normalized): if normalized.startswith(("~", "/", "\\")) or _ABSOLUTE_WINDOWS_RE.match(normalized):
raise _PatchError(f"patch path must be relative: {path}") raise _PatchError(f"patch path must be relative: {path}")
if any(part == ".." for part in re.split(r"[\/]+", normalized)): if any(part == ".." for part in re.split(r"[\\/]+", normalized)):
raise _PatchError(f"patch path must not contain '..': {path}") raise _PatchError(f"patch path must not contain '..': {path}")
return normalized return normalized
@ -109,6 +109,7 @@ def _format_summary(summary: _PatchSummary) -> str:
description="Validate and summarize the patch without writing files.", description="Validate and summarize the patch without writing files.",
default=False, default=False,
), ),
required=["edits"],
) )
) )
class ApplyPatchTool(_FsTool): class ApplyPatchTool(_FsTool):
@ -143,8 +144,15 @@ class ApplyPatchTool(_FsTool):
summaries: list[_PatchSummary] = [] summaries: list[_PatchSummary] = []
for edit in edits: for edit in edits:
path = _validate_relative_path(edit["path"]) if not isinstance(edit, dict):
action = edit["action"] raise _PatchError("each edit must be an object")
raw_path = edit.get("path")
if not isinstance(raw_path, str):
raise _PatchError("path required for edit")
path = _validate_relative_path(raw_path)
action = edit.get("action")
if not isinstance(action, str):
raise _PatchError(f"action required for edit: {path}")
source = self._resolve(path) source = self._resolve(path)
if action == "add": if action == "add":
@ -179,7 +187,9 @@ class ApplyPatchTool(_FsTool):
added, deleted = _line_diff_stats(content, new_norm) added, deleted = _line_diff_stats(content, new_norm)
action_name = "update" action_name = "update"
else: else:
new_norm = _lines_to_text(new_text.splitlines()) new_norm = new_text.replace("\r\n", "\n")
if new_norm and not new_norm.endswith("\n"):
new_norm += "\n"
writes[source] = new_norm writes[source] = new_norm
deletes.discard(source) deletes.discard(source)
added = _text_line_count(new_norm) added = _text_line_count(new_norm)
@ -274,7 +284,7 @@ class ApplyPatchTool(_FsTool):
if norm_content.find(norm_old, pos + 1) >= 0: if norm_content.find(norm_old, pos + 1) >= 0:
raise _PatchError(f"old_text appears multiple times in {path}") raise _PatchError(f"old_text appears multiple times in {path}")
if norm_old.strip() == norm_content.strip(): if norm_old == norm_content:
deletes.add(source) deletes.add(source)
writes.pop(source, None) writes.pop(source, None)
added, deleted = 0, _text_line_count(content) added, deleted = 0, _text_line_count(content)

View File

@ -46,6 +46,25 @@ def test_apply_patch_edits_add_new_file(tmp_path):
assert (tmp_path / "config.py").read_text() == "DEBUG = True\n" assert (tmp_path / "config.py").read_text() == "DEBUG = True\n"
def test_apply_patch_edits_preserves_new_file_trailing_blank_lines(tmp_path):
tool = ApplyPatchTool(workspace=tmp_path)
result = asyncio.run(
tool.execute(
edits=[
{
"path": "notes.txt",
"action": "add",
"new_text": "one\n\n",
}
]
)
)
assert "add notes.txt" in result
assert (tmp_path / "notes.txt").read_text() == "one\n\n"
def test_apply_patch_edits_add_to_existing_file(tmp_path): def test_apply_patch_edits_add_to_existing_file(tmp_path):
target = tmp_path / "log.py" target = tmp_path / "log.py"
target.write_text("import logging\n\nlogger = logging.getLogger(__name__)\n") target.write_text("import logging\n\nlogger = logging.getLogger(__name__)\n")
@ -112,6 +131,28 @@ def test_apply_patch_edits_delete_entire_file(tmp_path):
assert not target.exists() assert not target.exists()
def test_apply_patch_edits_delete_substring_with_surrounding_whitespace(tmp_path):
target = tmp_path / "keep_whitespace.txt"
target.write_text(" token \n")
tool = ApplyPatchTool(workspace=tmp_path)
result = asyncio.run(
tool.execute(
edits=[
{
"path": "keep_whitespace.txt",
"action": "delete",
"old_text": "token",
}
]
)
)
assert "update keep_whitespace.txt" in result
assert target.exists()
assert target.read_text() == " \n"
def test_apply_patch_edits_batch_multiple_files(tmp_path): def test_apply_patch_edits_batch_multiple_files(tmp_path):
a = tmp_path / "a.py" a = tmp_path / "a.py"
a.write_text("X = 1\n") a.write_text("X = 1\n")
@ -220,12 +261,48 @@ def test_apply_patch_edits_rejects_absolute_and_parent_paths(tmp_path):
] ]
) )
) )
windows_absolute = asyncio.run(
tool.execute(
edits=[
{
"path": r"C:\owned.txt",
"action": "add",
"new_text": "nope",
}
]
)
)
windows_parent = asyncio.run(
tool.execute(
edits=[
{
"path": r"..\owned.txt",
"action": "add",
"new_text": "nope",
}
]
)
)
assert "must be relative" in absolute assert "must be relative" in absolute
assert "must not contain '..'" in parent assert "must not contain '..'" in parent
assert "must be relative" in windows_absolute
assert "must not contain '..'" in windows_parent
assert not (tmp_path.parent / "owned.txt").exists() assert not (tmp_path.parent / "owned.txt").exists()
def test_apply_patch_edits_reports_invalid_edit_shapes(tmp_path):
tool = ApplyPatchTool(workspace=tmp_path)
missing_path = asyncio.run(tool.execute(edits=[{"action": "add", "new_text": "x"}]))
missing_action = asyncio.run(tool.execute(edits=[{"path": "x.txt", "new_text": "x"}]))
non_object = asyncio.run(tool.execute(edits=["not an object"])) # type: ignore[list-item]
assert "path required for edit" in missing_path
assert "action required for edit: x.txt" in missing_action
assert "each edit must be an object" in non_object
def test_apply_patch_edits_rolls_back_when_late_operation_fails(tmp_path): def test_apply_patch_edits_rolls_back_when_late_operation_fails(tmp_path):
first = tmp_path / "first.txt" first = tmp_path / "first.txt"
first.write_text("before\n") first.write_text("before\n")