mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-05 01:05:51 +00:00
fix: isolate /new consolidation in API mode
This commit is contained in:
parent
f5cf0bfdee
commit
9d69ba9f56
@ -14,7 +14,7 @@ from typing import TYPE_CHECKING, Any, Awaitable, Callable
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from nanobot.agent.context import ContextBuilder
|
from nanobot.agent.context import ContextBuilder
|
||||||
from nanobot.agent.memory import MemoryConsolidator
|
from nanobot.agent.memory import MemoryConsolidator, MemoryStore
|
||||||
from nanobot.agent.subagent import SubagentManager
|
from nanobot.agent.subagent import SubagentManager
|
||||||
from nanobot.agent.tools.cron import CronTool
|
from nanobot.agent.tools.cron import CronTool
|
||||||
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
|
from nanobot.agent.tools.filesystem import EditFileTool, ListDirTool, ReadFileTool, WriteFileTool
|
||||||
@ -362,7 +362,7 @@ class AgentLoop:
|
|||||||
logger.info("Processing system message from {}", msg.sender_id)
|
logger.info("Processing system message from {}", msg.sender_id)
|
||||||
key = f"{channel}:{chat_id}"
|
key = f"{channel}:{chat_id}"
|
||||||
session = self.sessions.get_or_create(key)
|
session = self.sessions.get_or_create(key)
|
||||||
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
|
await self.memory_consolidator.maybe_consolidate_by_tokens(session, store=memory_store)
|
||||||
self._set_tool_context(channel, chat_id, msg.metadata.get("message_id"))
|
self._set_tool_context(channel, chat_id, msg.metadata.get("message_id"))
|
||||||
history = session.get_history(max_messages=0)
|
history = session.get_history(max_messages=0)
|
||||||
messages = self.context.build_messages(
|
messages = self.context.build_messages(
|
||||||
@ -375,7 +375,7 @@ class AgentLoop:
|
|||||||
)
|
)
|
||||||
self._save_turn(session, all_msgs, 1 + len(history))
|
self._save_turn(session, all_msgs, 1 + len(history))
|
||||||
self.sessions.save(session)
|
self.sessions.save(session)
|
||||||
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
|
await self.memory_consolidator.maybe_consolidate_by_tokens(session, store=memory_store)
|
||||||
return OutboundMessage(channel=channel, chat_id=chat_id,
|
return OutboundMessage(channel=channel, chat_id=chat_id,
|
||||||
content=final_content or "Background task completed.")
|
content=final_content or "Background task completed.")
|
||||||
|
|
||||||
@ -389,7 +389,9 @@ class AgentLoop:
|
|||||||
cmd = msg.content.strip().lower()
|
cmd = msg.content.strip().lower()
|
||||||
if cmd == "/new":
|
if cmd == "/new":
|
||||||
try:
|
try:
|
||||||
if not await self.memory_consolidator.archive_unconsolidated(session):
|
if not await self.memory_consolidator.archive_unconsolidated(
|
||||||
|
session, store=memory_store,
|
||||||
|
):
|
||||||
return OutboundMessage(
|
return OutboundMessage(
|
||||||
channel=msg.channel,
|
channel=msg.channel,
|
||||||
chat_id=msg.chat_id,
|
chat_id=msg.chat_id,
|
||||||
@ -419,7 +421,7 @@ class AgentLoop:
|
|||||||
return OutboundMessage(
|
return OutboundMessage(
|
||||||
channel=msg.channel, chat_id=msg.chat_id, content="\n".join(lines),
|
channel=msg.channel, chat_id=msg.chat_id, content="\n".join(lines),
|
||||||
)
|
)
|
||||||
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
|
await self.memory_consolidator.maybe_consolidate_by_tokens(session, store=memory_store)
|
||||||
|
|
||||||
self._set_tool_context(msg.channel, msg.chat_id, msg.metadata.get("message_id"))
|
self._set_tool_context(msg.channel, msg.chat_id, msg.metadata.get("message_id"))
|
||||||
if message_tool := self.tools.get("message"):
|
if message_tool := self.tools.get("message"):
|
||||||
@ -453,7 +455,7 @@ class AgentLoop:
|
|||||||
|
|
||||||
self._save_turn(session, all_msgs, 1 + len(history))
|
self._save_turn(session, all_msgs, 1 + len(history))
|
||||||
self.sessions.save(session)
|
self.sessions.save(session)
|
||||||
await self.memory_consolidator.maybe_consolidate_by_tokens(session)
|
await self.memory_consolidator.maybe_consolidate_by_tokens(session, store=memory_store)
|
||||||
|
|
||||||
if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn:
|
if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn:
|
||||||
return None
|
return None
|
||||||
|
|||||||
@ -247,9 +247,14 @@ class MemoryConsolidator:
|
|||||||
"""Return the shared consolidation lock for one session."""
|
"""Return the shared consolidation lock for one session."""
|
||||||
return self._locks.setdefault(session_key, asyncio.Lock())
|
return self._locks.setdefault(session_key, asyncio.Lock())
|
||||||
|
|
||||||
async def consolidate_messages(self, messages: list[dict[str, object]]) -> bool:
|
async def consolidate_messages(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, object]],
|
||||||
|
store: MemoryStore | None = None,
|
||||||
|
) -> bool:
|
||||||
"""Archive a selected message chunk into persistent memory."""
|
"""Archive a selected message chunk into persistent memory."""
|
||||||
return await self.store.consolidate(messages, self.provider, self.model)
|
target = store or self.store
|
||||||
|
return await target.consolidate(messages, self.provider, self.model)
|
||||||
|
|
||||||
def pick_consolidation_boundary(
|
def pick_consolidation_boundary(
|
||||||
self,
|
self,
|
||||||
@ -290,16 +295,24 @@ class MemoryConsolidator:
|
|||||||
self._get_tool_definitions(),
|
self._get_tool_definitions(),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def archive_unconsolidated(self, session: Session) -> bool:
|
async def archive_unconsolidated(
|
||||||
|
self,
|
||||||
|
session: Session,
|
||||||
|
store: MemoryStore | None = None,
|
||||||
|
) -> bool:
|
||||||
"""Archive the full unconsolidated tail for /new-style session rollover."""
|
"""Archive the full unconsolidated tail for /new-style session rollover."""
|
||||||
lock = self.get_lock(session.key)
|
lock = self.get_lock(session.key)
|
||||||
async with lock:
|
async with lock:
|
||||||
snapshot = session.messages[session.last_consolidated:]
|
snapshot = session.messages[session.last_consolidated:]
|
||||||
if not snapshot:
|
if not snapshot:
|
||||||
return True
|
return True
|
||||||
return await self.consolidate_messages(snapshot)
|
return await self.consolidate_messages(snapshot, store=store)
|
||||||
|
|
||||||
async def maybe_consolidate_by_tokens(self, session: Session) -> None:
|
async def maybe_consolidate_by_tokens(
|
||||||
|
self,
|
||||||
|
session: Session,
|
||||||
|
store: MemoryStore | None = None,
|
||||||
|
) -> None:
|
||||||
"""Loop: archive old messages until prompt fits within half the context window."""
|
"""Loop: archive old messages until prompt fits within half the context window."""
|
||||||
if not session.messages or self.context_window_tokens <= 0:
|
if not session.messages or self.context_window_tokens <= 0:
|
||||||
return
|
return
|
||||||
@ -347,7 +360,7 @@ class MemoryConsolidator:
|
|||||||
source,
|
source,
|
||||||
len(chunk),
|
len(chunk),
|
||||||
)
|
)
|
||||||
if not await self.consolidate_messages(chunk):
|
if not await self.consolidate_messages(chunk, store=store):
|
||||||
return
|
return
|
||||||
session.last_consolidated = end_idx
|
session.last_consolidated = end_idx
|
||||||
self.sessions.save(session)
|
self.sessions.save(session)
|
||||||
|
|||||||
@ -516,7 +516,7 @@ class TestNewCommandArchival:
|
|||||||
loop.sessions.save(session)
|
loop.sessions.save(session)
|
||||||
before_count = len(session.messages)
|
before_count = len(session.messages)
|
||||||
|
|
||||||
async def _failing_consolidate(_messages) -> bool:
|
async def _failing_consolidate(_messages, store=None) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
loop.memory_consolidator.consolidate_messages = _failing_consolidate # type: ignore[method-assign]
|
loop.memory_consolidator.consolidate_messages = _failing_consolidate # type: ignore[method-assign]
|
||||||
@ -542,7 +542,7 @@ class TestNewCommandArchival:
|
|||||||
|
|
||||||
archived_count = -1
|
archived_count = -1
|
||||||
|
|
||||||
async def _fake_consolidate(messages) -> bool:
|
async def _fake_consolidate(messages, store=None) -> bool:
|
||||||
nonlocal archived_count
|
nonlocal archived_count
|
||||||
archived_count = len(messages)
|
archived_count = len(messages)
|
||||||
return True
|
return True
|
||||||
@ -567,7 +567,7 @@ class TestNewCommandArchival:
|
|||||||
session.add_message("assistant", f"resp{i}")
|
session.add_message("assistant", f"resp{i}")
|
||||||
loop.sessions.save(session)
|
loop.sessions.save(session)
|
||||||
|
|
||||||
async def _ok_consolidate(_messages) -> bool:
|
async def _ok_consolidate(_messages, store=None) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
loop.memory_consolidator.consolidate_messages = _ok_consolidate # type: ignore[method-assign]
|
loop.memory_consolidator.consolidate_messages = _ok_consolidate # type: ignore[method-assign]
|
||||||
@ -578,3 +578,33 @@ class TestNewCommandArchival:
|
|||||||
assert response is not None
|
assert response is not None
|
||||||
assert "new session started" in response.content.lower()
|
assert "new session started" in response.content.lower()
|
||||||
assert loop.sessions.get_or_create("cli:test").messages == []
|
assert loop.sessions.get_or_create("cli:test").messages == []
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_new_archives_to_custom_store_when_provided(self, tmp_path: Path) -> None:
|
||||||
|
"""When memory_store is passed, /new must archive through that store."""
|
||||||
|
from nanobot.bus.events import InboundMessage
|
||||||
|
from nanobot.agent.memory import MemoryStore
|
||||||
|
|
||||||
|
loop = self._make_loop(tmp_path)
|
||||||
|
session = loop.sessions.get_or_create("cli:test")
|
||||||
|
for i in range(5):
|
||||||
|
session.add_message("user", f"msg{i}")
|
||||||
|
session.add_message("assistant", f"resp{i}")
|
||||||
|
loop.sessions.save(session)
|
||||||
|
|
||||||
|
used_store = None
|
||||||
|
|
||||||
|
async def _tracking_consolidate(messages, store=None) -> bool:
|
||||||
|
nonlocal used_store
|
||||||
|
used_store = store
|
||||||
|
return True
|
||||||
|
|
||||||
|
loop.memory_consolidator.consolidate_messages = _tracking_consolidate # type: ignore[method-assign]
|
||||||
|
|
||||||
|
iso_store = MagicMock(spec=MemoryStore)
|
||||||
|
new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new")
|
||||||
|
response = await loop._process_message(new_msg, memory_store=iso_store)
|
||||||
|
|
||||||
|
assert response is not None
|
||||||
|
assert "new session started" in response.content.lower()
|
||||||
|
assert used_store is iso_store, "archive_unconsolidated must use the provided store"
|
||||||
|
|||||||
@ -158,7 +158,7 @@ async def test_preflight_consolidation_before_llm_call(tmp_path, monkeypatch) ->
|
|||||||
|
|
||||||
loop = _make_loop(tmp_path, estimated_tokens=0, context_window_tokens=200)
|
loop = _make_loop(tmp_path, estimated_tokens=0, context_window_tokens=200)
|
||||||
|
|
||||||
async def track_consolidate(messages):
|
async def track_consolidate(messages, store=None):
|
||||||
order.append("consolidate")
|
order.append("consolidate")
|
||||||
return True
|
return True
|
||||||
loop.memory_consolidator.consolidate_messages = track_consolidate # type: ignore[method-assign]
|
loop.memory_consolidator.consolidate_messages = track_consolidate # type: ignore[method-assign]
|
||||||
|
|||||||
@ -622,6 +622,53 @@ class TestConsolidationIsolation:
|
|||||||
assert (global_mem_dir / "MEMORY.md").read_text() == ""
|
assert (global_mem_dir / "MEMORY.md").read_text() == ""
|
||||||
assert (global_mem_dir / "HISTORY.md").read_text() == ""
|
assert (global_mem_dir / "HISTORY.md").read_text() == ""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_new_command_uses_isolated_store(self, tmp_path):
|
||||||
|
"""process_direct(isolate_memory=True) + /new must archive to the isolated store."""
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
from nanobot.agent.loop import AgentLoop
|
||||||
|
from nanobot.agent.memory import MemoryStore
|
||||||
|
from nanobot.bus.queue import MessageBus
|
||||||
|
from nanobot.providers.base import LLMResponse
|
||||||
|
|
||||||
|
bus = MessageBus()
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.get_default_model.return_value = "test-model"
|
||||||
|
provider.estimate_prompt_tokens.return_value = (10_000, "test")
|
||||||
|
agent = AgentLoop(
|
||||||
|
bus=bus, provider=provider, workspace=tmp_path,
|
||||||
|
model="test-model", context_window_tokens=1,
|
||||||
|
)
|
||||||
|
agent._mcp_connected = True # skip MCP connect
|
||||||
|
agent.tools.get_definitions = MagicMock(return_value=[])
|
||||||
|
|
||||||
|
# Pre-populate session so /new has something to archive
|
||||||
|
session = agent.sessions.get_or_create("api:alice")
|
||||||
|
for i in range(3):
|
||||||
|
session.add_message("user", f"msg{i}")
|
||||||
|
session.add_message("assistant", f"resp{i}")
|
||||||
|
agent.sessions.save(session)
|
||||||
|
|
||||||
|
used_store = None
|
||||||
|
|
||||||
|
async def _tracking_consolidate(messages, store=None) -> bool:
|
||||||
|
nonlocal used_store
|
||||||
|
used_store = store
|
||||||
|
return True
|
||||||
|
|
||||||
|
agent.memory_consolidator.consolidate_messages = _tracking_consolidate # type: ignore[method-assign]
|
||||||
|
|
||||||
|
result = await agent.process_direct(
|
||||||
|
"/new", session_key="api:alice", isolate_memory=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "new session started" in result.lower()
|
||||||
|
assert used_store is not None, "consolidation must receive a store"
|
||||||
|
assert isinstance(used_store, MemoryStore)
|
||||||
|
assert "sessions" in str(used_store.memory_dir), (
|
||||||
|
"store must point to per-session dir, not global workspace"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user