mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-15 07:14:08 +00:00
feat(webui): persist fork boundary metadata
This commit is contained in:
parent
03bca4c0a9
commit
73d4b1cb2f
@ -28,7 +28,11 @@ from nanobot.security.workspace_access import (
|
||||
WorkspaceScopeError,
|
||||
)
|
||||
from nanobot.session.goal_state import goal_state_ws_blob
|
||||
from nanobot.session.webui_turns import websocket_turn_wall_started_at
|
||||
from nanobot.session.webui_turns import (
|
||||
WEBUI_TITLE_METADATA_KEY,
|
||||
clean_generated_title,
|
||||
websocket_turn_wall_started_at,
|
||||
)
|
||||
from nanobot.utils.media_decode import (
|
||||
FileSizeExceeded,
|
||||
save_base64_data_url,
|
||||
@ -46,6 +50,7 @@ from nanobot.webui.http_utils import (
|
||||
)
|
||||
from nanobot.webui.mcp_presets_api import normalize_mcp_preset_mentions
|
||||
from nanobot.webui.transcript import (
|
||||
append_fork_marker,
|
||||
delete_webui_transcript,
|
||||
fork_transcript_before_user_index,
|
||||
write_session_messages_as_transcript,
|
||||
@ -709,6 +714,13 @@ class WebSocketChannel(BaseChannel):
|
||||
)
|
||||
if not transcript_ok:
|
||||
write_session_messages_as_transcript(target_key, forked.messages)
|
||||
append_fork_marker(target_key)
|
||||
fork_title = clean_generated_title(
|
||||
envelope.get("title") if isinstance(envelope.get("title"), str) else None,
|
||||
)
|
||||
if fork_title:
|
||||
forked.metadata[WEBUI_TITLE_METADATA_KEY] = fork_title
|
||||
self.gateway.session_manager.save(forked, fsync=True)
|
||||
except Exception as exc:
|
||||
delete_webui_transcript(target_key)
|
||||
self.gateway.session_manager.delete_session(target_key)
|
||||
|
||||
@ -648,8 +648,8 @@ class SessionManager:
|
||||
``before_user_index`` is zero-based over user messages in the full session:
|
||||
``0`` means "before the first user message", ``1`` means "before the
|
||||
second user message", and so on. A value equal to the total user-message
|
||||
count copies the full session prefix. The target user message itself is
|
||||
not copied; the WebUI pre-fills it in the composer for editing and resend.
|
||||
count copies the full session prefix. WebUI assistant-reply forks pass
|
||||
the next user index so the selected completed assistant turn is included.
|
||||
"""
|
||||
if before_user_index < 0:
|
||||
return None
|
||||
|
||||
@ -17,6 +17,7 @@ from nanobot.config.paths import get_webui_dir
|
||||
from nanobot.session.manager import SessionManager
|
||||
|
||||
WEBUI_TRANSCRIPT_SCHEMA_VERSION = 3
|
||||
WEBUI_FORK_MARKER_EVENT = "fork_marker"
|
||||
_MAX_TRANSCRIPT_FILE_BYTES = 8 * 1024 * 1024
|
||||
_WEBUI_TURN_ID_RE = re.compile(r"^[A-Za-z0-9._:-]{1,128}$")
|
||||
WEBUI_TURN_METADATA_KEY = "webui_turn_id"
|
||||
@ -306,6 +307,8 @@ def fork_transcript_before_user_index(
|
||||
user_index = 0
|
||||
found_target = False
|
||||
for row in lines:
|
||||
if row.get("event") == WEBUI_FORK_MARKER_EVENT:
|
||||
continue
|
||||
if _is_user_transcript_row(row):
|
||||
if user_index == before_user_index:
|
||||
found_target = True
|
||||
@ -340,6 +343,17 @@ def fork_transcript_before_user_index(
|
||||
return True
|
||||
|
||||
|
||||
def append_fork_marker(session_key: str) -> None:
|
||||
"""Mark the UI-only boundary where a WebUI fork starts accepting new turns."""
|
||||
append_transcript_object(
|
||||
session_key,
|
||||
{
|
||||
"event": WEBUI_FORK_MARKER_EVENT,
|
||||
"chat_id": _chat_id_from_session_key(session_key),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def write_session_messages_as_transcript(
|
||||
target_key: str,
|
||||
messages: list[dict[str, Any]],
|
||||
@ -1397,6 +1411,28 @@ def replay_transcript_to_ui_messages(
|
||||
return messages
|
||||
|
||||
|
||||
def fork_boundary_message_count(
|
||||
lines: list[dict[str, Any]],
|
||||
*,
|
||||
augment_user_media: Callable[[list[str]], list[dict[str, Any]]] | None = None,
|
||||
augment_assistant_media: Callable[[list[str]], list[dict[str, Any]]] | None = None,
|
||||
augment_assistant_text: Callable[[str], str] | None = None,
|
||||
) -> int | None:
|
||||
"""Return the replayed UI message count before the first fork marker, if any."""
|
||||
for idx, rec in enumerate(lines):
|
||||
if rec.get("event") != WEBUI_FORK_MARKER_EVENT:
|
||||
continue
|
||||
return len(
|
||||
replay_transcript_to_ui_messages(
|
||||
lines[:idx],
|
||||
augment_user_media=augment_user_media,
|
||||
augment_assistant_media=augment_assistant_media,
|
||||
augment_assistant_text=augment_assistant_text,
|
||||
),
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def build_webui_thread_response(
|
||||
session_key: str,
|
||||
*,
|
||||
@ -1410,14 +1446,23 @@ def build_webui_thread_response(
|
||||
if not lines:
|
||||
return None
|
||||
lines = inject_missing_user_events_from_session(session_key, lines, session_messages)
|
||||
fork_boundary = fork_boundary_message_count(
|
||||
lines,
|
||||
augment_user_media=augment_user_media,
|
||||
augment_assistant_media=augment_assistant_media,
|
||||
augment_assistant_text=augment_assistant_text,
|
||||
)
|
||||
msgs = replay_transcript_to_ui_messages(
|
||||
lines,
|
||||
augment_user_media=augment_user_media,
|
||||
augment_assistant_media=augment_assistant_media,
|
||||
augment_assistant_text=augment_assistant_text,
|
||||
)
|
||||
return {
|
||||
payload = {
|
||||
"schemaVersion": WEBUI_TRANSCRIPT_SCHEMA_VERSION,
|
||||
"sessionKey": session_key,
|
||||
"messages": msgs,
|
||||
}
|
||||
if fork_boundary is not None:
|
||||
payload["fork_boundary_message_count"] = fork_boundary
|
||||
return payload
|
||||
|
||||
@ -454,6 +454,34 @@ def test_fork_session_before_user_index_copies_only_prefix(tmp_path):
|
||||
assert [m["content"] for m in saved["messages"]] == ["round1", "answer1"]
|
||||
|
||||
|
||||
def test_fork_session_from_middle_assistant_reply_keeps_selected_turn(tmp_path):
|
||||
manager = SessionManager(tmp_path)
|
||||
source = manager.get_or_create("websocket:source")
|
||||
source.add_message("user", "round1")
|
||||
source.add_message("assistant", "answer1")
|
||||
source.add_message("user", "round2")
|
||||
source.add_message("assistant", "answer2")
|
||||
source.add_message("user", "round3 must not appear")
|
||||
source.add_message("assistant", "answer3 must not appear")
|
||||
manager.save(source)
|
||||
|
||||
forked = manager.fork_session_before_user_index(
|
||||
"websocket:source",
|
||||
"websocket:fork",
|
||||
2,
|
||||
)
|
||||
|
||||
assert forked is not None
|
||||
assert [m["content"] for m in forked.messages] == [
|
||||
"round1",
|
||||
"answer1",
|
||||
"round2",
|
||||
"answer2",
|
||||
]
|
||||
saved = manager.read_session_file("websocket:fork")
|
||||
assert "round3 must not appear" not in str(saved)
|
||||
|
||||
|
||||
def test_fork_session_rejects_negative_missing_and_out_of_range(tmp_path):
|
||||
manager = SessionManager(tmp_path)
|
||||
source = manager.get_or_create("websocket:source")
|
||||
|
||||
@ -2422,7 +2422,12 @@ async def test_fork_chat_copies_only_prefix_session_and_transcript(
|
||||
await channel._dispatch_envelope(
|
||||
conn,
|
||||
"webui-client",
|
||||
{"type": "fork_chat", "source_chat_id": "source", "before_user_index": 1},
|
||||
{
|
||||
"type": "fork_chat",
|
||||
"source_chat_id": "source",
|
||||
"before_user_index": 1,
|
||||
"title": "Fork: Old title",
|
||||
},
|
||||
)
|
||||
|
||||
sent = [json.loads(call.args[0]) for call in conn.send.await_args_list]
|
||||
@ -2430,8 +2435,10 @@ async def test_fork_chat_copies_only_prefix_session_and_transcript(
|
||||
fork_id = attached["chat_id"]
|
||||
saved = sessions.read_session_file(f"websocket:{fork_id}")
|
||||
assert [m["content"] for m in saved["messages"]] == ["round1", "answer1"]
|
||||
assert saved["metadata"]["title"] == "Fork: Old title"
|
||||
fork_lines = read_transcript_lines(f"websocket:{fork_id}")
|
||||
assert [line.get("text") for line in fork_lines] == ["round1", "answer1", None]
|
||||
assert [line.get("text") for line in fork_lines] == ["round1", "answer1", None, None]
|
||||
assert fork_lines[-1]["event"] == "fork_marker"
|
||||
assert all(line.get("chat_id") == fork_id for line in fork_lines)
|
||||
assert "round3 must not appear" not in json.dumps(saved, ensure_ascii=False)
|
||||
bus.publish_inbound.assert_not_awaited()
|
||||
@ -2477,7 +2484,8 @@ async def test_fork_chat_falls_back_to_session_prefix_when_transcript_lacks_user
|
||||
saved = sessions.read_session_file(f"websocket:{fork_id}")
|
||||
assert [m["content"] for m in saved["messages"]] == ["round1", "answer1"]
|
||||
fork_lines = read_transcript_lines(f"websocket:{fork_id}")
|
||||
assert [line.get("text") for line in fork_lines] == ["round1", "answer1"]
|
||||
assert [line.get("text") for line in fork_lines] == ["round1", "answer1", None]
|
||||
assert fork_lines[-1]["event"] == "fork_marker"
|
||||
assert "round3 must not appear" not in json.dumps(fork_lines, ensure_ascii=False)
|
||||
bus.publish_inbound.assert_not_awaited()
|
||||
|
||||
@ -2520,7 +2528,8 @@ async def test_fork_chat_allows_index_equal_to_user_count(
|
||||
saved = sessions.read_session_file(f"websocket:{fork_id}")
|
||||
assert [m["content"] for m in saved["messages"]] == ["round1", "answer1"]
|
||||
fork_lines = read_transcript_lines(f"websocket:{fork_id}")
|
||||
assert [line.get("text") for line in fork_lines] == ["round1", "answer1"]
|
||||
assert [line.get("text") for line in fork_lines] == ["round1", "answer1", None]
|
||||
assert fork_lines[-1]["event"] == "fork_marker"
|
||||
bus.publish_inbound.assert_not_awaited()
|
||||
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
|
||||
from nanobot.webui.transcript import (
|
||||
WEBUI_TRANSCRIPT_SCHEMA_VERSION,
|
||||
append_fork_marker,
|
||||
append_transcript_object,
|
||||
build_webui_thread_response,
|
||||
fork_transcript_before_user_index,
|
||||
@ -45,6 +46,33 @@ def test_fork_transcript_before_user_index_copies_only_prefix(tmp_path, monkeypa
|
||||
assert "round3 must not appear" not in "\n".join(str(line.get("text")) for line in lines)
|
||||
|
||||
|
||||
def test_fork_transcript_from_middle_assistant_reply_keeps_selected_turn(
|
||||
tmp_path,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr("nanobot.config.paths.get_data_dir", lambda: tmp_path)
|
||||
source = "websocket:source"
|
||||
for ev in (
|
||||
{"event": "user", "chat_id": "source", "text": "round1"},
|
||||
{"event": "message", "chat_id": "source", "text": "answer1"},
|
||||
{"event": "user", "chat_id": "source", "text": "round2"},
|
||||
{"event": "message", "chat_id": "source", "text": "answer2"},
|
||||
{"event": "user", "chat_id": "source", "text": "round3 must not appear"},
|
||||
{"event": "message", "chat_id": "source", "text": "answer3 must not appear"},
|
||||
):
|
||||
append_transcript_object(source, ev)
|
||||
|
||||
ok = fork_transcript_before_user_index(source, "websocket:fork", 2)
|
||||
|
||||
assert ok is True
|
||||
assert [line.get("text") for line in read_transcript_lines("websocket:fork")] == [
|
||||
"round1",
|
||||
"answer1",
|
||||
"round2",
|
||||
"answer2",
|
||||
]
|
||||
|
||||
|
||||
def test_fork_transcript_rejects_out_of_range_user_index(tmp_path, monkeypatch) -> None:
|
||||
monkeypatch.setattr("nanobot.config.paths.get_data_dir", lambda: tmp_path)
|
||||
source = "websocket:source"
|
||||
@ -72,6 +100,58 @@ def test_fork_transcript_allows_index_equal_to_user_count(tmp_path, monkeypatch)
|
||||
]
|
||||
|
||||
|
||||
def test_build_response_reports_fork_boundary_from_marker(tmp_path, monkeypatch) -> None:
|
||||
monkeypatch.setattr("nanobot.config.paths.get_data_dir", lambda: tmp_path)
|
||||
key = "websocket:fork"
|
||||
for ev in (
|
||||
{"event": "user", "chat_id": "fork", "text": "round1"},
|
||||
{"event": "message", "chat_id": "fork", "text": "answer1"},
|
||||
):
|
||||
append_transcript_object(key, ev)
|
||||
append_fork_marker(key)
|
||||
append_transcript_object(key, {"event": "user", "chat_id": "fork", "text": "new branch"})
|
||||
|
||||
out = build_webui_thread_response(key)
|
||||
|
||||
assert out is not None
|
||||
assert [m["content"] for m in out["messages"]] == ["round1", "answer1", "new branch"]
|
||||
assert out["fork_boundary_message_count"] == 2
|
||||
|
||||
|
||||
def test_nested_fork_drops_inherited_fork_marker(tmp_path, monkeypatch) -> None:
|
||||
monkeypatch.setattr("nanobot.config.paths.get_data_dir", lambda: tmp_path)
|
||||
source = "websocket:source"
|
||||
for ev in (
|
||||
{"event": "user", "chat_id": "source", "text": "round1"},
|
||||
{"event": "message", "chat_id": "source", "text": "answer1"},
|
||||
):
|
||||
append_transcript_object(source, ev)
|
||||
append_fork_marker(source)
|
||||
for ev in (
|
||||
{"event": "user", "chat_id": "source", "text": "round2"},
|
||||
{"event": "message", "chat_id": "source", "text": "answer2"},
|
||||
):
|
||||
append_transcript_object(source, ev)
|
||||
|
||||
ok = fork_transcript_before_user_index(source, "websocket:nested", 2)
|
||||
append_fork_marker("websocket:nested")
|
||||
|
||||
lines = read_transcript_lines("websocket:nested")
|
||||
out = build_webui_thread_response("websocket:nested")
|
||||
|
||||
assert ok is True
|
||||
assert [line.get("event") for line in lines] == [
|
||||
"user",
|
||||
"message",
|
||||
"user",
|
||||
"message",
|
||||
"fork_marker",
|
||||
]
|
||||
assert out is not None
|
||||
assert [m["content"] for m in out["messages"]] == ["round1", "answer1", "round2", "answer2"]
|
||||
assert out["fork_boundary_message_count"] == 4
|
||||
|
||||
|
||||
def test_write_session_messages_as_transcript_builds_canonical_prefix(
|
||||
tmp_path,
|
||||
monkeypatch,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user