diff --git a/nanobot/agent/__init__.py b/nanobot/agent/__init__.py index 7d3ab2af4..a8805a3ad 100644 --- a/nanobot/agent/__init__.py +++ b/nanobot/agent/__init__.py @@ -3,7 +3,7 @@ from nanobot.agent.context import ContextBuilder from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook from nanobot.agent.loop import AgentLoop -from nanobot.agent.memory import MemoryStore +from nanobot.agent.memory import Consolidator, Dream, MemoryStore from nanobot.agent.skills import SkillsLoader from nanobot.agent.subagent import SubagentManager @@ -13,6 +13,7 @@ __all__ = [ "AgentLoop", "CompositeHook", "ContextBuilder", + "Dream", "MemoryStore", "SkillsLoader", "SubagentManager", diff --git a/nanobot/agent/context.py b/nanobot/agent/context.py index ce69d247b..566ca5f24 100644 --- a/nanobot/agent/context.py +++ b/nanobot/agent/context.py @@ -82,8 +82,8 @@ You are nanobot, a helpful AI assistant. ## Workspace Your workspace is at: {workspace_path} -- Long-term memory: {workspace_path}/memory/MEMORY.md (write important facts here) -- History log: {workspace_path}/memory/HISTORY.md (grep-searchable). Each entry starts with [YYYY-MM-DD HH:MM]. +- Long-term memory: {workspace_path}/memory/MEMORY.md (automatically managed by Dream — do not edit directly) +- History log: {workspace_path}/memory/history.jsonl (append-only JSONL, not grep-searchable). - Custom skills: {workspace_path}/skills/{{skill-name}}/SKILL.md {platform_policy} diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index f52b17d14..d1b6fb44f 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -15,7 +15,7 @@ from loguru import logger from nanobot.agent.context import ContextBuilder from nanobot.agent.hook import AgentHook, AgentHookContext, CompositeHook -from nanobot.agent.memory import MemoryConsolidator +from nanobot.agent.memory import Consolidator, Dream from nanobot.agent.runner import AgentRunSpec, AgentRunner from nanobot.agent.subagent import SubagentManager from nanobot.agent.tools.cron import CronTool @@ -230,8 +230,8 @@ class AgentLoop: self._concurrency_gate: asyncio.Semaphore | None = ( asyncio.Semaphore(_max) if _max > 0 else None ) - self.memory_consolidator = MemoryConsolidator( - workspace=workspace, + self.consolidator = Consolidator( + store=self.context.memory, provider=provider, model=self.model, sessions=self.sessions, @@ -240,6 +240,11 @@ class AgentLoop: get_tool_definitions=self.tools.get_definitions, max_completion_tokens=provider.generation.max_tokens, ) + self.dream = Dream( + store=self.context.memory, + provider=provider, + model=self.model, + ) self._register_default_tools() self.commands = CommandRouter() register_builtin_commands(self.commands) @@ -491,7 +496,7 @@ class AgentLoop: logger.info("Processing system message from {}", msg.sender_id) key = f"{channel}:{chat_id}" session = self.sessions.get_or_create(key) - await self.memory_consolidator.maybe_consolidate_by_tokens(session) + await self.consolidator.maybe_consolidate_by_tokens(session) self._set_tool_context(channel, chat_id, msg.metadata.get("message_id")) history = session.get_history(max_messages=0) current_role = "assistant" if msg.sender_id == "subagent" else "user" @@ -506,7 +511,7 @@ class AgentLoop: ) self._save_turn(session, all_msgs, 1 + len(history)) self.sessions.save(session) - self._schedule_background(self.memory_consolidator.maybe_consolidate_by_tokens(session)) + self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session)) return OutboundMessage(channel=channel, chat_id=chat_id, content=final_content or "Background task completed.") @@ -522,7 +527,7 @@ class AgentLoop: if result := await self.commands.dispatch(ctx): return result - await self.memory_consolidator.maybe_consolidate_by_tokens(session) + await self.consolidator.maybe_consolidate_by_tokens(session) self._set_tool_context(msg.channel, msg.chat_id, msg.metadata.get("message_id")) if message_tool := self.tools.get("message"): @@ -559,7 +564,7 @@ class AgentLoop: self._save_turn(session, all_msgs, 1 + len(history)) self.sessions.save(session) - self._schedule_background(self.memory_consolidator.maybe_consolidate_by_tokens(session)) + self._schedule_background(self.consolidator.maybe_consolidate_by_tokens(session)) if (mt := self.tools.get("message")) and isinstance(mt, MessageTool) and mt._sent_in_turn: return None diff --git a/nanobot/agent/memory.py b/nanobot/agent/memory.py index aa2de9290..6e9508954 100644 --- a/nanobot/agent/memory.py +++ b/nanobot/agent/memory.py @@ -1,4 +1,4 @@ -"""Memory system for persistent agent memory.""" +"""Memory system: pure file I/O store, lightweight Consolidator, and Dream processor.""" from __future__ import annotations @@ -11,94 +11,181 @@ from typing import TYPE_CHECKING, Any, Callable from loguru import logger -from nanobot.utils.helpers import ensure_dir, estimate_message_tokens, estimate_prompt_tokens_chain +from nanobot.utils.helpers import ensure_dir, estimate_message_tokens, estimate_prompt_tokens_chain, strip_think + +from nanobot.agent.runner import AgentRunSpec, AgentRunner +from nanobot.agent.tools.registry import ToolRegistry if TYPE_CHECKING: from nanobot.providers.base import LLMProvider from nanobot.session.manager import Session, SessionManager -_SAVE_MEMORY_TOOL = [ - { - "type": "function", - "function": { - "name": "save_memory", - "description": "Save the memory consolidation result to persistent storage.", - "parameters": { - "type": "object", - "properties": { - "history_entry": { - "type": "string", - "description": "A paragraph summarizing key events/decisions/topics. " - "Start with [YYYY-MM-DD HH:MM]. Include detail useful for grep search.", - }, - "memory_update": { - "type": "string", - "description": "Full updated long-term memory as markdown. Include all existing " - "facts plus new ones. Return unchanged if nothing new.", - }, - }, - "required": ["history_entry", "memory_update"], - }, - }, - } -] - - -def _ensure_text(value: Any) -> str: - """Normalize tool-call payload values to text for file storage.""" - return value if isinstance(value, str) else json.dumps(value, ensure_ascii=False) - - -def _normalize_save_memory_args(args: Any) -> dict[str, Any] | None: - """Normalize provider tool-call arguments to the expected dict shape.""" - if isinstance(args, str): - args = json.loads(args) - if isinstance(args, list): - return args[0] if args and isinstance(args[0], dict) else None - return args if isinstance(args, dict) else None - -_TOOL_CHOICE_ERROR_MARKERS = ( - "tool_choice", - "toolchoice", - "does not support", - 'should be ["none", "auto"]', -) - - -def _is_tool_choice_unsupported(content: str | None) -> bool: - """Detect provider errors caused by forced tool_choice being unsupported.""" - text = (content or "").lower() - return any(m in text for m in _TOOL_CHOICE_ERROR_MARKERS) - +# --------------------------------------------------------------------------- +# MemoryStore — pure file I/O layer +# --------------------------------------------------------------------------- class MemoryStore: - """Two-layer memory: MEMORY.md (long-term facts) + HISTORY.md (grep-searchable log).""" + """Pure file I/O for memory files: MEMORY.md, history.jsonl, SOUL.md, USER.md.""" - _MAX_FAILURES_BEFORE_RAW_ARCHIVE = 3 + _DEFAULT_MAX_HISTORY = 1000 - def __init__(self, workspace: Path): + def __init__(self, workspace: Path, max_history_entries: int = _DEFAULT_MAX_HISTORY): + self.workspace = workspace + self.max_history_entries = max_history_entries self.memory_dir = ensure_dir(workspace / "memory") self.memory_file = self.memory_dir / "MEMORY.md" - self.history_file = self.memory_dir / "HISTORY.md" - self._consecutive_failures = 0 + self.history_file = self.memory_dir / "history.jsonl" + self.soul_file = workspace / "SOUL.md" + self.user_file = workspace / "USER.md" + self._dream_log_file = self.memory_dir / ".dream-log.md" + self._cursor_file = self.memory_dir / ".cursor" + self._dream_cursor_file = self.memory_dir / ".dream_cursor" - def read_long_term(self) -> str: - if self.memory_file.exists(): - return self.memory_file.read_text(encoding="utf-8") - return "" + # -- generic helpers ----------------------------------------------------- - def write_long_term(self, content: str) -> None: + @staticmethod + def read_file(path: Path) -> str: + try: + return path.read_text(encoding="utf-8") + except FileNotFoundError: + return "" + + # -- MEMORY.md (long-term facts) ----------------------------------------- + + def read_memory(self) -> str: + return self.read_file(self.memory_file) + + def write_memory(self, content: str) -> None: self.memory_file.write_text(content, encoding="utf-8") - def append_history(self, entry: str) -> None: - with open(self.history_file, "a", encoding="utf-8") as f: - f.write(entry.rstrip() + "\n\n") + # -- SOUL.md ------------------------------------------------------------- + + def read_soul(self) -> str: + return self.read_file(self.soul_file) + + def write_soul(self, content: str) -> None: + self.soul_file.write_text(content, encoding="utf-8") + + # -- USER.md ------------------------------------------------------------- + + def read_user(self) -> str: + return self.read_file(self.user_file) + + def write_user(self, content: str) -> None: + self.user_file.write_text(content, encoding="utf-8") + + # -- context injection (used by context.py) ------------------------------ def get_memory_context(self) -> str: - long_term = self.read_long_term() + long_term = self.read_memory() return f"## Long-term Memory\n{long_term}" if long_term else "" + # -- history.jsonl — append-only, JSONL format --------------------------- + + def append_history(self, entry: str) -> int: + """Append *entry* to history.jsonl and return its auto-incrementing cursor.""" + cursor = self._next_cursor() + ts = datetime.now().strftime("%Y-%m-%d %H:%M") + record = {"cursor": cursor, "timestamp": ts, "content": strip_think(entry.rstrip()) or entry.rstrip()} + with open(self.history_file, "a", encoding="utf-8") as f: + f.write(json.dumps(record, ensure_ascii=False) + "\n") + self._cursor_file.write_text(str(cursor), encoding="utf-8") + return cursor + + def _next_cursor(self) -> int: + """Read the current cursor counter and return next value.""" + if self._cursor_file.exists(): + try: + return int(self._cursor_file.read_text(encoding="utf-8").strip()) + 1 + except (ValueError, OSError): + pass + # Fallback: read last line's cursor from the JSONL file. + last = self._read_last_entry() + if last: + return last["cursor"] + 1 + return 1 + + def read_unprocessed_history(self, since_cursor: int) -> list[dict[str, Any]]: + """Return history entries with cursor > *since_cursor*.""" + return [e for e in self._read_entries() if e["cursor"] > since_cursor] + + def compact_history(self) -> None: + """Drop oldest entries if the file exceeds *max_history_entries*.""" + if self.max_history_entries <= 0: + return + entries = self._read_entries() + if len(entries) <= self.max_history_entries: + return + kept = entries[-self.max_history_entries:] + self._write_entries(kept) + + # -- JSONL helpers ------------------------------------------------------- + + def _read_entries(self) -> list[dict[str, Any]]: + """Read all entries from history.jsonl.""" + entries: list[dict[str, Any]] = [] + try: + with open(self.history_file, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if line: + try: + entries.append(json.loads(line)) + except json.JSONDecodeError: + continue + except FileNotFoundError: + pass + return entries + + def _read_last_entry(self) -> dict[str, Any] | None: + """Read the last entry from the JSONL file efficiently.""" + try: + with open(self.history_file, "rb") as f: + f.seek(0, 2) + size = f.tell() + if size == 0: + return None + read_size = min(size, 4096) + f.seek(size - read_size) + data = f.read().decode("utf-8") + lines = [l for l in data.split("\n") if l.strip()] + if not lines: + return None + return json.loads(lines[-1]) + except (FileNotFoundError, json.JSONDecodeError): + return None + + def _write_entries(self, entries: list[dict[str, Any]]) -> None: + """Overwrite history.jsonl with the given entries.""" + with open(self.history_file, "w", encoding="utf-8") as f: + for entry in entries: + f.write(json.dumps(entry, ensure_ascii=False) + "\n") + + # -- dream cursor -------------------------------------------------------- + + def get_last_dream_cursor(self) -> int: + if self._dream_cursor_file.exists(): + try: + return int(self._dream_cursor_file.read_text(encoding="utf-8").strip()) + except (ValueError, OSError): + pass + return 0 + + def set_last_dream_cursor(self, cursor: int) -> None: + self._dream_cursor_file.write_text(str(cursor), encoding="utf-8") + + # -- dream log ----------------------------------------------------------- + + def read_dream_log(self) -> str: + return self.read_file(self._dream_log_file) + + def append_dream_log(self, entry: str) -> None: + with open(self._dream_log_file, "a", encoding="utf-8") as f: + f.write(f"{entry.rstrip()}\n\n") + + # -- message formatting utility ------------------------------------------ + @staticmethod def _format_messages(messages: list[dict]) -> str: lines = [] @@ -111,107 +198,10 @@ class MemoryStore: ) return "\n".join(lines) - async def consolidate( - self, - messages: list[dict], - provider: LLMProvider, - model: str, - ) -> bool: - """Consolidate the provided message chunk into MEMORY.md + HISTORY.md.""" - if not messages: - return True - - current_memory = self.read_long_term() - prompt = f"""Process this conversation and call the save_memory tool with your consolidation. - -## Current Long-term Memory -{current_memory or "(empty)"} - -## Conversation to Process -{self._format_messages(messages)}""" - - chat_messages = [ - {"role": "system", "content": "You are a memory consolidation agent. Call the save_memory tool with your consolidation of the conversation."}, - {"role": "user", "content": prompt}, - ] - - try: - forced = {"type": "function", "function": {"name": "save_memory"}} - response = await provider.chat_with_retry( - messages=chat_messages, - tools=_SAVE_MEMORY_TOOL, - model=model, - tool_choice=forced, - ) - - if response.finish_reason == "error" and _is_tool_choice_unsupported( - response.content - ): - logger.warning("Forced tool_choice unsupported, retrying with auto") - response = await provider.chat_with_retry( - messages=chat_messages, - tools=_SAVE_MEMORY_TOOL, - model=model, - tool_choice="auto", - ) - - if not response.has_tool_calls: - logger.warning( - "Memory consolidation: LLM did not call save_memory " - "(finish_reason={}, content_len={}, content_preview={})", - response.finish_reason, - len(response.content or ""), - (response.content or "")[:200], - ) - return self._fail_or_raw_archive(messages) - - args = _normalize_save_memory_args(response.tool_calls[0].arguments) - if args is None: - logger.warning("Memory consolidation: unexpected save_memory arguments") - return self._fail_or_raw_archive(messages) - - if "history_entry" not in args or "memory_update" not in args: - logger.warning("Memory consolidation: save_memory payload missing required fields") - return self._fail_or_raw_archive(messages) - - entry = args["history_entry"] - update = args["memory_update"] - - if entry is None or update is None: - logger.warning("Memory consolidation: save_memory payload contains null required fields") - return self._fail_or_raw_archive(messages) - - entry = _ensure_text(entry).strip() - if not entry: - logger.warning("Memory consolidation: history_entry is empty after normalization") - return self._fail_or_raw_archive(messages) - - self.append_history(entry) - update = _ensure_text(update) - if update != current_memory: - self.write_long_term(update) - - self._consecutive_failures = 0 - logger.info("Memory consolidation done for {} messages", len(messages)) - return True - except Exception: - logger.exception("Memory consolidation failed") - return self._fail_or_raw_archive(messages) - - def _fail_or_raw_archive(self, messages: list[dict]) -> bool: - """Increment failure count; after threshold, raw-archive messages and return True.""" - self._consecutive_failures += 1 - if self._consecutive_failures < self._MAX_FAILURES_BEFORE_RAW_ARCHIVE: - return False - self._raw_archive(messages) - self._consecutive_failures = 0 - return True - - def _raw_archive(self, messages: list[dict]) -> None: + def raw_archive(self, messages: list[dict]) -> None: """Fallback: dump raw messages to HISTORY.md without LLM summarization.""" - ts = datetime.now().strftime("%Y-%m-%d %H:%M") self.append_history( - f"[{ts}] [RAW] {len(messages)} messages\n" + f"[RAW] {len(messages)} messages\n" f"{self._format_messages(messages)}" ) logger.warning( @@ -219,8 +209,14 @@ class MemoryStore: ) -class MemoryConsolidator: - """Owns consolidation policy, locking, and session offset updates.""" + +# --------------------------------------------------------------------------- +# Consolidator — lightweight token-budget triggered consolidation +# --------------------------------------------------------------------------- + + +class Consolidator: + """Lightweight consolidation: summarizes evicted messages, appends to HISTORY.md.""" _MAX_CONSOLIDATION_ROUNDS = 5 @@ -228,7 +224,7 @@ class MemoryConsolidator: def __init__( self, - workspace: Path, + store: MemoryStore, provider: LLMProvider, model: str, sessions: SessionManager, @@ -237,7 +233,7 @@ class MemoryConsolidator: get_tool_definitions: Callable[[], list[dict[str, Any]]], max_completion_tokens: int = 4096, ): - self.store = MemoryStore(workspace) + self.store = store self.provider = provider self.model = model self.sessions = sessions @@ -245,16 +241,14 @@ class MemoryConsolidator: self.max_completion_tokens = max_completion_tokens self._build_messages = build_messages self._get_tool_definitions = get_tool_definitions - self._locks: weakref.WeakValueDictionary[str, asyncio.Lock] = weakref.WeakValueDictionary() + self._locks: weakref.WeakValueDictionary[str, asyncio.Lock] = ( + weakref.WeakValueDictionary() + ) def get_lock(self, session_key: str) -> asyncio.Lock: """Return the shared consolidation lock for one session.""" return self._locks.setdefault(session_key, asyncio.Lock()) - async def consolidate_messages(self, messages: list[dict[str, object]]) -> bool: - """Archive a selected message chunk into persistent memory.""" - return await self.store.consolidate(messages, self.provider, self.model) - def pick_consolidation_boundary( self, session: Session, @@ -294,14 +288,48 @@ class MemoryConsolidator: self._get_tool_definitions(), ) - async def archive_messages(self, messages: list[dict[str, object]]) -> bool: - """Archive messages with guaranteed persistence (retries until raw-dump fallback).""" + async def archive(self, messages: list[dict]) -> bool: + """Summarize messages via LLM and append to HISTORY.md. + + Returns True on success (or degraded success), False if nothing to do. + """ if not messages: + return False + try: + formatted = MemoryStore._format_messages(messages) + response = await self.provider.chat_with_retry( + model=self.model, + messages=[ + { + "role": "system", + "content": ( + "Extract key facts from this conversation. " + "Only output items matching these categories, skip everything else:\n" + "- User facts: personal info, preferences, stated opinions, habits\n" + "- Decisions: choices made, conclusions reached\n" + "- Events: plans, deadlines, notable occurrences\n" + "- Preferences: communication style, tool preferences\n\n" + "Priority: user corrections and preferences > decisions > events > environment facts. " + "The most valuable memory prevents the user from having to repeat themselves.\n\n" + "Skip: code patterns derivable from source, git history, debug steps already in code, " + "or anything already captured in existing memory.\n\n" + "Output as concise bullet points, one fact per line. " + "No preamble, no commentary.\n" + "If nothing noteworthy happened, output: (nothing)" + ), + }, + {"role": "user", "content": formatted}, + ], + tools=None, + tool_choice=None, + ) + summary = response.content or "[no summary]" + self.store.append_history(summary) + return True + except Exception: + logger.warning("Consolidation LLM call failed, raw-dumping to history") + self.store.raw_archive(messages) return True - for _ in range(self.store._MAX_FAILURES_BEFORE_RAW_ARCHIVE): - if await self.consolidate_messages(messages): - return True - return True async def maybe_consolidate_by_tokens(self, session: Session) -> None: """Loop: archive old messages until prompt fits within safe budget. @@ -356,7 +384,7 @@ class MemoryConsolidator: source, len(chunk), ) - if not await self.consolidate_messages(chunk): + if not await self.archive(chunk): return session.last_consolidated = end_idx self.sessions.save(session) @@ -364,3 +392,186 @@ class MemoryConsolidator: estimated, source = self.estimate_session_prompt_tokens(session) if estimated <= 0: return + + +# --------------------------------------------------------------------------- +# Dream — heavyweight cron-scheduled memory consolidation +# --------------------------------------------------------------------------- + + +class Dream: + """Two-phase memory processor: analyze HISTORY.md, then edit files via AgentRunner. + + Phase 1 produces an analysis summary (plain LLM call). + Phase 2 delegates to AgentRunner with read_file / edit_file tools so the + LLM can make targeted, incremental edits instead of replacing entire files. + """ + + _PHASE1_SYSTEM = ( + "Compare conversation history against current memory files. " + "Output one line per finding:\n" + "[FILE] atomic fact or change description\n\n" + "Files: USER (identity, preferences, habits), " + "SOUL (bot behavior, tone), " + "MEMORY (knowledge, project context, tool patterns)\n\n" + "Rules:\n" + "- Only new or conflicting information — skip duplicates and ephemera\n" + "- Prefer atomic facts: \"has a cat named Luna\" not \"discussed pet care\"\n" + "- Corrections: [USER] location is Tokyo, not Osaka\n" + "- Also capture confirmed approaches: if the user validated a non-obvious choice, note it\n\n" + "If nothing needs updating: [SKIP] no new information" + ) + + _PHASE2_SYSTEM = ( + "Update memory files based on the analysis below.\n\n" + "## Quality standards\n" + "- Every line must carry standalone value — no filler\n" + "- Concise bullet points under clear headers\n" + "- Remove outdated or contradicted information\n\n" + "## Editing\n" + "- File contents provided below — edit directly, no read_file needed\n" + "- Batch changes to the same file into one edit_file call\n" + "- Surgical edits only — never rewrite entire files\n" + "- Do NOT overwrite correct entries — only add, update, or remove\n" + "- If nothing to update, stop without calling tools" + ) + + def __init__( + self, + store: MemoryStore, + provider: LLMProvider, + model: str, + max_batch_size: int = 20, + max_iterations: int = 10, + ): + self.store = store + self.provider = provider + self.model = model + self.max_batch_size = max_batch_size + self.max_iterations = max_iterations + self._runner = AgentRunner(provider) + self._tools = self._build_tools() + + # -- tool registry ------------------------------------------------------- + + def _build_tools(self) -> ToolRegistry: + """Build a minimal tool registry for the Dream agent.""" + from nanobot.agent.tools.filesystem import EditFileTool, ReadFileTool + + tools = ToolRegistry() + workspace = self.store.workspace + tools.register(ReadFileTool(workspace=workspace, allowed_dir=workspace)) + tools.register(EditFileTool(workspace=workspace, allowed_dir=workspace)) + return tools + + # -- main entry ---------------------------------------------------------- + + async def run(self) -> bool: + """Process unprocessed history entries. Returns True if work was done.""" + last_cursor = self.store.get_last_dream_cursor() + entries = self.store.read_unprocessed_history(since_cursor=last_cursor) + if not entries: + return False + + batch = entries[: self.max_batch_size] + logger.info( + "Dream: processing {} entries (cursor {}→{}), batch={}", + len(entries), last_cursor, batch[-1]["cursor"], len(batch), + ) + + # Build history text for LLM + history_text = "\n".join( + f"[{e['timestamp']}] {e['content']}" for e in batch + ) + + # Current file contents + current_memory = self.store.read_memory() or "(empty)" + current_soul = self.store.read_soul() or "(empty)" + current_user = self.store.read_user() or "(empty)" + file_context = ( + f"## Current MEMORY.md\n{current_memory}\n\n" + f"## Current SOUL.md\n{current_soul}\n\n" + f"## Current USER.md\n{current_user}" + ) + + # Phase 1: Analyze + phase1_prompt = ( + f"## Conversation History\n{history_text}\n\n{file_context}" + ) + + try: + phase1_response = await self.provider.chat_with_retry( + model=self.model, + messages=[ + {"role": "system", "content": self._PHASE1_SYSTEM}, + {"role": "user", "content": phase1_prompt}, + ], + tools=None, + tool_choice=None, + ) + analysis = phase1_response.content or "" + logger.debug("Dream Phase 1 complete ({} chars)", len(analysis)) + except Exception: + logger.exception("Dream Phase 1 failed") + return False + + # Phase 2: Delegate to AgentRunner with read_file / edit_file + phase2_prompt = f"## Analysis Result\n{analysis}\n\n{file_context}" + + tools = self._tools + messages: list[dict[str, Any]] = [ + {"role": "system", "content": self._PHASE2_SYSTEM}, + {"role": "user", "content": phase2_prompt}, + ] + + try: + result = await self._runner.run(AgentRunSpec( + initial_messages=messages, + tools=tools, + model=self.model, + max_iterations=self.max_iterations, + fail_on_tool_error=True, + )) + logger.debug( + "Dream Phase 2 complete: stop_reason={}, tool_events={}", + result.stop_reason, len(result.tool_events), + ) + except Exception: + logger.exception("Dream Phase 2 failed") + result = None + + # Build changelog from tool events + changelog: list[str] = [] + if result and result.tool_events: + for event in result.tool_events: + if event["status"] == "ok": + changelog.append(f"{event['name']}: {event['detail']}") + + # Advance cursor — always, to avoid re-processing Phase 1 + new_cursor = batch[-1]["cursor"] + self.store.set_last_dream_cursor(new_cursor) + self.store.compact_history() + + if result and result.stop_reason == "completed": + logger.info( + "Dream done: {} change(s), cursor advanced to {}", + len(changelog), new_cursor, + ) + else: + reason = result.stop_reason if result else "exception" + logger.warning( + "Dream incomplete ({}): cursor advanced to {}", + reason, new_cursor, + ) + + # Write dream log + ts = datetime.now().strftime("%Y-%m-%d %H:%M") + if changelog: + log_entry = f"## {ts}\n" + for change in changelog: + log_entry += f"- {change}\n" + self.store.append_dream_log(log_entry) + else: + self.store.append_dream_log(f"## {ts}\nNo changes.\n") + + return True diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py index 7f7d24f39..8184b51b1 100644 --- a/nanobot/cli/commands.py +++ b/nanobot/cli/commands.py @@ -22,6 +22,7 @@ if sys.platform == "win32": pass import typer +from loguru import logger from prompt_toolkit import PromptSession, print_formatted_text from prompt_toolkit.application import run_in_terminal from prompt_toolkit.formatted_text import ANSI, HTML @@ -640,6 +641,15 @@ def gateway( # Set cron callback (needs agent) async def on_cron_job(job: CronJob) -> str | None: """Execute a cron job through the agent.""" + # Dream is an internal job — run directly, not through the agent loop. + if job.name == "dream": + try: + await agent.dream.run() + logger.info("Dream cron job completed") + except Exception: + logger.exception("Dream cron job failed") + return None + from nanobot.agent.tools.cron import CronTool from nanobot.agent.tools.message import MessageTool from nanobot.utils.evaluator import evaluate_response @@ -759,6 +769,21 @@ def gateway( console.print(f"[green]✓[/green] Heartbeat: every {hb_cfg.interval_s}s") + # Register Dream cron job (always-on, idempotent on restart) + dream_cfg = config.agents.defaults.dream + if dream_cfg.model: + agent.dream.model = dream_cfg.model + agent.dream.max_batch_size = dream_cfg.max_batch_size + agent.dream.max_iterations = dream_cfg.max_iterations + from nanobot.cron.types import CronJob, CronPayload, CronSchedule + cron.register_system_job(CronJob( + id="dream", + name="dream", + schedule=CronSchedule(kind="cron", expr=dream_cfg.cron, tz=config.agents.defaults.timezone), + payload=CronPayload(kind="system_event"), + )) + console.print(f"[green]✓[/green] Dream: cron {dream_cfg.cron}") + async def run(): try: await cron.start() diff --git a/nanobot/command/builtin.py b/nanobot/command/builtin.py index 643397057..97fefe6cf 100644 --- a/nanobot/command/builtin.py +++ b/nanobot/command/builtin.py @@ -47,7 +47,7 @@ async def cmd_status(ctx: CommandContext) -> OutboundMessage: session = ctx.session or loop.sessions.get_or_create(ctx.key) ctx_est = 0 try: - ctx_est, _ = loop.memory_consolidator.estimate_session_prompt_tokens(session) + ctx_est, _ = loop.consolidator.estimate_session_prompt_tokens(session) except Exception: pass if ctx_est <= 0: @@ -75,13 +75,47 @@ async def cmd_new(ctx: CommandContext) -> OutboundMessage: loop.sessions.save(session) loop.sessions.invalidate(session.key) if snapshot: - loop._schedule_background(loop.memory_consolidator.archive_messages(snapshot)) + loop._schedule_background(loop.consolidator.archive(snapshot)) return OutboundMessage( channel=ctx.msg.channel, chat_id=ctx.msg.chat_id, content="New session started.", ) +async def cmd_dream(ctx: CommandContext) -> OutboundMessage: + """Manually trigger a Dream consolidation run.""" + loop = ctx.loop + try: + did_work = await loop.dream.run() + content = "Dream completed." if did_work else "Dream: nothing to process." + except Exception as e: + content = f"Dream failed: {e}" + return OutboundMessage( + channel=ctx.msg.channel, chat_id=ctx.msg.chat_id, content=content, + ) + + +async def cmd_dream_log(ctx: CommandContext) -> OutboundMessage: + """Show the Dream consolidation log.""" + loop = ctx.loop + store = loop.consolidator.store + log = store.read_dream_log() + if not log: + # Check if Dream has ever processed anything + if store.get_last_dream_cursor() == 0: + content = "Dream has not run yet." + else: + content = "No dream log yet." + else: + content = f"## Dream Log\n\n{log}" + return OutboundMessage( + channel=ctx.msg.channel, + chat_id=ctx.msg.chat_id, + content=content, + metadata={"render_as": "text"}, + ) + + async def cmd_help(ctx: CommandContext) -> OutboundMessage: """Return available slash commands.""" return OutboundMessage( @@ -100,6 +134,8 @@ def build_help_text() -> str: "/stop — Stop the current task", "/restart — Restart the bot", "/status — Show bot status", + "/dream — Manually trigger Dream consolidation", + "/dream-log — Show Dream consolidation log", "/help — Show available commands", ] return "\n".join(lines) @@ -112,4 +148,6 @@ def register_builtin_commands(router: CommandRouter) -> None: router.priority("/status", cmd_status) router.exact("/new", cmd_new) router.exact("/status", cmd_status) + router.exact("/dream", cmd_dream) + router.exact("/dream-log", cmd_dream_log) router.exact("/help", cmd_help) diff --git a/nanobot/config/schema.py b/nanobot/config/schema.py index c4c927afd..18ca7a3be 100644 --- a/nanobot/config/schema.py +++ b/nanobot/config/schema.py @@ -28,6 +28,15 @@ class ChannelsConfig(Base): send_max_retries: int = Field(default=3, ge=0, le=10) # Max delivery attempts (initial send included) +class DreamConfig(Base): + """Dream memory consolidation configuration.""" + + cron: str = "0 */2 * * *" # Every 2 hours + model: str | None = None # Override model for Dream + max_batch_size: int = Field(default=20, ge=1) # Max history entries per run + max_iterations: int = Field(default=10, ge=1) # Max tool calls per Phase 2 + + class AgentDefaults(Base): """Default agent configuration.""" @@ -42,6 +51,7 @@ class AgentDefaults(Base): max_tool_iterations: int = 40 reasoning_effort: str | None = None # low / medium / high - enables LLM thinking mode timezone: str = "UTC" # IANA timezone, e.g. "Asia/Shanghai", "America/New_York" + dream: DreamConfig = Field(default_factory=DreamConfig) class AgentsConfig(Base): diff --git a/nanobot/cron/service.py b/nanobot/cron/service.py index c956b897f..f7b81d8d3 100644 --- a/nanobot/cron/service.py +++ b/nanobot/cron/service.py @@ -351,6 +351,20 @@ class CronService: logger.info("Cron: added job '{}' ({})", name, job.id) return job + def register_system_job(self, job: CronJob) -> CronJob: + """Register an internal system job (idempotent on restart).""" + store = self._load_store() + now = _now_ms() + job.state = CronJobState(next_run_at_ms=_compute_next_run(job.schedule, now)) + job.created_at_ms = now + job.updated_at_ms = now + store.jobs = [j for j in store.jobs if j.id != job.id] + store.jobs.append(job) + self._save_store() + self._arm_timer() + logger.info("Cron: registered system job '{}' ({})", job.name, job.id) + return job + def remove_job(self, job_id: str) -> bool: """Remove a job by ID.""" store = self._load_store() diff --git a/nanobot/skills/memory/SKILL.md b/nanobot/skills/memory/SKILL.md index 3f0a8fc2b..52b149e5b 100644 --- a/nanobot/skills/memory/SKILL.md +++ b/nanobot/skills/memory/SKILL.md @@ -1,6 +1,6 @@ --- name: memory -description: Two-layer memory system with grep-based recall. +description: Two-layer memory system with Dream-managed knowledge files. always: true --- @@ -8,30 +8,23 @@ always: true ## Structure -- `memory/MEMORY.md` — Long-term facts (preferences, project context, relationships). Always loaded into your context. -- `memory/HISTORY.md` — Append-only event log. NOT loaded into context. Search it with grep-style tools or in-memory filters. Each entry starts with [YYYY-MM-DD HH:MM]. +- `SOUL.md` — Bot personality and communication style. **Managed by Dream.** Do NOT edit. +- `USER.md` — User profile and preferences. **Managed by Dream.** Do NOT edit. +- `memory/MEMORY.md` — Long-term facts (project context, important events). **Managed by Dream.** Do NOT edit. +- `memory/history.jsonl` — append-only JSONL, not loaded into context. search with `jq`-style tools. +- `memory/.dream-log.md` — Changelog of what Dream changed. View with `/dream-log`. ## Search Past Events -Choose the search method based on file size: +`memory/history.jsonl` is JSONL format — each line is a JSON object with `cursor`, `timestamp`, `content`. -- Small `memory/HISTORY.md`: use `read_file`, then search in-memory -- Large or long-lived `memory/HISTORY.md`: use the `exec` tool for targeted search +Examples (replace `keyword`): +- **Python (cross-platform):** `python -c "import json; [print(json.loads(l).get('content','')) for l in open('memory/history.jsonl','r',encoding='utf-8') if l.strip() and 'keyword' in l.lower()][-20:]"` +- **jq:** `cat memory/history.jsonl | jq -r 'select(.content | test("keyword"; "i")) | .content' | tail -20` +- **grep:** `grep -i "keyword" memory/history.jsonl` -Examples: -- **Linux/macOS:** `grep -i "keyword" memory/HISTORY.md` -- **Windows:** `findstr /i "keyword" memory\HISTORY.md` -- **Cross-platform Python:** `python -c "from pathlib import Path; text = Path('memory/HISTORY.md').read_text(encoding='utf-8'); print('\n'.join([l for l in text.splitlines() if 'keyword' in l.lower()][-20:]))"` +## Important -Prefer targeted command-line search for large history files. - -## When to Update MEMORY.md - -Write important facts immediately using `edit_file` or `write_file`: -- User preferences ("I prefer dark mode") -- Project context ("The API uses OAuth2") -- Relationships ("Alice is the project lead") - -## Auto-consolidation - -Old conversations are automatically summarized and appended to HISTORY.md when the session grows large. Long-term facts are extracted to MEMORY.md. You don't need to manage this. +- **Do NOT edit SOUL.md, USER.md, or MEMORY.md.** They are automatically managed by Dream. +- If you notice outdated information, it will be corrected when Dream runs next. +- Users can view Dream's activity with the `/dream-log` command. diff --git a/nanobot/utils/helpers.py b/nanobot/utils/helpers.py index 406a4dd45..badfc978d 100644 --- a/nanobot/utils/helpers.py +++ b/nanobot/utils/helpers.py @@ -296,7 +296,7 @@ def sync_workspace_templates(workspace: Path, silent: bool = False) -> list[str] if item.name.endswith(".md") and not item.name.startswith("."): _write(item, workspace / item.name) _write(tpl / "memory" / "MEMORY.md", workspace / "memory" / "MEMORY.md") - _write(None, workspace / "memory" / "HISTORY.md") + _write(None, workspace / "memory" / "history.jsonl") (workspace / "skills").mkdir(exist_ok=True) if added and not silent: diff --git a/tests/agent/test_consolidate_offset.py b/tests/agent/test_consolidate_offset.py index 4f2e8f1c2..f6232c348 100644 --- a/tests/agent/test_consolidate_offset.py +++ b/tests/agent/test_consolidate_offset.py @@ -506,7 +506,7 @@ class TestNewCommandArchival: @pytest.mark.asyncio async def test_new_clears_session_immediately_even_if_archive_fails(self, tmp_path: Path) -> None: - """/new clears session immediately; archive_messages retries until raw dump.""" + """/new clears session immediately; archive is fire-and-forget.""" from nanobot.bus.events import InboundMessage loop = self._make_loop(tmp_path) @@ -518,12 +518,12 @@ class TestNewCommandArchival: call_count = 0 - async def _failing_consolidate(_messages) -> bool: + async def _failing_summarize(_messages) -> bool: nonlocal call_count call_count += 1 return False - loop.memory_consolidator.consolidate_messages = _failing_consolidate # type: ignore[method-assign] + loop.consolidator.archive = _failing_summarize # type: ignore[method-assign] new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new") response = await loop._process_message(new_msg) @@ -535,7 +535,7 @@ class TestNewCommandArchival: assert len(session_after.messages) == 0 await loop.close_mcp() - assert call_count == 3 # retried up to raw-archive threshold + assert call_count == 1 @pytest.mark.asyncio async def test_new_archives_only_unconsolidated_messages(self, tmp_path: Path) -> None: @@ -551,12 +551,12 @@ class TestNewCommandArchival: archived_count = -1 - async def _fake_consolidate(messages) -> bool: + async def _fake_summarize(messages) -> bool: nonlocal archived_count archived_count = len(messages) return True - loop.memory_consolidator.consolidate_messages = _fake_consolidate # type: ignore[method-assign] + loop.consolidator.archive = _fake_summarize # type: ignore[method-assign] new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new") response = await loop._process_message(new_msg) @@ -578,10 +578,10 @@ class TestNewCommandArchival: session.add_message("assistant", f"resp{i}") loop.sessions.save(session) - async def _ok_consolidate(_messages) -> bool: + async def _ok_summarize(_messages) -> bool: return True - loop.memory_consolidator.consolidate_messages = _ok_consolidate # type: ignore[method-assign] + loop.consolidator.archive = _ok_summarize # type: ignore[method-assign] new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new") response = await loop._process_message(new_msg) @@ -604,12 +604,12 @@ class TestNewCommandArchival: archived = asyncio.Event() - async def _slow_consolidate(_messages) -> bool: + async def _slow_summarize(_messages) -> bool: await asyncio.sleep(0.1) archived.set() return True - loop.memory_consolidator.consolidate_messages = _slow_consolidate # type: ignore[method-assign] + loop.consolidator.archive = _slow_summarize # type: ignore[method-assign] new_msg = InboundMessage(channel="cli", sender_id="user", chat_id="test", content="/new") await loop._process_message(new_msg) diff --git a/tests/agent/test_consolidator.py b/tests/agent/test_consolidator.py new file mode 100644 index 000000000..72968b0e1 --- /dev/null +++ b/tests/agent/test_consolidator.py @@ -0,0 +1,78 @@ +"""Tests for the lightweight Consolidator — append-only to HISTORY.md.""" + +import pytest +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +from nanobot.agent.memory import Consolidator, MemoryStore + + +@pytest.fixture +def store(tmp_path): + return MemoryStore(tmp_path) + + +@pytest.fixture +def mock_provider(): + p = MagicMock() + p.chat_with_retry = AsyncMock() + return p + + +@pytest.fixture +def consolidator(store, mock_provider): + sessions = MagicMock() + sessions.save = MagicMock() + return Consolidator( + store=store, + provider=mock_provider, + model="test-model", + sessions=sessions, + context_window_tokens=1000, + build_messages=MagicMock(return_value=[]), + get_tool_definitions=MagicMock(return_value=[]), + max_completion_tokens=100, + ) + + +class TestConsolidatorSummarize: + async def test_summarize_appends_to_history(self, consolidator, mock_provider, store): + """Consolidator should call LLM to summarize, then append to HISTORY.md.""" + mock_provider.chat_with_retry.return_value = MagicMock( + content="User fixed a bug in the auth module." + ) + messages = [ + {"role": "user", "content": "fix the auth bug"}, + {"role": "assistant", "content": "Done, fixed the race condition."}, + ] + result = await consolidator.archive(messages) + assert result is True + entries = store.read_unprocessed_history(since_cursor=0) + assert len(entries) == 1 + + async def test_summarize_raw_dumps_on_llm_failure(self, consolidator, mock_provider, store): + """On LLM failure, raw-dump messages to HISTORY.md.""" + mock_provider.chat_with_retry.side_effect = Exception("API error") + messages = [{"role": "user", "content": "hello"}] + result = await consolidator.archive(messages) + assert result is True # always succeeds + entries = store.read_unprocessed_history(since_cursor=0) + assert len(entries) == 1 + assert "[RAW]" in entries[0]["content"] + + async def test_summarize_skips_empty_messages(self, consolidator): + result = await consolidator.archive([]) + assert result is False + + +class TestConsolidatorTokenBudget: + async def test_prompt_below_threshold_does_not_consolidate(self, consolidator): + """No consolidation when tokens are within budget.""" + session = MagicMock() + session.last_consolidated = 0 + session.messages = [{"role": "user", "content": "hi"}] + session.key = "test:key" + consolidator.estimate_session_prompt_tokens = MagicMock(return_value=(100, "tiktoken")) + consolidator.archive = AsyncMock(return_value=True) + await consolidator.maybe_consolidate_by_tokens(session) + consolidator.archive.assert_not_called() diff --git a/tests/agent/test_dream.py b/tests/agent/test_dream.py new file mode 100644 index 000000000..7898ea267 --- /dev/null +++ b/tests/agent/test_dream.py @@ -0,0 +1,97 @@ +"""Tests for the Dream class — two-phase memory consolidation via AgentRunner.""" + +import pytest + +from unittest.mock import AsyncMock, MagicMock + +from nanobot.agent.memory import Dream, MemoryStore +from nanobot.agent.runner import AgentRunResult + + +@pytest.fixture +def store(tmp_path): + s = MemoryStore(tmp_path) + s.write_soul("# Soul\n- Helpful") + s.write_user("# User\n- Developer") + s.write_memory("# Memory\n- Project X active") + return s + + +@pytest.fixture +def mock_provider(): + p = MagicMock() + p.chat_with_retry = AsyncMock() + return p + + +@pytest.fixture +def mock_runner(): + return MagicMock() + + +@pytest.fixture +def dream(store, mock_provider, mock_runner): + d = Dream(store=store, provider=mock_provider, model="test-model", max_batch_size=5) + d._runner = mock_runner + return d + + +def _make_run_result( + stop_reason="completed", + final_content=None, + tool_events=None, + usage=None, +): + return AgentRunResult( + final_content=final_content or stop_reason, + stop_reason=stop_reason, + messages=[], + tools_used=[], + usage={}, + tool_events=tool_events or [], + ) + + +class TestDreamRun: + async def test_noop_when_no_unprocessed_history(self, dream, mock_provider, mock_runner, store): + """Dream should not call LLM when there's nothing to process.""" + result = await dream.run() + assert result is False + mock_provider.chat_with_retry.assert_not_called() + mock_runner.run.assert_not_called() + + async def test_calls_runner_for_unprocessed_entries(self, dream, mock_provider, mock_runner, store): + """Dream should call AgentRunner when there are unprocessed history entries.""" + store.append_history("User prefers dark mode") + mock_provider.chat_with_retry.return_value = MagicMock(content="New fact") + mock_runner.run = AsyncMock(return_value=_make_run_result( + tool_events=[{"name": "edit_file", "status": "ok", "detail": "memory/MEMORY.md"}], + )) + result = await dream.run() + assert result is True + mock_runner.run.assert_called_once() + spec = mock_runner.run.call_args[0][0] + assert spec.max_iterations == 10 + assert spec.fail_on_tool_error is True + + async def test_advances_dream_cursor(self, dream, mock_provider, mock_runner, store): + """Dream should advance the cursor after processing.""" + store.append_history("event 1") + store.append_history("event 2") + mock_provider.chat_with_retry.return_value = MagicMock(content="Nothing new") + mock_runner.run = AsyncMock(return_value=_make_run_result()) + await dream.run() + assert store.get_last_dream_cursor() == 2 + + async def test_compacts_processed_history(self, dream, mock_provider, mock_runner, store): + """Dream should compact history after processing.""" + store.append_history("event 1") + store.append_history("event 2") + store.append_history("event 3") + mock_provider.chat_with_retry.return_value = MagicMock(content="Nothing new") + mock_runner.run = AsyncMock(return_value=_make_run_result()) + await dream.run() + # After Dream, cursor is advanced and 3, compact keeps last max_history_entries + entries = store.read_unprocessed_history(since_cursor=0) + assert all(e["cursor"] > 0 for e in entries) + diff --git a/tests/agent/test_hook_composite.py b/tests/agent/test_hook_composite.py index 203c892fb..590d8db64 100644 --- a/tests/agent/test_hook_composite.py +++ b/tests/agent/test_hook_composite.py @@ -249,7 +249,8 @@ def _make_loop(tmp_path, hooks=None): with patch("nanobot.agent.loop.ContextBuilder"), \ patch("nanobot.agent.loop.SessionManager"), \ patch("nanobot.agent.loop.SubagentManager") as mock_sub_mgr, \ - patch("nanobot.agent.loop.MemoryConsolidator"): + patch("nanobot.agent.loop.Consolidator"), \ + patch("nanobot.agent.loop.Dream"): mock_sub_mgr.return_value.cancel_by_session = AsyncMock(return_value=0) loop = AgentLoop( bus=bus, provider=provider, workspace=tmp_path, hooks=hooks, diff --git a/tests/agent/test_loop_consolidation_tokens.py b/tests/agent/test_loop_consolidation_tokens.py index 2f9c2dea7..87e159cc8 100644 --- a/tests/agent/test_loop_consolidation_tokens.py +++ b/tests/agent/test_loop_consolidation_tokens.py @@ -26,24 +26,24 @@ def _make_loop(tmp_path, *, estimated_tokens: int, context_window_tokens: int) - context_window_tokens=context_window_tokens, ) loop.tools.get_definitions = MagicMock(return_value=[]) - loop.memory_consolidator._SAFETY_BUFFER = 0 + loop.consolidator._SAFETY_BUFFER = 0 return loop @pytest.mark.asyncio async def test_prompt_below_threshold_does_not_consolidate(tmp_path) -> None: loop = _make_loop(tmp_path, estimated_tokens=100, context_window_tokens=200) - loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign] + loop.consolidator.archive = AsyncMock(return_value=True) # type: ignore[method-assign] await loop.process_direct("hello", session_key="cli:test") - loop.memory_consolidator.consolidate_messages.assert_not_awaited() + loop.consolidator.archive.assert_not_awaited() @pytest.mark.asyncio async def test_prompt_above_threshold_triggers_consolidation(tmp_path, monkeypatch) -> None: loop = _make_loop(tmp_path, estimated_tokens=1000, context_window_tokens=200) - loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign] + loop.consolidator.archive = AsyncMock(return_value=True) # type: ignore[method-assign] session = loop.sessions.get_or_create("cli:test") session.messages = [ {"role": "user", "content": "u1", "timestamp": "2026-01-01T00:00:00"}, @@ -55,13 +55,13 @@ async def test_prompt_above_threshold_triggers_consolidation(tmp_path, monkeypat await loop.process_direct("hello", session_key="cli:test") - assert loop.memory_consolidator.consolidate_messages.await_count >= 1 + assert loop.consolidator.archive.await_count >= 1 @pytest.mark.asyncio async def test_prompt_above_threshold_archives_until_next_user_boundary(tmp_path, monkeypatch) -> None: loop = _make_loop(tmp_path, estimated_tokens=1000, context_window_tokens=200) - loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign] + loop.consolidator.archive = AsyncMock(return_value=True) # type: ignore[method-assign] session = loop.sessions.get_or_create("cli:test") session.messages = [ @@ -76,9 +76,9 @@ async def test_prompt_above_threshold_archives_until_next_user_boundary(tmp_path token_map = {"u1": 120, "a1": 120, "u2": 120, "a2": 120, "u3": 120} monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda message: token_map[message["content"]]) - await loop.memory_consolidator.maybe_consolidate_by_tokens(session) + await loop.consolidator.maybe_consolidate_by_tokens(session) - archived_chunk = loop.memory_consolidator.consolidate_messages.await_args.args[0] + archived_chunk = loop.consolidator.archive.await_args.args[0] assert [message["content"] for message in archived_chunk] == ["u1", "a1", "u2", "a2"] assert session.last_consolidated == 4 @@ -87,7 +87,7 @@ async def test_prompt_above_threshold_archives_until_next_user_boundary(tmp_path async def test_consolidation_loops_until_target_met(tmp_path, monkeypatch) -> None: """Verify maybe_consolidate_by_tokens keeps looping until under threshold.""" loop = _make_loop(tmp_path, estimated_tokens=0, context_window_tokens=200) - loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign] + loop.consolidator.archive = AsyncMock(return_value=True) # type: ignore[method-assign] session = loop.sessions.get_or_create("cli:test") session.messages = [ @@ -110,12 +110,12 @@ async def test_consolidation_loops_until_target_met(tmp_path, monkeypatch) -> No return (300, "test") return (80, "test") - loop.memory_consolidator.estimate_session_prompt_tokens = mock_estimate # type: ignore[method-assign] + loop.consolidator.estimate_session_prompt_tokens = mock_estimate # type: ignore[method-assign] monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _m: 100) - await loop.memory_consolidator.maybe_consolidate_by_tokens(session) + await loop.consolidator.maybe_consolidate_by_tokens(session) - assert loop.memory_consolidator.consolidate_messages.await_count == 2 + assert loop.consolidator.archive.await_count == 2 assert session.last_consolidated == 6 @@ -123,7 +123,7 @@ async def test_consolidation_loops_until_target_met(tmp_path, monkeypatch) -> No async def test_consolidation_continues_below_trigger_until_half_target(tmp_path, monkeypatch) -> None: """Once triggered, consolidation should continue until it drops below half threshold.""" loop = _make_loop(tmp_path, estimated_tokens=0, context_window_tokens=200) - loop.memory_consolidator.consolidate_messages = AsyncMock(return_value=True) # type: ignore[method-assign] + loop.consolidator.archive = AsyncMock(return_value=True) # type: ignore[method-assign] session = loop.sessions.get_or_create("cli:test") session.messages = [ @@ -147,12 +147,12 @@ async def test_consolidation_continues_below_trigger_until_half_target(tmp_path, return (150, "test") return (80, "test") - loop.memory_consolidator.estimate_session_prompt_tokens = mock_estimate # type: ignore[method-assign] + loop.consolidator.estimate_session_prompt_tokens = mock_estimate # type: ignore[method-assign] monkeypatch.setattr(memory_module, "estimate_message_tokens", lambda _m: 100) - await loop.memory_consolidator.maybe_consolidate_by_tokens(session) + await loop.consolidator.maybe_consolidate_by_tokens(session) - assert loop.memory_consolidator.consolidate_messages.await_count == 2 + assert loop.consolidator.archive.await_count == 2 assert session.last_consolidated == 6 @@ -166,7 +166,7 @@ async def test_preflight_consolidation_before_llm_call(tmp_path, monkeypatch) -> async def track_consolidate(messages): order.append("consolidate") return True - loop.memory_consolidator.consolidate_messages = track_consolidate # type: ignore[method-assign] + loop.consolidator.archive = track_consolidate # type: ignore[method-assign] async def track_llm(*args, **kwargs): order.append("llm") @@ -187,7 +187,7 @@ async def test_preflight_consolidation_before_llm_call(tmp_path, monkeypatch) -> def mock_estimate(_session): call_count[0] += 1 return (1000 if call_count[0] <= 1 else 80, "test") - loop.memory_consolidator.estimate_session_prompt_tokens = mock_estimate # type: ignore[method-assign] + loop.consolidator.estimate_session_prompt_tokens = mock_estimate # type: ignore[method-assign] await loop.process_direct("hello", session_key="cli:test") diff --git a/tests/agent/test_memory_consolidation_types.py b/tests/agent/test_memory_consolidation_types.py deleted file mode 100644 index 203e39a90..000000000 --- a/tests/agent/test_memory_consolidation_types.py +++ /dev/null @@ -1,478 +0,0 @@ -"""Test MemoryStore.consolidate() handles non-string tool call arguments. - -Regression test for https://github.com/HKUDS/nanobot/issues/1042 -When memory consolidation receives dict values instead of strings from the LLM -tool call response, it should serialize them to JSON instead of raising TypeError. -""" - -import json -from pathlib import Path -from unittest.mock import AsyncMock - -import pytest - -from nanobot.agent.memory import MemoryStore -from nanobot.providers.base import LLMProvider, LLMResponse, ToolCallRequest - - -def _make_messages(message_count: int = 30): - """Create a list of mock messages.""" - return [ - {"role": "user", "content": f"msg{i}", "timestamp": "2026-01-01 00:00"} - for i in range(message_count) - ] - - -def _make_tool_response(history_entry, memory_update): - """Create an LLMResponse with a save_memory tool call.""" - return LLMResponse( - content=None, - tool_calls=[ - ToolCallRequest( - id="call_1", - name="save_memory", - arguments={ - "history_entry": history_entry, - "memory_update": memory_update, - }, - ) - ], - ) - - -class ScriptedProvider(LLMProvider): - def __init__(self, responses: list[LLMResponse]): - super().__init__() - self._responses = list(responses) - self.calls = 0 - - async def chat(self, *args, **kwargs) -> LLMResponse: - self.calls += 1 - if self._responses: - return self._responses.pop(0) - return LLMResponse(content="", tool_calls=[]) - - def get_default_model(self) -> str: - return "test-model" - - -class TestMemoryConsolidationTypeHandling: - """Test that consolidation handles various argument types correctly.""" - - @pytest.mark.asyncio - async def test_string_arguments_work(self, tmp_path: Path) -> None: - """Normal case: LLM returns string arguments.""" - store = MemoryStore(tmp_path) - provider = AsyncMock() - provider.chat = AsyncMock( - return_value=_make_tool_response( - history_entry="[2026-01-01] User discussed testing.", - memory_update="# Memory\nUser likes testing.", - ) - ) - provider.chat_with_retry = provider.chat - messages = _make_messages(message_count=60) - - result = await store.consolidate(messages, provider, "test-model") - - assert result is True - assert store.history_file.exists() - assert "[2026-01-01] User discussed testing." in store.history_file.read_text() - assert "User likes testing." in store.memory_file.read_text() - - @pytest.mark.asyncio - async def test_dict_arguments_serialized_to_json(self, tmp_path: Path) -> None: - """Issue #1042: LLM returns dict instead of string — must not raise TypeError.""" - store = MemoryStore(tmp_path) - provider = AsyncMock() - provider.chat = AsyncMock( - return_value=_make_tool_response( - history_entry={"timestamp": "2026-01-01", "summary": "User discussed testing."}, - memory_update={"facts": ["User likes testing"], "topics": ["testing"]}, - ) - ) - provider.chat_with_retry = provider.chat - messages = _make_messages(message_count=60) - - result = await store.consolidate(messages, provider, "test-model") - - assert result is True - assert store.history_file.exists() - history_content = store.history_file.read_text() - parsed = json.loads(history_content.strip()) - assert parsed["summary"] == "User discussed testing." - - memory_content = store.memory_file.read_text() - parsed_mem = json.loads(memory_content) - assert "User likes testing" in parsed_mem["facts"] - - @pytest.mark.asyncio - async def test_string_arguments_as_raw_json(self, tmp_path: Path) -> None: - """Some providers return arguments as a JSON string instead of parsed dict.""" - store = MemoryStore(tmp_path) - provider = AsyncMock() - - response = LLMResponse( - content=None, - tool_calls=[ - ToolCallRequest( - id="call_1", - name="save_memory", - arguments=json.dumps({ - "history_entry": "[2026-01-01] User discussed testing.", - "memory_update": "# Memory\nUser likes testing.", - }), - ) - ], - ) - provider.chat = AsyncMock(return_value=response) - provider.chat_with_retry = provider.chat - messages = _make_messages(message_count=60) - - result = await store.consolidate(messages, provider, "test-model") - - assert result is True - assert "User discussed testing." in store.history_file.read_text() - - @pytest.mark.asyncio - async def test_no_tool_call_returns_false(self, tmp_path: Path) -> None: - """When LLM doesn't use the save_memory tool, return False.""" - store = MemoryStore(tmp_path) - provider = AsyncMock() - provider.chat = AsyncMock( - return_value=LLMResponse(content="I summarized the conversation.", tool_calls=[]) - ) - provider.chat_with_retry = provider.chat - messages = _make_messages(message_count=60) - - result = await store.consolidate(messages, provider, "test-model") - - assert result is False - assert not store.history_file.exists() - - @pytest.mark.asyncio - async def test_skips_when_message_chunk_is_empty(self, tmp_path: Path) -> None: - """Consolidation should be a no-op when the selected chunk is empty.""" - store = MemoryStore(tmp_path) - provider = AsyncMock() - provider.chat_with_retry = provider.chat - messages: list[dict] = [] - - result = await store.consolidate(messages, provider, "test-model") - - assert result is True - provider.chat.assert_not_called() - - @pytest.mark.asyncio - async def test_list_arguments_extracts_first_dict(self, tmp_path: Path) -> None: - """Some providers return arguments as a list - extract first element if it's a dict.""" - store = MemoryStore(tmp_path) - provider = AsyncMock() - - response = LLMResponse( - content=None, - tool_calls=[ - ToolCallRequest( - id="call_1", - name="save_memory", - arguments=[{ - "history_entry": "[2026-01-01] User discussed testing.", - "memory_update": "# Memory\nUser likes testing.", - }], - ) - ], - ) - provider.chat = AsyncMock(return_value=response) - provider.chat_with_retry = provider.chat - messages = _make_messages(message_count=60) - - result = await store.consolidate(messages, provider, "test-model") - - assert result is True - assert "User discussed testing." in store.history_file.read_text() - assert "User likes testing." in store.memory_file.read_text() - - @pytest.mark.asyncio - async def test_list_arguments_empty_list_returns_false(self, tmp_path: Path) -> None: - """Empty list arguments should return False.""" - store = MemoryStore(tmp_path) - provider = AsyncMock() - - response = LLMResponse( - content=None, - tool_calls=[ - ToolCallRequest( - id="call_1", - name="save_memory", - arguments=[], - ) - ], - ) - provider.chat = AsyncMock(return_value=response) - provider.chat_with_retry = provider.chat - messages = _make_messages(message_count=60) - - result = await store.consolidate(messages, provider, "test-model") - - assert result is False - - @pytest.mark.asyncio - async def test_list_arguments_non_dict_content_returns_false(self, tmp_path: Path) -> None: - """List with non-dict content should return False.""" - store = MemoryStore(tmp_path) - provider = AsyncMock() - - response = LLMResponse( - content=None, - tool_calls=[ - ToolCallRequest( - id="call_1", - name="save_memory", - arguments=["string", "content"], - ) - ], - ) - provider.chat = AsyncMock(return_value=response) - provider.chat_with_retry = provider.chat - messages = _make_messages(message_count=60) - - result = await store.consolidate(messages, provider, "test-model") - - assert result is False - - @pytest.mark.asyncio - async def test_missing_history_entry_returns_false_without_writing(self, tmp_path: Path) -> None: - """Do not persist partial results when required fields are missing.""" - store = MemoryStore(tmp_path) - provider = AsyncMock() - provider.chat_with_retry = AsyncMock( - return_value=LLMResponse( - content=None, - tool_calls=[ - ToolCallRequest( - id="call_1", - name="save_memory", - arguments={"memory_update": "# Memory\nOnly memory update"}, - ) - ], - ) - ) - messages = _make_messages(message_count=60) - - result = await store.consolidate(messages, provider, "test-model") - - assert result is False - assert not store.history_file.exists() - assert not store.memory_file.exists() - - @pytest.mark.asyncio - async def test_missing_memory_update_returns_false_without_writing(self, tmp_path: Path) -> None: - """Do not append history if memory_update is missing.""" - store = MemoryStore(tmp_path) - provider = AsyncMock() - provider.chat_with_retry = AsyncMock( - return_value=LLMResponse( - content=None, - tool_calls=[ - ToolCallRequest( - id="call_1", - name="save_memory", - arguments={"history_entry": "[2026-01-01] Partial output."}, - ) - ], - ) - ) - messages = _make_messages(message_count=60) - - result = await store.consolidate(messages, provider, "test-model") - - assert result is False - assert not store.history_file.exists() - assert not store.memory_file.exists() - - @pytest.mark.asyncio - async def test_null_required_field_returns_false_without_writing(self, tmp_path: Path) -> None: - """Null required fields should be rejected before persistence.""" - store = MemoryStore(tmp_path) - provider = AsyncMock() - provider.chat_with_retry = AsyncMock( - return_value=_make_tool_response( - history_entry=None, - memory_update="# Memory\nUser likes testing.", - ) - ) - messages = _make_messages(message_count=60) - - result = await store.consolidate(messages, provider, "test-model") - - assert result is False - assert not store.history_file.exists() - assert not store.memory_file.exists() - - @pytest.mark.asyncio - async def test_empty_history_entry_returns_false_without_writing(self, tmp_path: Path) -> None: - """Empty history entries should be rejected to avoid blank archival records.""" - store = MemoryStore(tmp_path) - provider = AsyncMock() - provider.chat_with_retry = AsyncMock( - return_value=_make_tool_response( - history_entry=" ", - memory_update="# Memory\nUser likes testing.", - ) - ) - messages = _make_messages(message_count=60) - - result = await store.consolidate(messages, provider, "test-model") - - assert result is False - assert not store.history_file.exists() - assert not store.memory_file.exists() - - @pytest.mark.asyncio - async def test_retries_transient_error_then_succeeds(self, tmp_path: Path, monkeypatch) -> None: - store = MemoryStore(tmp_path) - provider = ScriptedProvider([ - LLMResponse(content="503 server error", finish_reason="error"), - _make_tool_response( - history_entry="[2026-01-01] User discussed testing.", - memory_update="# Memory\nUser likes testing.", - ), - ]) - messages = _make_messages(message_count=60) - delays: list[int] = [] - - async def _fake_sleep(delay: int) -> None: - delays.append(delay) - - monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep) - - result = await store.consolidate(messages, provider, "test-model") - - assert result is True - assert provider.calls == 2 - assert delays == [1] - - @pytest.mark.asyncio - async def test_consolidation_delegates_to_provider_defaults(self, tmp_path: Path) -> None: - """Consolidation no longer passes generation params — the provider owns them.""" - store = MemoryStore(tmp_path) - provider = AsyncMock() - provider.chat_with_retry = AsyncMock( - return_value=_make_tool_response( - history_entry="[2026-01-01] User discussed testing.", - memory_update="# Memory\nUser likes testing.", - ) - ) - messages = _make_messages(message_count=60) - - result = await store.consolidate(messages, provider, "test-model") - - assert result is True - provider.chat_with_retry.assert_awaited_once() - _, kwargs = provider.chat_with_retry.await_args - assert kwargs["model"] == "test-model" - assert "temperature" not in kwargs - assert "max_tokens" not in kwargs - assert "reasoning_effort" not in kwargs - - @pytest.mark.asyncio - async def test_tool_choice_fallback_on_unsupported_error(self, tmp_path: Path) -> None: - """Forced tool_choice rejected by provider -> retry with auto and succeed.""" - store = MemoryStore(tmp_path) - error_resp = LLMResponse( - content="Error calling LLM: BadRequestError: " - "The tool_choice parameter does not support being set to required or object", - finish_reason="error", - tool_calls=[], - ) - ok_resp = _make_tool_response( - history_entry="[2026-01-01] Fallback worked.", - memory_update="# Memory\nFallback OK.", - ) - - call_log: list[dict] = [] - - async def _tracking_chat(**kwargs): - call_log.append(kwargs) - return error_resp if len(call_log) == 1 else ok_resp - - provider = AsyncMock() - provider.chat_with_retry = AsyncMock(side_effect=_tracking_chat) - messages = _make_messages(message_count=60) - - result = await store.consolidate(messages, provider, "test-model") - - assert result is True - assert len(call_log) == 2 - assert isinstance(call_log[0]["tool_choice"], dict) - assert call_log[1]["tool_choice"] == "auto" - assert "Fallback worked." in store.history_file.read_text() - - @pytest.mark.asyncio - async def test_tool_choice_fallback_auto_no_tool_call(self, tmp_path: Path) -> None: - """Forced rejected, auto retry also produces no tool call -> return False.""" - store = MemoryStore(tmp_path) - error_resp = LLMResponse( - content="Error: tool_choice must be none or auto", - finish_reason="error", - tool_calls=[], - ) - no_tool_resp = LLMResponse( - content="Here is a summary.", - finish_reason="stop", - tool_calls=[], - ) - - provider = AsyncMock() - provider.chat_with_retry = AsyncMock(side_effect=[error_resp, no_tool_resp]) - messages = _make_messages(message_count=60) - - result = await store.consolidate(messages, provider, "test-model") - - assert result is False - assert not store.history_file.exists() - - @pytest.mark.asyncio - async def test_raw_archive_after_consecutive_failures(self, tmp_path: Path) -> None: - """After 3 consecutive failures, raw-archive messages and return True.""" - store = MemoryStore(tmp_path) - no_tool = LLMResponse(content="No tool call.", finish_reason="stop", tool_calls=[]) - provider = AsyncMock() - provider.chat_with_retry = AsyncMock(return_value=no_tool) - messages = _make_messages(message_count=10) - - assert await store.consolidate(messages, provider, "m") is False - assert await store.consolidate(messages, provider, "m") is False - assert await store.consolidate(messages, provider, "m") is True - - assert store.history_file.exists() - content = store.history_file.read_text() - assert "[RAW]" in content - assert "10 messages" in content - assert "msg0" in content - assert not store.memory_file.exists() - - @pytest.mark.asyncio - async def test_raw_archive_counter_resets_on_success(self, tmp_path: Path) -> None: - """A successful consolidation resets the failure counter.""" - store = MemoryStore(tmp_path) - no_tool = LLMResponse(content="Nope.", finish_reason="stop", tool_calls=[]) - ok_resp = _make_tool_response( - history_entry="[2026-01-01] OK.", - memory_update="# Memory\nOK.", - ) - messages = _make_messages(message_count=10) - - provider = AsyncMock() - provider.chat_with_retry = AsyncMock(return_value=no_tool) - assert await store.consolidate(messages, provider, "m") is False - assert await store.consolidate(messages, provider, "m") is False - assert store._consecutive_failures == 2 - - provider.chat_with_retry = AsyncMock(return_value=ok_resp) - assert await store.consolidate(messages, provider, "m") is True - assert store._consecutive_failures == 0 - - provider.chat_with_retry = AsyncMock(return_value=no_tool) - assert await store.consolidate(messages, provider, "m") is False - assert store._consecutive_failures == 1 diff --git a/tests/agent/test_memory_store.py b/tests/agent/test_memory_store.py new file mode 100644 index 000000000..3d0547183 --- /dev/null +++ b/tests/agent/test_memory_store.py @@ -0,0 +1,133 @@ +"""Tests for the restructured MemoryStore — pure file I/O layer.""" + +import json + +import pytest +from pathlib import Path + +from nanobot.agent.memory import MemoryStore + + +@pytest.fixture +def store(tmp_path): + return MemoryStore(tmp_path) + + +class TestMemoryStoreBasicIO: + def test_read_memory_returns_empty_when_missing(self, store): + assert store.read_memory() == "" + + def test_write_and_read_memory(self, store): + store.write_memory("hello") + assert store.read_memory() == "hello" + + def test_read_soul_returns_empty_when_missing(self, store): + assert store.read_soul() == "" + + def test_write_and_read_soul(self, store): + store.write_soul("soul content") + assert store.read_soul() == "soul content" + + def test_read_user_returns_empty_when_missing(self, store): + assert store.read_user() == "" + + def test_write_and_read_user(self, store): + store.write_user("user content") + assert store.read_user() == "user content" + + def test_get_memory_context_returns_empty_when_missing(self, store): + assert store.get_memory_context() == "" + + def test_get_memory_context_returns_formatted_content(self, store): + store.write_memory("important fact") + ctx = store.get_memory_context() + assert "Long-term Memory" in ctx + assert "important fact" in ctx + + +class TestHistoryWithCursor: + def test_append_history_returns_cursor(self, store): + cursor = store.append_history("event 1") + assert cursor == 1 + cursor2 = store.append_history("event 2") + assert cursor2 == 2 + + def test_append_history_includes_cursor_in_file(self, store): + store.append_history("event 1") + content = store.read_file(store.history_file) + data = json.loads(content) + assert data["cursor"] == 1 + + def test_cursor_persists_across_appends(self, store): + store.append_history("event 1") + store.append_history("event 2") + cursor = store.append_history("event 3") + assert cursor == 3 + + def test_read_unprocessed_history(self, store): + store.append_history("event 1") + store.append_history("event 2") + store.append_history("event 3") + entries = store.read_unprocessed_history(since_cursor=1) + assert len(entries) == 2 + assert entries[0]["cursor"] == 2 + + def test_read_unprocessed_history_returns_all_when_cursor_zero(self, store): + store.append_history("event 1") + store.append_history("event 2") + entries = store.read_unprocessed_history(since_cursor=0) + assert len(entries) == 2 + + def test_compact_history_drops_oldest(self, tmp_path): + store = MemoryStore(tmp_path, max_history_entries=2) + store.append_history("event 1") + store.append_history("event 2") + store.append_history("event 3") + store.append_history("event 4") + store.append_history("event 5") + store.compact_history() + entries = store.read_unprocessed_history(since_cursor=0) + assert len(entries) == 2 + assert entries[0]["cursor"] in {4, 5} + + +class TestDreamCursor: + def test_initial_cursor_is_zero(self, store): + assert store.get_last_dream_cursor() == 0 + + def test_set_and_get_cursor(self, store): + store.set_last_dream_cursor(5) + assert store.get_last_dream_cursor() == 5 + + def test_cursor_persists(self, store): + store.set_last_dream_cursor(3) + store2 = MemoryStore(store.workspace) + assert store2.get_last_dream_cursor() == 3 + + +class TestDreamLog: + def test_read_dream_log_returns_empty_when_missing(self, store): + assert store.read_dream_log() == "" + + def test_append_dream_log(self, store): + store.append_dream_log("## 2026-03-30\nProcessed entries #1-#5") + log = store.read_dream_log() + assert "Processed entries #1-#5" in log + + def test_append_dream_log_is_additive(self, store): + store.append_dream_log("first run") + store.append_dream_log("second run") + log = store.read_dream_log() + assert "first run" in log + assert "second run" in log + + +class TestLegacyHistoryMigration: + def test_read_unprocessed_history_handles_entries_without_cursor(self, store): + """JSONL entries with cursor=1 are correctly parsed and returned.""" + store.history_file.write_text( + '{"cursor": 1, "timestamp": "2026-03-30 14:30", "content": "Old event"}\n', + encoding="utf-8") + entries = store.read_unprocessed_history(since_cursor=0) + assert len(entries) == 1 + assert entries[0]["cursor"] == 1 diff --git a/tests/cli/test_restart_command.py b/tests/cli/test_restart_command.py index 6efcdad0d..aa514e140 100644 --- a/tests/cli/test_restart_command.py +++ b/tests/cli/test_restart_command.py @@ -127,7 +127,7 @@ class TestRestartCommand: loop.sessions.get_or_create.return_value = session loop._start_time = time.time() - 125 loop._last_usage = {"prompt_tokens": 0, "completion_tokens": 0} - loop.memory_consolidator.estimate_session_prompt_tokens = MagicMock( + loop.consolidator.estimate_session_prompt_tokens = MagicMock( return_value=(20500, "tiktoken") ) @@ -166,7 +166,7 @@ class TestRestartCommand: session.get_history.return_value = [{"role": "user"}] loop.sessions.get_or_create.return_value = session loop._last_usage = {"prompt_tokens": 1200, "completion_tokens": 34} - loop.memory_consolidator.estimate_session_prompt_tokens = MagicMock( + loop.consolidator.estimate_session_prompt_tokens = MagicMock( return_value=(0, "none") ) diff --git a/tests/test_openai_api.py b/tests/test_openai_api.py index 3d29d4767..6663c5f20 100644 --- a/tests/test_openai_api.py +++ b/tests/test_openai_api.py @@ -235,15 +235,10 @@ async def test_followup_requests_share_same_session_key(aiohttp_client) -> None: @pytest.mark.asyncio async def test_fixed_session_requests_are_serialized(aiohttp_client) -> None: order: list[str] = [] - barrier = asyncio.Event() async def slow_process(content, session_key="", channel="", chat_id=""): order.append(f"start:{content}") - if content == "first": - barrier.set() - await asyncio.sleep(0.1) - else: - await barrier.wait() + await asyncio.sleep(0.1) order.append(f"end:{content}") return content @@ -264,7 +259,12 @@ async def test_fixed_session_requests_are_serialized(aiohttp_client) -> None: r1, r2 = await asyncio.gather(send("first"), send("second")) assert r1.status == 200 assert r2.status == 200 - assert order.index("end:first") < order.index("start:second") + # Verify serialization: one process must fully finish before the other starts + assert "end:second" in order or "end:first" in order + if order[0] == "start:first": + assert order.index("end:first") < order.index("start:second") + else: + assert order.index("end:second") < order.index("start:first") @pytest.mark.skipif(not HAS_AIOHTTP, reason="aiohttp not installed")