mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-19 08:02:30 +00:00
refactor(agent): extract checkpoint.py and turn_writer.py from loop.py
Extract checkpoint and turn-persistence logic from AgentLoop into two dedicated modules to reduce loop.py's size and improve reviewability. - checkpoint.py: CheckpointManager handles runtime checkpoints, pending user-turn markers, and session recovery after cancellation/crash. - turn_writer.py: TurnWriter handles message sanitization, turn saving, early user-message persistence, and subagent-followup persistence. AgentLoop retains thin delegate methods for full backward compatibility. Test helpers updated to initialize the new sub-managers.
This commit is contained in:
parent
e804f2fddb
commit
4c751bb8e3
122
nanobot/agent/checkpoint.py
Normal file
122
nanobot/agent/checkpoint.py
Normal file
@ -0,0 +1,122 @@
|
||||
"""Runtime checkpoint management for session recovery."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from nanobot.session.manager import Session, SessionManager
|
||||
|
||||
|
||||
class CheckpointManager:
|
||||
"""Manages runtime checkpoints and pending-user-turn markers on sessions.
|
||||
|
||||
Checkpoints capture in-flight turn state (assistant messages, completed and
|
||||
pending tool calls) so that a cancelled or crashed turn can be materialised
|
||||
into session history on the next request, preventing data loss.
|
||||
"""
|
||||
|
||||
RUNTIME_CHECKPOINT_KEY = "runtime_checkpoint"
|
||||
PENDING_USER_TURN_KEY = "pending_user_turn"
|
||||
|
||||
def __init__(self, sessions: SessionManager) -> None:
|
||||
self.sessions = sessions
|
||||
|
||||
def set_runtime_checkpoint(self, session: Session, payload: dict[str, Any]) -> None:
|
||||
session.metadata[self.RUNTIME_CHECKPOINT_KEY] = payload
|
||||
self.sessions.save(session)
|
||||
|
||||
def mark_pending_user_turn(self, session: Session) -> None:
|
||||
session.metadata[self.PENDING_USER_TURN_KEY] = True
|
||||
|
||||
def clear_pending_user_turn(self, session: Session) -> None:
|
||||
session.metadata.pop(self.PENDING_USER_TURN_KEY, None)
|
||||
|
||||
def clear_runtime_checkpoint(self, session: Session) -> None:
|
||||
if self.RUNTIME_CHECKPOINT_KEY in session.metadata:
|
||||
session.metadata.pop(self.RUNTIME_CHECKPOINT_KEY, None)
|
||||
|
||||
@staticmethod
|
||||
def checkpoint_message_key(message: dict[str, Any]) -> tuple[Any, ...]:
|
||||
return (
|
||||
message.get("role"),
|
||||
message.get("content"),
|
||||
message.get("tool_call_id"),
|
||||
message.get("name"),
|
||||
message.get("tool_calls"),
|
||||
message.get("reasoning_content"),
|
||||
message.get("thinking_blocks"),
|
||||
)
|
||||
|
||||
def restore_runtime_checkpoint(self, session: Session) -> bool:
|
||||
"""Materialize an unfinished turn into session history before a new request."""
|
||||
from datetime import datetime
|
||||
|
||||
checkpoint = session.metadata.get(self.RUNTIME_CHECKPOINT_KEY)
|
||||
if not isinstance(checkpoint, dict):
|
||||
return False
|
||||
|
||||
assistant_message = checkpoint.get("assistant_message")
|
||||
completed_tool_results = checkpoint.get("completed_tool_results") or []
|
||||
pending_tool_calls = checkpoint.get("pending_tool_calls") or []
|
||||
|
||||
restored_messages: list[dict[str, Any]] = []
|
||||
if isinstance(assistant_message, dict):
|
||||
restored = dict(assistant_message)
|
||||
restored.setdefault("timestamp", datetime.now().isoformat())
|
||||
restored_messages.append(restored)
|
||||
for message in completed_tool_results:
|
||||
if isinstance(message, dict):
|
||||
restored = dict(message)
|
||||
restored.setdefault("timestamp", datetime.now().isoformat())
|
||||
restored_messages.append(restored)
|
||||
for tool_call in pending_tool_calls:
|
||||
if not isinstance(tool_call, dict):
|
||||
continue
|
||||
tool_id = tool_call.get("id")
|
||||
name = ((tool_call.get("function") or {}).get("name")) or "tool"
|
||||
restored_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_id,
|
||||
"name": name,
|
||||
"content": "Error: Task interrupted before this tool finished.",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
)
|
||||
|
||||
overlap = 0
|
||||
max_overlap = min(len(session.messages), len(restored_messages))
|
||||
for size in range(max_overlap, 0, -1):
|
||||
existing = session.messages[-size:]
|
||||
restored = restored_messages[:size]
|
||||
if all(
|
||||
self.checkpoint_message_key(left) == self.checkpoint_message_key(right)
|
||||
for left, right in zip(existing, restored)
|
||||
):
|
||||
overlap = size
|
||||
break
|
||||
session.messages.extend(restored_messages[overlap:])
|
||||
|
||||
self.clear_pending_user_turn(session)
|
||||
self.clear_runtime_checkpoint(session)
|
||||
return True
|
||||
|
||||
def restore_pending_user_turn(self, session: Session) -> bool:
|
||||
"""Close a turn that only persisted the user message before crashing."""
|
||||
from datetime import datetime
|
||||
|
||||
if not session.metadata.get(self.PENDING_USER_TURN_KEY):
|
||||
return False
|
||||
|
||||
if session.messages and session.messages[-1].get("role") == "user":
|
||||
session.messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Error: Task interrupted before a response was generated.",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
)
|
||||
session.updated_at = datetime.now()
|
||||
|
||||
self.clear_pending_user_turn(session)
|
||||
return True
|
||||
@ -16,6 +16,7 @@ from loguru import logger
|
||||
|
||||
from nanobot.agent import model_presets as preset_helpers
|
||||
from nanobot.agent.autocompact import AutoCompact
|
||||
from nanobot.agent.checkpoint import CheckpointManager
|
||||
from nanobot.agent.context import ContextBuilder
|
||||
from nanobot.agent.hook import AgentHook, CompositeHook
|
||||
from nanobot.agent.memory import Consolidator, Dream
|
||||
@ -26,6 +27,7 @@ from nanobot.agent.tools.file_state import FileStateStore, bind_file_states, res
|
||||
from nanobot.agent.tools.message import MessageTool
|
||||
from nanobot.agent.tools.registry import ToolRegistry
|
||||
from nanobot.agent.tools.self import MyTool
|
||||
from nanobot.agent.turn_writer import TurnWriter
|
||||
from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.command import CommandContext, CommandRouter, register_builtin_commands
|
||||
@ -40,8 +42,6 @@ from nanobot.session.goal_state import (
|
||||
from nanobot.session.manager import Session, SessionManager
|
||||
from nanobot.utils.artifacts import generated_image_paths_from_messages
|
||||
from nanobot.utils.document import extract_documents
|
||||
from nanobot.utils.helpers import image_placeholder_text
|
||||
from nanobot.utils.helpers import truncate_text as truncate_text_fn
|
||||
from nanobot.utils.image_generation_intent import image_generation_prompt
|
||||
from nanobot.utils.runtime import EMPTY_FINAL_RESPONSE_MESSAGE
|
||||
from nanobot.utils.session_attachments import merge_turn_media_into_last_assistant
|
||||
@ -137,8 +137,8 @@ class AgentLoop:
|
||||
def tool_names(self) -> list[str]:
|
||||
return self.tools.tool_names
|
||||
|
||||
_RUNTIME_CHECKPOINT_KEY = "runtime_checkpoint"
|
||||
_PENDING_USER_TURN_KEY = "pending_user_turn"
|
||||
_RUNTIME_CHECKPOINT_KEY = CheckpointManager.RUNTIME_CHECKPOINT_KEY
|
||||
_PENDING_USER_TURN_KEY = CheckpointManager.PENDING_USER_TURN_KEY
|
||||
|
||||
# Event-driven state transition table.
|
||||
# Handlers return an event string; the driver looks up the next state here.
|
||||
@ -238,6 +238,8 @@ class AgentLoop:
|
||||
|
||||
self.context = ContextBuilder(workspace, timezone=timezone, disabled_skills=disabled_skills)
|
||||
self.sessions = session_manager or SessionManager(workspace)
|
||||
self._checkpoint_mgr = CheckpointManager(self.sessions)
|
||||
self._turn_writer = TurnWriter(self.sessions, self._checkpoint_mgr, self.max_tool_result_chars)
|
||||
self.tools = ToolRegistry()
|
||||
# One file-read/write tracker per logical session. The tool registry is
|
||||
# shared by this loop, so tools resolve the active state via contextvars.
|
||||
@ -579,21 +581,7 @@ class AgentLoop:
|
||||
session: Session,
|
||||
**kwargs: Any,
|
||||
) -> bool:
|
||||
"""Persist the triggering user message before the turn starts.
|
||||
|
||||
Returns True if the message was persisted.
|
||||
"""
|
||||
media_paths = [p for p in (msg.media or []) if isinstance(p, str) and p]
|
||||
has_text = isinstance(msg.content, str) and msg.content.strip()
|
||||
if has_text or media_paths:
|
||||
extra: dict[str, Any] = {"media": list(media_paths)} if media_paths else {}
|
||||
extra.update(kwargs)
|
||||
text = msg.content if isinstance(msg.content, str) else ""
|
||||
session.add_message("user", text, **extra)
|
||||
self._mark_pending_user_turn(session)
|
||||
self.sessions.save(session)
|
||||
return True
|
||||
return False
|
||||
return self._turn_writer.persist_user_message_early(msg, session, **kwargs)
|
||||
|
||||
def _build_initial_messages(
|
||||
self,
|
||||
@ -1440,38 +1428,9 @@ class AgentLoop:
|
||||
should_truncate_text: bool = False,
|
||||
drop_runtime: bool = False,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Strip volatile multimodal payloads before writing session history."""
|
||||
filtered: list[dict[str, Any]] = []
|
||||
for block in content:
|
||||
if not isinstance(block, dict):
|
||||
filtered.append(block)
|
||||
continue
|
||||
|
||||
if (
|
||||
drop_runtime
|
||||
and block.get("type") == "text"
|
||||
and isinstance(block.get("text"), str)
|
||||
and block["text"].startswith(ContextBuilder._RUNTIME_CONTEXT_TAG)
|
||||
):
|
||||
continue
|
||||
|
||||
if block.get("type") == "image_url" and block.get("image_url", {}).get(
|
||||
"url", ""
|
||||
).startswith("data:image/"):
|
||||
path = (block.get("_meta") or {}).get("path", "")
|
||||
filtered.append({"type": "text", "text": image_placeholder_text(path)})
|
||||
continue
|
||||
|
||||
if block.get("type") == "text" and isinstance(block.get("text"), str):
|
||||
text = block["text"]
|
||||
if should_truncate_text and len(text) > self.max_tool_result_chars:
|
||||
text = truncate_text_fn(text, self.max_tool_result_chars)
|
||||
filtered.append({**block, "text": text})
|
||||
continue
|
||||
|
||||
filtered.append(block)
|
||||
|
||||
return filtered
|
||||
return self._turn_writer.sanitize_persisted_blocks(
|
||||
content, should_truncate_text=should_truncate_text, drop_runtime=drop_runtime,
|
||||
)
|
||||
|
||||
def _save_turn(
|
||||
self,
|
||||
@ -1481,169 +1440,28 @@ class AgentLoop:
|
||||
*,
|
||||
turn_latency_ms: int | None = None,
|
||||
) -> None:
|
||||
"""Save new-turn messages into session, truncating large tool results."""
|
||||
from datetime import datetime
|
||||
|
||||
last_assistant_idx: int | None = None
|
||||
for m in messages[skip:]:
|
||||
entry = dict(m)
|
||||
role, content = entry.get("role"), entry.get("content")
|
||||
if role == "assistant" and not content and not entry.get("tool_calls"):
|
||||
continue # skip empty assistant messages — they poison session context
|
||||
if role == "tool":
|
||||
if isinstance(content, str) and len(content) > self.max_tool_result_chars:
|
||||
entry["content"] = truncate_text_fn(content, self.max_tool_result_chars)
|
||||
elif isinstance(content, list):
|
||||
filtered = self._sanitize_persisted_blocks(content, should_truncate_text=True)
|
||||
if not filtered:
|
||||
continue
|
||||
entry["content"] = filtered
|
||||
elif role == "user":
|
||||
if isinstance(content, str) and ContextBuilder._RUNTIME_CONTEXT_TAG in content:
|
||||
# Strip the runtime-context block appended at the end.
|
||||
tag_pos = content.find(ContextBuilder._RUNTIME_CONTEXT_TAG)
|
||||
before = content[:tag_pos].rstrip("\n ")
|
||||
if before:
|
||||
entry["content"] = before
|
||||
else:
|
||||
continue
|
||||
if isinstance(content, list):
|
||||
filtered = self._sanitize_persisted_blocks(content, drop_runtime=True)
|
||||
if not filtered:
|
||||
continue
|
||||
entry["content"] = filtered
|
||||
entry.setdefault("timestamp", datetime.now().isoformat())
|
||||
session.messages.append(entry)
|
||||
if role == "assistant":
|
||||
last_assistant_idx = len(session.messages) - 1
|
||||
if turn_latency_ms is not None and last_assistant_idx is not None:
|
||||
session.messages[last_assistant_idx]["latency_ms"] = int(turn_latency_ms)
|
||||
session.updated_at = datetime.now()
|
||||
self._turn_writer.save_turn(session, messages, skip, turn_latency_ms=turn_latency_ms)
|
||||
|
||||
def _persist_subagent_followup(self, session: Session, msg: InboundMessage) -> bool:
|
||||
"""Persist subagent follow-ups before prompt assembly so history stays durable.
|
||||
|
||||
Returns True if a new entry was appended; False if the follow-up was
|
||||
deduped (same ``subagent_task_id`` already in session) or carries no
|
||||
content worth persisting.
|
||||
"""
|
||||
if not msg.content:
|
||||
return False
|
||||
task_id = msg.metadata.get("subagent_task_id") if isinstance(msg.metadata, dict) else None
|
||||
if task_id and any(
|
||||
m.get("injected_event") == "subagent_result" and m.get("subagent_task_id") == task_id
|
||||
for m in session.messages
|
||||
):
|
||||
return False
|
||||
session.add_message(
|
||||
"assistant",
|
||||
msg.content,
|
||||
sender_id=msg.sender_id,
|
||||
injected_event="subagent_result",
|
||||
subagent_task_id=task_id,
|
||||
)
|
||||
return True
|
||||
return self._turn_writer.persist_subagent_followup(session, msg)
|
||||
|
||||
def _set_runtime_checkpoint(self, session: Session, payload: dict[str, Any]) -> None:
|
||||
"""Persist the latest in-flight turn state into session metadata."""
|
||||
session.metadata[self._RUNTIME_CHECKPOINT_KEY] = payload
|
||||
self.sessions.save(session)
|
||||
self._checkpoint_mgr.set_runtime_checkpoint(session, payload)
|
||||
|
||||
def _mark_pending_user_turn(self, session: Session) -> None:
|
||||
session.metadata[self._PENDING_USER_TURN_KEY] = True
|
||||
self._checkpoint_mgr.mark_pending_user_turn(session)
|
||||
|
||||
def _clear_pending_user_turn(self, session: Session) -> None:
|
||||
session.metadata.pop(self._PENDING_USER_TURN_KEY, None)
|
||||
self._checkpoint_mgr.clear_pending_user_turn(session)
|
||||
|
||||
def _clear_runtime_checkpoint(self, session: Session) -> None:
|
||||
if self._RUNTIME_CHECKPOINT_KEY in session.metadata:
|
||||
session.metadata.pop(self._RUNTIME_CHECKPOINT_KEY, None)
|
||||
|
||||
@staticmethod
|
||||
def _checkpoint_message_key(message: dict[str, Any]) -> tuple[Any, ...]:
|
||||
return (
|
||||
message.get("role"),
|
||||
message.get("content"),
|
||||
message.get("tool_call_id"),
|
||||
message.get("name"),
|
||||
message.get("tool_calls"),
|
||||
message.get("reasoning_content"),
|
||||
message.get("thinking_blocks"),
|
||||
)
|
||||
self._checkpoint_mgr.clear_runtime_checkpoint(session)
|
||||
|
||||
def _restore_runtime_checkpoint(self, session: Session) -> bool:
|
||||
"""Materialize an unfinished turn into session history before a new request."""
|
||||
from datetime import datetime
|
||||
|
||||
checkpoint = session.metadata.get(self._RUNTIME_CHECKPOINT_KEY)
|
||||
if not isinstance(checkpoint, dict):
|
||||
return False
|
||||
|
||||
assistant_message = checkpoint.get("assistant_message")
|
||||
completed_tool_results = checkpoint.get("completed_tool_results") or []
|
||||
pending_tool_calls = checkpoint.get("pending_tool_calls") or []
|
||||
|
||||
restored_messages: list[dict[str, Any]] = []
|
||||
if isinstance(assistant_message, dict):
|
||||
restored = dict(assistant_message)
|
||||
restored.setdefault("timestamp", datetime.now().isoformat())
|
||||
restored_messages.append(restored)
|
||||
for message in completed_tool_results:
|
||||
if isinstance(message, dict):
|
||||
restored = dict(message)
|
||||
restored.setdefault("timestamp", datetime.now().isoformat())
|
||||
restored_messages.append(restored)
|
||||
for tool_call in pending_tool_calls:
|
||||
if not isinstance(tool_call, dict):
|
||||
continue
|
||||
tool_id = tool_call.get("id")
|
||||
name = ((tool_call.get("function") or {}).get("name")) or "tool"
|
||||
restored_messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_id,
|
||||
"name": name,
|
||||
"content": "Error: Task interrupted before this tool finished.",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
)
|
||||
|
||||
overlap = 0
|
||||
max_overlap = min(len(session.messages), len(restored_messages))
|
||||
for size in range(max_overlap, 0, -1):
|
||||
existing = session.messages[-size:]
|
||||
restored = restored_messages[:size]
|
||||
if all(
|
||||
self._checkpoint_message_key(left) == self._checkpoint_message_key(right)
|
||||
for left, right in zip(existing, restored)
|
||||
):
|
||||
overlap = size
|
||||
break
|
||||
session.messages.extend(restored_messages[overlap:])
|
||||
|
||||
self._clear_pending_user_turn(session)
|
||||
self._clear_runtime_checkpoint(session)
|
||||
return True
|
||||
return self._checkpoint_mgr.restore_runtime_checkpoint(session)
|
||||
|
||||
def _restore_pending_user_turn(self, session: Session) -> bool:
|
||||
"""Close a turn that only persisted the user message before crashing."""
|
||||
from datetime import datetime
|
||||
|
||||
if not session.metadata.get(self._PENDING_USER_TURN_KEY):
|
||||
return False
|
||||
|
||||
if session.messages and session.messages[-1].get("role") == "user":
|
||||
session.messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Error: Task interrupted before a response was generated.",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
)
|
||||
session.updated_at = datetime.now()
|
||||
|
||||
self._clear_pending_user_turn(session)
|
||||
return True
|
||||
return self._checkpoint_mgr.restore_pending_user_turn(session)
|
||||
|
||||
async def process_direct(
|
||||
self,
|
||||
|
||||
163
nanobot/agent/turn_writer.py
Normal file
163
nanobot/agent/turn_writer.py
Normal file
@ -0,0 +1,163 @@
|
||||
"""Turn persistence: sanitize, save, and early-persist message history."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from nanobot.agent.checkpoint import CheckpointManager
|
||||
from nanobot.agent.context import ContextBuilder
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.session.manager import Session, SessionManager
|
||||
from nanobot.utils.helpers import image_placeholder_text
|
||||
from nanobot.utils.helpers import truncate_text as truncate_text_fn
|
||||
|
||||
|
||||
class TurnWriter:
|
||||
"""Handles persisting turn messages into session history.
|
||||
|
||||
Responsible for sanitising volatile content (data-URI images, runtime
|
||||
context blocks), truncating large tool results, and stamping latency
|
||||
metadata on the last assistant message.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sessions: SessionManager,
|
||||
checkpoint: CheckpointManager,
|
||||
max_tool_result_chars: int,
|
||||
) -> None:
|
||||
self.sessions = sessions
|
||||
self._checkpoint = checkpoint
|
||||
self.max_tool_result_chars = max_tool_result_chars
|
||||
|
||||
def sanitize_persisted_blocks(
|
||||
self,
|
||||
content: list[dict[str, Any]],
|
||||
*,
|
||||
should_truncate_text: bool = False,
|
||||
drop_runtime: bool = False,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Strip volatile multimodal payloads before writing session history."""
|
||||
filtered: list[dict[str, Any]] = []
|
||||
for block in content:
|
||||
if not isinstance(block, dict):
|
||||
filtered.append(block)
|
||||
continue
|
||||
|
||||
if (
|
||||
drop_runtime
|
||||
and block.get("type") == "text"
|
||||
and isinstance(block.get("text"), str)
|
||||
and block["text"].startswith(ContextBuilder._RUNTIME_CONTEXT_TAG)
|
||||
):
|
||||
continue
|
||||
|
||||
if block.get("type") == "image_url" and block.get("image_url", {}).get(
|
||||
"url", ""
|
||||
).startswith("data:image/"):
|
||||
path = (block.get("_meta") or {}).get("path", "")
|
||||
filtered.append({"type": "text", "text": image_placeholder_text(path)})
|
||||
continue
|
||||
|
||||
if block.get("type") == "text" and isinstance(block.get("text"), str):
|
||||
text = block["text"]
|
||||
if should_truncate_text and len(text) > self.max_tool_result_chars:
|
||||
text = truncate_text_fn(text, self.max_tool_result_chars)
|
||||
filtered.append({**block, "text": text})
|
||||
continue
|
||||
|
||||
filtered.append(block)
|
||||
|
||||
return filtered
|
||||
|
||||
def save_turn(
|
||||
self,
|
||||
session: Session,
|
||||
messages: list[dict],
|
||||
skip: int,
|
||||
*,
|
||||
turn_latency_ms: int | None = None,
|
||||
) -> None:
|
||||
"""Save new-turn messages into session, truncating large tool results."""
|
||||
from datetime import datetime
|
||||
|
||||
last_assistant_idx: int | None = None
|
||||
for m in messages[skip:]:
|
||||
entry = dict(m)
|
||||
role, content = entry.get("role"), entry.get("content")
|
||||
if role == "assistant" and not content and not entry.get("tool_calls"):
|
||||
continue
|
||||
if role == "tool":
|
||||
if isinstance(content, str) and len(content) > self.max_tool_result_chars:
|
||||
entry["content"] = truncate_text_fn(content, self.max_tool_result_chars)
|
||||
elif isinstance(content, list):
|
||||
filtered = self.sanitize_persisted_blocks(content, should_truncate_text=True)
|
||||
if not filtered:
|
||||
continue
|
||||
entry["content"] = filtered
|
||||
elif role == "user":
|
||||
if isinstance(content, str) and ContextBuilder._RUNTIME_CONTEXT_TAG in content:
|
||||
tag_pos = content.find(ContextBuilder._RUNTIME_CONTEXT_TAG)
|
||||
before = content[:tag_pos].rstrip("\n ")
|
||||
if before:
|
||||
entry["content"] = before
|
||||
else:
|
||||
continue
|
||||
if isinstance(content, list):
|
||||
filtered = self.sanitize_persisted_blocks(content, drop_runtime=True)
|
||||
if not filtered:
|
||||
continue
|
||||
entry["content"] = filtered
|
||||
entry.setdefault("timestamp", datetime.now().isoformat())
|
||||
session.messages.append(entry)
|
||||
if role == "assistant":
|
||||
last_assistant_idx = len(session.messages) - 1
|
||||
if turn_latency_ms is not None and last_assistant_idx is not None:
|
||||
session.messages[last_assistant_idx]["latency_ms"] = int(turn_latency_ms)
|
||||
session.updated_at = datetime.now()
|
||||
|
||||
def persist_user_message_early(
|
||||
self,
|
||||
msg: InboundMessage,
|
||||
session: Session,
|
||||
**kwargs: Any,
|
||||
) -> bool:
|
||||
"""Persist the triggering user message before the turn starts.
|
||||
|
||||
Returns True if the message was persisted.
|
||||
"""
|
||||
media_paths = [p for p in (msg.media or []) if isinstance(p, str) and p]
|
||||
has_text = isinstance(msg.content, str) and msg.content.strip()
|
||||
if has_text or media_paths:
|
||||
extra: dict[str, Any] = {"media": list(media_paths)} if media_paths else {}
|
||||
extra.update(kwargs)
|
||||
text = msg.content if isinstance(msg.content, str) else ""
|
||||
session.add_message("user", text, **extra)
|
||||
self._checkpoint.mark_pending_user_turn(session)
|
||||
self.sessions.save(session)
|
||||
return True
|
||||
return False
|
||||
|
||||
def persist_subagent_followup(self, session: Session, msg: InboundMessage) -> bool:
|
||||
"""Persist subagent follow-ups before prompt assembly so history stays durable.
|
||||
|
||||
Returns True if a new entry was appended; False if the follow-up was
|
||||
deduped (same ``subagent_task_id`` already in session) or carries no
|
||||
content worth persisting.
|
||||
"""
|
||||
if not msg.content:
|
||||
return False
|
||||
task_id = msg.metadata.get("subagent_task_id") if isinstance(msg.metadata, dict) else None
|
||||
if task_id and any(
|
||||
m.get("injected_event") == "subagent_result" and m.get("subagent_task_id") == task_id
|
||||
for m in session.messages
|
||||
):
|
||||
return False
|
||||
session.add_message(
|
||||
"assistant",
|
||||
msg.content,
|
||||
sender_id=msg.sender_id,
|
||||
injected_event="subagent_result",
|
||||
subagent_task_id=task_id,
|
||||
)
|
||||
return True
|
||||
@ -19,9 +19,17 @@ from nanobot.utils.webui_titles import (
|
||||
|
||||
def _mk_loop() -> AgentLoop:
|
||||
loop = AgentLoop.__new__(AgentLoop)
|
||||
from nanobot.agent.checkpoint import CheckpointManager
|
||||
from nanobot.agent.turn_writer import TurnWriter
|
||||
from nanobot.config.schema import AgentDefaults
|
||||
|
||||
loop.max_tool_result_chars = AgentDefaults().max_tool_result_chars
|
||||
loop._checkpoint_mgr = CheckpointManager(sessions=None) # type: ignore[arg-type]
|
||||
loop._turn_writer = TurnWriter(
|
||||
sessions=None, # type: ignore[arg-type]
|
||||
checkpoint=loop._checkpoint_mgr,
|
||||
max_tool_result_chars=loop.max_tool_result_chars,
|
||||
)
|
||||
return loop
|
||||
|
||||
|
||||
|
||||
@ -20,10 +20,12 @@ def test_sanitize_persisted_blocks_truncate_text_shadowing_regression() -> None:
|
||||
assert "should_truncate_text" in sig.parameters
|
||||
assert "truncate_text" not in sig.parameters
|
||||
|
||||
dummy = SimpleNamespace(max_tool_result_chars=5)
|
||||
from nanobot.agent.turn_writer import TurnWriter
|
||||
|
||||
writer = TurnWriter(sessions=None, checkpoint=None, max_tool_result_chars=5) # type: ignore[arg-type]
|
||||
content = [{"type": "text", "text": "0123456789"}]
|
||||
|
||||
out = AgentLoop._sanitize_persisted_blocks(dummy, content, should_truncate_text=True)
|
||||
out = writer.sanitize_persisted_blocks(content, should_truncate_text=True)
|
||||
assert isinstance(out, list)
|
||||
assert out and out[0]["type"] == "text"
|
||||
assert isinstance(out[0]["text"], str)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user