refactor(loop): event-driven state transitions + trace logging

- State handlers now return event strings ('ok', 'dispatch', 'shortcut')
- Driver loop uses _TRANSITIONS lookup table: (state, event) -> next_state
- State graph is centralized and visible at a glance
- Added StateTraceEntry to record per-state timing and events
- Driver loop logs state duration + event at debug level
- Exception paths are traced with error field for observability
This commit is contained in:
chengyongru 2026-05-09 16:47:54 +08:00 committed by Xubin Ren
parent 6ef1b2c842
commit 5327f5e1a0

View File

@ -195,6 +195,15 @@ class TurnState(Enum):
DONE = auto()
@dataclass
class StateTraceEntry:
state: TurnState
started_at: float
duration_ms: float
event: str
error: str | None = None
@dataclass
class TurnContext:
msg: InboundMessage
@ -225,6 +234,8 @@ class TurnContext:
pending_queue: asyncio.Queue | None = None
pending_summary: str | None = None
trace: list[StateTraceEntry] = field(default_factory=list)
class AgentLoop:
"""
@ -241,6 +252,19 @@ class AgentLoop:
_RUNTIME_CHECKPOINT_KEY = "runtime_checkpoint"
_PENDING_USER_TURN_KEY = "pending_user_turn"
# Event-driven state transition table.
# Handlers return an event string; the driver looks up the next state here.
_TRANSITIONS: dict[tuple[TurnState, str], TurnState] = {
(TurnState.RESTORE, "ok"): TurnState.COMPACT,
(TurnState.COMPACT, "ok"): TurnState.COMMAND,
(TurnState.COMMAND, "dispatch"): TurnState.BUILD,
(TurnState.COMMAND, "shortcut"): TurnState.DONE,
(TurnState.BUILD, "ok"): TurnState.RUN,
(TurnState.RUN, "ok"): TurnState.SAVE,
(TurnState.SAVE, "ok"): TurnState.RESPOND,
(TurnState.RESPOND, "ok"): TurnState.DONE,
}
def __init__(
self,
bus: MessageBus,
@ -1257,7 +1281,45 @@ class AgentLoop:
handler = getattr(self, handler_name, None)
if handler is None:
raise RuntimeError(f"Missing state handler for {ctx.state}")
ctx.state = await handler(ctx)
t0 = time.perf_counter()
try:
event = await handler(ctx)
except Exception:
duration = (time.perf_counter() - t0) * 1000
ctx.trace.append(
StateTraceEntry(
state=ctx.state,
started_at=t0,
duration_ms=duration,
event="",
error="exception",
)
)
raise
duration = (time.perf_counter() - t0) * 1000
ctx.trace.append(
StateTraceEntry(
state=ctx.state,
started_at=t0,
duration_ms=duration,
event=event,
)
)
logger.debug(
"State {} took {:.1f}ms -> event {}",
ctx.state.name,
duration,
event,
)
next_state = self._TRANSITIONS.get((ctx.state, event))
if next_state is None:
raise RuntimeError(
f"No transition from {ctx.state} on event {event!r}"
)
ctx.state = next_state
return ctx.outbound
@ -1321,14 +1383,14 @@ class AgentLoop:
if self._restore_pending_user_turn(ctx.session):
self.sessions.save(ctx.session)
return TurnState.COMPACT
return "ok"
async def _state_compact(self, ctx: TurnContext) -> TurnState:
async def _state_compact(self, ctx: TurnContext) -> str:
ctx.session, pending = self.auto_compact.prepare_session(ctx.session, ctx.session_key)
ctx.pending_summary = pending
return TurnState.COMMAND
return "ok"
async def _state_command(self, ctx: TurnContext) -> TurnState:
async def _state_command(self, ctx: TurnContext) -> str:
raw = ctx.msg.content.strip()
cmd_ctx = CommandContext(
msg=ctx.msg, session=ctx.session, key=ctx.session_key, raw=raw, loop=self
@ -1336,10 +1398,10 @@ class AgentLoop:
result = await self.commands.dispatch(cmd_ctx)
if result is not None:
ctx.outbound = result
return TurnState.DONE
return TurnState.BUILD
return "shortcut"
return "dispatch"
async def _state_build(self, ctx: TurnContext) -> TurnState:
async def _state_build(self, ctx: TurnContext) -> str:
await self.consolidator.maybe_consolidate_by_tokens(
ctx.session,
session_summary=ctx.pending_summary,
@ -1376,9 +1438,9 @@ class AgentLoop:
if ctx.on_retry_wait is None:
ctx.on_retry_wait = await self._build_retry_wait_callback(ctx.msg)
return TurnState.RUN
return "ok"
async def _state_run(self, ctx: TurnContext) -> TurnState:
async def _state_run(self, ctx: TurnContext) -> str:
result = await self._run_agent_loop(
ctx.initial_messages,
on_progress=ctx.on_progress,
@ -1399,9 +1461,9 @@ class AgentLoop:
ctx.all_messages = all_msgs
ctx.stop_reason = stop_reason
ctx.had_injections = had_injections
return TurnState.SAVE
return "ok"
async def _state_save(self, ctx: TurnContext) -> TurnState:
async def _state_save(self, ctx: TurnContext) -> str:
if ctx.final_content is None or not ctx.final_content.strip():
ctx.final_content = EMPTY_FINAL_RESPONSE_MESSAGE
@ -1425,9 +1487,9 @@ class AgentLoop:
replay_max_messages=self._max_messages,
)
)
return TurnState.RESPOND
return "ok"
async def _state_respond(self, ctx: TurnContext) -> TurnState:
async def _state_respond(self, ctx: TurnContext) -> str:
ctx.outbound = self._assemble_outbound(
ctx.msg,
ctx.final_content,
@ -1437,7 +1499,7 @@ class AgentLoop:
ctx.generated_media,
ctx.on_stream,
)
return TurnState.DONE
return "ok"
def _sanitize_persisted_blocks(
self,