mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-10 21:23:39 +00:00
feat(session): add unified_session config to share one session across all channels
This commit is contained in:
parent
bfec06a2c1
commit
743e73da3f
@ -143,6 +143,7 @@ class AgentLoop:
|
||||
channels_config: ChannelsConfig | None = None,
|
||||
timezone: str | None = None,
|
||||
hooks: list[AgentHook] | None = None,
|
||||
unified_session: bool = False,
|
||||
):
|
||||
from nanobot.config.schema import ExecToolConfig, WebToolsConfig
|
||||
|
||||
@ -189,7 +190,7 @@ class AgentLoop:
|
||||
exec_config=self.exec_config,
|
||||
restrict_to_workspace=restrict_to_workspace,
|
||||
)
|
||||
|
||||
self._unified_session = unified_session
|
||||
self._running = False
|
||||
self._mcp_servers = mcp_servers or {}
|
||||
self._mcp_stack: AsyncExitStack | None = None
|
||||
@ -390,6 +391,9 @@ class AgentLoop:
|
||||
|
||||
async def _dispatch(self, msg: InboundMessage) -> None:
|
||||
"""Process a message: per-session serial, cross-session concurrent."""
|
||||
if self._unified_session and not msg.session_key_override:
|
||||
import dataclasses
|
||||
msg = dataclasses.replace(msg, session_key_override="unified:default")
|
||||
lock = self._session_locks.setdefault(msg.session_key, asyncio.Lock())
|
||||
gate = self._concurrency_gate or nullcontext()
|
||||
async with lock, gate:
|
||||
|
||||
@ -590,6 +590,7 @@ def serve(
|
||||
mcp_servers=runtime_config.tools.mcp_servers,
|
||||
channels_config=runtime_config.channels,
|
||||
timezone=runtime_config.agents.defaults.timezone,
|
||||
unified_session=runtime_config.agents.defaults.unified_session,
|
||||
)
|
||||
|
||||
model_name = runtime_config.agents.defaults.model
|
||||
@ -681,6 +682,7 @@ def gateway(
|
||||
mcp_servers=config.tools.mcp_servers,
|
||||
channels_config=config.channels,
|
||||
timezone=config.agents.defaults.timezone,
|
||||
unified_session=config.agents.defaults.unified_session,
|
||||
)
|
||||
|
||||
# Set cron callback (needs agent)
|
||||
@ -912,6 +914,7 @@ def agent(
|
||||
mcp_servers=config.tools.mcp_servers,
|
||||
channels_config=config.channels,
|
||||
timezone=config.agents.defaults.timezone,
|
||||
unified_session=config.agents.defaults.unified_session,
|
||||
)
|
||||
restart_notice = consume_restart_notice_from_env()
|
||||
if restart_notice and should_show_cli_restart_notice(restart_notice, session_id):
|
||||
|
||||
@ -76,6 +76,7 @@ class AgentDefaults(Base):
|
||||
provider_retry_mode: Literal["standard", "persistent"] = "standard"
|
||||
reasoning_effort: str | None = None # low / medium / high / adaptive - enables LLM thinking mode
|
||||
timezone: str = "UTC" # IANA timezone, e.g. "Asia/Shanghai", "America/New_York"
|
||||
unified_session: bool = False # Share one session across all channels (single-user multi-device)
|
||||
dream: DreamConfig = Field(default_factory=DreamConfig)
|
||||
|
||||
|
||||
|
||||
@ -81,6 +81,7 @@ class Nanobot:
|
||||
restrict_to_workspace=config.tools.restrict_to_workspace,
|
||||
mcp_servers=config.tools.mcp_servers,
|
||||
timezone=defaults.timezone,
|
||||
unified_session=defaults.unified_session,
|
||||
)
|
||||
return cls(loop)
|
||||
|
||||
|
||||
195
tests/agent/test_unified_session.py
Normal file
195
tests/agent/test_unified_session.py
Normal file
@ -0,0 +1,195 @@
|
||||
"""Tests for unified_session feature.
|
||||
|
||||
Covers:
|
||||
- AgentLoop._dispatch() rewrites session_key to "unified:default" when enabled
|
||||
- Existing session_key_override is respected (not overwritten)
|
||||
- Feature is off by default (no behavior change for existing users)
|
||||
- Config schema serialises unified_session as camelCase "unifiedSession"
|
||||
- onboard-generated config.json contains "unifiedSession" key
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.config.schema import AgentDefaults, Config
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_loop(tmp_path: Path, unified_session: bool = False) -> AgentLoop:
|
||||
"""Create a minimal AgentLoop for dispatch-level tests."""
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
|
||||
with patch("nanobot.agent.loop.SessionManager"), \
|
||||
patch("nanobot.agent.loop.SubagentManager") as MockSubMgr, \
|
||||
patch("nanobot.agent.loop.Dream"):
|
||||
MockSubMgr.return_value.cancel_by_session = AsyncMock(return_value=0)
|
||||
loop = AgentLoop(
|
||||
bus=bus,
|
||||
provider=provider,
|
||||
workspace=tmp_path,
|
||||
unified_session=unified_session,
|
||||
)
|
||||
return loop
|
||||
|
||||
|
||||
def _make_msg(channel: str = "telegram", chat_id: str = "111",
|
||||
session_key_override: str | None = None) -> InboundMessage:
|
||||
return InboundMessage(
|
||||
channel=channel,
|
||||
chat_id=chat_id,
|
||||
sender_id="user1",
|
||||
content="hello",
|
||||
session_key_override=session_key_override,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestUnifiedSessionDispatch — core behaviour
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestUnifiedSessionDispatch:
|
||||
"""AgentLoop._dispatch() session key rewriting logic."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unified_session_rewrites_key_to_unified_default(self, tmp_path: Path):
|
||||
"""When unified_session=True, all messages use 'unified:default' as session key."""
|
||||
loop = _make_loop(tmp_path, unified_session=True)
|
||||
|
||||
captured: list[str] = []
|
||||
|
||||
async def fake_process(msg, **kwargs):
|
||||
captured.append(msg.session_key)
|
||||
return None
|
||||
|
||||
loop._process_message = fake_process # type: ignore[method-assign]
|
||||
|
||||
msg = _make_msg(channel="telegram", chat_id="111")
|
||||
await loop._dispatch(msg)
|
||||
|
||||
assert captured == ["unified:default"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unified_session_different_channels_share_same_key(self, tmp_path: Path):
|
||||
"""Messages from different channels all resolve to the same session key."""
|
||||
loop = _make_loop(tmp_path, unified_session=True)
|
||||
|
||||
captured: list[str] = []
|
||||
|
||||
async def fake_process(msg, **kwargs):
|
||||
captured.append(msg.session_key)
|
||||
return None
|
||||
|
||||
loop._process_message = fake_process # type: ignore[method-assign]
|
||||
|
||||
await loop._dispatch(_make_msg(channel="telegram", chat_id="111"))
|
||||
await loop._dispatch(_make_msg(channel="discord", chat_id="222"))
|
||||
await loop._dispatch(_make_msg(channel="cli", chat_id="direct"))
|
||||
|
||||
assert captured == ["unified:default", "unified:default", "unified:default"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unified_session_disabled_preserves_original_key(self, tmp_path: Path):
|
||||
"""When unified_session=False (default), session key is channel:chat_id as usual."""
|
||||
loop = _make_loop(tmp_path, unified_session=False)
|
||||
|
||||
captured: list[str] = []
|
||||
|
||||
async def fake_process(msg, **kwargs):
|
||||
captured.append(msg.session_key)
|
||||
return None
|
||||
|
||||
loop._process_message = fake_process # type: ignore[method-assign]
|
||||
|
||||
msg = _make_msg(channel="telegram", chat_id="999")
|
||||
await loop._dispatch(msg)
|
||||
|
||||
assert captured == ["telegram:999"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unified_session_respects_existing_override(self, tmp_path: Path):
|
||||
"""If session_key_override is already set (e.g. Telegram thread), it is NOT overwritten."""
|
||||
loop = _make_loop(tmp_path, unified_session=True)
|
||||
|
||||
captured: list[str] = []
|
||||
|
||||
async def fake_process(msg, **kwargs):
|
||||
captured.append(msg.session_key)
|
||||
return None
|
||||
|
||||
loop._process_message = fake_process # type: ignore[method-assign]
|
||||
|
||||
msg = _make_msg(channel="telegram", chat_id="111", session_key_override="telegram:thread:42")
|
||||
await loop._dispatch(msg)
|
||||
|
||||
assert captured == ["telegram:thread:42"]
|
||||
|
||||
def test_unified_session_default_is_false(self, tmp_path: Path):
|
||||
"""unified_session defaults to False — no behavior change for existing users."""
|
||||
loop = _make_loop(tmp_path)
|
||||
assert loop._unified_session is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestUnifiedSessionConfig — schema & serialisation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestUnifiedSessionConfig:
|
||||
"""Config schema and onboard serialisation for unified_session."""
|
||||
|
||||
def test_agent_defaults_unified_session_default_is_false(self):
|
||||
"""AgentDefaults.unified_session defaults to False."""
|
||||
defaults = AgentDefaults()
|
||||
assert defaults.unified_session is False
|
||||
|
||||
def test_agent_defaults_unified_session_can_be_enabled(self):
|
||||
"""AgentDefaults.unified_session can be set to True."""
|
||||
defaults = AgentDefaults(unified_session=True)
|
||||
assert defaults.unified_session is True
|
||||
|
||||
def test_config_serialises_unified_session_as_camel_case(self):
|
||||
"""model_dump(by_alias=True) outputs 'unifiedSession' (camelCase) for JSON."""
|
||||
config = Config()
|
||||
data = config.model_dump(mode="json", by_alias=True)
|
||||
agents_defaults = data["agents"]["defaults"]
|
||||
assert "unifiedSession" in agents_defaults
|
||||
assert agents_defaults["unifiedSession"] is False
|
||||
|
||||
def test_config_parses_unified_session_from_camel_case(self):
|
||||
"""Config can be loaded from JSON with camelCase 'unifiedSession'."""
|
||||
raw = {"agents": {"defaults": {"unifiedSession": True}}}
|
||||
config = Config.model_validate(raw)
|
||||
assert config.agents.defaults.unified_session is True
|
||||
|
||||
def test_config_parses_unified_session_from_snake_case(self):
|
||||
"""Config also accepts snake_case 'unified_session' (populate_by_name=True)."""
|
||||
raw = {"agents": {"defaults": {"unified_session": True}}}
|
||||
config = Config.model_validate(raw)
|
||||
assert config.agents.defaults.unified_session is True
|
||||
|
||||
def test_onboard_generated_config_contains_unified_session(self, tmp_path: Path):
|
||||
"""save_config() writes 'unifiedSession' into config.json (simulates nanobot onboard)."""
|
||||
from nanobot.config.loader import save_config
|
||||
|
||||
config = Config()
|
||||
config_path = tmp_path / "config.json"
|
||||
save_config(config, config_path)
|
||||
|
||||
with open(config_path, encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
agents_defaults = data["agents"]["defaults"]
|
||||
assert "unifiedSession" in agents_defaults, (
|
||||
"onboard-generated config.json must contain 'unifiedSession' key"
|
||||
)
|
||||
assert agents_defaults["unifiedSession"] is False
|
||||
Loading…
x
Reference in New Issue
Block a user