mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-20 00:22:31 +00:00
fix(long-task): honor final signal and file tracking
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
parent
5f5f3d5d97
commit
78e8cc3e55
@ -4,7 +4,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, TYPE_CHECKING
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
@ -30,6 +30,7 @@ if TYPE_CHECKING:
|
|||||||
class HandoffState:
|
class HandoffState:
|
||||||
"""Structured progress state passed between long-task steps."""
|
"""Structured progress state passed between long-task steps."""
|
||||||
|
|
||||||
|
signal_type: str = ""
|
||||||
message: str = ""
|
message: str = ""
|
||||||
files_created: list[str] = field(default_factory=list)
|
files_created: list[str] = field(default_factory=list)
|
||||||
files_modified: list[str] = field(default_factory=list)
|
files_modified: list[str] = field(default_factory=list)
|
||||||
@ -39,6 +40,7 @@ class HandoffState:
|
|||||||
def is_empty(self) -> bool:
|
def is_empty(self) -> bool:
|
||||||
return not any(
|
return not any(
|
||||||
[
|
[
|
||||||
|
self.signal_type,
|
||||||
self.message,
|
self.message,
|
||||||
self.files_created,
|
self.files_created,
|
||||||
self.files_modified,
|
self.files_modified,
|
||||||
@ -105,6 +107,7 @@ class HandoffTool(Tool):
|
|||||||
verification: str = "",
|
verification: str = "",
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
|
self._store.signal_type = "handoff"
|
||||||
self._store.message = message
|
self._store.message = message
|
||||||
self._store.files_created = list(files_created or [])
|
self._store.files_created = list(files_created or [])
|
||||||
self._store.files_modified = list(files_modified or [])
|
self._store.files_modified = list(files_modified or [])
|
||||||
@ -139,6 +142,7 @@ class CompleteTool(Tool):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def execute(self, summary: str, **kwargs: Any) -> str:
|
async def execute(self, summary: str, **kwargs: Any) -> str:
|
||||||
|
self._store.signal_type = "complete"
|
||||||
self._store.message = summary
|
self._store.message = summary
|
||||||
return "Task marked as complete. Awaiting validation."
|
return "Task marked as complete. Awaiting validation."
|
||||||
|
|
||||||
@ -242,9 +246,19 @@ def _extract_handoff_from_messages(messages: list[dict[str, Any]]) -> str:
|
|||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
_FILE_EVENT_PREFIXES = ("Wrote ", "Edited ")
|
def _extract_write_path(detail: str) -> str | None:
|
||||||
# NOTE: path extraction depends on write_file/edit_file detail format.
|
prefix = "Successfully wrote "
|
||||||
# If those tools change their output format, this mapping must be updated.
|
marker = " to "
|
||||||
|
if not detail.startswith(prefix) or marker not in detail:
|
||||||
|
return None
|
||||||
|
return detail.rsplit(marker, 1)[1].strip()
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_edit_path(detail: str) -> str | None:
|
||||||
|
for prefix in ("Successfully created ", "Successfully edited "):
|
||||||
|
if detail.startswith(prefix):
|
||||||
|
return detail.removeprefix(prefix).strip()
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _extract_file_changes(
|
def _extract_file_changes(
|
||||||
@ -259,12 +273,18 @@ def _extract_file_changes(
|
|||||||
detail = event.get("detail", "")
|
detail = event.get("detail", "")
|
||||||
if status != "ok":
|
if status != "ok":
|
||||||
continue
|
continue
|
||||||
if name in ("write_file", "edit_file"):
|
|
||||||
if detail.startswith(_FILE_EVENT_PREFIXES):
|
|
||||||
path = detail.split(" ", 1)[1].split(":")[0].strip()
|
|
||||||
if name == "write_file":
|
if name == "write_file":
|
||||||
|
path = _extract_write_path(detail)
|
||||||
|
if path:
|
||||||
created.append(path)
|
created.append(path)
|
||||||
else:
|
else:
|
||||||
|
logger.debug(
|
||||||
|
"long_task: skipping file event with unexpected detail: {}",
|
||||||
|
detail[:80],
|
||||||
|
)
|
||||||
|
elif name == "edit_file":
|
||||||
|
path = _extract_edit_path(detail)
|
||||||
|
if path:
|
||||||
modified.append(path)
|
modified.append(path)
|
||||||
else:
|
else:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@ -502,7 +522,7 @@ class LongTaskTool(Tool):
|
|||||||
|
|
||||||
# Determine signal from tool events
|
# Determine signal from tool events
|
||||||
sig_type = "none"
|
sig_type = "none"
|
||||||
for event in tool_events:
|
for event in reversed(tool_events):
|
||||||
ev_name = event.get("name", "")
|
ev_name = event.get("name", "")
|
||||||
if ev_name == "complete":
|
if ev_name == "complete":
|
||||||
sig_type = "complete"
|
sig_type = "complete"
|
||||||
@ -512,11 +532,9 @@ class LongTaskTool(Tool):
|
|||||||
break
|
break
|
||||||
|
|
||||||
# Fallback: if no explicit signal but CompleteTool/HandoffTool was
|
# Fallback: if no explicit signal but CompleteTool/HandoffTool was
|
||||||
# called without arguments (message empty), use final_content
|
# called but the runner did not expose tool events, trust the store.
|
||||||
if sig_type == "none" and signal_store.message:
|
if sig_type == "none" and signal_store.signal_type:
|
||||||
# Tool was called but we couldn't detect from events;
|
sig_type = signal_store.signal_type
|
||||||
# use the store content as handoff
|
|
||||||
sig_type = "handoff"
|
|
||||||
elif sig_type == "none":
|
elif sig_type == "none":
|
||||||
signal_store.message = _extract_handoff_from_messages(
|
signal_store.message = _extract_handoff_from_messages(
|
||||||
getattr(result, "messages", []) or []
|
getattr(result, "messages", []) or []
|
||||||
|
|||||||
@ -1,22 +1,21 @@
|
|||||||
"""Tests for Long Task Tool: HandoffTool, CompleteTool, LongTaskTool."""
|
"""Tests for Long Task Tool: HandoffTool, CompleteTool, LongTaskTool."""
|
||||||
|
|
||||||
import pytest
|
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from nanobot.agent.tools.long_task import (
|
from nanobot.agent.tools.long_task import (
|
||||||
|
CompleteTool,
|
||||||
HandoffState,
|
HandoffState,
|
||||||
HandoffTool,
|
HandoffTool,
|
||||||
CompleteTool,
|
|
||||||
LongTaskTool,
|
LongTaskTool,
|
||||||
LongTaskEvent,
|
|
||||||
_build_system_prompt,
|
_build_system_prompt,
|
||||||
_build_user_message,
|
_build_user_message,
|
||||||
_extract_file_changes,
|
_extract_file_changes,
|
||||||
_extract_handoff_from_messages,
|
_extract_handoff_from_messages,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Signal tool tests
|
# Signal tool tests
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@ -34,6 +33,7 @@ async def test_handoff_tool_stores_structured_signal():
|
|||||||
verification="Tests passed",
|
verification="Tests passed",
|
||||||
)
|
)
|
||||||
assert result == "Progress recorded. The next step will continue from here."
|
assert result == "Progress recorded. The next step will continue from here."
|
||||||
|
assert store.signal_type == "handoff"
|
||||||
assert store.message == "Processed items 1-8. Results in out.md."
|
assert store.message == "Processed items 1-8. Results in out.md."
|
||||||
assert store.files_created == ["out.md", "report.md"]
|
assert store.files_created == ["out.md", "report.md"]
|
||||||
assert store.files_modified == ["main.py"]
|
assert store.files_modified == ["main.py"]
|
||||||
@ -58,6 +58,7 @@ async def test_complete_tool_stores_signal():
|
|||||||
tool = CompleteTool(store)
|
tool = CompleteTool(store)
|
||||||
result = await tool.execute(summary="All 100 items processed. Summary in report.md")
|
result = await tool.execute(summary="All 100 items processed. Summary in report.md")
|
||||||
assert result == "Task marked as complete. Awaiting validation."
|
assert result == "Task marked as complete. Awaiting validation."
|
||||||
|
assert store.signal_type == "complete"
|
||||||
assert store.message == "All 100 items processed. Summary in report.md"
|
assert store.message == "All 100 items processed. Summary in report.md"
|
||||||
|
|
||||||
|
|
||||||
@ -70,6 +71,7 @@ async def test_signal_tools_overwrite_on_multiple_calls():
|
|||||||
await handoff.execute(message="first progress")
|
await handoff.execute(message="first progress")
|
||||||
assert store.message == "first progress"
|
assert store.message == "first progress"
|
||||||
await complete.execute(summary="done early")
|
await complete.execute(summary="done early")
|
||||||
|
assert store.signal_type == "complete"
|
||||||
assert store.message == "done early"
|
assert store.message == "done early"
|
||||||
|
|
||||||
|
|
||||||
@ -191,6 +193,44 @@ async def test_long_task_completes_after_multiple_handoffs():
|
|||||||
assert call_count == 4 # 3 main steps + validation
|
assert call_count == 4 # 3 main steps + validation
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_long_task_uses_last_signal_when_multiple_signals_called():
|
||||||
|
"""If a step calls handoff() then complete(), complete() should win."""
|
||||||
|
mgr = _make_manager_stub()
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
async def fake_run_step(*, system_prompt, user_message, extra_tools, max_iterations=None):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
if call_count == 1:
|
||||||
|
for t in extra_tools:
|
||||||
|
if t.name == "handoff":
|
||||||
|
await t.execute(message="Partial progress.")
|
||||||
|
elif t.name == "complete":
|
||||||
|
await t.execute(summary="Actually complete.")
|
||||||
|
return _step_result(
|
||||||
|
tools_used=["handoff", "complete"],
|
||||||
|
tool_events=[
|
||||||
|
{"name": "handoff", "status": "ok", "detail": ""},
|
||||||
|
{"name": "complete", "status": "ok", "detail": ""},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
for t in extra_tools:
|
||||||
|
if t.name == "complete":
|
||||||
|
await t.execute(summary="Validated")
|
||||||
|
return _step_result(
|
||||||
|
tools_used=["complete"],
|
||||||
|
tool_events=[{"name": "complete", "status": "ok", "detail": ""}],
|
||||||
|
)
|
||||||
|
|
||||||
|
mgr.run_step.side_effect = fake_run_step
|
||||||
|
tool = LongTaskTool(manager=mgr)
|
||||||
|
result = await tool.execute(goal="Do something.", max_steps=1)
|
||||||
|
assert result == "Actually complete."
|
||||||
|
assert call_count == 2 # main step + validation
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_long_task_validation_falls_back_to_handoff():
|
async def test_long_task_validation_falls_back_to_handoff():
|
||||||
"""Subagent claims complete but validation fails — task continues."""
|
"""Subagent claims complete but validation fails — task continues."""
|
||||||
@ -566,14 +606,27 @@ def test_extract_handoff_from_empty_messages():
|
|||||||
|
|
||||||
def test_extract_file_changes_from_tool_events():
|
def test_extract_file_changes_from_tool_events():
|
||||||
events = [
|
events = [
|
||||||
{"name": "write_file", "status": "ok", "detail": "Wrote /workspace/a.py: done"},
|
{
|
||||||
{"name": "edit_file", "status": "ok", "detail": "Edited /workspace/b.py: patched"},
|
"name": "write_file",
|
||||||
|
"status": "ok",
|
||||||
|
"detail": "Successfully wrote 12 characters to /workspace/a.py",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "edit_file",
|
||||||
|
"status": "ok",
|
||||||
|
"detail": "Successfully edited /workspace/b.py",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "edit_file",
|
||||||
|
"status": "ok",
|
||||||
|
"detail": "Successfully created /workspace/c.py",
|
||||||
|
},
|
||||||
{"name": "read_file", "status": "ok", "detail": "Read /workspace/c.py"},
|
{"name": "read_file", "status": "ok", "detail": "Read /workspace/c.py"},
|
||||||
{"name": "write_file", "status": "error", "detail": "Failed"},
|
{"name": "write_file", "status": "error", "detail": "Failed"},
|
||||||
]
|
]
|
||||||
created, modified = _extract_file_changes(events)
|
created, modified = _extract_file_changes(events)
|
||||||
assert created == ["/workspace/a.py"]
|
assert created == ["/workspace/a.py"]
|
||||||
assert modified == ["/workspace/b.py"]
|
assert modified == ["/workspace/b.py", "/workspace/c.py"]
|
||||||
|
|
||||||
|
|
||||||
def test_extract_file_changes_empty():
|
def test_extract_file_changes_empty():
|
||||||
@ -736,7 +789,11 @@ async def test_explicit_file_changes_override_auto_detected():
|
|||||||
tool_events=[
|
tool_events=[
|
||||||
{"name": "handoff", "status": "ok", "detail": ""},
|
{"name": "handoff", "status": "ok", "detail": ""},
|
||||||
# Auto-detection would pick this up as "auto.py"
|
# Auto-detection would pick this up as "auto.py"
|
||||||
{"name": "write_file", "status": "ok", "detail": "Wrote auto.py: content"},
|
{
|
||||||
|
"name": "write_file",
|
||||||
|
"status": "ok",
|
||||||
|
"detail": "Successfully wrote 7 characters to auto.py",
|
||||||
|
},
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user