refactor(agent): extract checkpoint.py and turn_writer.py from loop.py

Extract checkpoint and turn-persistence logic from AgentLoop into two
dedicated modules to reduce loop.py's size and improve reviewability.

- checkpoint.py: CheckpointManager handles runtime checkpoints, pending
  user-turn markers, and session recovery after cancellation/crash.
- turn_writer.py: TurnWriter handles message sanitization, turn saving,
  early user-message persistence, and subagent-followup persistence.

AgentLoop retains thin delegate methods for full backward compatibility.
Test helpers updated to initialize the new sub-managers.
This commit is contained in:
chengyongru 2026-05-16 18:35:09 +08:00
parent e804f2fddb
commit 4c751bb8e3
5 changed files with 315 additions and 202 deletions

122
nanobot/agent/checkpoint.py Normal file
View File

@ -0,0 +1,122 @@
"""Runtime checkpoint management for session recovery."""
from __future__ import annotations
from typing import Any
from nanobot.session.manager import Session, SessionManager
class CheckpointManager:
"""Manages runtime checkpoints and pending-user-turn markers on sessions.
Checkpoints capture in-flight turn state (assistant messages, completed and
pending tool calls) so that a cancelled or crashed turn can be materialised
into session history on the next request, preventing data loss.
"""
RUNTIME_CHECKPOINT_KEY = "runtime_checkpoint"
PENDING_USER_TURN_KEY = "pending_user_turn"
def __init__(self, sessions: SessionManager) -> None:
self.sessions = sessions
def set_runtime_checkpoint(self, session: Session, payload: dict[str, Any]) -> None:
session.metadata[self.RUNTIME_CHECKPOINT_KEY] = payload
self.sessions.save(session)
def mark_pending_user_turn(self, session: Session) -> None:
session.metadata[self.PENDING_USER_TURN_KEY] = True
def clear_pending_user_turn(self, session: Session) -> None:
session.metadata.pop(self.PENDING_USER_TURN_KEY, None)
def clear_runtime_checkpoint(self, session: Session) -> None:
if self.RUNTIME_CHECKPOINT_KEY in session.metadata:
session.metadata.pop(self.RUNTIME_CHECKPOINT_KEY, None)
@staticmethod
def checkpoint_message_key(message: dict[str, Any]) -> tuple[Any, ...]:
return (
message.get("role"),
message.get("content"),
message.get("tool_call_id"),
message.get("name"),
message.get("tool_calls"),
message.get("reasoning_content"),
message.get("thinking_blocks"),
)
def restore_runtime_checkpoint(self, session: Session) -> bool:
"""Materialize an unfinished turn into session history before a new request."""
from datetime import datetime
checkpoint = session.metadata.get(self.RUNTIME_CHECKPOINT_KEY)
if not isinstance(checkpoint, dict):
return False
assistant_message = checkpoint.get("assistant_message")
completed_tool_results = checkpoint.get("completed_tool_results") or []
pending_tool_calls = checkpoint.get("pending_tool_calls") or []
restored_messages: list[dict[str, Any]] = []
if isinstance(assistant_message, dict):
restored = dict(assistant_message)
restored.setdefault("timestamp", datetime.now().isoformat())
restored_messages.append(restored)
for message in completed_tool_results:
if isinstance(message, dict):
restored = dict(message)
restored.setdefault("timestamp", datetime.now().isoformat())
restored_messages.append(restored)
for tool_call in pending_tool_calls:
if not isinstance(tool_call, dict):
continue
tool_id = tool_call.get("id")
name = ((tool_call.get("function") or {}).get("name")) or "tool"
restored_messages.append(
{
"role": "tool",
"tool_call_id": tool_id,
"name": name,
"content": "Error: Task interrupted before this tool finished.",
"timestamp": datetime.now().isoformat(),
}
)
overlap = 0
max_overlap = min(len(session.messages), len(restored_messages))
for size in range(max_overlap, 0, -1):
existing = session.messages[-size:]
restored = restored_messages[:size]
if all(
self.checkpoint_message_key(left) == self.checkpoint_message_key(right)
for left, right in zip(existing, restored)
):
overlap = size
break
session.messages.extend(restored_messages[overlap:])
self.clear_pending_user_turn(session)
self.clear_runtime_checkpoint(session)
return True
def restore_pending_user_turn(self, session: Session) -> bool:
"""Close a turn that only persisted the user message before crashing."""
from datetime import datetime
if not session.metadata.get(self.PENDING_USER_TURN_KEY):
return False
if session.messages and session.messages[-1].get("role") == "user":
session.messages.append(
{
"role": "assistant",
"content": "Error: Task interrupted before a response was generated.",
"timestamp": datetime.now().isoformat(),
}
)
session.updated_at = datetime.now()
self.clear_pending_user_turn(session)
return True

View File

@ -16,6 +16,7 @@ from loguru import logger
from nanobot.agent import model_presets as preset_helpers
from nanobot.agent.autocompact import AutoCompact
from nanobot.agent.checkpoint import CheckpointManager
from nanobot.agent.context import ContextBuilder
from nanobot.agent.hook import AgentHook, CompositeHook
from nanobot.agent.memory import Consolidator, Dream
@ -26,6 +27,7 @@ from nanobot.agent.tools.file_state import FileStateStore, bind_file_states, res
from nanobot.agent.tools.message import MessageTool
from nanobot.agent.tools.registry import ToolRegistry
from nanobot.agent.tools.self import MyTool
from nanobot.agent.turn_writer import TurnWriter
from nanobot.bus.events import InboundMessage, OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.command import CommandContext, CommandRouter, register_builtin_commands
@ -40,8 +42,6 @@ from nanobot.session.goal_state import (
from nanobot.session.manager import Session, SessionManager
from nanobot.utils.artifacts import generated_image_paths_from_messages
from nanobot.utils.document import extract_documents
from nanobot.utils.helpers import image_placeholder_text
from nanobot.utils.helpers import truncate_text as truncate_text_fn
from nanobot.utils.image_generation_intent import image_generation_prompt
from nanobot.utils.runtime import EMPTY_FINAL_RESPONSE_MESSAGE
from nanobot.utils.session_attachments import merge_turn_media_into_last_assistant
@ -137,8 +137,8 @@ class AgentLoop:
def tool_names(self) -> list[str]:
return self.tools.tool_names
_RUNTIME_CHECKPOINT_KEY = "runtime_checkpoint"
_PENDING_USER_TURN_KEY = "pending_user_turn"
_RUNTIME_CHECKPOINT_KEY = CheckpointManager.RUNTIME_CHECKPOINT_KEY
_PENDING_USER_TURN_KEY = CheckpointManager.PENDING_USER_TURN_KEY
# Event-driven state transition table.
# Handlers return an event string; the driver looks up the next state here.
@ -238,6 +238,8 @@ class AgentLoop:
self.context = ContextBuilder(workspace, timezone=timezone, disabled_skills=disabled_skills)
self.sessions = session_manager or SessionManager(workspace)
self._checkpoint_mgr = CheckpointManager(self.sessions)
self._turn_writer = TurnWriter(self.sessions, self._checkpoint_mgr, self.max_tool_result_chars)
self.tools = ToolRegistry()
# One file-read/write tracker per logical session. The tool registry is
# shared by this loop, so tools resolve the active state via contextvars.
@ -579,21 +581,7 @@ class AgentLoop:
session: Session,
**kwargs: Any,
) -> bool:
"""Persist the triggering user message before the turn starts.
Returns True if the message was persisted.
"""
media_paths = [p for p in (msg.media or []) if isinstance(p, str) and p]
has_text = isinstance(msg.content, str) and msg.content.strip()
if has_text or media_paths:
extra: dict[str, Any] = {"media": list(media_paths)} if media_paths else {}
extra.update(kwargs)
text = msg.content if isinstance(msg.content, str) else ""
session.add_message("user", text, **extra)
self._mark_pending_user_turn(session)
self.sessions.save(session)
return True
return False
return self._turn_writer.persist_user_message_early(msg, session, **kwargs)
def _build_initial_messages(
self,
@ -1440,38 +1428,9 @@ class AgentLoop:
should_truncate_text: bool = False,
drop_runtime: bool = False,
) -> list[dict[str, Any]]:
"""Strip volatile multimodal payloads before writing session history."""
filtered: list[dict[str, Any]] = []
for block in content:
if not isinstance(block, dict):
filtered.append(block)
continue
if (
drop_runtime
and block.get("type") == "text"
and isinstance(block.get("text"), str)
and block["text"].startswith(ContextBuilder._RUNTIME_CONTEXT_TAG)
):
continue
if block.get("type") == "image_url" and block.get("image_url", {}).get(
"url", ""
).startswith("data:image/"):
path = (block.get("_meta") or {}).get("path", "")
filtered.append({"type": "text", "text": image_placeholder_text(path)})
continue
if block.get("type") == "text" and isinstance(block.get("text"), str):
text = block["text"]
if should_truncate_text and len(text) > self.max_tool_result_chars:
text = truncate_text_fn(text, self.max_tool_result_chars)
filtered.append({**block, "text": text})
continue
filtered.append(block)
return filtered
return self._turn_writer.sanitize_persisted_blocks(
content, should_truncate_text=should_truncate_text, drop_runtime=drop_runtime,
)
def _save_turn(
self,
@ -1481,169 +1440,28 @@ class AgentLoop:
*,
turn_latency_ms: int | None = None,
) -> None:
"""Save new-turn messages into session, truncating large tool results."""
from datetime import datetime
last_assistant_idx: int | None = None
for m in messages[skip:]:
entry = dict(m)
role, content = entry.get("role"), entry.get("content")
if role == "assistant" and not content and not entry.get("tool_calls"):
continue # skip empty assistant messages — they poison session context
if role == "tool":
if isinstance(content, str) and len(content) > self.max_tool_result_chars:
entry["content"] = truncate_text_fn(content, self.max_tool_result_chars)
elif isinstance(content, list):
filtered = self._sanitize_persisted_blocks(content, should_truncate_text=True)
if not filtered:
continue
entry["content"] = filtered
elif role == "user":
if isinstance(content, str) and ContextBuilder._RUNTIME_CONTEXT_TAG in content:
# Strip the runtime-context block appended at the end.
tag_pos = content.find(ContextBuilder._RUNTIME_CONTEXT_TAG)
before = content[:tag_pos].rstrip("\n ")
if before:
entry["content"] = before
else:
continue
if isinstance(content, list):
filtered = self._sanitize_persisted_blocks(content, drop_runtime=True)
if not filtered:
continue
entry["content"] = filtered
entry.setdefault("timestamp", datetime.now().isoformat())
session.messages.append(entry)
if role == "assistant":
last_assistant_idx = len(session.messages) - 1
if turn_latency_ms is not None and last_assistant_idx is not None:
session.messages[last_assistant_idx]["latency_ms"] = int(turn_latency_ms)
session.updated_at = datetime.now()
self._turn_writer.save_turn(session, messages, skip, turn_latency_ms=turn_latency_ms)
def _persist_subagent_followup(self, session: Session, msg: InboundMessage) -> bool:
"""Persist subagent follow-ups before prompt assembly so history stays durable.
Returns True if a new entry was appended; False if the follow-up was
deduped (same ``subagent_task_id`` already in session) or carries no
content worth persisting.
"""
if not msg.content:
return False
task_id = msg.metadata.get("subagent_task_id") if isinstance(msg.metadata, dict) else None
if task_id and any(
m.get("injected_event") == "subagent_result" and m.get("subagent_task_id") == task_id
for m in session.messages
):
return False
session.add_message(
"assistant",
msg.content,
sender_id=msg.sender_id,
injected_event="subagent_result",
subagent_task_id=task_id,
)
return True
return self._turn_writer.persist_subagent_followup(session, msg)
def _set_runtime_checkpoint(self, session: Session, payload: dict[str, Any]) -> None:
"""Persist the latest in-flight turn state into session metadata."""
session.metadata[self._RUNTIME_CHECKPOINT_KEY] = payload
self.sessions.save(session)
self._checkpoint_mgr.set_runtime_checkpoint(session, payload)
def _mark_pending_user_turn(self, session: Session) -> None:
session.metadata[self._PENDING_USER_TURN_KEY] = True
self._checkpoint_mgr.mark_pending_user_turn(session)
def _clear_pending_user_turn(self, session: Session) -> None:
session.metadata.pop(self._PENDING_USER_TURN_KEY, None)
self._checkpoint_mgr.clear_pending_user_turn(session)
def _clear_runtime_checkpoint(self, session: Session) -> None:
if self._RUNTIME_CHECKPOINT_KEY in session.metadata:
session.metadata.pop(self._RUNTIME_CHECKPOINT_KEY, None)
@staticmethod
def _checkpoint_message_key(message: dict[str, Any]) -> tuple[Any, ...]:
return (
message.get("role"),
message.get("content"),
message.get("tool_call_id"),
message.get("name"),
message.get("tool_calls"),
message.get("reasoning_content"),
message.get("thinking_blocks"),
)
self._checkpoint_mgr.clear_runtime_checkpoint(session)
def _restore_runtime_checkpoint(self, session: Session) -> bool:
"""Materialize an unfinished turn into session history before a new request."""
from datetime import datetime
checkpoint = session.metadata.get(self._RUNTIME_CHECKPOINT_KEY)
if not isinstance(checkpoint, dict):
return False
assistant_message = checkpoint.get("assistant_message")
completed_tool_results = checkpoint.get("completed_tool_results") or []
pending_tool_calls = checkpoint.get("pending_tool_calls") or []
restored_messages: list[dict[str, Any]] = []
if isinstance(assistant_message, dict):
restored = dict(assistant_message)
restored.setdefault("timestamp", datetime.now().isoformat())
restored_messages.append(restored)
for message in completed_tool_results:
if isinstance(message, dict):
restored = dict(message)
restored.setdefault("timestamp", datetime.now().isoformat())
restored_messages.append(restored)
for tool_call in pending_tool_calls:
if not isinstance(tool_call, dict):
continue
tool_id = tool_call.get("id")
name = ((tool_call.get("function") or {}).get("name")) or "tool"
restored_messages.append(
{
"role": "tool",
"tool_call_id": tool_id,
"name": name,
"content": "Error: Task interrupted before this tool finished.",
"timestamp": datetime.now().isoformat(),
}
)
overlap = 0
max_overlap = min(len(session.messages), len(restored_messages))
for size in range(max_overlap, 0, -1):
existing = session.messages[-size:]
restored = restored_messages[:size]
if all(
self._checkpoint_message_key(left) == self._checkpoint_message_key(right)
for left, right in zip(existing, restored)
):
overlap = size
break
session.messages.extend(restored_messages[overlap:])
self._clear_pending_user_turn(session)
self._clear_runtime_checkpoint(session)
return True
return self._checkpoint_mgr.restore_runtime_checkpoint(session)
def _restore_pending_user_turn(self, session: Session) -> bool:
"""Close a turn that only persisted the user message before crashing."""
from datetime import datetime
if not session.metadata.get(self._PENDING_USER_TURN_KEY):
return False
if session.messages and session.messages[-1].get("role") == "user":
session.messages.append(
{
"role": "assistant",
"content": "Error: Task interrupted before a response was generated.",
"timestamp": datetime.now().isoformat(),
}
)
session.updated_at = datetime.now()
self._clear_pending_user_turn(session)
return True
return self._checkpoint_mgr.restore_pending_user_turn(session)
async def process_direct(
self,

View File

@ -0,0 +1,163 @@
"""Turn persistence: sanitize, save, and early-persist message history."""
from __future__ import annotations
from typing import Any
from nanobot.agent.checkpoint import CheckpointManager
from nanobot.agent.context import ContextBuilder
from nanobot.bus.events import InboundMessage
from nanobot.session.manager import Session, SessionManager
from nanobot.utils.helpers import image_placeholder_text
from nanobot.utils.helpers import truncate_text as truncate_text_fn
class TurnWriter:
"""Handles persisting turn messages into session history.
Responsible for sanitising volatile content (data-URI images, runtime
context blocks), truncating large tool results, and stamping latency
metadata on the last assistant message.
"""
def __init__(
self,
sessions: SessionManager,
checkpoint: CheckpointManager,
max_tool_result_chars: int,
) -> None:
self.sessions = sessions
self._checkpoint = checkpoint
self.max_tool_result_chars = max_tool_result_chars
def sanitize_persisted_blocks(
self,
content: list[dict[str, Any]],
*,
should_truncate_text: bool = False,
drop_runtime: bool = False,
) -> list[dict[str, Any]]:
"""Strip volatile multimodal payloads before writing session history."""
filtered: list[dict[str, Any]] = []
for block in content:
if not isinstance(block, dict):
filtered.append(block)
continue
if (
drop_runtime
and block.get("type") == "text"
and isinstance(block.get("text"), str)
and block["text"].startswith(ContextBuilder._RUNTIME_CONTEXT_TAG)
):
continue
if block.get("type") == "image_url" and block.get("image_url", {}).get(
"url", ""
).startswith("data:image/"):
path = (block.get("_meta") or {}).get("path", "")
filtered.append({"type": "text", "text": image_placeholder_text(path)})
continue
if block.get("type") == "text" and isinstance(block.get("text"), str):
text = block["text"]
if should_truncate_text and len(text) > self.max_tool_result_chars:
text = truncate_text_fn(text, self.max_tool_result_chars)
filtered.append({**block, "text": text})
continue
filtered.append(block)
return filtered
def save_turn(
self,
session: Session,
messages: list[dict],
skip: int,
*,
turn_latency_ms: int | None = None,
) -> None:
"""Save new-turn messages into session, truncating large tool results."""
from datetime import datetime
last_assistant_idx: int | None = None
for m in messages[skip:]:
entry = dict(m)
role, content = entry.get("role"), entry.get("content")
if role == "assistant" and not content and not entry.get("tool_calls"):
continue
if role == "tool":
if isinstance(content, str) and len(content) > self.max_tool_result_chars:
entry["content"] = truncate_text_fn(content, self.max_tool_result_chars)
elif isinstance(content, list):
filtered = self.sanitize_persisted_blocks(content, should_truncate_text=True)
if not filtered:
continue
entry["content"] = filtered
elif role == "user":
if isinstance(content, str) and ContextBuilder._RUNTIME_CONTEXT_TAG in content:
tag_pos = content.find(ContextBuilder._RUNTIME_CONTEXT_TAG)
before = content[:tag_pos].rstrip("\n ")
if before:
entry["content"] = before
else:
continue
if isinstance(content, list):
filtered = self.sanitize_persisted_blocks(content, drop_runtime=True)
if not filtered:
continue
entry["content"] = filtered
entry.setdefault("timestamp", datetime.now().isoformat())
session.messages.append(entry)
if role == "assistant":
last_assistant_idx = len(session.messages) - 1
if turn_latency_ms is not None and last_assistant_idx is not None:
session.messages[last_assistant_idx]["latency_ms"] = int(turn_latency_ms)
session.updated_at = datetime.now()
def persist_user_message_early(
self,
msg: InboundMessage,
session: Session,
**kwargs: Any,
) -> bool:
"""Persist the triggering user message before the turn starts.
Returns True if the message was persisted.
"""
media_paths = [p for p in (msg.media or []) if isinstance(p, str) and p]
has_text = isinstance(msg.content, str) and msg.content.strip()
if has_text or media_paths:
extra: dict[str, Any] = {"media": list(media_paths)} if media_paths else {}
extra.update(kwargs)
text = msg.content if isinstance(msg.content, str) else ""
session.add_message("user", text, **extra)
self._checkpoint.mark_pending_user_turn(session)
self.sessions.save(session)
return True
return False
def persist_subagent_followup(self, session: Session, msg: InboundMessage) -> bool:
"""Persist subagent follow-ups before prompt assembly so history stays durable.
Returns True if a new entry was appended; False if the follow-up was
deduped (same ``subagent_task_id`` already in session) or carries no
content worth persisting.
"""
if not msg.content:
return False
task_id = msg.metadata.get("subagent_task_id") if isinstance(msg.metadata, dict) else None
if task_id and any(
m.get("injected_event") == "subagent_result" and m.get("subagent_task_id") == task_id
for m in session.messages
):
return False
session.add_message(
"assistant",
msg.content,
sender_id=msg.sender_id,
injected_event="subagent_result",
subagent_task_id=task_id,
)
return True

View File

@ -19,9 +19,17 @@ from nanobot.utils.webui_titles import (
def _mk_loop() -> AgentLoop:
loop = AgentLoop.__new__(AgentLoop)
from nanobot.agent.checkpoint import CheckpointManager
from nanobot.agent.turn_writer import TurnWriter
from nanobot.config.schema import AgentDefaults
loop.max_tool_result_chars = AgentDefaults().max_tool_result_chars
loop._checkpoint_mgr = CheckpointManager(sessions=None) # type: ignore[arg-type]
loop._turn_writer = TurnWriter(
sessions=None, # type: ignore[arg-type]
checkpoint=loop._checkpoint_mgr,
max_tool_result_chars=loop.max_tool_result_chars,
)
return loop

View File

@ -20,10 +20,12 @@ def test_sanitize_persisted_blocks_truncate_text_shadowing_regression() -> None:
assert "should_truncate_text" in sig.parameters
assert "truncate_text" not in sig.parameters
dummy = SimpleNamespace(max_tool_result_chars=5)
from nanobot.agent.turn_writer import TurnWriter
writer = TurnWriter(sessions=None, checkpoint=None, max_tool_result_chars=5) # type: ignore[arg-type]
content = [{"type": "text", "text": "0123456789"}]
out = AgentLoop._sanitize_persisted_blocks(dummy, content, should_truncate_text=True)
out = writer.sanitize_persisted_blocks(content, should_truncate_text=True)
assert isinstance(out, list)
assert out and out[0]["type"] == "text"
assert isinstance(out[0]["text"], str)