nanobot/nanobot/session/webui_turns.py
2026-06-01 23:00:53 +08:00

450 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Session turn helpers for WebUI-capable WebSocket sessions."""
from __future__ import annotations
import re
import time
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field
from typing import Any
from loguru import logger
from nanobot.bus import progress as bus_progress
from nanobot.bus.events import InboundMessage, OutboundMessage
from nanobot.bus.queue import MessageBus
from nanobot.bus.runtime_events import (
GoalStateChanged,
RuntimeEventBus,
RuntimeEventContext,
RuntimeModelChanged,
SessionTurnStarted,
TurnCompleted,
TurnRunStatusChanged,
)
from nanobot.providers.base import LLMProvider
from nanobot.session.goal_state import goal_state_ws_blob
from nanobot.session.manager import Session, SessionManager
from nanobot.utils.helpers import strip_think, truncate_text
from nanobot.utils.llm_runtime import LLMRuntime
WEBUI_SESSION_METADATA_KEY = "webui"
WEBUI_TITLE_METADATA_KEY = "title"
WEBUI_TITLE_USER_EDITED_METADATA_KEY = "title_user_edited"
TITLE_MAX_CHARS = 60
TITLE_GENERATION_MAX_TOKENS = 96
TITLE_GENERATION_REASONING_EFFORT = "none"
# Wall-clock turn start per ``chat_id`` (websocket only). Survives browser refresh while the
# gateway process stays up; cleared on idle/stop and implicitly dropped on restart.
_WEBSOCKET_TURN_WALL_STARTED_AT: dict[str, float] = {}
def mark_webui_session(session: Session, metadata: dict[str, Any]) -> bool:
"""Persist a WebUI marker only when the inbound websocket frame opted in."""
if metadata.get(WEBUI_SESSION_METADATA_KEY) is not True:
return False
session.metadata[WEBUI_SESSION_METADATA_KEY] = True
return True
def clean_generated_title(raw: str | None) -> str:
text = (raw or "").strip()
if not text:
return ""
text = re.sub(r"^\s*(title|标题)\s*[:]\s*", "", text, flags=re.IGNORECASE)
text = text.strip().strip("\"'`“”‘’")
text = strip_think(text)
text = re.sub(r"\s+", " ", text).strip()
text = text.rstrip("。.!?,;:")
if len(text) > TITLE_MAX_CHARS:
text = text[: TITLE_MAX_CHARS - 1].rstrip() + ""
return text
def _title_inputs(session: Session) -> tuple[str, str]:
user_text = ""
assistant_text = ""
for message in session.messages:
if message.get("_command") is True:
continue
role = message.get("role")
content = message.get("content")
if not isinstance(content, str) or not content.strip():
continue
content = strip_think(content)
if not content:
continue
if role == "user" and not user_text:
user_text = content.strip()
elif role == "assistant" and not assistant_text:
assistant_text = content.strip()
if user_text and assistant_text:
break
return user_text, assistant_text
async def maybe_generate_webui_title(
*,
sessions: SessionManager,
session_key: str,
provider: LLMProvider,
model: str,
) -> bool:
"""Generate and persist a short title for WebUI-owned sessions only."""
session = sessions.get_or_create(session_key)
if session.metadata.get(WEBUI_SESSION_METADATA_KEY) is not True:
return False
if session.metadata.get(WEBUI_TITLE_USER_EDITED_METADATA_KEY) is True:
return False
current_title = session.metadata.get(WEBUI_TITLE_METADATA_KEY)
if isinstance(current_title, str) and current_title.strip():
cleaned_current_title = clean_generated_title(current_title)
if cleaned_current_title:
if cleaned_current_title != current_title:
session.metadata[WEBUI_TITLE_METADATA_KEY] = cleaned_current_title
sessions.save(session)
return False
session.metadata.pop(WEBUI_TITLE_METADATA_KEY, None)
user_text, assistant_text = _title_inputs(session)
if not user_text:
return False
prompt = (
"Generate a concise title for this chat.\n"
"Rules:\n"
"- Use the same language as the user when practical.\n"
"- 3 to 8 words.\n"
"- No quotes.\n"
"- No punctuation at the end.\n"
"- Return only the title.\n\n"
f"User: {truncate_text(user_text, 1_000)}"
)
if assistant_text:
prompt += f"\nAssistant: {truncate_text(assistant_text, 1_000)}"
try:
response = await provider.chat_with_retry(
[
{
"role": "system",
"content": (
"You write short, neutral chat titles. "
"Return only the title text."
),
},
{"role": "user", "content": prompt},
],
tools=None,
model=model,
max_tokens=TITLE_GENERATION_MAX_TOKENS,
temperature=0.2,
reasoning_effort=TITLE_GENERATION_REASONING_EFFORT,
retry_mode="standard",
)
except Exception:
logger.debug("Failed to generate webui session title for {}", session_key, exc_info=True)
return False
title = clean_generated_title(response.content)
if not title or title.lower().startswith("error"):
logger.debug(
"WebUI title generation returned no usable title for {} (finish_reason={})",
session_key,
response.finish_reason,
)
return False
session.metadata[WEBUI_TITLE_METADATA_KEY] = title
sessions.save(session)
return True
async def maybe_generate_webui_title_after_turn(
*,
channel: str,
metadata: dict[str, Any],
sessions: SessionManager,
session_key: str,
provider: LLMProvider,
model: str,
) -> bool:
if channel != "websocket" or metadata.get(WEBUI_SESSION_METADATA_KEY) is not True:
return False
return await maybe_generate_webui_title(
sessions=sessions,
session_key=session_key,
provider=provider,
model=model,
)
def websocket_turn_wall_started_at(chat_id: str) -> float | None:
"""Return ``time.time()`` when the active user turn began, if still running."""
return _WEBSOCKET_TURN_WALL_STARTED_AT.get(chat_id)
def build_bus_progress_callback(
bus: MessageBus,
msg: InboundMessage,
) -> Callable[..., Awaitable[None]]:
"""Compatibility wrapper for the generic bus progress callback."""
return bus_progress.build_bus_progress_callback(bus, msg)
async def publish_turn_run_status(
bus: MessageBus,
msg: InboundMessage,
status: str,
*,
started_at: float | None = None,
) -> None:
"""Notify WebSocket clients while a user turn is executing (timing strip)."""
if msg.channel != "websocket":
return
cid = str(msg.chat_id)
meta: dict[str, Any] = {
**dict(msg.metadata or {}),
"_goal_status": True,
"goal_status": status,
}
if status == "running":
if isinstance(started_at, int | float) and started_at > 0:
t0 = float(started_at)
else:
t0 = time.time()
meta["started_at"] = t0
_WEBSOCKET_TURN_WALL_STARTED_AT[cid] = t0
else:
_WEBSOCKET_TURN_WALL_STARTED_AT.pop(cid, None)
await bus.publish_outbound(
OutboundMessage(
channel=msg.channel,
chat_id=cid,
content="",
metadata=meta,
),
)
@dataclass
class WebuiTurnCoordinator:
"""Translate generic runtime events into WebUI/WebSocket wire messages."""
bus: MessageBus
sessions: SessionManager
schedule_background: Callable[[Awaitable[None]], None]
_title_contexts: dict[str, LLMRuntime] = field(default_factory=dict)
def subscribe(self, runtime_events: RuntimeEventBus) -> Callable[[], None]:
"""Subscribe this coordinator to runtime events."""
unsubscribe = [
runtime_events.subscribe(
self._handle_session_turn_started,
SessionTurnStarted,
),
runtime_events.subscribe(
self._handle_run_status_changed,
TurnRunStatusChanged,
),
runtime_events.subscribe(
self._handle_turn_completed_event,
TurnCompleted,
),
runtime_events.subscribe(
self._handle_goal_state_changed,
GoalStateChanged,
),
runtime_events.subscribe(
self._handle_runtime_model_changed,
RuntimeModelChanged,
),
]
def _unsubscribe() -> None:
for fn in reversed(unsubscribe):
fn()
return _unsubscribe
@staticmethod
def _ctx_msg(ctx: RuntimeEventContext) -> InboundMessage:
return InboundMessage(
channel=ctx.channel,
sender_id="runtime",
chat_id=ctx.chat_id,
content="",
metadata=dict(ctx.metadata or {}),
session_key_override=ctx.session_key,
)
@staticmethod
def _is_websocket_event(ctx: RuntimeEventContext) -> bool:
return ctx.channel == "websocket"
def _handle_session_turn_started(self, event: SessionTurnStarted) -> None:
if not self._is_websocket_event(event.context):
return
session = self.sessions.get_or_create(event.context.session_key)
mark_webui_session(session, event.context.metadata)
async def _handle_run_status_changed(self, event: TurnRunStatusChanged) -> None:
if not self._is_websocket_event(event.context):
return
await publish_turn_run_status(
self.bus,
self._ctx_msg(event.context),
event.status,
started_at=event.started_at,
)
async def _handle_turn_completed_event(self, event: TurnCompleted) -> None:
if not self._is_websocket_event(event.context):
return
msg = self._ctx_msg(event.context)
await self.handle_turn_end(
msg,
session_key=event.context.session_key,
latency_ms=event.latency_ms,
)
self._schedule_title_update_from_event(event)
async def _handle_goal_state_changed(self, event: GoalStateChanged) -> None:
if not self._is_websocket_event(event.context):
return
cid = str(event.context.chat_id or "").strip()
if not cid:
return
await self.bus.publish_outbound(
OutboundMessage(
channel=event.context.channel,
chat_id=cid,
content="",
metadata={
"_goal_state_sync": True,
"goal_state": goal_state_ws_blob(event.session_metadata),
},
),
)
async def _handle_runtime_model_changed(self, event: RuntimeModelChanged) -> None:
await self.bus.publish_outbound(
OutboundMessage(
channel="websocket",
chat_id="*",
content="",
metadata={
"_runtime_model_updated": True,
"model": event.model,
"model_preset": event.model_preset,
},
)
)
def capture_title_context(
self,
session_key: str,
msg: InboundMessage,
llm: LLMRuntime,
) -> None:
if msg.channel == "websocket" and msg.metadata.get("webui") is True:
self._title_contexts[session_key] = llm
def discard(self, session_key: str) -> None:
self._title_contexts.pop(session_key, None)
async def publish_run_status(
self,
msg: InboundMessage,
status: str,
*,
started_at: float | None = None,
) -> None:
await publish_turn_run_status(self.bus, msg, status, started_at=started_at)
async def handle_turn_end(
self,
msg: InboundMessage,
*,
session_key: str,
latency_ms: int | None,
) -> None:
if msg.channel != "websocket":
return
turn_metadata: dict[str, Any] = {**msg.metadata, "_turn_end": True}
if latency_ms is not None:
turn_metadata["latency_ms"] = int(latency_ms)
session = self.sessions.get_or_create(session_key)
turn_metadata["goal_state"] = goal_state_ws_blob(session.metadata)
await self.bus.publish_outbound(OutboundMessage(
channel=msg.channel,
chat_id=msg.chat_id,
content="",
metadata=turn_metadata,
))
self._schedule_title_update(msg, session_key=session_key)
def _schedule_title_update(self, msg: InboundMessage, *, session_key: str) -> None:
title_context = self._title_contexts.pop(session_key, None)
if msg.metadata.get("webui") is not True or title_context is None:
return
async def _generate_title_and_notify(
title_llm: LLMRuntime = title_context,
) -> None:
generated = await maybe_generate_webui_title_after_turn(
channel=msg.channel,
metadata=msg.metadata,
sessions=self.sessions,
session_key=session_key,
provider=title_llm.provider,
model=title_llm.model,
)
if generated:
await self.bus.publish_outbound(OutboundMessage(
channel=msg.channel,
chat_id=msg.chat_id,
content="",
metadata={
**msg.metadata,
"_session_updated": True,
"_session_update_scope": "metadata",
},
))
self.schedule_background(_generate_title_and_notify())
def _schedule_title_update_from_event(self, event: TurnCompleted) -> None:
title_context = event.runtime
if (
event.context.metadata.get("webui") is not True
or title_context is None
or not isinstance(title_context, LLMRuntime)
):
return
async def _generate_title_and_notify(
title_llm: LLMRuntime = title_context,
) -> None:
generated = await maybe_generate_webui_title_after_turn(
channel=event.context.channel,
metadata=event.context.metadata,
sessions=self.sessions,
session_key=event.context.session_key,
provider=title_llm.provider,
model=title_llm.model,
)
if generated:
await self.bus.publish_outbound(OutboundMessage(
channel=event.context.channel,
chat_id=event.context.chat_id,
content="",
metadata={
**event.context.metadata,
"_session_updated": True,
"_session_update_scope": "metadata",
},
))
self.schedule_background(_generate_title_and_notify())