mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-06-13 22:34:06 +00:00
refactor: move runtime event publishing out of loop
This commit is contained in:
parent
81370565e0
commit
f78700fe69
@ -33,11 +33,8 @@ from nanobot.bus.progress import build_bus_progress_callback
|
||||
from nanobot.bus.queue import MessageBus
|
||||
from nanobot.bus.runtime_events import (
|
||||
RuntimeEventBus,
|
||||
RuntimeEventContext,
|
||||
RuntimeModelChanged,
|
||||
SessionTurnStarted,
|
||||
TurnCompleted,
|
||||
TurnRunStatusChanged,
|
||||
RuntimeEventPublisher,
|
||||
ensure_runtime_event_publisher,
|
||||
)
|
||||
from nanobot.command import CommandContext, CommandRouter, register_builtin_commands
|
||||
from nanobot.config.schema import AgentDefaults, ModelPresetConfig
|
||||
@ -129,7 +126,6 @@ class TurnContext:
|
||||
turn_wall_started_at: float = field(default_factory=time.time)
|
||||
visible_run_started_at: float | None = None
|
||||
turn_latency_ms: int | None = None
|
||||
llm_runtime: LLMRuntime | None = None
|
||||
|
||||
trace: list[StateTraceEntry] = field(default_factory=list)
|
||||
|
||||
@ -217,6 +213,7 @@ class AgentLoop:
|
||||
defaults = AgentDefaults()
|
||||
self.bus = bus
|
||||
self.runtime_events = runtime_events or RuntimeEventBus()
|
||||
self.runtime_event_publisher = RuntimeEventPublisher(self.runtime_events)
|
||||
self.channels_config = channels_config
|
||||
self.provider = provider
|
||||
self._provider_snapshot_loader = provider_snapshot_loader
|
||||
@ -262,8 +259,6 @@ class AgentLoop:
|
||||
)
|
||||
self._start_time = time.time()
|
||||
self._last_usage: dict[str, int] = {}
|
||||
self._pending_turn_latency_ms: dict[str, int] = {}
|
||||
self._pending_turn_runtime: dict[str, LLMRuntime] = {}
|
||||
self._extra_hooks: list[AgentHook] = hooks or []
|
||||
|
||||
self.context = ContextBuilder(workspace, timezone=timezone, disabled_skills=disabled_skills)
|
||||
@ -422,11 +417,9 @@ class AgentLoop:
|
||||
model_preset if model_preset is not None else self.model_preset,
|
||||
)
|
||||
if publish_update:
|
||||
self._runtime_event_bus().publish_nowait(
|
||||
RuntimeModelChanged(
|
||||
model=self.model,
|
||||
model_preset=model_preset if model_preset is not None else self.model_preset,
|
||||
)
|
||||
self._runtime_events().runtime_model_changed(
|
||||
self.model,
|
||||
model_preset if model_preset is not None else self.model_preset,
|
||||
)
|
||||
logger.info("Runtime model switched for next turn: {} -> {}", old_model, model)
|
||||
|
||||
@ -569,85 +562,8 @@ class AgentLoop:
|
||||
|
||||
return _on_retry_wait
|
||||
|
||||
@staticmethod
|
||||
def _runtime_event_context(
|
||||
*,
|
||||
channel: str,
|
||||
chat_id: str,
|
||||
session_key: str,
|
||||
metadata: dict[str, Any] | None,
|
||||
) -> RuntimeEventContext:
|
||||
return RuntimeEventContext(
|
||||
channel=channel,
|
||||
chat_id=chat_id,
|
||||
session_key=session_key,
|
||||
metadata=dict(metadata or {}),
|
||||
)
|
||||
|
||||
def _runtime_event_bus(self) -> RuntimeEventBus:
|
||||
bus = getattr(self, "runtime_events", None)
|
||||
if bus is None:
|
||||
bus = RuntimeEventBus()
|
||||
self.runtime_events = bus
|
||||
return bus
|
||||
|
||||
def _pop_pending_turn_latency(self, session_key: str) -> int | None:
|
||||
pending = getattr(self, "_pending_turn_latency_ms", None)
|
||||
if not isinstance(pending, dict):
|
||||
return None
|
||||
return pending.pop(session_key, None)
|
||||
|
||||
def _pop_pending_turn_runtime(self, session_key: str) -> LLMRuntime | None:
|
||||
pending = getattr(self, "_pending_turn_runtime", None)
|
||||
if not isinstance(pending, dict):
|
||||
return None
|
||||
return pending.pop(session_key, None)
|
||||
|
||||
def _clear_pending_turn_runtime(self, session_key: str) -> None:
|
||||
self._pop_pending_turn_latency(session_key)
|
||||
self._pop_pending_turn_runtime(session_key)
|
||||
|
||||
async def _publish_run_status_event(
|
||||
self,
|
||||
msg: InboundMessage,
|
||||
session_key: str,
|
||||
status: str,
|
||||
*,
|
||||
started_at: float | None = None,
|
||||
) -> None:
|
||||
await self._runtime_event_bus().publish(
|
||||
TurnRunStatusChanged(
|
||||
context=self._runtime_event_context(
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
session_key=session_key,
|
||||
metadata=msg.metadata,
|
||||
),
|
||||
status=status,
|
||||
started_at=started_at,
|
||||
)
|
||||
)
|
||||
|
||||
async def _publish_turn_completed_event(
|
||||
self,
|
||||
*,
|
||||
channel: str,
|
||||
chat_id: str,
|
||||
session_key: str,
|
||||
metadata: dict[str, Any] | None,
|
||||
) -> None:
|
||||
await self._runtime_event_bus().publish(
|
||||
TurnCompleted(
|
||||
context=self._runtime_event_context(
|
||||
channel=channel,
|
||||
chat_id=chat_id,
|
||||
session_key=session_key,
|
||||
metadata=metadata,
|
||||
),
|
||||
latency_ms=self._pop_pending_turn_latency(session_key),
|
||||
runtime=self._pop_pending_turn_runtime(session_key),
|
||||
)
|
||||
)
|
||||
def _runtime_events(self) -> RuntimeEventPublisher:
|
||||
return ensure_runtime_event_publisher(self)
|
||||
|
||||
def _persist_user_message_early(
|
||||
self,
|
||||
@ -1063,7 +979,7 @@ class AgentLoop:
|
||||
))
|
||||
continuing = turn_continuation.internal_continuation_pending(msg.metadata)
|
||||
if not continuing:
|
||||
await self._publish_turn_completed_event(
|
||||
await self._runtime_events().turn_completed(
|
||||
channel=completed_channel,
|
||||
chat_id=completed_chat_id,
|
||||
session_key=session_key,
|
||||
@ -1127,12 +1043,16 @@ class AgentLoop:
|
||||
leftover, session_key,
|
||||
)
|
||||
if not turn_continuation.internal_continuation_pending(msg.metadata):
|
||||
await self._publish_run_status_event(msg, session_key, "idle")
|
||||
self._clear_pending_turn_runtime(session_key)
|
||||
await self._runtime_events().run_status_changed(
|
||||
msg, session_key, "idle"
|
||||
)
|
||||
self._runtime_events().clear_turn(session_key)
|
||||
finally:
|
||||
if pending is None:
|
||||
await self._publish_run_status_event(msg, session_key, "idle")
|
||||
self._clear_pending_turn_runtime(session_key)
|
||||
await self._runtime_events().run_status_changed(
|
||||
msg, session_key, "idle"
|
||||
)
|
||||
self._runtime_events().clear_turn(session_key)
|
||||
|
||||
async def close_mcp(self) -> None:
|
||||
"""Drain pending background archives, then close MCP connections."""
|
||||
@ -1228,7 +1148,7 @@ class AgentLoop:
|
||||
wall_done = time.time()
|
||||
latency_ms = max(0, int((wall_done - t_wall) * 1000))
|
||||
self._save_turn(session, all_msgs, 1 + len(history), turn_latency_ms=latency_ms)
|
||||
self._pending_turn_latency_ms[key] = latency_ms
|
||||
self._runtime_events().record_turn_latency(key, latency_ms)
|
||||
session.enforce_file_cap(on_archive=self.context.memory.raw_archive)
|
||||
self._clear_runtime_checkpoint(session)
|
||||
self.sessions.save(session)
|
||||
@ -1394,16 +1314,7 @@ class AgentLoop:
|
||||
# ensure it exists in case this handler is invoked independently.
|
||||
if ctx.session is None:
|
||||
ctx.session = self.sessions.get_or_create(ctx.session_key)
|
||||
await self._runtime_event_bus().publish(
|
||||
SessionTurnStarted(
|
||||
context=self._runtime_event_context(
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
session_key=ctx.session_key,
|
||||
metadata=msg.metadata,
|
||||
)
|
||||
)
|
||||
)
|
||||
await self._runtime_events().session_turn_started(msg, ctx.session_key)
|
||||
self.workspace_scopes.persist_message_scope(ctx.session, msg)
|
||||
|
||||
if self._restore_runtime_checkpoint(ctx.session):
|
||||
@ -1475,8 +1386,10 @@ class AgentLoop:
|
||||
"include_timestamps": True,
|
||||
}
|
||||
ctx.history = ctx.session.get_history(**_hist_kwargs)
|
||||
ctx.llm_runtime = self.llm_runtime()
|
||||
self._pending_turn_runtime[ctx.session_key] = ctx.llm_runtime
|
||||
self._runtime_events().record_turn_runtime(
|
||||
ctx.session_key,
|
||||
self.llm_runtime(),
|
||||
)
|
||||
|
||||
ctx.initial_messages = self._build_initial_messages(
|
||||
ctx.msg,
|
||||
@ -1498,7 +1411,7 @@ class AgentLoop:
|
||||
async def _state_run(self, ctx: TurnContext) -> str:
|
||||
if ctx.visible_run_started_at is None:
|
||||
ctx.visible_run_started_at = time.time()
|
||||
await self._publish_run_status_event(
|
||||
await self._runtime_events().run_status_changed(
|
||||
ctx.msg,
|
||||
ctx.session_key,
|
||||
"running",
|
||||
@ -1547,7 +1460,10 @@ class AgentLoop:
|
||||
ctx.session, ctx.all_messages, ctx.save_skip,
|
||||
turn_latency_ms=ctx.turn_latency_ms,
|
||||
)
|
||||
self._pending_turn_latency_ms[ctx.session_key] = ctx.turn_latency_ms
|
||||
self._runtime_events().record_turn_latency(
|
||||
ctx.session_key,
|
||||
ctx.turn_latency_ms,
|
||||
)
|
||||
ctx.session.enforce_file_cap(on_archive=self.context.memory.raw_archive)
|
||||
self._clear_pending_user_turn(ctx.session)
|
||||
self._clear_runtime_checkpoint(ctx.session)
|
||||
@ -1816,5 +1732,5 @@ class AgentLoop:
|
||||
on_stream_end=on_stream_end,
|
||||
)
|
||||
finally:
|
||||
await self._publish_run_status_event(msg, session_key, "idle")
|
||||
self._clear_pending_turn_runtime(session_key)
|
||||
await self._runtime_events().run_status_changed(msg, session_key, "idle")
|
||||
self._runtime_events().clear_turn(session_key)
|
||||
|
||||
@ -16,6 +16,8 @@ from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from nanobot.bus.events import InboundMessage
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RuntimeEventContext:
|
||||
@ -129,3 +131,121 @@ class RuntimeEventBus:
|
||||
logger.debug("dropping runtime event without a running loop: {}", type(event).__name__)
|
||||
return
|
||||
loop.create_task(self.publish(event))
|
||||
|
||||
|
||||
class RuntimeEventPublisher:
|
||||
"""Convenience publisher for turn-scoped runtime events.
|
||||
|
||||
Agent code should decide when state transitions happen; this helper owns
|
||||
the mechanics of building event contexts and carrying per-turn metadata.
|
||||
"""
|
||||
|
||||
def __init__(self, bus: RuntimeEventBus | None = None) -> None:
|
||||
self.bus = bus or RuntimeEventBus()
|
||||
self._turn_latency_ms: dict[str, int] = {}
|
||||
self._turn_runtime: dict[str, Any] = {}
|
||||
|
||||
@staticmethod
|
||||
def _context(
|
||||
*,
|
||||
channel: str,
|
||||
chat_id: str,
|
||||
session_key: str,
|
||||
metadata: dict[str, Any] | None,
|
||||
) -> RuntimeEventContext:
|
||||
return RuntimeEventContext(
|
||||
channel=channel,
|
||||
chat_id=chat_id,
|
||||
session_key=session_key,
|
||||
metadata=dict(metadata or {}),
|
||||
)
|
||||
|
||||
def record_turn_runtime(self, session_key: str, runtime: Any) -> None:
|
||||
self._turn_runtime[session_key] = runtime
|
||||
|
||||
def record_turn_latency(self, session_key: str, latency_ms: int | None) -> None:
|
||||
if latency_ms is not None:
|
||||
self._turn_latency_ms[session_key] = int(latency_ms)
|
||||
|
||||
def clear_turn(self, session_key: str) -> None:
|
||||
self._turn_latency_ms.pop(session_key, None)
|
||||
self._turn_runtime.pop(session_key, None)
|
||||
|
||||
async def session_turn_started(
|
||||
self,
|
||||
msg: InboundMessage,
|
||||
session_key: str,
|
||||
) -> None:
|
||||
await self.bus.publish(
|
||||
SessionTurnStarted(
|
||||
context=self._context(
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
session_key=session_key,
|
||||
metadata=msg.metadata,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
async def run_status_changed(
|
||||
self,
|
||||
msg: InboundMessage,
|
||||
session_key: str,
|
||||
status: str,
|
||||
*,
|
||||
started_at: float | None = None,
|
||||
) -> None:
|
||||
await self.bus.publish(
|
||||
TurnRunStatusChanged(
|
||||
context=self._context(
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
session_key=session_key,
|
||||
metadata=msg.metadata,
|
||||
),
|
||||
status=status,
|
||||
started_at=started_at,
|
||||
)
|
||||
)
|
||||
|
||||
async def turn_completed(
|
||||
self,
|
||||
*,
|
||||
channel: str,
|
||||
chat_id: str,
|
||||
session_key: str,
|
||||
metadata: dict[str, Any] | None,
|
||||
) -> None:
|
||||
await self.bus.publish(
|
||||
TurnCompleted(
|
||||
context=self._context(
|
||||
channel=channel,
|
||||
chat_id=chat_id,
|
||||
session_key=session_key,
|
||||
metadata=metadata,
|
||||
),
|
||||
latency_ms=self._turn_latency_ms.pop(session_key, None),
|
||||
runtime=self._turn_runtime.pop(session_key, None),
|
||||
)
|
||||
)
|
||||
|
||||
def runtime_model_changed(self, model: str, model_preset: str | None) -> None:
|
||||
self.bus.publish_nowait(
|
||||
RuntimeModelChanged(model=model, model_preset=model_preset)
|
||||
)
|
||||
|
||||
|
||||
def ensure_runtime_event_publisher(owner: Any) -> RuntimeEventPublisher:
|
||||
"""Return an owner's runtime publisher, creating missing state lazily."""
|
||||
publisher = getattr(owner, "runtime_event_publisher", None)
|
||||
if isinstance(publisher, RuntimeEventPublisher):
|
||||
return publisher
|
||||
|
||||
bus = getattr(owner, "runtime_events", None)
|
||||
if not isinstance(bus, RuntimeEventBus):
|
||||
bus = RuntimeEventBus()
|
||||
owner.runtime_events = bus
|
||||
|
||||
publisher = RuntimeEventPublisher(bus)
|
||||
owner.runtime_event_publisher = publisher
|
||||
return publisher
|
||||
|
||||
@ -1,9 +1,13 @@
|
||||
import pytest
|
||||
|
||||
from nanobot.bus.events import InboundMessage
|
||||
from nanobot.bus.runtime_events import (
|
||||
RuntimeEventBus,
|
||||
RuntimeEventContext,
|
||||
RuntimeEventPublisher,
|
||||
RuntimeModelChanged,
|
||||
SessionTurnStarted,
|
||||
TurnCompleted,
|
||||
TurnRunStatusChanged,
|
||||
)
|
||||
|
||||
@ -46,3 +50,73 @@ async def test_runtime_event_bus_keeps_catch_all_subscription() -> None:
|
||||
await bus.publish(RuntimeModelChanged(model="m", model_preset=None))
|
||||
|
||||
assert seen == ["RuntimeModelChanged"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_event_publisher_builds_context_from_inbound_message() -> None:
|
||||
bus = RuntimeEventBus()
|
||||
seen: list[object] = []
|
||||
publisher = RuntimeEventPublisher(bus)
|
||||
msg = InboundMessage(
|
||||
channel="websocket",
|
||||
sender_id="user",
|
||||
chat_id="chat-a",
|
||||
content="hello",
|
||||
metadata={"trace_id": "turn-1"},
|
||||
)
|
||||
|
||||
bus.subscribe(seen.append)
|
||||
|
||||
await publisher.session_turn_started(msg, "websocket:chat-a")
|
||||
await publisher.run_status_changed(
|
||||
msg,
|
||||
"websocket:chat-a",
|
||||
"running",
|
||||
started_at=12.5,
|
||||
)
|
||||
|
||||
started = seen[0]
|
||||
running = seen[1]
|
||||
assert isinstance(started, SessionTurnStarted)
|
||||
assert started.context.channel == "websocket"
|
||||
assert started.context.chat_id == "chat-a"
|
||||
assert started.context.session_key == "websocket:chat-a"
|
||||
assert started.context.metadata == {"trace_id": "turn-1"}
|
||||
assert started.context.metadata is not msg.metadata
|
||||
assert isinstance(running, TurnRunStatusChanged)
|
||||
assert running.status == "running"
|
||||
assert running.started_at == 12.5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_event_publisher_consumes_turn_metadata_on_complete() -> None:
|
||||
bus = RuntimeEventBus()
|
||||
seen: list[object] = []
|
||||
publisher = RuntimeEventPublisher(bus)
|
||||
|
||||
bus.subscribe(seen.append)
|
||||
publisher.record_turn_runtime("cli:direct", "runtime")
|
||||
publisher.record_turn_latency("cli:direct", 123)
|
||||
|
||||
await publisher.turn_completed(
|
||||
channel="cli",
|
||||
chat_id="direct",
|
||||
session_key="cli:direct",
|
||||
metadata={"source": "test"},
|
||||
)
|
||||
await publisher.turn_completed(
|
||||
channel="cli",
|
||||
chat_id="direct",
|
||||
session_key="cli:direct",
|
||||
metadata=None,
|
||||
)
|
||||
|
||||
first = seen[0]
|
||||
second = seen[1]
|
||||
assert isinstance(first, TurnCompleted)
|
||||
assert first.context.metadata == {"source": "test"}
|
||||
assert first.latency_ms == 123
|
||||
assert first.runtime == "runtime"
|
||||
assert isinstance(second, TurnCompleted)
|
||||
assert second.latency_ms is None
|
||||
assert second.runtime is None
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user