diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index 2e98136bf..0f13ce8e2 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -33,11 +33,8 @@ from nanobot.bus.progress import build_bus_progress_callback from nanobot.bus.queue import MessageBus from nanobot.bus.runtime_events import ( RuntimeEventBus, - RuntimeEventContext, - RuntimeModelChanged, - SessionTurnStarted, - TurnCompleted, - TurnRunStatusChanged, + RuntimeEventPublisher, + ensure_runtime_event_publisher, ) from nanobot.command import CommandContext, CommandRouter, register_builtin_commands from nanobot.config.schema import AgentDefaults, ModelPresetConfig @@ -129,7 +126,6 @@ class TurnContext: turn_wall_started_at: float = field(default_factory=time.time) visible_run_started_at: float | None = None turn_latency_ms: int | None = None - llm_runtime: LLMRuntime | None = None trace: list[StateTraceEntry] = field(default_factory=list) @@ -217,6 +213,7 @@ class AgentLoop: defaults = AgentDefaults() self.bus = bus self.runtime_events = runtime_events or RuntimeEventBus() + self.runtime_event_publisher = RuntimeEventPublisher(self.runtime_events) self.channels_config = channels_config self.provider = provider self._provider_snapshot_loader = provider_snapshot_loader @@ -262,8 +259,6 @@ class AgentLoop: ) self._start_time = time.time() self._last_usage: dict[str, int] = {} - self._pending_turn_latency_ms: dict[str, int] = {} - self._pending_turn_runtime: dict[str, LLMRuntime] = {} self._extra_hooks: list[AgentHook] = hooks or [] self.context = ContextBuilder(workspace, timezone=timezone, disabled_skills=disabled_skills) @@ -422,11 +417,9 @@ class AgentLoop: model_preset if model_preset is not None else self.model_preset, ) if publish_update: - self._runtime_event_bus().publish_nowait( - RuntimeModelChanged( - model=self.model, - model_preset=model_preset if model_preset is not None else self.model_preset, - ) + self._runtime_events().runtime_model_changed( + self.model, + model_preset if model_preset is not None else self.model_preset, ) logger.info("Runtime model switched for next turn: {} -> {}", old_model, model) @@ -569,85 +562,8 @@ class AgentLoop: return _on_retry_wait - @staticmethod - def _runtime_event_context( - *, - channel: str, - chat_id: str, - session_key: str, - metadata: dict[str, Any] | None, - ) -> RuntimeEventContext: - return RuntimeEventContext( - channel=channel, - chat_id=chat_id, - session_key=session_key, - metadata=dict(metadata or {}), - ) - - def _runtime_event_bus(self) -> RuntimeEventBus: - bus = getattr(self, "runtime_events", None) - if bus is None: - bus = RuntimeEventBus() - self.runtime_events = bus - return bus - - def _pop_pending_turn_latency(self, session_key: str) -> int | None: - pending = getattr(self, "_pending_turn_latency_ms", None) - if not isinstance(pending, dict): - return None - return pending.pop(session_key, None) - - def _pop_pending_turn_runtime(self, session_key: str) -> LLMRuntime | None: - pending = getattr(self, "_pending_turn_runtime", None) - if not isinstance(pending, dict): - return None - return pending.pop(session_key, None) - - def _clear_pending_turn_runtime(self, session_key: str) -> None: - self._pop_pending_turn_latency(session_key) - self._pop_pending_turn_runtime(session_key) - - async def _publish_run_status_event( - self, - msg: InboundMessage, - session_key: str, - status: str, - *, - started_at: float | None = None, - ) -> None: - await self._runtime_event_bus().publish( - TurnRunStatusChanged( - context=self._runtime_event_context( - channel=msg.channel, - chat_id=msg.chat_id, - session_key=session_key, - metadata=msg.metadata, - ), - status=status, - started_at=started_at, - ) - ) - - async def _publish_turn_completed_event( - self, - *, - channel: str, - chat_id: str, - session_key: str, - metadata: dict[str, Any] | None, - ) -> None: - await self._runtime_event_bus().publish( - TurnCompleted( - context=self._runtime_event_context( - channel=channel, - chat_id=chat_id, - session_key=session_key, - metadata=metadata, - ), - latency_ms=self._pop_pending_turn_latency(session_key), - runtime=self._pop_pending_turn_runtime(session_key), - ) - ) + def _runtime_events(self) -> RuntimeEventPublisher: + return ensure_runtime_event_publisher(self) def _persist_user_message_early( self, @@ -1063,7 +979,7 @@ class AgentLoop: )) continuing = turn_continuation.internal_continuation_pending(msg.metadata) if not continuing: - await self._publish_turn_completed_event( + await self._runtime_events().turn_completed( channel=completed_channel, chat_id=completed_chat_id, session_key=session_key, @@ -1127,12 +1043,16 @@ class AgentLoop: leftover, session_key, ) if not turn_continuation.internal_continuation_pending(msg.metadata): - await self._publish_run_status_event(msg, session_key, "idle") - self._clear_pending_turn_runtime(session_key) + await self._runtime_events().run_status_changed( + msg, session_key, "idle" + ) + self._runtime_events().clear_turn(session_key) finally: if pending is None: - await self._publish_run_status_event(msg, session_key, "idle") - self._clear_pending_turn_runtime(session_key) + await self._runtime_events().run_status_changed( + msg, session_key, "idle" + ) + self._runtime_events().clear_turn(session_key) async def close_mcp(self) -> None: """Drain pending background archives, then close MCP connections.""" @@ -1228,7 +1148,7 @@ class AgentLoop: wall_done = time.time() latency_ms = max(0, int((wall_done - t_wall) * 1000)) self._save_turn(session, all_msgs, 1 + len(history), turn_latency_ms=latency_ms) - self._pending_turn_latency_ms[key] = latency_ms + self._runtime_events().record_turn_latency(key, latency_ms) session.enforce_file_cap(on_archive=self.context.memory.raw_archive) self._clear_runtime_checkpoint(session) self.sessions.save(session) @@ -1394,16 +1314,7 @@ class AgentLoop: # ensure it exists in case this handler is invoked independently. if ctx.session is None: ctx.session = self.sessions.get_or_create(ctx.session_key) - await self._runtime_event_bus().publish( - SessionTurnStarted( - context=self._runtime_event_context( - channel=msg.channel, - chat_id=msg.chat_id, - session_key=ctx.session_key, - metadata=msg.metadata, - ) - ) - ) + await self._runtime_events().session_turn_started(msg, ctx.session_key) self.workspace_scopes.persist_message_scope(ctx.session, msg) if self._restore_runtime_checkpoint(ctx.session): @@ -1475,8 +1386,10 @@ class AgentLoop: "include_timestamps": True, } ctx.history = ctx.session.get_history(**_hist_kwargs) - ctx.llm_runtime = self.llm_runtime() - self._pending_turn_runtime[ctx.session_key] = ctx.llm_runtime + self._runtime_events().record_turn_runtime( + ctx.session_key, + self.llm_runtime(), + ) ctx.initial_messages = self._build_initial_messages( ctx.msg, @@ -1498,7 +1411,7 @@ class AgentLoop: async def _state_run(self, ctx: TurnContext) -> str: if ctx.visible_run_started_at is None: ctx.visible_run_started_at = time.time() - await self._publish_run_status_event( + await self._runtime_events().run_status_changed( ctx.msg, ctx.session_key, "running", @@ -1547,7 +1460,10 @@ class AgentLoop: ctx.session, ctx.all_messages, ctx.save_skip, turn_latency_ms=ctx.turn_latency_ms, ) - self._pending_turn_latency_ms[ctx.session_key] = ctx.turn_latency_ms + self._runtime_events().record_turn_latency( + ctx.session_key, + ctx.turn_latency_ms, + ) ctx.session.enforce_file_cap(on_archive=self.context.memory.raw_archive) self._clear_pending_user_turn(ctx.session) self._clear_runtime_checkpoint(ctx.session) @@ -1816,5 +1732,5 @@ class AgentLoop: on_stream_end=on_stream_end, ) finally: - await self._publish_run_status_event(msg, session_key, "idle") - self._clear_pending_turn_runtime(session_key) + await self._runtime_events().run_status_changed(msg, session_key, "idle") + self._runtime_events().clear_turn(session_key) diff --git a/nanobot/bus/runtime_events.py b/nanobot/bus/runtime_events.py index ccb4eb145..fabe3c9b9 100644 --- a/nanobot/bus/runtime_events.py +++ b/nanobot/bus/runtime_events.py @@ -16,6 +16,8 @@ from typing import Any from loguru import logger +from nanobot.bus.events import InboundMessage + @dataclass(frozen=True) class RuntimeEventContext: @@ -129,3 +131,121 @@ class RuntimeEventBus: logger.debug("dropping runtime event without a running loop: {}", type(event).__name__) return loop.create_task(self.publish(event)) + + +class RuntimeEventPublisher: + """Convenience publisher for turn-scoped runtime events. + + Agent code should decide when state transitions happen; this helper owns + the mechanics of building event contexts and carrying per-turn metadata. + """ + + def __init__(self, bus: RuntimeEventBus | None = None) -> None: + self.bus = bus or RuntimeEventBus() + self._turn_latency_ms: dict[str, int] = {} + self._turn_runtime: dict[str, Any] = {} + + @staticmethod + def _context( + *, + channel: str, + chat_id: str, + session_key: str, + metadata: dict[str, Any] | None, + ) -> RuntimeEventContext: + return RuntimeEventContext( + channel=channel, + chat_id=chat_id, + session_key=session_key, + metadata=dict(metadata or {}), + ) + + def record_turn_runtime(self, session_key: str, runtime: Any) -> None: + self._turn_runtime[session_key] = runtime + + def record_turn_latency(self, session_key: str, latency_ms: int | None) -> None: + if latency_ms is not None: + self._turn_latency_ms[session_key] = int(latency_ms) + + def clear_turn(self, session_key: str) -> None: + self._turn_latency_ms.pop(session_key, None) + self._turn_runtime.pop(session_key, None) + + async def session_turn_started( + self, + msg: InboundMessage, + session_key: str, + ) -> None: + await self.bus.publish( + SessionTurnStarted( + context=self._context( + channel=msg.channel, + chat_id=msg.chat_id, + session_key=session_key, + metadata=msg.metadata, + ) + ) + ) + + async def run_status_changed( + self, + msg: InboundMessage, + session_key: str, + status: str, + *, + started_at: float | None = None, + ) -> None: + await self.bus.publish( + TurnRunStatusChanged( + context=self._context( + channel=msg.channel, + chat_id=msg.chat_id, + session_key=session_key, + metadata=msg.metadata, + ), + status=status, + started_at=started_at, + ) + ) + + async def turn_completed( + self, + *, + channel: str, + chat_id: str, + session_key: str, + metadata: dict[str, Any] | None, + ) -> None: + await self.bus.publish( + TurnCompleted( + context=self._context( + channel=channel, + chat_id=chat_id, + session_key=session_key, + metadata=metadata, + ), + latency_ms=self._turn_latency_ms.pop(session_key, None), + runtime=self._turn_runtime.pop(session_key, None), + ) + ) + + def runtime_model_changed(self, model: str, model_preset: str | None) -> None: + self.bus.publish_nowait( + RuntimeModelChanged(model=model, model_preset=model_preset) + ) + + +def ensure_runtime_event_publisher(owner: Any) -> RuntimeEventPublisher: + """Return an owner's runtime publisher, creating missing state lazily.""" + publisher = getattr(owner, "runtime_event_publisher", None) + if isinstance(publisher, RuntimeEventPublisher): + return publisher + + bus = getattr(owner, "runtime_events", None) + if not isinstance(bus, RuntimeEventBus): + bus = RuntimeEventBus() + owner.runtime_events = bus + + publisher = RuntimeEventPublisher(bus) + owner.runtime_event_publisher = publisher + return publisher diff --git a/tests/bus/test_runtime_events.py b/tests/bus/test_runtime_events.py index dd5842108..f5438541f 100644 --- a/tests/bus/test_runtime_events.py +++ b/tests/bus/test_runtime_events.py @@ -1,9 +1,13 @@ import pytest +from nanobot.bus.events import InboundMessage from nanobot.bus.runtime_events import ( RuntimeEventBus, RuntimeEventContext, + RuntimeEventPublisher, RuntimeModelChanged, + SessionTurnStarted, + TurnCompleted, TurnRunStatusChanged, ) @@ -46,3 +50,73 @@ async def test_runtime_event_bus_keeps_catch_all_subscription() -> None: await bus.publish(RuntimeModelChanged(model="m", model_preset=None)) assert seen == ["RuntimeModelChanged"] + + +@pytest.mark.asyncio +async def test_runtime_event_publisher_builds_context_from_inbound_message() -> None: + bus = RuntimeEventBus() + seen: list[object] = [] + publisher = RuntimeEventPublisher(bus) + msg = InboundMessage( + channel="websocket", + sender_id="user", + chat_id="chat-a", + content="hello", + metadata={"trace_id": "turn-1"}, + ) + + bus.subscribe(seen.append) + + await publisher.session_turn_started(msg, "websocket:chat-a") + await publisher.run_status_changed( + msg, + "websocket:chat-a", + "running", + started_at=12.5, + ) + + started = seen[0] + running = seen[1] + assert isinstance(started, SessionTurnStarted) + assert started.context.channel == "websocket" + assert started.context.chat_id == "chat-a" + assert started.context.session_key == "websocket:chat-a" + assert started.context.metadata == {"trace_id": "turn-1"} + assert started.context.metadata is not msg.metadata + assert isinstance(running, TurnRunStatusChanged) + assert running.status == "running" + assert running.started_at == 12.5 + + +@pytest.mark.asyncio +async def test_runtime_event_publisher_consumes_turn_metadata_on_complete() -> None: + bus = RuntimeEventBus() + seen: list[object] = [] + publisher = RuntimeEventPublisher(bus) + + bus.subscribe(seen.append) + publisher.record_turn_runtime("cli:direct", "runtime") + publisher.record_turn_latency("cli:direct", 123) + + await publisher.turn_completed( + channel="cli", + chat_id="direct", + session_key="cli:direct", + metadata={"source": "test"}, + ) + await publisher.turn_completed( + channel="cli", + chat_id="direct", + session_key="cli:direct", + metadata=None, + ) + + first = seen[0] + second = seen[1] + assert isinstance(first, TurnCompleted) + assert first.context.metadata == {"source": "test"} + assert first.latency_ms == 123 + assert first.runtime == "runtime" + assert isinstance(second, TurnCompleted) + assert second.latency_ms is None + assert second.runtime is None