From 5acae58a13223d641bd06249d9e208a56416c17c Mon Sep 17 00:00:00 2001 From: chengyongru <2755839590@qq.com> Date: Wed, 13 May 2026 01:26:01 +0800 Subject: [PATCH] test(long-task): add boundary tests and fix race conditions - Add 7 edge-case tests: validation crash resilience, hook exception safety, mid-run correction injection, FIFO correction ordering, explicit file changes overriding auto-detection, final budget for max_steps=1, and dynamic budget switching boundaries - Fix assertion in test_long_task_completes_after_multiple_handoffs to match exact prompt format - Remove asyncio timing hack from test_state_exposure - Add asyncio.sleep(0) yield in test_inject_correction_during_execution to prevent race between signal injection and step continuation - All 34 tests passing --- tests/agent/tools/test_long_task.py | 229 +++++++++++++++++++++++++++- 1 file changed, 221 insertions(+), 8 deletions(-) diff --git a/tests/agent/tools/test_long_task.py b/tests/agent/tools/test_long_task.py index c3bc5646c..8207c4576 100644 --- a/tests/agent/tools/test_long_task.py +++ b/tests/agent/tools/test_long_task.py @@ -158,7 +158,7 @@ async def test_long_task_completes_after_multiple_handoffs(): ) elif call_count == 2: assert "Processed 1-8." in user_message - assert "Step 2" in user_message or "Step 2 of" in user_message + assert "Step 2 of 20" in user_message for t in extra_tools: if t.name == "handoff": await t.execute(message="Processed 9-16.") @@ -458,13 +458,7 @@ async def test_state_exposure(): tool = LongTaskTool(manager=mgr) assert tool.status == "idle" - # Start execution in background so we can inspect mid-run - import asyncio - task = asyncio.create_task(tool.execute(goal="Test state.", max_steps=3)) - # Give it a moment to start - await asyncio.sleep(0.01) - # Task should have finished by now since mocks are instant - await task + await tool.execute(goal="Test state.", max_steps=3) assert tool.goal == "Test state." assert tool.total_steps == 3 @@ -586,6 +580,225 @@ def test_extract_file_changes_empty(): assert _extract_file_changes([]) == ([], []) +# --------------------------------------------------------------------------- +# Boundary and edge-case tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_validation_step_crash_continues_task(): + """If the validation round itself crashes, the task should continue, not die.""" + mgr = _make_manager_stub() + call_count = 0 + + async def fake_run_step(*, system_prompt, user_message, extra_tools, max_iterations=None): + nonlocal call_count + call_count += 1 + if call_count == 1: + for t in extra_tools: + if t.name == "complete": + await t.execute(summary="Claimed done.") + return _step_result( + tools_used=["complete"], + tool_events=[{"name": "complete", "status": "ok", "detail": ""}], + ) + elif call_count == 2: + # Validation round crashes + raise RuntimeError("Validator exploded") + else: + # After validation crash, task continues and completes for real + for t in extra_tools: + if t.name == "complete": + await t.execute(summary="Actually done.") + return _step_result( + tools_used=["complete"], + tool_events=[{"name": "complete", "status": "ok", "detail": ""}], + ) + + mgr.run_step.side_effect = fake_run_step + tool = LongTaskTool(manager=mgr) + result = await tool.execute(goal="Test validation crash.", max_steps=5) + assert "Actually done." == result + assert call_count == 4 # main step + validation crash + continue + re-validation + + +@pytest.mark.asyncio +async def test_hook_exception_does_not_stop_execution(): + """A misbehaving hook must not break the long-task execution.""" + mgr = _make_manager_stub() + + async def fake_run_step(*, system_prompt, user_message, extra_tools, max_iterations=None): + for t in extra_tools: + if t.name == "complete": + await t.execute(summary="Done.") + return _step_result( + tools_used=["complete"], + tool_events=[{"name": "complete", "status": "ok", "detail": ""}], + ) + + mgr.run_step.side_effect = fake_run_step + tool = LongTaskTool(manager=mgr) + tool.set_hooks({ + "on_step_start": lambda **kw: (_ for _ in ()).throw(RuntimeError("bad hook")), + "on_event": lambda ev: (_ for _ in ()).throw(RuntimeError("bad catch-all")), + }) + result = await tool.execute(goal="Test hook resilience.") + assert result == "Done." + + +@pytest.mark.asyncio +async def test_inject_correction_during_execution(): + """Correction injected while the task is running should reach the next step.""" + import asyncio + + mgr = _make_manager_stub() + step1_done = asyncio.Event() + captured_messages = [] + + async def fake_run_step(*, system_prompt, user_message, extra_tools, max_iterations=None): + captured_messages.append(user_message) + if len(captured_messages) == 1: + for t in extra_tools: + if t.name == "handoff": + await t.execute(message="Step 1 done.") + step1_done.set() + await asyncio.sleep(0) # yield so the test coroutine can inject correction + return _step_result( + tools_used=["handoff"], + tool_events=[{"name": "handoff", "status": "ok", "detail": ""}], + ) + else: + for t in extra_tools: + if t.name == "complete": + await t.execute(summary="Step 2 done.") + return _step_result( + tools_used=["complete"], + tool_events=[{"name": "complete", "status": "ok", "detail": ""}], + ) + + mgr.run_step.side_effect = fake_run_step + tool = LongTaskTool(manager=mgr) + + task = asyncio.create_task(tool.execute(goal="Test mid-run correction.", max_steps=5)) + await step1_done.wait() + tool.inject_correction("Change direction.") + result = await task + + assert result == "Step 2 done." + assert any("Change direction." in msg for msg in captured_messages) + + +@pytest.mark.asyncio +async def test_multiple_corrections_consumed_in_order(): + """Multiple corrections should be consumed FIFO across steps.""" + mgr = _make_manager_stub() + captured_messages = [] + + async def fake_run_step(*, system_prompt, user_message, extra_tools, max_iterations=None): + captured_messages.append(user_message) + for t in extra_tools: + if t.name == "handoff": + await t.execute(message=f"Step {len(captured_messages)} done.") + return _step_result( + tools_used=["handoff"], + tool_events=[{"name": "handoff", "status": "ok", "detail": ""}], + ) + + mgr.run_step.side_effect = fake_run_step + tool = LongTaskTool(manager=mgr) + tool.inject_correction("First correction.") + tool.inject_correction("Second correction.") + await tool.execute(goal="Test FIFO.", max_steps=3) + + # Step 0 should see First correction, step 1 should see Second + assert any("First correction." in msg for msg in captured_messages) + assert any("Second correction." in msg for msg in captured_messages) + # Verify ordering: First appears in an earlier index than Second + first_idx = next(i for i, msg in enumerate(captured_messages) if "First correction." in msg) + second_idx = next(i for i, msg in enumerate(captured_messages) if "Second correction." in msg) + assert first_idx < second_idx + + +@pytest.mark.asyncio +async def test_explicit_file_changes_override_auto_detected(): + """If subagent explicitly reports files_created, auto-detection must not overwrite.""" + mgr = _make_manager_stub() + + async def fake_run_step(*, system_prompt, user_message, extra_tools, max_iterations=None): + for t in extra_tools: + if t.name == "handoff": + await t.execute( + message="Done.", + files_created=["explicit.py"], + ) + return _step_result( + tools_used=["handoff"], + tool_events=[ + {"name": "handoff", "status": "ok", "detail": ""}, + # Auto-detection would pick this up as "auto.py" + {"name": "write_file", "status": "ok", "detail": "Wrote auto.py: content"}, + ], + ) + + mgr.run_step.side_effect = fake_run_step + tool = LongTaskTool(manager=mgr) + await tool.execute(goal="Test file merge.", max_steps=2) + + # Explicit report should win over auto-detection + assert tool.last_handoff.files_created == ["explicit.py"] + + +@pytest.mark.asyncio +async def test_max_steps_one_uses_final_budget(): + """max_steps=1 should immediately use the final-step budget of 4.""" + mgr = _make_manager_stub() + captured_budgets = [] + + async def fake_run_step(*, system_prompt, user_message, extra_tools, max_iterations=None): + captured_budgets.append(max_iterations) + for t in extra_tools: + if t.name == "complete": + await t.execute(summary="Done.") + return _step_result( + tools_used=["complete"], + tool_events=[{"name": "complete", "status": "ok", "detail": ""}], + ) + + mgr.run_step.side_effect = fake_run_step + tool = LongTaskTool(manager=mgr) + await tool.execute(goal="Test max_steps=1.", max_steps=1) + assert captured_budgets[0] == 4 # final step budget + assert captured_budgets[1] == 4 # validation also uses short budget + + +@pytest.mark.asyncio +async def test_budget_switches_at_correct_step(): + """With max_steps=5, budget should switch from 8 to 4 at step 3 (max_steps - 2).""" + mgr = _make_manager_stub() + captured_budgets = [] + + async def fake_run_step(*, system_prompt, user_message, extra_tools, max_iterations=None): + captured_budgets.append(max_iterations) + for t in extra_tools: + if t.name == "handoff": + await t.execute(message=f"Step {len(captured_budgets)}.") + return _step_result( + tools_used=["handoff"], + tool_events=[{"name": "handoff", "status": "ok", "detail": ""}], + ) + + mgr.run_step.side_effect = fake_run_step + tool = LongTaskTool(manager=mgr) + await tool.execute(goal="Test budget switch.", max_steps=5) + + # Steps 0,1,2 use 8; steps 3,4 use 4 + assert captured_budgets[0] == 8 + assert captured_budgets[1] == 8 + assert captured_budgets[2] == 8 + assert captured_budgets[3] == 4 + assert captured_budgets[4] == 4 + + # --------------------------------------------------------------------------- # Integration: verify LongTaskTool is wired into the main agent loop # ---------------------------------------------------------------------------