nanobot/nanobot/agent/turn_writer.py
chengyongru 4c751bb8e3 refactor(agent): extract checkpoint.py and turn_writer.py from loop.py
Extract checkpoint and turn-persistence logic from AgentLoop into two
dedicated modules to reduce loop.py's size and improve reviewability.

- checkpoint.py: CheckpointManager handles runtime checkpoints, pending
  user-turn markers, and session recovery after cancellation/crash.
- turn_writer.py: TurnWriter handles message sanitization, turn saving,
  early user-message persistence, and subagent-followup persistence.

AgentLoop retains thin delegate methods for full backward compatibility.
Test helpers updated to initialize the new sub-managers.
2026-05-16 18:35:09 +08:00

164 lines
6.4 KiB
Python

"""Turn persistence: sanitize, save, and early-persist message history."""
from __future__ import annotations
from typing import Any
from nanobot.agent.checkpoint import CheckpointManager
from nanobot.agent.context import ContextBuilder
from nanobot.bus.events import InboundMessage
from nanobot.session.manager import Session, SessionManager
from nanobot.utils.helpers import image_placeholder_text
from nanobot.utils.helpers import truncate_text as truncate_text_fn
class TurnWriter:
"""Handles persisting turn messages into session history.
Responsible for sanitising volatile content (data-URI images, runtime
context blocks), truncating large tool results, and stamping latency
metadata on the last assistant message.
"""
def __init__(
self,
sessions: SessionManager,
checkpoint: CheckpointManager,
max_tool_result_chars: int,
) -> None:
self.sessions = sessions
self._checkpoint = checkpoint
self.max_tool_result_chars = max_tool_result_chars
def sanitize_persisted_blocks(
self,
content: list[dict[str, Any]],
*,
should_truncate_text: bool = False,
drop_runtime: bool = False,
) -> list[dict[str, Any]]:
"""Strip volatile multimodal payloads before writing session history."""
filtered: list[dict[str, Any]] = []
for block in content:
if not isinstance(block, dict):
filtered.append(block)
continue
if (
drop_runtime
and block.get("type") == "text"
and isinstance(block.get("text"), str)
and block["text"].startswith(ContextBuilder._RUNTIME_CONTEXT_TAG)
):
continue
if block.get("type") == "image_url" and block.get("image_url", {}).get(
"url", ""
).startswith("data:image/"):
path = (block.get("_meta") or {}).get("path", "")
filtered.append({"type": "text", "text": image_placeholder_text(path)})
continue
if block.get("type") == "text" and isinstance(block.get("text"), str):
text = block["text"]
if should_truncate_text and len(text) > self.max_tool_result_chars:
text = truncate_text_fn(text, self.max_tool_result_chars)
filtered.append({**block, "text": text})
continue
filtered.append(block)
return filtered
def save_turn(
self,
session: Session,
messages: list[dict],
skip: int,
*,
turn_latency_ms: int | None = None,
) -> None:
"""Save new-turn messages into session, truncating large tool results."""
from datetime import datetime
last_assistant_idx: int | None = None
for m in messages[skip:]:
entry = dict(m)
role, content = entry.get("role"), entry.get("content")
if role == "assistant" and not content and not entry.get("tool_calls"):
continue
if role == "tool":
if isinstance(content, str) and len(content) > self.max_tool_result_chars:
entry["content"] = truncate_text_fn(content, self.max_tool_result_chars)
elif isinstance(content, list):
filtered = self.sanitize_persisted_blocks(content, should_truncate_text=True)
if not filtered:
continue
entry["content"] = filtered
elif role == "user":
if isinstance(content, str) and ContextBuilder._RUNTIME_CONTEXT_TAG in content:
tag_pos = content.find(ContextBuilder._RUNTIME_CONTEXT_TAG)
before = content[:tag_pos].rstrip("\n ")
if before:
entry["content"] = before
else:
continue
if isinstance(content, list):
filtered = self.sanitize_persisted_blocks(content, drop_runtime=True)
if not filtered:
continue
entry["content"] = filtered
entry.setdefault("timestamp", datetime.now().isoformat())
session.messages.append(entry)
if role == "assistant":
last_assistant_idx = len(session.messages) - 1
if turn_latency_ms is not None and last_assistant_idx is not None:
session.messages[last_assistant_idx]["latency_ms"] = int(turn_latency_ms)
session.updated_at = datetime.now()
def persist_user_message_early(
self,
msg: InboundMessage,
session: Session,
**kwargs: Any,
) -> bool:
"""Persist the triggering user message before the turn starts.
Returns True if the message was persisted.
"""
media_paths = [p for p in (msg.media or []) if isinstance(p, str) and p]
has_text = isinstance(msg.content, str) and msg.content.strip()
if has_text or media_paths:
extra: dict[str, Any] = {"media": list(media_paths)} if media_paths else {}
extra.update(kwargs)
text = msg.content if isinstance(msg.content, str) else ""
session.add_message("user", text, **extra)
self._checkpoint.mark_pending_user_turn(session)
self.sessions.save(session)
return True
return False
def persist_subagent_followup(self, session: Session, msg: InboundMessage) -> bool:
"""Persist subagent follow-ups before prompt assembly so history stays durable.
Returns True if a new entry was appended; False if the follow-up was
deduped (same ``subagent_task_id`` already in session) or carries no
content worth persisting.
"""
if not msg.content:
return False
task_id = msg.metadata.get("subagent_task_id") if isinstance(msg.metadata, dict) else None
if task_id and any(
m.get("injected_event") == "subagent_result" and m.get("subagent_task_id") == task_id
for m in session.messages
):
return False
session.add_message(
"assistant",
msg.content,
sender_id=msg.sender_id,
injected_event="subagent_result",
subagent_task_id=task_id,
)
return True