nanobot/nanobot/session/manager.py
Xubin Ren df37a36174 fix(agent): expose session timestamps in model context
Include persisted turn timestamps when assembling LLM prompts so relative-date references like yesterday and today have concrete anchors.

Made-with: Cursor
2026-04-26 17:42:58 +00:00

472 lines
18 KiB
Python

"""Session management for conversation history."""
import json
import os
import shutil
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from typing import Any
from loguru import logger
from nanobot.config.paths import get_legacy_sessions_dir
from nanobot.utils.helpers import (
ensure_dir,
find_legal_message_start,
image_placeholder_text,
safe_filename,
)
@dataclass
class Session:
"""A conversation session."""
key: str # channel:chat_id
messages: list[dict[str, Any]] = field(default_factory=list)
created_at: datetime = field(default_factory=datetime.now)
updated_at: datetime = field(default_factory=datetime.now)
metadata: dict[str, Any] = field(default_factory=dict)
last_consolidated: int = 0 # Number of messages already consolidated to files
@staticmethod
def _annotate_message_time(message: dict[str, Any], content: Any) -> Any:
"""Expose persisted turn timestamps to the model for relative-date reasoning."""
timestamp = message.get("timestamp")
if (
not timestamp
or message.get("role") not in {"user", "assistant"}
or not isinstance(content, str)
):
return content
return f"[Message Time: {timestamp}]\n{content}"
def add_message(self, role: str, content: str, **kwargs: Any) -> None:
"""Add a message to the session."""
msg = {
"role": role,
"content": content,
"timestamp": datetime.now().isoformat(),
**kwargs
}
self.messages.append(msg)
self.updated_at = datetime.now()
def get_history(
self,
max_messages: int = 500,
*,
include_timestamps: bool = False,
) -> list[dict[str, Any]]:
"""Return unconsolidated messages for LLM input, aligned to a legal tool-call boundary."""
unconsolidated = self.messages[self.last_consolidated:]
sliced = unconsolidated[-max_messages:]
# Avoid starting mid-turn when possible, except for proactive
# assistant deliveries that the user may be replying to.
for i, message in enumerate(sliced):
if message.get("role") == "user":
start = i
if i > 0 and sliced[i - 1].get("_channel_delivery"):
start = i - 1
sliced = sliced[start:]
break
# Drop orphan tool results at the front.
start = find_legal_message_start(sliced)
if start:
sliced = sliced[start:]
out: list[dict[str, Any]] = []
for message in sliced:
content = message.get("content", "")
# Synthesize an ``[image: path]`` breadcrumb from the persisted
# ``media`` kwarg so LLM replay still sees *something* where the
# image used to be. Without this, an image-only user turn
# replays as an empty user message — the assistant's reply then
# looks like it's responding to nothing.
media = message.get("media")
if isinstance(media, list) and media and isinstance(content, str):
breadcrumbs = "\n".join(
image_placeholder_text(p) for p in media if isinstance(p, str) and p
)
content = f"{content}\n{breadcrumbs}" if content else breadcrumbs
if include_timestamps:
content = self._annotate_message_time(message, content)
entry: dict[str, Any] = {"role": message["role"], "content": content}
for key in ("tool_calls", "tool_call_id", "name", "reasoning_content"):
if key in message:
entry[key] = message[key]
out.append(entry)
return out
def clear(self) -> None:
"""Clear all messages and reset session to initial state."""
self.messages = []
self.last_consolidated = 0
self.updated_at = datetime.now()
def retain_recent_legal_suffix(self, max_messages: int) -> None:
"""Keep a legal recent suffix, mirroring get_history boundary rules."""
if max_messages <= 0:
self.clear()
return
if len(self.messages) <= max_messages:
return
start_idx = max(0, len(self.messages) - max_messages)
# If the cutoff lands mid-turn, extend backward to the nearest user turn.
while start_idx > 0 and self.messages[start_idx].get("role") != "user":
start_idx -= 1
retained = self.messages[start_idx:]
# Mirror get_history(): avoid persisting orphan tool results at the front.
start = find_legal_message_start(retained)
if start:
retained = retained[start:]
dropped = len(self.messages) - len(retained)
self.messages = retained
self.last_consolidated = max(0, self.last_consolidated - dropped)
self.updated_at = datetime.now()
class SessionManager:
"""
Manages conversation sessions.
Sessions are stored as JSONL files in the sessions directory.
"""
def __init__(self, workspace: Path):
self.workspace = workspace
self.sessions_dir = ensure_dir(self.workspace / "sessions")
self.legacy_sessions_dir = get_legacy_sessions_dir()
self._cache: dict[str, Session] = {}
@staticmethod
def safe_key(key: str) -> str:
"""Public helper used by HTTP handlers to map an arbitrary key to a stable filename stem."""
return safe_filename(key.replace(":", "_"))
def _get_session_path(self, key: str) -> Path:
"""Get the file path for a session."""
return self.sessions_dir / f"{self.safe_key(key)}.jsonl"
def _get_legacy_session_path(self, key: str) -> Path:
"""Legacy global session path (~/.nanobot/sessions/)."""
return self.legacy_sessions_dir / f"{self.safe_key(key)}.jsonl"
def get_or_create(self, key: str) -> Session:
"""
Get an existing session or create a new one.
Args:
key: Session key (usually channel:chat_id).
Returns:
The session.
"""
if key in self._cache:
return self._cache[key]
session = self._load(key)
if session is None:
session = Session(key=key)
self._cache[key] = session
return session
def _load(self, key: str) -> Session | None:
"""Load a session from disk."""
path = self._get_session_path(key)
if not path.exists():
legacy_path = self._get_legacy_session_path(key)
if legacy_path.exists():
try:
shutil.move(str(legacy_path), str(path))
logger.info("Migrated session {} from legacy path", key)
except Exception:
logger.exception("Failed to migrate session {}", key)
if not path.exists():
return None
try:
messages = []
metadata = {}
created_at = None
updated_at = None
last_consolidated = 0
with open(path, encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
data = json.loads(line)
if data.get("_type") == "metadata":
metadata = data.get("metadata", {})
created_at = datetime.fromisoformat(data["created_at"]) if data.get("created_at") else None
updated_at = datetime.fromisoformat(data["updated_at"]) if data.get("updated_at") else None
last_consolidated = data.get("last_consolidated", 0)
else:
messages.append(data)
return Session(
key=key,
messages=messages,
created_at=created_at or datetime.now(),
updated_at=updated_at or datetime.now(),
metadata=metadata,
last_consolidated=last_consolidated
)
except Exception as e:
logger.warning("Failed to load session {}: {}", key, e)
repaired = self._repair(key)
if repaired is not None:
logger.info("Recovered session {} from corrupt file ({} messages)", key, len(repaired.messages))
return repaired
def _repair(self, key: str) -> Session | None:
"""Attempt to recover a session from a corrupt JSONL file."""
path = self._get_session_path(key)
if not path.exists():
return None
try:
messages: list[dict[str, Any]] = []
metadata: dict[str, Any] = {}
created_at: datetime | None = None
updated_at: datetime | None = None
last_consolidated = 0
skipped = 0
with open(path, encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
data = json.loads(line)
except json.JSONDecodeError:
skipped += 1
continue
if data.get("_type") == "metadata":
metadata = data.get("metadata", {})
if data.get("created_at"):
try:
created_at = datetime.fromisoformat(data["created_at"])
except (ValueError, TypeError):
pass
if data.get("updated_at"):
try:
updated_at = datetime.fromisoformat(data["updated_at"])
except (ValueError, TypeError):
pass
last_consolidated = data.get("last_consolidated", 0)
else:
messages.append(data)
if skipped:
logger.warning("Skipped {} corrupt lines in session {}", skipped, key)
if not messages and not metadata:
return None
return Session(
key=key,
messages=messages,
created_at=created_at or datetime.now(),
updated_at=updated_at or datetime.now(),
metadata=metadata,
last_consolidated=last_consolidated
)
except Exception as e:
logger.warning("Repair failed for session {}: {}", key, e)
return None
@staticmethod
def _session_payload(session: Session) -> dict[str, Any]:
return {
"key": session.key,
"created_at": session.created_at.isoformat(),
"updated_at": session.updated_at.isoformat(),
"metadata": session.metadata,
"messages": session.messages,
}
def save(self, session: Session, *, fsync: bool = False) -> None:
"""Save a session to disk atomically.
When *fsync* is ``True`` the final file and its parent directory are
explicitly flushed to durable storage. This is intentionally off by
default (the OS page-cache is sufficient for normal operation) but
should be enabled during graceful shutdown so that filesystems with
write-back caching (e.g. rclone VFS, NFS, FUSE mounts) do not lose
the most recent writes.
"""
path = self._get_session_path(session.key)
tmp_path = path.with_suffix(".jsonl.tmp")
try:
with open(tmp_path, "w", encoding="utf-8") as f:
metadata_line = {
"_type": "metadata",
"key": session.key,
"created_at": session.created_at.isoformat(),
"updated_at": session.updated_at.isoformat(),
"metadata": session.metadata,
"last_consolidated": session.last_consolidated
}
f.write(json.dumps(metadata_line, ensure_ascii=False) + "\n")
for msg in session.messages:
f.write(json.dumps(msg, ensure_ascii=False) + "\n")
if fsync:
f.flush()
os.fsync(f.fileno())
os.replace(tmp_path, path)
if fsync:
# fsync the directory so the rename is durable.
# On Windows, opening a directory with O_RDONLY raises
# PermissionError — skip the dir sync there (NTFS
# journals metadata synchronously).
try:
fd = os.open(str(path.parent), os.O_RDONLY)
try:
os.fsync(fd)
finally:
os.close(fd)
except PermissionError:
pass # Windows — directory fsync not supported
except BaseException:
tmp_path.unlink(missing_ok=True)
raise
self._cache[session.key] = session
def flush_all(self) -> int:
"""Re-save every cached session with fsync for durable shutdown.
Returns the number of sessions flushed. Errors on individual
sessions are logged but do not prevent other sessions from being
flushed.
"""
flushed = 0
for key, session in list(self._cache.items()):
try:
self.save(session, fsync=True)
flushed += 1
except Exception:
logger.warning("Failed to flush session {}", key, exc_info=True)
return flushed
def invalidate(self, key: str) -> None:
"""Remove a session from the in-memory cache."""
self._cache.pop(key, None)
def delete_session(self, key: str) -> bool:
"""Remove a session from disk and the in-memory cache.
Returns True if a JSONL file was found and unlinked.
"""
path = self._get_session_path(key)
self.invalidate(key)
if not path.exists():
return False
try:
path.unlink()
return True
except OSError as e:
logger.warning("Failed to delete session file {}: {}", path, e)
return False
def read_session_file(self, key: str) -> dict[str, Any] | None:
"""Load a session from disk without caching; intended for read-only HTTP endpoints.
Returns ``{"key", "created_at", "updated_at", "metadata", "messages"}`` or
``None`` when the session file does not exist or fails to parse.
"""
path = self._get_session_path(key)
if not path.exists():
return None
try:
messages: list[dict[str, Any]] = []
metadata: dict[str, Any] = {}
created_at: str | None = None
updated_at: str | None = None
stored_key: str | None = None
with open(path, encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
data = json.loads(line)
if data.get("_type") == "metadata":
metadata = data.get("metadata", {})
created_at = data.get("created_at")
updated_at = data.get("updated_at")
stored_key = data.get("key")
else:
messages.append(data)
return {
"key": stored_key or key,
"created_at": created_at,
"updated_at": updated_at,
"metadata": metadata,
"messages": messages,
}
except Exception as e:
logger.warning("Failed to read session {}: {}", key, e)
repaired = self._repair(key)
if repaired is not None:
logger.info("Recovered read-only session view {} from corrupt file", key)
return self._session_payload(repaired)
return None
def list_sessions(self) -> list[dict[str, Any]]:
"""
List all sessions.
Returns:
List of session info dicts.
"""
sessions = []
for path in self.sessions_dir.glob("*.jsonl"):
fallback_key = path.stem.replace("_", ":", 1)
try:
# Read just the metadata line
with open(path, encoding="utf-8") as f:
first_line = f.readline().strip()
if first_line:
data = json.loads(first_line)
if data.get("_type") == "metadata":
key = data.get("key") or path.stem.replace("_", ":", 1)
sessions.append({
"key": key,
"created_at": data.get("created_at"),
"updated_at": data.get("updated_at"),
"path": str(path)
})
except Exception:
repaired = self._repair(fallback_key)
if repaired is not None:
sessions.append({
"key": repaired.key,
"created_at": repaired.created_at.isoformat(),
"updated_at": repaired.updated_at.isoformat(),
"path": str(path)
})
continue
return sorted(sessions, key=lambda x: x.get("updated_at", ""), reverse=True)