fix: use effective session key for _active_tasks in unified mode

This commit is contained in:
whs 2026-04-08 21:39:12 +08:00 committed by Xubin Ren
parent 985f9c443b
commit b4c7cd654e
2 changed files with 112 additions and 4 deletions

View File

@ -3,7 +3,9 @@
from __future__ import annotations
import asyncio
import dataclasses
import json
import re
import os
import time
from contextlib import AsyncExitStack, nullcontext
@ -40,6 +42,8 @@ if TYPE_CHECKING:
from nanobot.cron.service import CronService
# Named constant for unified session key, used across multiple locations
UNIFIED_SESSION_KEY = "unified:default"
class _LoopHook(AgentHook):
"""Core hook for the main loop."""
@ -385,15 +389,17 @@ class AgentLoop:
if result:
await self.bus.publish_outbound(result)
continue
# Compute the effective session key before dispatching
# This ensures /stop command can find tasks correctly when unified session is enabled
effective_key = UNIFIED_SESSION_KEY if self._unified_session and not msg.session_key_override else msg.session_key
task = asyncio.create_task(self._dispatch(msg))
self._active_tasks.setdefault(msg.session_key, []).append(task)
task.add_done_callback(lambda t, k=msg.session_key: self._active_tasks.get(k, []) and self._active_tasks[k].remove(t) if t in self._active_tasks.get(k, []) else None)
self._active_tasks.setdefault(effective_key, []).append(task)
task.add_done_callback(lambda t, k=effective_key: self._active_tasks.get(k, []) and self._active_tasks[k].remove(t) if t in self._active_tasks.get(k, []) else None)
async def _dispatch(self, msg: InboundMessage) -> None:
"""Process a message: per-session serial, cross-session concurrent."""
if self._unified_session and not msg.session_key_override:
import dataclasses
msg = dataclasses.replace(msg, session_key_override="unified:default")
msg = dataclasses.replace(msg, session_key_override=UNIFIED_SESSION_KEY)
lock = self._session_locks.setdefault(msg.session_key, asyncio.Lock())
gate = self._concurrency_gate or nullcontext()
async with lock, gate:

View File

@ -398,3 +398,105 @@ class TestConsolidationUnaffectedByUnifiedSession:
consolidator.estimate_session_prompt_tokens.assert_called_once_with(session)
# but archive was not called (no valid boundary)
consolidator.archive.assert_not_called()
# ---------------------------------------------------------------------------
# TestStopCommandWithUnifiedSession — /stop command integration
# ---------------------------------------------------------------------------
class TestStopCommandWithUnifiedSession:
"""Verify /stop command works correctly with unified session enabled."""
@pytest.mark.asyncio
async def test_active_tasks_use_effective_key_in_unified_mode(self, tmp_path: Path):
"""When unified_session=True, tasks are stored under UNIFIED_SESSION_KEY."""
from nanobot.agent.loop import UNIFIED_SESSION_KEY
loop = _make_loop(tmp_path, unified_session=True)
# Create a message from telegram channel
msg = _make_msg(channel="telegram", chat_id="123456")
# Mock _dispatch to complete immediately
async def fake_dispatch(m):
pass
loop._dispatch = fake_dispatch # type: ignore[method-assign]
# Simulate the task creation flow (from _run loop)
effective_key = UNIFIED_SESSION_KEY if loop._unified_session and not msg.session_key_override else msg.session_key
task = asyncio.create_task(loop._dispatch(msg))
loop._active_tasks.setdefault(effective_key, []).append(task)
# Wait for task to complete
await task
# Verify the task is stored under UNIFIED_SESSION_KEY, not the original channel:chat_id
assert UNIFIED_SESSION_KEY in loop._active_tasks
assert "telegram:123456" not in loop._active_tasks
@pytest.mark.asyncio
async def test_stop_command_finds_task_in_unified_mode(self, tmp_path: Path):
"""cmd_stop can cancel tasks when unified_session=True."""
from nanobot.agent.loop import UNIFIED_SESSION_KEY
from nanobot.command.builtin import cmd_stop
loop = _make_loop(tmp_path, unified_session=True)
# Create a long-running task stored under UNIFIED_SESSION_KEY
async def long_running():
await asyncio.sleep(10) # Will be cancelled
task = asyncio.create_task(long_running())
loop._active_tasks[UNIFIED_SESSION_KEY] = [task]
# Create a message that would have session_key=UNIFIED_SESSION_KEY after dispatch
msg = InboundMessage(
channel="telegram",
chat_id="123456",
sender_id="user1",
content="/stop",
session_key_override=UNIFIED_SESSION_KEY, # Simulate post-dispatch state
)
ctx = CommandContext(msg=msg, session=None, key=UNIFIED_SESSION_KEY, raw="/stop", loop=loop)
# Execute /stop
result = await cmd_stop(ctx)
# Verify task was cancelled
assert task.cancelled() or task.done()
assert "Stopped 1 task" in result.content
@pytest.mark.asyncio
async def test_stop_command_cross_channel_in_unified_mode(self, tmp_path: Path):
"""In unified mode, /stop from one channel cancels tasks from another channel."""
from nanobot.agent.loop import UNIFIED_SESSION_KEY
from nanobot.command.builtin import cmd_stop
loop = _make_loop(tmp_path, unified_session=True)
# Create tasks from different channels, all stored under UNIFIED_SESSION_KEY
async def long_running():
await asyncio.sleep(10)
task1 = asyncio.create_task(long_running())
task2 = asyncio.create_task(long_running())
loop._active_tasks[UNIFIED_SESSION_KEY] = [task1, task2]
# /stop from discord should cancel tasks started from telegram
msg = InboundMessage(
channel="discord",
chat_id="789012",
sender_id="user2",
content="/stop",
session_key_override=UNIFIED_SESSION_KEY,
)
ctx = CommandContext(msg=msg, session=None, key=UNIFIED_SESSION_KEY, raw="/stop", loop=loop)
result = await cmd_stop(ctx)
# Both tasks should be cancelled
assert "Stopped 2 task" in result.content