From 73d4b1cb2f2229eb7045852ae0566642fc3c9a5c Mon Sep 17 00:00:00 2001 From: Xubin Ren <52506698+Re-bin@users.noreply.github.com> Date: Wed, 10 Jun 2026 02:07:14 +0800 Subject: [PATCH] feat(webui): persist fork boundary metadata --- nanobot/channels/websocket.py | 14 +++- nanobot/session/manager.py | 4 +- nanobot/webui/transcript.py | 47 +++++++++++- tests/agent/test_session_manager_history.py | 28 ++++++++ tests/channels/test_websocket_channel.py | 17 +++-- tests/utils/test_webui_transcript.py | 80 +++++++++++++++++++++ 6 files changed, 182 insertions(+), 8 deletions(-) diff --git a/nanobot/channels/websocket.py b/nanobot/channels/websocket.py index 20aaac097..ec26198e6 100644 --- a/nanobot/channels/websocket.py +++ b/nanobot/channels/websocket.py @@ -28,7 +28,11 @@ from nanobot.security.workspace_access import ( WorkspaceScopeError, ) from nanobot.session.goal_state import goal_state_ws_blob -from nanobot.session.webui_turns import websocket_turn_wall_started_at +from nanobot.session.webui_turns import ( + WEBUI_TITLE_METADATA_KEY, + clean_generated_title, + websocket_turn_wall_started_at, +) from nanobot.utils.media_decode import ( FileSizeExceeded, save_base64_data_url, @@ -46,6 +50,7 @@ from nanobot.webui.http_utils import ( ) from nanobot.webui.mcp_presets_api import normalize_mcp_preset_mentions from nanobot.webui.transcript import ( + append_fork_marker, delete_webui_transcript, fork_transcript_before_user_index, write_session_messages_as_transcript, @@ -709,6 +714,13 @@ class WebSocketChannel(BaseChannel): ) if not transcript_ok: write_session_messages_as_transcript(target_key, forked.messages) + append_fork_marker(target_key) + fork_title = clean_generated_title( + envelope.get("title") if isinstance(envelope.get("title"), str) else None, + ) + if fork_title: + forked.metadata[WEBUI_TITLE_METADATA_KEY] = fork_title + self.gateway.session_manager.save(forked, fsync=True) except Exception as exc: delete_webui_transcript(target_key) self.gateway.session_manager.delete_session(target_key) diff --git a/nanobot/session/manager.py b/nanobot/session/manager.py index 6c92fe753..73fb52cec 100644 --- a/nanobot/session/manager.py +++ b/nanobot/session/manager.py @@ -648,8 +648,8 @@ class SessionManager: ``before_user_index`` is zero-based over user messages in the full session: ``0`` means "before the first user message", ``1`` means "before the second user message", and so on. A value equal to the total user-message - count copies the full session prefix. The target user message itself is - not copied; the WebUI pre-fills it in the composer for editing and resend. + count copies the full session prefix. WebUI assistant-reply forks pass + the next user index so the selected completed assistant turn is included. """ if before_user_index < 0: return None diff --git a/nanobot/webui/transcript.py b/nanobot/webui/transcript.py index 59b7a2fd9..a5f5175d7 100644 --- a/nanobot/webui/transcript.py +++ b/nanobot/webui/transcript.py @@ -17,6 +17,7 @@ from nanobot.config.paths import get_webui_dir from nanobot.session.manager import SessionManager WEBUI_TRANSCRIPT_SCHEMA_VERSION = 3 +WEBUI_FORK_MARKER_EVENT = "fork_marker" _MAX_TRANSCRIPT_FILE_BYTES = 8 * 1024 * 1024 _WEBUI_TURN_ID_RE = re.compile(r"^[A-Za-z0-9._:-]{1,128}$") WEBUI_TURN_METADATA_KEY = "webui_turn_id" @@ -306,6 +307,8 @@ def fork_transcript_before_user_index( user_index = 0 found_target = False for row in lines: + if row.get("event") == WEBUI_FORK_MARKER_EVENT: + continue if _is_user_transcript_row(row): if user_index == before_user_index: found_target = True @@ -340,6 +343,17 @@ def fork_transcript_before_user_index( return True +def append_fork_marker(session_key: str) -> None: + """Mark the UI-only boundary where a WebUI fork starts accepting new turns.""" + append_transcript_object( + session_key, + { + "event": WEBUI_FORK_MARKER_EVENT, + "chat_id": _chat_id_from_session_key(session_key), + }, + ) + + def write_session_messages_as_transcript( target_key: str, messages: list[dict[str, Any]], @@ -1397,6 +1411,28 @@ def replay_transcript_to_ui_messages( return messages +def fork_boundary_message_count( + lines: list[dict[str, Any]], + *, + augment_user_media: Callable[[list[str]], list[dict[str, Any]]] | None = None, + augment_assistant_media: Callable[[list[str]], list[dict[str, Any]]] | None = None, + augment_assistant_text: Callable[[str], str] | None = None, +) -> int | None: + """Return the replayed UI message count before the first fork marker, if any.""" + for idx, rec in enumerate(lines): + if rec.get("event") != WEBUI_FORK_MARKER_EVENT: + continue + return len( + replay_transcript_to_ui_messages( + lines[:idx], + augment_user_media=augment_user_media, + augment_assistant_media=augment_assistant_media, + augment_assistant_text=augment_assistant_text, + ), + ) + return None + + def build_webui_thread_response( session_key: str, *, @@ -1410,14 +1446,23 @@ def build_webui_thread_response( if not lines: return None lines = inject_missing_user_events_from_session(session_key, lines, session_messages) + fork_boundary = fork_boundary_message_count( + lines, + augment_user_media=augment_user_media, + augment_assistant_media=augment_assistant_media, + augment_assistant_text=augment_assistant_text, + ) msgs = replay_transcript_to_ui_messages( lines, augment_user_media=augment_user_media, augment_assistant_media=augment_assistant_media, augment_assistant_text=augment_assistant_text, ) - return { + payload = { "schemaVersion": WEBUI_TRANSCRIPT_SCHEMA_VERSION, "sessionKey": session_key, "messages": msgs, } + if fork_boundary is not None: + payload["fork_boundary_message_count"] = fork_boundary + return payload diff --git a/tests/agent/test_session_manager_history.py b/tests/agent/test_session_manager_history.py index 3441c4833..6f123de32 100644 --- a/tests/agent/test_session_manager_history.py +++ b/tests/agent/test_session_manager_history.py @@ -454,6 +454,34 @@ def test_fork_session_before_user_index_copies_only_prefix(tmp_path): assert [m["content"] for m in saved["messages"]] == ["round1", "answer1"] +def test_fork_session_from_middle_assistant_reply_keeps_selected_turn(tmp_path): + manager = SessionManager(tmp_path) + source = manager.get_or_create("websocket:source") + source.add_message("user", "round1") + source.add_message("assistant", "answer1") + source.add_message("user", "round2") + source.add_message("assistant", "answer2") + source.add_message("user", "round3 must not appear") + source.add_message("assistant", "answer3 must not appear") + manager.save(source) + + forked = manager.fork_session_before_user_index( + "websocket:source", + "websocket:fork", + 2, + ) + + assert forked is not None + assert [m["content"] for m in forked.messages] == [ + "round1", + "answer1", + "round2", + "answer2", + ] + saved = manager.read_session_file("websocket:fork") + assert "round3 must not appear" not in str(saved) + + def test_fork_session_rejects_negative_missing_and_out_of_range(tmp_path): manager = SessionManager(tmp_path) source = manager.get_or_create("websocket:source") diff --git a/tests/channels/test_websocket_channel.py b/tests/channels/test_websocket_channel.py index f8e8ea2e9..901d58664 100644 --- a/tests/channels/test_websocket_channel.py +++ b/tests/channels/test_websocket_channel.py @@ -2422,7 +2422,12 @@ async def test_fork_chat_copies_only_prefix_session_and_transcript( await channel._dispatch_envelope( conn, "webui-client", - {"type": "fork_chat", "source_chat_id": "source", "before_user_index": 1}, + { + "type": "fork_chat", + "source_chat_id": "source", + "before_user_index": 1, + "title": "Fork: Old title", + }, ) sent = [json.loads(call.args[0]) for call in conn.send.await_args_list] @@ -2430,8 +2435,10 @@ async def test_fork_chat_copies_only_prefix_session_and_transcript( fork_id = attached["chat_id"] saved = sessions.read_session_file(f"websocket:{fork_id}") assert [m["content"] for m in saved["messages"]] == ["round1", "answer1"] + assert saved["metadata"]["title"] == "Fork: Old title" fork_lines = read_transcript_lines(f"websocket:{fork_id}") - assert [line.get("text") for line in fork_lines] == ["round1", "answer1", None] + assert [line.get("text") for line in fork_lines] == ["round1", "answer1", None, None] + assert fork_lines[-1]["event"] == "fork_marker" assert all(line.get("chat_id") == fork_id for line in fork_lines) assert "round3 must not appear" not in json.dumps(saved, ensure_ascii=False) bus.publish_inbound.assert_not_awaited() @@ -2477,7 +2484,8 @@ async def test_fork_chat_falls_back_to_session_prefix_when_transcript_lacks_user saved = sessions.read_session_file(f"websocket:{fork_id}") assert [m["content"] for m in saved["messages"]] == ["round1", "answer1"] fork_lines = read_transcript_lines(f"websocket:{fork_id}") - assert [line.get("text") for line in fork_lines] == ["round1", "answer1"] + assert [line.get("text") for line in fork_lines] == ["round1", "answer1", None] + assert fork_lines[-1]["event"] == "fork_marker" assert "round3 must not appear" not in json.dumps(fork_lines, ensure_ascii=False) bus.publish_inbound.assert_not_awaited() @@ -2520,7 +2528,8 @@ async def test_fork_chat_allows_index_equal_to_user_count( saved = sessions.read_session_file(f"websocket:{fork_id}") assert [m["content"] for m in saved["messages"]] == ["round1", "answer1"] fork_lines = read_transcript_lines(f"websocket:{fork_id}") - assert [line.get("text") for line in fork_lines] == ["round1", "answer1"] + assert [line.get("text") for line in fork_lines] == ["round1", "answer1", None] + assert fork_lines[-1]["event"] == "fork_marker" bus.publish_inbound.assert_not_awaited() diff --git a/tests/utils/test_webui_transcript.py b/tests/utils/test_webui_transcript.py index 37876e30a..595e75330 100644 --- a/tests/utils/test_webui_transcript.py +++ b/tests/utils/test_webui_transcript.py @@ -4,6 +4,7 @@ from __future__ import annotations from nanobot.webui.transcript import ( WEBUI_TRANSCRIPT_SCHEMA_VERSION, + append_fork_marker, append_transcript_object, build_webui_thread_response, fork_transcript_before_user_index, @@ -45,6 +46,33 @@ def test_fork_transcript_before_user_index_copies_only_prefix(tmp_path, monkeypa assert "round3 must not appear" not in "\n".join(str(line.get("text")) for line in lines) +def test_fork_transcript_from_middle_assistant_reply_keeps_selected_turn( + tmp_path, + monkeypatch, +) -> None: + monkeypatch.setattr("nanobot.config.paths.get_data_dir", lambda: tmp_path) + source = "websocket:source" + for ev in ( + {"event": "user", "chat_id": "source", "text": "round1"}, + {"event": "message", "chat_id": "source", "text": "answer1"}, + {"event": "user", "chat_id": "source", "text": "round2"}, + {"event": "message", "chat_id": "source", "text": "answer2"}, + {"event": "user", "chat_id": "source", "text": "round3 must not appear"}, + {"event": "message", "chat_id": "source", "text": "answer3 must not appear"}, + ): + append_transcript_object(source, ev) + + ok = fork_transcript_before_user_index(source, "websocket:fork", 2) + + assert ok is True + assert [line.get("text") for line in read_transcript_lines("websocket:fork")] == [ + "round1", + "answer1", + "round2", + "answer2", + ] + + def test_fork_transcript_rejects_out_of_range_user_index(tmp_path, monkeypatch) -> None: monkeypatch.setattr("nanobot.config.paths.get_data_dir", lambda: tmp_path) source = "websocket:source" @@ -72,6 +100,58 @@ def test_fork_transcript_allows_index_equal_to_user_count(tmp_path, monkeypatch) ] +def test_build_response_reports_fork_boundary_from_marker(tmp_path, monkeypatch) -> None: + monkeypatch.setattr("nanobot.config.paths.get_data_dir", lambda: tmp_path) + key = "websocket:fork" + for ev in ( + {"event": "user", "chat_id": "fork", "text": "round1"}, + {"event": "message", "chat_id": "fork", "text": "answer1"}, + ): + append_transcript_object(key, ev) + append_fork_marker(key) + append_transcript_object(key, {"event": "user", "chat_id": "fork", "text": "new branch"}) + + out = build_webui_thread_response(key) + + assert out is not None + assert [m["content"] for m in out["messages"]] == ["round1", "answer1", "new branch"] + assert out["fork_boundary_message_count"] == 2 + + +def test_nested_fork_drops_inherited_fork_marker(tmp_path, monkeypatch) -> None: + monkeypatch.setattr("nanobot.config.paths.get_data_dir", lambda: tmp_path) + source = "websocket:source" + for ev in ( + {"event": "user", "chat_id": "source", "text": "round1"}, + {"event": "message", "chat_id": "source", "text": "answer1"}, + ): + append_transcript_object(source, ev) + append_fork_marker(source) + for ev in ( + {"event": "user", "chat_id": "source", "text": "round2"}, + {"event": "message", "chat_id": "source", "text": "answer2"}, + ): + append_transcript_object(source, ev) + + ok = fork_transcript_before_user_index(source, "websocket:nested", 2) + append_fork_marker("websocket:nested") + + lines = read_transcript_lines("websocket:nested") + out = build_webui_thread_response("websocket:nested") + + assert ok is True + assert [line.get("event") for line in lines] == [ + "user", + "message", + "user", + "message", + "fork_marker", + ] + assert out is not None + assert [m["content"] for m in out["messages"]] == ["round1", "answer1", "round2", "answer2"] + assert out["fork_boundary_message_count"] == 4 + + def test_write_session_messages_as_transcript_builds_canonical_prefix( tmp_path, monkeypatch,