fix: isolate /new consolidation in API mode

This commit is contained in:
Tink 2026-03-13 19:26:50 +08:00
parent f5cf0bfdee
commit 9d69ba9f56
5 changed files with 108 additions and 16 deletions

View File

@ -14,7 +14,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable
from loguru import logger from loguru import logger
from nanobot.agent.context import ContextBuilder from nanobot.agent.context import ContextBuilder
from nanobot.agent.memory import MemoryConsolidator from nanobot.agent.memory import MemoryConsolidator, MemoryStore
from nanobot.agent.subagent import SubagentManager from nanobot.agent.subagent import SubagentManager
from nanobot.agent.tools.cron import CronTool from nanobot.agent.tools.cron import CronTool
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
@ -362,7 +362,7 @@ class AgentLoop:
logger.info("Processing system message from {}", msg.sender_id) logger.info("Processing system message from {}", msg.sender_id)
key = f"{channel}:{chat_id}" key = f"{channel}:{chat_id}"
session = self.sessions.get_or_create(key) session = self.sessions.get_or_create(key)
await self.memory_consolidator.maybe_consolidate_by_tokens(session) await self.memory_consolidator.maybe_consolidate_by_tokens(session, store=memory_store)
self._set_tool_context(channel, chat_id, msg.metadata.get("message_id")) self._set_tool_context(channel, chat_id, msg.metadata.get("message_id"))
history = session.get_history(max_messages=0) history = session.get_history(max_messages=0)
messages = self.context.build_messages( messages = self.context.build_messages(
@ -375,7 +375,7 @@ class AgentLoop:
) )
self._save_turn(session, all_msgs, 1 + len(history)) self._save_turn(session, all_msgs, 1 + len(history))
self.sessions.save(session) self.sessions.save(session)
await self.memory_consolidator.maybe_consolidate_by_tokens(session) await self.memory_consolidator.maybe_consolidate_by_tokens(session, store=memory_store)
return OutboundMessage(channel=channel, chat_id=chat_id, return OutboundMessage(channel=channel, chat_id=chat_id,
content=final_content or "Background task completed.") content=final_content or "Background task completed.")
@ -389,7 +389,9 @@ class AgentLoop:
cmd = msg.content.strip().lower() cmd = msg.content.strip().lower()
if cmd == "/new": if cmd == "/new":
try: try:
if not await self.memory_consolidator.archive_unconsolidated(session): if not await self.memory_consolidator.archive_unconsolidated(
session, store=memory_store,
):
return OutboundMessage( return OutboundMessage(
channel=msg.channel, channel=msg.channel,
chat_id=msg.chat_id, chat_id=msg.chat_id,
@ -419,7 +421,7 @@ class AgentLoop:
return OutboundMessage( return OutboundMessage(
channel=msg.channel, chat_id=msg.chat_id, content="\n".join(lines), channel=msg.channel, chat_id=msg.chat_id, content="\n".join(lines),
) )
await self.memory_consolidator.maybe_consolidate_by_tokens(session) await self.memory_consolidator.maybe_consolidate_by_tokens(session, store=memory_store)
self._set_tool_context(msg.channel, msg.chat_id, msg.metadata.get("message_id")) self._set_tool_context(msg.channel, msg.chat_id, msg.metadata.get("message_id"))
if message_tool := self.tools.get("message"): if message_tool := self.tools.get("message"):
@ -453,7 +455,7 @@ class AgentLoop:
self._save_turn(session, all_msgs, 1 + len(history)) self._save_turn(session, all_msgs, 1 + len(history))
self.sessions.save(session) self.sessions.save(session)
await self.memory_consolidator.maybe_consolidate_by_tokens(session) await self.memory_consolidator.maybe_consolidate_by_tokens(session, store=memory_store)
if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn: if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn:
return None return None

View File

@ -247,9 +247,14 @@ class MemoryConsolidator:
"""Return the shared consolidation lock for one session.""" """Return the shared consolidation lock for one session."""
return self._locks.setdefault(session_key, asyncio.Lock()) return self._locks.setdefault(session_key, asyncio.Lock())
async def consolidate_messages(self, messages: list[dict[str, object]]) -> bool: async def consolidate_messages(
self,
messages: list[dict[str, object]],
store: MemoryStore | None = None,
) -> bool:
"""Archive a selected message chunk into persistent memory.""" """Archive a selected message chunk into persistent memory."""
return await self.store.consolidate(messages, self.provider, self.model) target = store or self.store
return await target.consolidate(messages, self.provider, self.model)
def pick_consolidation_boundary( def pick_consolidation_boundary(
self, self,
@ -290,16 +295,24 @@ class MemoryConsolidator:
self._get_tool_definitions(), self._get_tool_definitions(),
) )
async def archive_unconsolidated(self, session: Session) -> bool: async def archive_unconsolidated(
self,
session: Session,
store: MemoryStore | None = None,
) -> bool:
"""Archive the full unconsolidated tail for /new-style session rollover.""" """Archive the full unconsolidated tail for /new-style session rollover."""
lock = self.get_lock(session.key) lock = self.get_lock(session.key)
async with lock: async with lock:
snapshot = session.messages[session.last_consolidated:] snapshot = session.messages[session.last_consolidated:]
if not snapshot: if not snapshot:
return True return True
return await self.consolidate_messages(snapshot) return await self.consolidate_messages(snapshot, store=store)
async def maybe_consolidate_by_tokens(self, session: Session) -> None: async def maybe_consolidate_by_tokens(
self,
session: Session,
store: MemoryStore | None = None,
) -> None:
"""Loop: archive old messages until prompt fits within half the context window.""" """Loop: archive old messages until prompt fits within half the context window."""
if not session.messages or self.context_window_tokens <= 0: if not session.messages or self.context_window_tokens <= 0:
return return
@ -347,7 +360,7 @@ class MemoryConsolidator:
source, source,
len(chunk), len(chunk),
) )
if not await self.consolidate_messages(chunk): if not await self.consolidate_messages(chunk, store=store):
return return
session.last_consolidated = end_idx session.last_consolidated = end_idx
self.sessions.save(session) self.sessions.save(session)

View File

@ -516,7 +516,7 @@ class TestNewCommandArchival:
loop.sessions.save(session) loop.sessions.save(session)
before_count = len(session.messages) before_count = len(session.messages)
async def _failing_consolidate(_messages) -> bool: async def _failing_consolidate(_messages, store=None) -> bool:
return False return False
loop.memory_consolidator.consolidate_messages = _failing_consolidate # type: ignore[method-assign] loop.memory_consolidator.consolidate_messages = _failing_consolidate # type: ignore[method-assign]
@ -542,7 +542,7 @@ class TestNewCommandArchival:
archived_count = -1 archived_count = -1
async def _fake_consolidate(messages) -> bool: async def _fake_consolidate(messages, store=None) -> bool:
nonlocal archived_count nonlocal archived_count
archived_count = len(messages) archived_count = len(messages)
return True return True
@ -567,7 +567,7 @@ class TestNewCommandArchival:
session.add_message("assistant", f"resp{i}") session.add_message("assistant", f"resp{i}")
loop.sessions.save(session) loop.sessions.save(session)
async def _ok_consolidate(_messages) -> bool: async def _ok_consolidate(_messages, store=None) -> bool:
return True return True
loop.memory_consolidator.consolidate_messages = _ok_consolidate # type: ignore[method-assign] loop.memory_consolidator.consolidate_messages = _ok_consolidate # type: ignore[method-assign]
@ -578,3 +578,33 @@ class TestNewCommandArchival:
assert response is not None assert response is not None
assert "new session started" in response.content.lower() assert "new session started" in response.content.lower()
assert loop.sessions.get_or_create("cli:test").messages == [] assert loop.sessions.get_or_create("cli:test").messages == []
@pytest.mark.asyncio
async def test_new_archives_to_custom_store_when_provided(self, tmp_path: Path) -> None:
"""When memory_store is passed, /new must archive through that store."""
from nanobot.bus.events import InboundMessage
from nanobot.agent.memory import MemoryStore
loop = self._make_loop(tmp_path)
session = loop.sessions.get_or_create("cli:test")
for i in range(5):
session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}")
loop.sessions.save(session)
used_store = None
async def _tracking_consolidate(messages, store=None) -> bool:
nonlocal used_store
used_store = store
return True
loop.memory_consolidator.consolidate_messages = _tracking_consolidate # type: ignore[method-assign]
iso_store = MagicMock(spec=MemoryStore)
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
response = await loop._process_message(new_msg, memory_store=iso_store)
assert response is not None
assert "new session started" in response.content.lower()
assert used_store is iso_store, "archive_unconsolidated must use the provided store"

View File

@ -158,7 +158,7 @@ async def test_preflight_consolidation_before_llm_call(tmp_path, monkeypatch) ->
loop = _make_loop(tmp_path, estimated_tokens=0, context_window_tokens=200) loop = _make_loop(tmp_path, estimated_tokens=0, context_window_tokens=200)
async def track_consolidate(messages): async def track_consolidate(messages, store=None):
order.append("consolidate") order.append("consolidate")
return True return True
loop.memory_consolidator.consolidate_messages = track_consolidate # type: ignore[method-assign] loop.memory_consolidator.consolidate_messages = track_consolidate # type: ignore[method-assign]

View File

@ -622,6 +622,53 @@ class TestConsolidationIsolation:
assert (global_mem_dir / "MEMORY.md").read_text() == "" assert (global_mem_dir / "MEMORY.md").read_text() == ""
assert (global_mem_dir / "HISTORY.md").read_text() == "" assert (global_mem_dir / "HISTORY.md").read_text() == ""
@pytest.mark.asyncio
async def test_new_command_uses_isolated_store(self, tmp_path):
"""process_direct(isolate_memory=True) + /new must archive to the isolated store."""
from unittest.mock import AsyncMock, MagicMock
from nanobot.agent.loop import AgentLoop
from nanobot.agent.memory import MemoryStore
from nanobot.bus.queue import MessageBus
from nanobot.providers.base import LLMResponse
bus = MessageBus()
provider = MagicMock()
provider.get_default_model.return_value = "test-model"
provider.estimate_prompt_tokens.return_value = (10_000, "test")
agent = AgentLoop(
bus=bus, provider=provider, workspace=tmp_path,
model="test-model", context_window_tokens=1,
)
agent._mcp_connected = True # skip MCP connect
agent.tools.get_definitions = MagicMock(return_value=[])
# Pre-populate session so /new has something to archive
session = agent.sessions.get_or_create("api:alice")
for i in range(3):
session.add_message("user", f"msg{i}")
session.add_message("assistant", f"resp{i}")
agent.sessions.save(session)
used_store = None
async def _tracking_consolidate(messages, store=None) -> bool:
nonlocal used_store
used_store = store
return True
agent.memory_consolidator.consolidate_messages = _tracking_consolidate # type: ignore[method-assign]
result = await agent.process_direct(
"/new", session_key="api:alice", isolate_memory=True,
)
assert "new session started" in result.lower()
assert used_store is not None, "consolidation must receive a store"
assert isinstance(used_store, MemoryStore)
assert "sessions" in str(used_store.memory_dir), (
"store must point to per-session dir, not global workspace"
)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------