mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-02 17:32:39 +00:00
fix: isolate /new consolidation in API mode
This commit is contained in:
parent
f5cf0bfdee
commit
9d69ba9f56
@ -14,7 +14,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.agent.context import ContextBuilder
|
||||
from nanobot.agent.memory import MemoryConsolidator
|
||||
from nanobot.agent.memory import MemoryConsolidator, MemoryStore
|
||||
from nanobot.agent.subagent import SubagentManager
|
||||
from nanobot.agent.tools.cron import CronTool
|
||||
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
|
||||
@ -362,7 +362,7 @@ class AgentLoop:
|
||||
logger.info("Processing system message from {}", msg.sender_id)
|
||||
key = f"{channel}:{chat_id}"
|
||||
session = self.sessions.get_or_create(key)
|
||||
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
|
||||
await self.memory_consolidator.maybe_consolidate_by_tokens(session, store=memory_store)
|
||||
self._set_tool_context(channel, chat_id, msg.metadata.get("message_id"))
|
||||
history = session.get_history(max_messages=0)
|
||||
messages = self.context.build_messages(
|
||||
@ -375,7 +375,7 @@ class AgentLoop:
|
||||
)
|
||||
self._save_turn(session, all_msgs, 1 + len(history))
|
||||
self.sessions.save(session)
|
||||
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
|
||||
await self.memory_consolidator.maybe_consolidate_by_tokens(session, store=memory_store)
|
||||
return OutboundMessage(channel=channel, chat_id=chat_id,
|
||||
content=final_content or "Background task completed.")
|
||||
|
||||
@ -389,7 +389,9 @@ class AgentLoop:
|
||||
cmd = msg.content.strip().lower()
|
||||
if cmd == "/new":
|
||||
try:
|
||||
if not await self.memory_consolidator.archive_unconsolidated(session):
|
||||
if not await self.memory_consolidator.archive_unconsolidated(
|
||||
session, store=memory_store,
|
||||
):
|
||||
return OutboundMessage(
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
@ -419,7 +421,7 @@ class AgentLoop:
|
||||
return OutboundMessage(
|
||||
channel=msg.channel, chat_id=msg.chat_id, content="\n".join(lines),
|
||||
)
|
||||
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
|
||||
await self.memory_consolidator.maybe_consolidate_by_tokens(session, store=memory_store)
|
||||
|
||||
self._set_tool_context(msg.channel, msg.chat_id, msg.metadata.get("message_id"))
|
||||
if message_tool := self.tools.get("message"):
|
||||
@ -453,7 +455,7 @@ class AgentLoop:
|
||||
|
||||
self._save_turn(session, all_msgs, 1 + len(history))
|
||||
self.sessions.save(session)
|
||||
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
|
||||
await self.memory_consolidator.maybe_consolidate_by_tokens(session, store=memory_store)
|
||||
|
||||
if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn:
|
||||
return None
|
||||
|
||||
@ -247,9 +247,14 @@ class MemoryConsolidator:
|
||||
"""Return the shared consolidation lock for one session."""
|
||||
return self._locks.setdefault(session_key, asyncio.Lock())
|
||||
|
||||
async def consolidate_messages(self, messages: list[dict[str, object]]) -> bool:
|
||||
async def consolidate_messages(
|
||||
self,
|
||||
messages: list[dict[str, object]],
|
||||
store: MemoryStore | None = None,
|
||||
) -> bool:
|
||||
"""Archive a selected message chunk into persistent memory."""
|
||||
return await self.store.consolidate(messages, self.provider, self.model)
|
||||
target = store or self.store
|
||||
return await target.consolidate(messages, self.provider, self.model)
|
||||
|
||||
def pick_consolidation_boundary(
|
||||
self,
|
||||
@ -290,16 +295,24 @@ class MemoryConsolidator:
|
||||
self._get_tool_definitions(),
|
||||
)
|
||||
|
||||
async def archive_unconsolidated(self, session: Session) -> bool:
|
||||
async def archive_unconsolidated(
|
||||
self,
|
||||
session: Session,
|
||||
store: MemoryStore | None = None,
|
||||
) -> bool:
|
||||
"""Archive the full unconsolidated tail for /new-style session rollover."""
|
||||
lock = self.get_lock(session.key)
|
||||
async with lock:
|
||||
snapshot = session.messages[session.last_consolidated:]
|
||||
if not snapshot:
|
||||
return True
|
||||
return await self.consolidate_messages(snapshot)
|
||||
return await self.consolidate_messages(snapshot, store=store)
|
||||
|
||||
async def maybe_consolidate_by_tokens(self, session: Session) -> None:
|
||||
async def maybe_consolidate_by_tokens(
|
||||
self,
|
||||
session: Session,
|
||||
store: MemoryStore | None = None,
|
||||
) -> None:
|
||||
"""Loop: archive old messages until prompt fits within half the context window."""
|
||||
if not session.messages or self.context_window_tokens <= 0:
|
||||
return
|
||||
@ -347,7 +360,7 @@ class MemoryConsolidator:
|
||||
source,
|
||||
len(chunk),
|
||||
)
|
||||
if not await self.consolidate_messages(chunk):
|
||||
if not await self.consolidate_messages(chunk, store=store):
|
||||
return
|
||||
session.last_consolidated = end_idx
|
||||
self.sessions.save(session)
|
||||
|
||||
@ -516,7 +516,7 @@ class TestNewCommandArchival:
|
||||
loop.sessions.save(session)
|
||||
before_count = len(session.messages)
|
||||
|
||||
async def _failing_consolidate(_messages) -> bool:
|
||||
async def _failing_consolidate(_messages, store=None) -> bool:
|
||||
return False
|
||||
|
||||
loop.memory_consolidator.consolidate_messages = _failing_consolidate # type: ignore[method-assign]
|
||||
@ -542,7 +542,7 @@ class TestNewCommandArchival:
|
||||
|
||||
archived_count = -1
|
||||
|
||||
async def _fake_consolidate(messages) -> bool:
|
||||
async def _fake_consolidate(messages, store=None) -> bool:
|
||||
nonlocal archived_count
|
||||
archived_count = len(messages)
|
||||
return True
|
||||
@ -567,7 +567,7 @@ class TestNewCommandArchival:
|
||||
session.add_message("assistant", f"resp{i}")
|
||||
loop.sessions.save(session)
|
||||
|
||||
async def _ok_consolidate(_messages) -> bool:
|
||||
async def _ok_consolidate(_messages, store=None) -> bool:
|
||||
return True
|
||||
|
||||
loop.memory_consolidator.consolidate_messages = _ok_consolidate # type: ignore[method-assign]
|
||||
@ -578,3 +578,33 @@ class TestNewCommandArchival:
|
||||
assert response is not None
|
||||
assert "new session started" in response.content.lower()
|
||||
assert loop.sessions.get_or_create("cli:test").messages == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_archives_to_custom_store_when_provided(self, tmp_path: Path) -> None:
|
||||
"""When memory_store is passed, /new must archive through that store."""
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.agent.memory import MemoryStore
|
||||
|
||||
loop = self._make_loop(tmp_path)
|
||||
session = loop.sessions.get_or_create("cli:test")
|
||||
for i in range(5):
|
||||
session.add_message("user", f"msg{i}")
|
||||
session.add_message("assistant", f"resp{i}")
|
||||
loop.sessions.save(session)
|
||||
|
||||
used_store = None
|
||||
|
||||
async def _tracking_consolidate(messages, store=None) -> bool:
|
||||
nonlocal used_store
|
||||
used_store = store
|
||||
return True
|
||||
|
||||
loop.memory_consolidator.consolidate_messages = _tracking_consolidate # type: ignore[method-assign]
|
||||
|
||||
iso_store = MagicMock(spec=MemoryStore)
|
||||
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
||||
response = await loop._process_message(new_msg, memory_store=iso_store)
|
||||
|
||||
assert response is not None
|
||||
assert "new session started" in response.content.lower()
|
||||
assert used_store is iso_store, "archive_unconsolidated must use the provided store"
|
||||
|
||||
@ -158,7 +158,7 @@ async def test_preflight_consolidation_before_llm_call(tmp_path, monkeypatch) ->
|
||||
|
||||
loop = _make_loop(tmp_path, estimated_tokens=0, context_window_tokens=200)
|
||||
|
||||
async def track_consolidate(messages):
|
||||
async def track_consolidate(messages, store=None):
|
||||
order.append("consolidate")
|
||||
return True
|
||||
loop.memory_consolidator.consolidate_messages = track_consolidate # type: ignore[method-assign]
|
||||
|
||||
@ -622,6 +622,53 @@ class TestConsolidationIsolation:
|
||||
assert (global_mem_dir / "MEMORY.md").read_text() == ""
|
||||
assert (global_mem_dir / "HISTORY.md").read_text() == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_command_uses_isolated_store(self, tmp_path):
|
||||
"""process_direct(isolate_memory=True) + /new must archive to the isolated store."""
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from nanobot.agent.loop import AgentLoop
|
||||
from nanobot.agent.memory import MemoryStore
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.providers.base import LLMResponse
|
||||
|
||||
bus = MessageBus()
|
||||
provider = MagicMock()
|
||||
provider.get_default_model.return_value = "test-model"
|
||||
provider.estimate_prompt_tokens.return_value = (10_000, "test")
|
||||
agent = AgentLoop(
|
||||
bus=bus, provider=provider, workspace=tmp_path,
|
||||
model="test-model", context_window_tokens=1,
|
||||
)
|
||||
agent._mcp_connected = True # skip MCP connect
|
||||
agent.tools.get_definitions = MagicMock(return_value=[])
|
||||
|
||||
# Pre-populate session so /new has something to archive
|
||||
session = agent.sessions.get_or_create("api:alice")
|
||||
for i in range(3):
|
||||
session.add_message("user", f"msg{i}")
|
||||
session.add_message("assistant", f"resp{i}")
|
||||
agent.sessions.save(session)
|
||||
|
||||
used_store = None
|
||||
|
||||
async def _tracking_consolidate(messages, store=None) -> bool:
|
||||
nonlocal used_store
|
||||
used_store = store
|
||||
return True
|
||||
|
||||
agent.memory_consolidator.consolidate_messages = _tracking_consolidate # type: ignore[method-assign]
|
||||
|
||||
result = await agent.process_direct(
|
||||
"/new", session_key="api:alice", isolate_memory=True,
|
||||
)
|
||||
|
||||
assert "new session started" in result.lower()
|
||||
assert used_store is not None, "consolidation must receive a store"
|
||||
assert isinstance(used_store, MemoryStore)
|
||||
assert "sessions" in str(used_store.memory_dir), (
|
||||
"store must point to per-session dir, not global workspace"
|
||||
)
|
||||
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user