fix(webui): persist user messages for refresh

This commit is contained in:
chengyongru 2026-06-05 15:50:32 +08:00 committed by Xubin Ren
parent 3da68ac7fe
commit 710d00a179
5 changed files with 427 additions and 2 deletions

View File

@ -48,7 +48,7 @@ from nanobot.webui.http_utils import (
query_first as _query_first,
)
from nanobot.webui.mcp_presets_api import normalize_mcp_preset_mentions
from nanobot.webui.transcript import append_transcript_object
from nanobot.webui.transcript import append_transcript_object, build_user_transcript_event
from nanobot.webui.websocket_logging import websockets_server_logger
@ -768,6 +768,14 @@ class WebSocketChannel(BaseChannel):
"enabled": True,
"aspect_ratio": aspect_ratio if isinstance(aspect_ratio, str) else None,
}
if envelope.get("webui") is True and self.is_allowed(client_id):
self._try_append_webui_user_transcript(
cid,
content,
media_paths=media_paths,
cli_apps=cli_apps,
mcp_presets=mcp_presets,
)
await self._handle_message(
sender_id=client_id,
chat_id=cid,
@ -833,9 +841,31 @@ class WebSocketChannel(BaseChannel):
try:
dup = json.loads(json.dumps(wire, ensure_ascii=False))
append_transcript_object(sk, dup)
except (ValueError, TypeError) as e:
except (OSError, ValueError, TypeError) as e:
self.logger.warning("webui transcript append failed: {}", e)
def _try_append_webui_user_transcript(
self,
chat_id: str,
content: str,
*,
media_paths: list[str],
cli_apps: list[dict[str, Any]],
mcp_presets: list[dict[str, Any]],
) -> None:
if content.strip() == "/stop" and not media_paths:
return
payload = build_user_transcript_event(
chat_id,
content,
media_paths=media_paths,
cli_apps=cli_apps,
mcp_presets=mcp_presets,
)
if payload is None:
return
self._try_append_webui_transcript(chat_id, payload)
async def send(self, msg: OutboundMessage) -> None:
if msg.metadata.get("_runtime_model_updated"):
await self.send_runtime_model_updated(

View File

@ -40,6 +40,15 @@ _FILE_EDIT_TOOL_NAMES: frozenset[str] = frozenset({
"edit_file",
"apply_patch",
})
_TURN_DISPLAY_EVENTS: frozenset[str] = frozenset({
"reasoning_delta",
"reasoning_end",
"delta",
"stream_end",
"message",
"file_edit",
"turn_end",
})
def rewrite_local_markdown_images(
@ -155,6 +164,165 @@ def delete_webui_transcript(session_key: str) -> bool:
return False
def build_user_transcript_event(
chat_id: str,
text: str,
*,
media_paths: list[Any] | None = None,
cli_apps: list[Any] | None = None,
mcp_presets: list[Any] | None = None,
) -> dict[str, Any] | None:
paths = [str(path) for path in (media_paths or []) if path]
if not text and not paths:
return None
event: dict[str, Any] = {
"event": "user",
"chat_id": chat_id,
"text": text,
}
if paths:
event["media_paths"] = paths
apps = [dict(app) for app in (cli_apps or []) if isinstance(app, Mapping)]
if apps:
event["cli_apps"] = apps
presets = [dict(preset) for preset in (mcp_presets or []) if isinstance(preset, Mapping)]
if presets:
event["mcp_presets"] = presets
return event
def _session_user_event(
session_key: str,
message: dict[str, Any],
) -> dict[str, Any] | None:
if message.get("role") != "user":
return None
content = message.get("content")
text = content if isinstance(content, str) else ""
media = message.get("media")
cli_apps = message.get("cli_apps")
mcp_presets = message.get("mcp_presets")
chat_id = session_key.split(":", 1)[1] if ":" in session_key else session_key
return build_user_transcript_event(
chat_id,
text,
media_paths=media if isinstance(media, list) else None,
cli_apps=cli_apps if isinstance(cli_apps, list) else None,
mcp_presets=mcp_presets if isinstance(mcp_presets, list) else None,
)
def _assistant_text_signature(value: Any) -> str:
return value.strip() if isinstance(value, str) else ""
def _session_backfill_turns(
session_key: str,
session_messages: list[dict[str, Any]],
) -> list[tuple[dict[str, Any], tuple[str, ...]]]:
turns: list[tuple[dict[str, Any], tuple[str, ...]]] = []
current_user: dict[str, Any] | None = None
assistant_texts: list[str] = []
def flush() -> None:
if current_user is None:
return
signature = tuple(text for text in assistant_texts if text)
if signature:
turns.append((current_user, signature))
for message in session_messages:
role = message.get("role")
if role == "user":
flush()
current_user = _session_user_event(session_key, message)
assistant_texts = []
continue
if role == "assistant" and current_user is not None:
text = _assistant_text_signature(message.get("content"))
if text:
assistant_texts.append(text)
flush()
return turns
def _split_transcript_turns(lines: list[dict[str, Any]]) -> list[list[dict[str, Any]]]:
turns: list[list[dict[str, Any]]] = []
current: list[dict[str, Any]] = []
for rec in lines:
current.append(rec)
if rec.get("event") == "turn_end":
turns.append(current)
current = []
if current:
turns.append(current)
return turns
def _transcript_turn_signature(records: list[dict[str, Any]]) -> tuple[str, ...]:
texts: list[str] = []
for message in replay_transcript_to_ui_messages(records):
if message.get("role") != "assistant" or message.get("kind") == "trace":
continue
text = _assistant_text_signature(message.get("content"))
if text:
texts.append(text)
return tuple(texts)
def _find_unique_session_turn(
session_turns: list[tuple[dict[str, Any], tuple[str, ...]]],
signature: tuple[str, ...],
start: int,
) -> int | None:
if not signature:
return None
found: int | None = None
for index in range(start, len(session_turns)):
if session_turns[index][1] != signature:
continue
if found is not None:
return None
found = index
return found
def _with_backfilled_user(
records: list[dict[str, Any]],
user_event: dict[str, Any],
) -> list[dict[str, Any]]:
for index, rec in enumerate(records):
if rec.get("event") in _TURN_DISPLAY_EVENTS:
return [*records[:index], dict(user_event), *records[index:]]
return records
def inject_missing_user_events_from_session(
session_key: str,
lines: list[dict[str, Any]],
session_messages: list[dict[str, Any]] | None,
) -> list[dict[str, Any]]:
"""Backfill user rows for legacy WebUI transcripts that only stored assistant streams."""
if not lines or not session_messages:
return lines
session_turns = _session_backfill_turns(session_key, session_messages)
if not session_turns:
return lines
out: list[dict[str, Any]] = []
session_cursor = 0
for turn in _split_transcript_turns(lines):
has_user = any(rec.get("event") == "user" for rec in turn)
signature = _transcript_turn_signature(turn)
match_index = _find_unique_session_turn(session_turns, signature, session_cursor)
if match_index is None:
out.extend(turn)
continue
out.extend(turn if has_user else _with_backfilled_user(turn, session_turns[match_index][0]))
session_cursor = match_index + 1
return out
def _format_tool_call_trace(call: Any) -> str | None:
if not call or not isinstance(call, dict):
return None
@ -904,11 +1072,13 @@ def build_webui_thread_response(
augment_user_media: Callable[[list[str]], list[dict[str, Any]]] | None = None,
augment_assistant_media: Callable[[list[str]], list[dict[str, Any]]] | None = None,
augment_assistant_text: Callable[[str], str] | None = None,
session_messages: list[dict[str, Any]] | None = None,
) -> dict[str, Any] | None:
"""Return a payload compatible with ``WebuiThreadPersistedPayload``."""
lines = read_transcript_lines(session_key)
if not lines:
return None
lines = inject_missing_user_events_from_session(session_key, lines, session_messages)
msgs = replay_transcript_to_ui_messages(
lines,
augment_user_media=augment_user_media,

View File

@ -348,6 +348,12 @@ class GatewayHTTPHandler:
if not _is_websocket_channel_session_key(decoded_key):
return _http_error(404, "session not found")
scope = self.workspaces.scope_for_session_key(decoded_key)
session_messages: list[dict[str, Any]] | None = None
if self.session_manager is not None:
session_data = self.session_manager.read_session_file(decoded_key)
raw_messages = session_data.get("messages") if isinstance(session_data, dict) else None
if isinstance(raw_messages, list):
session_messages = [m for m in raw_messages if isinstance(m, dict)]
data = build_webui_thread_response(
decoded_key,
augment_user_media=self.media.augment_transcript_media,
@ -356,6 +362,7 @@ class GatewayHTTPHandler:
text,
workspace_path=scope.project_path,
),
session_messages=session_messages,
)
if data is None:
return _http_error(404, "webui thread not found")

View File

@ -294,6 +294,87 @@ async def test_webui_message_envelope_marks_inbound_metadata(bus: MagicMock) ->
assert msg.metadata["_wants_stream"] is True
@pytest.mark.asyncio
async def test_webui_message_envelope_persists_user_transcript_for_refresh(
bus: MagicMock,
tmp_path,
monkeypatch,
) -> None:
from nanobot.webui.transcript import build_webui_thread_response, read_transcript_lines
monkeypatch.setattr("nanobot.config.paths.get_data_dir", lambda: tmp_path)
channel = _ch(bus)
conn = AsyncMock()
conn.remote_address = ("127.0.0.1", 50123)
async def answer_during_publish(_msg: Any) -> None:
await channel.send(OutboundMessage(channel="websocket", chat_id="chat-1", content="hi back"))
bus.publish_inbound.side_effect = answer_during_publish
await channel._dispatch_envelope(
conn,
"webui-client",
{"type": "message", "chat_id": "chat-1", "content": "hello", "webui": True},
)
lines = read_transcript_lines("websocket:chat-1")
assert [line["event"] for line in lines] == ["user", "message"]
body = build_webui_thread_response("websocket:chat-1")
assert body is not None
assert [message["role"] for message in body["messages"]] == ["user", "assistant"]
assert [message["content"] for message in body["messages"]] == ["hello", "hi back"]
@pytest.mark.asyncio
async def test_webui_stop_control_message_is_not_persisted_as_user_bubble(
bus: MagicMock,
tmp_path,
monkeypatch,
) -> None:
from nanobot.webui.transcript import read_transcript_lines
monkeypatch.setattr("nanobot.config.paths.get_data_dir", lambda: tmp_path)
channel = _ch(bus)
conn = AsyncMock()
conn.remote_address = ("127.0.0.1", 50123)
await channel._dispatch_envelope(
conn,
"webui-client",
{"type": "message", "chat_id": "chat-1", "content": "/stop", "webui": True},
)
msg = bus.publish_inbound.await_args.args[0]
assert msg.content == "/stop"
assert read_transcript_lines("websocket:chat-1") == []
@pytest.mark.asyncio
async def test_webui_user_transcript_append_failure_does_not_block_inbound(
bus: MagicMock,
monkeypatch,
) -> None:
def fail_append(_session_key: str, _obj: dict[str, Any]) -> None:
raise OSError("disk full")
monkeypatch.setattr("nanobot.channels.websocket.append_transcript_object", fail_append)
channel = _ch(bus)
conn = AsyncMock()
conn.remote_address = ("127.0.0.1", 50123)
await channel._dispatch_envelope(
conn,
"webui-client",
{"type": "message", "chat_id": "chat-1", "content": "hello", "webui": True},
)
msg = bus.publish_inbound.await_args.args[0]
assert msg.chat_id == "chat-1"
assert msg.content == "hello"
@pytest.mark.asyncio
async def test_plain_websocket_message_does_not_mark_webui(bus: MagicMock) -> None:
channel = _ch(bus)
@ -2411,3 +2492,47 @@ def test_handle_webui_thread_get_returns_json(tmp_path, monkeypatch) -> None:
assert len(body["messages"]) == 1
assert body["messages"][0]["role"] == "user"
assert body["messages"][0]["content"] == "hi"
def test_handle_webui_thread_get_backfills_legacy_missing_user_rows(
tmp_path,
monkeypatch,
) -> None:
from urllib.parse import quote
from websockets.datastructures import Headers
from websockets.http11 import Request
from nanobot.webui.transcript import append_transcript_object
monkeypatch.setattr("nanobot.config.paths.get_data_dir", lambda: tmp_path)
workspace = tmp_path / "workspace"
sessions = SessionManager(workspace)
key = "websocket:c-legacy"
session = sessions.get_or_create(key)
session.add_message("user", "legacy question")
session.add_message("assistant", "legacy answer")
sessions.save(session)
append_transcript_object(
key,
{"event": "message", "chat_id": "c-legacy", "text": "legacy answer"},
)
bus = MagicMock()
channel = WebSocketChannel(
{"enabled": True, "allowFrom": ["*"]},
bus,
gateway=_basic_handler(bus, session_manager=sessions, workspace_path=workspace),
)
channel.gateway.tokens.api_tokens["tok"] = time.monotonic() + 300.0
enc = quote(key, safe="")
req = Request(f"/api/sessions/{enc}/webui-thread", Headers([("Authorization", "Bearer tok")]))
resp = channel.gateway.http._handle_webui_thread_get(req, enc)
assert resp.status_code == 200
body = json.loads(resp.body.decode())
assert [message["role"] for message in body["messages"]] == ["user", "assistant"]
assert [message["content"] for message in body["messages"]] == [
"legacy question",
"legacy answer",
]

View File

@ -5,6 +5,7 @@ from __future__ import annotations
from nanobot.webui.transcript import (
WEBUI_TRANSCRIPT_SCHEMA_VERSION,
append_transcript_object,
build_webui_thread_response,
read_transcript_lines,
replay_transcript_to_ui_messages,
)
@ -66,6 +67,98 @@ def test_replay_uses_stream_end_final_text() -> None:
assert msgs[1]["content"] == "![Diagram](/api/media/sig/payload)"
def test_build_response_backfills_legacy_sse_only_transcripts(tmp_path, monkeypatch) -> None:
monkeypatch.setattr("nanobot.config.paths.get_data_dir", lambda: tmp_path)
key = "websocket:t-legacy"
for ev in (
{"event": "delta", "chat_id": "t-legacy", "text": "first answer"},
{"event": "stream_end", "chat_id": "t-legacy"},
{"event": "turn_end", "chat_id": "t-legacy"},
{"event": "message", "chat_id": "t-legacy", "text": "second answer"},
{"event": "turn_end", "chat_id": "t-legacy"},
):
append_transcript_object(key, ev)
out = build_webui_thread_response(
key,
session_messages=[
{"role": "user", "content": "first question"},
{"role": "assistant", "content": "first answer"},
{"role": "user", "content": "second question"},
{"role": "assistant", "content": "second answer"},
],
)
assert out is not None
assert [message["role"] for message in out["messages"]] == [
"user",
"assistant",
"user",
"assistant",
]
assert [message["content"] for message in out["messages"]] == [
"first question",
"first answer",
"second question",
"second answer",
]
def test_backfill_does_not_duplicate_existing_user_transcript(tmp_path, monkeypatch) -> None:
monkeypatch.setattr("nanobot.config.paths.get_data_dir", lambda: tmp_path)
key = "websocket:t-current"
for ev in (
{"event": "user", "chat_id": "t-current", "text": "already stored"},
{"event": "message", "chat_id": "t-current", "text": "answer"},
{"event": "turn_end", "chat_id": "t-current"},
):
append_transcript_object(key, ev)
out = build_webui_thread_response(
key,
session_messages=[{"role": "user", "content": "already stored"}],
)
assert out is not None
assert [message["role"] for message in out["messages"]] == ["user", "assistant"]
assert out["messages"][0]["content"] == "already stored"
def test_backfill_does_not_misalign_when_session_only_has_transcript_tail(
tmp_path,
monkeypatch,
) -> None:
monkeypatch.setattr("nanobot.config.paths.get_data_dir", lambda: tmp_path)
key = "websocket:t-tail"
for ev in (
{"event": "message", "chat_id": "t-tail", "text": "old answer"},
{"event": "turn_end", "chat_id": "t-tail"},
{"event": "message", "chat_id": "t-tail", "text": "tail answer"},
{"event": "turn_end", "chat_id": "t-tail"},
):
append_transcript_object(key, ev)
out = build_webui_thread_response(
key,
session_messages=[
{"role": "user", "content": "tail question"},
{"role": "assistant", "content": "tail answer"},
],
)
assert out is not None
assert [message["role"] for message in out["messages"]] == [
"assistant",
"user",
"assistant",
]
assert [message["content"] for message in out["messages"]] == [
"old answer",
"tail question",
"tail answer",
]
def test_replay_infers_video_media_from_attachment_name() -> None:
msgs = replay_transcript_to_ui_messages(
[