mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-01 14:31:14 +00:00
test(loop): cover /stop checkpoint recovery
This commit is contained in:
parent
ee061f0595
commit
65a15f39ee
@ -1,3 +1,4 @@
|
|||||||
|
import asyncio
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
@ -308,3 +309,111 @@ async def test_next_turn_after_crash_closes_pending_user_turn_before_new_input(t
|
|||||||
{"role": "assistant", "content": "new answer"},
|
{"role": "assistant", "content": "new answer"},
|
||||||
]
|
]
|
||||||
assert AgentLoop._PENDING_USER_TURN_KEY not in session.metadata
|
assert AgentLoop._PENDING_USER_TURN_KEY not in session.metadata
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stop_preserves_runtime_checkpoint_for_next_turn(tmp_path: Path) -> None:
|
||||||
|
from nanobot.command.builtin import cmd_stop
|
||||||
|
from nanobot.command.router import CommandContext
|
||||||
|
|
||||||
|
loop = _make_full_loop(tmp_path)
|
||||||
|
loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign]
|
||||||
|
|
||||||
|
checkpoint_saved = asyncio.Event()
|
||||||
|
|
||||||
|
async def interrupted_run_agent_loop(_initial_messages, *, session=None, **_kwargs):
|
||||||
|
assert session is not None
|
||||||
|
loop._set_runtime_checkpoint(
|
||||||
|
session,
|
||||||
|
{
|
||||||
|
"assistant_message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "working",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "call_done",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "read_file", "arguments": "{}"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "call_pending",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "exec", "arguments": "{}"},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"completed_tool_results": [
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": "call_done",
|
||||||
|
"name": "read_file",
|
||||||
|
"content": "ok",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"pending_tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "call_pending",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "exec", "arguments": "{}"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
checkpoint_saved.set()
|
||||||
|
await asyncio.Event().wait()
|
||||||
|
|
||||||
|
loop._run_agent_loop = interrupted_run_agent_loop # type: ignore[method-assign]
|
||||||
|
|
||||||
|
first_msg = InboundMessage(channel="feishu", sender_id="u1", chat_id="c4", content="keep progress")
|
||||||
|
task = asyncio.create_task(loop._process_message(first_msg))
|
||||||
|
loop._active_tasks[first_msg.session_key] = [task]
|
||||||
|
await asyncio.wait_for(checkpoint_saved.wait(), timeout=1.0)
|
||||||
|
|
||||||
|
stop_msg = InboundMessage(channel="feishu", sender_id="u1", chat_id="c4", content="/stop")
|
||||||
|
stop_ctx = CommandContext(msg=stop_msg, session=None, key=stop_msg.session_key, raw="/stop", loop=loop)
|
||||||
|
stop_result = await cmd_stop(stop_ctx)
|
||||||
|
|
||||||
|
assert "Stopped 1 task" in stop_result.content
|
||||||
|
assert task.done()
|
||||||
|
|
||||||
|
loop.sessions.invalidate("feishu:c4")
|
||||||
|
interrupted = loop.sessions.get_or_create("feishu:c4")
|
||||||
|
assert interrupted.metadata.get(AgentLoop._PENDING_USER_TURN_KEY) is True
|
||||||
|
assert interrupted.metadata.get(AgentLoop._RUNTIME_CHECKPOINT_KEY) is not None
|
||||||
|
|
||||||
|
async def resumed_run_agent_loop(initial_messages, **_kwargs):
|
||||||
|
return (
|
||||||
|
"next answer",
|
||||||
|
None,
|
||||||
|
[*initial_messages, {"role": "assistant", "content": "next answer"}],
|
||||||
|
"stop",
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
|
||||||
|
loop._run_agent_loop = resumed_run_agent_loop # type: ignore[method-assign]
|
||||||
|
result = await loop._process_message(
|
||||||
|
InboundMessage(channel="feishu", sender_id="u1", chat_id="c4", content="continue here")
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.content == "next answer"
|
||||||
|
|
||||||
|
session = loop.sessions.get_or_create("feishu:c4")
|
||||||
|
assert [
|
||||||
|
{k: v for k, v in m.items() if k in {"role", "content", "tool_call_id", "name"}}
|
||||||
|
for m in session.messages
|
||||||
|
] == [
|
||||||
|
{"role": "user", "content": "keep progress"},
|
||||||
|
{"role": "assistant", "content": "working"},
|
||||||
|
{"role": "tool", "tool_call_id": "call_done", "name": "read_file", "content": "ok"},
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": "call_pending",
|
||||||
|
"name": "exec",
|
||||||
|
"content": "Error: Task interrupted before this tool finished.",
|
||||||
|
},
|
||||||
|
{"role": "user", "content": "continue here"},
|
||||||
|
{"role": "assistant", "content": "next answer"},
|
||||||
|
]
|
||||||
|
assert AgentLoop._PENDING_USER_TURN_KEY not in session.metadata
|
||||||
|
assert AgentLoop._RUNTIME_CHECKPOINT_KEY not in session.metadata
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user