diff --git a/nanobot/webui/transcript.py b/nanobot/webui/transcript.py index 40f865046..ee2734283 100644 --- a/nanobot/webui/transcript.py +++ b/nanobot/webui/transcript.py @@ -2,13 +2,16 @@ from __future__ import annotations +import base64 +import binascii import json import os import re +import shutil import time import uuid from pathlib import Path -from typing import Any, Callable, Mapping +from typing import Any, Callable, Mapping, NamedTuple from urllib.parse import unquote, urlparse from loguru import logger @@ -19,6 +22,12 @@ from nanobot.session.manager import SessionManager WEBUI_TRANSCRIPT_SCHEMA_VERSION = 3 WEBUI_FORK_MARKER_EVENT = "fork_marker" _MAX_TRANSCRIPT_FILE_BYTES = 8 * 1024 * 1024 +_TARGET_ACTIVE_TRANSCRIPT_BYTES = _MAX_TRANSCRIPT_FILE_BYTES // 2 +_TRANSCRIPT_SEGMENT_MANIFEST_VERSION = 2 +_TRANSCRIPT_ACTIVE_CHUNK_ID = "active" +_TRANSCRIPT_SEGMENT_RE = re.compile(r"^\d{6}\.jsonl$") +_DEFAULT_TRANSCRIPT_PAGE_LIMIT = 160 +_MAX_TRANSCRIPT_PAGE_LIMIT = 1000 _WEBUI_TURN_ID_RE = re.compile(r"^[A-Za-z0-9._:-]{1,128}$") WEBUI_TURN_METADATA_KEY = "webui_turn_id" WEBUI_MESSAGE_SOURCE_METADATA_KEY = "_webui_message_source" @@ -114,14 +123,37 @@ def webui_transcript_path(session_key: str) -> Path: return get_webui_dir() / f"{stem}.jsonl" -def read_transcript_lines(session_key: str) -> list[dict[str, Any]]: - path = webui_transcript_path(session_key) - if not path.is_file(): - return [] - size = path.stat().st_size - if size > _MAX_TRANSCRIPT_FILE_BYTES: - logger.warning("webui transcript too large, skipping: {}", path) - return [] +def webui_transcript_segments_dir(session_key: str) -> Path: + stem = SessionManager.safe_key(session_key) + return get_webui_dir() / f"{stem}.segments" + + +def _webui_transcript_manifest_path(session_key: str) -> Path: + return webui_transcript_segments_dir(session_key) / "manifest.json" + + +def _legacy_webui_thread_path(session_key: str) -> Path: + stem = SessionManager.safe_key(session_key) + return get_webui_dir() / f"{stem}.json" + + +class _TranscriptTurnRef(NamedTuple): + ordinal: int + records: list[dict[str, Any]] + + +class _TranscriptChunkRef(NamedTuple): + chunk_id: str + start_ordinal: int + turn_count: int + user_count: int + + +def _record_json_line(record: dict[str, Any]) -> str: + return json.dumps(record, ensure_ascii=False, separators=(",", ":")) + + +def _read_transcript_file(path: Path) -> list[dict[str, Any]]: lines_out: list[dict[str, Any]] = [] try: with open(path, encoding="utf-8") as f: @@ -142,8 +174,402 @@ def read_transcript_lines(session_key: str) -> list[dict[str, Any]]: return lines_out -def append_transcript_object(session_key: str, obj: dict[str, Any]) -> None: - raw = json.dumps(obj, ensure_ascii=False, separators=(",", ":")) +def _records_bytes(records: list[dict[str, Any]]) -> int: + total = 0 + for record in records: + total += len(_record_json_line(record).encode("utf-8")) + 1 + return total + + +def _flatten_turns(turns: list[list[dict[str, Any]]]) -> list[dict[str, Any]]: + return [record for turn in turns for record in turn] + + +def _write_records_to_path(path: Path, rows: list[dict[str, Any]]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + tmp_path = path.with_suffix(path.suffix + ".tmp") + try: + with open(tmp_path, "w", encoding="utf-8") as f: + for row in rows: + raw = _record_json_line(row) + if len(raw.encode("utf-8")) > _MAX_TRANSCRIPT_FILE_BYTES: + raise ValueError("webui transcript line too large") + f.write(raw + "\n") + f.flush() + os.fsync(f.fileno()) + os.replace(tmp_path, path) + except BaseException: + tmp_path.unlink(missing_ok=True) + raise + + +def _segment_file_path(session_key: str, segment_id: str) -> Path: + return webui_transcript_segments_dir(session_key) / f"{segment_id}.jsonl" + + +def _segment_ids_on_disk(session_key: str) -> list[str]: + directory = webui_transcript_segments_dir(session_key) + if not directory.is_dir(): + return [] + return sorted( + path.stem + for path in directory.iterdir() + if path.is_file() and _TRANSCRIPT_SEGMENT_RE.fullmatch(path.name) + ) + + +def _segment_manifest_entry(session_key: str, segment_id: str) -> dict[str, Any]: + path = _segment_file_path(session_key, segment_id) + lines = _read_transcript_file(path) + return { + "id": segment_id, + "bytes": path.stat().st_size if path.exists() else 0, + "turn_count": len(_split_transcript_turns(lines)), + "user_count": sum(1 for line in lines if _is_user_transcript_row(line)), + } + + +def _non_negative_int(value: Any) -> int | None: + if isinstance(value, bool) or not isinstance(value, int) or value < 0: + return None + return value + + +def _normalize_manifest_entry(session_key: str, entry: Any) -> dict[str, Any] | None: + if not isinstance(entry, dict): + return None + segment_id = entry.get("id") + if not isinstance(segment_id, str) or not _TRANSCRIPT_SEGMENT_RE.fullmatch(f"{segment_id}.jsonl"): + return None + segment_path = _segment_file_path(session_key, segment_id) + values = { + key: _non_negative_int(entry.get(key)) + for key in ("bytes", "turn_count", "user_count") + } + if not segment_path.is_file() or values["bytes"] != segment_path.stat().st_size: + return None + if values["turn_count"] is None or values["user_count"] is None: + return None + return { + "id": segment_id, + "bytes": values["bytes"], + "turn_count": values["turn_count"], + "user_count": values["user_count"], + } + + +def _write_segment_manifest(session_key: str, segment_ids: list[str]) -> None: + directory = webui_transcript_segments_dir(session_key) + directory.mkdir(parents=True, exist_ok=True) + data = { + "version": _TRANSCRIPT_SEGMENT_MANIFEST_VERSION, + "segments": [_segment_manifest_entry(session_key, segment_id) for segment_id in segment_ids], + } + path = _webui_transcript_manifest_path(session_key) + tmp_path = path.with_suffix(".json.tmp") + try: + tmp_path.write_text(json.dumps(data, ensure_ascii=False, indent=2) + "\n", encoding="utf-8") + os.replace(tmp_path, path) + except BaseException: + tmp_path.unlink(missing_ok=True) + raise + + +def _rebuild_segment_manifest(session_key: str) -> list[str]: + segment_ids = _segment_ids_on_disk(session_key) + if segment_ids: + _write_segment_manifest(session_key, segment_ids) + else: + _webui_transcript_manifest_path(session_key).unlink(missing_ok=True) + return segment_ids + + +def _rebuilt_segment_manifest_entries(session_key: str) -> list[dict[str, Any]]: + return [_segment_manifest_entry(session_key, segment_id) for segment_id in _rebuild_segment_manifest(session_key)] + + +def _read_segment_manifest_entries(session_key: str) -> list[dict[str, Any]]: + directory = webui_transcript_segments_dir(session_key) + if not directory.is_dir(): + return [] + path = _webui_transcript_manifest_path(session_key) + if not path.is_file(): + return _rebuilt_segment_manifest_entries(session_key) + try: + data = json.loads(path.read_text(encoding="utf-8")) + raw_segments = data.get("segments") if isinstance(data, dict) else None + if data.get("version") != _TRANSCRIPT_SEGMENT_MANIFEST_VERSION or not isinstance(raw_segments, list): + return _rebuilt_segment_manifest_entries(session_key) + entries: list[dict[str, Any]] = [] + for entry in raw_segments: + normalized = _normalize_manifest_entry(session_key, entry) + if normalized is None: + return _rebuilt_segment_manifest_entries(session_key) + entries.append(normalized) + if [entry["id"] for entry in entries] != _segment_ids_on_disk(session_key): + return _rebuilt_segment_manifest_entries(session_key) + return entries + except (OSError, json.JSONDecodeError, TypeError, AttributeError): + return _rebuilt_segment_manifest_entries(session_key) + + +def _read_segment_ids(session_key: str) -> list[str]: + return [entry["id"] for entry in _read_segment_manifest_entries(session_key)] + + +def _append_segment_turns(session_key: str, turns: list[list[dict[str, Any]]]) -> None: + if not turns: + return + segment_ids = _read_segment_ids(session_key) + next_id = int(segment_ids[-1]) + 1 if segment_ids else 1 + batch: list[list[dict[str, Any]]] = [] + batch_bytes = 0 + for turn in turns: + turn_bytes = _records_bytes(turn) + if batch and batch_bytes + turn_bytes > _MAX_TRANSCRIPT_FILE_BYTES: + segment_id = f"{next_id:06d}" + _write_records_to_path(_segment_file_path(session_key, segment_id), _flatten_turns(batch)) + segment_ids.append(segment_id) + next_id += 1 + batch = [] + batch_bytes = 0 + batch.append(turn) + batch_bytes += turn_bytes + if batch: + segment_id = f"{next_id:06d}" + _write_records_to_path(_segment_file_path(session_key, segment_id), _flatten_turns(batch)) + segment_ids.append(segment_id) + _write_segment_manifest(session_key, segment_ids) + + +def _rotate_active_transcript_if_needed(session_key: str) -> None: + path = webui_transcript_path(session_key) + if not path.is_file(): + return + try: + if path.stat().st_size <= _MAX_TRANSCRIPT_FILE_BYTES: + return + except OSError: + return + + lines = _read_transcript_file(path) + if not lines: + return + turns = _split_transcript_turns(lines) + if len(turns) <= 1: + return + + keep_start = len(turns) - 1 + keep_bytes = 0 + for idx in range(len(turns) - 1, -1, -1): + turn_bytes = _records_bytes(turns[idx]) + if idx == len(turns) - 1 or keep_bytes + turn_bytes <= _TARGET_ACTIVE_TRANSCRIPT_BYTES: + keep_start = idx + keep_bytes += turn_bytes + continue + break + + moved = turns[:keep_start] + kept = turns[keep_start:] + if not moved: + return + _append_segment_turns(session_key, moved) + _write_records_to_path(path, _flatten_turns(kept)) + + +def _chunk_ids(session_key: str) -> list[str]: + _rotate_active_transcript_if_needed(session_key) + ids = _read_segment_ids(session_key) + if webui_transcript_path(session_key).is_file(): + ids.append(_TRANSCRIPT_ACTIVE_CHUNK_ID) + return ids + + +def _read_chunk_turns(session_key: str, chunk_id: str) -> list[list[dict[str, Any]]]: + if chunk_id == _TRANSCRIPT_ACTIVE_CHUNK_ID: + path = webui_transcript_path(session_key) + else: + path = _segment_file_path(session_key, chunk_id) + if not path.is_file(): + return [] + return _split_transcript_turns(_read_transcript_file(path)) + + +def _encode_page_cursor(before_turn_ordinal: int) -> str: + raw = json.dumps( + {"before_turn": before_turn_ordinal}, + separators=(",", ":"), + ensure_ascii=False, + ).encode("utf-8") + return base64.urlsafe_b64encode(raw).decode("ascii").rstrip("=") + + +def _decode_page_cursor(value: str | None) -> int | None: + if not value: + return None + try: + padded = value + "=" * (-len(value) % 4) + data = json.loads(base64.urlsafe_b64decode(padded.encode("ascii")).decode("utf-8")) + except (binascii.Error, json.JSONDecodeError, UnicodeDecodeError, ValueError): + return None + if not isinstance(data, dict): + return None + before_turn = data.get("before_turn") + if ( + isinstance(before_turn, bool) + or not isinstance(before_turn, int) + or before_turn < 0 + ): + return None + return before_turn + + +def _coerce_page_limit(limit: int | None) -> int: + if limit is None: + return _DEFAULT_TRANSCRIPT_PAGE_LIMIT + return max(1, min(_MAX_TRANSCRIPT_PAGE_LIMIT, int(limit))) + + +def _chunk_turn_refs(session_key: str) -> list[_TranscriptChunkRef]: + _rotate_active_transcript_if_needed(session_key) + refs: list[_TranscriptChunkRef] = [] + ordinal = 0 + for entry in _read_segment_manifest_entries(session_key): + chunk_id = str(entry["id"]) + turn_count = int(entry["turn_count"]) + if turn_count <= 0: + continue + refs.append(_TranscriptChunkRef(chunk_id, ordinal, turn_count, int(entry["user_count"]))) + ordinal += turn_count + if webui_transcript_path(session_key).is_file(): + active_turns = _read_chunk_turns(session_key, _TRANSCRIPT_ACTIVE_CHUNK_ID) + active_turn_count = len(active_turns) + if active_turn_count > 0: + refs.append( + _TranscriptChunkRef( + _TRANSCRIPT_ACTIVE_CHUNK_ID, + ordinal, + active_turn_count, + sum(1 for turn in active_turns for row in turn if _is_user_transcript_row(row)), + ), + ) + return refs + + +def _count_user_messages_before_ordinal( + session_key: str, + chunks: list[_TranscriptChunkRef], + before_ordinal: int, +) -> int: + total = 0 + for chunk in chunks: + if before_ordinal <= chunk.start_ordinal: + break + local_end = min(chunk.turn_count, before_ordinal - chunk.start_ordinal) + if local_end <= 0: + continue + if local_end >= chunk.turn_count: + total += chunk.user_count + continue + turns = _read_chunk_turns(session_key, chunk.chunk_id) + total += sum( + 1 + for turn in turns[:local_end] + for row in turn + if _is_user_transcript_row(row) + ) + return total + + +def _select_transcript_page( + session_key: str, + *, + limit: int | None, + before: str | None, + _manifest_rebuilt: bool = False, +) -> tuple[list[dict[str, Any]], dict[str, Any]]: + page_limit = _coerce_page_limit(limit) + chunks = _chunk_turn_refs(session_key) + total_turns = sum(chunk.turn_count for chunk in chunks) + before_ordinal = _decode_page_cursor(before) + upper_ordinal = total_turns if before_ordinal is None else min(before_ordinal, total_turns) + selected: list[_TranscriptTurnRef] = [] + selected_message_count = 0 + + for chunk in reversed(chunks): + if chunk.start_ordinal >= upper_ordinal: + continue + local_upper = min(chunk.turn_count, upper_ordinal - chunk.start_ordinal) + if local_upper <= 0: + continue + turns = _read_chunk_turns(session_key, chunk.chunk_id) + if ( + chunk.chunk_id != _TRANSCRIPT_ACTIVE_CHUNK_ID + and len(turns) != chunk.turn_count + and not _manifest_rebuilt + ): + _rebuild_segment_manifest(session_key) + return _select_transcript_page( + session_key, + limit=limit, + before=before, + _manifest_rebuilt=True, + ) + local_upper = min(local_upper, len(turns)) + for turn_index in range(local_upper - 1, -1, -1): + ordinal = chunk.start_ordinal + turn_index + turn = turns[turn_index] + selected.append(_TranscriptTurnRef(ordinal, turn)) + selected_message_count += len(replay_transcript_to_ui_messages(turn)) + if selected_message_count >= page_limit: + break + if selected_message_count >= page_limit: + break + + selected_chronological = list(reversed(selected)) + lines = [record for ref in selected_chronological for record in ref.records] + if not selected_chronological: + return [], { + "before_cursor": None, + "has_more_before": False, + "loaded_message_count": 0, + "user_message_offset": 0, + } + + first_ref = selected_chronological[0] + has_more = first_ref.ordinal > 0 + page = { + "before_cursor": _encode_page_cursor(first_ref.ordinal) if has_more else None, + "has_more_before": has_more, + "loaded_message_count": 0, + "user_message_offset": _count_user_messages_before_ordinal( + session_key, + chunks, + first_ref.ordinal, + ), + } + return lines, page + + +def read_transcript_lines(session_key: str) -> list[dict[str, Any]]: + lines: list[dict[str, Any]] = [] + for chunk_id in _chunk_ids(session_key): + if chunk_id == _TRANSCRIPT_ACTIVE_CHUNK_ID: + lines.extend(_read_transcript_file(webui_transcript_path(session_key))) + else: + lines.extend(_read_transcript_file(_segment_file_path(session_key, chunk_id))) + return lines + + +def _write_transcript_lines(session_key: str, rows: list[dict[str, Any]]) -> None: + delete_webui_transcript(session_key) + path = webui_transcript_path(session_key) + _write_records_to_path(path, rows) + _rotate_active_transcript_if_needed(session_key) + + +def _append_to_active_transcript(session_key: str, obj: dict[str, Any]) -> None: + raw = _record_json_line(obj) if len(raw.encode("utf-8")) > _MAX_TRANSCRIPT_FILE_BYTES: msg = "webui transcript line too large" raise ValueError(msg) @@ -156,6 +582,12 @@ def append_transcript_object(session_key: str, obj: dict[str, Any]) -> None: os.fsync(f.fileno()) +def append_transcript_object(session_key: str, obj: dict[str, Any]) -> None: + _append_to_active_transcript(session_key, obj) + if obj.get("event") == "turn_end": + _rotate_active_transcript_if_needed(session_key) + + def normalize_webui_turn_id(value: Any) -> str: if isinstance(value, str): candidate = value.strip() @@ -286,25 +718,6 @@ def _is_user_transcript_row(row: dict[str, Any]) -> bool: return row.get("event") == "user" or row.get("role") == "user" -def _write_transcript_lines(session_key: str, rows: list[dict[str, Any]]) -> None: - path = webui_transcript_path(session_key) - path.parent.mkdir(parents=True, exist_ok=True) - tmp_path = path.with_suffix(".jsonl.tmp") - try: - with open(tmp_path, "w", encoding="utf-8") as f: - for row in rows: - raw = json.dumps(row, ensure_ascii=False, separators=(",", ":")) - if len(raw.encode("utf-8")) > _MAX_TRANSCRIPT_FILE_BYTES: - raise ValueError("webui transcript line too large") - f.write(raw + "\n") - f.flush() - os.fsync(f.fileno()) - os.replace(tmp_path, path) - except BaseException: - tmp_path.unlink(missing_ok=True) - raise - - def fork_transcript_before_user_index( source_key: str, target_key: str, @@ -390,15 +803,23 @@ def write_session_messages_as_transcript( def delete_webui_transcript(session_key: str) -> bool: - path = webui_transcript_path(session_key) - if not path.is_file(): - return False - try: - path.unlink() - return True - except OSError as e: - logger.warning("Failed to delete webui transcript {}: {}", path, e) - return False + removed = False + for path in (webui_transcript_path(session_key), _legacy_webui_thread_path(session_key)): + if not path.is_file(): + continue + try: + path.unlink() + removed = True + except OSError as e: + logger.warning("Failed to delete webui transcript {}: {}", path, e) + segments_dir = webui_transcript_segments_dir(session_key) + if segments_dir.is_dir(): + try: + shutil.rmtree(segments_dir) + removed = True + except OSError as e: + logger.warning("Failed to delete webui transcript segments {}: {}", segments_dir, e) + return removed def build_user_transcript_event( @@ -1409,9 +1830,17 @@ def build_webui_thread_response( augment_assistant_media: Callable[[list[str]], list[dict[str, Any]]] | None = None, augment_assistant_text: Callable[[str], str] | None = None, session_messages: list[dict[str, Any]] | None = None, + limit: int | None = None, + direction: str | None = None, + before: str | None = None, ) -> dict[str, Any] | None: """Return a payload compatible with ``WebuiThreadPersistedPayload``.""" - lines = read_transcript_lines(session_key) + paginated = limit is not None or direction is not None or before is not None + page: dict[str, Any] | None = None + if paginated: + lines, page = _select_transcript_page(session_key, limit=limit, before=before) + else: + lines = read_transcript_lines(session_key) if not lines: return None lines = inject_missing_user_events_from_session(session_key, lines, session_messages) @@ -1427,6 +1856,9 @@ def build_webui_thread_response( "sessionKey": session_key, "messages": msgs, } + if page is not None: + page["loaded_message_count"] = len(msgs) + payload["page"] = page if fork_boundary is not None: payload["fork_boundary_message_count"] = fork_boundary return payload diff --git a/nanobot/webui/ws_http.py b/nanobot/webui/ws_http.py index d21261681..f04642e04 100644 --- a/nanobot/webui/ws_http.py +++ b/nanobot/webui/ws_http.py @@ -375,6 +375,18 @@ class GatewayHTTPHandler: raw_messages = session_data.get("messages") if isinstance(session_data, dict) else None if isinstance(raw_messages, list): session_messages = [m for m in raw_messages if isinstance(m, dict)] + query = _parse_query(request.path) + raw_limit = _query_first(query, "limit") + limit: int | None = None + if raw_limit is not None and raw_limit.strip(): + try: + limit = int(raw_limit) + except ValueError: + return _http_error(400, "invalid limit") + direction = _query_first(query, "direction") + if direction is not None and direction not in {"latest"}: + return _http_error(400, "invalid direction") + before = _query_first(query, "before") data = build_webui_thread_response( decoded_key, augment_user_media=self.media.augment_transcript_media, @@ -384,6 +396,9 @@ class GatewayHTTPHandler: workspace_path=scope.project_path, ), session_messages=session_messages, + limit=limit, + direction=direction, + before=before, ) if data is None: return _http_error(404, "webui thread not found") diff --git a/tests/channels/test_websocket_channel.py b/tests/channels/test_websocket_channel.py index b74b54ad6..cf6a15455 100644 --- a/tests/channels/test_websocket_channel.py +++ b/tests/channels/test_websocket_channel.py @@ -2718,6 +2718,45 @@ def test_handle_webui_thread_get_returns_json(tmp_path, monkeypatch) -> None: assert body["messages"][0]["content"] == "hi" +def test_handle_webui_thread_get_accepts_pagination_query(tmp_path, monkeypatch) -> None: + from urllib.parse import quote + + from websockets.datastructures import Headers + from websockets.http11 import Request + + from nanobot.webui.transcript import append_transcript_object + + monkeypatch.setattr("nanobot.config.paths.get_data_dir", lambda: tmp_path) + key = "websocket:paged-route" + for idx in range(1, 4): + append_transcript_object( + key, + {"event": "user", "chat_id": "paged-route", "text": f"q{idx}"}, + ) + append_transcript_object( + key, + {"event": "message", "chat_id": "paged-route", "text": f"a{idx}"}, + ) + append_transcript_object(key, {"event": "turn_end", "chat_id": "paged-route"}) + + bus = MagicMock() + channel = _ch(bus) + channel.gateway.tokens.api_tokens["tok"] = time.monotonic() + 300.0 + enc = quote(key, safe="") + req = Request( + f"/api/sessions/{enc}/webui-thread?limit=2&direction=latest", + Headers([("Authorization", "Bearer tok")]), + ) + + resp = channel.gateway.http._handle_webui_thread_get(req, enc) + + assert resp.status_code == 200 + body = json.loads(resp.body.decode()) + assert [message["content"] for message in body["messages"]] == ["q3", "a3"] + assert body["page"]["has_more_before"] is True + assert body["page"]["before_cursor"] + + def test_handle_file_preview_returns_workspace_file(tmp_path) -> None: from urllib.parse import quote diff --git a/tests/utils/test_webui_thread_disk.py b/tests/utils/test_webui_thread_disk.py index 53094d65b..ee825dc42 100644 --- a/tests/utils/test_webui_thread_disk.py +++ b/tests/utils/test_webui_thread_disk.py @@ -3,18 +3,35 @@ from __future__ import annotations from nanobot.webui.thread_disk import delete_webui_thread, webui_thread_file_path -from nanobot.webui.transcript import append_transcript_object, webui_transcript_path +from nanobot.webui.transcript import ( + append_transcript_object, + webui_transcript_path, + webui_transcript_segments_dir, +) def test_delete_webui_thread_removes_legacy_json_and_transcript(tmp_path, monkeypatch) -> None: monkeypatch.setattr("nanobot.config.paths.get_data_dir", lambda: tmp_path) + monkeypatch.setattr("nanobot.webui.transcript._MAX_TRANSCRIPT_FILE_BYTES", 520) + monkeypatch.setattr("nanobot.webui.transcript._TARGET_ACTIVE_TRANSCRIPT_BYTES", 260) key = "websocket:k1" json_path = webui_thread_file_path(key) json_path.parent.mkdir(parents=True, exist_ok=True) json_path.write_text('{"x":1}', encoding="utf-8") - append_transcript_object(key, {"event": "user", "chat_id": "k1", "text": "hi"}) + for idx in range(1, 5): + append_transcript_object( + key, + {"event": "user", "chat_id": "k1", "text": f"question {idx} " + ("x" * 24)}, + ) + append_transcript_object( + key, + {"event": "message", "chat_id": "k1", "text": f"answer {idx} " + ("y" * 24)}, + ) + append_transcript_object(key, {"event": "turn_end", "chat_id": "k1"}) assert webui_transcript_path(key).is_file() + assert webui_transcript_segments_dir(key).is_dir() assert delete_webui_thread(key) is True assert not json_path.is_file() assert not webui_transcript_path(key).is_file() + assert not webui_transcript_segments_dir(key).exists() assert delete_webui_thread(key) is False diff --git a/tests/utils/test_webui_transcript.py b/tests/utils/test_webui_transcript.py index e44d7eb3f..0675b659a 100644 --- a/tests/utils/test_webui_transcript.py +++ b/tests/utils/test_webui_transcript.py @@ -10,6 +10,7 @@ from nanobot.webui.transcript import ( fork_transcript_before_user_index, read_transcript_lines, replay_transcript_to_ui_messages, + webui_transcript_segments_dir, write_session_messages_as_transcript, ) @@ -23,6 +24,142 @@ def test_append_and_read_roundtrip(tmp_path, monkeypatch) -> None: assert lines[0]["text"] == "hello" +def _force_small_transcript_budget(monkeypatch, *, limit: int = 520, target: int = 260) -> None: + monkeypatch.setattr("nanobot.webui.transcript._MAX_TRANSCRIPT_FILE_BYTES", limit) + monkeypatch.setattr("nanobot.webui.transcript._TARGET_ACTIVE_TRANSCRIPT_BYTES", target) + + +def _append_numbered_turn(key: str, chat_id: str, idx: int) -> None: + append_transcript_object( + key, + {"event": "user", "chat_id": chat_id, "text": f"question {idx} " + ("x" * 24)}, + ) + append_transcript_object( + key, + {"event": "message", "chat_id": chat_id, "text": f"answer {idx} " + ("y" * 24)}, + ) + append_transcript_object(key, {"event": "turn_end", "chat_id": chat_id}) + + +def _write_segmented_turns(tmp_path, monkeypatch, key: str, chat_id: str, count: int) -> None: + monkeypatch.setattr("nanobot.config.paths.get_data_dir", lambda: tmp_path) + _force_small_transcript_budget(monkeypatch) + for idx in range(1, count + 1): + _append_numbered_turn(key, chat_id, idx) + + +def _message_contents(payload: dict) -> list[str]: + return [str(message.get("content") or "") for message in payload["messages"]] + + +def _numbered_turn_texts(start: int, end: int) -> list[str]: + return [ + text + for idx in range(start, end + 1) + for text in (f"question {idx} " + ("x" * 24), f"answer {idx} " + ("y" * 24)) + ] + + +def test_segmented_transcript_rotation_preserves_full_history(tmp_path, monkeypatch) -> None: + key = "websocket:segmented" + _write_segmented_turns(tmp_path, monkeypatch, key, "segmented", 6) + + segment_dir = webui_transcript_segments_dir(key) + assert segment_dir.is_dir() + assert (segment_dir / "manifest.json").is_file() + + lines = read_transcript_lines(key) + contents = [str(line.get("text") or "") for line in lines if line.get("event") in {"user", "message"}] + assert contents == _numbered_turn_texts(1, 6) + + +def test_segmented_transcript_paginates_latest_and_older_without_overlap( + tmp_path, + monkeypatch, +) -> None: + key = "websocket:paged" + _write_segmented_turns(tmp_path, monkeypatch, key, "paged", 6) + + latest = build_webui_thread_response(key, limit=4, direction="latest") + assert latest is not None + assert latest["page"]["has_more_before"] is True + assert latest["page"]["user_message_offset"] == 4 + assert _message_contents(latest) == _numbered_turn_texts(5, 6) + + older = build_webui_thread_response( + key, + limit=4, + before=latest["page"]["before_cursor"], + ) + assert older is not None + assert older["page"]["user_message_offset"] == 2 + assert _message_contents(older) == _numbered_turn_texts(3, 4) + + +def test_page_cursor_survives_active_rotation_after_latest_page( + tmp_path, + monkeypatch, +) -> None: + key = "websocket:stable-cursor" + _write_segmented_turns(tmp_path, monkeypatch, key, "stable-cursor", 7) + + latest = build_webui_thread_response(key, limit=4, direction="latest") + assert latest is not None + cursor = latest["page"]["before_cursor"] + assert cursor + assert _message_contents(latest) == _numbered_turn_texts(6, 7) + + for idx in range(8, 13): + _append_numbered_turn(key, "stable-cursor", idx) + + older = build_webui_thread_response(key, limit=4, before=cursor) + + assert older is not None + assert _message_contents(older) == _numbered_turn_texts(4, 5) + + +def test_segment_manifest_can_be_rebuilt_when_missing_or_corrupt(tmp_path, monkeypatch) -> None: + key = "websocket:manifest" + _write_segmented_turns(tmp_path, monkeypatch, key, "manifest", 4) + + manifest = webui_transcript_segments_dir(key) / "manifest.json" + manifest.write_text("{not json", encoding="utf-8") + + lines = read_transcript_lines(key) + + assert len([line for line in lines if line.get("event") == "user"]) == 4 + assert manifest.read_text(encoding="utf-8").lstrip().startswith("{") + + +def test_delete_webui_transcript_removes_segments(tmp_path, monkeypatch) -> None: + from nanobot.webui.thread_disk import webui_thread_file_path + from nanobot.webui.transcript import delete_webui_transcript, webui_transcript_path + + key = "websocket:delete-segments" + _write_segmented_turns(tmp_path, monkeypatch, key, "delete-segments", 4) + legacy_path = webui_thread_file_path(key) + legacy_path.parent.mkdir(parents=True, exist_ok=True) + legacy_path.write_text('{"messages":[]}', encoding="utf-8") + + assert webui_transcript_segments_dir(key).is_dir() + assert delete_webui_transcript(key) is True + assert not legacy_path.exists() + assert not webui_transcript_path(key).exists() + assert not webui_transcript_segments_dir(key).exists() + + +def test_fork_transcript_reads_across_segments(tmp_path, monkeypatch) -> None: + source = "websocket:seg-source" + _write_segmented_turns(tmp_path, monkeypatch, source, "seg-source", 5) + + ok = fork_transcript_before_user_index(source, "websocket:seg-fork", 3) + + assert ok is True + forked = build_webui_thread_response("websocket:seg-fork") + assert forked is not None + assert _message_contents(forked) == _numbered_turn_texts(1, 3) + + def test_fork_transcript_before_user_index_copies_only_prefix(tmp_path, monkeypatch) -> None: monkeypatch.setattr("nanobot.config.paths.get_data_dir", lambda: tmp_path) source = "websocket:source" diff --git a/webui/src/components/thread/ThreadMessages.tsx b/webui/src/components/thread/ThreadMessages.tsx index f6122ca48..b75460a67 100644 --- a/webui/src/components/thread/ThreadMessages.tsx +++ b/webui/src/components/thread/ThreadMessages.tsx @@ -1,6 +1,5 @@ import { Fragment, useMemo } from "react"; import { useTranslation } from "react-i18next"; - import { MessageBubble } from "@/components/MessageBubble"; import { AgentActivityCluster } from "@/components/thread/AgentActivityCluster"; import { normalizeActivityTimeline, type TurnUnit } from "@/lib/activity-timeline"; @@ -10,9 +9,7 @@ interface ThreadMessagesProps { messages: UIMessage[]; /** When true, agent turn still in flight — keeps activity timeline expanded. */ isStreaming?: boolean; - hiddenMessageCount?: number; hiddenUserMessageCount?: number; - onLoadEarlier?: () => void; cliApps?: CliAppInfo[]; mcpPresets?: McpPresetInfo[]; forkBoundaryMessageCount?: number | null; @@ -66,9 +63,7 @@ export function assistantCopyFlags(units: DisplayUnit[]): boolean[] { export function ThreadMessages({ messages, isStreaming = false, - hiddenMessageCount = 0, hiddenUserMessageCount = 0, - onLoadEarlier, cliApps = [], mcpPresets = [], forkBoundaryMessageCount = null, @@ -90,20 +85,6 @@ export function ThreadMessages({ return (
- {hiddenMessageCount > 0 && onLoadEarlier ? ( -
- -
- ) : null} {units.map((unit, index) => { const prev = units[index - 1]; const marginTop = diff --git a/webui/src/components/thread/ThreadShell.tsx b/webui/src/components/thread/ThreadShell.tsx index dfb516c2d..3d9d332fe 100644 --- a/webui/src/components/thread/ThreadShell.tsx +++ b/webui/src/components/thread/ThreadShell.tsx @@ -250,6 +250,10 @@ export function ThreadShell({ const { messages: historical, loading, + loadingOlder, + loadOlder, + hasMoreBefore, + userMessageOffset, hasPendingToolCalls, refresh: refreshHistory, version: historyVersion, @@ -415,6 +419,14 @@ export function ThreadShell({ } if (cached && cached.length > 0) { const normalizedCached = projectWebuiThreadMessages(cached); + if ( + normalizedHistory.length > normalizedCached.length + && !isStaleThreadSnapshot(prev, normalizedHistory) + ) { + messageCacheRef.current.set(chatId, normalizedHistory); + appliedHistoryVersionRef.current.set(chatId, historyVersion); + return normalizedHistory; + } if (isStaleThreadSnapshot(prev, normalizedCached)) return keepLiveMessages(prev); return normalizedCached; } @@ -752,6 +764,10 @@ export function ThreadShell({ cliApps={cliApps} mcpPresets={mcpPresets} forkBoundaryMessageCount={forkBoundaryMessageCount} + hasMoreBefore={hasMoreBefore} + loadingOlder={loadingOlder} + userMessageOffset={userMessageOffset} + onLoadOlder={loadOlder} onOpenFilePreview={historyKey ? handleOpenFilePreview : undefined} onForkFromMessage={onForkChat ? handleForkFromMessage : undefined} /> diff --git a/webui/src/components/thread/ThreadViewport.tsx b/webui/src/components/thread/ThreadViewport.tsx index 42ac3b379..55df4ecb0 100644 --- a/webui/src/components/thread/ThreadViewport.tsx +++ b/webui/src/components/thread/ThreadViewport.tsx @@ -38,11 +38,16 @@ interface ThreadViewportProps { cliApps?: CliAppInfo[]; mcpPresets?: McpPresetInfo[]; forkBoundaryMessageCount?: number | null; + hasMoreBefore?: boolean; + loadingOlder?: boolean; + userMessageOffset?: number; + onLoadOlder?: () => Promise | void; onOpenFilePreview?: (path: string) => void; onForkFromMessage?: (beforeUserIndex: number) => void; } const NEAR_BOTTOM_PX = 48; +const NEAR_TOP_PX = 96; const DEFAULT_SCROLL_BUTTON_BOTTOM_PX = 192; const SCROLL_BUTTON_COMPOSER_GAP_PX = 16; export const INITIAL_HISTORY_WINDOW = 160; @@ -72,6 +77,10 @@ export const ThreadViewport = forwardRef 0 + userMessageOffset + + (hiddenMessageCount > 0 ? messages.slice(0, hiddenMessageCount).filter((message) => message.role === "user").length - : 0; + : 0); const visibleForkBoundaryMessageCount = forkBoundaryMessageCount !== null && forkBoundaryMessageCount > hiddenMessageCount ? forkBoundaryMessageCount - hiddenMessageCount @@ -126,6 +136,7 @@ export const ThreadViewport = forwardRef - Math.min(messages.length, count + HISTORY_WINDOW_INCREMENT), - ); - }, [messages.length]); + if (hiddenMessageCount > 0) { + setVisibleMessageCount((count) => + Math.min(messages.length, count + HISTORY_WINDOW_INCREMENT), + ); + return; + } + if (hasMoreBefore && onLoadOlder && !loadingOlder) { + setVisibleMessageCount((count) => count + HISTORY_WINDOW_INCREMENT); + void onLoadOlder(); + } + }, [hasMoreBefore, hiddenMessageCount, loadingOlder, messages.length, onLoadOlder]); + + const maybeLoadEarlierFromScroll = useCallback(() => { + const el = scrollRef.current; + if (!el || !hasMessages || pendingConversationScrollRef.current) return; + if (!userReadingHistoryRef.current) return; + if (el.scrollTop > NEAR_TOP_PX) return; + if (hiddenMessageCount <= 0 && !hasMoreBefore) return; + loadEarlierMessages(); + }, [hasMessages, hasMoreBefore, hiddenMessageCount, loadEarlierMessages]); const jumpToUserPrompt = useCallback((promptId: string) => { const scrollEl = scrollRef.current; @@ -218,8 +245,17 @@ export const ThreadViewport = forwardRef { const promptId = pendingPromptJumpRef.current; @@ -271,17 +307,19 @@ export const ThreadViewport = forwardRef { + const onScroll = (allowHistoryLoad = true) => { const distance = el.scrollHeight - el.scrollTop - el.clientHeight; const near = distance < NEAR_BOTTOM_PX; setAtBottom(near); userReadingHistoryRef.current = !near; + if (allowHistoryLoad && !near) maybeLoadEarlierFromScroll(); }; - onScroll(); - el.addEventListener("scroll", onScroll, { passive: true }); - return () => el.removeEventListener("scroll", onScroll); - }, []); + onScroll(false); + const handleScroll = () => onScroll(true); + el.addEventListener("scroll", handleScroll, { passive: true }); + return () => el.removeEventListener("scroll", handleScroll); + }, [maybeLoadEarlierFromScroll]); return (
@@ -302,9 +340,7 @@ export const ThreadViewport = forwardRef ({ + ...m, + id: m.id ?? `hist-${idx}`, + createdAt: typeof m.createdAt === "number" ? m.createdAt : Date.now(), + })); +} /** Sidebar state: fetches the full session list and exposes create / delete actions. */ export function useSessions(): { @@ -129,14 +139,19 @@ export function useSessions(): { export function useSessionHistory(key: string | null): { messages: UIMessage[]; loading: boolean; + loadingOlder: boolean; error: string | null; refresh: () => void; + loadOlder: () => Promise; + hasMoreBefore: boolean; + userMessageOffset: number; version: number; forkBoundaryMessageCount: number | null; /** ``true`` when the replayed transcript ends with a trace row (turn still in flight). */ hasPendingToolCalls: boolean; } { const { token } = useClient(); + const loadingOlderRef = useRef(false); const [refreshSeq, setRefreshSeq] = useState(0); const refresh = useCallback(() => { setRefreshSeq((value) => value + 1); @@ -145,17 +160,25 @@ export function useSessionHistory(key: string | null): { key: string | null; messages: UIMessage[]; loading: boolean; + loadingOlder: boolean; error: string | null; hasPendingToolCalls: boolean; forkBoundaryMessageCount: number | null; + beforeCursor: string | null; + hasMoreBefore: boolean; + userMessageOffset: number; version: number; }>({ key: null, messages: [], loading: false, + loadingOlder: false, error: null, hasPendingToolCalls: false, forkBoundaryMessageCount: null, + beforeCursor: null, + hasMoreBefore: false, + userMessageOffset: 0, version: 0, }); @@ -165,9 +188,13 @@ export function useSessionHistory(key: string | null): { key: null, messages: [], loading: false, + loadingOlder: false, error: null, hasPendingToolCalls: false, forkBoundaryMessageCount: null, + beforeCursor: null, + hasMoreBefore: false, + userMessageOffset: 0, version: 0, }); return; @@ -176,37 +203,44 @@ export function useSessionHistory(key: string | null): { // Mark the new key as loading immediately so callers never see stale // messages from the previous session during the render right after a switch. setState((prev) => prev.key === key - ? { ...prev, loading: true, error: null } + ? { ...prev, loading: true, loadingOlder: false, error: null } : { key, messages: [], loading: true, + loadingOlder: false, error: null, hasPendingToolCalls: false, forkBoundaryMessageCount: null, + beforeCursor: null, + hasMoreBefore: false, + userMessageOffset: 0, version: 0, }); (async () => { try { - const body = await fetchWebuiThread(token, key); + const body = await fetchWebuiThread(token, key, { + limit: INITIAL_HISTORY_PAGE_LIMIT, + direction: "latest", + }); if (cancelled) return; if (!body?.messages?.length) { setState((prev) => ({ key, messages: [], loading: false, + loadingOlder: false, error: null, hasPendingToolCalls: false, forkBoundaryMessageCount: null, + beforeCursor: null, + hasMoreBefore: false, + userMessageOffset: 0, version: prev.key === key ? prev.version + 1 : 1, })); return; } - const ui: UIMessage[] = body.messages.map((m, idx) => ({ - ...m, - id: m.id ?? `hist-${idx}`, - createdAt: typeof m.createdAt === "number" ? m.createdAt : Date.now(), - })); + const ui = persistedMessagesToUi(body.messages); const last = ui[ui.length - 1]; const hasPending = last?.kind === "trace"; const forkBoundary = typeof body.fork_boundary_message_count === "number" @@ -216,9 +250,13 @@ export function useSessionHistory(key: string | null): { key, messages: ui, loading: false, + loadingOlder: false, error: null, hasPendingToolCalls: hasPending, forkBoundaryMessageCount: forkBoundary, + beforeCursor: body.page?.before_cursor ?? null, + hasMoreBefore: body.page?.has_more_before === true, + userMessageOffset: Math.max(0, body.page?.user_message_offset ?? 0), version: prev.key === key ? prev.version + 1 : 1, })); } catch (e) { @@ -228,9 +266,13 @@ export function useSessionHistory(key: string | null): { key, messages: [], loading: false, + loadingOlder: false, error: null, hasPendingToolCalls: false, forkBoundaryMessageCount: null, + beforeCursor: null, + hasMoreBefore: false, + userMessageOffset: 0, version: prev.key === key ? prev.version + 1 : 1, })); } else { @@ -238,9 +280,13 @@ export function useSessionHistory(key: string | null): { key, messages: [], loading: false, + loadingOlder: false, error: (e as Error).message, hasPendingToolCalls: false, forkBoundaryMessageCount: null, + beforeCursor: null, + hasMoreBefore: false, + userMessageOffset: 0, version: prev.key === key ? prev.version : 0, })); } @@ -251,12 +297,78 @@ export function useSessionHistory(key: string | null): { }; }, [key, token, refreshSeq]); + const loadOlder = useCallback(async () => { + if (!key || loadingOlderRef.current) return; + const before = state.key === key ? state.beforeCursor : null; + if (!before || !state.hasMoreBefore) return; + loadingOlderRef.current = true; + setState((prev) => prev.key === key ? { ...prev, loadingOlder: true, error: null } : prev); + try { + const body = await fetchWebuiThread(token, key, { + limit: OLDER_HISTORY_PAGE_LIMIT, + before, + }); + setState((prev) => { + if (prev.key !== key) return prev; + if (!body?.messages?.length) { + return { + ...prev, + loadingOlder: false, + hasMoreBefore: false, + beforeCursor: null, + }; + } + const older = persistedMessagesToUi(body.messages); + const olderBoundary = typeof body.fork_boundary_message_count === "number" + ? Math.max(0, Math.min(body.fork_boundary_message_count, older.length)) + : null; + const shiftedBoundary = prev.forkBoundaryMessageCount === null + ? null + : prev.forkBoundaryMessageCount + older.length; + const nextMessages = [...older, ...prev.messages]; + const last = nextMessages[nextMessages.length - 1]; + return { + ...prev, + messages: nextMessages, + loadingOlder: false, + error: null, + hasPendingToolCalls: last?.kind === "trace", + forkBoundaryMessageCount: olderBoundary ?? shiftedBoundary, + beforeCursor: body.page?.before_cursor ?? null, + hasMoreBefore: body.page?.has_more_before === true, + userMessageOffset: Math.max(0, body.page?.user_message_offset ?? 0), + version: prev.version + 1, + }; + }); + } catch (e) { + setState((prev) => prev.key === key + ? { + ...prev, + loadingOlder: false, + error: (e as Error).message, + } + : prev); + } finally { + loadingOlderRef.current = false; + } + }, [ + key, + state.beforeCursor, + state.hasMoreBefore, + state.key, + token, + ]); + if (!key) { return { messages: EMPTY_MESSAGES, loading: false, + loadingOlder: false, error: null, refresh, + loadOlder, + hasMoreBefore: false, + userMessageOffset: 0, version: 0, forkBoundaryMessageCount: null, hasPendingToolCalls: false, @@ -269,8 +381,12 @@ export function useSessionHistory(key: string | null): { return { messages: EMPTY_MESSAGES, loading: true, + loadingOlder: false, error: null, refresh, + loadOlder, + hasMoreBefore: false, + userMessageOffset: 0, version: 0, forkBoundaryMessageCount: null, hasPendingToolCalls: false, @@ -280,8 +396,12 @@ export function useSessionHistory(key: string | null): { return { messages: state.messages, loading: state.loading, + loadingOlder: state.loadingOlder, error: state.error, refresh, + loadOlder, + hasMoreBefore: state.hasMoreBefore, + userMessageOffset: state.userMessageOffset, version: state.version, forkBoundaryMessageCount: state.forkBoundaryMessageCount, hasPendingToolCalls: state.hasPendingToolCalls, diff --git a/webui/src/lib/api.ts b/webui/src/lib/api.ts index 1342a102b..63a74e06e 100644 --- a/webui/src/lib/api.ts +++ b/webui/src/lib/api.ts @@ -124,12 +124,27 @@ export async function listSessions( } /** Disk-backed WebUI display thread snapshot (separate from agent session). */ +export interface FetchWebuiThreadOptions { + limit?: number; + direction?: "latest"; + before?: string | null; +} + export async function fetchWebuiThread( token: string, key: string, + optionsOrBase?: FetchWebuiThreadOptions | string, base: string = "", ): Promise { - const url = `${base}/api/sessions/${encodeURIComponent(key)}/webui-thread`; + const options = typeof optionsOrBase === "string" ? undefined : optionsOrBase; + const resolvedBase = typeof optionsOrBase === "string" ? optionsOrBase : base; + const params = new URLSearchParams(); + if (options?.limit !== undefined) params.set("limit", String(options.limit)); + if (options?.direction) params.set("direction", options.direction); + if (options?.before) params.set("before", options.before); + const query = params.toString(); + const suffix = query ? `?${query}` : ""; + const url = `${resolvedBase}/api/sessions/${encodeURIComponent(key)}/webui-thread${suffix}`; const res = await fetchWithTimeout(url, { headers: { Authorization: `Bearer ${token}` }, credentials: "same-origin", diff --git a/webui/src/lib/types.ts b/webui/src/lib/types.ts index 438373a1f..ae21b98b3 100644 --- a/webui/src/lib/types.ts +++ b/webui/src/lib/types.ts @@ -857,12 +857,21 @@ export interface OutboundMcpPresetMention { } /** Response shape for ``GET .../webui-thread`` (server-built transcript replay). */ +export interface WebuiThreadPagePayload { + before_cursor?: string | null; + has_more_before?: boolean; + loaded_message_count?: number; + total_known_message_count?: number; + user_message_offset?: number; +} + export interface WebuiThreadPersistedPayload { schemaVersion: number; sessionKey?: string; savedAt?: string; messages: UIMessage[]; fork_boundary_message_count?: number; + page?: WebuiThreadPagePayload; workspace_scope?: WorkspaceScopePayload; } diff --git a/webui/src/tests/api.test.ts b/webui/src/tests/api.test.ts index d48483615..f4c5972f2 100644 --- a/webui/src/tests/api.test.ts +++ b/webui/src/tests/api.test.ts @@ -60,6 +60,21 @@ describe("webui API helpers", () => { ); }); + it("passes pagination params when fetching a WebUI thread page", async () => { + await fetchWebuiThread("tok", "websocket:chat-1", { + limit: 120, + before: "abc+/=", + }); + + expect(fetch).toHaveBeenCalledWith( + "/api/sessions/websocket%3Achat-1/webui-thread?limit=120&before=abc%2B%2F%3D", + expect.objectContaining({ + headers: { Authorization: "Bearer tok" }, + credentials: "same-origin", + }), + ); + }); + it("percent-encodes websocket keys and paths when fetching file previews", async () => { await fetchFilePreview("tok", "websocket:chat-1", "/tmp/project/hook.py:12"); diff --git a/webui/src/tests/thread-shell.test.tsx b/webui/src/tests/thread-shell.test.tsx index f80640056..5d026e767 100644 --- a/webui/src/tests/thread-shell.test.tsx +++ b/webui/src/tests/thread-shell.test.tsx @@ -725,16 +725,24 @@ describe("ThreadShell", () => { it("forks assistant replies using the global user message index rather than the visible window index", async () => { const client = makeClient(); const onForkChat = vi.fn().mockResolvedValue("chat-fork"); - const rows = Array.from({ length: 165 }, (_, index) => [ - { role: "user" as const, content: `question ${index}` }, - { role: "assistant" as const, content: `answer ${index}` }, - ]).flat(); + const rows = [ + { role: "user" as const, content: "question 100" }, + { role: "assistant" as const, content: "answer 100" }, + ]; vi.stubGlobal( "fetch", vi.fn(async (input: RequestInfo | URL) => { const url = String(input); if (url.includes("websocket%3Along-chat/webui-thread")) { - return httpJson(transcriptFromSimpleMessages(rows)); + return httpJson({ + ...transcriptFromSimpleMessages(rows), + page: { + before_cursor: "before-question-100", + has_more_before: true, + loaded_message_count: 2, + user_message_offset: 100, + }, + }); } return { ok: false, diff --git a/webui/src/tests/thread-viewport.test.tsx b/webui/src/tests/thread-viewport.test.tsx index e7d72fb1b..6a442db4e 100644 --- a/webui/src/tests/thread-viewport.test.tsx +++ b/webui/src/tests/thread-viewport.test.tsx @@ -143,7 +143,7 @@ describe("ThreadViewport", () => { Object.defineProperties(scroller, { scrollHeight: { configurable: true, value: 2400 }, clientHeight: { configurable: true, value: 600 }, - scrollTop: { configurable: true, value: 0 }, + scrollTop: { configurable: true, writable: true, value: 0 }, }); act(() => { @@ -167,13 +167,13 @@ describe("ThreadViewport", () => { expect(screen.queryByText("message 139")).not.toBeInTheDocument(); expect(screen.getByText("message 140")).toBeInTheDocument(); expect(screen.getByText("message 299")).toBeInTheDocument(); - expect(screen.getByRole("button", { name: "Load earlier messages" })).toBeInTheDocument(); + expect(screen.queryByRole("button", { name: "Load earlier messages" })).not.toBeInTheDocument(); }); - it("loads earlier history in fixed increments without rendering the whole transcript", () => { + it("automatically expands earlier local history near the top", () => { const longMessages = makeLongMessages(300); - render( + const { container } = render( { />, ); - fireEvent.click(screen.getByRole("button", { name: "Load earlier messages" })); + const scroller = container.firstElementChild?.firstElementChild as HTMLElement; + Object.defineProperties(scroller, { + scrollHeight: { configurable: true, value: 2400 }, + clientHeight: { configurable: true, value: 600 }, + scrollTop: { configurable: true, writable: true, value: 0 }, + }); + + act(() => { + scroller.dispatchEvent(new Event("scroll")); + }); const firstVisible = 300 - INITIAL_HISTORY_WINDOW - HISTORY_WINDOW_INCREMENT; @@ -193,6 +202,33 @@ describe("ThreadViewport", () => { expect(screen.getByText("message 299")).toBeInTheDocument(); }); + it("automatically requests older transcript pages near the top", () => { + const onLoadOlder = vi.fn(); + + const { container } = render( + } + hasMoreBefore + onLoadOlder={onLoadOlder} + />, + ); + + const scroller = container.firstElementChild?.firstElementChild as HTMLElement; + Object.defineProperties(scroller, { + scrollHeight: { configurable: true, value: 1800 }, + clientHeight: { configurable: true, value: 600 }, + scrollTop: { configurable: true, writable: true, value: 0 }, + }); + + act(() => { + scroller.dispatchEvent(new Event("scroll")); + }); + + expect(onLoadOlder).toHaveBeenCalledTimes(1); + }); + it("renders a prompt rail that jumps to user messages", async () => { const promptMessages = makeLongMessages(5); const { container } = render( diff --git a/webui/src/tests/useSessions.test.tsx b/webui/src/tests/useSessions.test.tsx index 1d79b4673..a606b249a 100644 --- a/webui/src/tests/useSessions.test.tsx +++ b/webui/src/tests/useSessions.test.tsx @@ -414,6 +414,65 @@ describe("useSessions", () => { expect(result.current.hasPendingToolCalls).toBe(false); }); + it("loads older transcript pages before the current history", async () => { + vi.mocked(api.fetchWebuiThread) + .mockResolvedValueOnce({ + schemaVersion: 3, + messages: [ + { id: "u2", role: "user", content: "new question", createdAt: 2 }, + { id: "a2", role: "assistant", content: "new answer", createdAt: 3 }, + ], + page: { + before_cursor: "cursor-2", + has_more_before: true, + loaded_message_count: 2, + user_message_offset: 1, + }, + }) + .mockResolvedValueOnce({ + schemaVersion: 3, + messages: [ + { id: "u1", role: "user", content: "old question", createdAt: 0 }, + { id: "a1", role: "assistant", content: "old answer", createdAt: 1 }, + ], + page: { + before_cursor: null, + has_more_before: false, + loaded_message_count: 2, + user_message_offset: 0, + }, + }); + + const { result } = renderHook(() => useSessionHistory("websocket:paged"), { + wrapper: wrap(fakeClient()), + }); + + await waitFor(() => expect(result.current.loading).toBe(false)); + expect(api.fetchWebuiThread).toHaveBeenCalledWith("tok", "websocket:paged", { + limit: 160, + direction: "latest", + }); + expect(result.current.hasMoreBefore).toBe(true); + expect(result.current.userMessageOffset).toBe(1); + + await act(async () => { + await result.current.loadOlder(); + }); + + expect(api.fetchWebuiThread).toHaveBeenLastCalledWith("tok", "websocket:paged", { + limit: 120, + before: "cursor-2", + }); + expect(result.current.messages.map((message) => message.content)).toEqual([ + "old question", + "old answer", + "new question", + "new answer", + ]); + expect(result.current.hasMoreBefore).toBe(false); + expect(result.current.userMessageOffset).toBe(0); + }); + it("keeps the session in the list when delete fails", async () => { vi.mocked(api.listSessions).mockResolvedValue([ {