mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-10 21:23:39 +00:00
fix: use effective session key for _active_tasks in unified mode
This commit is contained in:
parent
985f9c443b
commit
b4c7cd654e
@ -3,7 +3,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import json
|
||||
import re
|
||||
import os
|
||||
import time
|
||||
from contextlib import AsyncExitStack, nullcontext
|
||||
@ -40,6 +42,8 @@ if TYPE_CHECKING:
|
||||
from nanobot.cron.service import CronService
|
||||
|
||||
|
||||
# Named constant for unified session key, used across multiple locations
|
||||
UNIFIED_SESSION_KEY = "unified:default"
|
||||
class _LoopHook(AgentHook):
|
||||
"""Core hook for the main loop."""
|
||||
|
||||
@ -385,15 +389,17 @@ class AgentLoop:
|
||||
if result:
|
||||
await self.bus.publish_outbound(result)
|
||||
continue
|
||||
# Compute the effective session key before dispatching
|
||||
# This ensures /stop command can find tasks correctly when unified session is enabled
|
||||
effective_key = UNIFIED_SESSION_KEY if self._unified_session and not msg.session_key_override else msg.session_key
|
||||
task = asyncio.create_task(self._dispatch(msg))
|
||||
self._active_tasks.setdefault(msg.session_key, []).append(task)
|
||||
task.add_done_callback(lambda t, k=msg.session_key: self._active_tasks.get(k, []) and self._active_tasks[k].remove(t) if t in self._active_tasks.get(k, []) else None)
|
||||
self._active_tasks.setdefault(effective_key, []).append(task)
|
||||
task.add_done_callback(lambda t, k=effective_key: self._active_tasks.get(k, []) and self._active_tasks[k].remove(t) if t in self._active_tasks.get(k, []) else None)
|
||||
|
||||
async def _dispatch(self, msg: InboundMessage) -> None:
|
||||
"""Process a message: per-session serial, cross-session concurrent."""
|
||||
if self._unified_session and not msg.session_key_override:
|
||||
import dataclasses
|
||||
msg = dataclasses.replace(msg, session_key_override="unified:default")
|
||||
msg = dataclasses.replace(msg, session_key_override=UNIFIED_SESSION_KEY)
|
||||
lock = self._session_locks.setdefault(msg.session_key, asyncio.Lock())
|
||||
gate = self._concurrency_gate or nullcontext()
|
||||
async with lock, gate:
|
||||
|
||||
@ -398,3 +398,105 @@ class TestConsolidationUnaffectedByUnifiedSession:
|
||||
consolidator.estimate_session_prompt_tokens.assert_called_once_with(session)
|
||||
# but archive was not called (no valid boundary)
|
||||
consolidator.archive.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestStopCommandWithUnifiedSession — /stop command integration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestStopCommandWithUnifiedSession:
|
||||
"""Verify /stop command works correctly with unified session enabled."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_active_tasks_use_effective_key_in_unified_mode(self, tmp_path: Path):
|
||||
"""When unified_session=True, tasks are stored under UNIFIED_SESSION_KEY."""
|
||||
from nanobot.agent.loop import UNIFIED_SESSION_KEY
|
||||
|
||||
loop = _make_loop(tmp_path, unified_session=True)
|
||||
|
||||
# Create a message from telegram channel
|
||||
msg = _make_msg(channel="telegram", chat_id="123456")
|
||||
|
||||
# Mock _dispatch to complete immediately
|
||||
async def fake_dispatch(m):
|
||||
pass
|
||||
|
||||
loop._dispatch = fake_dispatch # type: ignore[method-assign]
|
||||
|
||||
# Simulate the task creation flow (from _run loop)
|
||||
effective_key = UNIFIED_SESSION_KEY if loop._unified_session and not msg.session_key_override else msg.session_key
|
||||
task = asyncio.create_task(loop._dispatch(msg))
|
||||
loop._active_tasks.setdefault(effective_key, []).append(task)
|
||||
|
||||
# Wait for task to complete
|
||||
await task
|
||||
|
||||
# Verify the task is stored under UNIFIED_SESSION_KEY, not the original channel:chat_id
|
||||
assert UNIFIED_SESSION_KEY in loop._active_tasks
|
||||
assert "telegram:123456" not in loop._active_tasks
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_command_finds_task_in_unified_mode(self, tmp_path: Path):
|
||||
"""cmd_stop can cancel tasks when unified_session=True."""
|
||||
from nanobot.agent.loop import UNIFIED_SESSION_KEY
|
||||
from nanobot.command.builtin import cmd_stop
|
||||
|
||||
loop = _make_loop(tmp_path, unified_session=True)
|
||||
|
||||
# Create a long-running task stored under UNIFIED_SESSION_KEY
|
||||
async def long_running():
|
||||
await asyncio.sleep(10) # Will be cancelled
|
||||
|
||||
task = asyncio.create_task(long_running())
|
||||
loop._active_tasks[UNIFIED_SESSION_KEY] = [task]
|
||||
|
||||
# Create a message that would have session_key=UNIFIED_SESSION_KEY after dispatch
|
||||
msg = InboundMessage(
|
||||
channel="telegram",
|
||||
chat_id="123456",
|
||||
sender_id="user1",
|
||||
content="/stop",
|
||||
session_key_override=UNIFIED_SESSION_KEY, # Simulate post-dispatch state
|
||||
)
|
||||
|
||||
ctx = CommandContext(msg=msg, session=None, key=UNIFIED_SESSION_KEY, raw="/stop", loop=loop)
|
||||
|
||||
# Execute /stop
|
||||
result = await cmd_stop(ctx)
|
||||
|
||||
# Verify task was cancelled
|
||||
assert task.cancelled() or task.done()
|
||||
assert "Stopped 1 task" in result.content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_command_cross_channel_in_unified_mode(self, tmp_path: Path):
|
||||
"""In unified mode, /stop from one channel cancels tasks from another channel."""
|
||||
from nanobot.agent.loop import UNIFIED_SESSION_KEY
|
||||
from nanobot.command.builtin import cmd_stop
|
||||
|
||||
loop = _make_loop(tmp_path, unified_session=True)
|
||||
|
||||
# Create tasks from different channels, all stored under UNIFIED_SESSION_KEY
|
||||
async def long_running():
|
||||
await asyncio.sleep(10)
|
||||
|
||||
task1 = asyncio.create_task(long_running())
|
||||
task2 = asyncio.create_task(long_running())
|
||||
loop._active_tasks[UNIFIED_SESSION_KEY] = [task1, task2]
|
||||
|
||||
# /stop from discord should cancel tasks started from telegram
|
||||
msg = InboundMessage(
|
||||
channel="discord",
|
||||
chat_id="789012",
|
||||
sender_id="user2",
|
||||
content="/stop",
|
||||
session_key_override=UNIFIED_SESSION_KEY,
|
||||
)
|
||||
|
||||
ctx = CommandContext(msg=msg, session=None, key=UNIFIED_SESSION_KEY, raw="/stop", loop=loop)
|
||||
|
||||
result = await cmd_stop(ctx)
|
||||
|
||||
# Both tasks should be cancelled
|
||||
assert "Stopped 2 task" in result.content
|
||||
Loading…
x
Reference in New Issue
Block a user