mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-30 06:45:55 +00:00
382 lines
12 KiB
Python
382 lines
12 KiB
Python
"""End-to-end tests for the embedded webui's HTTP routes on the WebSocket channel."""
|
|
|
|
import asyncio
|
|
import functools
|
|
import json
|
|
from pathlib import Path
|
|
from typing import Any
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
import httpx
|
|
import pytest
|
|
|
|
from nanobot.channels.websocket import WebSocketChannel
|
|
from nanobot.session.manager import Session, SessionManager
|
|
|
|
_PORT = 29900
|
|
|
|
|
|
def _ch(
|
|
bus: Any,
|
|
*,
|
|
session_manager: SessionManager | None = None,
|
|
static_dist_path: Path | None = None,
|
|
port: int = _PORT,
|
|
**extra: Any,
|
|
) -> WebSocketChannel:
|
|
cfg: dict[str, Any] = {
|
|
"enabled": True,
|
|
"allowFrom": ["*"],
|
|
"host": "127.0.0.1",
|
|
"port": port,
|
|
"path": "/",
|
|
"websocketRequiresToken": False,
|
|
}
|
|
cfg.update(extra)
|
|
return WebSocketChannel(
|
|
cfg,
|
|
bus,
|
|
session_manager=session_manager,
|
|
static_dist_path=static_dist_path,
|
|
)
|
|
|
|
|
|
@pytest.fixture()
|
|
def bus() -> MagicMock:
|
|
b = MagicMock()
|
|
b.publish_inbound = AsyncMock()
|
|
return b
|
|
|
|
|
|
async def _http_get(
|
|
url: str, headers: dict[str, str] | None = None
|
|
) -> httpx.Response:
|
|
return await asyncio.to_thread(
|
|
functools.partial(httpx.get, url, headers=headers or {}, timeout=5.0)
|
|
)
|
|
|
|
|
|
def _seed_session(workspace: Path, key: str = "websocket:test") -> SessionManager:
|
|
sm = SessionManager(workspace)
|
|
s = Session(key=key)
|
|
s.add_message("user", "hi")
|
|
s.add_message("assistant", "hello back")
|
|
sm.save(s)
|
|
return sm
|
|
|
|
|
|
def _seed_many(workspace: Path, keys: list[str]) -> SessionManager:
|
|
sm = SessionManager(workspace)
|
|
for k in keys:
|
|
s = Session(key=k)
|
|
s.add_message("user", f"hi from {k}")
|
|
sm.save(s)
|
|
return sm
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_bootstrap_returns_token_for_localhost(
|
|
bus: MagicMock, tmp_path: Path
|
|
) -> None:
|
|
sm = _seed_session(tmp_path)
|
|
channel = _ch(bus, session_manager=sm, port=29901)
|
|
server_task = asyncio.create_task(channel.start())
|
|
await asyncio.sleep(0.3)
|
|
try:
|
|
resp = await _http_get("http://127.0.0.1:29901/webui/bootstrap")
|
|
assert resp.status_code == 200
|
|
body = resp.json()
|
|
assert body["token"].startswith("nbwt_")
|
|
assert body["ws_path"] == "/"
|
|
assert body["expires_in"] > 0
|
|
assert isinstance(body.get("model_name"), str)
|
|
finally:
|
|
await channel.stop()
|
|
await server_task
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_sessions_routes_require_bearer_token(
|
|
bus: MagicMock, tmp_path: Path
|
|
) -> None:
|
|
sm = _seed_session(tmp_path, key="websocket:abc")
|
|
channel = _ch(bus, session_manager=sm, port=29902)
|
|
server_task = asyncio.create_task(channel.start())
|
|
await asyncio.sleep(0.3)
|
|
try:
|
|
# Unauthenticated → 401.
|
|
deny = await _http_get("http://127.0.0.1:29902/api/sessions")
|
|
assert deny.status_code == 401
|
|
|
|
# Mint a token via bootstrap, then call the API with it.
|
|
boot = await _http_get("http://127.0.0.1:29902/webui/bootstrap")
|
|
token = boot.json()["token"]
|
|
auth = {"Authorization": f"Bearer {token}"}
|
|
|
|
listing = await _http_get("http://127.0.0.1:29902/api/sessions", headers=auth)
|
|
assert listing.status_code == 200
|
|
keys = [s["key"] for s in listing.json()["sessions"]]
|
|
assert "websocket:abc" in keys
|
|
# Server stays an opaque source: filesystem paths must not leak to the wire.
|
|
assert all("path" not in s for s in listing.json()["sessions"])
|
|
|
|
msgs = await _http_get(
|
|
"http://127.0.0.1:29902/api/sessions/websocket:abc/messages",
|
|
headers=auth,
|
|
)
|
|
assert msgs.status_code == 200
|
|
body = msgs.json()
|
|
assert body["key"] == "websocket:abc"
|
|
assert [m["role"] for m in body["messages"]] == ["user", "assistant"]
|
|
finally:
|
|
await channel.stop()
|
|
await server_task
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_sessions_list_only_returns_websocket_sessions_by_default(
|
|
bus: MagicMock, tmp_path: Path
|
|
) -> None:
|
|
# Seed a realistic multi-channel disk state: CLI, Slack, Lark and
|
|
# websocket sessions all live in the same ``sessions/`` directory.
|
|
sm = _seed_many(
|
|
tmp_path,
|
|
[
|
|
"cli:direct",
|
|
"slack:C123",
|
|
"lark:oc_abc",
|
|
"websocket:alpha",
|
|
"websocket:beta",
|
|
],
|
|
)
|
|
channel = _ch(bus, session_manager=sm, port=29906)
|
|
server_task = asyncio.create_task(channel.start())
|
|
await asyncio.sleep(0.3)
|
|
try:
|
|
boot = await _http_get("http://127.0.0.1:29906/webui/bootstrap")
|
|
token = boot.json()["token"]
|
|
auth = {"Authorization": f"Bearer {token}"}
|
|
|
|
listing = await _http_get(
|
|
"http://127.0.0.1:29906/api/sessions", headers=auth
|
|
)
|
|
assert listing.status_code == 200
|
|
keys = {s["key"] for s in listing.json()["sessions"]}
|
|
# Only websocket-channel sessions are part of the webui surface; CLI /
|
|
# Slack / Lark rows would be non-resumable from the browser.
|
|
assert keys == {"websocket:alpha", "websocket:beta"}
|
|
finally:
|
|
await channel.stop()
|
|
await server_task
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_session_delete_removes_file(bus: MagicMock, tmp_path: Path) -> None:
|
|
sm = _seed_session(tmp_path, key="websocket:doomed")
|
|
channel = _ch(bus, session_manager=sm, port=29903)
|
|
server_task = asyncio.create_task(channel.start())
|
|
await asyncio.sleep(0.3)
|
|
try:
|
|
boot = await _http_get("http://127.0.0.1:29903/webui/bootstrap")
|
|
token = boot.json()["token"]
|
|
auth = {"Authorization": f"Bearer {token}"}
|
|
|
|
path = sm._get_session_path("websocket:doomed")
|
|
assert path.exists()
|
|
resp = await _http_get(
|
|
"http://127.0.0.1:29903/api/sessions/websocket:doomed/delete",
|
|
headers=auth,
|
|
)
|
|
assert resp.status_code == 200
|
|
assert resp.json()["deleted"] is True
|
|
assert not path.exists()
|
|
finally:
|
|
await channel.stop()
|
|
await server_task
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_session_routes_accept_percent_encoded_websocket_keys(
|
|
bus: MagicMock, tmp_path: Path
|
|
) -> None:
|
|
sm = _seed_session(tmp_path, key="websocket:encoded-key")
|
|
channel = _ch(bus, session_manager=sm, port=29910)
|
|
server_task = asyncio.create_task(channel.start())
|
|
await asyncio.sleep(0.3)
|
|
try:
|
|
boot = await _http_get("http://127.0.0.1:29910/webui/bootstrap")
|
|
token = boot.json()["token"]
|
|
auth = {"Authorization": f"Bearer {token}"}
|
|
|
|
msgs = await _http_get(
|
|
"http://127.0.0.1:29910/api/sessions/websocket%3Aencoded-key/messages",
|
|
headers=auth,
|
|
)
|
|
assert msgs.status_code == 200
|
|
assert msgs.json()["key"] == "websocket:encoded-key"
|
|
|
|
path = sm._get_session_path("websocket:encoded-key")
|
|
assert path.exists()
|
|
deleted = await _http_get(
|
|
"http://127.0.0.1:29910/api/sessions/websocket%3Aencoded-key/delete",
|
|
headers=auth,
|
|
)
|
|
assert deleted.status_code == 200
|
|
assert deleted.json()["deleted"] is True
|
|
assert not path.exists()
|
|
finally:
|
|
await channel.stop()
|
|
await server_task
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_session_routes_reject_non_websocket_keys(
|
|
bus: MagicMock, tmp_path: Path
|
|
) -> None:
|
|
sm = _seed_many(
|
|
tmp_path,
|
|
[
|
|
"websocket:kept",
|
|
"cli:direct",
|
|
"slack:C123",
|
|
],
|
|
)
|
|
channel = _ch(bus, session_manager=sm, port=29909)
|
|
server_task = asyncio.create_task(channel.start())
|
|
await asyncio.sleep(0.3)
|
|
try:
|
|
boot = await _http_get("http://127.0.0.1:29909/webui/bootstrap")
|
|
token = boot.json()["token"]
|
|
auth = {"Authorization": f"Bearer {token}"}
|
|
|
|
# The webui list already hides non-websocket sessions; handcrafted URLs
|
|
# should hit the same boundary rather than exposing or deleting them.
|
|
msgs = await _http_get(
|
|
"http://127.0.0.1:29909/api/sessions/cli:direct/messages",
|
|
headers=auth,
|
|
)
|
|
assert msgs.status_code == 404
|
|
|
|
doomed = sm._get_session_path("slack:C123")
|
|
assert doomed.exists()
|
|
deny_delete = await _http_get(
|
|
"http://127.0.0.1:29909/api/sessions/slack:C123/delete",
|
|
headers=auth,
|
|
)
|
|
assert deny_delete.status_code == 404
|
|
assert doomed.exists()
|
|
finally:
|
|
await channel.stop()
|
|
await server_task
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_session_routes_reject_invalid_key(
|
|
bus: MagicMock, tmp_path: Path
|
|
) -> None:
|
|
sm = _seed_session(tmp_path)
|
|
channel = _ch(bus, session_manager=sm, port=29904)
|
|
server_task = asyncio.create_task(channel.start())
|
|
await asyncio.sleep(0.3)
|
|
try:
|
|
boot = await _http_get("http://127.0.0.1:29904/webui/bootstrap")
|
|
token = boot.json()["token"]
|
|
auth = {"Authorization": f"Bearer {token}"}
|
|
|
|
# Invalid characters in the key -> regex match fails -> 404
|
|
# (route doesn't match, falls through to channel 404).
|
|
resp = await _http_get(
|
|
"http://127.0.0.1:29904/api/sessions/bad%20key/messages",
|
|
headers=auth,
|
|
)
|
|
assert resp.status_code in {400, 404}
|
|
finally:
|
|
await channel.stop()
|
|
await server_task
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_static_serves_index_when_dist_present(
|
|
bus: MagicMock, tmp_path: Path
|
|
) -> None:
|
|
dist = tmp_path / "dist"
|
|
dist.mkdir()
|
|
(dist / "index.html").write_text("<!doctype html><title>nbweb</title>")
|
|
(dist / "favicon.svg").write_text("<svg/>")
|
|
sm = _seed_session(tmp_path / "ws_state")
|
|
channel = _ch(bus, session_manager=sm, static_dist_path=dist, port=29905)
|
|
server_task = asyncio.create_task(channel.start())
|
|
await asyncio.sleep(0.3)
|
|
try:
|
|
# Bare ``GET /`` is a browser opening the app: it must return the SPA
|
|
# index.html, not the WS-upgrade handler's 401/426.
|
|
root = await _http_get("http://127.0.0.1:29905/")
|
|
assert root.status_code == 200
|
|
assert "nbweb" in root.text
|
|
asset = await _http_get("http://127.0.0.1:29905/favicon.svg")
|
|
assert asset.status_code == 200
|
|
assert "<svg" in asset.text
|
|
# Unknown SPA route falls back to index.html.
|
|
spa = await _http_get("http://127.0.0.1:29905/sessions/abc")
|
|
assert spa.status_code == 200
|
|
assert "nbweb" in spa.text
|
|
finally:
|
|
await channel.stop()
|
|
await server_task
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_static_rejects_path_traversal(
|
|
bus: MagicMock, tmp_path: Path
|
|
) -> None:
|
|
dist = tmp_path / "dist"
|
|
dist.mkdir()
|
|
(dist / "index.html").write_text("ok")
|
|
secret = tmp_path / "secret.txt"
|
|
secret.write_text("classified")
|
|
channel = _ch(bus, static_dist_path=dist, port=29906)
|
|
server_task = asyncio.create_task(channel.start())
|
|
await asyncio.sleep(0.3)
|
|
try:
|
|
resp = await _http_get("http://127.0.0.1:29906/../secret.txt")
|
|
# Normalized by httpx into /secret.txt → falls back to index.html, not 'classified'.
|
|
assert "classified" not in resp.text
|
|
finally:
|
|
await channel.stop()
|
|
await server_task
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_unknown_route_returns_404(bus: MagicMock) -> None:
|
|
channel = _ch(bus, port=29907)
|
|
server_task = asyncio.create_task(channel.start())
|
|
await asyncio.sleep(0.3)
|
|
try:
|
|
resp = await _http_get("http://127.0.0.1:29907/api/unknown")
|
|
assert resp.status_code == 404
|
|
finally:
|
|
await channel.stop()
|
|
await server_task
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_api_token_pool_purges_expired(bus: MagicMock, tmp_path: Path) -> None:
|
|
sm = _seed_session(tmp_path)
|
|
channel = _ch(bus, session_manager=sm, port=29908)
|
|
# Don't start a server — directly inject and validate.
|
|
import time as _time
|
|
channel._api_tokens["expired"] = _time.monotonic() - 1
|
|
channel._api_tokens["live"] = _time.monotonic() + 60
|
|
|
|
class _FakeReq:
|
|
path = "/api/sessions"
|
|
headers = {"Authorization": "Bearer expired"}
|
|
|
|
assert channel._check_api_token(_FakeReq()) is False
|
|
|
|
class _LiveReq:
|
|
path = "/api/sessions"
|
|
headers = {"Authorization": "Bearer live"}
|
|
|
|
assert channel._check_api_token(_LiveReq()) is True
|