mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-02 17:32:39 +00:00
838 lines
29 KiB
Python
838 lines
29 KiB
Python
"""Tests for the OpenAI-compatible API server."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from nanobot.api.server import _SessionLocks, _chat_completion_response, _error_json, create_app
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# aiohttp test client helper
|
|
# ---------------------------------------------------------------------------
|
|
|
|
try:
|
|
from aiohttp.test_utils import AioHTTPTestCase, unittest_run_loop
|
|
from aiohttp import web
|
|
|
|
HAS_AIOHTTP = True
|
|
except ImportError:
|
|
HAS_AIOHTTP = False
|
|
|
|
pytest_plugins = ("pytest_asyncio",)
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Unit tests — no aiohttp required
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestSessionLocks:
|
|
def test_acquire_creates_lock(self):
|
|
sl = _SessionLocks()
|
|
lock = sl.acquire("k1")
|
|
assert isinstance(lock, asyncio.Lock)
|
|
|
|
def test_same_key_returns_same_lock(self):
|
|
sl = _SessionLocks()
|
|
l1 = sl.acquire("k1")
|
|
l2 = sl.acquire("k1")
|
|
assert l1 is l2
|
|
|
|
def test_different_keys_different_locks(self):
|
|
sl = _SessionLocks()
|
|
l1 = sl.acquire("k1")
|
|
l2 = sl.acquire("k2")
|
|
assert l1 is not l2
|
|
|
|
def test_release_cleans_up(self):
|
|
sl = _SessionLocks()
|
|
sl.acquire("k1")
|
|
sl.release("k1")
|
|
assert "k1" not in sl._locks
|
|
|
|
def test_release_keeps_lock_if_still_referenced(self):
|
|
sl = _SessionLocks()
|
|
sl.acquire("k1")
|
|
sl.acquire("k1")
|
|
sl.release("k1")
|
|
assert "k1" in sl._locks
|
|
sl.release("k1")
|
|
assert "k1" not in sl._locks
|
|
|
|
|
|
class TestResponseHelpers:
|
|
def test_error_json(self):
|
|
resp = _error_json(400, "bad request")
|
|
assert resp.status == 400
|
|
body = json.loads(resp.body)
|
|
assert body["error"]["message"] == "bad request"
|
|
assert body["error"]["code"] == 400
|
|
|
|
def test_chat_completion_response(self):
|
|
result = _chat_completion_response("hello world", "test-model")
|
|
assert result["object"] == "chat.completion"
|
|
assert result["model"] == "test-model"
|
|
assert result["choices"][0]["message"]["content"] == "hello world"
|
|
assert result["choices"][0]["finish_reason"] == "stop"
|
|
assert result["id"].startswith("chatcmpl-")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Integration tests — require aiohttp
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _make_mock_agent(response_text: str = "mock response") -> MagicMock:
|
|
agent = MagicMock()
|
|
agent.process_direct = AsyncMock(return_value=response_text)
|
|
agent._connect_mcp = AsyncMock()
|
|
agent.close_mcp = AsyncMock()
|
|
return agent
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_agent():
|
|
return _make_mock_agent()
|
|
|
|
|
|
@pytest.fixture
|
|
def app(mock_agent):
|
|
return create_app(mock_agent, model_name="test-model", request_timeout=10.0)
|
|
|
|
|
|
@pytest.fixture
|
|
def cli(event_loop, aiohttp_client, app):
|
|
return event_loop.run_until_complete(aiohttp_client(app))
|
|
|
|
|
|
# ---- Missing header tests ----
|
|
|
|
|
|
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
|
@pytest.mark.asyncio
|
|
async def test_missing_session_key_returns_400(aiohttp_client, app):
|
|
client = await aiohttp_client(app)
|
|
resp = await client.post(
|
|
"/v1/chat/completions",
|
|
json={"messages": [{"role": "user", "content": "hello"}]},
|
|
)
|
|
assert resp.status == 400
|
|
body = await resp.json()
|
|
assert "x-session-key" in body["error"]["message"]
|
|
|
|
|
|
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
|
@pytest.mark.asyncio
|
|
async def test_empty_session_key_returns_400(aiohttp_client, app):
|
|
client = await aiohttp_client(app)
|
|
resp = await client.post(
|
|
"/v1/chat/completions",
|
|
json={"messages": [{"role": "user", "content": "hello"}]},
|
|
headers={"x-session-key": " "},
|
|
)
|
|
assert resp.status == 400
|
|
|
|
|
|
# ---- Missing messages tests ----
|
|
|
|
|
|
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
|
@pytest.mark.asyncio
|
|
async def test_missing_messages_returns_400(aiohttp_client, app):
|
|
client = await aiohttp_client(app)
|
|
resp = await client.post(
|
|
"/v1/chat/completions",
|
|
json={"model": "test"},
|
|
headers={"x-session-key": "test-key"},
|
|
)
|
|
assert resp.status == 400
|
|
|
|
|
|
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
|
@pytest.mark.asyncio
|
|
async def test_no_user_message_returns_400(aiohttp_client, app):
|
|
client = await aiohttp_client(app)
|
|
resp = await client.post(
|
|
"/v1/chat/completions",
|
|
json={"messages": [{"role": "system", "content": "you are a bot"}]},
|
|
headers={"x-session-key": "test-key"},
|
|
)
|
|
assert resp.status == 400
|
|
|
|
|
|
# ---- Stream not supported ----
|
|
|
|
|
|
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
|
@pytest.mark.asyncio
|
|
async def test_stream_true_returns_400(aiohttp_client, app):
|
|
client = await aiohttp_client(app)
|
|
resp = await client.post(
|
|
"/v1/chat/completions",
|
|
json={
|
|
"messages": [{"role": "user", "content": "hello"}],
|
|
"stream": True,
|
|
},
|
|
headers={"x-session-key": "test-key"},
|
|
)
|
|
assert resp.status == 400
|
|
body = await resp.json()
|
|
assert "stream" in body["error"]["message"].lower()
|
|
|
|
|
|
# ---- Successful request ----
|
|
|
|
|
|
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
|
@pytest.mark.asyncio
|
|
async def test_successful_request(aiohttp_client, mock_agent):
|
|
app = create_app(mock_agent, model_name="test-model")
|
|
client = await aiohttp_client(app)
|
|
resp = await client.post(
|
|
"/v1/chat/completions",
|
|
json={"messages": [{"role": "user", "content": "hello"}]},
|
|
headers={"x-session-key": "wx:dm:user1"},
|
|
)
|
|
assert resp.status == 200
|
|
body = await resp.json()
|
|
assert body["choices"][0]["message"]["content"] == "mock response"
|
|
assert body["model"] == "test-model"
|
|
mock_agent.process_direct.assert_called_once_with(
|
|
content="hello",
|
|
session_key="wx:dm:user1",
|
|
channel="api",
|
|
chat_id="wx:dm:user1",
|
|
isolate_memory=True,
|
|
disabled_tools={"read_file", "write_file", "edit_file", "list_dir", "exec"},
|
|
)
|
|
|
|
|
|
# ---- Session isolation ----
|
|
|
|
|
|
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
|
@pytest.mark.asyncio
|
|
async def test_session_isolation_different_keys(aiohttp_client):
|
|
"""Two different session keys must route to separate session_key arguments."""
|
|
call_log: list[str] = []
|
|
|
|
async def fake_process(content, session_key="", channel="", chat_id="",
|
|
isolate_memory=False, disabled_tools=None):
|
|
call_log.append(session_key)
|
|
return f"reply to {session_key}"
|
|
|
|
agent = MagicMock()
|
|
agent.process_direct = fake_process
|
|
agent._connect_mcp = AsyncMock()
|
|
agent.close_mcp = AsyncMock()
|
|
|
|
app = create_app(agent, model_name="m")
|
|
client = await aiohttp_client(app)
|
|
|
|
r1 = await client.post(
|
|
"/v1/chat/completions",
|
|
json={"messages": [{"role": "user", "content": "msg1"}]},
|
|
headers={"x-session-key": "wx:dm:alice"},
|
|
)
|
|
r2 = await client.post(
|
|
"/v1/chat/completions",
|
|
json={"messages": [{"role": "user", "content": "msg2"}]},
|
|
headers={"x-session-key": "wx:group:g1:user:bob"},
|
|
)
|
|
|
|
assert r1.status == 200
|
|
assert r2.status == 200
|
|
|
|
b1 = await r1.json()
|
|
b2 = await r2.json()
|
|
assert b1["choices"][0]["message"]["content"] == "reply to wx:dm:alice"
|
|
assert b2["choices"][0]["message"]["content"] == "reply to wx:group:g1:user:bob"
|
|
assert call_log == ["wx:dm:alice", "wx:group:g1:user:bob"]
|
|
|
|
|
|
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
|
@pytest.mark.asyncio
|
|
async def test_same_session_key_serialized(aiohttp_client):
|
|
"""Concurrent requests with the same session key must run serially."""
|
|
order: list[str] = []
|
|
barrier = asyncio.Event()
|
|
|
|
async def slow_process(content, session_key="", channel="", chat_id="",
|
|
isolate_memory=False, disabled_tools=None):
|
|
order.append(f"start:{content}")
|
|
if content == "first":
|
|
barrier.set()
|
|
await asyncio.sleep(0.1) # hold lock
|
|
else:
|
|
await barrier.wait() # ensure "second" starts after "first" begins
|
|
order.append(f"end:{content}")
|
|
return content
|
|
|
|
agent = MagicMock()
|
|
agent.process_direct = slow_process
|
|
agent._connect_mcp = AsyncMock()
|
|
agent.close_mcp = AsyncMock()
|
|
|
|
app = create_app(agent, model_name="m")
|
|
client = await aiohttp_client(app)
|
|
|
|
async def send(msg):
|
|
return await client.post(
|
|
"/v1/chat/completions",
|
|
json={"messages": [{"role": "user", "content": msg}]},
|
|
headers={"x-session-key": "same-key"},
|
|
)
|
|
|
|
r1, r2 = await asyncio.gather(send("first"), send("second"))
|
|
assert r1.status == 200
|
|
assert r2.status == 200
|
|
# "first" must fully complete before "second" starts
|
|
assert order.index("end:first") < order.index("start:second")
|
|
|
|
|
|
# ---- /v1/models ----
|
|
|
|
|
|
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
|
@pytest.mark.asyncio
|
|
async def test_models_endpoint(aiohttp_client, app):
|
|
client = await aiohttp_client(app)
|
|
resp = await client.get("/v1/models")
|
|
assert resp.status == 200
|
|
body = await resp.json()
|
|
assert body["object"] == "list"
|
|
assert len(body["data"]) >= 1
|
|
assert body["data"][0]["id"] == "test-model"
|
|
|
|
|
|
# ---- /health ----
|
|
|
|
|
|
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
|
@pytest.mark.asyncio
|
|
async def test_health_endpoint(aiohttp_client, app):
|
|
client = await aiohttp_client(app)
|
|
resp = await client.get("/health")
|
|
assert resp.status == 200
|
|
body = await resp.json()
|
|
assert body["status"] == "ok"
|
|
|
|
|
|
# ---- Multimodal content array ----
|
|
|
|
|
|
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
|
@pytest.mark.asyncio
|
|
async def test_multimodal_content_extracts_text(aiohttp_client, mock_agent):
|
|
app = create_app(mock_agent, model_name="m")
|
|
client = await aiohttp_client(app)
|
|
resp = await client.post(
|
|
"/v1/chat/completions",
|
|
json={
|
|
"messages": [
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "text", "text": "describe this"},
|
|
{"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}},
|
|
],
|
|
}
|
|
]
|
|
},
|
|
headers={"x-session-key": "test"},
|
|
)
|
|
assert resp.status == 200
|
|
mock_agent.process_direct.assert_called_once()
|
|
call_kwargs = mock_agent.process_direct.call_args
|
|
assert call_kwargs.kwargs["content"] == "describe this"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Memory isolation regression tests (root cause of cross-session leakage)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestMemoryIsolation:
|
|
"""Verify that per-session-key memory prevents cross-session context leakage.
|
|
|
|
Root cause: ContextBuilder.build_system_prompt() reads a SHARED
|
|
workspace/memory/MEMORY.md into the system prompt of ALL users.
|
|
If user_1 writes "my name is Alice" and the agent persists it to
|
|
MEMORY.md, user_2/user_N will see it.
|
|
|
|
Fix: API mode passes a per-session MemoryStore so each session reads/
|
|
writes its own MEMORY.md.
|
|
"""
|
|
|
|
def test_context_builder_uses_override_memory(self, tmp_path):
|
|
"""build_system_prompt with memory_store= must use the override, not global."""
|
|
from nanobot.agent.context import ContextBuilder
|
|
from nanobot.agent.memory import MemoryStore
|
|
|
|
workspace = tmp_path / "workspace"
|
|
workspace.mkdir()
|
|
(workspace / "memory").mkdir()
|
|
(workspace / "memory" / "MEMORY.md").write_text("Global: I am shared context")
|
|
|
|
ctx = ContextBuilder(workspace)
|
|
|
|
# Without override → sees global memory
|
|
prompt_global = ctx.build_system_prompt()
|
|
assert "I am shared context" in prompt_global
|
|
|
|
# With override → sees only the override's memory
|
|
override_dir = tmp_path / "isolated" / "memory"
|
|
override_dir.mkdir(parents=True)
|
|
(override_dir / "MEMORY.md").write_text("User Alice's private note")
|
|
|
|
override_store = MemoryStore.__new__(MemoryStore)
|
|
override_store.memory_dir = override_dir
|
|
override_store.memory_file = override_dir / "MEMORY.md"
|
|
override_store.history_file = override_dir / "HISTORY.md"
|
|
|
|
prompt_isolated = ctx.build_system_prompt(memory_store=override_store)
|
|
assert "User Alice's private note" in prompt_isolated
|
|
assert "I am shared context" not in prompt_isolated
|
|
|
|
def test_different_session_keys_get_different_memory_dirs(self, tmp_path):
|
|
"""_isolated_memory_store must return distinct paths for distinct keys."""
|
|
from unittest.mock import MagicMock
|
|
from nanobot.agent.loop import AgentLoop
|
|
|
|
agent = MagicMock(spec=AgentLoop)
|
|
agent.workspace = tmp_path
|
|
agent._isolated_memory_store = AgentLoop._isolated_memory_store.__get__(agent)
|
|
|
|
store_a = agent._isolated_memory_store("wx:dm:alice")
|
|
store_b = agent._isolated_memory_store("wx:dm:bob")
|
|
|
|
assert store_a.memory_file != store_b.memory_file
|
|
assert store_a.memory_dir != store_b.memory_dir
|
|
assert store_a.memory_file.parent.exists()
|
|
assert store_b.memory_file.parent.exists()
|
|
|
|
def test_isolated_memory_does_not_leak_across_sessions(self, tmp_path):
|
|
"""End-to-end: writing to one session's memory must not appear in another's."""
|
|
from nanobot.agent.context import ContextBuilder
|
|
from nanobot.agent.memory import MemoryStore
|
|
|
|
workspace = tmp_path / "workspace"
|
|
workspace.mkdir()
|
|
(workspace / "memory").mkdir()
|
|
(workspace / "memory" / "MEMORY.md").write_text("")
|
|
|
|
ctx = ContextBuilder(workspace)
|
|
|
|
# Simulate two isolated memory stores (as the API server would create)
|
|
def make_store(name):
|
|
d = tmp_path / "sessions" / name / "memory"
|
|
d.mkdir(parents=True)
|
|
s = MemoryStore.__new__(MemoryStore)
|
|
s.memory_dir = d
|
|
s.memory_file = d / "MEMORY.md"
|
|
s.history_file = d / "HISTORY.md"
|
|
return s
|
|
|
|
store_alice = make_store("wx_dm_alice")
|
|
store_bob = make_store("wx_dm_bob")
|
|
|
|
# Use unique markers that won't appear in builtin skills/prompts
|
|
alice_marker = "XYZZY_ALICE_PRIVATE_MARKER_42"
|
|
store_alice.write_long_term(alice_marker)
|
|
|
|
# Alice's prompt sees it
|
|
prompt_alice = ctx.build_system_prompt(memory_store=store_alice)
|
|
assert alice_marker in prompt_alice
|
|
|
|
# Bob's prompt must NOT see it
|
|
prompt_bob = ctx.build_system_prompt(memory_store=store_bob)
|
|
assert alice_marker not in prompt_bob
|
|
|
|
# Global prompt must NOT see it either
|
|
prompt_global = ctx.build_system_prompt()
|
|
assert alice_marker not in prompt_global
|
|
|
|
def test_build_messages_passes_memory_store(self, tmp_path):
|
|
"""build_messages must forward memory_store to build_system_prompt."""
|
|
from nanobot.agent.context import ContextBuilder
|
|
from nanobot.agent.memory import MemoryStore
|
|
|
|
workspace = tmp_path / "workspace"
|
|
workspace.mkdir()
|
|
(workspace / "memory").mkdir()
|
|
(workspace / "memory" / "MEMORY.md").write_text("GLOBAL_SECRET")
|
|
|
|
ctx = ContextBuilder(workspace)
|
|
|
|
override_dir = tmp_path / "per_session" / "memory"
|
|
override_dir.mkdir(parents=True)
|
|
(override_dir / "MEMORY.md").write_text("SESSION_PRIVATE")
|
|
|
|
override_store = MemoryStore.__new__(MemoryStore)
|
|
override_store.memory_dir = override_dir
|
|
override_store.memory_file = override_dir / "MEMORY.md"
|
|
override_store.history_file = override_dir / "HISTORY.md"
|
|
|
|
messages = ctx.build_messages(
|
|
history=[], current_message="hello",
|
|
memory_store=override_store,
|
|
)
|
|
system_content = messages[0]["content"]
|
|
assert "SESSION_PRIVATE" in system_content
|
|
assert "GLOBAL_SECRET" not in system_content
|
|
|
|
def test_api_handler_passes_isolate_memory_and_disabled_tools(self):
|
|
"""The API handler must call process_direct with isolate_memory=True and disabled filesystem tools."""
|
|
import ast
|
|
from pathlib import Path
|
|
|
|
server_path = Path(__file__).parent.parent / "nanobot" / "api" / "server.py"
|
|
source = server_path.read_text()
|
|
tree = ast.parse(source)
|
|
|
|
found_isolate = False
|
|
found_disabled = False
|
|
for node in ast.walk(tree):
|
|
if isinstance(node, ast.keyword):
|
|
if node.arg == "isolate_memory" and isinstance(node.value, ast.Constant) and node.value.value is True:
|
|
found_isolate = True
|
|
if node.arg == "disabled_tools":
|
|
found_disabled = True
|
|
assert found_isolate, "server.py must call process_direct with isolate_memory=True"
|
|
assert found_disabled, "server.py must call process_direct with disabled_tools"
|
|
|
|
def test_disabled_tools_constant_blocks_filesystem_and_exec(self):
|
|
"""_API_DISABLED_TOOLS must include all filesystem tool names and exec."""
|
|
from nanobot.api.server import _API_DISABLED_TOOLS
|
|
for name in ("read_file", "write_file", "edit_file", "list_dir", "exec"):
|
|
assert name in _API_DISABLED_TOOLS, f"{name} missing from _API_DISABLED_TOOLS"
|
|
|
|
def test_system_prompt_uses_isolated_memory_path(self, tmp_path):
|
|
"""When memory_store is provided, the system prompt must reference
|
|
the store's paths, NOT the global workspace/memory/MEMORY.md."""
|
|
from nanobot.agent.context import ContextBuilder
|
|
from nanobot.agent.memory import MemoryStore
|
|
|
|
workspace = tmp_path / "workspace"
|
|
workspace.mkdir()
|
|
(workspace / "memory").mkdir()
|
|
|
|
ctx = ContextBuilder(workspace)
|
|
|
|
# Default prompt references global path
|
|
default_prompt = ctx.build_system_prompt()
|
|
assert "memory/MEMORY.md" in default_prompt
|
|
|
|
# Isolated store
|
|
iso_dir = tmp_path / "sessions" / "wx_dm_alice" / "memory"
|
|
iso_dir.mkdir(parents=True)
|
|
store = MemoryStore.__new__(MemoryStore)
|
|
store.memory_dir = iso_dir
|
|
store.memory_file = iso_dir / "MEMORY.md"
|
|
store.history_file = iso_dir / "HISTORY.md"
|
|
|
|
iso_prompt = ctx.build_system_prompt(memory_store=store)
|
|
# Must reference the isolated path
|
|
assert str(iso_dir / "MEMORY.md") in iso_prompt
|
|
assert str(iso_dir / "HISTORY.md") in iso_prompt
|
|
# Must NOT reference the global workspace memory path
|
|
global_mem = str(workspace.resolve() / "memory" / "MEMORY.md")
|
|
assert global_mem not in iso_prompt
|
|
|
|
def test_run_agent_loop_filters_disabled_tools(self):
|
|
"""_run_agent_loop must exclude disabled tools from definitions
|
|
and reject execution of disabled tools."""
|
|
from nanobot.agent.tools.registry import ToolRegistry
|
|
|
|
registry = ToolRegistry()
|
|
|
|
# Create minimal fake tool definitions
|
|
class FakeTool:
|
|
def __init__(self, n):
|
|
self._name = n
|
|
|
|
@property
|
|
def name(self):
|
|
return self._name
|
|
|
|
def to_schema(self):
|
|
return {"type": "function", "function": {"name": self._name, "parameters": {}}}
|
|
|
|
def validate_params(self, params):
|
|
return []
|
|
|
|
async def execute(self, **kw):
|
|
return "ok"
|
|
|
|
for n in ("read_file", "write_file", "web_search", "exec"):
|
|
registry.register(FakeTool(n))
|
|
|
|
all_defs = registry.get_definitions()
|
|
assert len(all_defs) == 4
|
|
|
|
disabled = {"read_file", "write_file"}
|
|
filtered = [d for d in all_defs
|
|
if d.get("function", {}).get("name") not in disabled]
|
|
assert len(filtered) == 2
|
|
names = {d["function"]["name"] for d in filtered}
|
|
assert names == {"web_search", "exec"}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Consolidation isolation regression tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestConsolidationIsolation:
|
|
"""Verify that memory consolidation in API (isolate_memory) mode writes
|
|
to the per-session directory and never touches global workspace/memory."""
|
|
|
|
def test_consolidate_writes_to_isolated_dir_not_global(self, tmp_path):
|
|
"""End-to-end: MemoryStore.consolidate with an isolated store must
|
|
write HISTORY.md in the isolated dir, not in workspace/memory."""
|
|
from nanobot.agent.memory import MemoryStore
|
|
|
|
# Set up global workspace memory
|
|
global_mem_dir = tmp_path / "workspace" / "memory"
|
|
global_mem_dir.mkdir(parents=True)
|
|
(global_mem_dir / "MEMORY.md").write_text("")
|
|
(global_mem_dir / "HISTORY.md").write_text("")
|
|
|
|
# Set up isolated per-session store
|
|
iso_dir = tmp_path / "sessions" / "wx_dm_alice" / "memory"
|
|
iso_dir.mkdir(parents=True)
|
|
|
|
iso_store = MemoryStore.__new__(MemoryStore)
|
|
iso_store.memory_dir = iso_dir
|
|
iso_store.memory_file = iso_dir / "MEMORY.md"
|
|
iso_store.history_file = iso_dir / "HISTORY.md"
|
|
|
|
# Write via the isolated store
|
|
iso_store.write_long_term("Alice's private data")
|
|
iso_store.append_history("[2025-01-01 00:00] Alice asked about X")
|
|
|
|
# Isolated store has the data
|
|
assert "Alice's private data" in iso_store.read_long_term()
|
|
assert "Alice asked about X" in iso_store.history_file.read_text()
|
|
|
|
# Global store must NOT have it
|
|
assert (global_mem_dir / "MEMORY.md").read_text() == ""
|
|
assert (global_mem_dir / "HISTORY.md").read_text() == ""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_new_command_uses_isolated_store(self, tmp_path):
|
|
"""process_direct(isolate_memory=True) + /new must archive to the isolated store."""
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
from nanobot.agent.loop import AgentLoop
|
|
from nanobot.agent.memory import MemoryStore
|
|
from nanobot.bus.queue import MessageBus
|
|
from nanobot.providers.base import LLMResponse
|
|
|
|
bus = MessageBus()
|
|
provider = MagicMock()
|
|
provider.get_default_model.return_value = "test-model"
|
|
provider.estimate_prompt_tokens.return_value = (10_000, "test")
|
|
agent = AgentLoop(
|
|
bus=bus, provider=provider, workspace=tmp_path,
|
|
model="test-model", context_window_tokens=1,
|
|
)
|
|
agent._mcp_connected = True # skip MCP connect
|
|
agent.tools.get_definitions = MagicMock(return_value=[])
|
|
|
|
# Pre-populate session so /new has something to archive
|
|
session = agent.sessions.get_or_create("api:alice")
|
|
for i in range(3):
|
|
session.add_message("user", f"msg{i}")
|
|
session.add_message("assistant", f"resp{i}")
|
|
agent.sessions.save(session)
|
|
|
|
used_store = None
|
|
|
|
async def _tracking_consolidate(messages, store=None) -> bool:
|
|
nonlocal used_store
|
|
used_store = store
|
|
return True
|
|
|
|
agent.memory_consolidator.consolidate_messages = _tracking_consolidate # type: ignore[method-assign]
|
|
|
|
result = await agent.process_direct(
|
|
"/new", session_key="api:alice", isolate_memory=True,
|
|
)
|
|
|
|
assert "new session started" in result.lower()
|
|
assert used_store is not None, "consolidation must receive a store"
|
|
assert isinstance(used_store, MemoryStore)
|
|
assert "sessions" in str(used_store.memory_dir), (
|
|
"store must point to per-session dir, not global workspace"
|
|
)
|
|
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Empty response retry + fallback tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
|
@pytest.mark.asyncio
|
|
async def test_empty_response_retry_then_success(aiohttp_client):
|
|
"""First call returns empty → retry once → second call returns real text."""
|
|
call_count = 0
|
|
|
|
async def sometimes_empty(content, session_key="", channel="", chat_id="",
|
|
isolate_memory=False, disabled_tools=None):
|
|
nonlocal call_count
|
|
call_count += 1
|
|
if call_count == 1:
|
|
return ""
|
|
return "recovered response"
|
|
|
|
agent = MagicMock()
|
|
agent.process_direct = sometimes_empty
|
|
agent._connect_mcp = AsyncMock()
|
|
agent.close_mcp = AsyncMock()
|
|
|
|
app = create_app(agent, model_name="m")
|
|
client = await aiohttp_client(app)
|
|
|
|
resp = await client.post(
|
|
"/v1/chat/completions",
|
|
json={"messages": [{"role": "user", "content": "hello"}]},
|
|
headers={"x-session-key": "retry-test"},
|
|
)
|
|
assert resp.status == 200
|
|
body = await resp.json()
|
|
assert body["choices"][0]["message"]["content"] == "recovered response"
|
|
assert call_count == 2
|
|
|
|
|
|
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
|
@pytest.mark.asyncio
|
|
async def test_empty_response_both_empty_returns_fallback(aiohttp_client):
|
|
"""Both calls return empty → must use the fallback text."""
|
|
call_count = 0
|
|
|
|
async def always_empty(content, session_key="", channel="", chat_id="",
|
|
isolate_memory=False, disabled_tools=None):
|
|
nonlocal call_count
|
|
call_count += 1
|
|
return ""
|
|
|
|
agent = MagicMock()
|
|
agent.process_direct = always_empty
|
|
agent._connect_mcp = AsyncMock()
|
|
agent.close_mcp = AsyncMock()
|
|
|
|
app = create_app(agent, model_name="m")
|
|
client = await aiohttp_client(app)
|
|
|
|
resp = await client.post(
|
|
"/v1/chat/completions",
|
|
json={"messages": [{"role": "user", "content": "hello"}]},
|
|
headers={"x-session-key": "fallback-test"},
|
|
)
|
|
assert resp.status == 200
|
|
body = await resp.json()
|
|
assert body["choices"][0]["message"]["content"] == "I've completed processing but have no response to give."
|
|
assert call_count == 2
|
|
|
|
|
|
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
|
@pytest.mark.asyncio
|
|
async def test_whitespace_only_response_triggers_retry(aiohttp_client):
|
|
"""Whitespace-only response should be treated as empty and trigger retry."""
|
|
call_count = 0
|
|
|
|
async def whitespace_then_ok(content, session_key="", channel="", chat_id="",
|
|
isolate_memory=False, disabled_tools=None):
|
|
nonlocal call_count
|
|
call_count += 1
|
|
if call_count == 1:
|
|
return " \n "
|
|
return "real answer"
|
|
|
|
agent = MagicMock()
|
|
agent.process_direct = whitespace_then_ok
|
|
agent._connect_mcp = AsyncMock()
|
|
agent.close_mcp = AsyncMock()
|
|
|
|
app = create_app(agent, model_name="m")
|
|
client = await aiohttp_client(app)
|
|
|
|
resp = await client.post(
|
|
"/v1/chat/completions",
|
|
json={"messages": [{"role": "user", "content": "hello"}]},
|
|
headers={"x-session-key": "ws-test"},
|
|
)
|
|
assert resp.status == 200
|
|
body = await resp.json()
|
|
assert body["choices"][0]["message"]["content"] == "real answer"
|
|
assert call_count == 2
|
|
|
|
|
|
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
|
@pytest.mark.asyncio
|
|
async def test_none_response_triggers_retry(aiohttp_client):
|
|
"""None response should be treated as empty and trigger retry."""
|
|
call_count = 0
|
|
|
|
async def none_then_ok(content, session_key="", channel="", chat_id="",
|
|
isolate_memory=False, disabled_tools=None):
|
|
nonlocal call_count
|
|
call_count += 1
|
|
if call_count == 1:
|
|
return None
|
|
return "got it"
|
|
|
|
agent = MagicMock()
|
|
agent.process_direct = none_then_ok
|
|
agent._connect_mcp = AsyncMock()
|
|
agent.close_mcp = AsyncMock()
|
|
|
|
app = create_app(agent, model_name="m")
|
|
client = await aiohttp_client(app)
|
|
|
|
resp = await client.post(
|
|
"/v1/chat/completions",
|
|
json={"messages": [{"role": "user", "content": "hello"}]},
|
|
headers={"x-session-key": "none-test"},
|
|
)
|
|
assert resp.status == 200
|
|
body = await resp.json()
|
|
assert body["choices"][0]["message"]["content"] == "got it"
|
|
assert call_count == 2
|
|
|
|
|
|
@pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")
|
|
@pytest.mark.asyncio
|
|
async def test_nonempty_response_no_retry(aiohttp_client):
|
|
"""A normal non-empty response must NOT trigger a retry."""
|
|
call_count = 0
|
|
|
|
async def normal_response(content, session_key="", channel="", chat_id="",
|
|
isolate_memory=False, disabled_tools=None):
|
|
nonlocal call_count
|
|
call_count += 1
|
|
return "immediate answer"
|
|
|
|
agent = MagicMock()
|
|
agent.process_direct = normal_response
|
|
agent._connect_mcp = AsyncMock()
|
|
agent.close_mcp = AsyncMock()
|
|
|
|
app = create_app(agent, model_name="m")
|
|
client = await aiohttp_client(app)
|
|
|
|
resp = await client.post(
|
|
"/v1/chat/completions",
|
|
json={"messages": [{"role": "user", "content": "hello"}]},
|
|
headers={"x-session-key": "normal-test"},
|
|
)
|
|
assert resp.status == 200
|
|
body = await resp.json()
|
|
assert body["choices"][0]["message"]["content"] == "immediate answer"
|
|
assert call_count == 1
|