mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-13 22:34:06 +00:00
refactor: subscribe to runtime event types
This commit is contained in:
parent
2f0e638bd1
commit
81370565e0
@ -75,7 +75,15 @@ RuntimeEvent = (
|
|||||||
| GoalStateChanged
|
| GoalStateChanged
|
||||||
| RuntimeModelChanged
|
| RuntimeModelChanged
|
||||||
)
|
)
|
||||||
RuntimeEventHandler = Callable[[RuntimeEvent], Awaitable[None] | None]
|
RuntimeEventType = (
|
||||||
|
type[SessionTurnStarted]
|
||||||
|
| type[TurnRunStatusChanged]
|
||||||
|
| type[TurnCompleted]
|
||||||
|
| type[GoalStateChanged]
|
||||||
|
| type[RuntimeModelChanged]
|
||||||
|
)
|
||||||
|
RuntimeEventHandler = Callable[[Any], Awaitable[None] | None]
|
||||||
|
_HandlerEntry = tuple[RuntimeEventType | None, RuntimeEventHandler]
|
||||||
|
|
||||||
|
|
||||||
class RuntimeEventBus:
|
class RuntimeEventBus:
|
||||||
@ -87,19 +95,26 @@ class RuntimeEventBus:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._handlers: list[RuntimeEventHandler] = []
|
self._handlers: list[_HandlerEntry] = []
|
||||||
|
|
||||||
def subscribe(self, handler: RuntimeEventHandler) -> Callable[[], None]:
|
def subscribe(
|
||||||
self._handlers.append(handler)
|
self,
|
||||||
|
handler: RuntimeEventHandler,
|
||||||
|
event_type: RuntimeEventType | None = None,
|
||||||
|
) -> Callable[[], None]:
|
||||||
|
entry = (event_type, handler)
|
||||||
|
self._handlers.append(entry)
|
||||||
|
|
||||||
def _unsubscribe() -> None:
|
def _unsubscribe() -> None:
|
||||||
with contextlib.suppress(ValueError):
|
with contextlib.suppress(ValueError):
|
||||||
self._handlers.remove(handler)
|
self._handlers.remove(entry)
|
||||||
|
|
||||||
return _unsubscribe
|
return _unsubscribe
|
||||||
|
|
||||||
async def publish(self, event: RuntimeEvent) -> None:
|
async def publish(self, event: RuntimeEvent) -> None:
|
||||||
for handler in list(self._handlers):
|
for event_type, handler in list(self._handlers):
|
||||||
|
if event_type is not None and not isinstance(event, event_type):
|
||||||
|
continue
|
||||||
try:
|
try:
|
||||||
result = handler(event)
|
result = handler(event)
|
||||||
if inspect.isawaitable(result):
|
if inspect.isawaitable(result):
|
||||||
|
|||||||
@ -15,7 +15,6 @@ from nanobot.bus.events import InboundMessage, OutboundMessage
|
|||||||
from nanobot.bus.queue import MessageBus
|
from nanobot.bus.queue import MessageBus
|
||||||
from nanobot.bus.runtime_events import (
|
from nanobot.bus.runtime_events import (
|
||||||
GoalStateChanged,
|
GoalStateChanged,
|
||||||
RuntimeEvent,
|
|
||||||
RuntimeEventBus,
|
RuntimeEventBus,
|
||||||
RuntimeEventContext,
|
RuntimeEventContext,
|
||||||
RuntimeModelChanged,
|
RuntimeModelChanged,
|
||||||
@ -238,24 +237,34 @@ class WebuiTurnCoordinator:
|
|||||||
|
|
||||||
def subscribe(self, runtime_events: RuntimeEventBus) -> Callable[[], None]:
|
def subscribe(self, runtime_events: RuntimeEventBus) -> Callable[[], None]:
|
||||||
"""Subscribe this coordinator to runtime events."""
|
"""Subscribe this coordinator to runtime events."""
|
||||||
return runtime_events.subscribe(self.handle_runtime_event)
|
unsubscribe = [
|
||||||
|
runtime_events.subscribe(
|
||||||
|
self._handle_session_turn_started,
|
||||||
|
SessionTurnStarted,
|
||||||
|
),
|
||||||
|
runtime_events.subscribe(
|
||||||
|
self._handle_run_status_changed,
|
||||||
|
TurnRunStatusChanged,
|
||||||
|
),
|
||||||
|
runtime_events.subscribe(
|
||||||
|
self._handle_turn_completed_event,
|
||||||
|
TurnCompleted,
|
||||||
|
),
|
||||||
|
runtime_events.subscribe(
|
||||||
|
self._handle_goal_state_changed,
|
||||||
|
GoalStateChanged,
|
||||||
|
),
|
||||||
|
runtime_events.subscribe(
|
||||||
|
self._handle_runtime_model_changed,
|
||||||
|
RuntimeModelChanged,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
async def handle_runtime_event(self, event: RuntimeEvent) -> None:
|
def _unsubscribe() -> None:
|
||||||
if isinstance(event, SessionTurnStarted):
|
for fn in reversed(unsubscribe):
|
||||||
self._handle_session_turn_started(event)
|
fn()
|
||||||
return
|
|
||||||
if isinstance(event, TurnRunStatusChanged):
|
return _unsubscribe
|
||||||
await self._handle_run_status_changed(event)
|
|
||||||
return
|
|
||||||
if isinstance(event, TurnCompleted):
|
|
||||||
await self._handle_turn_completed_event(event)
|
|
||||||
return
|
|
||||||
if isinstance(event, GoalStateChanged):
|
|
||||||
await self._handle_goal_state_changed(event)
|
|
||||||
return
|
|
||||||
if isinstance(event, RuntimeModelChanged):
|
|
||||||
await self._handle_runtime_model_changed(event)
|
|
||||||
return
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _ctx_msg(ctx: RuntimeEventContext) -> InboundMessage:
|
def _ctx_msg(ctx: RuntimeEventContext) -> InboundMessage:
|
||||||
|
|||||||
48
tests/bus/test_runtime_events.py
Normal file
48
tests/bus/test_runtime_events.py
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from nanobot.bus.runtime_events import (
|
||||||
|
RuntimeEventBus,
|
||||||
|
RuntimeEventContext,
|
||||||
|
RuntimeModelChanged,
|
||||||
|
TurnRunStatusChanged,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runtime_event_bus_filters_by_event_type() -> None:
|
||||||
|
bus = RuntimeEventBus()
|
||||||
|
seen: list[str] = []
|
||||||
|
|
||||||
|
async def handle_run_status(event: TurnRunStatusChanged) -> None:
|
||||||
|
seen.append(event.status)
|
||||||
|
|
||||||
|
bus.subscribe(handle_run_status, TurnRunStatusChanged)
|
||||||
|
|
||||||
|
await bus.publish(RuntimeModelChanged(model="m", model_preset=None))
|
||||||
|
await bus.publish(
|
||||||
|
TurnRunStatusChanged(
|
||||||
|
context=RuntimeEventContext(
|
||||||
|
channel="cli",
|
||||||
|
chat_id="direct",
|
||||||
|
session_key="cli:direct",
|
||||||
|
),
|
||||||
|
status="running",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert seen == ["running"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_runtime_event_bus_keeps_catch_all_subscription() -> None:
|
||||||
|
bus = RuntimeEventBus()
|
||||||
|
seen: list[str] = []
|
||||||
|
|
||||||
|
def handle_any(event) -> None:
|
||||||
|
seen.append(type(event).__name__)
|
||||||
|
|
||||||
|
bus.subscribe(handle_any)
|
||||||
|
|
||||||
|
await bus.publish(RuntimeModelChanged(model="m", model_preset=None))
|
||||||
|
|
||||||
|
assert seen == ["RuntimeModelChanged"]
|
||||||
Loading…
x
Reference in New Issue
Block a user