From b4c7cd654ee69ac167c5a49a61e2e04641c086ab Mon Sep 17 00:00:00 2001 From: whs Date: Wed, 8 Apr 2026 21:39:12 +0800 Subject: [PATCH] fix: use effective session key for _active_tasks in unified mode --- nanobot/agent/loop.py | 14 ++-- tests/agent/test_unified_session.py | 102 ++++++++++++++++++++++++++++ 2 files changed, 112 insertions(+), 4 deletions(-) diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index 593331c3f..76bed4158 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -3,7 +3,9 @@ from __future__ import annotations import asyncio +import dataclasses import json +import re import os import time from contextlib import AsyncExitStack, nullcontext @@ -40,6 +42,8 @@ if TYPE_CHECKING: from nanobot.cron.service import CronService +# Named constant for unified session key, used across multiple locations +UNIFIED_SESSION_KEY = "unified:default" class _LoopHook(AgentHook): """Core hook for the main loop.""" @@ -385,15 +389,17 @@ class AgentLoop: if result: await self.bus.publish_outbound(result) continue + # Compute the effective session key before dispatching + # This ensures /stop command can find tasks correctly when unified session is enabled + effective_key = UNIFIED_SESSION_KEY if self._unified_session and not msg.session_key_override else msg.session_key task = asyncio.create_task(self._dispatch(msg)) - self._active_tasks.setdefault(msg.session_key, []).append(task) - task.add_done_callback(lambda t, k=msg.session_key: self._active_tasks.get(k, []) and self._active_tasks[k].remove(t) if t in self._active_tasks.get(k, []) else None) + self._active_tasks.setdefault(effective_key, []).append(task) + task.add_done_callback(lambda t, k=effective_key: self._active_tasks.get(k, []) and self._active_tasks[k].remove(t) if t in self._active_tasks.get(k, []) else None) async def _dispatch(self, msg: InboundMessage) -> None: """Process a message: per-session serial, cross-session concurrent.""" if self._unified_session and not msg.session_key_override: - import dataclasses - msg = dataclasses.replace(msg, session_key_override="unified:default") + msg = dataclasses.replace(msg, session_key_override=UNIFIED_SESSION_KEY) lock = self._session_locks.setdefault(msg.session_key, asyncio.Lock()) gate = self._concurrency_gate or nullcontext() async with lock, gate: diff --git a/tests/agent/test_unified_session.py b/tests/agent/test_unified_session.py index 1d9eaad64..557beaca7 100644 --- a/tests/agent/test_unified_session.py +++ b/tests/agent/test_unified_session.py @@ -398,3 +398,105 @@ class TestConsolidationUnaffectedByUnifiedSession: consolidator.estimate_session_prompt_tokens.assert_called_once_with(session) # but archive was not called (no valid boundary) consolidator.archive.assert_not_called() + + +# --------------------------------------------------------------------------- +# TestStopCommandWithUnifiedSession — /stop command integration +# --------------------------------------------------------------------------- + + +class TestStopCommandWithUnifiedSession: + """Verify /stop command works correctly with unified session enabled.""" + + @pytest.mark.asyncio + async def test_active_tasks_use_effective_key_in_unified_mode(self, tmp_path: Path): + """When unified_session=True, tasks are stored under UNIFIED_SESSION_KEY.""" + from nanobot.agent.loop import UNIFIED_SESSION_KEY + + loop = _make_loop(tmp_path, unified_session=True) + + # Create a message from telegram channel + msg = _make_msg(channel="telegram", chat_id="123456") + + # Mock _dispatch to complete immediately + async def fake_dispatch(m): + pass + + loop._dispatch = fake_dispatch # type: ignore[method-assign] + + # Simulate the task creation flow (from _run loop) + effective_key = UNIFIED_SESSION_KEY if loop._unified_session and not msg.session_key_override else msg.session_key + task = asyncio.create_task(loop._dispatch(msg)) + loop._active_tasks.setdefault(effective_key, []).append(task) + + # Wait for task to complete + await task + + # Verify the task is stored under UNIFIED_SESSION_KEY, not the original channel:chat_id + assert UNIFIED_SESSION_KEY in loop._active_tasks + assert "telegram:123456" not in loop._active_tasks + + @pytest.mark.asyncio + async def test_stop_command_finds_task_in_unified_mode(self, tmp_path: Path): + """cmd_stop can cancel tasks when unified_session=True.""" + from nanobot.agent.loop import UNIFIED_SESSION_KEY + from nanobot.command.builtin import cmd_stop + + loop = _make_loop(tmp_path, unified_session=True) + + # Create a long-running task stored under UNIFIED_SESSION_KEY + async def long_running(): + await asyncio.sleep(10) # Will be cancelled + + task = asyncio.create_task(long_running()) + loop._active_tasks[UNIFIED_SESSION_KEY] = [task] + + # Create a message that would have session_key=UNIFIED_SESSION_KEY after dispatch + msg = InboundMessage( + channel="telegram", + chat_id="123456", + sender_id="user1", + content="/stop", + session_key_override=UNIFIED_SESSION_KEY, # Simulate post-dispatch state + ) + + ctx = CommandContext(msg=msg, session=None, key=UNIFIED_SESSION_KEY, raw="/stop", loop=loop) + + # Execute /stop + result = await cmd_stop(ctx) + + # Verify task was cancelled + assert task.cancelled() or task.done() + assert "Stopped 1 task" in result.content + + @pytest.mark.asyncio + async def test_stop_command_cross_channel_in_unified_mode(self, tmp_path: Path): + """In unified mode, /stop from one channel cancels tasks from another channel.""" + from nanobot.agent.loop import UNIFIED_SESSION_KEY + from nanobot.command.builtin import cmd_stop + + loop = _make_loop(tmp_path, unified_session=True) + + # Create tasks from different channels, all stored under UNIFIED_SESSION_KEY + async def long_running(): + await asyncio.sleep(10) + + task1 = asyncio.create_task(long_running()) + task2 = asyncio.create_task(long_running()) + loop._active_tasks[UNIFIED_SESSION_KEY] = [task1, task2] + + # /stop from discord should cancel tasks started from telegram + msg = InboundMessage( + channel="discord", + chat_id="789012", + sender_id="user2", + content="/stop", + session_key_override=UNIFIED_SESSION_KEY, + ) + + ctx = CommandContext(msg=msg, session=None, key=UNIFIED_SESSION_KEY, raw="/stop", loop=loop) + + result = await cmd_stop(ctx) + + # Both tasks should be cancelled + assert "Stopped 2 task" in result.content \ No newline at end of file