fix: use effective session key for _active_tasks in unified mode

This commit is contained in:
whs 2026-04-08 21:39:12 +08:00 committed by Xubin Ren
parent 985f9c443b
commit b4c7cd654e
2 changed files with 112 additions and 4 deletions

View File

@ -3,7 +3,9 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import dataclasses
import json import json
import re
import os import os
import time import time
from contextlib import AsyncExitStack, nullcontext from contextlib import AsyncExitStack, nullcontext
@ -40,6 +42,8 @@ if TYPE_CHECKING:
from nanobot.cron.service import CronService from nanobot.cron.service import CronService
# Named constant for unified session key, used across multiple locations
UNIFIED_SESSION_KEY = "unified:default"
class _LoopHook(AgentHook): class _LoopHook(AgentHook):
"""Core hook for the main loop.""" """Core hook for the main loop."""
@ -385,15 +389,17 @@ class AgentLoop:
if result: if result:
await self.bus.publish_outbound(result) await self.bus.publish_outbound(result)
continue continue
# Compute the effective session key before dispatching
# This ensures /stop command can find tasks correctly when unified session is enabled
effective_key = UNIFIED_SESSION_KEY if self._unified_session and not msg.session_key_override else msg.session_key
task = asyncio.create_task(self._dispatch(msg)) task = asyncio.create_task(self._dispatch(msg))
self._active_tasks.setdefault(msg.session_key, []).append(task) self._active_tasks.setdefault(effective_key, []).append(task)
task.add_done_callback(lambda t, k=msg.session_key: self._active_tasks.get(k, []) and self._active_tasks[k].remove(t) if t in self._active_tasks.get(k, []) else None) task.add_done_callback(lambda t, k=effective_key: self._active_tasks.get(k, []) and self._active_tasks[k].remove(t) if t in self._active_tasks.get(k, []) else None)
async def _dispatch(self, msg: InboundMessage) -> None: async def _dispatch(self, msg: InboundMessage) -> None:
"""Process a message: per-session serial, cross-session concurrent.""" """Process a message: per-session serial, cross-session concurrent."""
if self._unified_session and not msg.session_key_override: if self._unified_session and not msg.session_key_override:
import dataclasses msg = dataclasses.replace(msg, session_key_override=UNIFIED_SESSION_KEY)
msg = dataclasses.replace(msg, session_key_override="unified:default")
lock = self._session_locks.setdefault(msg.session_key, asyncio.Lock()) lock = self._session_locks.setdefault(msg.session_key, asyncio.Lock())
gate = self._concurrency_gate or nullcontext() gate = self._concurrency_gate or nullcontext()
async with lock, gate: async with lock, gate:

View File

@ -398,3 +398,105 @@ class TestConsolidationUnaffectedByUnifiedSession:
consolidator.estimate_session_prompt_tokens.assert_called_once_with(session) consolidator.estimate_session_prompt_tokens.assert_called_once_with(session)
# but archive was not called (no valid boundary) # but archive was not called (no valid boundary)
consolidator.archive.assert_not_called() consolidator.archive.assert_not_called()
# ---------------------------------------------------------------------------
# TestStopCommandWithUnifiedSession — /stop command integration
# ---------------------------------------------------------------------------
class TestStopCommandWithUnifiedSession:
"""Verify /stop command works correctly with unified session enabled."""
@pytest.mark.asyncio
async def test_active_tasks_use_effective_key_in_unified_mode(self, tmp_path: Path):
"""When unified_session=True, tasks are stored under UNIFIED_SESSION_KEY."""
from nanobot.agent.loop import UNIFIED_SESSION_KEY
loop = _make_loop(tmp_path, unified_session=True)
# Create a message from telegram channel
msg = _make_msg(channel="telegram", chat_id="123456")
# Mock _dispatch to complete immediately
async def fake_dispatch(m):
pass
loop._dispatch = fake_dispatch # type: ignore[method-assign]
# Simulate the task creation flow (from _run loop)
effective_key = UNIFIED_SESSION_KEY if loop._unified_session and not msg.session_key_override else msg.session_key
task = asyncio.create_task(loop._dispatch(msg))
loop._active_tasks.setdefault(effective_key, []).append(task)
# Wait for task to complete
await task
# Verify the task is stored under UNIFIED_SESSION_KEY, not the original channel:chat_id
assert UNIFIED_SESSION_KEY in loop._active_tasks
assert "telegram:123456" not in loop._active_tasks
@pytest.mark.asyncio
async def test_stop_command_finds_task_in_unified_mode(self, tmp_path: Path):
"""cmd_stop can cancel tasks when unified_session=True."""
from nanobot.agent.loop import UNIFIED_SESSION_KEY
from nanobot.command.builtin import cmd_stop
loop = _make_loop(tmp_path, unified_session=True)
# Create a long-running task stored under UNIFIED_SESSION_KEY
async def long_running():
await asyncio.sleep(10) # Will be cancelled
task = asyncio.create_task(long_running())
loop._active_tasks[UNIFIED_SESSION_KEY] = [task]
# Create a message that would have session_key=UNIFIED_SESSION_KEY after dispatch
msg = InboundMessage(
channel="telegram",
chat_id="123456",
sender_id="user1",
content="/stop",
session_key_override=UNIFIED_SESSION_KEY, # Simulate post-dispatch state
)
ctx = CommandContext(msg=msg, session=None, key=UNIFIED_SESSION_KEY, raw="/stop", loop=loop)
# Execute /stop
result = await cmd_stop(ctx)
# Verify task was cancelled
assert task.cancelled() or task.done()
assert "Stopped 1 task" in result.content
@pytest.mark.asyncio
async def test_stop_command_cross_channel_in_unified_mode(self, tmp_path: Path):
"""In unified mode, /stop from one channel cancels tasks from another channel."""
from nanobot.agent.loop import UNIFIED_SESSION_KEY
from nanobot.command.builtin import cmd_stop
loop = _make_loop(tmp_path, unified_session=True)
# Create tasks from different channels, all stored under UNIFIED_SESSION_KEY
async def long_running():
await asyncio.sleep(10)
task1 = asyncio.create_task(long_running())
task2 = asyncio.create_task(long_running())
loop._active_tasks[UNIFIED_SESSION_KEY] = [task1, task2]
# /stop from discord should cancel tasks started from telegram
msg = InboundMessage(
channel="discord",
chat_id="789012",
sender_id="user2",
content="/stop",
session_key_override=UNIFIED_SESSION_KEY,
)
ctx = CommandContext(msg=msg, session=None, key=UNIFIED_SESSION_KEY, raw="/stop", loop=loop)
result = await cmd_stop(ctx)
# Both tasks should be cancelled
assert "Stopped 2 task" in result.content