From 628b250e9a7dcc01f7e4d3ac092a95ff6a466fa7 Mon Sep 17 00:00:00 2001 From: chengyongru Date: Mon, 1 Jun 2026 14:16:38 +0800 Subject: [PATCH] refactor: decouple webui runtime state via events --- nanobot/agent/loop.py | 146 ++++++++--- nanobot/agent/tools/context.py | 1 + nanobot/agent/tools/long_task.py | 69 ++++-- nanobot/bus/progress.py | 84 +++++++ nanobot/bus/runtime_events.py | 116 +++++++++ nanobot/cli/commands.py | 15 +- nanobot/session/webui_turns.py | 232 +++++++++++------- .../test_loop_direct_websocket_status.py | 6 + tests/agent/test_loop_progress.py | 15 ++ tests/agent/test_loop_save_turn.py | 8 +- tests/agent/tools/test_long_task.py | 20 +- 11 files changed, 559 insertions(+), 153 deletions(-) create mode 100644 nanobot/bus/progress.py create mode 100644 nanobot/bus/runtime_events.py diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index 9d3f5771f..2e02ffdf4 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -29,7 +29,16 @@ from nanobot.agent.tools.message import MessageTool from nanobot.agent.tools.registry import ToolRegistry from nanobot.agent.tools.self import MyTool from nanobot.bus.events import InboundMessage, OutboundMessage +from nanobot.bus.progress import build_bus_progress_callback from nanobot.bus.queue import MessageBus +from nanobot.bus.runtime_events import ( + RuntimeEventBus, + RuntimeEventContext, + RuntimeModelChanged, + SessionTurnStarted, + TurnCompleted, + TurnRunStatusChanged, +) from nanobot.command import CommandContext, CommandRouter, register_builtin_commands from nanobot.config.schema import AgentDefaults, ModelPresetConfig from nanobot.providers.base import LLMProvider @@ -39,18 +48,13 @@ from nanobot.security.workspace_access import ( bind_workspace_scope, reset_workspace_scope, ) +from nanobot.session import turn_continuation from nanobot.session.goal_state import ( goal_state_runtime_lines, runner_wall_llm_timeout_s, sustained_goal_active, ) from nanobot.session.manager import Session, SessionManager -from nanobot.session import turn_continuation -from nanobot.session.webui_turns import ( - WebuiTurnCoordinator, - build_bus_progress_callback, - mark_webui_session, -) from nanobot.utils.document import extract_documents, reference_non_image_attachments from nanobot.utils.helpers import image_placeholder_text from nanobot.utils.helpers import truncate_text as truncate_text_fn @@ -125,6 +129,7 @@ class TurnContext: turn_wall_started_at: float = field(default_factory=time.time) visible_run_started_at: float | None = None turn_latency_ms: int | None = None + llm_runtime: LLMRuntime | None = None trace: list[StateTraceEntry] = field(default_factory=list) @@ -203,6 +208,7 @@ class AgentLoop: model_presets: dict[str, ModelPresetConfig] | None = None, model_preset: str | None = None, preset_snapshot_loader: preset_helpers.PresetSnapshotLoader | None = None, + runtime_events: RuntimeEventBus | None = None, runtime_model_publisher: Callable[[str, str | None], None] | None = None, ): from nanobot.config.schema import ToolsConfig @@ -210,6 +216,7 @@ class AgentLoop: _tc = tools_config or ToolsConfig() defaults = AgentDefaults() self.bus = bus + self.runtime_events = runtime_events or RuntimeEventBus() self.channels_config = channels_config self.provider = provider self._provider_snapshot_loader = provider_snapshot_loader @@ -256,15 +263,11 @@ class AgentLoop: self._start_time = time.time() self._last_usage: dict[str, int] = {} self._pending_turn_latency_ms: dict[str, int] = {} + self._pending_turn_runtime: dict[str, LLMRuntime] = {} self._extra_hooks: list[AgentHook] = hooks or [] self.context = ContextBuilder(workspace, timezone=timezone, disabled_skills=disabled_skills) self.sessions = session_manager or SessionManager(workspace) - self._webui_turns = WebuiTurnCoordinator( - bus=self.bus, - sessions=self.sessions, - schedule_background=lambda coro: self._schedule_background(coro), - ) self.tools = ToolRegistry() # One file-read/write tracker per logical session. The tool registry is # shared by this loop, so tools resolve the active state via contextvars. @@ -418,6 +421,13 @@ class AgentLoop: self.model, model_preset if model_preset is not None else self.model_preset, ) + if publish_update: + self.runtime_events.publish_nowait( + RuntimeModelChanged( + model=self.model, + model_preset=model_preset if model_preset is not None else self.model_preset, + ) + ) logger.info("Runtime model switched for next turn: {} -> {}", old_model, model) def _refresh_provider_snapshot(self) -> None: @@ -483,6 +493,7 @@ class AgentLoop: image_generation_provider_configs=self._image_generation_provider_configs, timezone=self.context.timezone or "UTC", workspace_sandbox=self.workspace_scopes.sandbox_status, + runtime_events=self.runtime_events, ) loader = ToolLoader() registered = loader.load(ctx, self.tools) @@ -558,6 +569,63 @@ class AgentLoop: return _on_retry_wait + @staticmethod + def _runtime_event_context( + *, + channel: str, + chat_id: str, + session_key: str, + metadata: dict[str, Any] | None, + ) -> RuntimeEventContext: + return RuntimeEventContext( + channel=channel, + chat_id=chat_id, + session_key=session_key, + metadata=dict(metadata or {}), + ) + + async def _publish_run_status_event( + self, + msg: InboundMessage, + session_key: str, + status: str, + *, + started_at: float | None = None, + ) -> None: + await self.runtime_events.publish( + TurnRunStatusChanged( + context=self._runtime_event_context( + channel=msg.channel, + chat_id=msg.chat_id, + session_key=session_key, + metadata=msg.metadata, + ), + status=status, + started_at=started_at, + ) + ) + + async def _publish_turn_completed_event( + self, + *, + channel: str, + chat_id: str, + session_key: str, + metadata: dict[str, Any] | None, + ) -> None: + await self.runtime_events.publish( + TurnCompleted( + context=self._runtime_event_context( + channel=channel, + chat_id=chat_id, + session_key=session_key, + metadata=metadata, + ), + latency_ms=self._pending_turn_latency_ms.pop(session_key, None), + runtime=self._pending_turn_runtime.pop(session_key, None), + ) + ) + def _persist_user_message_early( self, msg: InboundMessage, @@ -959,20 +1027,24 @@ class AgentLoop: msg, on_stream=on_stream, on_stream_end=on_stream_end, pending_queue=pending, ) + completed_channel = msg.channel + completed_chat_id = msg.chat_id if response is not None: await self.bus.publish_outbound(response) + completed_channel = response.channel + completed_chat_id = response.chat_id elif msg.channel == "cli": await self.bus.publish_outbound(OutboundMessage( channel=msg.channel, chat_id=msg.chat_id, content="", metadata=msg.metadata or {}, )) continuing = turn_continuation.internal_continuation_pending(msg.metadata) - if msg.channel == "websocket" and not continuing: - turn_lat = self._pending_turn_latency_ms.pop(session_key, None) - await self._webui_turns.handle_turn_end( - msg, + if not continuing: + await self._publish_turn_completed_event( + channel=completed_channel, + chat_id=completed_chat_id, session_key=session_key, - latency_ms=turn_lat, + metadata=msg.metadata, ) except asyncio.CancelledError: logger.info("Task cancelled for session {}", session_key) @@ -1032,14 +1104,14 @@ class AgentLoop: leftover, session_key, ) if not turn_continuation.internal_continuation_pending(msg.metadata): - await self._webui_turns.publish_run_status(msg, "idle") + await self._publish_run_status_event(msg, session_key, "idle") self._pending_turn_latency_ms.pop(session_key, None) - self._webui_turns.discard(session_key) + self._pending_turn_runtime.pop(session_key, None) finally: if pending is None: - await self._webui_turns.publish_run_status(msg, "idle") + await self._publish_run_status_event(msg, session_key, "idle") self._pending_turn_latency_ms.pop(session_key, None) - self._webui_turns.discard(session_key) + self._pending_turn_runtime.pop(session_key, None) async def close_mcp(self) -> None: """Drain pending background archives, then close MCP connections.""" @@ -1135,8 +1207,7 @@ class AgentLoop: wall_done = time.time() latency_ms = max(0, int((wall_done - t_wall) * 1000)) self._save_turn(session, all_msgs, 1 + len(history), turn_latency_ms=latency_ms) - if channel == "websocket": - self._pending_turn_latency_ms[key] = latency_ms + self._pending_turn_latency_ms[key] = latency_ms session.enforce_file_cap(on_archive=self.context.memory.raw_archive) self._clear_runtime_checkpoint(session) self.sessions.save(session) @@ -1302,7 +1373,16 @@ class AgentLoop: # ensure it exists in case this handler is invoked independently. if ctx.session is None: ctx.session = self.sessions.get_or_create(ctx.session_key) - mark_webui_session(ctx.session, msg.metadata) + await self.runtime_events.publish( + SessionTurnStarted( + context=self._runtime_event_context( + channel=msg.channel, + chat_id=msg.chat_id, + session_key=ctx.session_key, + metadata=msg.metadata, + ) + ) + ) self.workspace_scopes.persist_message_scope(ctx.session, msg) if self._restore_runtime_checkpoint(ctx.session): @@ -1374,11 +1454,8 @@ class AgentLoop: "include_timestamps": True, } ctx.history = ctx.session.get_history(**_hist_kwargs) - self._webui_turns.capture_title_context( - ctx.session_key, - ctx.msg, - self.llm_runtime(), - ) + ctx.llm_runtime = self.llm_runtime() + self._pending_turn_runtime[ctx.session_key] = ctx.llm_runtime ctx.initial_messages = self._build_initial_messages( ctx.msg, @@ -1400,8 +1477,9 @@ class AgentLoop: async def _state_run(self, ctx: TurnContext) -> str: if ctx.visible_run_started_at is None: ctx.visible_run_started_at = time.time() - await self._webui_turns.publish_run_status( + await self._publish_run_status_event( ctx.msg, + ctx.session_key, "running", started_at=ctx.visible_run_started_at, ) @@ -1448,8 +1526,7 @@ class AgentLoop: ctx.session, ctx.all_messages, ctx.save_skip, turn_latency_ms=ctx.turn_latency_ms, ) - if ctx.msg.channel == "websocket": - self._pending_turn_latency_ms[ctx.session_key] = ctx.turn_latency_ms + self._pending_turn_latency_ms[ctx.session_key] = ctx.turn_latency_ms ctx.session.enforce_file_cap(on_archive=self.context.memory.raw_archive) self._clear_pending_user_turn(ctx.session) self._clear_runtime_checkpoint(ctx.session) @@ -1718,7 +1795,6 @@ class AgentLoop: on_stream_end=on_stream_end, ) finally: - if channel == "websocket": - await self._webui_turns.publish_run_status(msg, "idle") - self._pending_turn_latency_ms.pop(session_key, None) - self._webui_turns.discard(session_key) + await self._publish_run_status_event(msg, session_key, "idle") + self._pending_turn_latency_ms.pop(session_key, None) + self._pending_turn_runtime.pop(session_key, None) diff --git a/nanobot/agent/tools/context.py b/nanobot/agent/tools/context.py index 61aa8ed7c..619816181 100644 --- a/nanobot/agent/tools/context.py +++ b/nanobot/agent/tools/context.py @@ -57,3 +57,4 @@ class ToolContext: image_generation_provider_configs: dict[str, Any] | None = None timezone: str = "UTC" workspace_sandbox: Any | None = None + runtime_events: Any | None = None diff --git a/nanobot/agent/tools/long_task.py b/nanobot/agent/tools/long_task.py index a58d14ee5..12fcec174 100644 --- a/nanobot/agent/tools/long_task.py +++ b/nanobot/agent/tools/long_task.py @@ -23,12 +23,11 @@ from typing import TYPE_CHECKING, Any from nanobot.agent.tools.base import Tool, tool_parameters from nanobot.agent.tools.context import ContextAware, RequestContext from nanobot.agent.tools.schema import StringSchema, tool_parameters_schema -from nanobot.bus.events import OutboundMessage +from nanobot.bus.runtime_events import GoalStateChanged, RuntimeEventBus, RuntimeEventContext from nanobot.session.goal_state import ( GOAL_STATE_KEY, discard_legacy_goal_state_key, goal_state_raw, - goal_state_ws_blob, parse_goal_state, ) @@ -43,9 +42,13 @@ def _iso_now() -> str: class _GoalToolsMixin(ContextAware): """Shared routing context + Session lookup.""" - def __init__(self, sessions: SessionManager, bus: Any | None = None) -> None: + def __init__( + self, + sessions: SessionManager, + runtime_events: RuntimeEventBus | None = None, + ) -> None: self._sessions = sessions - self._bus = bus + self._runtime_events = runtime_events # Each subclass gets its own ContextVar so concurrent tasks across # different tool types (LongTaskTool vs CompleteGoalTool) do not # interfere with each other. @@ -66,25 +69,25 @@ class _GoalToolsMixin(ContextAware): return None return self._sessions.get_or_create(key) - async def _publish_goal_state_ws(self, metadata: dict[str, Any]) -> None: - """Fan-out authoritative goal snapshot for this WebSocket chat only.""" - bus = self._bus + async def _publish_goal_state_changed(self, metadata: dict[str, Any]) -> None: + """Publish authoritative goal metadata as a runtime event.""" + runtime_events = self._runtime_events rc = self._request_ctx.get() - if bus is None or rc is None or rc.channel != "websocket": + if runtime_events is None or rc is None: return cid = (rc.chat_id or "").strip() if not cid: return - await bus.publish_outbound( - OutboundMessage( - channel="websocket", - chat_id=cid, - content="", - metadata={ - "_goal_state_sync": True, - "goal_state": goal_state_ws_blob(metadata), - }, - ), + await runtime_events.publish( + GoalStateChanged( + context=RuntimeEventContext( + channel=rc.channel, + chat_id=cid, + session_key=rc.session_key or f"{rc.channel}:{cid}", + metadata=dict(rc.metadata or {}), + ), + session_metadata=dict(metadata), + ) ) @@ -108,14 +111,21 @@ class _GoalToolsMixin(ContextAware): class LongTaskTool(Tool, _GoalToolsMixin): """Begin or replace focus on a long-running objective stored on the session.""" - def __init__(self, sessions: Any, bus: Any | None = None) -> None: - _GoalToolsMixin.__init__(self, sessions, bus) + def __init__( + self, + sessions: Any, + runtime_events: RuntimeEventBus | None = None, + ) -> None: + _GoalToolsMixin.__init__(self, sessions, runtime_events) @classmethod def create(cls, ctx: Any) -> Tool: sess = getattr(ctx, "sessions", None) assert sess is not None # guarded by enabled() - return cls(sessions=sess, bus=getattr(ctx, "bus", None)) + return cls( + sessions=sess, + runtime_events=getattr(ctx, "runtime_events", None), + ) @classmethod def enabled(cls, ctx: Any) -> bool: @@ -160,7 +170,7 @@ class LongTaskTool(Tool, _GoalToolsMixin): sess.metadata[GOAL_STATE_KEY] = blob discard_legacy_goal_state_key(sess.metadata) self._sessions.save(sess) - await self._publish_goal_state_ws(sess.metadata) + await self._publish_goal_state_changed(sess.metadata) extra = f"\nSummary line: {summary}" if summary else "" return ( "Goal recorded. Keep working toward the objective using ordinary tools. " @@ -183,14 +193,21 @@ class LongTaskTool(Tool, _GoalToolsMixin): class CompleteGoalTool(Tool, _GoalToolsMixin): """Mark the active sustained goal finished after all required work is verified.""" - def __init__(self, sessions: Any, bus: Any | None = None) -> None: - _GoalToolsMixin.__init__(self, sessions, bus) + def __init__( + self, + sessions: Any, + runtime_events: RuntimeEventBus | None = None, + ) -> None: + _GoalToolsMixin.__init__(self, sessions, runtime_events) @classmethod def create(cls, ctx: Any) -> Tool: sess = getattr(ctx, "sessions", None) assert sess is not None - return cls(sessions=sess, bus=getattr(ctx, "bus", None)) + return cls( + sessions=sess, + runtime_events=getattr(ctx, "runtime_events", None), + ) @classmethod def enabled(cls, ctx: Any) -> bool: @@ -227,7 +244,7 @@ class CompleteGoalTool(Tool, _GoalToolsMixin): } discard_legacy_goal_state_key(sess.metadata) self._sessions.save(sess) - await self._publish_goal_state_ws(sess.metadata) + await self._publish_goal_state_changed(sess.metadata) tail = (recap or "").strip() if tail: return f"Goal marked complete ({ended}). Recap:\n{tail}" diff --git a/nanobot/bus/progress.py b/nanobot/bus/progress.py new file mode 100644 index 000000000..6cdb416d9 --- /dev/null +++ b/nanobot/bus/progress.py @@ -0,0 +1,84 @@ +"""Progress callback helpers that publish through the message bus.""" + +from __future__ import annotations + +from collections.abc import Awaitable, Callable +from typing import Any + +from nanobot.bus.events import InboundMessage, OutboundMessage +from nanobot.bus.queue import MessageBus + + +def build_bus_progress_callback( + bus: MessageBus, + msg: InboundMessage, +) -> Callable[..., Awaitable[None]]: + """Return the bus progress callback for agent runtime events.""" + + async def _publish_progress( + content: str, + *, + tool_hint: bool = False, + tool_events: list[dict[str, Any]] | None = None, + file_edit_events: list[dict[str, Any]] | None = None, + reasoning: bool = False, + reasoning_end: bool = False, + ) -> None: + meta = dict(msg.metadata or {}) + meta["_progress"] = True + meta["_tool_hint"] = tool_hint + if reasoning: + meta["_reasoning_delta"] = True + if reasoning_end: + meta["_reasoning_end"] = True + if tool_events: + meta["_tool_events"] = tool_events + if file_edit_events: + meta["_file_edit_events"] = file_edit_events + await bus.publish_outbound( + OutboundMessage( + channel=msg.channel, + chat_id=msg.chat_id, + content=content, + metadata=meta, + ) + ) + + if msg.channel == "websocket": + async def _websocket_progress( + content: str, + *, + tool_hint: bool = False, + tool_events: list[dict[str, Any]] | None = None, + file_edit_events: list[dict[str, Any]] | None = None, + reasoning: bool = False, + reasoning_end: bool = False, + ) -> None: + await _publish_progress( + content, + tool_hint=tool_hint, + tool_events=tool_events, + file_edit_events=file_edit_events, + reasoning=reasoning, + reasoning_end=reasoning_end, + ) + + return _websocket_progress + + async def _bus_progress( + content: str, + *, + tool_hint: bool = False, + tool_events: list[dict[str, Any]] | None = None, + reasoning: bool = False, + reasoning_end: bool = False, + ) -> None: + await _publish_progress( + content, + tool_hint=tool_hint, + tool_events=tool_events, + reasoning=reasoning, + reasoning_end=reasoning_end, + ) + + return _bus_progress diff --git a/nanobot/bus/runtime_events.py b/nanobot/bus/runtime_events.py new file mode 100644 index 000000000..b9e7b9e9d --- /dev/null +++ b/nanobot/bus/runtime_events.py @@ -0,0 +1,116 @@ +"""Runtime event bus for agent state notifications. + +This bus is separate from :mod:`nanobot.bus.queue`: message bus events are +user/chat delivery, while runtime events are in-process state notifications +that optional subscribers such as WebUI adapters may render. +""" + +from __future__ import annotations + +import asyncio +import contextlib +import inspect +from collections.abc import Awaitable, Callable +from dataclasses import dataclass, field +from typing import Any + +from loguru import logger + + +@dataclass(frozen=True) +class RuntimeEventContext: + """Routing context common to turn-scoped runtime events.""" + + channel: str + chat_id: str + session_key: str + metadata: dict[str, Any] = field(default_factory=dict) + + +@dataclass(frozen=True) +class SessionTurnStarted: + """A user/system turn has loaded its session and is about to build context.""" + + context: RuntimeEventContext + + +@dataclass(frozen=True) +class TurnRunStatusChanged: + """Visible run status changed for a turn.""" + + context: RuntimeEventContext + status: str + started_at: float | None = None + + +@dataclass(frozen=True) +class TurnCompleted: + """A turn has delivered its final user-visible response.""" + + context: RuntimeEventContext + latency_ms: int | None = None + runtime: Any | None = None + + +@dataclass(frozen=True) +class GoalStateChanged: + """A session's sustained-goal state changed.""" + + context: RuntimeEventContext + session_metadata: dict[str, Any] = field(default_factory=dict) + + +@dataclass(frozen=True) +class RuntimeModelChanged: + """The active runtime model/preset changed.""" + + model: str + model_preset: str | None + + +RuntimeEvent = ( + SessionTurnStarted + | TurnRunStatusChanged + | TurnCompleted + | GoalStateChanged + | RuntimeModelChanged +) +RuntimeEventHandler = Callable[[RuntimeEvent], Awaitable[None] | None] + + +class RuntimeEventBus: + """Small in-process pub/sub bus for runtime state. + + Subscribers run in registration order. ``publish`` awaits async handlers so + callers can preserve ordering when a runtime event must follow a user + message. ``publish_nowait`` is available for synchronous call sites. + """ + + def __init__(self) -> None: + self._handlers: list[RuntimeEventHandler] = [] + + def subscribe(self, handler: RuntimeEventHandler) -> Callable[[], None]: + self._handlers.append(handler) + + def _unsubscribe() -> None: + with contextlib.suppress(ValueError): + self._handlers.remove(handler) + + return _unsubscribe + + async def publish(self, event: RuntimeEvent) -> None: + for handler in list(self._handlers): + try: + result = handler(event) + if inspect.isawaitable(result): + await result + except Exception: + logger.exception("runtime event handler failed for {}", type(event).__name__) + + def publish_nowait(self, event: RuntimeEvent) -> None: + try: + loop = asyncio.get_running_loop() + except RuntimeError: + logger.debug("dropping runtime event without a running loop: {}", type(event).__name__) + return + loop.create_task(self.publish(event)) diff --git a/nanobot/cli/commands.py b/nanobot/cli/commands.py index bf1dc55a2..2c8808d26 100644 --- a/nanobot/cli/commands.py +++ b/nanobot/cli/commands.py @@ -881,19 +881,21 @@ def _run_gateway( from nanobot.agent.tools.cron import CronTool from nanobot.agent.tools.message import MessageTool from nanobot.bus.queue import MessageBus + from nanobot.bus.runtime_events import RuntimeEventBus from nanobot.channels.manager import ChannelManager - from nanobot.channels.websocket import publish_runtime_model_update from nanobot.cron.service import CronService from nanobot.cron.types import CronJob from nanobot.providers.factory import build_provider_snapshot, load_provider_snapshot from nanobot.providers.image_generation import image_gen_provider_configs from nanobot.session.manager import SessionManager + from nanobot.session.webui_turns import WebuiTurnCoordinator port = port if port is not None else config.gateway.port console.print(f"{__logo__} Starting nanobot gateway version {__version__} on port {port}...") sync_workspace_templates(config.workspace_path) bus = MessageBus() + runtime_events = RuntimeEventBus() try: provider_snapshot = build_provider_snapshot(config) except ValueError as exc: @@ -919,13 +921,14 @@ def _run_gateway( session_manager=session_manager, image_generation_provider_configs=image_gen_provider_configs(config), provider_snapshot_loader=load_provider_snapshot, - runtime_model_publisher=lambda model, preset: publish_runtime_model_update( - bus, - model, - preset, - ), + runtime_events=runtime_events, provider_signature=provider_snapshot.signature, ) + WebuiTurnCoordinator( + bus=bus, + sessions=session_manager, + schedule_background=lambda coro: agent._schedule_background(coro), + ).subscribe(runtime_events) from nanobot.agent.loop import UNIFIED_SESSION_KEY from nanobot.bus.events import OutboundMessage diff --git a/nanobot/session/webui_turns.py b/nanobot/session/webui_turns.py index 47614f0b1..5113fb20d 100644 --- a/nanobot/session/webui_turns.py +++ b/nanobot/session/webui_turns.py @@ -1,8 +1,4 @@ -"""Session turn helpers for WebUI-capable WebSocket sessions. - -AgentLoop uses these without importing a concrete channel plugin; only -``channel == "websocket"`` messages are affected. -""" +"""Session turn helpers for WebUI-capable WebSocket sessions.""" from __future__ import annotations @@ -14,8 +10,19 @@ from typing import Any from loguru import logger +from nanobot.bus import progress as bus_progress from nanobot.bus.events import InboundMessage, OutboundMessage from nanobot.bus.queue import MessageBus +from nanobot.bus.runtime_events import ( + GoalStateChanged, + RuntimeEvent, + RuntimeEventBus, + RuntimeEventContext, + RuntimeModelChanged, + SessionTurnStarted, + TurnCompleted, + TurnRunStatusChanged, +) from nanobot.providers.base import LLMProvider from nanobot.session.goal_state import goal_state_ws_blob from nanobot.session.manager import Session, SessionManager @@ -178,6 +185,14 @@ def websocket_turn_wall_started_at(chat_id: str) -> float | None: return _WEBSOCKET_TURN_WALL_STARTED_AT.get(chat_id) +def build_bus_progress_callback( + bus: MessageBus, + msg: InboundMessage, +) -> Callable[..., Awaitable[None]]: + """Compatibility wrapper for the generic bus progress callback.""" + return bus_progress.build_bus_progress_callback(bus, msg) + + async def publish_turn_run_status( bus: MessageBus, msg: InboundMessage, @@ -212,91 +227,110 @@ async def publish_turn_run_status( ), ) - -def build_bus_progress_callback( - bus: MessageBus, - msg: InboundMessage, -) -> Callable[..., Awaitable[None]]: - """Return the bus progress callback for agent runtime events.""" - - async def _publish_progress( - content: str, - *, - tool_hint: bool = False, - tool_events: list[dict[str, Any]] | None = None, - file_edit_events: list[dict[str, Any]] | None = None, - reasoning: bool = False, - reasoning_end: bool = False, - ) -> None: - meta = dict(msg.metadata or {}) - meta["_progress"] = True - meta["_tool_hint"] = tool_hint - if reasoning: - meta["_reasoning_delta"] = True - if reasoning_end: - meta["_reasoning_end"] = True - if tool_events: - meta["_tool_events"] = tool_events - if file_edit_events: - meta["_file_edit_events"] = file_edit_events - await bus.publish_outbound( - OutboundMessage( - channel=msg.channel, - chat_id=msg.chat_id, - content=content, - metadata=meta, - ) - ) - - if msg.channel == "websocket": - async def _websocket_progress( - content: str, - *, - tool_hint: bool = False, - tool_events: list[dict[str, Any]] | None = None, - file_edit_events: list[dict[str, Any]] | None = None, - reasoning: bool = False, - reasoning_end: bool = False, - ) -> None: - await _publish_progress( - content, - tool_hint=tool_hint, - tool_events=tool_events, - file_edit_events=file_edit_events, - reasoning=reasoning, - reasoning_end=reasoning_end, - ) - - return _websocket_progress - - async def _bus_progress( - content: str, - *, - tool_hint: bool = False, - tool_events: list[dict[str, Any]] | None = None, - reasoning: bool = False, - reasoning_end: bool = False, - ) -> None: - await _publish_progress( - content, - tool_hint=tool_hint, - tool_events=tool_events, - reasoning=reasoning, - reasoning_end=reasoning_end, - ) - - return _bus_progress - - @dataclass class WebuiTurnCoordinator: - """Own the WebUI/WebSocket wire details that hang off AgentLoop turns.""" + """Translate generic runtime events into WebUI/WebSocket wire messages.""" bus: MessageBus sessions: SessionManager schedule_background: Callable[[Awaitable[None]], None] _title_contexts: dict[str, LLMRuntime] = field(default_factory=dict) + def subscribe(self, runtime_events: RuntimeEventBus) -> Callable[[], None]: + """Subscribe this coordinator to runtime events.""" + return runtime_events.subscribe(self.handle_runtime_event) + + async def handle_runtime_event(self, event: RuntimeEvent) -> None: + if isinstance(event, SessionTurnStarted): + self._handle_session_turn_started(event) + return + if isinstance(event, TurnRunStatusChanged): + await self._handle_run_status_changed(event) + return + if isinstance(event, TurnCompleted): + await self._handle_turn_completed_event(event) + return + if isinstance(event, GoalStateChanged): + await self._handle_goal_state_changed(event) + return + if isinstance(event, RuntimeModelChanged): + await self._handle_runtime_model_changed(event) + return + + @staticmethod + def _ctx_msg(ctx: RuntimeEventContext) -> InboundMessage: + return InboundMessage( + channel=ctx.channel, + sender_id="runtime", + chat_id=ctx.chat_id, + content="", + metadata=dict(ctx.metadata or {}), + session_key_override=ctx.session_key, + ) + + @staticmethod + def _is_websocket_event(ctx: RuntimeEventContext) -> bool: + return ctx.channel == "websocket" + + def _handle_session_turn_started(self, event: SessionTurnStarted) -> None: + if not self._is_websocket_event(event.context): + return + session = self.sessions.get_or_create(event.context.session_key) + mark_webui_session(session, event.context.metadata) + + async def _handle_run_status_changed(self, event: TurnRunStatusChanged) -> None: + if not self._is_websocket_event(event.context): + return + await publish_turn_run_status( + self.bus, + self._ctx_msg(event.context), + event.status, + started_at=event.started_at, + ) + + async def _handle_turn_completed_event(self, event: TurnCompleted) -> None: + if not self._is_websocket_event(event.context): + return + msg = self._ctx_msg(event.context) + await self.handle_turn_end( + msg, + session_key=event.context.session_key, + latency_ms=event.latency_ms, + ) + self._schedule_title_update_from_event(event) + + async def _handle_goal_state_changed(self, event: GoalStateChanged) -> None: + if not self._is_websocket_event(event.context): + return + cid = str(event.context.chat_id or "").strip() + if not cid: + return + await self.bus.publish_outbound( + OutboundMessage( + channel=event.context.channel, + chat_id=cid, + content="", + metadata={ + "_goal_state_sync": True, + "goal_state": goal_state_ws_blob(event.session_metadata), + }, + ), + ) + + async def _handle_runtime_model_changed(self, event: RuntimeModelChanged) -> None: + await self.bus.publish_outbound( + OutboundMessage( + channel="websocket", + chat_id="*", + content="", + metadata={ + "_runtime_model_updated": True, + "model": event.model, + "model_preset": event.model_preset, + }, + ) + ) + def capture_title_context( self, session_key: str, @@ -370,3 +404,37 @@ class WebuiTurnCoordinator: )) self.schedule_background(_generate_title_and_notify()) + + def _schedule_title_update_from_event(self, event: TurnCompleted) -> None: + title_context = event.runtime + if ( + event.context.metadata.get("webui") is not True + or title_context is None + or not isinstance(title_context, LLMRuntime) + ): + return + + async def _generate_title_and_notify( + title_llm: LLMRuntime = title_context, + ) -> None: + generated = await maybe_generate_webui_title_after_turn( + channel=event.context.channel, + metadata=event.context.metadata, + sessions=self.sessions, + session_key=event.context.session_key, + provider=title_llm.provider, + model=title_llm.model, + ) + if generated: + await self.bus.publish_outbound(OutboundMessage( + channel=event.context.channel, + chat_id=event.context.chat_id, + content="", + metadata={ + **event.context.metadata, + "_session_updated": True, + "_session_update_scope": "metadata", + }, + )) + + self.schedule_background(_generate_title_and_notify()) diff --git a/tests/agent/test_loop_direct_websocket_status.py b/tests/agent/test_loop_direct_websocket_status.py index dff6df803..879fa23a1 100644 --- a/tests/agent/test_loop_direct_websocket_status.py +++ b/tests/agent/test_loop_direct_websocket_status.py @@ -7,6 +7,7 @@ from nanobot.agent.loop import AgentLoop from nanobot.bus.events import OutboundMessage from nanobot.bus.queue import MessageBus from nanobot.providers.base import GenerationSettings, LLMResponse +from nanobot.session.webui_turns import WebuiTurnCoordinator def _make_loop(tmp_path): @@ -25,6 +26,11 @@ def _make_loop(tmp_path): workspace=tmp_path, model="test-model", ) + WebuiTurnCoordinator( + bus=bus, + sessions=loop.sessions, + schedule_background=lambda coro: loop._schedule_background(coro), + ).subscribe(loop.runtime_events) loop.tools.get_definitions = MagicMock(return_value=[]) return loop diff --git a/tests/agent/test_loop_progress.py b/tests/agent/test_loop_progress.py index f7bd038ba..2f5486f72 100644 --- a/tests/agent/test_loop_progress.py +++ b/tests/agent/test_loop_progress.py @@ -11,6 +11,7 @@ from nanobot.agent.loop import AgentLoop from nanobot.bus.events import InboundMessage from nanobot.bus.queue import MessageBus from nanobot.providers.base import LLMResponse, ToolCallRequest +from nanobot.session.webui_turns import WebuiTurnCoordinator from nanobot.utils.progress_events import ( invoke_file_edit_progress, on_progress_accepts_file_edit_events, @@ -24,6 +25,15 @@ def _make_loop(tmp_path: Path) -> AgentLoop: return AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model") +def _attach_webui_runtime_events(loop: AgentLoop, bus: MessageBus) -> None: + coordinator = WebuiTurnCoordinator( + bus=bus, + sessions=loop.sessions, + schedule_background=lambda coro: loop._schedule_background(coro), + ) + coordinator.subscribe(loop.runtime_events) + + class TestToolEventProgress: """_run_agent_loop emits structured tool_events via on_progress.""" @@ -456,6 +466,7 @@ class TestToolEventProgress: provider.chat_stream_with_retry = chat_stream_with_retry provider.chat_with_retry = AsyncMock() loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="openai-codex/gpt-5.5") + _attach_webui_runtime_events(loop, bus) loop.tools.get_definitions = MagicMock(return_value=[]) loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign] @@ -549,6 +560,7 @@ class TestToolEventProgress: provider.get_default_model.return_value = "test-model" provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="Done", tool_calls=[])) loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model") + _attach_webui_runtime_events(loop, bus) loop.tools.get_definitions = MagicMock(return_value=[]) loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign] @@ -593,6 +605,7 @@ class TestToolEventProgress: provider.chat_with_retry = AsyncMock(side_effect=chat_with_retry) loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model") + _attach_webui_runtime_events(loop, bus) loop.tools.get_definitions = MagicMock(return_value=[]) loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign] @@ -641,6 +654,7 @@ class TestToolEventProgress: provider.get_default_model.return_value = "test-model" provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="Done", tool_calls=[])) loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model") + _attach_webui_runtime_events(loop, bus) loop.tools.get_definitions = MagicMock(return_value=[]) loop.consolidator.maybe_consolidate_by_tokens = AsyncMock(return_value=False) # type: ignore[method-assign] @@ -693,6 +707,7 @@ class TestToolEventProgress: provider.get_default_model.return_value = "test-model" provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="Done", tool_calls=[])) loop = AgentLoop(bus=bus, provider=provider, workspace=tmp_path, model="test-model") + _attach_webui_runtime_events(loop, bus) async def fake_title_after_turn(**_kwargs: object) -> bool: raise AssertionError("command-only turns should not generate titles") diff --git a/tests/agent/test_loop_save_turn.py b/tests/agent/test_loop_save_turn.py index f9847df8b..874f0b435 100644 --- a/tests/agent/test_loop_save_turn.py +++ b/tests/agent/test_loop_save_turn.py @@ -39,7 +39,13 @@ def _make_full_loop(tmp_path: Path) -> AgentLoop: provider = MagicMock() provider.get_default_model.return_value = "test-model" provider.chat_with_retry = AsyncMock(return_value=LLMResponse(content="Test title")) - return AgentLoop(bus=MessageBus(), provider=provider, workspace=tmp_path, model="test-model") + loop = AgentLoop(bus=MessageBus(), provider=provider, workspace=tmp_path, model="test-model") + WebuiTurnCoordinator( + bus=loop.bus, + sessions=loop.sessions, + schedule_background=lambda coro: loop._schedule_background(coro), + ).subscribe(loop.runtime_events) + return loop def test_agent_loop_llm_runtime_reflects_current_provider_and_model(tmp_path: Path) -> None: diff --git a/tests/agent/tools/test_long_task.py b/tests/agent/tools/test_long_task.py index ef573a473..03bd91d8b 100644 --- a/tests/agent/tools/test_long_task.py +++ b/tests/agent/tools/test_long_task.py @@ -14,8 +14,10 @@ from nanobot.agent.tools.long_task import ( LongTaskTool, ) from nanobot.bus.queue import MessageBus +from nanobot.bus.runtime_events import RuntimeEventBus from nanobot.session.goal_state import GOAL_STATE_KEY from nanobot.session.manager import SessionManager +from nanobot.session.webui_turns import WebuiTurnCoordinator def _tools(sm: SessionManager) -> tuple[LongTaskTool, CompleteGoalTool]: @@ -120,8 +122,14 @@ async def test_goal_tools_context_isolated_across_tool_types(tmp_path): async def test_long_task_publishes_goal_state_ws_after_save(tmp_path): bus = MagicMock() bus.publish_outbound = AsyncMock() + runtime_events = RuntimeEventBus() sm = SessionManager(tmp_path) - lt = LongTaskTool(sessions=sm, bus=bus) + WebuiTurnCoordinator( + bus=bus, + sessions=sm, + schedule_background=lambda _coro: None, + ).subscribe(runtime_events) + lt = LongTaskTool(sessions=sm, runtime_events=runtime_events) rc = RequestContext( channel="websocket", chat_id="chat-99", @@ -148,9 +156,15 @@ async def test_long_task_publishes_goal_state_ws_after_save(tmp_path): async def test_complete_goal_publishes_inactive_goal_state_ws(tmp_path): bus = MagicMock() bus.publish_outbound = AsyncMock() + runtime_events = RuntimeEventBus() sm = SessionManager(tmp_path) - lt = LongTaskTool(sessions=sm, bus=bus) - cg = CompleteGoalTool(sessions=sm, bus=bus) + WebuiTurnCoordinator( + bus=bus, + sessions=sm, + schedule_background=lambda _coro: None, + ).subscribe(runtime_events) + lt = LongTaskTool(sessions=sm, runtime_events=runtime_events) + cg = CompleteGoalTool(sessions=sm, runtime_events=runtime_events) rc = RequestContext( channel="websocket", chat_id="chat-z",