From 4c751bb8e3be7c776df69cb4a4442f6573548776 Mon Sep 17 00:00:00 2001 From: chengyongru <2755839590@qq.com> Date: Sat, 16 May 2026 18:35:09 +0800 Subject: [PATCH] refactor(agent): extract checkpoint.py and turn_writer.py from loop.py Extract checkpoint and turn-persistence logic from AgentLoop into two dedicated modules to reduce loop.py's size and improve reviewability. - checkpoint.py: CheckpointManager handles runtime checkpoints, pending user-turn markers, and session recovery after cancellation/crash. - turn_writer.py: TurnWriter handles message sanitization, turn saving, early user-message persistence, and subagent-followup persistence. AgentLoop retains thin delegate methods for full backward compatibility. Test helpers updated to initialize the new sub-managers. --- nanobot/agent/checkpoint.py | 122 ++++++++++++++ nanobot/agent/loop.py | 218 +++----------------------- nanobot/agent/turn_writer.py | 163 +++++++++++++++++++ tests/agent/test_loop_save_turn.py | 8 + tests/test_truncate_text_shadowing.py | 6 +- 5 files changed, 315 insertions(+), 202 deletions(-) create mode 100644 nanobot/agent/checkpoint.py create mode 100644 nanobot/agent/turn_writer.py diff --git a/nanobot/agent/checkpoint.py b/nanobot/agent/checkpoint.py new file mode 100644 index 000000000..9c93f7f8f --- /dev/null +++ b/nanobot/agent/checkpoint.py @@ -0,0 +1,122 @@ +"""Runtime checkpoint management for session recovery.""" + +from __future__ import annotations + +from typing import Any + +from nanobot.session.manager import Session, SessionManager + + +class CheckpointManager: + """Manages runtime checkpoints and pending-user-turn markers on sessions. + + Checkpoints capture in-flight turn state (assistant messages, completed and + pending tool calls) so that a cancelled or crashed turn can be materialised + into session history on the next request, preventing data loss. + """ + + RUNTIME_CHECKPOINT_KEY = "runtime_checkpoint" + PENDING_USER_TURN_KEY = "pending_user_turn" + + def __init__(self, sessions: SessionManager) -> None: + self.sessions = sessions + + def set_runtime_checkpoint(self, session: Session, payload: dict[str, Any]) -> None: + session.metadata[self.RUNTIME_CHECKPOINT_KEY] = payload + self.sessions.save(session) + + def mark_pending_user_turn(self, session: Session) -> None: + session.metadata[self.PENDING_USER_TURN_KEY] = True + + def clear_pending_user_turn(self, session: Session) -> None: + session.metadata.pop(self.PENDING_USER_TURN_KEY, None) + + def clear_runtime_checkpoint(self, session: Session) -> None: + if self.RUNTIME_CHECKPOINT_KEY in session.metadata: + session.metadata.pop(self.RUNTIME_CHECKPOINT_KEY, None) + + @staticmethod + def checkpoint_message_key(message: dict[str, Any]) -> tuple[Any, ...]: + return ( + message.get("role"), + message.get("content"), + message.get("tool_call_id"), + message.get("name"), + message.get("tool_calls"), + message.get("reasoning_content"), + message.get("thinking_blocks"), + ) + + def restore_runtime_checkpoint(self, session: Session) -> bool: + """Materialize an unfinished turn into session history before a new request.""" + from datetime import datetime + + checkpoint = session.metadata.get(self.RUNTIME_CHECKPOINT_KEY) + if not isinstance(checkpoint, dict): + return False + + assistant_message = checkpoint.get("assistant_message") + completed_tool_results = checkpoint.get("completed_tool_results") or [] + pending_tool_calls = checkpoint.get("pending_tool_calls") or [] + + restored_messages: list[dict[str, Any]] = [] + if isinstance(assistant_message, dict): + restored = dict(assistant_message) + restored.setdefault("timestamp", datetime.now().isoformat()) + restored_messages.append(restored) + for message in completed_tool_results: + if isinstance(message, dict): + restored = dict(message) + restored.setdefault("timestamp", datetime.now().isoformat()) + restored_messages.append(restored) + for tool_call in pending_tool_calls: + if not isinstance(tool_call, dict): + continue + tool_id = tool_call.get("id") + name = ((tool_call.get("function") or {}).get("name")) or "tool" + restored_messages.append( + { + "role": "tool", + "tool_call_id": tool_id, + "name": name, + "content": "Error: Task interrupted before this tool finished.", + "timestamp": datetime.now().isoformat(), + } + ) + + overlap = 0 + max_overlap = min(len(session.messages), len(restored_messages)) + for size in range(max_overlap, 0, -1): + existing = session.messages[-size:] + restored = restored_messages[:size] + if all( + self.checkpoint_message_key(left) == self.checkpoint_message_key(right) + for left, right in zip(existing, restored) + ): + overlap = size + break + session.messages.extend(restored_messages[overlap:]) + + self.clear_pending_user_turn(session) + self.clear_runtime_checkpoint(session) + return True + + def restore_pending_user_turn(self, session: Session) -> bool: + """Close a turn that only persisted the user message before crashing.""" + from datetime import datetime + + if not session.metadata.get(self.PENDING_USER_TURN_KEY): + return False + + if session.messages and session.messages[-1].get("role") == "user": + session.messages.append( + { + "role": "assistant", + "content": "Error: Task interrupted before a response was generated.", + "timestamp": datetime.now().isoformat(), + } + ) + session.updated_at = datetime.now() + + self.clear_pending_user_turn(session) + return True diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index 0868ebb7c..75aa4acdb 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -16,6 +16,7 @@ from loguru import logger from nanobot.agent import model_presets as preset_helpers from nanobot.agent.autocompact import AutoCompact +from nanobot.agent.checkpoint import CheckpointManager from nanobot.agent.context import ContextBuilder from nanobot.agent.hook import AgentHook, CompositeHook from nanobot.agent.memory import Consolidator, Dream @@ -26,6 +27,7 @@ from nanobot.agent.tools.file_state import FileStateStore, bind_file_states, res from nanobot.agent.tools.message import MessageTool from nanobot.agent.tools.registry import ToolRegistry from nanobot.agent.tools.self import MyTool +from nanobot.agent.turn_writer import TurnWriter from nanobot.bus.events import InboundMessage, OutboundMessage from nanobot.bus.queue import MessageBus from nanobot.command import CommandContext, CommandRouter, register_builtin_commands @@ -40,8 +42,6 @@ from nanobot.session.goal_state import ( from nanobot.session.manager import Session, SessionManager from nanobot.utils.artifacts import generated_image_paths_from_messages from nanobot.utils.document import extract_documents -from nanobot.utils.helpers import image_placeholder_text -from nanobot.utils.helpers import truncate_text as truncate_text_fn from nanobot.utils.image_generation_intent import image_generation_prompt from nanobot.utils.runtime import EMPTY_FINAL_RESPONSE_MESSAGE from nanobot.utils.session_attachments import merge_turn_media_into_last_assistant @@ -137,8 +137,8 @@ class AgentLoop: def tool_names(self) -> list[str]: return self.tools.tool_names - _RUNTIME_CHECKPOINT_KEY = "runtime_checkpoint" - _PENDING_USER_TURN_KEY = "pending_user_turn" + _RUNTIME_CHECKPOINT_KEY = CheckpointManager.RUNTIME_CHECKPOINT_KEY + _PENDING_USER_TURN_KEY = CheckpointManager.PENDING_USER_TURN_KEY # Event-driven state transition table. # Handlers return an event string; the driver looks up the next state here. @@ -238,6 +238,8 @@ class AgentLoop: self.context = ContextBuilder(workspace, timezone=timezone, disabled_skills=disabled_skills) self.sessions = session_manager or SessionManager(workspace) + self._checkpoint_mgr = CheckpointManager(self.sessions) + self._turn_writer = TurnWriter(self.sessions, self._checkpoint_mgr, self.max_tool_result_chars) self.tools = ToolRegistry() # One file-read/write tracker per logical session. The tool registry is # shared by this loop, so tools resolve the active state via contextvars. @@ -579,21 +581,7 @@ class AgentLoop: session: Session, **kwargs: Any, ) -> bool: - """Persist the triggering user message before the turn starts. - - Returns True if the message was persisted. - """ - media_paths = [p for p in (msg.media or []) if isinstance(p, str) and p] - has_text = isinstance(msg.content, str) and msg.content.strip() - if has_text or media_paths: - extra: dict[str, Any] = {"media": list(media_paths)} if media_paths else {} - extra.update(kwargs) - text = msg.content if isinstance(msg.content, str) else "" - session.add_message("user", text, **extra) - self._mark_pending_user_turn(session) - self.sessions.save(session) - return True - return False + return self._turn_writer.persist_user_message_early(msg, session, **kwargs) def _build_initial_messages( self, @@ -1440,38 +1428,9 @@ class AgentLoop: should_truncate_text: bool = False, drop_runtime: bool = False, ) -> list[dict[str, Any]]: - """Strip volatile multimodal payloads before writing session history.""" - filtered: list[dict[str, Any]] = [] - for block in content: - if not isinstance(block, dict): - filtered.append(block) - continue - - if ( - drop_runtime - and block.get("type") == "text" - and isinstance(block.get("text"), str) - and block["text"].startswith(ContextBuilder._RUNTIME_CONTEXT_TAG) - ): - continue - - if block.get("type") == "image_url" and block.get("image_url", {}).get( - "url", "" - ).startswith("data:image/"): - path = (block.get("_meta") or {}).get("path", "") - filtered.append({"type": "text", "text": image_placeholder_text(path)}) - continue - - if block.get("type") == "text" and isinstance(block.get("text"), str): - text = block["text"] - if should_truncate_text and len(text) > self.max_tool_result_chars: - text = truncate_text_fn(text, self.max_tool_result_chars) - filtered.append({**block, "text": text}) - continue - - filtered.append(block) - - return filtered + return self._turn_writer.sanitize_persisted_blocks( + content, should_truncate_text=should_truncate_text, drop_runtime=drop_runtime, + ) def _save_turn( self, @@ -1481,169 +1440,28 @@ class AgentLoop: *, turn_latency_ms: int | None = None, ) -> None: - """Save new-turn messages into session, truncating large tool results.""" - from datetime import datetime - - last_assistant_idx: int | None = None - for m in messages[skip:]: - entry = dict(m) - role, content = entry.get("role"), entry.get("content") - if role == "assistant" and not content and not entry.get("tool_calls"): - continue # skip empty assistant messages — they poison session context - if role == "tool": - if isinstance(content, str) and len(content) > self.max_tool_result_chars: - entry["content"] = truncate_text_fn(content, self.max_tool_result_chars) - elif isinstance(content, list): - filtered = self._sanitize_persisted_blocks(content, should_truncate_text=True) - if not filtered: - continue - entry["content"] = filtered - elif role == "user": - if isinstance(content, str) and ContextBuilder._RUNTIME_CONTEXT_TAG in content: - # Strip the runtime-context block appended at the end. - tag_pos = content.find(ContextBuilder._RUNTIME_CONTEXT_TAG) - before = content[:tag_pos].rstrip("\n ") - if before: - entry["content"] = before - else: - continue - if isinstance(content, list): - filtered = self._sanitize_persisted_blocks(content, drop_runtime=True) - if not filtered: - continue - entry["content"] = filtered - entry.setdefault("timestamp", datetime.now().isoformat()) - session.messages.append(entry) - if role == "assistant": - last_assistant_idx = len(session.messages) - 1 - if turn_latency_ms is not None and last_assistant_idx is not None: - session.messages[last_assistant_idx]["latency_ms"] = int(turn_latency_ms) - session.updated_at = datetime.now() + self._turn_writer.save_turn(session, messages, skip, turn_latency_ms=turn_latency_ms) def _persist_subagent_followup(self, session: Session, msg: InboundMessage) -> bool: - """Persist subagent follow-ups before prompt assembly so history stays durable. - - Returns True if a new entry was appended; False if the follow-up was - deduped (same ``subagent_task_id`` already in session) or carries no - content worth persisting. - """ - if not msg.content: - return False - task_id = msg.metadata.get("subagent_task_id") if isinstance(msg.metadata, dict) else None - if task_id and any( - m.get("injected_event") == "subagent_result" and m.get("subagent_task_id") == task_id - for m in session.messages - ): - return False - session.add_message( - "assistant", - msg.content, - sender_id=msg.sender_id, - injected_event="subagent_result", - subagent_task_id=task_id, - ) - return True + return self._turn_writer.persist_subagent_followup(session, msg) def _set_runtime_checkpoint(self, session: Session, payload: dict[str, Any]) -> None: - """Persist the latest in-flight turn state into session metadata.""" - session.metadata[self._RUNTIME_CHECKPOINT_KEY] = payload - self.sessions.save(session) + self._checkpoint_mgr.set_runtime_checkpoint(session, payload) def _mark_pending_user_turn(self, session: Session) -> None: - session.metadata[self._PENDING_USER_TURN_KEY] = True + self._checkpoint_mgr.mark_pending_user_turn(session) def _clear_pending_user_turn(self, session: Session) -> None: - session.metadata.pop(self._PENDING_USER_TURN_KEY, None) + self._checkpoint_mgr.clear_pending_user_turn(session) def _clear_runtime_checkpoint(self, session: Session) -> None: - if self._RUNTIME_CHECKPOINT_KEY in session.metadata: - session.metadata.pop(self._RUNTIME_CHECKPOINT_KEY, None) - - @staticmethod - def _checkpoint_message_key(message: dict[str, Any]) -> tuple[Any, ...]: - return ( - message.get("role"), - message.get("content"), - message.get("tool_call_id"), - message.get("name"), - message.get("tool_calls"), - message.get("reasoning_content"), - message.get("thinking_blocks"), - ) + self._checkpoint_mgr.clear_runtime_checkpoint(session) def _restore_runtime_checkpoint(self, session: Session) -> bool: - """Materialize an unfinished turn into session history before a new request.""" - from datetime import datetime - - checkpoint = session.metadata.get(self._RUNTIME_CHECKPOINT_KEY) - if not isinstance(checkpoint, dict): - return False - - assistant_message = checkpoint.get("assistant_message") - completed_tool_results = checkpoint.get("completed_tool_results") or [] - pending_tool_calls = checkpoint.get("pending_tool_calls") or [] - - restored_messages: list[dict[str, Any]] = [] - if isinstance(assistant_message, dict): - restored = dict(assistant_message) - restored.setdefault("timestamp", datetime.now().isoformat()) - restored_messages.append(restored) - for message in completed_tool_results: - if isinstance(message, dict): - restored = dict(message) - restored.setdefault("timestamp", datetime.now().isoformat()) - restored_messages.append(restored) - for tool_call in pending_tool_calls: - if not isinstance(tool_call, dict): - continue - tool_id = tool_call.get("id") - name = ((tool_call.get("function") or {}).get("name")) or "tool" - restored_messages.append( - { - "role": "tool", - "tool_call_id": tool_id, - "name": name, - "content": "Error: Task interrupted before this tool finished.", - "timestamp": datetime.now().isoformat(), - } - ) - - overlap = 0 - max_overlap = min(len(session.messages), len(restored_messages)) - for size in range(max_overlap, 0, -1): - existing = session.messages[-size:] - restored = restored_messages[:size] - if all( - self._checkpoint_message_key(left) == self._checkpoint_message_key(right) - for left, right in zip(existing, restored) - ): - overlap = size - break - session.messages.extend(restored_messages[overlap:]) - - self._clear_pending_user_turn(session) - self._clear_runtime_checkpoint(session) - return True + return self._checkpoint_mgr.restore_runtime_checkpoint(session) def _restore_pending_user_turn(self, session: Session) -> bool: - """Close a turn that only persisted the user message before crashing.""" - from datetime import datetime - - if not session.metadata.get(self._PENDING_USER_TURN_KEY): - return False - - if session.messages and session.messages[-1].get("role") == "user": - session.messages.append( - { - "role": "assistant", - "content": "Error: Task interrupted before a response was generated.", - "timestamp": datetime.now().isoformat(), - } - ) - session.updated_at = datetime.now() - - self._clear_pending_user_turn(session) - return True + return self._checkpoint_mgr.restore_pending_user_turn(session) async def process_direct( self, diff --git a/nanobot/agent/turn_writer.py b/nanobot/agent/turn_writer.py new file mode 100644 index 000000000..1b92dd016 --- /dev/null +++ b/nanobot/agent/turn_writer.py @@ -0,0 +1,163 @@ +"""Turn persistence: sanitize, save, and early-persist message history.""" + +from __future__ import annotations + +from typing import Any + +from nanobot.agent.checkpoint import CheckpointManager +from nanobot.agent.context import ContextBuilder +from nanobot.bus.events import InboundMessage +from nanobot.session.manager import Session, SessionManager +from nanobot.utils.helpers import image_placeholder_text +from nanobot.utils.helpers import truncate_text as truncate_text_fn + + +class TurnWriter: + """Handles persisting turn messages into session history. + + Responsible for sanitising volatile content (data-URI images, runtime + context blocks), truncating large tool results, and stamping latency + metadata on the last assistant message. + """ + + def __init__( + self, + sessions: SessionManager, + checkpoint: CheckpointManager, + max_tool_result_chars: int, + ) -> None: + self.sessions = sessions + self._checkpoint = checkpoint + self.max_tool_result_chars = max_tool_result_chars + + def sanitize_persisted_blocks( + self, + content: list[dict[str, Any]], + *, + should_truncate_text: bool = False, + drop_runtime: bool = False, + ) -> list[dict[str, Any]]: + """Strip volatile multimodal payloads before writing session history.""" + filtered: list[dict[str, Any]] = [] + for block in content: + if not isinstance(block, dict): + filtered.append(block) + continue + + if ( + drop_runtime + and block.get("type") == "text" + and isinstance(block.get("text"), str) + and block["text"].startswith(ContextBuilder._RUNTIME_CONTEXT_TAG) + ): + continue + + if block.get("type") == "image_url" and block.get("image_url", {}).get( + "url", "" + ).startswith("data:image/"): + path = (block.get("_meta") or {}).get("path", "") + filtered.append({"type": "text", "text": image_placeholder_text(path)}) + continue + + if block.get("type") == "text" and isinstance(block.get("text"), str): + text = block["text"] + if should_truncate_text and len(text) > self.max_tool_result_chars: + text = truncate_text_fn(text, self.max_tool_result_chars) + filtered.append({**block, "text": text}) + continue + + filtered.append(block) + + return filtered + + def save_turn( + self, + session: Session, + messages: list[dict], + skip: int, + *, + turn_latency_ms: int | None = None, + ) -> None: + """Save new-turn messages into session, truncating large tool results.""" + from datetime import datetime + + last_assistant_idx: int | None = None + for m in messages[skip:]: + entry = dict(m) + role, content = entry.get("role"), entry.get("content") + if role == "assistant" and not content and not entry.get("tool_calls"): + continue + if role == "tool": + if isinstance(content, str) and len(content) > self.max_tool_result_chars: + entry["content"] = truncate_text_fn(content, self.max_tool_result_chars) + elif isinstance(content, list): + filtered = self.sanitize_persisted_blocks(content, should_truncate_text=True) + if not filtered: + continue + entry["content"] = filtered + elif role == "user": + if isinstance(content, str) and ContextBuilder._RUNTIME_CONTEXT_TAG in content: + tag_pos = content.find(ContextBuilder._RUNTIME_CONTEXT_TAG) + before = content[:tag_pos].rstrip("\n ") + if before: + entry["content"] = before + else: + continue + if isinstance(content, list): + filtered = self.sanitize_persisted_blocks(content, drop_runtime=True) + if not filtered: + continue + entry["content"] = filtered + entry.setdefault("timestamp", datetime.now().isoformat()) + session.messages.append(entry) + if role == "assistant": + last_assistant_idx = len(session.messages) - 1 + if turn_latency_ms is not None and last_assistant_idx is not None: + session.messages[last_assistant_idx]["latency_ms"] = int(turn_latency_ms) + session.updated_at = datetime.now() + + def persist_user_message_early( + self, + msg: InboundMessage, + session: Session, + **kwargs: Any, + ) -> bool: + """Persist the triggering user message before the turn starts. + + Returns True if the message was persisted. + """ + media_paths = [p for p in (msg.media or []) if isinstance(p, str) and p] + has_text = isinstance(msg.content, str) and msg.content.strip() + if has_text or media_paths: + extra: dict[str, Any] = {"media": list(media_paths)} if media_paths else {} + extra.update(kwargs) + text = msg.content if isinstance(msg.content, str) else "" + session.add_message("user", text, **extra) + self._checkpoint.mark_pending_user_turn(session) + self.sessions.save(session) + return True + return False + + def persist_subagent_followup(self, session: Session, msg: InboundMessage) -> bool: + """Persist subagent follow-ups before prompt assembly so history stays durable. + + Returns True if a new entry was appended; False if the follow-up was + deduped (same ``subagent_task_id`` already in session) or carries no + content worth persisting. + """ + if not msg.content: + return False + task_id = msg.metadata.get("subagent_task_id") if isinstance(msg.metadata, dict) else None + if task_id and any( + m.get("injected_event") == "subagent_result" and m.get("subagent_task_id") == task_id + for m in session.messages + ): + return False + session.add_message( + "assistant", + msg.content, + sender_id=msg.sender_id, + injected_event="subagent_result", + subagent_task_id=task_id, + ) + return True diff --git a/tests/agent/test_loop_save_turn.py b/tests/agent/test_loop_save_turn.py index c33ecf422..d09df6da1 100644 --- a/tests/agent/test_loop_save_turn.py +++ b/tests/agent/test_loop_save_turn.py @@ -19,9 +19,17 @@ from nanobot.utils.webui_titles import ( def _mk_loop() -> AgentLoop: loop = AgentLoop.__new__(AgentLoop) + from nanobot.agent.checkpoint import CheckpointManager + from nanobot.agent.turn_writer import TurnWriter from nanobot.config.schema import AgentDefaults loop.max_tool_result_chars = AgentDefaults().max_tool_result_chars + loop._checkpoint_mgr = CheckpointManager(sessions=None) # type: ignore[arg-type] + loop._turn_writer = TurnWriter( + sessions=None, # type: ignore[arg-type] + checkpoint=loop._checkpoint_mgr, + max_tool_result_chars=loop.max_tool_result_chars, + ) return loop diff --git a/tests/test_truncate_text_shadowing.py b/tests/test_truncate_text_shadowing.py index 11132b511..1f16b3f6e 100644 --- a/tests/test_truncate_text_shadowing.py +++ b/tests/test_truncate_text_shadowing.py @@ -20,10 +20,12 @@ def test_sanitize_persisted_blocks_truncate_text_shadowing_regression() -> None: assert "should_truncate_text" in sig.parameters assert "truncate_text" not in sig.parameters - dummy = SimpleNamespace(max_tool_result_chars=5) + from nanobot.agent.turn_writer import TurnWriter + + writer = TurnWriter(sessions=None, checkpoint=None, max_tool_result_chars=5) # type: ignore[arg-type] content = [{"type": "text", "text": "0123456789"}] - out = AgentLoop._sanitize_persisted_blocks(dummy, content, should_truncate_text=True) + out = writer.sanitize_persisted_blocks(content, should_truncate_text=True) assert isinstance(out, list) assert out and out[0]["type"] == "text" assert isinstance(out[0]["text"], str)