mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-25 12:26:00 +00:00
fix: use effective session key for _active_tasks in unified mode
This commit is contained in:
parent
985f9c443b
commit
b4c7cd654e
@ -3,7 +3,9 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import dataclasses
|
||||||
import json
|
import json
|
||||||
|
import re
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from contextlib import AsyncExitStack, nullcontext
|
from contextlib import AsyncExitStack, nullcontext
|
||||||
@ -40,6 +42,8 @@ if TYPE_CHECKING:
|
|||||||
from nanobot.cron.service import CronService
|
from nanobot.cron.service import CronService
|
||||||
|
|
||||||
|
|
||||||
|
# Named constant for unified session key, used across multiple locations
|
||||||
|
UNIFIED_SESSION_KEY = "unified:default"
|
||||||
class _LoopHook(AgentHook):
|
class _LoopHook(AgentHook):
|
||||||
"""Core hook for the main loop."""
|
"""Core hook for the main loop."""
|
||||||
|
|
||||||
@ -385,15 +389,17 @@ class AgentLoop:
|
|||||||
if result:
|
if result:
|
||||||
await self.bus.publish_outbound(result)
|
await self.bus.publish_outbound(result)
|
||||||
continue
|
continue
|
||||||
|
# Compute the effective session key before dispatching
|
||||||
|
# This ensures /stop command can find tasks correctly when unified session is enabled
|
||||||
|
effective_key = UNIFIED_SESSION_KEY if self._unified_session and not msg.session_key_override else msg.session_key
|
||||||
task = asyncio.create_task(self._dispatch(msg))
|
task = asyncio.create_task(self._dispatch(msg))
|
||||||
self._active_tasks.setdefault(msg.session_key, []).append(task)
|
self._active_tasks.setdefault(effective_key, []).append(task)
|
||||||
task.add_done_callback(lambda t, k=msg.session_key: self._active_tasks.get(k, []) and self._active_tasks[k].remove(t) if t in self._active_tasks.get(k, []) else None)
|
task.add_done_callback(lambda t, k=effective_key: self._active_tasks.get(k, []) and self._active_tasks[k].remove(t) if t in self._active_tasks.get(k, []) else None)
|
||||||
|
|
||||||
async def _dispatch(self, msg: InboundMessage) -> None:
|
async def _dispatch(self, msg: InboundMessage) -> None:
|
||||||
"""Process a message: per-session serial, cross-session concurrent."""
|
"""Process a message: per-session serial, cross-session concurrent."""
|
||||||
if self._unified_session and not msg.session_key_override:
|
if self._unified_session and not msg.session_key_override:
|
||||||
import dataclasses
|
msg = dataclasses.replace(msg, session_key_override=UNIFIED_SESSION_KEY)
|
||||||
msg = dataclasses.replace(msg, session_key_override="unified:default")
|
|
||||||
lock = self._session_locks.setdefault(msg.session_key, asyncio.Lock())
|
lock = self._session_locks.setdefault(msg.session_key, asyncio.Lock())
|
||||||
gate = self._concurrency_gate or nullcontext()
|
gate = self._concurrency_gate or nullcontext()
|
||||||
async with lock, gate:
|
async with lock, gate:
|
||||||
|
|||||||
@ -398,3 +398,105 @@ class TestConsolidationUnaffectedByUnifiedSession:
|
|||||||
consolidator.estimate_session_prompt_tokens.assert_called_once_with(session)
|
consolidator.estimate_session_prompt_tokens.assert_called_once_with(session)
|
||||||
# but archive was not called (no valid boundary)
|
# but archive was not called (no valid boundary)
|
||||||
consolidator.archive.assert_not_called()
|
consolidator.archive.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# TestStopCommandWithUnifiedSession — /stop command integration
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestStopCommandWithUnifiedSession:
|
||||||
|
"""Verify /stop command works correctly with unified session enabled."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_active_tasks_use_effective_key_in_unified_mode(self, tmp_path: Path):
|
||||||
|
"""When unified_session=True, tasks are stored under UNIFIED_SESSION_KEY."""
|
||||||
|
from nanobot.agent.loop import UNIFIED_SESSION_KEY
|
||||||
|
|
||||||
|
loop = _make_loop(tmp_path, unified_session=True)
|
||||||
|
|
||||||
|
# Create a message from telegram channel
|
||||||
|
msg = _make_msg(channel="telegram", chat_id="123456")
|
||||||
|
|
||||||
|
# Mock _dispatch to complete immediately
|
||||||
|
async def fake_dispatch(m):
|
||||||
|
pass
|
||||||
|
|
||||||
|
loop._dispatch = fake_dispatch # type: ignore[method-assign]
|
||||||
|
|
||||||
|
# Simulate the task creation flow (from _run loop)
|
||||||
|
effective_key = UNIFIED_SESSION_KEY if loop._unified_session and not msg.session_key_override else msg.session_key
|
||||||
|
task = asyncio.create_task(loop._dispatch(msg))
|
||||||
|
loop._active_tasks.setdefault(effective_key, []).append(task)
|
||||||
|
|
||||||
|
# Wait for task to complete
|
||||||
|
await task
|
||||||
|
|
||||||
|
# Verify the task is stored under UNIFIED_SESSION_KEY, not the original channel:chat_id
|
||||||
|
assert UNIFIED_SESSION_KEY in loop._active_tasks
|
||||||
|
assert "telegram:123456" not in loop._active_tasks
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stop_command_finds_task_in_unified_mode(self, tmp_path: Path):
|
||||||
|
"""cmd_stop can cancel tasks when unified_session=True."""
|
||||||
|
from nanobot.agent.loop import UNIFIED_SESSION_KEY
|
||||||
|
from nanobot.command.builtin import cmd_stop
|
||||||
|
|
||||||
|
loop = _make_loop(tmp_path, unified_session=True)
|
||||||
|
|
||||||
|
# Create a long-running task stored under UNIFIED_SESSION_KEY
|
||||||
|
async def long_running():
|
||||||
|
await asyncio.sleep(10) # Will be cancelled
|
||||||
|
|
||||||
|
task = asyncio.create_task(long_running())
|
||||||
|
loop._active_tasks[UNIFIED_SESSION_KEY] = [task]
|
||||||
|
|
||||||
|
# Create a message that would have session_key=UNIFIED_SESSION_KEY after dispatch
|
||||||
|
msg = InboundMessage(
|
||||||
|
channel="telegram",
|
||||||
|
chat_id="123456",
|
||||||
|
sender_id="user1",
|
||||||
|
content="/stop",
|
||||||
|
session_key_override=UNIFIED_SESSION_KEY, # Simulate post-dispatch state
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx = CommandContext(msg=msg, session=None, key=UNIFIED_SESSION_KEY, raw="/stop", loop=loop)
|
||||||
|
|
||||||
|
# Execute /stop
|
||||||
|
result = await cmd_stop(ctx)
|
||||||
|
|
||||||
|
# Verify task was cancelled
|
||||||
|
assert task.cancelled() or task.done()
|
||||||
|
assert "Stopped 1 task" in result.content
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stop_command_cross_channel_in_unified_mode(self, tmp_path: Path):
|
||||||
|
"""In unified mode, /stop from one channel cancels tasks from another channel."""
|
||||||
|
from nanobot.agent.loop import UNIFIED_SESSION_KEY
|
||||||
|
from nanobot.command.builtin import cmd_stop
|
||||||
|
|
||||||
|
loop = _make_loop(tmp_path, unified_session=True)
|
||||||
|
|
||||||
|
# Create tasks from different channels, all stored under UNIFIED_SESSION_KEY
|
||||||
|
async def long_running():
|
||||||
|
await asyncio.sleep(10)
|
||||||
|
|
||||||
|
task1 = asyncio.create_task(long_running())
|
||||||
|
task2 = asyncio.create_task(long_running())
|
||||||
|
loop._active_tasks[UNIFIED_SESSION_KEY] = [task1, task2]
|
||||||
|
|
||||||
|
# /stop from discord should cancel tasks started from telegram
|
||||||
|
msg = InboundMessage(
|
||||||
|
channel="discord",
|
||||||
|
chat_id="789012",
|
||||||
|
sender_id="user2",
|
||||||
|
content="/stop",
|
||||||
|
session_key_override=UNIFIED_SESSION_KEY,
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx = CommandContext(msg=msg, session=None, key=UNIFIED_SESSION_KEY, raw="/stop", loop=loop)
|
||||||
|
|
||||||
|
result = await cmd_stop(ctx)
|
||||||
|
|
||||||
|
# Both tasks should be cancelled
|
||||||
|
assert "Stopped 2 task" in result.content
|
||||||
Loading…
x
Reference in New Issue
Block a user