mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-13 22:34:06 +00:00
fix(webui): persist user messages for refresh
This commit is contained in:
parent
3da68ac7fe
commit
710d00a179
@ -48,7 +48,7 @@ from nanobot.webui.http_utils import (
|
|||||||
query_first as _query_first,
|
query_first as _query_first,
|
||||||
)
|
)
|
||||||
from nanobot.webui.mcp_presets_api import normalize_mcp_preset_mentions
|
from nanobot.webui.mcp_presets_api import normalize_mcp_preset_mentions
|
||||||
from nanobot.webui.transcript import append_transcript_object
|
from nanobot.webui.transcript import append_transcript_object, build_user_transcript_event
|
||||||
from nanobot.webui.websocket_logging import websockets_server_logger
|
from nanobot.webui.websocket_logging import websockets_server_logger
|
||||||
|
|
||||||
|
|
||||||
@ -768,6 +768,14 @@ class WebSocketChannel(BaseChannel):
|
|||||||
"enabled": True,
|
"enabled": True,
|
||||||
"aspect_ratio": aspect_ratio if isinstance(aspect_ratio, str) else None,
|
"aspect_ratio": aspect_ratio if isinstance(aspect_ratio, str) else None,
|
||||||
}
|
}
|
||||||
|
if envelope.get("webui") is True and self.is_allowed(client_id):
|
||||||
|
self._try_append_webui_user_transcript(
|
||||||
|
cid,
|
||||||
|
content,
|
||||||
|
media_paths=media_paths,
|
||||||
|
cli_apps=cli_apps,
|
||||||
|
mcp_presets=mcp_presets,
|
||||||
|
)
|
||||||
await self._handle_message(
|
await self._handle_message(
|
||||||
sender_id=client_id,
|
sender_id=client_id,
|
||||||
chat_id=cid,
|
chat_id=cid,
|
||||||
@ -833,9 +841,31 @@ class WebSocketChannel(BaseChannel):
|
|||||||
try:
|
try:
|
||||||
dup = json.loads(json.dumps(wire, ensure_ascii=False))
|
dup = json.loads(json.dumps(wire, ensure_ascii=False))
|
||||||
append_transcript_object(sk, dup)
|
append_transcript_object(sk, dup)
|
||||||
except (ValueError, TypeError) as e:
|
except (OSError, ValueError, TypeError) as e:
|
||||||
self.logger.warning("webui transcript append failed: {}", e)
|
self.logger.warning("webui transcript append failed: {}", e)
|
||||||
|
|
||||||
|
def _try_append_webui_user_transcript(
|
||||||
|
self,
|
||||||
|
chat_id: str,
|
||||||
|
content: str,
|
||||||
|
*,
|
||||||
|
media_paths: list[str],
|
||||||
|
cli_apps: list[dict[str, Any]],
|
||||||
|
mcp_presets: list[dict[str, Any]],
|
||||||
|
) -> None:
|
||||||
|
if content.strip() == "/stop" and not media_paths:
|
||||||
|
return
|
||||||
|
payload = build_user_transcript_event(
|
||||||
|
chat_id,
|
||||||
|
content,
|
||||||
|
media_paths=media_paths,
|
||||||
|
cli_apps=cli_apps,
|
||||||
|
mcp_presets=mcp_presets,
|
||||||
|
)
|
||||||
|
if payload is None:
|
||||||
|
return
|
||||||
|
self._try_append_webui_transcript(chat_id, payload)
|
||||||
|
|
||||||
async def send(self, msg: OutboundMessage) -> None:
|
async def send(self, msg: OutboundMessage) -> None:
|
||||||
if msg.metadata.get("_runtime_model_updated"):
|
if msg.metadata.get("_runtime_model_updated"):
|
||||||
await self.send_runtime_model_updated(
|
await self.send_runtime_model_updated(
|
||||||
|
|||||||
@ -40,6 +40,15 @@ _FILE_EDIT_TOOL_NAMES: frozenset[str] = frozenset({
|
|||||||
"edit_file",
|
"edit_file",
|
||||||
"apply_patch",
|
"apply_patch",
|
||||||
})
|
})
|
||||||
|
_TURN_DISPLAY_EVENTS: frozenset[str] = frozenset({
|
||||||
|
"reasoning_delta",
|
||||||
|
"reasoning_end",
|
||||||
|
"delta",
|
||||||
|
"stream_end",
|
||||||
|
"message",
|
||||||
|
"file_edit",
|
||||||
|
"turn_end",
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
def rewrite_local_markdown_images(
|
def rewrite_local_markdown_images(
|
||||||
@ -155,6 +164,165 @@ def delete_webui_transcript(session_key: str) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def build_user_transcript_event(
|
||||||
|
chat_id: str,
|
||||||
|
text: str,
|
||||||
|
*,
|
||||||
|
media_paths: list[Any] | None = None,
|
||||||
|
cli_apps: list[Any] | None = None,
|
||||||
|
mcp_presets: list[Any] | None = None,
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
paths = [str(path) for path in (media_paths or []) if path]
|
||||||
|
if not text and not paths:
|
||||||
|
return None
|
||||||
|
event: dict[str, Any] = {
|
||||||
|
"event": "user",
|
||||||
|
"chat_id": chat_id,
|
||||||
|
"text": text,
|
||||||
|
}
|
||||||
|
if paths:
|
||||||
|
event["media_paths"] = paths
|
||||||
|
apps = [dict(app) for app in (cli_apps or []) if isinstance(app, Mapping)]
|
||||||
|
if apps:
|
||||||
|
event["cli_apps"] = apps
|
||||||
|
presets = [dict(preset) for preset in (mcp_presets or []) if isinstance(preset, Mapping)]
|
||||||
|
if presets:
|
||||||
|
event["mcp_presets"] = presets
|
||||||
|
return event
|
||||||
|
|
||||||
|
|
||||||
|
def _session_user_event(
|
||||||
|
session_key: str,
|
||||||
|
message: dict[str, Any],
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
if message.get("role") != "user":
|
||||||
|
return None
|
||||||
|
content = message.get("content")
|
||||||
|
text = content if isinstance(content, str) else ""
|
||||||
|
media = message.get("media")
|
||||||
|
cli_apps = message.get("cli_apps")
|
||||||
|
mcp_presets = message.get("mcp_presets")
|
||||||
|
chat_id = session_key.split(":", 1)[1] if ":" in session_key else session_key
|
||||||
|
return build_user_transcript_event(
|
||||||
|
chat_id,
|
||||||
|
text,
|
||||||
|
media_paths=media if isinstance(media, list) else None,
|
||||||
|
cli_apps=cli_apps if isinstance(cli_apps, list) else None,
|
||||||
|
mcp_presets=mcp_presets if isinstance(mcp_presets, list) else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _assistant_text_signature(value: Any) -> str:
|
||||||
|
return value.strip() if isinstance(value, str) else ""
|
||||||
|
|
||||||
|
|
||||||
|
def _session_backfill_turns(
|
||||||
|
session_key: str,
|
||||||
|
session_messages: list[dict[str, Any]],
|
||||||
|
) -> list[tuple[dict[str, Any], tuple[str, ...]]]:
|
||||||
|
turns: list[tuple[dict[str, Any], tuple[str, ...]]] = []
|
||||||
|
current_user: dict[str, Any] | None = None
|
||||||
|
assistant_texts: list[str] = []
|
||||||
|
|
||||||
|
def flush() -> None:
|
||||||
|
if current_user is None:
|
||||||
|
return
|
||||||
|
signature = tuple(text for text in assistant_texts if text)
|
||||||
|
if signature:
|
||||||
|
turns.append((current_user, signature))
|
||||||
|
|
||||||
|
for message in session_messages:
|
||||||
|
role = message.get("role")
|
||||||
|
if role == "user":
|
||||||
|
flush()
|
||||||
|
current_user = _session_user_event(session_key, message)
|
||||||
|
assistant_texts = []
|
||||||
|
continue
|
||||||
|
if role == "assistant" and current_user is not None:
|
||||||
|
text = _assistant_text_signature(message.get("content"))
|
||||||
|
if text:
|
||||||
|
assistant_texts.append(text)
|
||||||
|
flush()
|
||||||
|
return turns
|
||||||
|
|
||||||
|
|
||||||
|
def _split_transcript_turns(lines: list[dict[str, Any]]) -> list[list[dict[str, Any]]]:
|
||||||
|
turns: list[list[dict[str, Any]]] = []
|
||||||
|
current: list[dict[str, Any]] = []
|
||||||
|
for rec in lines:
|
||||||
|
current.append(rec)
|
||||||
|
if rec.get("event") == "turn_end":
|
||||||
|
turns.append(current)
|
||||||
|
current = []
|
||||||
|
if current:
|
||||||
|
turns.append(current)
|
||||||
|
return turns
|
||||||
|
|
||||||
|
|
||||||
|
def _transcript_turn_signature(records: list[dict[str, Any]]) -> tuple[str, ...]:
|
||||||
|
texts: list[str] = []
|
||||||
|
for message in replay_transcript_to_ui_messages(records):
|
||||||
|
if message.get("role") != "assistant" or message.get("kind") == "trace":
|
||||||
|
continue
|
||||||
|
text = _assistant_text_signature(message.get("content"))
|
||||||
|
if text:
|
||||||
|
texts.append(text)
|
||||||
|
return tuple(texts)
|
||||||
|
|
||||||
|
|
||||||
|
def _find_unique_session_turn(
|
||||||
|
session_turns: list[tuple[dict[str, Any], tuple[str, ...]]],
|
||||||
|
signature: tuple[str, ...],
|
||||||
|
start: int,
|
||||||
|
) -> int | None:
|
||||||
|
if not signature:
|
||||||
|
return None
|
||||||
|
found: int | None = None
|
||||||
|
for index in range(start, len(session_turns)):
|
||||||
|
if session_turns[index][1] != signature:
|
||||||
|
continue
|
||||||
|
if found is not None:
|
||||||
|
return None
|
||||||
|
found = index
|
||||||
|
return found
|
||||||
|
|
||||||
|
|
||||||
|
def _with_backfilled_user(
|
||||||
|
records: list[dict[str, Any]],
|
||||||
|
user_event: dict[str, Any],
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
for index, rec in enumerate(records):
|
||||||
|
if rec.get("event") in _TURN_DISPLAY_EVENTS:
|
||||||
|
return [*records[:index], dict(user_event), *records[index:]]
|
||||||
|
return records
|
||||||
|
|
||||||
|
|
||||||
|
def inject_missing_user_events_from_session(
|
||||||
|
session_key: str,
|
||||||
|
lines: list[dict[str, Any]],
|
||||||
|
session_messages: list[dict[str, Any]] | None,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Backfill user rows for legacy WebUI transcripts that only stored assistant streams."""
|
||||||
|
if not lines or not session_messages:
|
||||||
|
return lines
|
||||||
|
session_turns = _session_backfill_turns(session_key, session_messages)
|
||||||
|
if not session_turns:
|
||||||
|
return lines
|
||||||
|
|
||||||
|
out: list[dict[str, Any]] = []
|
||||||
|
session_cursor = 0
|
||||||
|
for turn in _split_transcript_turns(lines):
|
||||||
|
has_user = any(rec.get("event") == "user" for rec in turn)
|
||||||
|
signature = _transcript_turn_signature(turn)
|
||||||
|
match_index = _find_unique_session_turn(session_turns, signature, session_cursor)
|
||||||
|
if match_index is None:
|
||||||
|
out.extend(turn)
|
||||||
|
continue
|
||||||
|
out.extend(turn if has_user else _with_backfilled_user(turn, session_turns[match_index][0]))
|
||||||
|
session_cursor = match_index + 1
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
def _format_tool_call_trace(call: Any) -> str | None:
|
def _format_tool_call_trace(call: Any) -> str | None:
|
||||||
if not call or not isinstance(call, dict):
|
if not call or not isinstance(call, dict):
|
||||||
return None
|
return None
|
||||||
@ -904,11 +1072,13 @@ def build_webui_thread_response(
|
|||||||
augment_user_media: Callable[[list[str]], list[dict[str, Any]]] | None = None,
|
augment_user_media: Callable[[list[str]], list[dict[str, Any]]] | None = None,
|
||||||
augment_assistant_media: Callable[[list[str]], list[dict[str, Any]]] | None = None,
|
augment_assistant_media: Callable[[list[str]], list[dict[str, Any]]] | None = None,
|
||||||
augment_assistant_text: Callable[[str], str] | None = None,
|
augment_assistant_text: Callable[[str], str] | None = None,
|
||||||
|
session_messages: list[dict[str, Any]] | None = None,
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Return a payload compatible with ``WebuiThreadPersistedPayload``."""
|
"""Return a payload compatible with ``WebuiThreadPersistedPayload``."""
|
||||||
lines = read_transcript_lines(session_key)
|
lines = read_transcript_lines(session_key)
|
||||||
if not lines:
|
if not lines:
|
||||||
return None
|
return None
|
||||||
|
lines = inject_missing_user_events_from_session(session_key, lines, session_messages)
|
||||||
msgs = replay_transcript_to_ui_messages(
|
msgs = replay_transcript_to_ui_messages(
|
||||||
lines,
|
lines,
|
||||||
augment_user_media=augment_user_media,
|
augment_user_media=augment_user_media,
|
||||||
|
|||||||
@ -348,6 +348,12 @@ class GatewayHTTPHandler:
|
|||||||
if not _is_websocket_channel_session_key(decoded_key):
|
if not _is_websocket_channel_session_key(decoded_key):
|
||||||
return _http_error(404, "session not found")
|
return _http_error(404, "session not found")
|
||||||
scope = self.workspaces.scope_for_session_key(decoded_key)
|
scope = self.workspaces.scope_for_session_key(decoded_key)
|
||||||
|
session_messages: list[dict[str, Any]] | None = None
|
||||||
|
if self.session_manager is not None:
|
||||||
|
session_data = self.session_manager.read_session_file(decoded_key)
|
||||||
|
raw_messages = session_data.get("messages") if isinstance(session_data, dict) else None
|
||||||
|
if isinstance(raw_messages, list):
|
||||||
|
session_messages = [m for m in raw_messages if isinstance(m, dict)]
|
||||||
data = build_webui_thread_response(
|
data = build_webui_thread_response(
|
||||||
decoded_key,
|
decoded_key,
|
||||||
augment_user_media=self.media.augment_transcript_media,
|
augment_user_media=self.media.augment_transcript_media,
|
||||||
@ -356,6 +362,7 @@ class GatewayHTTPHandler:
|
|||||||
text,
|
text,
|
||||||
workspace_path=scope.project_path,
|
workspace_path=scope.project_path,
|
||||||
),
|
),
|
||||||
|
session_messages=session_messages,
|
||||||
)
|
)
|
||||||
if data is None:
|
if data is None:
|
||||||
return _http_error(404, "webui thread not found")
|
return _http_error(404, "webui thread not found")
|
||||||
|
|||||||
@ -294,6 +294,87 @@ async def test_webui_message_envelope_marks_inbound_metadata(bus: MagicMock) ->
|
|||||||
assert msg.metadata["_wants_stream"] is True
|
assert msg.metadata["_wants_stream"] is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_webui_message_envelope_persists_user_transcript_for_refresh(
|
||||||
|
bus: MagicMock,
|
||||||
|
tmp_path,
|
||||||
|
monkeypatch,
|
||||||
|
) -> None:
|
||||||
|
from nanobot.webui.transcript import build_webui_thread_response, read_transcript_lines
|
||||||
|
|
||||||
|
monkeypatch.setattr("nanobot.config.paths.get_data_dir", lambda: tmp_path)
|
||||||
|
channel = _ch(bus)
|
||||||
|
conn = AsyncMock()
|
||||||
|
conn.remote_address = ("127.0.0.1", 50123)
|
||||||
|
|
||||||
|
async def answer_during_publish(_msg: Any) -> None:
|
||||||
|
await channel.send(OutboundMessage(channel="websocket", chat_id="chat-1", content="hi back"))
|
||||||
|
|
||||||
|
bus.publish_inbound.side_effect = answer_during_publish
|
||||||
|
|
||||||
|
await channel._dispatch_envelope(
|
||||||
|
conn,
|
||||||
|
"webui-client",
|
||||||
|
{"type": "message", "chat_id": "chat-1", "content": "hello", "webui": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
lines = read_transcript_lines("websocket:chat-1")
|
||||||
|
assert [line["event"] for line in lines] == ["user", "message"]
|
||||||
|
|
||||||
|
body = build_webui_thread_response("websocket:chat-1")
|
||||||
|
assert body is not None
|
||||||
|
assert [message["role"] for message in body["messages"]] == ["user", "assistant"]
|
||||||
|
assert [message["content"] for message in body["messages"]] == ["hello", "hi back"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_webui_stop_control_message_is_not_persisted_as_user_bubble(
|
||||||
|
bus: MagicMock,
|
||||||
|
tmp_path,
|
||||||
|
monkeypatch,
|
||||||
|
) -> None:
|
||||||
|
from nanobot.webui.transcript import read_transcript_lines
|
||||||
|
|
||||||
|
monkeypatch.setattr("nanobot.config.paths.get_data_dir", lambda: tmp_path)
|
||||||
|
channel = _ch(bus)
|
||||||
|
conn = AsyncMock()
|
||||||
|
conn.remote_address = ("127.0.0.1", 50123)
|
||||||
|
|
||||||
|
await channel._dispatch_envelope(
|
||||||
|
conn,
|
||||||
|
"webui-client",
|
||||||
|
{"type": "message", "chat_id": "chat-1", "content": "/stop", "webui": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
msg = bus.publish_inbound.await_args.args[0]
|
||||||
|
assert msg.content == "/stop"
|
||||||
|
assert read_transcript_lines("websocket:chat-1") == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_webui_user_transcript_append_failure_does_not_block_inbound(
|
||||||
|
bus: MagicMock,
|
||||||
|
monkeypatch,
|
||||||
|
) -> None:
|
||||||
|
def fail_append(_session_key: str, _obj: dict[str, Any]) -> None:
|
||||||
|
raise OSError("disk full")
|
||||||
|
|
||||||
|
monkeypatch.setattr("nanobot.channels.websocket.append_transcript_object", fail_append)
|
||||||
|
channel = _ch(bus)
|
||||||
|
conn = AsyncMock()
|
||||||
|
conn.remote_address = ("127.0.0.1", 50123)
|
||||||
|
|
||||||
|
await channel._dispatch_envelope(
|
||||||
|
conn,
|
||||||
|
"webui-client",
|
||||||
|
{"type": "message", "chat_id": "chat-1", "content": "hello", "webui": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
msg = bus.publish_inbound.await_args.args[0]
|
||||||
|
assert msg.chat_id == "chat-1"
|
||||||
|
assert msg.content == "hello"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_plain_websocket_message_does_not_mark_webui(bus: MagicMock) -> None:
|
async def test_plain_websocket_message_does_not_mark_webui(bus: MagicMock) -> None:
|
||||||
channel = _ch(bus)
|
channel = _ch(bus)
|
||||||
@ -2411,3 +2492,47 @@ def test_handle_webui_thread_get_returns_json(tmp_path, monkeypatch) -> None:
|
|||||||
assert len(body["messages"]) == 1
|
assert len(body["messages"]) == 1
|
||||||
assert body["messages"][0]["role"] == "user"
|
assert body["messages"][0]["role"] == "user"
|
||||||
assert body["messages"][0]["content"] == "hi"
|
assert body["messages"][0]["content"] == "hi"
|
||||||
|
|
||||||
|
|
||||||
|
def test_handle_webui_thread_get_backfills_legacy_missing_user_rows(
|
||||||
|
tmp_path,
|
||||||
|
monkeypatch,
|
||||||
|
) -> None:
|
||||||
|
from urllib.parse import quote
|
||||||
|
|
||||||
|
from websockets.datastructures import Headers
|
||||||
|
from websockets.http11 import Request
|
||||||
|
|
||||||
|
from nanobot.webui.transcript import append_transcript_object
|
||||||
|
|
||||||
|
monkeypatch.setattr("nanobot.config.paths.get_data_dir", lambda: tmp_path)
|
||||||
|
workspace = tmp_path / "workspace"
|
||||||
|
sessions = SessionManager(workspace)
|
||||||
|
key = "websocket:c-legacy"
|
||||||
|
session = sessions.get_or_create(key)
|
||||||
|
session.add_message("user", "legacy question")
|
||||||
|
session.add_message("assistant", "legacy answer")
|
||||||
|
sessions.save(session)
|
||||||
|
append_transcript_object(
|
||||||
|
key,
|
||||||
|
{"event": "message", "chat_id": "c-legacy", "text": "legacy answer"},
|
||||||
|
)
|
||||||
|
|
||||||
|
bus = MagicMock()
|
||||||
|
channel = WebSocketChannel(
|
||||||
|
{"enabled": True, "allowFrom": ["*"]},
|
||||||
|
bus,
|
||||||
|
gateway=_basic_handler(bus, session_manager=sessions, workspace_path=workspace),
|
||||||
|
)
|
||||||
|
channel.gateway.tokens.api_tokens["tok"] = time.monotonic() + 300.0
|
||||||
|
enc = quote(key, safe="")
|
||||||
|
req = Request(f"/api/sessions/{enc}/webui-thread", Headers([("Authorization", "Bearer tok")]))
|
||||||
|
resp = channel.gateway.http._handle_webui_thread_get(req, enc)
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = json.loads(resp.body.decode())
|
||||||
|
assert [message["role"] for message in body["messages"]] == ["user", "assistant"]
|
||||||
|
assert [message["content"] for message in body["messages"]] == [
|
||||||
|
"legacy question",
|
||||||
|
"legacy answer",
|
||||||
|
]
|
||||||
|
|||||||
@ -5,6 +5,7 @@ from __future__ import annotations
|
|||||||
from nanobot.webui.transcript import (
|
from nanobot.webui.transcript import (
|
||||||
WEBUI_TRANSCRIPT_SCHEMA_VERSION,
|
WEBUI_TRANSCRIPT_SCHEMA_VERSION,
|
||||||
append_transcript_object,
|
append_transcript_object,
|
||||||
|
build_webui_thread_response,
|
||||||
read_transcript_lines,
|
read_transcript_lines,
|
||||||
replay_transcript_to_ui_messages,
|
replay_transcript_to_ui_messages,
|
||||||
)
|
)
|
||||||
@ -66,6 +67,98 @@ def test_replay_uses_stream_end_final_text() -> None:
|
|||||||
assert msgs[1]["content"] == ""
|
assert msgs[1]["content"] == ""
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_response_backfills_legacy_sse_only_transcripts(tmp_path, monkeypatch) -> None:
|
||||||
|
monkeypatch.setattr("nanobot.config.paths.get_data_dir", lambda: tmp_path)
|
||||||
|
key = "websocket:t-legacy"
|
||||||
|
for ev in (
|
||||||
|
{"event": "delta", "chat_id": "t-legacy", "text": "first answer"},
|
||||||
|
{"event": "stream_end", "chat_id": "t-legacy"},
|
||||||
|
{"event": "turn_end", "chat_id": "t-legacy"},
|
||||||
|
{"event": "message", "chat_id": "t-legacy", "text": "second answer"},
|
||||||
|
{"event": "turn_end", "chat_id": "t-legacy"},
|
||||||
|
):
|
||||||
|
append_transcript_object(key, ev)
|
||||||
|
|
||||||
|
out = build_webui_thread_response(
|
||||||
|
key,
|
||||||
|
session_messages=[
|
||||||
|
{"role": "user", "content": "first question"},
|
||||||
|
{"role": "assistant", "content": "first answer"},
|
||||||
|
{"role": "user", "content": "second question"},
|
||||||
|
{"role": "assistant", "content": "second answer"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert out is not None
|
||||||
|
assert [message["role"] for message in out["messages"]] == [
|
||||||
|
"user",
|
||||||
|
"assistant",
|
||||||
|
"user",
|
||||||
|
"assistant",
|
||||||
|
]
|
||||||
|
assert [message["content"] for message in out["messages"]] == [
|
||||||
|
"first question",
|
||||||
|
"first answer",
|
||||||
|
"second question",
|
||||||
|
"second answer",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_backfill_does_not_duplicate_existing_user_transcript(tmp_path, monkeypatch) -> None:
|
||||||
|
monkeypatch.setattr("nanobot.config.paths.get_data_dir", lambda: tmp_path)
|
||||||
|
key = "websocket:t-current"
|
||||||
|
for ev in (
|
||||||
|
{"event": "user", "chat_id": "t-current", "text": "already stored"},
|
||||||
|
{"event": "message", "chat_id": "t-current", "text": "answer"},
|
||||||
|
{"event": "turn_end", "chat_id": "t-current"},
|
||||||
|
):
|
||||||
|
append_transcript_object(key, ev)
|
||||||
|
|
||||||
|
out = build_webui_thread_response(
|
||||||
|
key,
|
||||||
|
session_messages=[{"role": "user", "content": "already stored"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert out is not None
|
||||||
|
assert [message["role"] for message in out["messages"]] == ["user", "assistant"]
|
||||||
|
assert out["messages"][0]["content"] == "already stored"
|
||||||
|
|
||||||
|
|
||||||
|
def test_backfill_does_not_misalign_when_session_only_has_transcript_tail(
|
||||||
|
tmp_path,
|
||||||
|
monkeypatch,
|
||||||
|
) -> None:
|
||||||
|
monkeypatch.setattr("nanobot.config.paths.get_data_dir", lambda: tmp_path)
|
||||||
|
key = "websocket:t-tail"
|
||||||
|
for ev in (
|
||||||
|
{"event": "message", "chat_id": "t-tail", "text": "old answer"},
|
||||||
|
{"event": "turn_end", "chat_id": "t-tail"},
|
||||||
|
{"event": "message", "chat_id": "t-tail", "text": "tail answer"},
|
||||||
|
{"event": "turn_end", "chat_id": "t-tail"},
|
||||||
|
):
|
||||||
|
append_transcript_object(key, ev)
|
||||||
|
|
||||||
|
out = build_webui_thread_response(
|
||||||
|
key,
|
||||||
|
session_messages=[
|
||||||
|
{"role": "user", "content": "tail question"},
|
||||||
|
{"role": "assistant", "content": "tail answer"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert out is not None
|
||||||
|
assert [message["role"] for message in out["messages"]] == [
|
||||||
|
"assistant",
|
||||||
|
"user",
|
||||||
|
"assistant",
|
||||||
|
]
|
||||||
|
assert [message["content"] for message in out["messages"]] == [
|
||||||
|
"old answer",
|
||||||
|
"tail question",
|
||||||
|
"tail answer",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_replay_infers_video_media_from_attachment_name() -> None:
|
def test_replay_infers_video_media_from_attachment_name() -> None:
|
||||||
msgs = replay_transcript_to_ui_messages(
|
msgs = replay_transcript_to_ui_messages(
|
||||||
[
|
[
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user