diff --git a/nanobot/bus/runtime_events.py b/nanobot/bus/runtime_events.py index b9e7b9e9d..ccb4eb145 100644 --- a/nanobot/bus/runtime_events.py +++ b/nanobot/bus/runtime_events.py @@ -75,7 +75,15 @@ RuntimeEvent = ( | GoalStateChanged | RuntimeModelChanged ) -RuntimeEventHandler = Callable[[RuntimeEvent], Awaitable[None] | None] +RuntimeEventType = ( + type[SessionTurnStarted] + | type[TurnRunStatusChanged] + | type[TurnCompleted] + | type[GoalStateChanged] + | type[RuntimeModelChanged] +) +RuntimeEventHandler = Callable[[Any], Awaitable[None] | None] +_HandlerEntry = tuple[RuntimeEventType | None, RuntimeEventHandler] class RuntimeEventBus: @@ -87,19 +95,26 @@ class RuntimeEventBus: """ def __init__(self) -> None: - self._handlers: list[RuntimeEventHandler] = [] + self._handlers: list[_HandlerEntry] = [] - def subscribe(self, handler: RuntimeEventHandler) -> Callable[[], None]: - self._handlers.append(handler) + def subscribe( + self, + handler: RuntimeEventHandler, + event_type: RuntimeEventType | None = None, + ) -> Callable[[], None]: + entry = (event_type, handler) + self._handlers.append(entry) def _unsubscribe() -> None: with contextlib.suppress(ValueError): - self._handlers.remove(handler) + self._handlers.remove(entry) return _unsubscribe async def publish(self, event: RuntimeEvent) -> None: - for handler in list(self._handlers): + for event_type, handler in list(self._handlers): + if event_type is not None and not isinstance(event, event_type): + continue try: result = handler(event) if inspect.isawaitable(result): diff --git a/nanobot/session/webui_turns.py b/nanobot/session/webui_turns.py index 5113fb20d..8d4163f32 100644 --- a/nanobot/session/webui_turns.py +++ b/nanobot/session/webui_turns.py @@ -15,7 +15,6 @@ from nanobot.bus.events import InboundMessage, OutboundMessage from nanobot.bus.queue import MessageBus from nanobot.bus.runtime_events import ( GoalStateChanged, - RuntimeEvent, RuntimeEventBus, RuntimeEventContext, RuntimeModelChanged, @@ -238,24 +237,34 @@ class WebuiTurnCoordinator: def subscribe(self, runtime_events: RuntimeEventBus) -> Callable[[], None]: """Subscribe this coordinator to runtime events.""" - return runtime_events.subscribe(self.handle_runtime_event) + unsubscribe = [ + runtime_events.subscribe( + self._handle_session_turn_started, + SessionTurnStarted, + ), + runtime_events.subscribe( + self._handle_run_status_changed, + TurnRunStatusChanged, + ), + runtime_events.subscribe( + self._handle_turn_completed_event, + TurnCompleted, + ), + runtime_events.subscribe( + self._handle_goal_state_changed, + GoalStateChanged, + ), + runtime_events.subscribe( + self._handle_runtime_model_changed, + RuntimeModelChanged, + ), + ] - async def handle_runtime_event(self, event: RuntimeEvent) -> None: - if isinstance(event, SessionTurnStarted): - self._handle_session_turn_started(event) - return - if isinstance(event, TurnRunStatusChanged): - await self._handle_run_status_changed(event) - return - if isinstance(event, TurnCompleted): - await self._handle_turn_completed_event(event) - return - if isinstance(event, GoalStateChanged): - await self._handle_goal_state_changed(event) - return - if isinstance(event, RuntimeModelChanged): - await self._handle_runtime_model_changed(event) - return + def _unsubscribe() -> None: + for fn in reversed(unsubscribe): + fn() + + return _unsubscribe @staticmethod def _ctx_msg(ctx: RuntimeEventContext) -> InboundMessage: diff --git a/tests/bus/test_runtime_events.py b/tests/bus/test_runtime_events.py new file mode 100644 index 000000000..dd5842108 --- /dev/null +++ b/tests/bus/test_runtime_events.py @@ -0,0 +1,48 @@ +import pytest + +from nanobot.bus.runtime_events import ( + RuntimeEventBus, + RuntimeEventContext, + RuntimeModelChanged, + TurnRunStatusChanged, +) + + +@pytest.mark.asyncio +async def test_runtime_event_bus_filters_by_event_type() -> None: + bus = RuntimeEventBus() + seen: list[str] = [] + + async def handle_run_status(event: TurnRunStatusChanged) -> None: + seen.append(event.status) + + bus.subscribe(handle_run_status, TurnRunStatusChanged) + + await bus.publish(RuntimeModelChanged(model="m", model_preset=None)) + await bus.publish( + TurnRunStatusChanged( + context=RuntimeEventContext( + channel="cli", + chat_id="direct", + session_key="cli:direct", + ), + status="running", + ) + ) + + assert seen == ["running"] + + +@pytest.mark.asyncio +async def test_runtime_event_bus_keeps_catch_all_subscription() -> None: + bus = RuntimeEventBus() + seen: list[str] = [] + + def handle_any(event) -> None: + seen.append(type(event).__name__) + + bus.subscribe(handle_any) + + await bus.publish(RuntimeModelChanged(model="m", model_preset=None)) + + assert seen == ["RuntimeModelChanged"]