refactor(loop): convert _process_message to functional state machine

- Extract TurnState enum and TurnContext dataclass
- Extract state handlers: _state_restore, _state_compact, _state_command,
  _state_build, _state_run, _state_save, _state_respond
- Extract _process_system_message for system message short-circuit
- Driver loop uses getattr dispatch over explicit state transitions
- Preserve all existing behavior (794 tests passing)
This commit is contained in:
chengyongru 2026-05-09 15:40:11 +08:00
parent 8710524f2b
commit d6d4a7cf67

View File

@ -1072,7 +1072,7 @@ class AgentLoop:
self._running = False self._running = False
logger.info("Agent loop stopping") logger.info("Agent loop stopping")
async def _process_message( async def _process_system_message(
self, self,
msg: InboundMessage, msg: InboundMessage,
session_key: str | None = None, session_key: str | None = None,
@ -1081,17 +1081,11 @@ class AgentLoop:
on_stream_end: Callable[..., Awaitable[None]] | None = None, on_stream_end: Callable[..., Awaitable[None]] | None = None,
pending_queue: asyncio.Queue | None = None, pending_queue: asyncio.Queue | None = None,
) -> OutboundMessage | None: ) -> OutboundMessage | None:
"""Process a single inbound message and return the response.""" """Process a system inbound message (e.g. subagent announce)."""
self._refresh_provider_snapshot()
# System messages: parse origin from chat_id ("channel:chat_id")
if msg.channel == "system":
channel, chat_id = ( channel, chat_id = (
msg.chat_id.split(":", 1) if ":" in msg.chat_id else ("cli", msg.chat_id) msg.chat_id.split(":", 1) if ":" in msg.chat_id else ("cli", msg.chat_id)
) )
logger.info("Processing system message from {}", msg.sender_id) logger.info("Processing system message from {}", msg.sender_id)
# Honor session_key_override so subagent announces from threaded
# callers route to the originating thread session, not the
# channel-level session derived from chat_id.
key = msg.session_key_override or f"{channel}:{chat_id}" key = msg.session_key_override or f"{channel}:{chat_id}"
session = self.sessions.get_or_create(key) session = self.sessions.get_or_create(key)
if self._restore_runtime_checkpoint(session): if self._restore_runtime_checkpoint(session):
@ -1108,11 +1102,6 @@ class AgentLoop:
session_summary=pending, session_summary=pending,
replay_max_messages=self._max_messages, replay_max_messages=self._max_messages,
) )
# Persist subagent follow-ups into durable history BEFORE prompt
# assembly. ContextBuilder merges adjacent same-role messages for
# provider compatibility, which previously caused the follow-up to
# disappear from session.messages while still being visible to the
# LLM via the merged prompt. See _persist_subagent_followup.
is_subagent = msg.sender_id == "subagent" is_subagent = msg.sender_id == "subagent"
if is_subagent and self._persist_subagent_followup(session, msg): if is_subagent and self._persist_subagent_followup(session, msg):
logger.debug("Subagent result persisted for session {}", key) logger.debug("Subagent result persisted for session {}", key)
@ -1129,8 +1118,6 @@ class AgentLoop:
history = session.get_history(**_hist_kwargs) history = session.get_history(**_hist_kwargs)
current_role = "assistant" if is_subagent else "user" current_role = "assistant" if is_subagent else "user"
# Subagent content is already in `history` above; passing it again
# as current_message would double-project it into the prompt.
messages = self.context.build_messages( messages = self.context.build_messages(
history=history, history=history,
current_message="" if is_subagent else msg.content, current_message="" if is_subagent else msg.content,
@ -1163,11 +1150,6 @@ class AgentLoop:
options, options,
channel, channel,
) )
# Reconstruct channel-specific metadata from session.key so the
# outbound reply lands in the originating thread (not the channel
# top-level). The announce InboundMessage carries only
# injected_event metadata; we recover thread_ts from the session
# key, which slack writes as "slack:<chat_id>:<thread_ts>".
outbound_metadata: dict[str, Any] = {} outbound_metadata: dict[str, Any] = {}
if channel == "slack" and key.startswith("slack:") and key.count(":") >= 2: if channel == "slack" and key.startswith("slack:") and key.count(":") >= 2:
outbound_metadata["slack"] = {"thread_ts": key.split(":", 2)[2]} outbound_metadata["slack"] = {"thread_ts": key.split(":", 2)[2]}
@ -1181,112 +1163,45 @@ class AgentLoop:
metadata=outbound_metadata, metadata=outbound_metadata,
) )
# Extract document text from media at the processing boundary so all async def _process_message(
# channels benefit without format-specific logic in ContextBuilder. self,
if msg.media: msg: InboundMessage,
new_content, image_only = extract_documents(msg.content, msg.media) session_key: str | None = None,
msg = dataclasses.replace(msg, content=new_content, media=image_only) on_progress: Callable[..., Awaitable[None]] | None = None,
on_stream: Callable[[str], Awaitable[None]] | None = None,
on_stream_end: Callable[..., Awaitable[None]] | None = None,
pending_queue: asyncio.Queue | None = None,
) -> OutboundMessage | None:
"""Process a single inbound message and return the response."""
self._refresh_provider_snapshot()
preview = msg.content[:80] + "..." if len(msg.content) > 80 else msg.content if msg.channel == "system":
logger.info("Processing message from {}:{}: {}", msg.channel, msg.sender_id, preview) return await self._process_system_message(
msg,
key = session_key or msg.session_key session_key=session_key,
session = self.sessions.get_or_create(key) on_progress=on_progress,
mark_webui_session(session, msg.metadata)
if self._restore_runtime_checkpoint(session):
self.sessions.save(session)
if self._restore_pending_user_turn(session):
self.sessions.save(session)
session, pending = self.auto_compact.prepare_session(session, key)
# Slash commands
raw = msg.content.strip()
ctx = CommandContext(msg=msg, session=session, key=key, raw=raw, loop=self)
if result := await self.commands.dispatch(ctx):
return result
await self.consolidator.maybe_consolidate_by_tokens(
session,
session_summary=pending,
replay_max_messages=self._max_messages,
)
self._set_tool_context(
msg.channel, msg.chat_id, msg.metadata.get("message_id"),
msg.metadata, session_key=key,
)
if message_tool := self.tools.get("message"):
if isinstance(message_tool, MessageTool):
message_tool.start_turn()
_hist_kwargs: dict[str, Any] = {
"max_messages": self._max_messages,
"max_tokens": self._replay_token_budget(),
"include_timestamps": True,
}
history = session.get_history(**_hist_kwargs)
pending_ask_id = pending_ask_user_id(history)
initial_messages = self._build_initial_messages(
msg, session, history, pending_ask_id, pending
)
_bus_progress = await self._build_bus_progress_callback(msg)
_on_retry_wait = await self._build_retry_wait_callback(msg)
# Persist the triggering user message up front so a mid-turn crash
# doesn't silently lose the prompt on recovery. ``media`` rides along
# as raw on-disk paths — sanitized image blocks are stripped from
# JSONL, and webui replay needs the paths to mint signed URLs.
user_persisted_early = self._persist_user_message_early(msg, session, pending_ask_id)
final_content, _, all_msgs, stop_reason, had_injections = await self._run_agent_loop(
initial_messages,
on_progress=on_progress or _bus_progress,
on_stream=on_stream, on_stream=on_stream,
on_stream_end=on_stream_end, on_stream_end=on_stream_end,
on_retry_wait=_on_retry_wait,
session=session,
channel=msg.channel,
chat_id=msg.chat_id,
message_id=msg.metadata.get("message_id"),
metadata=msg.metadata,
session_key=key,
pending_queue=pending_queue, pending_queue=pending_queue,
) )
if final_content is None or not final_content.strip(): key = session_key or msg.session_key
final_content = EMPTY_FINAL_RESPONSE_MESSAGE ctx = TurnContext(
msg=msg,
# Skip the already-persisted user message when saving the turn session=self.sessions.get_or_create(key),
save_skip = 1 + len(history) + (1 if user_persisted_early else 0) session_key=key,
generated_media = generated_image_paths_from_messages(all_msgs[save_skip:]) state=TurnState.RESTORE,
if generated_media and all_msgs and all_msgs[-1].get("role") == "assistant": on_progress=on_progress,
existing_media = all_msgs[-1].get("media") on_stream=on_stream,
media = existing_media if isinstance(existing_media, list) else [] on_stream_end=on_stream_end,
all_msgs[-1]["media"] = list(dict.fromkeys([*media, *generated_media])) pending_queue=pending_queue,
self._save_turn(session, all_msgs, save_skip)
session.enforce_file_cap(on_archive=self.context.memory.raw_archive)
self._clear_pending_user_turn(session)
self._clear_runtime_checkpoint(session)
self.sessions.save(session)
self._schedule_background(
self.consolidator.maybe_consolidate_by_tokens(
session,
replay_max_messages=self._max_messages,
)
) )
return self._assemble_outbound( while ctx.state is not TurnState.DONE:
msg, handler = getattr(self, f"_state_{ctx.state.name.lower()}")
final_content, ctx.state = await handler(ctx)
all_msgs,
stop_reason, return ctx.outbound
had_injections,
generated_media,
on_stream,
)
def _assemble_outbound( def _assemble_outbound(
self, self,
@ -1325,6 +1240,141 @@ class AgentLoop:
buttons=buttons, buttons=buttons,
) )
async def _state_restore(self, ctx: TurnContext) -> TurnState:
"""Restore checkpoint / pending user turn; extract documents."""
msg = ctx.msg
if msg.media:
new_content, image_only = extract_documents(msg.content, msg.media)
ctx.msg = dataclasses.replace(msg, content=new_content, media=image_only)
msg = ctx.msg
preview = msg.content[:80] + "..." if len(msg.content) > 80 else msg.content
logger.info("Processing message from {}:{}: {}", msg.channel, msg.sender_id, preview)
ctx.session = self.sessions.get_or_create(ctx.session_key)
mark_webui_session(ctx.session, msg.metadata)
if self._restore_runtime_checkpoint(ctx.session):
self.sessions.save(ctx.session)
if self._restore_pending_user_turn(ctx.session):
self.sessions.save(ctx.session)
return TurnState.COMPACT
async def _state_compact(self, ctx: TurnContext) -> TurnState:
ctx.session, pending = self.auto_compact.prepare_session(ctx.session, ctx.session_key)
ctx.pending_summary = pending
return TurnState.COMMAND
async def _state_command(self, ctx: TurnContext) -> TurnState:
raw = ctx.msg.content.strip()
cmd_ctx = CommandContext(
msg=ctx.msg, session=ctx.session, key=ctx.session_key, raw=raw, loop=self
)
result = await self.commands.dispatch(cmd_ctx)
if result is not None:
ctx.outbound = result
return TurnState.DONE
return TurnState.BUILD
async def _state_build(self, ctx: TurnContext) -> TurnState:
await self.consolidator.maybe_consolidate_by_tokens(
ctx.session,
session_summary=ctx.pending_summary,
replay_max_messages=self._max_messages,
)
self._set_tool_context(
ctx.msg.channel,
ctx.msg.chat_id,
ctx.msg.metadata.get("message_id"),
ctx.msg.metadata,
session_key=ctx.session_key,
)
if message_tool := self.tools.get("message"):
if isinstance(message_tool, MessageTool):
message_tool.start_turn()
_hist_kwargs: dict[str, Any] = {
"max_messages": self._max_messages,
"max_tokens": self._replay_token_budget(),
"include_timestamps": True,
}
ctx.history = ctx.session.get_history(**_hist_kwargs)
pending_ask_id = pending_ask_user_id(ctx.history)
ctx.initial_messages = self._build_initial_messages(
ctx.msg, ctx.session, ctx.history, pending_ask_id, ctx.pending_summary
)
ctx.user_persisted_early = self._persist_user_message_early(
ctx.msg, ctx.session, pending_ask_id
)
if ctx.on_progress is None:
ctx.on_progress = await self._build_bus_progress_callback(ctx.msg)
if ctx.on_retry_wait is None:
ctx.on_retry_wait = await self._build_retry_wait_callback(ctx.msg)
return TurnState.RUN
async def _state_run(self, ctx: TurnContext) -> TurnState:
final_content, tools_used, all_msgs, stop_reason, had_injections = await self._run_agent_loop(
ctx.initial_messages,
on_progress=ctx.on_progress,
on_stream=ctx.on_stream,
on_stream_end=ctx.on_stream_end,
on_retry_wait=ctx.on_retry_wait,
session=ctx.session,
channel=ctx.msg.channel,
chat_id=ctx.msg.chat_id,
message_id=ctx.msg.metadata.get("message_id"),
metadata=ctx.msg.metadata,
session_key=ctx.session_key,
pending_queue=ctx.pending_queue,
)
ctx.final_content = final_content
ctx.tools_used = tools_used
ctx.all_messages = all_msgs
ctx.stop_reason = stop_reason
ctx.had_injections = had_injections
return TurnState.SAVE
async def _state_save(self, ctx: TurnContext) -> TurnState:
if ctx.final_content is None or not ctx.final_content.strip():
ctx.final_content = EMPTY_FINAL_RESPONSE_MESSAGE
ctx.save_skip = 1 + len(ctx.history) + (1 if ctx.user_persisted_early else 0)
ctx.generated_media = generated_image_paths_from_messages(ctx.all_messages[ctx.save_skip:])
if ctx.generated_media and ctx.all_messages and ctx.all_messages[-1].get("role") == "assistant":
existing_media = ctx.all_messages[-1].get("media")
media = existing_media if isinstance(existing_media, list) else []
ctx.all_messages[-1]["media"] = list(dict.fromkeys([*media, *ctx.generated_media]))
self._save_turn(ctx.session, ctx.all_messages, ctx.save_skip)
ctx.session.enforce_file_cap(on_archive=self.context.memory.raw_archive)
self._clear_pending_user_turn(ctx.session)
self._clear_runtime_checkpoint(ctx.session)
self.sessions.save(ctx.session)
self._schedule_background(
self.consolidator.maybe_consolidate_by_tokens(
ctx.session,
replay_max_messages=self._max_messages,
)
)
return TurnState.RESPOND
async def _state_respond(self, ctx: TurnContext) -> TurnState:
ctx.outbound = self._assemble_outbound(
ctx.msg,
ctx.final_content or "",
ctx.all_messages,
ctx.stop_reason,
ctx.had_injections,
ctx.generated_media,
ctx.on_stream,
)
return TurnState.DONE
def _sanitize_persisted_blocks( def _sanitize_persisted_blocks(
self, self,
content: list[dict[str, Any]], content: list[dict[str, Any]],