mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-15 07:14:08 +00:00
fix: honor unified session for webui automations
This commit is contained in:
parent
1ad9d77bc7
commit
0ff8cd0cb3
@ -528,13 +528,6 @@ class AgentLoop:
|
||||
effective_key = UNIFIED_SESSION_KEY
|
||||
else:
|
||||
effective_key = f"{channel}:{chat_id}"
|
||||
effective_key = self._tool_context_session_key(
|
||||
channel=channel,
|
||||
chat_id=chat_id,
|
||||
metadata=metadata,
|
||||
session_key=effective_key,
|
||||
)
|
||||
|
||||
request_ctx = RequestContext(
|
||||
channel=channel,
|
||||
chat_id=chat_id,
|
||||
@ -548,24 +541,6 @@ class AgentLoop:
|
||||
if tool and isinstance(tool, ContextAware):
|
||||
tool.set_context(request_ctx)
|
||||
|
||||
def _tool_context_session_key(
|
||||
self,
|
||||
*,
|
||||
channel: str,
|
||||
chat_id: str,
|
||||
metadata: dict | None,
|
||||
session_key: str,
|
||||
) -> str:
|
||||
"""Return the session key tools should use for ownership-scoped resources."""
|
||||
if (
|
||||
self._unified_session
|
||||
and channel == "websocket"
|
||||
and (metadata or {}).get("webui") is True
|
||||
and chat_id
|
||||
):
|
||||
return f"websocket:{chat_id}"
|
||||
return session_key
|
||||
|
||||
@staticmethod
|
||||
def _runtime_chat_id(msg: InboundMessage) -> str:
|
||||
"""Return the chat id shown in runtime metadata for the model."""
|
||||
|
||||
@ -125,6 +125,7 @@ class ChannelManager:
|
||||
runtime_model_name=self._webui_runtime_model_name,
|
||||
runtime_surface=self._webui_runtime_surface,
|
||||
runtime_capabilities_overrides=self._webui_runtime_capabilities,
|
||||
unified_session=self.config.agents.defaults.unified_session,
|
||||
cron_service=self._cron_service,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
@ -39,6 +39,7 @@ def build_gateway_services(
|
||||
runtime_model_name: Any | None,
|
||||
runtime_surface: str,
|
||||
runtime_capabilities_overrides: dict[str, Any] | None,
|
||||
unified_session: bool = False,
|
||||
disabled_skills: set[str] | None = None,
|
||||
cron_service: Any | None = None,
|
||||
logger: Any = default_logger,
|
||||
@ -61,6 +62,7 @@ def build_gateway_services(
|
||||
runtime_model_name=runtime_model_name,
|
||||
runtime_surface=runtime_surface,
|
||||
runtime_capabilities_overrides=runtime_capabilities_overrides,
|
||||
unified_session=unified_session,
|
||||
bus=bus,
|
||||
tokens=tokens,
|
||||
media=media,
|
||||
|
||||
@ -20,6 +20,7 @@ from loguru import logger
|
||||
from websockets.http11 import Request as WsRequest
|
||||
from websockets.http11 import Response
|
||||
|
||||
from nanobot.agent.loop import UNIFIED_SESSION_KEY
|
||||
from nanobot.command.builtin import builtin_command_palette
|
||||
from nanobot.utils.subagent_channel_display import scrub_subagent_messages_for_channel
|
||||
from nanobot.webui.file_preview import WebUIFilePreviewError, file_preview_payload
|
||||
@ -139,6 +140,7 @@ class GatewayHTTPHandler:
|
||||
runtime_model_name: Callable[[], str | None] | None,
|
||||
runtime_surface: str,
|
||||
runtime_capabilities_overrides: dict[str, Any] | None,
|
||||
unified_session: bool = False,
|
||||
bus: MessageBus,
|
||||
tokens: GatewayTokenStore,
|
||||
media: WebUIMediaGateway,
|
||||
@ -161,6 +163,7 @@ class GatewayHTTPHandler:
|
||||
self.cron_service = cron_service
|
||||
self._log = log
|
||||
self._runtime_surface = runtime_surface
|
||||
self._unified_session = unified_session
|
||||
|
||||
from nanobot.webui.settings_api import runtime_capabilities as _rc
|
||||
from nanobot.webui.settings_routes import WebUISettingsRouter
|
||||
@ -437,7 +440,7 @@ class GatewayHTTPHandler:
|
||||
if not _is_websocket_channel_session_key(decoded_key):
|
||||
return _http_error(404, "session not found")
|
||||
return _http_json_response(
|
||||
session_automations_payload(self.cron_service, decoded_key)
|
||||
session_automations_payload(self.cron_service, self._automation_display_key(decoded_key))
|
||||
)
|
||||
|
||||
def _handle_session_delete(self, request: WsRequest, key: str) -> Response:
|
||||
@ -468,6 +471,12 @@ class GatewayHTTPHandler:
|
||||
delete_webui_thread(decoded_key)
|
||||
return _http_json_response({"deleted": bool(deleted)})
|
||||
|
||||
def _automation_display_key(self, session_key: str) -> str:
|
||||
"""Return the cron ownership key shown for this WebUI thread."""
|
||||
if self._unified_session:
|
||||
return UNIFIED_SESSION_KEY
|
||||
return session_key
|
||||
|
||||
# -- Media routes -------------------------------------------------------
|
||||
|
||||
def _dispatch_media_routes(self, request: WsRequest, got: str) -> Response | None:
|
||||
|
||||
@ -11,6 +11,7 @@ from urllib.parse import urlencode
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from nanobot.agent.loop import UNIFIED_SESSION_KEY
|
||||
from nanobot.channels.websocket import WebSocketChannel, WebSocketConfig
|
||||
from nanobot.cron.service import CronService
|
||||
from nanobot.cron.types import CronJob, CronPayload, CronSchedule
|
||||
@ -29,6 +30,7 @@ def _make_handler(
|
||||
workspace_path: Path | None = None,
|
||||
runtime_model_name: Any | None = None,
|
||||
cron_service: CronService | None = None,
|
||||
unified_session: bool = False,
|
||||
) -> GatewayServices:
|
||||
config = WebSocketConfig.model_validate(cfg) if isinstance(cfg, dict) else cfg
|
||||
workspace = workspace_path or Path.cwd()
|
||||
@ -42,6 +44,7 @@ def _make_handler(
|
||||
runtime_model_name=runtime_model_name,
|
||||
runtime_surface="browser",
|
||||
runtime_capabilities_overrides=None,
|
||||
unified_session=unified_session,
|
||||
cron_service=cron_service,
|
||||
)
|
||||
|
||||
@ -55,6 +58,7 @@ def _ch(
|
||||
port: int = _PORT,
|
||||
runtime_model_name: Any | None = None,
|
||||
cron_service: CronService | None = None,
|
||||
unified_session: bool = False,
|
||||
**extra: Any,
|
||||
) -> WebSocketChannel:
|
||||
cfg: dict[str, Any] = {
|
||||
@ -73,6 +77,7 @@ def _ch(
|
||||
workspace_path=workspace_path,
|
||||
runtime_model_name=runtime_model_name,
|
||||
cron_service=cron_service,
|
||||
unified_session=unified_session,
|
||||
)
|
||||
return WebSocketChannel(cfg, bus, gateway=gateway)
|
||||
|
||||
@ -237,6 +242,51 @@ async def test_session_automations_route_filters_by_webui_session(
|
||||
await server_task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_automations_route_uses_unified_owner_when_enabled(
|
||||
bus: MagicMock, tmp_path: Path
|
||||
) -> None:
|
||||
cron = CronService(tmp_path / "cron" / "jobs.json")
|
||||
hourly = CronSchedule(kind="every", every_ms=3_600_000)
|
||||
cron.add_job(
|
||||
name="Unified check",
|
||||
schedule=hourly,
|
||||
message="Check the shared session",
|
||||
session_key=UNIFIED_SESSION_KEY,
|
||||
)
|
||||
cron.add_job(
|
||||
name="Visible thread only",
|
||||
schedule=hourly,
|
||||
message="Do not show in unified mode",
|
||||
session_key="websocket:abc",
|
||||
)
|
||||
channel = _ch(
|
||||
bus,
|
||||
session_manager=_seed_session(tmp_path, key="websocket:abc"),
|
||||
cron_service=cron,
|
||||
unified_session=True,
|
||||
port=29917,
|
||||
)
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
boot = await _http_get("http://127.0.0.1:29917/webui/bootstrap")
|
||||
token = boot.json()["token"]
|
||||
auth = {"Authorization": f"Bearer {token}"}
|
||||
|
||||
for key in ("websocket%3Aabc", "websocket%3Aother"):
|
||||
resp = await _http_get(
|
||||
f"http://127.0.0.1:29917/api/sessions/{key}/automations",
|
||||
headers=auth,
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert [job["name"] for job in body["jobs"]] == ["Unified check"]
|
||||
finally:
|
||||
await channel.stop()
|
||||
await server_task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_webui_skills_route_requires_token_and_hides_paths(
|
||||
bus: MagicMock, tmp_path: Path
|
||||
@ -751,6 +801,50 @@ async def test_session_delete_can_cascade_bound_automations(
|
||||
await server_task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_delete_does_not_cascade_unified_automations(
|
||||
bus: MagicMock, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
monkeypatch.setattr("nanobot.config.paths.get_data_dir", lambda: tmp_path)
|
||||
sm = _seed_session(tmp_path, key="websocket:doomed")
|
||||
cron = CronService(tmp_path / "cron" / "jobs.json")
|
||||
cron.add_job(
|
||||
name="Shared daily check",
|
||||
schedule=CronSchedule(kind="every", every_ms=86_400_000),
|
||||
message="Check the shared session",
|
||||
session_key=UNIFIED_SESSION_KEY,
|
||||
)
|
||||
channel = _ch(
|
||||
bus,
|
||||
session_manager=sm,
|
||||
cron_service=cron,
|
||||
unified_session=True,
|
||||
port=29918,
|
||||
)
|
||||
server_task = asyncio.create_task(channel.start())
|
||||
await asyncio.sleep(0.3)
|
||||
try:
|
||||
boot = await _http_get("http://127.0.0.1:29918/webui/bootstrap")
|
||||
token = boot.json()["token"]
|
||||
auth = {"Authorization": f"Bearer {token}"}
|
||||
|
||||
path = sm._get_session_path("websocket:doomed")
|
||||
resp = await _http_get(
|
||||
"http://127.0.0.1:29918/api/sessions/websocket:doomed/delete",
|
||||
headers=auth,
|
||||
)
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["deleted"] is True
|
||||
assert not path.exists()
|
||||
assert [job.name for job in cron.list_bound_agent_jobs_for_session(UNIFIED_SESSION_KEY)] == [
|
||||
"Shared daily check"
|
||||
]
|
||||
finally:
|
||||
await channel.stop()
|
||||
await server_task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_routes_accept_percent_encoded_websocket_keys(
|
||||
bus: MagicMock, tmp_path: Path
|
||||
|
||||
@ -245,8 +245,8 @@ async def test_cron_tool_basic_set_context_and_execute(tmp_path) -> None:
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_webui_cron_tool_uses_visible_session_under_unified_session(tmp_path) -> None:
|
||||
"""WebUI-created automations should attach to the visible thread, not unified memory."""
|
||||
async def test_webui_cron_tool_uses_unified_session_when_enabled(tmp_path) -> None:
|
||||
"""WebUI-created automations should follow unified session ownership."""
|
||||
tool = CronTool(CronService(tmp_path / "jobs.json"))
|
||||
|
||||
class _Tools:
|
||||
@ -270,7 +270,7 @@ async def test_webui_cron_tool_uses_visible_session_under_unified_session(tmp_pa
|
||||
|
||||
jobs = tool._cron.list_jobs()
|
||||
assert len(jobs) == 1
|
||||
assert jobs[0].payload.session_key == "websocket:chat-123"
|
||||
assert jobs[0].payload.session_key == "unified:default"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user