From 65a15f39ee7ebfe8b9585231165222cf5ee1cd76 Mon Sep 17 00:00:00 2001 From: yeyitech Date: Tue, 14 Apr 2026 13:42:59 +0800 Subject: [PATCH] test(loop): cover /stop checkpoint recovery --- tests/agent/test_loop_save_turn.py | 109 +++++++++++++++++++++++++++++ 1 file changed, 109 insertions(+) diff --git a/tests/agent/test_loop_save_turn.py b/tests/agent/test_loop_save_turn.py index c965ccd8c..8885e0cc0 100644 --- a/tests/agent/test_loop_save_turn.py +++ b/tests/agent/test_loop_save_turn.py @@ -1,3 +1,4 @@ +import asyncio from pathlib import Path from unittest.mock import AsyncMock, MagicMock @@ -308,3 +309,111 @@ async def test_next_turn_after_crash_closes_pending_user_turn_before_new_input(t {"role": "assistant", "content": "new answer"}, ] assert AgentLoop._PENDING_USER_TURN_KEY not in session.metadata + + +@pytest.mark.asyncio +async def test_stop_preserves_runtime_checkpoint_for_next_turn(tmp_path: Path) -> None: + from nanobot.command.builtin import cmd_stop + from nanobot.command.router import CommandContext + + loop = _make_full_loop(tmp_path) + loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign] + + checkpoint_saved = asyncio.Event() + + async def interrupted_run_agent_loop(_initial_messages, *, session=None, **_kwargs): + assert session is not None + loop._set_runtime_checkpoint( + session, + { + "assistant_message": { + "role": "assistant", + "content": "working", + "tool_calls": [ + { + "id": "call_done", + "type": "function", + "function": {"name": "read_file", "arguments": "{}"}, + }, + { + "id": "call_pending", + "type": "function", + "function": {"name": "exec", "arguments": "{}"}, + }, + ], + }, + "completed_tool_results": [ + { + "role": "tool", + "tool_call_id": "call_done", + "name": "read_file", + "content": "ok", + } + ], + "pending_tool_calls": [ + { + "id": "call_pending", + "type": "function", + "function": {"name": "exec", "arguments": "{}"}, + } + ], + }, + ) + checkpoint_saved.set() + await asyncio.Event().wait() + + loop._run_agent_loop = interrupted_run_agent_loop # type: ignore[method-assign] + + first_msg = InboundMessage(channel="feishu", sender_id="u1", chat_id="c4", content="keep progress") + task = asyncio.create_task(loop._process_message(first_msg)) + loop._active_tasks[first_msg.session_key] = [task] + await asyncio.wait_for(checkpoint_saved.wait(), timeout=1.0) + + stop_msg = InboundMessage(channel="feishu", sender_id="u1", chat_id="c4", content="/stop") + stop_ctx = CommandContext(msg=stop_msg, session=None, key=stop_msg.session_key, raw="/stop", loop=loop) + stop_result = await cmd_stop(stop_ctx) + + assert "Stopped 1 task" in stop_result.content + assert task.done() + + loop.sessions.invalidate("feishu:c4") + interrupted = loop.sessions.get_or_create("feishu:c4") + assert interrupted.metadata.get(AgentLoop._PENDING_USER_TURN_KEY) is True + assert interrupted.metadata.get(AgentLoop._RUNTIME_CHECKPOINT_KEY) is not None + + async def resumed_run_agent_loop(initial_messages, **_kwargs): + return ( + "next answer", + None, + [*initial_messages, {"role": "assistant", "content": "next answer"}], + "stop", + False, + ) + + loop._run_agent_loop = resumed_run_agent_loop # type: ignore[method-assign] + result = await loop._process_message( + InboundMessage(channel="feishu", sender_id="u1", chat_id="c4", content="continue here") + ) + + assert result is not None + assert result.content == "next answer" + + session = loop.sessions.get_or_create("feishu:c4") + assert [ + {k: v for k, v in m.items() if k in {"role", "content", "tool_call_id", "name"}} + for m in session.messages + ] == [ + {"role": "user", "content": "keep progress"}, + {"role": "assistant", "content": "working"}, + {"role": "tool", "tool_call_id": "call_done", "name": "read_file", "content": "ok"}, + { + "role": "tool", + "tool_call_id": "call_pending", + "name": "exec", + "content": "Error: Task interrupted before this tool finished.", + }, + {"role": "user", "content": "continue here"}, + {"role": "assistant", "content": "next answer"}, + ] + assert AgentLoop._PENDING_USER_TURN_KEY not in session.metadata + assert AgentLoop._RUNTIME_CHECKPOINT_KEY not in session.metadata