mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-13 22:34:06 +00:00
refactor: subscribe to runtime event types
This commit is contained in:
parent
2f0e638bd1
commit
81370565e0
@ -75,7 +75,15 @@ RuntimeEvent = (
|
||||
| GoalStateChanged
|
||||
| RuntimeModelChanged
|
||||
)
|
||||
RuntimeEventHandler = Callable[[RuntimeEvent], Awaitable[None] | None]
|
||||
RuntimeEventType = (
|
||||
type[SessionTurnStarted]
|
||||
| type[TurnRunStatusChanged]
|
||||
| type[TurnCompleted]
|
||||
| type[GoalStateChanged]
|
||||
| type[RuntimeModelChanged]
|
||||
)
|
||||
RuntimeEventHandler = Callable[[Any], Awaitable[None] | None]
|
||||
_HandlerEntry = tuple[RuntimeEventType | None, RuntimeEventHandler]
|
||||
|
||||
|
||||
class RuntimeEventBus:
|
||||
@ -87,19 +95,26 @@ class RuntimeEventBus:
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._handlers: list[RuntimeEventHandler] = []
|
||||
self._handlers: list[_HandlerEntry] = []
|
||||
|
||||
def subscribe(self, handler: RuntimeEventHandler) -> Callable[[], None]:
|
||||
self._handlers.append(handler)
|
||||
def subscribe(
|
||||
self,
|
||||
handler: RuntimeEventHandler,
|
||||
event_type: RuntimeEventType | None = None,
|
||||
) -> Callable[[], None]:
|
||||
entry = (event_type, handler)
|
||||
self._handlers.append(entry)
|
||||
|
||||
def _unsubscribe() -> None:
|
||||
with contextlib.suppress(ValueError):
|
||||
self._handlers.remove(handler)
|
||||
self._handlers.remove(entry)
|
||||
|
||||
return _unsubscribe
|
||||
|
||||
async def publish(self, event: RuntimeEvent) -> None:
|
||||
for handler in list(self._handlers):
|
||||
for event_type, handler in list(self._handlers):
|
||||
if event_type is not None and not isinstance(event, event_type):
|
||||
continue
|
||||
try:
|
||||
result = handler(event)
|
||||
if inspect.isawaitable(result):
|
||||
|
||||
@ -15,7 +15,6 @@ from nanobot.bus.events import InboundMessage, OutboundMessage
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.bus.runtime_events import (
|
||||
GoalStateChanged,
|
||||
RuntimeEvent,
|
||||
RuntimeEventBus,
|
||||
RuntimeEventContext,
|
||||
RuntimeModelChanged,
|
||||
@ -238,24 +237,34 @@ class WebuiTurnCoordinator:
|
||||
|
||||
def subscribe(self, runtime_events: RuntimeEventBus) -> Callable[[], None]:
|
||||
"""Subscribe this coordinator to runtime events."""
|
||||
return runtime_events.subscribe(self.handle_runtime_event)
|
||||
unsubscribe = [
|
||||
runtime_events.subscribe(
|
||||
self._handle_session_turn_started,
|
||||
SessionTurnStarted,
|
||||
),
|
||||
runtime_events.subscribe(
|
||||
self._handle_run_status_changed,
|
||||
TurnRunStatusChanged,
|
||||
),
|
||||
runtime_events.subscribe(
|
||||
self._handle_turn_completed_event,
|
||||
TurnCompleted,
|
||||
),
|
||||
runtime_events.subscribe(
|
||||
self._handle_goal_state_changed,
|
||||
GoalStateChanged,
|
||||
),
|
||||
runtime_events.subscribe(
|
||||
self._handle_runtime_model_changed,
|
||||
RuntimeModelChanged,
|
||||
),
|
||||
]
|
||||
|
||||
async def handle_runtime_event(self, event: RuntimeEvent) -> None:
|
||||
if isinstance(event, SessionTurnStarted):
|
||||
self._handle_session_turn_started(event)
|
||||
return
|
||||
if isinstance(event, TurnRunStatusChanged):
|
||||
await self._handle_run_status_changed(event)
|
||||
return
|
||||
if isinstance(event, TurnCompleted):
|
||||
await self._handle_turn_completed_event(event)
|
||||
return
|
||||
if isinstance(event, GoalStateChanged):
|
||||
await self._handle_goal_state_changed(event)
|
||||
return
|
||||
if isinstance(event, RuntimeModelChanged):
|
||||
await self._handle_runtime_model_changed(event)
|
||||
return
|
||||
def _unsubscribe() -> None:
|
||||
for fn in reversed(unsubscribe):
|
||||
fn()
|
||||
|
||||
return _unsubscribe
|
||||
|
||||
@staticmethod
|
||||
def _ctx_msg(ctx: RuntimeEventContext) -> InboundMessage:
|
||||
|
||||
48
tests/bus/test_runtime_events.py
Normal file
48
tests/bus/test_runtime_events.py
Normal file
@ -0,0 +1,48 @@
|
||||
import pytest
|
||||
|
||||
from nanobot.bus.runtime_events import (
|
||||
RuntimeEventBus,
|
||||
RuntimeEventContext,
|
||||
RuntimeModelChanged,
|
||||
TurnRunStatusChanged,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_event_bus_filters_by_event_type() -> None:
|
||||
bus = RuntimeEventBus()
|
||||
seen: list[str] = []
|
||||
|
||||
async def handle_run_status(event: TurnRunStatusChanged) -> None:
|
||||
seen.append(event.status)
|
||||
|
||||
bus.subscribe(handle_run_status, TurnRunStatusChanged)
|
||||
|
||||
await bus.publish(RuntimeModelChanged(model="m", model_preset=None))
|
||||
await bus.publish(
|
||||
TurnRunStatusChanged(
|
||||
context=RuntimeEventContext(
|
||||
channel="cli",
|
||||
chat_id="direct",
|
||||
session_key="cli:direct",
|
||||
),
|
||||
status="running",
|
||||
)
|
||||
)
|
||||
|
||||
assert seen == ["running"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_event_bus_keeps_catch_all_subscription() -> None:
|
||||
bus = RuntimeEventBus()
|
||||
seen: list[str] = []
|
||||
|
||||
def handle_any(event) -> None:
|
||||
seen.append(type(event).__name__)
|
||||
|
||||
bus.subscribe(handle_any)
|
||||
|
||||
await bus.publish(RuntimeModelChanged(model="m", model_preset=None))
|
||||
|
||||
assert seen == ["RuntimeModelChanged"]
|
||||
Loading…
x
Reference in New Issue
Block a user