"""Session management for conversation history.""" import json import os import shutil from dataclasses import dataclass, field from datetime import datetime from pathlib import Path from typing import Any from loguru import logger from nanobot.config.paths import get_legacy_sessions_dir from nanobot.utils.helpers import ( ensure_dir, find_legal_message_start, image_placeholder_text, safe_filename, ) @dataclass class Session: """A conversation session.""" key: str # channel:chat_id messages: list[dict[str, Any]] = field(default_factory=list) created_at: datetime = field(default_factory=datetime.now) updated_at: datetime = field(default_factory=datetime.now) metadata: dict[str, Any] = field(default_factory=dict) last_consolidated: int = 0 # Number of messages already consolidated to files @staticmethod def _annotate_message_time(message: dict[str, Any], content: Any) -> Any: """Expose persisted turn timestamps to the model for relative-date reasoning.""" timestamp = message.get("timestamp") if ( not timestamp or message.get("role") not in {"user", "assistant"} or not isinstance(content, str) ): return content return f"[Message Time: {timestamp}]\n{content}" def add_message(self, role: str, content: str, **kwargs: Any) -> None: """Add a message to the session.""" msg = { "role": role, "content": content, "timestamp": datetime.now().isoformat(), **kwargs } self.messages.append(msg) self.updated_at = datetime.now() def get_history( self, max_messages: int = 500, *, include_timestamps: bool = False, ) -> list[dict[str, Any]]: """Return unconsolidated messages for LLM input, aligned to a legal tool-call boundary.""" unconsolidated = self.messages[self.last_consolidated:] sliced = unconsolidated[-max_messages:] # Avoid starting mid-turn when possible, except for proactive # assistant deliveries that the user may be replying to. for i, message in enumerate(sliced): if message.get("role") == "user": start = i if i > 0 and sliced[i - 1].get("_channel_delivery"): start = i - 1 sliced = sliced[start:] break # Drop orphan tool results at the front. start = find_legal_message_start(sliced) if start: sliced = sliced[start:] out: list[dict[str, Any]] = [] for message in sliced: content = message.get("content", "") # Synthesize an ``[image: path]`` breadcrumb from the persisted # ``media`` kwarg so LLM replay still sees *something* where the # image used to be. Without this, an image-only user turn # replays as an empty user message — the assistant's reply then # looks like it's responding to nothing. media = message.get("media") if isinstance(media, list) and media and isinstance(content, str): breadcrumbs = "\n".join( image_placeholder_text(p) for p in media if isinstance(p, str) and p ) content = f"{content}\n{breadcrumbs}" if content else breadcrumbs if include_timestamps: content = self._annotate_message_time(message, content) entry: dict[str, Any] = {"role": message["role"], "content": content} for key in ("tool_calls", "tool_call_id", "name", "reasoning_content"): if key in message: entry[key] = message[key] out.append(entry) return out def clear(self) -> None: """Clear all messages and reset session to initial state.""" self.messages = [] self.last_consolidated = 0 self.updated_at = datetime.now() def retain_recent_legal_suffix(self, max_messages: int) -> None: """Keep a legal recent suffix, mirroring get_history boundary rules.""" if max_messages <= 0: self.clear() return if len(self.messages) <= max_messages: return start_idx = max(0, len(self.messages) - max_messages) # If the cutoff lands mid-turn, extend backward to the nearest user turn. while start_idx > 0 and self.messages[start_idx].get("role") != "user": start_idx -= 1 retained = self.messages[start_idx:] # Mirror get_history(): avoid persisting orphan tool results at the front. start = find_legal_message_start(retained) if start: retained = retained[start:] dropped = len(self.messages) - len(retained) self.messages = retained self.last_consolidated = max(0, self.last_consolidated - dropped) self.updated_at = datetime.now() class SessionManager: """ Manages conversation sessions. Sessions are stored as JSONL files in the sessions directory. """ def __init__(self, workspace: Path): self.workspace = workspace self.sessions_dir = ensure_dir(self.workspace / "sessions") self.legacy_sessions_dir = get_legacy_sessions_dir() self._cache: dict[str, Session] = {} @staticmethod def safe_key(key: str) -> str: """Public helper used by HTTP handlers to map an arbitrary key to a stable filename stem.""" return safe_filename(key.replace(":", "_")) def _get_session_path(self, key: str) -> Path: """Get the file path for a session.""" return self.sessions_dir / f"{self.safe_key(key)}.jsonl" def _get_legacy_session_path(self, key: str) -> Path: """Legacy global session path (~/.nanobot/sessions/).""" return self.legacy_sessions_dir / f"{self.safe_key(key)}.jsonl" def get_or_create(self, key: str) -> Session: """ Get an existing session or create a new one. Args: key: Session key (usually channel:chat_id). Returns: The session. """ if key in self._cache: return self._cache[key] session = self._load(key) if session is None: session = Session(key=key) self._cache[key] = session return session def _load(self, key: str) -> Session | None: """Load a session from disk.""" path = self._get_session_path(key) if not path.exists(): legacy_path = self._get_legacy_session_path(key) if legacy_path.exists(): try: shutil.move(str(legacy_path), str(path)) logger.info("Migrated session {} from legacy path", key) except Exception: logger.exception("Failed to migrate session {}", key) if not path.exists(): return None try: messages = [] metadata = {} created_at = None updated_at = None last_consolidated = 0 with open(path, encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue data = json.loads(line) if data.get("_type") == "metadata": metadata = data.get("metadata", {}) created_at = datetime.fromisoformat(data["created_at"]) if data.get("created_at") else None updated_at = datetime.fromisoformat(data["updated_at"]) if data.get("updated_at") else None last_consolidated = data.get("last_consolidated", 0) else: messages.append(data) return Session( key=key, messages=messages, created_at=created_at or datetime.now(), updated_at=updated_at or datetime.now(), metadata=metadata, last_consolidated=last_consolidated ) except Exception as e: logger.warning("Failed to load session {}: {}", key, e) repaired = self._repair(key) if repaired is not None: logger.info("Recovered session {} from corrupt file ({} messages)", key, len(repaired.messages)) return repaired def _repair(self, key: str) -> Session | None: """Attempt to recover a session from a corrupt JSONL file.""" path = self._get_session_path(key) if not path.exists(): return None try: messages: list[dict[str, Any]] = [] metadata: dict[str, Any] = {} created_at: datetime | None = None updated_at: datetime | None = None last_consolidated = 0 skipped = 0 with open(path, encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue try: data = json.loads(line) except json.JSONDecodeError: skipped += 1 continue if data.get("_type") == "metadata": metadata = data.get("metadata", {}) if data.get("created_at"): try: created_at = datetime.fromisoformat(data["created_at"]) except (ValueError, TypeError): pass if data.get("updated_at"): try: updated_at = datetime.fromisoformat(data["updated_at"]) except (ValueError, TypeError): pass last_consolidated = data.get("last_consolidated", 0) else: messages.append(data) if skipped: logger.warning("Skipped {} corrupt lines in session {}", skipped, key) if not messages and not metadata: return None return Session( key=key, messages=messages, created_at=created_at or datetime.now(), updated_at=updated_at or datetime.now(), metadata=metadata, last_consolidated=last_consolidated ) except Exception as e: logger.warning("Repair failed for session {}: {}", key, e) return None @staticmethod def _session_payload(session: Session) -> dict[str, Any]: return { "key": session.key, "created_at": session.created_at.isoformat(), "updated_at": session.updated_at.isoformat(), "metadata": session.metadata, "messages": session.messages, } def save(self, session: Session, *, fsync: bool = False) -> None: """Save a session to disk atomically. When *fsync* is ``True`` the final file and its parent directory are explicitly flushed to durable storage. This is intentionally off by default (the OS page-cache is sufficient for normal operation) but should be enabled during graceful shutdown so that filesystems with write-back caching (e.g. rclone VFS, NFS, FUSE mounts) do not lose the most recent writes. """ path = self._get_session_path(session.key) tmp_path = path.with_suffix(".jsonl.tmp") try: with open(tmp_path, "w", encoding="utf-8") as f: metadata_line = { "_type": "metadata", "key": session.key, "created_at": session.created_at.isoformat(), "updated_at": session.updated_at.isoformat(), "metadata": session.metadata, "last_consolidated": session.last_consolidated } f.write(json.dumps(metadata_line, ensure_ascii=False) + "\n") for msg in session.messages: f.write(json.dumps(msg, ensure_ascii=False) + "\n") if fsync: f.flush() os.fsync(f.fileno()) os.replace(tmp_path, path) if fsync: # fsync the directory so the rename is durable. # On Windows, opening a directory with O_RDONLY raises # PermissionError — skip the dir sync there (NTFS # journals metadata synchronously). try: fd = os.open(str(path.parent), os.O_RDONLY) try: os.fsync(fd) finally: os.close(fd) except PermissionError: pass # Windows — directory fsync not supported except BaseException: tmp_path.unlink(missing_ok=True) raise self._cache[session.key] = session def flush_all(self) -> int: """Re-save every cached session with fsync for durable shutdown. Returns the number of sessions flushed. Errors on individual sessions are logged but do not prevent other sessions from being flushed. """ flushed = 0 for key, session in list(self._cache.items()): try: self.save(session, fsync=True) flushed += 1 except Exception: logger.warning("Failed to flush session {}", key, exc_info=True) return flushed def invalidate(self, key: str) -> None: """Remove a session from the in-memory cache.""" self._cache.pop(key, None) def delete_session(self, key: str) -> bool: """Remove a session from disk and the in-memory cache. Returns True if a JSONL file was found and unlinked. """ path = self._get_session_path(key) self.invalidate(key) if not path.exists(): return False try: path.unlink() return True except OSError as e: logger.warning("Failed to delete session file {}: {}", path, e) return False def read_session_file(self, key: str) -> dict[str, Any] | None: """Load a session from disk without caching; intended for read-only HTTP endpoints. Returns ``{"key", "created_at", "updated_at", "metadata", "messages"}`` or ``None`` when the session file does not exist or fails to parse. """ path = self._get_session_path(key) if not path.exists(): return None try: messages: list[dict[str, Any]] = [] metadata: dict[str, Any] = {} created_at: str | None = None updated_at: str | None = None stored_key: str | None = None with open(path, encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue data = json.loads(line) if data.get("_type") == "metadata": metadata = data.get("metadata", {}) created_at = data.get("created_at") updated_at = data.get("updated_at") stored_key = data.get("key") else: messages.append(data) return { "key": stored_key or key, "created_at": created_at, "updated_at": updated_at, "metadata": metadata, "messages": messages, } except Exception as e: logger.warning("Failed to read session {}: {}", key, e) repaired = self._repair(key) if repaired is not None: logger.info("Recovered read-only session view {} from corrupt file", key) return self._session_payload(repaired) return None def list_sessions(self) -> list[dict[str, Any]]: """ List all sessions. Returns: List of session info dicts. """ sessions = [] for path in self.sessions_dir.glob("*.jsonl"): fallback_key = path.stem.replace("_", ":", 1) try: # Read just the metadata line with open(path, encoding="utf-8") as f: first_line = f.readline().strip() if first_line: data = json.loads(first_line) if data.get("_type") == "metadata": key = data.get("key") or path.stem.replace("_", ":", 1) sessions.append({ "key": key, "created_at": data.get("created_at"), "updated_at": data.get("updated_at"), "path": str(path) }) except Exception: repaired = self._repair(fallback_key) if repaired is not None: sessions.append({ "key": repaired.key, "created_at": repaired.created_at.isoformat(), "updated_at": repaired.updated_at.isoformat(), "path": str(path) }) continue return sorted(sessions, key=lambda x: x.get("updated_at", ""), reverse=True)