fix(long-task): honor final signal and file tracking

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
Xubin Ren 2026-05-13 16:53:58 +00:00
parent 5f5f3d5d97
commit 78e8cc3e55
2 changed files with 100 additions and 25 deletions

View File

@ -4,7 +4,7 @@ from __future__ import annotations
import time import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, TYPE_CHECKING from typing import TYPE_CHECKING, Any
from loguru import logger from loguru import logger
@ -30,6 +30,7 @@ if TYPE_CHECKING:
class HandoffState: class HandoffState:
"""Structured progress state passed between long-task steps.""" """Structured progress state passed between long-task steps."""
signal_type: str = ""
message: str = "" message: str = ""
files_created: list[str] = field(default_factory=list) files_created: list[str] = field(default_factory=list)
files_modified: list[str] = field(default_factory=list) files_modified: list[str] = field(default_factory=list)
@ -39,6 +40,7 @@ class HandoffState:
def is_empty(self) -> bool: def is_empty(self) -> bool:
return not any( return not any(
[ [
self.signal_type,
self.message, self.message,
self.files_created, self.files_created,
self.files_modified, self.files_modified,
@ -105,6 +107,7 @@ class HandoffTool(Tool):
verification: str = "", verification: str = "",
**kwargs: Any, **kwargs: Any,
) -> str: ) -> str:
self._store.signal_type = "handoff"
self._store.message = message self._store.message = message
self._store.files_created = list(files_created or []) self._store.files_created = list(files_created or [])
self._store.files_modified = list(files_modified or []) self._store.files_modified = list(files_modified or [])
@ -139,6 +142,7 @@ class CompleteTool(Tool):
) )
async def execute(self, summary: str, **kwargs: Any) -> str: async def execute(self, summary: str, **kwargs: Any) -> str:
self._store.signal_type = "complete"
self._store.message = summary self._store.message = summary
return "Task marked as complete. Awaiting validation." return "Task marked as complete. Awaiting validation."
@ -242,9 +246,19 @@ def _extract_handoff_from_messages(messages: list[dict[str, Any]]) -> str:
return "" return ""
_FILE_EVENT_PREFIXES = ("Wrote ", "Edited ") def _extract_write_path(detail: str) -> str | None:
# NOTE: path extraction depends on write_file/edit_file detail format. prefix = "Successfully wrote "
# If those tools change their output format, this mapping must be updated. marker = " to "
if not detail.startswith(prefix) or marker not in detail:
return None
return detail.rsplit(marker, 1)[1].strip()
def _extract_edit_path(detail: str) -> str | None:
for prefix in ("Successfully created ", "Successfully edited "):
if detail.startswith(prefix):
return detail.removeprefix(prefix).strip()
return None
def _extract_file_changes( def _extract_file_changes(
@ -259,13 +273,19 @@ def _extract_file_changes(
detail = event.get("detail", "") detail = event.get("detail", "")
if status != "ok": if status != "ok":
continue continue
if name in ("write_file", "edit_file"): if name == "write_file":
if detail.startswith(_FILE_EVENT_PREFIXES): path = _extract_write_path(detail)
path = detail.split(" ", 1)[1].split(":")[0].strip() if path:
if name == "write_file": created.append(path)
created.append(path) else:
else: logger.debug(
modified.append(path) "long_task: skipping file event with unexpected detail: {}",
detail[:80],
)
elif name == "edit_file":
path = _extract_edit_path(detail)
if path:
modified.append(path)
else: else:
logger.debug( logger.debug(
"long_task: skipping file event with unexpected detail: {}", "long_task: skipping file event with unexpected detail: {}",
@ -502,7 +522,7 @@ class LongTaskTool(Tool):
# Determine signal from tool events # Determine signal from tool events
sig_type = "none" sig_type = "none"
for event in tool_events: for event in reversed(tool_events):
ev_name = event.get("name", "") ev_name = event.get("name", "")
if ev_name == "complete": if ev_name == "complete":
sig_type = "complete" sig_type = "complete"
@ -512,11 +532,9 @@ class LongTaskTool(Tool):
break break
# Fallback: if no explicit signal but CompleteTool/HandoffTool was # Fallback: if no explicit signal but CompleteTool/HandoffTool was
# called without arguments (message empty), use final_content # called but the runner did not expose tool events, trust the store.
if sig_type == "none" and signal_store.message: if sig_type == "none" and signal_store.signal_type:
# Tool was called but we couldn't detect from events; sig_type = signal_store.signal_type
# use the store content as handoff
sig_type = "handoff"
elif sig_type == "none": elif sig_type == "none":
signal_store.message = _extract_handoff_from_messages( signal_store.message = _extract_handoff_from_messages(
getattr(result, "messages", []) or [] getattr(result, "messages", []) or []

View File

@ -1,22 +1,21 @@
"""Tests for Long Task Tool: HandoffTool, CompleteTool, LongTaskTool.""" """Tests for Long Task Tool: HandoffTool, CompleteTool, LongTaskTool."""
import pytest
from types import SimpleNamespace from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock from unittest.mock import AsyncMock, MagicMock
import pytest
from nanobot.agent.tools.long_task import ( from nanobot.agent.tools.long_task import (
CompleteTool,
HandoffState, HandoffState,
HandoffTool, HandoffTool,
CompleteTool,
LongTaskTool, LongTaskTool,
LongTaskEvent,
_build_system_prompt, _build_system_prompt,
_build_user_message, _build_user_message,
_extract_file_changes, _extract_file_changes,
_extract_handoff_from_messages, _extract_handoff_from_messages,
) )
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Signal tool tests # Signal tool tests
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -34,6 +33,7 @@ async def test_handoff_tool_stores_structured_signal():
verification="Tests passed", verification="Tests passed",
) )
assert result == "Progress recorded. The next step will continue from here." assert result == "Progress recorded. The next step will continue from here."
assert store.signal_type == "handoff"
assert store.message == "Processed items 1-8. Results in out.md." assert store.message == "Processed items 1-8. Results in out.md."
assert store.files_created == ["out.md", "report.md"] assert store.files_created == ["out.md", "report.md"]
assert store.files_modified == ["main.py"] assert store.files_modified == ["main.py"]
@ -58,6 +58,7 @@ async def test_complete_tool_stores_signal():
tool = CompleteTool(store) tool = CompleteTool(store)
result = await tool.execute(summary="All 100 items processed. Summary in report.md") result = await tool.execute(summary="All 100 items processed. Summary in report.md")
assert result == "Task marked as complete. Awaiting validation." assert result == "Task marked as complete. Awaiting validation."
assert store.signal_type == "complete"
assert store.message == "All 100 items processed. Summary in report.md" assert store.message == "All 100 items processed. Summary in report.md"
@ -70,6 +71,7 @@ async def test_signal_tools_overwrite_on_multiple_calls():
await handoff.execute(message="first progress") await handoff.execute(message="first progress")
assert store.message == "first progress" assert store.message == "first progress"
await complete.execute(summary="done early") await complete.execute(summary="done early")
assert store.signal_type == "complete"
assert store.message == "done early" assert store.message == "done early"
@ -191,6 +193,44 @@ async def test_long_task_completes_after_multiple_handoffs():
assert call_count == 4 # 3 main steps + validation assert call_count == 4 # 3 main steps + validation
@pytest.mark.asyncio
async def test_long_task_uses_last_signal_when_multiple_signals_called():
"""If a step calls handoff() then complete(), complete() should win."""
mgr = _make_manager_stub()
call_count = 0
async def fake_run_step(*, system_prompt, user_message, extra_tools, max_iterations=None):
nonlocal call_count
call_count += 1
if call_count == 1:
for t in extra_tools:
if t.name == "handoff":
await t.execute(message="Partial progress.")
elif t.name == "complete":
await t.execute(summary="Actually complete.")
return _step_result(
tools_used=["handoff", "complete"],
tool_events=[
{"name": "handoff", "status": "ok", "detail": ""},
{"name": "complete", "status": "ok", "detail": ""},
],
)
for t in extra_tools:
if t.name == "complete":
await t.execute(summary="Validated")
return _step_result(
tools_used=["complete"],
tool_events=[{"name": "complete", "status": "ok", "detail": ""}],
)
mgr.run_step.side_effect = fake_run_step
tool = LongTaskTool(manager=mgr)
result = await tool.execute(goal="Do something.", max_steps=1)
assert result == "Actually complete."
assert call_count == 2 # main step + validation
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_long_task_validation_falls_back_to_handoff(): async def test_long_task_validation_falls_back_to_handoff():
"""Subagent claims complete but validation fails — task continues.""" """Subagent claims complete but validation fails — task continues."""
@ -566,14 +606,27 @@ def test_extract_handoff_from_empty_messages():
def test_extract_file_changes_from_tool_events(): def test_extract_file_changes_from_tool_events():
events = [ events = [
{"name": "write_file", "status": "ok", "detail": "Wrote /workspace/a.py: done"}, {
{"name": "edit_file", "status": "ok", "detail": "Edited /workspace/b.py: patched"}, "name": "write_file",
"status": "ok",
"detail": "Successfully wrote 12 characters to /workspace/a.py",
},
{
"name": "edit_file",
"status": "ok",
"detail": "Successfully edited /workspace/b.py",
},
{
"name": "edit_file",
"status": "ok",
"detail": "Successfully created /workspace/c.py",
},
{"name": "read_file", "status": "ok", "detail": "Read /workspace/c.py"}, {"name": "read_file", "status": "ok", "detail": "Read /workspace/c.py"},
{"name": "write_file", "status": "error", "detail": "Failed"}, {"name": "write_file", "status": "error", "detail": "Failed"},
] ]
created, modified = _extract_file_changes(events) created, modified = _extract_file_changes(events)
assert created == ["/workspace/a.py"] assert created == ["/workspace/a.py"]
assert modified == ["/workspace/b.py"] assert modified == ["/workspace/b.py", "/workspace/c.py"]
def test_extract_file_changes_empty(): def test_extract_file_changes_empty():
@ -736,7 +789,11 @@ async def test_explicit_file_changes_override_auto_detected():
tool_events=[ tool_events=[
{"name": "handoff", "status": "ok", "detail": ""}, {"name": "handoff", "status": "ok", "detail": ""},
# Auto-detection would pick this up as "auto.py" # Auto-detection would pick this up as "auto.py"
{"name": "write_file", "status": "ok", "detail": "Wrote auto.py: content"}, {
"name": "write_file",
"status": "ok",
"detail": "Successfully wrote 7 characters to auto.py",
},
], ],
) )