feat(webui): persist fork boundary metadata

This commit is contained in:
Xubin Ren 2026-06-10 02:07:14 +08:00
parent 03bca4c0a9
commit 73d4b1cb2f
6 changed files with 182 additions and 8 deletions

View File

@ -28,7 +28,11 @@ from nanobot.security.workspace_access import (
WorkspaceScopeError,
)
from nanobot.session.goal_state import goal_state_ws_blob
from nanobot.session.webui_turns import websocket_turn_wall_started_at
from nanobot.session.webui_turns import (
WEBUI_TITLE_METADATA_KEY,
clean_generated_title,
websocket_turn_wall_started_at,
)
from nanobot.utils.media_decode import (
FileSizeExceeded,
save_base64_data_url,
@ -46,6 +50,7 @@ from nanobot.webui.http_utils import (
)
from nanobot.webui.mcp_presets_api import normalize_mcp_preset_mentions
from nanobot.webui.transcript import (
append_fork_marker,
delete_webui_transcript,
fork_transcript_before_user_index,
write_session_messages_as_transcript,
@ -709,6 +714,13 @@ class WebSocketChannel(BaseChannel):
)
if not transcript_ok:
write_session_messages_as_transcript(target_key, forked.messages)
append_fork_marker(target_key)
fork_title = clean_generated_title(
envelope.get("title") if isinstance(envelope.get("title"), str) else None,
)
if fork_title:
forked.metadata[WEBUI_TITLE_METADATA_KEY] = fork_title
self.gateway.session_manager.save(forked, fsync=True)
except Exception as exc:
delete_webui_transcript(target_key)
self.gateway.session_manager.delete_session(target_key)

View File

@ -648,8 +648,8 @@ class SessionManager:
``before_user_index`` is zero-based over user messages in the full session:
``0`` means "before the first user message", ``1`` means "before the
second user message", and so on. A value equal to the total user-message
count copies the full session prefix. The target user message itself is
not copied; the WebUI pre-fills it in the composer for editing and resend.
count copies the full session prefix. WebUI assistant-reply forks pass
the next user index so the selected completed assistant turn is included.
"""
if before_user_index < 0:
return None

View File

@ -17,6 +17,7 @@ from nanobot.config.paths import get_webui_dir
from nanobot.session.manager import SessionManager
WEBUI_TRANSCRIPT_SCHEMA_VERSION = 3
WEBUI_FORK_MARKER_EVENT = "fork_marker"
_MAX_TRANSCRIPT_FILE_BYTES = 8 * 1024 * 1024
_WEBUI_TURN_ID_RE = re.compile(r"^[A-Za-z0-9._:-]{1,128}$")
WEBUI_TURN_METADATA_KEY = "webui_turn_id"
@ -306,6 +307,8 @@ def fork_transcript_before_user_index(
user_index = 0
found_target = False
for row in lines:
if row.get("event") == WEBUI_FORK_MARKER_EVENT:
continue
if _is_user_transcript_row(row):
if user_index == before_user_index:
found_target = True
@ -340,6 +343,17 @@ def fork_transcript_before_user_index(
return True
def append_fork_marker(session_key: str) -> None:
"""Mark the UI-only boundary where a WebUI fork starts accepting new turns."""
append_transcript_object(
session_key,
{
"event": WEBUI_FORK_MARKER_EVENT,
"chat_id": _chat_id_from_session_key(session_key),
},
)
def write_session_messages_as_transcript(
target_key: str,
messages: list[dict[str, Any]],
@ -1397,6 +1411,28 @@ def replay_transcript_to_ui_messages(
return messages
def fork_boundary_message_count(
lines: list[dict[str, Any]],
*,
augment_user_media: Callable[[list[str]], list[dict[str, Any]]] | None = None,
augment_assistant_media: Callable[[list[str]], list[dict[str, Any]]] | None = None,
augment_assistant_text: Callable[[str], str] | None = None,
) -> int | None:
"""Return the replayed UI message count before the first fork marker, if any."""
for idx, rec in enumerate(lines):
if rec.get("event") != WEBUI_FORK_MARKER_EVENT:
continue
return len(
replay_transcript_to_ui_messages(
lines[:idx],
augment_user_media=augment_user_media,
augment_assistant_media=augment_assistant_media,
augment_assistant_text=augment_assistant_text,
),
)
return None
def build_webui_thread_response(
session_key: str,
*,
@ -1410,14 +1446,23 @@ def build_webui_thread_response(
if not lines:
return None
lines = inject_missing_user_events_from_session(session_key, lines, session_messages)
fork_boundary = fork_boundary_message_count(
lines,
augment_user_media=augment_user_media,
augment_assistant_media=augment_assistant_media,
augment_assistant_text=augment_assistant_text,
)
msgs = replay_transcript_to_ui_messages(
lines,
augment_user_media=augment_user_media,
augment_assistant_media=augment_assistant_media,
augment_assistant_text=augment_assistant_text,
)
return {
payload = {
"schemaVersion": WEBUI_TRANSCRIPT_SCHEMA_VERSION,
"sessionKey": session_key,
"messages": msgs,
}
if fork_boundary is not None:
payload["fork_boundary_message_count"] = fork_boundary
return payload

View File

@ -454,6 +454,34 @@ def test_fork_session_before_user_index_copies_only_prefix(tmp_path):
assert [m["content"] for m in saved["messages"]] == ["round1", "answer1"]
def test_fork_session_from_middle_assistant_reply_keeps_selected_turn(tmp_path):
manager = SessionManager(tmp_path)
source = manager.get_or_create("websocket:source")
source.add_message("user", "round1")
source.add_message("assistant", "answer1")
source.add_message("user", "round2")
source.add_message("assistant", "answer2")
source.add_message("user", "round3 must not appear")
source.add_message("assistant", "answer3 must not appear")
manager.save(source)
forked = manager.fork_session_before_user_index(
"websocket:source",
"websocket:fork",
2,
)
assert forked is not None
assert [m["content"] for m in forked.messages] == [
"round1",
"answer1",
"round2",
"answer2",
]
saved = manager.read_session_file("websocket:fork")
assert "round3 must not appear" not in str(saved)
def test_fork_session_rejects_negative_missing_and_out_of_range(tmp_path):
manager = SessionManager(tmp_path)
source = manager.get_or_create("websocket:source")

View File

@ -2422,7 +2422,12 @@ async def test_fork_chat_copies_only_prefix_session_and_transcript(
await channel._dispatch_envelope(
conn,
"webui-client",
{"type": "fork_chat", "source_chat_id": "source", "before_user_index": 1},
{
"type": "fork_chat",
"source_chat_id": "source",
"before_user_index": 1,
"title": "Fork: Old title",
},
)
sent = [json.loads(call.args[0]) for call in conn.send.await_args_list]
@ -2430,8 +2435,10 @@ async def test_fork_chat_copies_only_prefix_session_and_transcript(
fork_id = attached["chat_id"]
saved = sessions.read_session_file(f"websocket:{fork_id}")
assert [m["content"] for m in saved["messages"]] == ["round1", "answer1"]
assert saved["metadata"]["title"] == "Fork: Old title"
fork_lines = read_transcript_lines(f"websocket:{fork_id}")
assert [line.get("text") for line in fork_lines] == ["round1", "answer1", None]
assert [line.get("text") for line in fork_lines] == ["round1", "answer1", None, None]
assert fork_lines[-1]["event"] == "fork_marker"
assert all(line.get("chat_id") == fork_id for line in fork_lines)
assert "round3 must not appear" not in json.dumps(saved, ensure_ascii=False)
bus.publish_inbound.assert_not_awaited()
@ -2477,7 +2484,8 @@ async def test_fork_chat_falls_back_to_session_prefix_when_transcript_lacks_user
saved = sessions.read_session_file(f"websocket:{fork_id}")
assert [m["content"] for m in saved["messages"]] == ["round1", "answer1"]
fork_lines = read_transcript_lines(f"websocket:{fork_id}")
assert [line.get("text") for line in fork_lines] == ["round1", "answer1"]
assert [line.get("text") for line in fork_lines] == ["round1", "answer1", None]
assert fork_lines[-1]["event"] == "fork_marker"
assert "round3 must not appear" not in json.dumps(fork_lines, ensure_ascii=False)
bus.publish_inbound.assert_not_awaited()
@ -2520,7 +2528,8 @@ async def test_fork_chat_allows_index_equal_to_user_count(
saved = sessions.read_session_file(f"websocket:{fork_id}")
assert [m["content"] for m in saved["messages"]] == ["round1", "answer1"]
fork_lines = read_transcript_lines(f"websocket:{fork_id}")
assert [line.get("text") for line in fork_lines] == ["round1", "answer1"]
assert [line.get("text") for line in fork_lines] == ["round1", "answer1", None]
assert fork_lines[-1]["event"] == "fork_marker"
bus.publish_inbound.assert_not_awaited()

View File

@ -4,6 +4,7 @@ from __future__ import annotations
from nanobot.webui.transcript import (
WEBUI_TRANSCRIPT_SCHEMA_VERSION,
append_fork_marker,
append_transcript_object,
build_webui_thread_response,
fork_transcript_before_user_index,
@ -45,6 +46,33 @@ def test_fork_transcript_before_user_index_copies_only_prefix(tmp_path, monkeypa
assert "round3 must not appear" not in "\n".join(str(line.get("text")) for line in lines)
def test_fork_transcript_from_middle_assistant_reply_keeps_selected_turn(
tmp_path,
monkeypatch,
) -> None:
monkeypatch.setattr("nanobot.config.paths.get_data_dir", lambda: tmp_path)
source = "websocket:source"
for ev in (
{"event": "user", "chat_id": "source", "text": "round1"},
{"event": "message", "chat_id": "source", "text": "answer1"},
{"event": "user", "chat_id": "source", "text": "round2"},
{"event": "message", "chat_id": "source", "text": "answer2"},
{"event": "user", "chat_id": "source", "text": "round3 must not appear"},
{"event": "message", "chat_id": "source", "text": "answer3 must not appear"},
):
append_transcript_object(source, ev)
ok = fork_transcript_before_user_index(source, "websocket:fork", 2)
assert ok is True
assert [line.get("text") for line in read_transcript_lines("websocket:fork")] == [
"round1",
"answer1",
"round2",
"answer2",
]
def test_fork_transcript_rejects_out_of_range_user_index(tmp_path, monkeypatch) -> None:
monkeypatch.setattr("nanobot.config.paths.get_data_dir", lambda: tmp_path)
source = "websocket:source"
@ -72,6 +100,58 @@ def test_fork_transcript_allows_index_equal_to_user_count(tmp_path, monkeypatch)
]
def test_build_response_reports_fork_boundary_from_marker(tmp_path, monkeypatch) -> None:
monkeypatch.setattr("nanobot.config.paths.get_data_dir", lambda: tmp_path)
key = "websocket:fork"
for ev in (
{"event": "user", "chat_id": "fork", "text": "round1"},
{"event": "message", "chat_id": "fork", "text": "answer1"},
):
append_transcript_object(key, ev)
append_fork_marker(key)
append_transcript_object(key, {"event": "user", "chat_id": "fork", "text": "new branch"})
out = build_webui_thread_response(key)
assert out is not None
assert [m["content"] for m in out["messages"]] == ["round1", "answer1", "new branch"]
assert out["fork_boundary_message_count"] == 2
def test_nested_fork_drops_inherited_fork_marker(tmp_path, monkeypatch) -> None:
monkeypatch.setattr("nanobot.config.paths.get_data_dir", lambda: tmp_path)
source = "websocket:source"
for ev in (
{"event": "user", "chat_id": "source", "text": "round1"},
{"event": "message", "chat_id": "source", "text": "answer1"},
):
append_transcript_object(source, ev)
append_fork_marker(source)
for ev in (
{"event": "user", "chat_id": "source", "text": "round2"},
{"event": "message", "chat_id": "source", "text": "answer2"},
):
append_transcript_object(source, ev)
ok = fork_transcript_before_user_index(source, "websocket:nested", 2)
append_fork_marker("websocket:nested")
lines = read_transcript_lines("websocket:nested")
out = build_webui_thread_response("websocket:nested")
assert ok is True
assert [line.get("event") for line in lines] == [
"user",
"message",
"user",
"message",
"fork_marker",
]
assert out is not None
assert [m["content"] for m in out["messages"]] == ["round1", "answer1", "round2", "answer2"]
assert out["fork_boundary_message_count"] == 4
def test_write_session_messages_as_transcript_builds_canonical_prefix(
tmp_path,
monkeypatch,