mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-01 15:25:56 +00:00
Include persisted turn timestamps when assembling LLM prompts so relative-date references like yesterday and today have concrete anchors. Made-with: Cursor
472 lines
18 KiB
Python
472 lines
18 KiB
Python
"""Session management for conversation history."""
|
|
|
|
import json
|
|
import os
|
|
import shutil
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
from loguru import logger
|
|
|
|
from nanobot.config.paths import get_legacy_sessions_dir
|
|
from nanobot.utils.helpers import (
|
|
ensure_dir,
|
|
find_legal_message_start,
|
|
image_placeholder_text,
|
|
safe_filename,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class Session:
|
|
"""A conversation session."""
|
|
|
|
key: str # channel:chat_id
|
|
messages: list[dict[str, Any]] = field(default_factory=list)
|
|
created_at: datetime = field(default_factory=datetime.now)
|
|
updated_at: datetime = field(default_factory=datetime.now)
|
|
metadata: dict[str, Any] = field(default_factory=dict)
|
|
last_consolidated: int = 0 # Number of messages already consolidated to files
|
|
|
|
@staticmethod
|
|
def _annotate_message_time(message: dict[str, Any], content: Any) -> Any:
|
|
"""Expose persisted turn timestamps to the model for relative-date reasoning."""
|
|
timestamp = message.get("timestamp")
|
|
if (
|
|
not timestamp
|
|
or message.get("role") not in {"user", "assistant"}
|
|
or not isinstance(content, str)
|
|
):
|
|
return content
|
|
return f"[Message Time: {timestamp}]\n{content}"
|
|
|
|
def add_message(self, role: str, content: str, **kwargs: Any) -> None:
|
|
"""Add a message to the session."""
|
|
msg = {
|
|
"role": role,
|
|
"content": content,
|
|
"timestamp": datetime.now().isoformat(),
|
|
**kwargs
|
|
}
|
|
self.messages.append(msg)
|
|
self.updated_at = datetime.now()
|
|
|
|
def get_history(
|
|
self,
|
|
max_messages: int = 500,
|
|
*,
|
|
include_timestamps: bool = False,
|
|
) -> list[dict[str, Any]]:
|
|
"""Return unconsolidated messages for LLM input, aligned to a legal tool-call boundary."""
|
|
unconsolidated = self.messages[self.last_consolidated:]
|
|
sliced = unconsolidated[-max_messages:]
|
|
|
|
# Avoid starting mid-turn when possible, except for proactive
|
|
# assistant deliveries that the user may be replying to.
|
|
for i, message in enumerate(sliced):
|
|
if message.get("role") == "user":
|
|
start = i
|
|
if i > 0 and sliced[i - 1].get("_channel_delivery"):
|
|
start = i - 1
|
|
sliced = sliced[start:]
|
|
break
|
|
|
|
# Drop orphan tool results at the front.
|
|
start = find_legal_message_start(sliced)
|
|
if start:
|
|
sliced = sliced[start:]
|
|
|
|
out: list[dict[str, Any]] = []
|
|
for message in sliced:
|
|
content = message.get("content", "")
|
|
# Synthesize an ``[image: path]`` breadcrumb from the persisted
|
|
# ``media`` kwarg so LLM replay still sees *something* where the
|
|
# image used to be. Without this, an image-only user turn
|
|
# replays as an empty user message — the assistant's reply then
|
|
# looks like it's responding to nothing.
|
|
media = message.get("media")
|
|
if isinstance(media, list) and media and isinstance(content, str):
|
|
breadcrumbs = "\n".join(
|
|
image_placeholder_text(p) for p in media if isinstance(p, str) and p
|
|
)
|
|
content = f"{content}\n{breadcrumbs}" if content else breadcrumbs
|
|
if include_timestamps:
|
|
content = self._annotate_message_time(message, content)
|
|
entry: dict[str, Any] = {"role": message["role"], "content": content}
|
|
for key in ("tool_calls", "tool_call_id", "name", "reasoning_content"):
|
|
if key in message:
|
|
entry[key] = message[key]
|
|
out.append(entry)
|
|
return out
|
|
|
|
def clear(self) -> None:
|
|
"""Clear all messages and reset session to initial state."""
|
|
self.messages = []
|
|
self.last_consolidated = 0
|
|
self.updated_at = datetime.now()
|
|
|
|
def retain_recent_legal_suffix(self, max_messages: int) -> None:
|
|
"""Keep a legal recent suffix, mirroring get_history boundary rules."""
|
|
if max_messages <= 0:
|
|
self.clear()
|
|
return
|
|
if len(self.messages) <= max_messages:
|
|
return
|
|
|
|
start_idx = max(0, len(self.messages) - max_messages)
|
|
|
|
# If the cutoff lands mid-turn, extend backward to the nearest user turn.
|
|
while start_idx > 0 and self.messages[start_idx].get("role") != "user":
|
|
start_idx -= 1
|
|
|
|
retained = self.messages[start_idx:]
|
|
|
|
# Mirror get_history(): avoid persisting orphan tool results at the front.
|
|
start = find_legal_message_start(retained)
|
|
if start:
|
|
retained = retained[start:]
|
|
|
|
dropped = len(self.messages) - len(retained)
|
|
self.messages = retained
|
|
self.last_consolidated = max(0, self.last_consolidated - dropped)
|
|
self.updated_at = datetime.now()
|
|
|
|
|
|
class SessionManager:
|
|
"""
|
|
Manages conversation sessions.
|
|
|
|
Sessions are stored as JSONL files in the sessions directory.
|
|
"""
|
|
|
|
def __init__(self, workspace: Path):
|
|
self.workspace = workspace
|
|
self.sessions_dir = ensure_dir(self.workspace / "sessions")
|
|
self.legacy_sessions_dir = get_legacy_sessions_dir()
|
|
self._cache: dict[str, Session] = {}
|
|
|
|
@staticmethod
|
|
def safe_key(key: str) -> str:
|
|
"""Public helper used by HTTP handlers to map an arbitrary key to a stable filename stem."""
|
|
return safe_filename(key.replace(":", "_"))
|
|
|
|
def _get_session_path(self, key: str) -> Path:
|
|
"""Get the file path for a session."""
|
|
return self.sessions_dir / f"{self.safe_key(key)}.jsonl"
|
|
|
|
def _get_legacy_session_path(self, key: str) -> Path:
|
|
"""Legacy global session path (~/.nanobot/sessions/)."""
|
|
return self.legacy_sessions_dir / f"{self.safe_key(key)}.jsonl"
|
|
|
|
def get_or_create(self, key: str) -> Session:
|
|
"""
|
|
Get an existing session or create a new one.
|
|
|
|
Args:
|
|
key: Session key (usually channel:chat_id).
|
|
|
|
Returns:
|
|
The session.
|
|
"""
|
|
if key in self._cache:
|
|
return self._cache[key]
|
|
|
|
session = self._load(key)
|
|
if session is None:
|
|
session = Session(key=key)
|
|
|
|
self._cache[key] = session
|
|
return session
|
|
|
|
def _load(self, key: str) -> Session | None:
|
|
"""Load a session from disk."""
|
|
path = self._get_session_path(key)
|
|
if not path.exists():
|
|
legacy_path = self._get_legacy_session_path(key)
|
|
if legacy_path.exists():
|
|
try:
|
|
shutil.move(str(legacy_path), str(path))
|
|
logger.info("Migrated session {} from legacy path", key)
|
|
except Exception:
|
|
logger.exception("Failed to migrate session {}", key)
|
|
|
|
if not path.exists():
|
|
return None
|
|
|
|
try:
|
|
messages = []
|
|
metadata = {}
|
|
created_at = None
|
|
updated_at = None
|
|
last_consolidated = 0
|
|
|
|
with open(path, encoding="utf-8") as f:
|
|
for line in f:
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
|
|
data = json.loads(line)
|
|
|
|
if data.get("_type") == "metadata":
|
|
metadata = data.get("metadata", {})
|
|
created_at = datetime.fromisoformat(data["created_at"]) if data.get("created_at") else None
|
|
updated_at = datetime.fromisoformat(data["updated_at"]) if data.get("updated_at") else None
|
|
last_consolidated = data.get("last_consolidated", 0)
|
|
else:
|
|
messages.append(data)
|
|
|
|
return Session(
|
|
key=key,
|
|
messages=messages,
|
|
created_at=created_at or datetime.now(),
|
|
updated_at=updated_at or datetime.now(),
|
|
metadata=metadata,
|
|
last_consolidated=last_consolidated
|
|
)
|
|
except Exception as e:
|
|
logger.warning("Failed to load session {}: {}", key, e)
|
|
repaired = self._repair(key)
|
|
if repaired is not None:
|
|
logger.info("Recovered session {} from corrupt file ({} messages)", key, len(repaired.messages))
|
|
return repaired
|
|
|
|
def _repair(self, key: str) -> Session | None:
|
|
"""Attempt to recover a session from a corrupt JSONL file."""
|
|
path = self._get_session_path(key)
|
|
if not path.exists():
|
|
return None
|
|
|
|
try:
|
|
messages: list[dict[str, Any]] = []
|
|
metadata: dict[str, Any] = {}
|
|
created_at: datetime | None = None
|
|
updated_at: datetime | None = None
|
|
last_consolidated = 0
|
|
skipped = 0
|
|
|
|
with open(path, encoding="utf-8") as f:
|
|
for line in f:
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
try:
|
|
data = json.loads(line)
|
|
except json.JSONDecodeError:
|
|
skipped += 1
|
|
continue
|
|
|
|
if data.get("_type") == "metadata":
|
|
metadata = data.get("metadata", {})
|
|
if data.get("created_at"):
|
|
try:
|
|
created_at = datetime.fromisoformat(data["created_at"])
|
|
except (ValueError, TypeError):
|
|
pass
|
|
if data.get("updated_at"):
|
|
try:
|
|
updated_at = datetime.fromisoformat(data["updated_at"])
|
|
except (ValueError, TypeError):
|
|
pass
|
|
last_consolidated = data.get("last_consolidated", 0)
|
|
else:
|
|
messages.append(data)
|
|
|
|
if skipped:
|
|
logger.warning("Skipped {} corrupt lines in session {}", skipped, key)
|
|
|
|
if not messages and not metadata:
|
|
return None
|
|
|
|
return Session(
|
|
key=key,
|
|
messages=messages,
|
|
created_at=created_at or datetime.now(),
|
|
updated_at=updated_at or datetime.now(),
|
|
metadata=metadata,
|
|
last_consolidated=last_consolidated
|
|
)
|
|
except Exception as e:
|
|
logger.warning("Repair failed for session {}: {}", key, e)
|
|
return None
|
|
|
|
@staticmethod
|
|
def _session_payload(session: Session) -> dict[str, Any]:
|
|
return {
|
|
"key": session.key,
|
|
"created_at": session.created_at.isoformat(),
|
|
"updated_at": session.updated_at.isoformat(),
|
|
"metadata": session.metadata,
|
|
"messages": session.messages,
|
|
}
|
|
|
|
def save(self, session: Session, *, fsync: bool = False) -> None:
|
|
"""Save a session to disk atomically.
|
|
|
|
When *fsync* is ``True`` the final file and its parent directory are
|
|
explicitly flushed to durable storage. This is intentionally off by
|
|
default (the OS page-cache is sufficient for normal operation) but
|
|
should be enabled during graceful shutdown so that filesystems with
|
|
write-back caching (e.g. rclone VFS, NFS, FUSE mounts) do not lose
|
|
the most recent writes.
|
|
"""
|
|
path = self._get_session_path(session.key)
|
|
tmp_path = path.with_suffix(".jsonl.tmp")
|
|
|
|
try:
|
|
with open(tmp_path, "w", encoding="utf-8") as f:
|
|
metadata_line = {
|
|
"_type": "metadata",
|
|
"key": session.key,
|
|
"created_at": session.created_at.isoformat(),
|
|
"updated_at": session.updated_at.isoformat(),
|
|
"metadata": session.metadata,
|
|
"last_consolidated": session.last_consolidated
|
|
}
|
|
f.write(json.dumps(metadata_line, ensure_ascii=False) + "\n")
|
|
for msg in session.messages:
|
|
f.write(json.dumps(msg, ensure_ascii=False) + "\n")
|
|
if fsync:
|
|
f.flush()
|
|
os.fsync(f.fileno())
|
|
|
|
os.replace(tmp_path, path)
|
|
|
|
if fsync:
|
|
# fsync the directory so the rename is durable.
|
|
# On Windows, opening a directory with O_RDONLY raises
|
|
# PermissionError — skip the dir sync there (NTFS
|
|
# journals metadata synchronously).
|
|
try:
|
|
fd = os.open(str(path.parent), os.O_RDONLY)
|
|
try:
|
|
os.fsync(fd)
|
|
finally:
|
|
os.close(fd)
|
|
except PermissionError:
|
|
pass # Windows — directory fsync not supported
|
|
except BaseException:
|
|
tmp_path.unlink(missing_ok=True)
|
|
raise
|
|
|
|
self._cache[session.key] = session
|
|
|
|
def flush_all(self) -> int:
|
|
"""Re-save every cached session with fsync for durable shutdown.
|
|
|
|
Returns the number of sessions flushed. Errors on individual
|
|
sessions are logged but do not prevent other sessions from being
|
|
flushed.
|
|
"""
|
|
flushed = 0
|
|
for key, session in list(self._cache.items()):
|
|
try:
|
|
self.save(session, fsync=True)
|
|
flushed += 1
|
|
except Exception:
|
|
logger.warning("Failed to flush session {}", key, exc_info=True)
|
|
return flushed
|
|
|
|
def invalidate(self, key: str) -> None:
|
|
"""Remove a session from the in-memory cache."""
|
|
self._cache.pop(key, None)
|
|
|
|
def delete_session(self, key: str) -> bool:
|
|
"""Remove a session from disk and the in-memory cache.
|
|
|
|
Returns True if a JSONL file was found and unlinked.
|
|
"""
|
|
path = self._get_session_path(key)
|
|
self.invalidate(key)
|
|
if not path.exists():
|
|
return False
|
|
try:
|
|
path.unlink()
|
|
return True
|
|
except OSError as e:
|
|
logger.warning("Failed to delete session file {}: {}", path, e)
|
|
return False
|
|
|
|
def read_session_file(self, key: str) -> dict[str, Any] | None:
|
|
"""Load a session from disk without caching; intended for read-only HTTP endpoints.
|
|
|
|
Returns ``{"key", "created_at", "updated_at", "metadata", "messages"}`` or
|
|
``None`` when the session file does not exist or fails to parse.
|
|
"""
|
|
path = self._get_session_path(key)
|
|
if not path.exists():
|
|
return None
|
|
try:
|
|
messages: list[dict[str, Any]] = []
|
|
metadata: dict[str, Any] = {}
|
|
created_at: str | None = None
|
|
updated_at: str | None = None
|
|
stored_key: str | None = None
|
|
with open(path, encoding="utf-8") as f:
|
|
for line in f:
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
data = json.loads(line)
|
|
if data.get("_type") == "metadata":
|
|
metadata = data.get("metadata", {})
|
|
created_at = data.get("created_at")
|
|
updated_at = data.get("updated_at")
|
|
stored_key = data.get("key")
|
|
else:
|
|
messages.append(data)
|
|
return {
|
|
"key": stored_key or key,
|
|
"created_at": created_at,
|
|
"updated_at": updated_at,
|
|
"metadata": metadata,
|
|
"messages": messages,
|
|
}
|
|
except Exception as e:
|
|
logger.warning("Failed to read session {}: {}", key, e)
|
|
repaired = self._repair(key)
|
|
if repaired is not None:
|
|
logger.info("Recovered read-only session view {} from corrupt file", key)
|
|
return self._session_payload(repaired)
|
|
return None
|
|
|
|
def list_sessions(self) -> list[dict[str, Any]]:
|
|
"""
|
|
List all sessions.
|
|
|
|
Returns:
|
|
List of session info dicts.
|
|
"""
|
|
sessions = []
|
|
|
|
for path in self.sessions_dir.glob("*.jsonl"):
|
|
fallback_key = path.stem.replace("_", ":", 1)
|
|
try:
|
|
# Read just the metadata line
|
|
with open(path, encoding="utf-8") as f:
|
|
first_line = f.readline().strip()
|
|
if first_line:
|
|
data = json.loads(first_line)
|
|
if data.get("_type") == "metadata":
|
|
key = data.get("key") or path.stem.replace("_", ":", 1)
|
|
sessions.append({
|
|
"key": key,
|
|
"created_at": data.get("created_at"),
|
|
"updated_at": data.get("updated_at"),
|
|
"path": str(path)
|
|
})
|
|
except Exception:
|
|
repaired = self._repair(fallback_key)
|
|
if repaired is not None:
|
|
sessions.append({
|
|
"key": repaired.key,
|
|
"created_at": repaired.created_at.isoformat(),
|
|
"updated_at": repaired.updated_at.isoformat(),
|
|
"path": str(path)
|
|
})
|
|
continue
|
|
|
|
return sorted(sessions, key=lambda x: x.get("updated_at", ""), reverse=True)
|