test(long-task): add boundary tests and fix race conditions

- Add 7 edge-case tests: validation crash resilience, hook exception safety, mid-run correction injection, FIFO correction ordering, explicit file changes overriding auto-detection, final budget for max_steps=1, and dynamic budget switching boundaries

- Fix assertion in test_long_task_completes_after_multiple_handoffs to match exact prompt format

- Remove asyncio timing hack from test_state_exposure

- Add asyncio.sleep(0) yield in test_inject_correction_during_execution to prevent race between signal injection and step continuation

- All 34 tests passing
This commit is contained in:
chengyongru 2026-05-13 01:26:01 +08:00
parent 78ecb2a99a
commit 5acae58a13

View File

@ -158,7 +158,7 @@ async def test_long_task_completes_after_multiple_handoffs():
)
elif call_count == 2:
assert "Processed 1-8." in user_message
assert "Step 2" in user_message or "Step 2 of" in user_message
assert "Step 2 of 20" in user_message
for t in extra_tools:
if t.name == "handoff":
await t.execute(message="Processed 9-16.")
@ -458,13 +458,7 @@ async def test_state_exposure():
tool = LongTaskTool(manager=mgr)
assert tool.status == "idle"
# Start execution in background so we can inspect mid-run
import asyncio
task = asyncio.create_task(tool.execute(goal="Test state.", max_steps=3))
# Give it a moment to start
await asyncio.sleep(0.01)
# Task should have finished by now since mocks are instant
await task
await tool.execute(goal="Test state.", max_steps=3)
assert tool.goal == "Test state."
assert tool.total_steps == 3
@ -586,6 +580,225 @@ def test_extract_file_changes_empty():
assert _extract_file_changes([]) == ([], [])
# ---------------------------------------------------------------------------
# Boundary and edge-case tests
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_validation_step_crash_continues_task():
"""If the validation round itself crashes, the task should continue, not die."""
mgr = _make_manager_stub()
call_count = 0
async def fake_run_step(*, system_prompt, user_message, extra_tools, max_iterations=None):
nonlocal call_count
call_count += 1
if call_count == 1:
for t in extra_tools:
if t.name == "complete":
await t.execute(summary="Claimed done.")
return _step_result(
tools_used=["complete"],
tool_events=[{"name": "complete", "status": "ok", "detail": ""}],
)
elif call_count == 2:
# Validation round crashes
raise RuntimeError("Validator exploded")
else:
# After validation crash, task continues and completes for real
for t in extra_tools:
if t.name == "complete":
await t.execute(summary="Actually done.")
return _step_result(
tools_used=["complete"],
tool_events=[{"name": "complete", "status": "ok", "detail": ""}],
)
mgr.run_step.side_effect = fake_run_step
tool = LongTaskTool(manager=mgr)
result = await tool.execute(goal="Test validation crash.", max_steps=5)
assert "Actually done." == result
assert call_count == 4 # main step + validation crash + continue + re-validation
@pytest.mark.asyncio
async def test_hook_exception_does_not_stop_execution():
"""A misbehaving hook must not break the long-task execution."""
mgr = _make_manager_stub()
async def fake_run_step(*, system_prompt, user_message, extra_tools, max_iterations=None):
for t in extra_tools:
if t.name == "complete":
await t.execute(summary="Done.")
return _step_result(
tools_used=["complete"],
tool_events=[{"name": "complete", "status": "ok", "detail": ""}],
)
mgr.run_step.side_effect = fake_run_step
tool = LongTaskTool(manager=mgr)
tool.set_hooks({
"on_step_start": lambda **kw: (_ for _ in ()).throw(RuntimeError("bad hook")),
"on_event": lambda ev: (_ for _ in ()).throw(RuntimeError("bad catch-all")),
})
result = await tool.execute(goal="Test hook resilience.")
assert result == "Done."
@pytest.mark.asyncio
async def test_inject_correction_during_execution():
"""Correction injected while the task is running should reach the next step."""
import asyncio
mgr = _make_manager_stub()
step1_done = asyncio.Event()
captured_messages = []
async def fake_run_step(*, system_prompt, user_message, extra_tools, max_iterations=None):
captured_messages.append(user_message)
if len(captured_messages) == 1:
for t in extra_tools:
if t.name == "handoff":
await t.execute(message="Step 1 done.")
step1_done.set()
await asyncio.sleep(0) # yield so the test coroutine can inject correction
return _step_result(
tools_used=["handoff"],
tool_events=[{"name": "handoff", "status": "ok", "detail": ""}],
)
else:
for t in extra_tools:
if t.name == "complete":
await t.execute(summary="Step 2 done.")
return _step_result(
tools_used=["complete"],
tool_events=[{"name": "complete", "status": "ok", "detail": ""}],
)
mgr.run_step.side_effect = fake_run_step
tool = LongTaskTool(manager=mgr)
task = asyncio.create_task(tool.execute(goal="Test mid-run correction.", max_steps=5))
await step1_done.wait()
tool.inject_correction("Change direction.")
result = await task
assert result == "Step 2 done."
assert any("Change direction." in msg for msg in captured_messages)
@pytest.mark.asyncio
async def test_multiple_corrections_consumed_in_order():
"""Multiple corrections should be consumed FIFO across steps."""
mgr = _make_manager_stub()
captured_messages = []
async def fake_run_step(*, system_prompt, user_message, extra_tools, max_iterations=None):
captured_messages.append(user_message)
for t in extra_tools:
if t.name == "handoff":
await t.execute(message=f"Step {len(captured_messages)} done.")
return _step_result(
tools_used=["handoff"],
tool_events=[{"name": "handoff", "status": "ok", "detail": ""}],
)
mgr.run_step.side_effect = fake_run_step
tool = LongTaskTool(manager=mgr)
tool.inject_correction("First correction.")
tool.inject_correction("Second correction.")
await tool.execute(goal="Test FIFO.", max_steps=3)
# Step 0 should see First correction, step 1 should see Second
assert any("First correction." in msg for msg in captured_messages)
assert any("Second correction." in msg for msg in captured_messages)
# Verify ordering: First appears in an earlier index than Second
first_idx = next(i for i, msg in enumerate(captured_messages) if "First correction." in msg)
second_idx = next(i for i, msg in enumerate(captured_messages) if "Second correction." in msg)
assert first_idx < second_idx
@pytest.mark.asyncio
async def test_explicit_file_changes_override_auto_detected():
"""If subagent explicitly reports files_created, auto-detection must not overwrite."""
mgr = _make_manager_stub()
async def fake_run_step(*, system_prompt, user_message, extra_tools, max_iterations=None):
for t in extra_tools:
if t.name == "handoff":
await t.execute(
message="Done.",
files_created=["explicit.py"],
)
return _step_result(
tools_used=["handoff"],
tool_events=[
{"name": "handoff", "status": "ok", "detail": ""},
# Auto-detection would pick this up as "auto.py"
{"name": "write_file", "status": "ok", "detail": "Wrote auto.py: content"},
],
)
mgr.run_step.side_effect = fake_run_step
tool = LongTaskTool(manager=mgr)
await tool.execute(goal="Test file merge.", max_steps=2)
# Explicit report should win over auto-detection
assert tool.last_handoff.files_created == ["explicit.py"]
@pytest.mark.asyncio
async def test_max_steps_one_uses_final_budget():
"""max_steps=1 should immediately use the final-step budget of 4."""
mgr = _make_manager_stub()
captured_budgets = []
async def fake_run_step(*, system_prompt, user_message, extra_tools, max_iterations=None):
captured_budgets.append(max_iterations)
for t in extra_tools:
if t.name == "complete":
await t.execute(summary="Done.")
return _step_result(
tools_used=["complete"],
tool_events=[{"name": "complete", "status": "ok", "detail": ""}],
)
mgr.run_step.side_effect = fake_run_step
tool = LongTaskTool(manager=mgr)
await tool.execute(goal="Test max_steps=1.", max_steps=1)
assert captured_budgets[0] == 4 # final step budget
assert captured_budgets[1] == 4 # validation also uses short budget
@pytest.mark.asyncio
async def test_budget_switches_at_correct_step():
"""With max_steps=5, budget should switch from 8 to 4 at step 3 (max_steps - 2)."""
mgr = _make_manager_stub()
captured_budgets = []
async def fake_run_step(*, system_prompt, user_message, extra_tools, max_iterations=None):
captured_budgets.append(max_iterations)
for t in extra_tools:
if t.name == "handoff":
await t.execute(message=f"Step {len(captured_budgets)}.")
return _step_result(
tools_used=["handoff"],
tool_events=[{"name": "handoff", "status": "ok", "detail": ""}],
)
mgr.run_step.side_effect = fake_run_step
tool = LongTaskTool(manager=mgr)
await tool.execute(goal="Test budget switch.", max_steps=5)
# Steps 0,1,2 use 8; steps 3,4 use 4
assert captured_budgets[0] == 8
assert captured_budgets[1] == 8
assert captured_budgets[2] == 8
assert captured_budgets[3] == 4
assert captured_budgets[4] == 4
# ---------------------------------------------------------------------------
# Integration: verify LongTaskTool is wired into the main agent loop
# ---------------------------------------------------------------------------