Preserve session key when archiving new sessions

This commit is contained in:
chengyongru 2026-06-10 16:32:45 +08:00 committed by Xubin Ren
parent bfc6febddc
commit 8c30dc5a57
3 changed files with 17 additions and 7 deletions

View File

@ -212,7 +212,7 @@ async def cmd_new(ctx: CommandContext) -> OutboundMessage:
loop.sessions.save(session)
loop.sessions.invalidate(session.key)
if snapshot:
loop._schedule_background(loop.consolidator.archive(snapshot))
loop._schedule_background(loop.consolidator.archive(snapshot, session_key=ctx.key))
return OutboundMessage(
channel=ctx.msg.channel, chat_id=ctx.msg.chat_id,
content="New session started.",

View File

@ -519,8 +519,9 @@ class TestNewCommandArchival:
call_count = 0
async def _failing_summarize(_messages) -> bool:
async def _failing_summarize(_messages, *, session_key=None) -> bool:
nonlocal call_count
assert session_key == "cli:test"
call_count += 1
return False
@ -551,10 +552,12 @@ class TestNewCommandArchival:
loop.sessions.save(session)
archived_count = -1
archived_session_key = None
async def _fake_summarize(messages) -> bool:
nonlocal archived_count
async def _fake_summarize(messages, *, session_key=None) -> bool:
nonlocal archived_count, archived_session_key
archived_count = len(messages)
archived_session_key = session_key
return True
loop.consolidator.archive = _fake_summarize # type: ignore[method-assign]
@ -567,6 +570,7 @@ class TestNewCommandArchival:
await loop.close_mcp()
assert archived_count == 3
assert archived_session_key == "cli:test"
@pytest.mark.asyncio
async def test_new_clears_session_and_responds(self, tmp_path: Path) -> None:
@ -579,7 +583,8 @@ class TestNewCommandArchival:
session.add_message("assistant", f"resp{i}")
loop.sessions.save(session)
async def _ok_summarize(_messages) -> bool:
async def _ok_summarize(_messages, *, session_key=None) -> bool:
assert session_key == "cli:test"
return True
loop.consolidator.archive = _ok_summarize # type: ignore[method-assign]
@ -606,7 +611,8 @@ class TestNewCommandArchival:
archived = asyncio.Event()
release_archive = asyncio.Event()
async def _slow_summarize(_messages) -> bool:
async def _slow_summarize(_messages, *, session_key=None) -> bool:
assert session_key == "cli:test"
await release_archive.wait()
archived.set()
return True

View File

@ -219,8 +219,11 @@ async def test_preflight_consolidation_before_llm_call(tmp_path, monkeypatch) ->
loop = _make_loop(tmp_path, estimated_tokens=0, context_window_tokens=200)
async def track_consolidate(messages):
archived_session_keys: list[str | None] = []
async def track_consolidate(messages, *, session_key=None):
order.append("consolidate")
archived_session_keys.append(session_key)
return True
loop.consolidator.archive = track_consolidate # type: ignore[method-assign]
@ -251,3 +254,4 @@ async def test_preflight_consolidation_before_llm_call(tmp_path, monkeypatch) ->
assert "consolidate" in order
assert "llm" in order
assert order.index("consolidate") < order.index("llm")
assert archived_session_keys == ["cli:test"]