mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-15 15:24:06 +00:00
feat(webui): persist fork boundary metadata
This commit is contained in:
parent
03bca4c0a9
commit
73d4b1cb2f
@ -28,7 +28,11 @@ from nanobot.security.workspace_access import (
|
|||||||
WorkspaceScopeError,
|
WorkspaceScopeError,
|
||||||
)
|
)
|
||||||
from nanobot.session.goal_state import goal_state_ws_blob
|
from nanobot.session.goal_state import goal_state_ws_blob
|
||||||
from nanobot.session.webui_turns import websocket_turn_wall_started_at
|
from nanobot.session.webui_turns import (
|
||||||
|
WEBUI_TITLE_METADATA_KEY,
|
||||||
|
clean_generated_title,
|
||||||
|
websocket_turn_wall_started_at,
|
||||||
|
)
|
||||||
from nanobot.utils.media_decode import (
|
from nanobot.utils.media_decode import (
|
||||||
FileSizeExceeded,
|
FileSizeExceeded,
|
||||||
save_base64_data_url,
|
save_base64_data_url,
|
||||||
@ -46,6 +50,7 @@ from nanobot.webui.http_utils import (
|
|||||||
)
|
)
|
||||||
from nanobot.webui.mcp_presets_api import normalize_mcp_preset_mentions
|
from nanobot.webui.mcp_presets_api import normalize_mcp_preset_mentions
|
||||||
from nanobot.webui.transcript import (
|
from nanobot.webui.transcript import (
|
||||||
|
append_fork_marker,
|
||||||
delete_webui_transcript,
|
delete_webui_transcript,
|
||||||
fork_transcript_before_user_index,
|
fork_transcript_before_user_index,
|
||||||
write_session_messages_as_transcript,
|
write_session_messages_as_transcript,
|
||||||
@ -709,6 +714,13 @@ class WebSocketChannel(BaseChannel):
|
|||||||
)
|
)
|
||||||
if not transcript_ok:
|
if not transcript_ok:
|
||||||
write_session_messages_as_transcript(target_key, forked.messages)
|
write_session_messages_as_transcript(target_key, forked.messages)
|
||||||
|
append_fork_marker(target_key)
|
||||||
|
fork_title = clean_generated_title(
|
||||||
|
envelope.get("title") if isinstance(envelope.get("title"), str) else None,
|
||||||
|
)
|
||||||
|
if fork_title:
|
||||||
|
forked.metadata[WEBUI_TITLE_METADATA_KEY] = fork_title
|
||||||
|
self.gateway.session_manager.save(forked, fsync=True)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
delete_webui_transcript(target_key)
|
delete_webui_transcript(target_key)
|
||||||
self.gateway.session_manager.delete_session(target_key)
|
self.gateway.session_manager.delete_session(target_key)
|
||||||
|
|||||||
@ -648,8 +648,8 @@ class SessionManager:
|
|||||||
``before_user_index`` is zero-based over user messages in the full session:
|
``before_user_index`` is zero-based over user messages in the full session:
|
||||||
``0`` means "before the first user message", ``1`` means "before the
|
``0`` means "before the first user message", ``1`` means "before the
|
||||||
second user message", and so on. A value equal to the total user-message
|
second user message", and so on. A value equal to the total user-message
|
||||||
count copies the full session prefix. The target user message itself is
|
count copies the full session prefix. WebUI assistant-reply forks pass
|
||||||
not copied; the WebUI pre-fills it in the composer for editing and resend.
|
the next user index so the selected completed assistant turn is included.
|
||||||
"""
|
"""
|
||||||
if before_user_index < 0:
|
if before_user_index < 0:
|
||||||
return None
|
return None
|
||||||
|
|||||||
@ -17,6 +17,7 @@ from nanobot.config.paths import get_webui_dir
|
|||||||
from nanobot.session.manager import SessionManager
|
from nanobot.session.manager import SessionManager
|
||||||
|
|
||||||
WEBUI_TRANSCRIPT_SCHEMA_VERSION = 3
|
WEBUI_TRANSCRIPT_SCHEMA_VERSION = 3
|
||||||
|
WEBUI_FORK_MARKER_EVENT = "fork_marker"
|
||||||
_MAX_TRANSCRIPT_FILE_BYTES = 8 * 1024 * 1024
|
_MAX_TRANSCRIPT_FILE_BYTES = 8 * 1024 * 1024
|
||||||
_WEBUI_TURN_ID_RE = re.compile(r"^[A-Za-z0-9._:-]{1,128}$")
|
_WEBUI_TURN_ID_RE = re.compile(r"^[A-Za-z0-9._:-]{1,128}$")
|
||||||
WEBUI_TURN_METADATA_KEY = "webui_turn_id"
|
WEBUI_TURN_METADATA_KEY = "webui_turn_id"
|
||||||
@ -306,6 +307,8 @@ def fork_transcript_before_user_index(
|
|||||||
user_index = 0
|
user_index = 0
|
||||||
found_target = False
|
found_target = False
|
||||||
for row in lines:
|
for row in lines:
|
||||||
|
if row.get("event") == WEBUI_FORK_MARKER_EVENT:
|
||||||
|
continue
|
||||||
if _is_user_transcript_row(row):
|
if _is_user_transcript_row(row):
|
||||||
if user_index == before_user_index:
|
if user_index == before_user_index:
|
||||||
found_target = True
|
found_target = True
|
||||||
@ -340,6 +343,17 @@ def fork_transcript_before_user_index(
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def append_fork_marker(session_key: str) -> None:
|
||||||
|
"""Mark the UI-only boundary where a WebUI fork starts accepting new turns."""
|
||||||
|
append_transcript_object(
|
||||||
|
session_key,
|
||||||
|
{
|
||||||
|
"event": WEBUI_FORK_MARKER_EVENT,
|
||||||
|
"chat_id": _chat_id_from_session_key(session_key),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def write_session_messages_as_transcript(
|
def write_session_messages_as_transcript(
|
||||||
target_key: str,
|
target_key: str,
|
||||||
messages: list[dict[str, Any]],
|
messages: list[dict[str, Any]],
|
||||||
@ -1397,6 +1411,28 @@ def replay_transcript_to_ui_messages(
|
|||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
def fork_boundary_message_count(
|
||||||
|
lines: list[dict[str, Any]],
|
||||||
|
*,
|
||||||
|
augment_user_media: Callable[[list[str]], list[dict[str, Any]]] | None = None,
|
||||||
|
augment_assistant_media: Callable[[list[str]], list[dict[str, Any]]] | None = None,
|
||||||
|
augment_assistant_text: Callable[[str], str] | None = None,
|
||||||
|
) -> int | None:
|
||||||
|
"""Return the replayed UI message count before the first fork marker, if any."""
|
||||||
|
for idx, rec in enumerate(lines):
|
||||||
|
if rec.get("event") != WEBUI_FORK_MARKER_EVENT:
|
||||||
|
continue
|
||||||
|
return len(
|
||||||
|
replay_transcript_to_ui_messages(
|
||||||
|
lines[:idx],
|
||||||
|
augment_user_media=augment_user_media,
|
||||||
|
augment_assistant_media=augment_assistant_media,
|
||||||
|
augment_assistant_text=augment_assistant_text,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def build_webui_thread_response(
|
def build_webui_thread_response(
|
||||||
session_key: str,
|
session_key: str,
|
||||||
*,
|
*,
|
||||||
@ -1410,14 +1446,23 @@ def build_webui_thread_response(
|
|||||||
if not lines:
|
if not lines:
|
||||||
return None
|
return None
|
||||||
lines = inject_missing_user_events_from_session(session_key, lines, session_messages)
|
lines = inject_missing_user_events_from_session(session_key, lines, session_messages)
|
||||||
|
fork_boundary = fork_boundary_message_count(
|
||||||
|
lines,
|
||||||
|
augment_user_media=augment_user_media,
|
||||||
|
augment_assistant_media=augment_assistant_media,
|
||||||
|
augment_assistant_text=augment_assistant_text,
|
||||||
|
)
|
||||||
msgs = replay_transcript_to_ui_messages(
|
msgs = replay_transcript_to_ui_messages(
|
||||||
lines,
|
lines,
|
||||||
augment_user_media=augment_user_media,
|
augment_user_media=augment_user_media,
|
||||||
augment_assistant_media=augment_assistant_media,
|
augment_assistant_media=augment_assistant_media,
|
||||||
augment_assistant_text=augment_assistant_text,
|
augment_assistant_text=augment_assistant_text,
|
||||||
)
|
)
|
||||||
return {
|
payload = {
|
||||||
"schemaVersion": WEBUI_TRANSCRIPT_SCHEMA_VERSION,
|
"schemaVersion": WEBUI_TRANSCRIPT_SCHEMA_VERSION,
|
||||||
"sessionKey": session_key,
|
"sessionKey": session_key,
|
||||||
"messages": msgs,
|
"messages": msgs,
|
||||||
}
|
}
|
||||||
|
if fork_boundary is not None:
|
||||||
|
payload["fork_boundary_message_count"] = fork_boundary
|
||||||
|
return payload
|
||||||
|
|||||||
@ -454,6 +454,34 @@ def test_fork_session_before_user_index_copies_only_prefix(tmp_path):
|
|||||||
assert [m["content"] for m in saved["messages"]] == ["round1", "answer1"]
|
assert [m["content"] for m in saved["messages"]] == ["round1", "answer1"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_fork_session_from_middle_assistant_reply_keeps_selected_turn(tmp_path):
|
||||||
|
manager = SessionManager(tmp_path)
|
||||||
|
source = manager.get_or_create("websocket:source")
|
||||||
|
source.add_message("user", "round1")
|
||||||
|
source.add_message("assistant", "answer1")
|
||||||
|
source.add_message("user", "round2")
|
||||||
|
source.add_message("assistant", "answer2")
|
||||||
|
source.add_message("user", "round3 must not appear")
|
||||||
|
source.add_message("assistant", "answer3 must not appear")
|
||||||
|
manager.save(source)
|
||||||
|
|
||||||
|
forked = manager.fork_session_before_user_index(
|
||||||
|
"websocket:source",
|
||||||
|
"websocket:fork",
|
||||||
|
2,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert forked is not None
|
||||||
|
assert [m["content"] for m in forked.messages] == [
|
||||||
|
"round1",
|
||||||
|
"answer1",
|
||||||
|
"round2",
|
||||||
|
"answer2",
|
||||||
|
]
|
||||||
|
saved = manager.read_session_file("websocket:fork")
|
||||||
|
assert "round3 must not appear" not in str(saved)
|
||||||
|
|
||||||
|
|
||||||
def test_fork_session_rejects_negative_missing_and_out_of_range(tmp_path):
|
def test_fork_session_rejects_negative_missing_and_out_of_range(tmp_path):
|
||||||
manager = SessionManager(tmp_path)
|
manager = SessionManager(tmp_path)
|
||||||
source = manager.get_or_create("websocket:source")
|
source = manager.get_or_create("websocket:source")
|
||||||
|
|||||||
@ -2422,7 +2422,12 @@ async def test_fork_chat_copies_only_prefix_session_and_transcript(
|
|||||||
await channel._dispatch_envelope(
|
await channel._dispatch_envelope(
|
||||||
conn,
|
conn,
|
||||||
"webui-client",
|
"webui-client",
|
||||||
{"type": "fork_chat", "source_chat_id": "source", "before_user_index": 1},
|
{
|
||||||
|
"type": "fork_chat",
|
||||||
|
"source_chat_id": "source",
|
||||||
|
"before_user_index": 1,
|
||||||
|
"title": "Fork: Old title",
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
sent = [json.loads(call.args[0]) for call in conn.send.await_args_list]
|
sent = [json.loads(call.args[0]) for call in conn.send.await_args_list]
|
||||||
@ -2430,8 +2435,10 @@ async def test_fork_chat_copies_only_prefix_session_and_transcript(
|
|||||||
fork_id = attached["chat_id"]
|
fork_id = attached["chat_id"]
|
||||||
saved = sessions.read_session_file(f"websocket:{fork_id}")
|
saved = sessions.read_session_file(f"websocket:{fork_id}")
|
||||||
assert [m["content"] for m in saved["messages"]] == ["round1", "answer1"]
|
assert [m["content"] for m in saved["messages"]] == ["round1", "answer1"]
|
||||||
|
assert saved["metadata"]["title"] == "Fork: Old title"
|
||||||
fork_lines = read_transcript_lines(f"websocket:{fork_id}")
|
fork_lines = read_transcript_lines(f"websocket:{fork_id}")
|
||||||
assert [line.get("text") for line in fork_lines] == ["round1", "answer1", None]
|
assert [line.get("text") for line in fork_lines] == ["round1", "answer1", None, None]
|
||||||
|
assert fork_lines[-1]["event"] == "fork_marker"
|
||||||
assert all(line.get("chat_id") == fork_id for line in fork_lines)
|
assert all(line.get("chat_id") == fork_id for line in fork_lines)
|
||||||
assert "round3 must not appear" not in json.dumps(saved, ensure_ascii=False)
|
assert "round3 must not appear" not in json.dumps(saved, ensure_ascii=False)
|
||||||
bus.publish_inbound.assert_not_awaited()
|
bus.publish_inbound.assert_not_awaited()
|
||||||
@ -2477,7 +2484,8 @@ async def test_fork_chat_falls_back_to_session_prefix_when_transcript_lacks_user
|
|||||||
saved = sessions.read_session_file(f"websocket:{fork_id}")
|
saved = sessions.read_session_file(f"websocket:{fork_id}")
|
||||||
assert [m["content"] for m in saved["messages"]] == ["round1", "answer1"]
|
assert [m["content"] for m in saved["messages"]] == ["round1", "answer1"]
|
||||||
fork_lines = read_transcript_lines(f"websocket:{fork_id}")
|
fork_lines = read_transcript_lines(f"websocket:{fork_id}")
|
||||||
assert [line.get("text") for line in fork_lines] == ["round1", "answer1"]
|
assert [line.get("text") for line in fork_lines] == ["round1", "answer1", None]
|
||||||
|
assert fork_lines[-1]["event"] == "fork_marker"
|
||||||
assert "round3 must not appear" not in json.dumps(fork_lines, ensure_ascii=False)
|
assert "round3 must not appear" not in json.dumps(fork_lines, ensure_ascii=False)
|
||||||
bus.publish_inbound.assert_not_awaited()
|
bus.publish_inbound.assert_not_awaited()
|
||||||
|
|
||||||
@ -2520,7 +2528,8 @@ async def test_fork_chat_allows_index_equal_to_user_count(
|
|||||||
saved = sessions.read_session_file(f"websocket:{fork_id}")
|
saved = sessions.read_session_file(f"websocket:{fork_id}")
|
||||||
assert [m["content"] for m in saved["messages"]] == ["round1", "answer1"]
|
assert [m["content"] for m in saved["messages"]] == ["round1", "answer1"]
|
||||||
fork_lines = read_transcript_lines(f"websocket:{fork_id}")
|
fork_lines = read_transcript_lines(f"websocket:{fork_id}")
|
||||||
assert [line.get("text") for line in fork_lines] == ["round1", "answer1"]
|
assert [line.get("text") for line in fork_lines] == ["round1", "answer1", None]
|
||||||
|
assert fork_lines[-1]["event"] == "fork_marker"
|
||||||
bus.publish_inbound.assert_not_awaited()
|
bus.publish_inbound.assert_not_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -4,6 +4,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from nanobot.webui.transcript import (
|
from nanobot.webui.transcript import (
|
||||||
WEBUI_TRANSCRIPT_SCHEMA_VERSION,
|
WEBUI_TRANSCRIPT_SCHEMA_VERSION,
|
||||||
|
append_fork_marker,
|
||||||
append_transcript_object,
|
append_transcript_object,
|
||||||
build_webui_thread_response,
|
build_webui_thread_response,
|
||||||
fork_transcript_before_user_index,
|
fork_transcript_before_user_index,
|
||||||
@ -45,6 +46,33 @@ def test_fork_transcript_before_user_index_copies_only_prefix(tmp_path, monkeypa
|
|||||||
assert "round3 must not appear" not in "\n".join(str(line.get("text")) for line in lines)
|
assert "round3 must not appear" not in "\n".join(str(line.get("text")) for line in lines)
|
||||||
|
|
||||||
|
|
||||||
|
def test_fork_transcript_from_middle_assistant_reply_keeps_selected_turn(
|
||||||
|
tmp_path,
|
||||||
|
monkeypatch,
|
||||||
|
) -> None:
|
||||||
|
monkeypatch.setattr("nanobot.config.paths.get_data_dir", lambda: tmp_path)
|
||||||
|
source = "websocket:source"
|
||||||
|
for ev in (
|
||||||
|
{"event": "user", "chat_id": "source", "text": "round1"},
|
||||||
|
{"event": "message", "chat_id": "source", "text": "answer1"},
|
||||||
|
{"event": "user", "chat_id": "source", "text": "round2"},
|
||||||
|
{"event": "message", "chat_id": "source", "text": "answer2"},
|
||||||
|
{"event": "user", "chat_id": "source", "text": "round3 must not appear"},
|
||||||
|
{"event": "message", "chat_id": "source", "text": "answer3 must not appear"},
|
||||||
|
):
|
||||||
|
append_transcript_object(source, ev)
|
||||||
|
|
||||||
|
ok = fork_transcript_before_user_index(source, "websocket:fork", 2)
|
||||||
|
|
||||||
|
assert ok is True
|
||||||
|
assert [line.get("text") for line in read_transcript_lines("websocket:fork")] == [
|
||||||
|
"round1",
|
||||||
|
"answer1",
|
||||||
|
"round2",
|
||||||
|
"answer2",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_fork_transcript_rejects_out_of_range_user_index(tmp_path, monkeypatch) -> None:
|
def test_fork_transcript_rejects_out_of_range_user_index(tmp_path, monkeypatch) -> None:
|
||||||
monkeypatch.setattr("nanobot.config.paths.get_data_dir", lambda: tmp_path)
|
monkeypatch.setattr("nanobot.config.paths.get_data_dir", lambda: tmp_path)
|
||||||
source = "websocket:source"
|
source = "websocket:source"
|
||||||
@ -72,6 +100,58 @@ def test_fork_transcript_allows_index_equal_to_user_count(tmp_path, monkeypatch)
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_response_reports_fork_boundary_from_marker(tmp_path, monkeypatch) -> None:
|
||||||
|
monkeypatch.setattr("nanobot.config.paths.get_data_dir", lambda: tmp_path)
|
||||||
|
key = "websocket:fork"
|
||||||
|
for ev in (
|
||||||
|
{"event": "user", "chat_id": "fork", "text": "round1"},
|
||||||
|
{"event": "message", "chat_id": "fork", "text": "answer1"},
|
||||||
|
):
|
||||||
|
append_transcript_object(key, ev)
|
||||||
|
append_fork_marker(key)
|
||||||
|
append_transcript_object(key, {"event": "user", "chat_id": "fork", "text": "new branch"})
|
||||||
|
|
||||||
|
out = build_webui_thread_response(key)
|
||||||
|
|
||||||
|
assert out is not None
|
||||||
|
assert [m["content"] for m in out["messages"]] == ["round1", "answer1", "new branch"]
|
||||||
|
assert out["fork_boundary_message_count"] == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_nested_fork_drops_inherited_fork_marker(tmp_path, monkeypatch) -> None:
|
||||||
|
monkeypatch.setattr("nanobot.config.paths.get_data_dir", lambda: tmp_path)
|
||||||
|
source = "websocket:source"
|
||||||
|
for ev in (
|
||||||
|
{"event": "user", "chat_id": "source", "text": "round1"},
|
||||||
|
{"event": "message", "chat_id": "source", "text": "answer1"},
|
||||||
|
):
|
||||||
|
append_transcript_object(source, ev)
|
||||||
|
append_fork_marker(source)
|
||||||
|
for ev in (
|
||||||
|
{"event": "user", "chat_id": "source", "text": "round2"},
|
||||||
|
{"event": "message", "chat_id": "source", "text": "answer2"},
|
||||||
|
):
|
||||||
|
append_transcript_object(source, ev)
|
||||||
|
|
||||||
|
ok = fork_transcript_before_user_index(source, "websocket:nested", 2)
|
||||||
|
append_fork_marker("websocket:nested")
|
||||||
|
|
||||||
|
lines = read_transcript_lines("websocket:nested")
|
||||||
|
out = build_webui_thread_response("websocket:nested")
|
||||||
|
|
||||||
|
assert ok is True
|
||||||
|
assert [line.get("event") for line in lines] == [
|
||||||
|
"user",
|
||||||
|
"message",
|
||||||
|
"user",
|
||||||
|
"message",
|
||||||
|
"fork_marker",
|
||||||
|
]
|
||||||
|
assert out is not None
|
||||||
|
assert [m["content"] for m in out["messages"]] == ["round1", "answer1", "round2", "answer2"]
|
||||||
|
assert out["fork_boundary_message_count"] == 4
|
||||||
|
|
||||||
|
|
||||||
def test_write_session_messages_as_transcript_builds_canonical_prefix(
|
def test_write_session_messages_as_transcript_builds_canonical_prefix(
|
||||||
tmp_path,
|
tmp_path,
|
||||||
monkeypatch,
|
monkeypatch,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user