fix: honor unified session for webui automations

This commit is contained in:
chengyongru 2026-06-12 10:19:12 +08:00
parent 1ad9d77bc7
commit 0ff8cd0cb3
6 changed files with 110 additions and 29 deletions

View File

@ -528,13 +528,6 @@ class AgentLoop:
effective_key = UNIFIED_SESSION_KEY effective_key = UNIFIED_SESSION_KEY
else: else:
effective_key = f"{channel}:{chat_id}" effective_key = f"{channel}:{chat_id}"
effective_key = self._tool_context_session_key(
channel=channel,
chat_id=chat_id,
metadata=metadata,
session_key=effective_key,
)
request_ctx = RequestContext( request_ctx = RequestContext(
channel=channel, channel=channel,
chat_id=chat_id, chat_id=chat_id,
@ -548,24 +541,6 @@ class AgentLoop:
if tool and isinstance(tool, ContextAware): if tool and isinstance(tool, ContextAware):
tool.set_context(request_ctx) tool.set_context(request_ctx)
def _tool_context_session_key(
self,
*,
channel: str,
chat_id: str,
metadata: dict | None,
session_key: str,
) -> str:
"""Return the session key tools should use for ownership-scoped resources."""
if (
self._unified_session
and channel == "websocket"
and (metadata or {}).get("webui") is True
and chat_id
):
return f"websocket:{chat_id}"
return session_key
@staticmethod @staticmethod
def _runtime_chat_id(msg: InboundMessage) -> str: def _runtime_chat_id(msg: InboundMessage) -> str:
"""Return the chat id shown in runtime metadata for the model.""" """Return the chat id shown in runtime metadata for the model."""

View File

@ -125,6 +125,7 @@ class ChannelManager:
runtime_model_name=self._webui_runtime_model_name, runtime_model_name=self._webui_runtime_model_name,
runtime_surface=self._webui_runtime_surface, runtime_surface=self._webui_runtime_surface,
runtime_capabilities_overrides=self._webui_runtime_capabilities, runtime_capabilities_overrides=self._webui_runtime_capabilities,
unified_session=self.config.agents.defaults.unified_session,
cron_service=self._cron_service, cron_service=self._cron_service,
logger=logger, logger=logger,
) )

View File

@ -39,6 +39,7 @@ def build_gateway_services(
runtime_model_name: Any | None, runtime_model_name: Any | None,
runtime_surface: str, runtime_surface: str,
runtime_capabilities_overrides: dict[str, Any] | None, runtime_capabilities_overrides: dict[str, Any] | None,
unified_session: bool = False,
disabled_skills: set[str] | None = None, disabled_skills: set[str] | None = None,
cron_service: Any | None = None, cron_service: Any | None = None,
logger: Any = default_logger, logger: Any = default_logger,
@ -61,6 +62,7 @@ def build_gateway_services(
runtime_model_name=runtime_model_name, runtime_model_name=runtime_model_name,
runtime_surface=runtime_surface, runtime_surface=runtime_surface,
runtime_capabilities_overrides=runtime_capabilities_overrides, runtime_capabilities_overrides=runtime_capabilities_overrides,
unified_session=unified_session,
bus=bus, bus=bus,
tokens=tokens, tokens=tokens,
media=media, media=media,

View File

@ -20,6 +20,7 @@ from loguru import logger
from websockets.http11 import Request as WsRequest from websockets.http11 import Request as WsRequest
from websockets.http11 import Response from websockets.http11 import Response
from nanobot.agent.loop import UNIFIED_SESSION_KEY
from nanobot.command.builtin import builtin_command_palette from nanobot.command.builtin import builtin_command_palette
from nanobot.utils.subagent_channel_display import scrub_subagent_messages_for_channel from nanobot.utils.subagent_channel_display import scrub_subagent_messages_for_channel
from nanobot.webui.file_preview import WebUIFilePreviewError, file_preview_payload from nanobot.webui.file_preview import WebUIFilePreviewError, file_preview_payload
@ -139,6 +140,7 @@ class GatewayHTTPHandler:
runtime_model_name: Callable[[], str | None] | None, runtime_model_name: Callable[[], str | None] | None,
runtime_surface: str, runtime_surface: str,
runtime_capabilities_overrides: dict[str, Any] | None, runtime_capabilities_overrides: dict[str, Any] | None,
unified_session: bool = False,
bus: MessageBus, bus: MessageBus,
tokens: GatewayTokenStore, tokens: GatewayTokenStore,
media: WebUIMediaGateway, media: WebUIMediaGateway,
@ -161,6 +163,7 @@ class GatewayHTTPHandler:
self.cron_service = cron_service self.cron_service = cron_service
self._log = log self._log = log
self._runtime_surface = runtime_surface self._runtime_surface = runtime_surface
self._unified_session = unified_session
from nanobot.webui.settings_api import runtime_capabilities as _rc from nanobot.webui.settings_api import runtime_capabilities as _rc
from nanobot.webui.settings_routes import WebUISettingsRouter from nanobot.webui.settings_routes import WebUISettingsRouter
@ -437,7 +440,7 @@ class GatewayHTTPHandler:
if not _is_websocket_channel_session_key(decoded_key): if not _is_websocket_channel_session_key(decoded_key):
return _http_error(404, "session not found") return _http_error(404, "session not found")
return _http_json_response( return _http_json_response(
session_automations_payload(self.cron_service, decoded_key) session_automations_payload(self.cron_service, self._automation_display_key(decoded_key))
) )
def _handle_session_delete(self, request: WsRequest, key: str) -> Response: def _handle_session_delete(self, request: WsRequest, key: str) -> Response:
@ -468,6 +471,12 @@ class GatewayHTTPHandler:
delete_webui_thread(decoded_key) delete_webui_thread(decoded_key)
return _http_json_response({"deleted": bool(deleted)}) return _http_json_response({"deleted": bool(deleted)})
def _automation_display_key(self, session_key: str) -> str:
"""Return the cron ownership key shown for this WebUI thread."""
if self._unified_session:
return UNIFIED_SESSION_KEY
return session_key
# -- Media routes ------------------------------------------------------- # -- Media routes -------------------------------------------------------
def _dispatch_media_routes(self, request: WsRequest, got: str) -> Response | None: def _dispatch_media_routes(self, request: WsRequest, got: str) -> Response | None:

View File

@ -11,6 +11,7 @@ from urllib.parse import urlencode
import httpx import httpx
import pytest import pytest
from nanobot.agent.loop import UNIFIED_SESSION_KEY
from nanobot.channels.websocket import WebSocketChannel, WebSocketConfig from nanobot.channels.websocket import WebSocketChannel, WebSocketConfig
from nanobot.cron.service import CronService from nanobot.cron.service import CronService
from nanobot.cron.types import CronJob, CronPayload, CronSchedule from nanobot.cron.types import CronJob, CronPayload, CronSchedule
@ -29,6 +30,7 @@ def _make_handler(
workspace_path: Path | None = None, workspace_path: Path | None = None,
runtime_model_name: Any | None = None, runtime_model_name: Any | None = None,
cron_service: CronService | None = None, cron_service: CronService | None = None,
unified_session: bool = False,
) -> GatewayServices: ) -> GatewayServices:
config = WebSocketConfig.model_validate(cfg) if isinstance(cfg, dict) else cfg config = WebSocketConfig.model_validate(cfg) if isinstance(cfg, dict) else cfg
workspace = workspace_path or Path.cwd() workspace = workspace_path or Path.cwd()
@ -42,6 +44,7 @@ def _make_handler(
runtime_model_name=runtime_model_name, runtime_model_name=runtime_model_name,
runtime_surface="browser", runtime_surface="browser",
runtime_capabilities_overrides=None, runtime_capabilities_overrides=None,
unified_session=unified_session,
cron_service=cron_service, cron_service=cron_service,
) )
@ -55,6 +58,7 @@ def _ch(
port: int = _PORT, port: int = _PORT,
runtime_model_name: Any | None = None, runtime_model_name: Any | None = None,
cron_service: CronService | None = None, cron_service: CronService | None = None,
unified_session: bool = False,
**extra: Any, **extra: Any,
) -> WebSocketChannel: ) -> WebSocketChannel:
cfg: dict[str, Any] = { cfg: dict[str, Any] = {
@ -73,6 +77,7 @@ def _ch(
workspace_path=workspace_path, workspace_path=workspace_path,
runtime_model_name=runtime_model_name, runtime_model_name=runtime_model_name,
cron_service=cron_service, cron_service=cron_service,
unified_session=unified_session,
) )
return WebSocketChannel(cfg, bus, gateway=gateway) return WebSocketChannel(cfg, bus, gateway=gateway)
@ -237,6 +242,51 @@ async def test_session_automations_route_filters_by_webui_session(
await server_task await server_task
@pytest.mark.asyncio
async def test_session_automations_route_uses_unified_owner_when_enabled(
bus: MagicMock, tmp_path: Path
) -> None:
cron = CronService(tmp_path / "cron" / "jobs.json")
hourly = CronSchedule(kind="every", every_ms=3_600_000)
cron.add_job(
name="Unified check",
schedule=hourly,
message="Check the shared session",
session_key=UNIFIED_SESSION_KEY,
)
cron.add_job(
name="Visible thread only",
schedule=hourly,
message="Do not show in unified mode",
session_key="websocket:abc",
)
channel = _ch(
bus,
session_manager=_seed_session(tmp_path, key="websocket:abc"),
cron_service=cron,
unified_session=True,
port=29917,
)
server_task = asyncio.create_task(channel.start())
await asyncio.sleep(0.3)
try:
boot = await _http_get("http://127.0.0.1:29917/webui/bootstrap")
token = boot.json()["token"]
auth = {"Authorization": f"Bearer {token}"}
for key in ("websocket%3Aabc", "websocket%3Aother"):
resp = await _http_get(
f"http://127.0.0.1:29917/api/sessions/{key}/automations",
headers=auth,
)
assert resp.status_code == 200
body = resp.json()
assert [job["name"] for job in body["jobs"]] == ["Unified check"]
finally:
await channel.stop()
await server_task
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_webui_skills_route_requires_token_and_hides_paths( async def test_webui_skills_route_requires_token_and_hides_paths(
bus: MagicMock, tmp_path: Path bus: MagicMock, tmp_path: Path
@ -751,6 +801,50 @@ async def test_session_delete_can_cascade_bound_automations(
await server_task await server_task
@pytest.mark.asyncio
async def test_session_delete_does_not_cascade_unified_automations(
bus: MagicMock, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
monkeypatch.setattr("nanobot.config.paths.get_data_dir", lambda: tmp_path)
sm = _seed_session(tmp_path, key="websocket:doomed")
cron = CronService(tmp_path / "cron" / "jobs.json")
cron.add_job(
name="Shared daily check",
schedule=CronSchedule(kind="every", every_ms=86_400_000),
message="Check the shared session",
session_key=UNIFIED_SESSION_KEY,
)
channel = _ch(
bus,
session_manager=sm,
cron_service=cron,
unified_session=True,
port=29918,
)
server_task = asyncio.create_task(channel.start())
await asyncio.sleep(0.3)
try:
boot = await _http_get("http://127.0.0.1:29918/webui/bootstrap")
token = boot.json()["token"]
auth = {"Authorization": f"Bearer {token}"}
path = sm._get_session_path("websocket:doomed")
resp = await _http_get(
"http://127.0.0.1:29918/api/sessions/websocket:doomed/delete",
headers=auth,
)
assert resp.status_code == 200
assert resp.json()["deleted"] is True
assert not path.exists()
assert [job.name for job in cron.list_bound_agent_jobs_for_session(UNIFIED_SESSION_KEY)] == [
"Shared daily check"
]
finally:
await channel.stop()
await server_task
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_session_routes_accept_percent_encoded_websocket_keys( async def test_session_routes_accept_percent_encoded_websocket_keys(
bus: MagicMock, tmp_path: Path bus: MagicMock, tmp_path: Path

View File

@ -245,8 +245,8 @@ async def test_cron_tool_basic_set_context_and_execute(tmp_path) -> None:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_webui_cron_tool_uses_visible_session_under_unified_session(tmp_path) -> None: async def test_webui_cron_tool_uses_unified_session_when_enabled(tmp_path) -> None:
"""WebUI-created automations should attach to the visible thread, not unified memory.""" """WebUI-created automations should follow unified session ownership."""
tool = CronTool(CronService(tmp_path / "jobs.json")) tool = CronTool(CronService(tmp_path / "jobs.json"))
class _Tools: class _Tools:
@ -270,7 +270,7 @@ async def test_webui_cron_tool_uses_visible_session_under_unified_session(tmp_pa
jobs = tool._cron.list_jobs() jobs = tool._cron.list_jobs()
assert len(jobs) == 1 assert len(jobs) == 1
assert jobs[0].payload.session_key == "websocket:chat-123" assert jobs[0].payload.session_key == "unified:default"
@pytest.mark.asyncio @pytest.mark.asyncio