refactor(agent): extract checkpoint.py and turn_writer.py from loop.py

Extract checkpoint and turn-persistence logic from AgentLoop into two
dedicated modules to reduce loop.py's size and improve reviewability.

- checkpoint.py: CheckpointManager handles runtime checkpoints, pending
  user-turn markers, and session recovery after cancellation/crash.
- turn_writer.py: TurnWriter handles message sanitization, turn saving,
  early user-message persistence, and subagent-followup persistence.

AgentLoop retains thin delegate methods for full backward compatibility.
Test helpers updated to initialize the new sub-managers.
This commit is contained in:
chengyongru 2026-05-16 18:35:09 +08:00
parent e804f2fddb
commit 4c751bb8e3
5 changed files with 315 additions and 202 deletions

122
nanobot/agent/checkpoint.py Normal file
View File

@ -0,0 +1,122 @@
"""Runtime checkpoint management for session recovery."""
from __future__ import annotations
from typing import Any
from nanobot.session.manager import Session, SessionManager
class CheckpointManager:
"""Manages runtime checkpoints and pending-user-turn markers on sessions.
Checkpoints capture in-flight turn state (assistant messages, completed and
pending tool calls) so that a cancelled or crashed turn can be materialised
into session history on the next request, preventing data loss.
"""
RUNTIME_CHECKPOINT_KEY = "runtime_checkpoint"
PENDING_USER_TURN_KEY = "pending_user_turn"
def __init__(self, sessions: SessionManager) -> None:
self.sessions = sessions
def set_runtime_checkpoint(self, session: Session, payload: dict[str, Any]) -> None:
session.metadata[self.RUNTIME_CHECKPOINT_KEY] = payload
self.sessions.save(session)
def mark_pending_user_turn(self, session: Session) -> None:
session.metadata[self.PENDING_USER_TURN_KEY] = True
def clear_pending_user_turn(self, session: Session) -> None:
session.metadata.pop(self.PENDING_USER_TURN_KEY, None)
def clear_runtime_checkpoint(self, session: Session) -> None:
if self.RUNTIME_CHECKPOINT_KEY in session.metadata:
session.metadata.pop(self.RUNTIME_CHECKPOINT_KEY, None)
@staticmethod
def checkpoint_message_key(message: dict[str, Any]) -> tuple[Any, ...]:
return (
message.get("role"),
message.get("content"),
message.get("tool_call_id"),
message.get("name"),
message.get("tool_calls"),
message.get("reasoning_content"),
message.get("thinking_blocks"),
)
def restore_runtime_checkpoint(self, session: Session) -> bool:
"""Materialize an unfinished turn into session history before a new request."""
from datetime import datetime
checkpoint = session.metadata.get(self.RUNTIME_CHECKPOINT_KEY)
if not isinstance(checkpoint, dict):
return False
assistant_message = checkpoint.get("assistant_message")
completed_tool_results = checkpoint.get("completed_tool_results") or []
pending_tool_calls = checkpoint.get("pending_tool_calls") or []
restored_messages: list[dict[str, Any]] = []
if isinstance(assistant_message, dict):
restored = dict(assistant_message)
restored.setdefault("timestamp", datetime.now().isoformat())
restored_messages.append(restored)
for message in completed_tool_results:
if isinstance(message, dict):
restored = dict(message)
restored.setdefault("timestamp", datetime.now().isoformat())
restored_messages.append(restored)
for tool_call in pending_tool_calls:
if not isinstance(tool_call, dict):
continue
tool_id = tool_call.get("id")
name = ((tool_call.get("function") or {}).get("name")) or "tool"
restored_messages.append(
{
"role": "tool",
"tool_call_id": tool_id,
"name": name,
"content": "Error: Task interrupted before this tool finished.",
"timestamp": datetime.now().isoformat(),
}
)
overlap = 0
max_overlap = min(len(session.messages), len(restored_messages))
for size in range(max_overlap, 0, -1):
existing = session.messages[-size:]
restored = restored_messages[:size]
if all(
self.checkpoint_message_key(left) == self.checkpoint_message_key(right)
for left, right in zip(existing, restored)
):
overlap = size
break
session.messages.extend(restored_messages[overlap:])
self.clear_pending_user_turn(session)
self.clear_runtime_checkpoint(session)
return True
def restore_pending_user_turn(self, session: Session) -> bool:
"""Close a turn that only persisted the user message before crashing."""
from datetime import datetime
if not session.metadata.get(self.PENDING_USER_TURN_KEY):
return False
if session.messages and session.messages[-1].get("role") == "user":
session.messages.append(
{
"role": "assistant",
"content": "Error: Task interrupted before a response was generated.",
"timestamp": datetime.now().isoformat(),
}
)
session.updated_at = datetime.now()
self.clear_pending_user_turn(session)
return True

View File

@ -16,6 +16,7 @@ from loguru import logger
from nanobot.agent import model_presets as preset_helpers from nanobot.agent import model_presets as preset_helpers
from nanobot.agent.autocompact import AutoCompact from nanobot.agent.autocompact import AutoCompact
from nanobot.agent.checkpoint import CheckpointManager
from nanobot.agent.context import ContextBuilder from nanobot.agent.context import ContextBuilder
from nanobot.agent.hook import AgentHook, CompositeHook from nanobot.agent.hook import AgentHook, CompositeHook
from nanobot.agent.memory import Consolidator, Dream from nanobot.agent.memory import Consolidator, Dream
@ -26,6 +27,7 @@ from nanobot.agent.tools.file_state import FileStateStore, bind_file_states, res
from nanobot.agent.tools.message import MessageTool from nanobot.agent.tools.message import MessageTool
from nanobot.agent.tools.registry import ToolRegistry from nanobot.agent.tools.registry import ToolRegistry
from nanobot.agent.tools.self import MyTool from nanobot.agent.tools.self import MyTool
from nanobot.agent.turn_writer import TurnWriter
from nanobot.bus.events import InboundMessage, OutboundMessage from nanobot.bus.events import InboundMessage, OutboundMessage
from nanobot.bus.queue import MessageBus from nanobot.bus.queue import MessageBus
from nanobot.command import CommandContext, CommandRouter, register_builtin_commands from nanobot.command import CommandContext, CommandRouter, register_builtin_commands
@ -40,8 +42,6 @@ from nanobot.session.goal_state import (
from nanobot.session.manager import Session, SessionManager from nanobot.session.manager import Session, SessionManager
from nanobot.utils.artifacts import generated_image_paths_from_messages from nanobot.utils.artifacts import generated_image_paths_from_messages
from nanobot.utils.document import extract_documents from nanobot.utils.document import extract_documents
from nanobot.utils.helpers import image_placeholder_text
from nanobot.utils.helpers import truncate_text as truncate_text_fn
from nanobot.utils.image_generation_intent import image_generation_prompt from nanobot.utils.image_generation_intent import image_generation_prompt
from nanobot.utils.runtime import EMPTY_FINAL_RESPONSE_MESSAGE from nanobot.utils.runtime import EMPTY_FINAL_RESPONSE_MESSAGE
from nanobot.utils.session_attachments import merge_turn_media_into_last_assistant from nanobot.utils.session_attachments import merge_turn_media_into_last_assistant
@ -137,8 +137,8 @@ class AgentLoop:
def tool_names(self) -> list[str]: def tool_names(self) -> list[str]:
return self.tools.tool_names return self.tools.tool_names
_RUNTIME_CHECKPOINT_KEY = "runtime_checkpoint" _RUNTIME_CHECKPOINT_KEY = CheckpointManager.RUNTIME_CHECKPOINT_KEY
_PENDING_USER_TURN_KEY = "pending_user_turn" _PENDING_USER_TURN_KEY = CheckpointManager.PENDING_USER_TURN_KEY
# Event-driven state transition table. # Event-driven state transition table.
# Handlers return an event string; the driver looks up the next state here. # Handlers return an event string; the driver looks up the next state here.
@ -238,6 +238,8 @@ class AgentLoop:
self.context = ContextBuilder(workspace, timezone=timezone, disabled_skills=disabled_skills) self.context = ContextBuilder(workspace, timezone=timezone, disabled_skills=disabled_skills)
self.sessions = session_manager or SessionManager(workspace) self.sessions = session_manager or SessionManager(workspace)
self._checkpoint_mgr = CheckpointManager(self.sessions)
self._turn_writer = TurnWriter(self.sessions, self._checkpoint_mgr, self.max_tool_result_chars)
self.tools = ToolRegistry() self.tools = ToolRegistry()
# One file-read/write tracker per logical session. The tool registry is # One file-read/write tracker per logical session. The tool registry is
# shared by this loop, so tools resolve the active state via contextvars. # shared by this loop, so tools resolve the active state via contextvars.
@ -579,21 +581,7 @@ class AgentLoop:
session: Session, session: Session,
**kwargs: Any, **kwargs: Any,
) -> bool: ) -> bool:
"""Persist the triggering user message before the turn starts. return self._turn_writer.persist_user_message_early(msg, session, **kwargs)
Returns True if the message was persisted.
"""
media_paths = [p for p in (msg.media or []) if isinstance(p, str) and p]
has_text = isinstance(msg.content, str) and msg.content.strip()
if has_text or media_paths:
extra: dict[str, Any] = {"media": list(media_paths)} if media_paths else {}
extra.update(kwargs)
text = msg.content if isinstance(msg.content, str) else ""
session.add_message("user", text, **extra)
self._mark_pending_user_turn(session)
self.sessions.save(session)
return True
return False
def _build_initial_messages( def _build_initial_messages(
self, self,
@ -1440,38 +1428,9 @@ class AgentLoop:
should_truncate_text: bool = False, should_truncate_text: bool = False,
drop_runtime: bool = False, drop_runtime: bool = False,
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
"""Strip volatile multimodal payloads before writing session history.""" return self._turn_writer.sanitize_persisted_blocks(
filtered: list[dict[str, Any]] = [] content, should_truncate_text=should_truncate_text, drop_runtime=drop_runtime,
for block in content: )
if not isinstance(block, dict):
filtered.append(block)
continue
if (
drop_runtime
and block.get("type") == "text"
and isinstance(block.get("text"), str)
and block["text"].startswith(ContextBuilder._RUNTIME_CONTEXT_TAG)
):
continue
if block.get("type") == "image_url" and block.get("image_url", {}).get(
"url", ""
).startswith("data:image/"):
path = (block.get("_meta") or {}).get("path", "")
filtered.append({"type": "text", "text": image_placeholder_text(path)})
continue
if block.get("type") == "text" and isinstance(block.get("text"), str):
text = block["text"]
if should_truncate_text and len(text) > self.max_tool_result_chars:
text = truncate_text_fn(text, self.max_tool_result_chars)
filtered.append({**block, "text": text})
continue
filtered.append(block)
return filtered
def _save_turn( def _save_turn(
self, self,
@ -1481,169 +1440,28 @@ class AgentLoop:
*, *,
turn_latency_ms: int | None = None, turn_latency_ms: int | None = None,
) -> None: ) -> None:
"""Save new-turn messages into session, truncating large tool results.""" self._turn_writer.save_turn(session, messages, skip, turn_latency_ms=turn_latency_ms)
from datetime import datetime
last_assistant_idx: int | None = None
for m in messages[skip:]:
entry = dict(m)
role, content = entry.get("role"), entry.get("content")
if role == "assistant" and not content and not entry.get("tool_calls"):
continue # skip empty assistant messages — they poison session context
if role == "tool":
if isinstance(content, str) and len(content) > self.max_tool_result_chars:
entry["content"] = truncate_text_fn(content, self.max_tool_result_chars)
elif isinstance(content, list):
filtered = self._sanitize_persisted_blocks(content, should_truncate_text=True)
if not filtered:
continue
entry["content"] = filtered
elif role == "user":
if isinstance(content, str) and ContextBuilder._RUNTIME_CONTEXT_TAG in content:
# Strip the runtime-context block appended at the end.
tag_pos = content.find(ContextBuilder._RUNTIME_CONTEXT_TAG)
before = content[:tag_pos].rstrip("\n ")
if before:
entry["content"] = before
else:
continue
if isinstance(content, list):
filtered = self._sanitize_persisted_blocks(content, drop_runtime=True)
if not filtered:
continue
entry["content"] = filtered
entry.setdefault("timestamp", datetime.now().isoformat())
session.messages.append(entry)
if role == "assistant":
last_assistant_idx = len(session.messages) - 1
if turn_latency_ms is not None and last_assistant_idx is not None:
session.messages[last_assistant_idx]["latency_ms"] = int(turn_latency_ms)
session.updated_at = datetime.now()
def _persist_subagent_followup(self, session: Session, msg: InboundMessage) -> bool: def _persist_subagent_followup(self, session: Session, msg: InboundMessage) -> bool:
"""Persist subagent follow-ups before prompt assembly so history stays durable. return self._turn_writer.persist_subagent_followup(session, msg)
Returns True if a new entry was appended; False if the follow-up was
deduped (same ``subagent_task_id`` already in session) or carries no
content worth persisting.
"""
if not msg.content:
return False
task_id = msg.metadata.get("subagent_task_id") if isinstance(msg.metadata, dict) else None
if task_id and any(
m.get("injected_event") == "subagent_result" and m.get("subagent_task_id") == task_id
for m in session.messages
):
return False
session.add_message(
"assistant",
msg.content,
sender_id=msg.sender_id,
injected_event="subagent_result",
subagent_task_id=task_id,
)
return True
def _set_runtime_checkpoint(self, session: Session, payload: dict[str, Any]) -> None: def _set_runtime_checkpoint(self, session: Session, payload: dict[str, Any]) -> None:
"""Persist the latest in-flight turn state into session metadata.""" self._checkpoint_mgr.set_runtime_checkpoint(session, payload)
session.metadata[self._RUNTIME_CHECKPOINT_KEY] = payload
self.sessions.save(session)
def _mark_pending_user_turn(self, session: Session) -> None: def _mark_pending_user_turn(self, session: Session) -> None:
session.metadata[self._PENDING_USER_TURN_KEY] = True self._checkpoint_mgr.mark_pending_user_turn(session)
def _clear_pending_user_turn(self, session: Session) -> None: def _clear_pending_user_turn(self, session: Session) -> None:
session.metadata.pop(self._PENDING_USER_TURN_KEY, None) self._checkpoint_mgr.clear_pending_user_turn(session)
def _clear_runtime_checkpoint(self, session: Session) -> None: def _clear_runtime_checkpoint(self, session: Session) -> None:
if self._RUNTIME_CHECKPOINT_KEY in session.metadata: self._checkpoint_mgr.clear_runtime_checkpoint(session)
session.metadata.pop(self._RUNTIME_CHECKPOINT_KEY, None)
@staticmethod
def _checkpoint_message_key(message: dict[str, Any]) -> tuple[Any, ...]:
return (
message.get("role"),
message.get("content"),
message.get("tool_call_id"),
message.get("name"),
message.get("tool_calls"),
message.get("reasoning_content"),
message.get("thinking_blocks"),
)
def _restore_runtime_checkpoint(self, session: Session) -> bool: def _restore_runtime_checkpoint(self, session: Session) -> bool:
"""Materialize an unfinished turn into session history before a new request.""" return self._checkpoint_mgr.restore_runtime_checkpoint(session)
from datetime import datetime
checkpoint = session.metadata.get(self._RUNTIME_CHECKPOINT_KEY)
if not isinstance(checkpoint, dict):
return False
assistant_message = checkpoint.get("assistant_message")
completed_tool_results = checkpoint.get("completed_tool_results") or []
pending_tool_calls = checkpoint.get("pending_tool_calls") or []
restored_messages: list[dict[str, Any]] = []
if isinstance(assistant_message, dict):
restored = dict(assistant_message)
restored.setdefault("timestamp", datetime.now().isoformat())
restored_messages.append(restored)
for message in completed_tool_results:
if isinstance(message, dict):
restored = dict(message)
restored.setdefault("timestamp", datetime.now().isoformat())
restored_messages.append(restored)
for tool_call in pending_tool_calls:
if not isinstance(tool_call, dict):
continue
tool_id = tool_call.get("id")
name = ((tool_call.get("function") or {}).get("name")) or "tool"
restored_messages.append(
{
"role": "tool",
"tool_call_id": tool_id,
"name": name,
"content": "Error: Task interrupted before this tool finished.",
"timestamp": datetime.now().isoformat(),
}
)
overlap = 0
max_overlap = min(len(session.messages), len(restored_messages))
for size in range(max_overlap, 0, -1):
existing = session.messages[-size:]
restored = restored_messages[:size]
if all(
self._checkpoint_message_key(left) == self._checkpoint_message_key(right)
for left, right in zip(existing, restored)
):
overlap = size
break
session.messages.extend(restored_messages[overlap:])
self._clear_pending_user_turn(session)
self._clear_runtime_checkpoint(session)
return True
def _restore_pending_user_turn(self, session: Session) -> bool: def _restore_pending_user_turn(self, session: Session) -> bool:
"""Close a turn that only persisted the user message before crashing.""" return self._checkpoint_mgr.restore_pending_user_turn(session)
from datetime import datetime
if not session.metadata.get(self._PENDING_USER_TURN_KEY):
return False
if session.messages and session.messages[-1].get("role") == "user":
session.messages.append(
{
"role": "assistant",
"content": "Error: Task interrupted before a response was generated.",
"timestamp": datetime.now().isoformat(),
}
)
session.updated_at = datetime.now()
self._clear_pending_user_turn(session)
return True
async def process_direct( async def process_direct(
self, self,

View File

@ -0,0 +1,163 @@
"""Turn persistence: sanitize, save, and early-persist message history."""
from __future__ import annotations
from typing import Any
from nanobot.agent.checkpoint import CheckpointManager
from nanobot.agent.context import ContextBuilder
from nanobot.bus.events import InboundMessage
from nanobot.session.manager import Session, SessionManager
from nanobot.utils.helpers import image_placeholder_text
from nanobot.utils.helpers import truncate_text as truncate_text_fn
class TurnWriter:
"""Handles persisting turn messages into session history.
Responsible for sanitising volatile content (data-URI images, runtime
context blocks), truncating large tool results, and stamping latency
metadata on the last assistant message.
"""
def __init__(
self,
sessions: SessionManager,
checkpoint: CheckpointManager,
max_tool_result_chars: int,
) -> None:
self.sessions = sessions
self._checkpoint = checkpoint
self.max_tool_result_chars = max_tool_result_chars
def sanitize_persisted_blocks(
self,
content: list[dict[str, Any]],
*,
should_truncate_text: bool = False,
drop_runtime: bool = False,
) -> list[dict[str, Any]]:
"""Strip volatile multimodal payloads before writing session history."""
filtered: list[dict[str, Any]] = []
for block in content:
if not isinstance(block, dict):
filtered.append(block)
continue
if (
drop_runtime
and block.get("type") == "text"
and isinstance(block.get("text"), str)
and block["text"].startswith(ContextBuilder._RUNTIME_CONTEXT_TAG)
):
continue
if block.get("type") == "image_url" and block.get("image_url", {}).get(
"url", ""
).startswith("data:image/"):
path = (block.get("_meta") or {}).get("path", "")
filtered.append({"type": "text", "text": image_placeholder_text(path)})
continue
if block.get("type") == "text" and isinstance(block.get("text"), str):
text = block["text"]
if should_truncate_text and len(text) > self.max_tool_result_chars:
text = truncate_text_fn(text, self.max_tool_result_chars)
filtered.append({**block, "text": text})
continue
filtered.append(block)
return filtered
def save_turn(
self,
session: Session,
messages: list[dict],
skip: int,
*,
turn_latency_ms: int | None = None,
) -> None:
"""Save new-turn messages into session, truncating large tool results."""
from datetime import datetime
last_assistant_idx: int | None = None
for m in messages[skip:]:
entry = dict(m)
role, content = entry.get("role"), entry.get("content")
if role == "assistant" and not content and not entry.get("tool_calls"):
continue
if role == "tool":
if isinstance(content, str) and len(content) > self.max_tool_result_chars:
entry["content"] = truncate_text_fn(content, self.max_tool_result_chars)
elif isinstance(content, list):
filtered = self.sanitize_persisted_blocks(content, should_truncate_text=True)
if not filtered:
continue
entry["content"] = filtered
elif role == "user":
if isinstance(content, str) and ContextBuilder._RUNTIME_CONTEXT_TAG in content:
tag_pos = content.find(ContextBuilder._RUNTIME_CONTEXT_TAG)
before = content[:tag_pos].rstrip("\n ")
if before:
entry["content"] = before
else:
continue
if isinstance(content, list):
filtered = self.sanitize_persisted_blocks(content, drop_runtime=True)
if not filtered:
continue
entry["content"] = filtered
entry.setdefault("timestamp", datetime.now().isoformat())
session.messages.append(entry)
if role == "assistant":
last_assistant_idx = len(session.messages) - 1
if turn_latency_ms is not None and last_assistant_idx is not None:
session.messages[last_assistant_idx]["latency_ms"] = int(turn_latency_ms)
session.updated_at = datetime.now()
def persist_user_message_early(
self,
msg: InboundMessage,
session: Session,
**kwargs: Any,
) -> bool:
"""Persist the triggering user message before the turn starts.
Returns True if the message was persisted.
"""
media_paths = [p for p in (msg.media or []) if isinstance(p, str) and p]
has_text = isinstance(msg.content, str) and msg.content.strip()
if has_text or media_paths:
extra: dict[str, Any] = {"media": list(media_paths)} if media_paths else {}
extra.update(kwargs)
text = msg.content if isinstance(msg.content, str) else ""
session.add_message("user", text, **extra)
self._checkpoint.mark_pending_user_turn(session)
self.sessions.save(session)
return True
return False
def persist_subagent_followup(self, session: Session, msg: InboundMessage) -> bool:
"""Persist subagent follow-ups before prompt assembly so history stays durable.
Returns True if a new entry was appended; False if the follow-up was
deduped (same ``subagent_task_id`` already in session) or carries no
content worth persisting.
"""
if not msg.content:
return False
task_id = msg.metadata.get("subagent_task_id") if isinstance(msg.metadata, dict) else None
if task_id and any(
m.get("injected_event") == "subagent_result" and m.get("subagent_task_id") == task_id
for m in session.messages
):
return False
session.add_message(
"assistant",
msg.content,
sender_id=msg.sender_id,
injected_event="subagent_result",
subagent_task_id=task_id,
)
return True

View File

@ -19,9 +19,17 @@ from nanobot.utils.webui_titles import (
def _mk_loop() -> AgentLoop: def _mk_loop() -> AgentLoop:
loop = AgentLoop.__new__(AgentLoop) loop = AgentLoop.__new__(AgentLoop)
from nanobot.agent.checkpoint import CheckpointManager
from nanobot.agent.turn_writer import TurnWriter
from nanobot.config.schema import AgentDefaults from nanobot.config.schema import AgentDefaults
loop.max_tool_result_chars = AgentDefaults().max_tool_result_chars loop.max_tool_result_chars = AgentDefaults().max_tool_result_chars
loop._checkpoint_mgr = CheckpointManager(sessions=None) # type: ignore[arg-type]
loop._turn_writer = TurnWriter(
sessions=None, # type: ignore[arg-type]
checkpoint=loop._checkpoint_mgr,
max_tool_result_chars=loop.max_tool_result_chars,
)
return loop return loop

View File

@ -20,10 +20,12 @@ def test_sanitize_persisted_blocks_truncate_text_shadowing_regression() -> None:
assert "should_truncate_text" in sig.parameters assert "should_truncate_text" in sig.parameters
assert "truncate_text" not in sig.parameters assert "truncate_text" not in sig.parameters
dummy = SimpleNamespace(max_tool_result_chars=5) from nanobot.agent.turn_writer import TurnWriter
writer = TurnWriter(sessions=None, checkpoint=None, max_tool_result_chars=5) # type: ignore[arg-type]
content = [{"type": "text", "text": "0123456789"}] content = [{"type": "text", "text": "0123456789"}]
out = AgentLoop._sanitize_persisted_blocks(dummy, content, should_truncate_text=True) out = writer.sanitize_persisted_blocks(content, should_truncate_text=True)
assert isinstance(out, list) assert isinstance(out, list)
assert out and out[0]["type"] == "text" assert out and out[0]["type"] == "text"
assert isinstance(out[0]["text"], str) assert isinstance(out[0]["text"], str)