refactor(loop): convert _process_message to functional state machine

- Extract TurnState enum and TurnContext dataclass
- Extract state handlers: _state_restore, _state_compact, _state_command,
  _state_build, _state_run, _state_save, _state_respond
- Extract _process_system_message for system message short-circuit
- Driver loop uses getattr dispatch over explicit state transitions
- Preserve all existing behavior (794 tests passing)
This commit is contained in:
chengyongru 2026-05-09 15:40:11 +08:00 committed by Xubin Ren
parent 9fb9f53147
commit 02443ca208

View File

@ -1127,6 +1127,97 @@ class AgentLoop:
self._running = False
logger.info("Agent loop stopping")
async def _process_system_message(
self,
msg: InboundMessage,
session_key: str | None = None,
on_progress: Callable[..., Awaitable[None]] | None = None,
on_stream: Callable[[str], Awaitable[None]] | None = None,
on_stream_end: Callable[..., Awaitable[None]] | None = None,
pending_queue: asyncio.Queue | None = None,
) -> OutboundMessage | None:
"""Process a system inbound message (e.g. subagent announce)."""
channel, chat_id = (
msg.chat_id.split(":", 1) if ":" in msg.chat_id else ("cli", msg.chat_id)
)
logger.info("Processing system message from {}", msg.sender_id)
key = msg.session_key_override or f"{channel}:{chat_id}"
session = self.sessions.get_or_create(key)
if self._restore_runtime_checkpoint(session):
self.sessions.save(session)
if self._restore_pending_user_turn(session):
self.sessions.save(session)
session, pending = self.auto_compact.prepare_session(session, key)
if pending:
logger.info("Memory compact triggered for session {}", key)
await self.consolidator.maybe_consolidate_by_tokens(
session,
session_summary=pending,
replay_max_messages=self._max_messages,
)
is_subagent = msg.sender_id == "subagent"
if is_subagent and self._persist_subagent_followup(session, msg):
logger.debug("Subagent result persisted for session {}", key)
self.sessions.save(session)
self._set_tool_context(
channel, chat_id, msg.metadata.get("message_id"),
msg.metadata, session_key=key,
)
_hist_kwargs: dict[str, Any] = {
"max_messages": self._max_messages,
"max_tokens": self._replay_token_budget(),
"include_timestamps": True,
}
history = session.get_history(**_hist_kwargs)
current_role = "assistant" if is_subagent else "user"
messages = self.context.build_messages(
history=history,
current_message="" if is_subagent else msg.content,
channel=channel,
chat_id=chat_id,
session_summary=pending,
current_role=current_role,
sender_id=msg.sender_id,
)
final_content, _, all_msgs, stop_reason, _ = await self._run_agent_loop(
messages, session=session, channel=channel, chat_id=chat_id,
message_id=msg.metadata.get("message_id"),
metadata=msg.metadata,
session_key=key,
pending_queue=pending_queue,
)
self._save_turn(session, all_msgs, 1 + len(history))
session.enforce_file_cap(on_archive=self.context.memory.raw_archive)
self._clear_runtime_checkpoint(session)
self.sessions.save(session)
self._schedule_background(
self.consolidator.maybe_consolidate_by_tokens(
session,
replay_max_messages=self._max_messages,
)
)
options = ask_user_options_from_messages(all_msgs) if stop_reason == "ask_user" else []
content, buttons = ask_user_outbound(
final_content or "Background task completed.",
options,
channel,
)
outbound_metadata: dict[str, Any] = {}
if channel == "slack" and key.startswith("slack:") and key.count(":") >= 2:
outbound_metadata["slack"] = {"thread_ts": key.split(":", 2)[2]}
if origin_message_id := msg.metadata.get("origin_message_id"):
outbound_metadata["origin_message_id"] = origin_message_id
return OutboundMessage(
channel=channel,
chat_id=chat_id,
content=content,
buttons=buttons,
metadata=outbound_metadata,
)
async def _process_message(
self,
msg: InboundMessage,
@ -1138,210 +1229,34 @@ class AgentLoop:
) -> OutboundMessage | None:
"""Process a single inbound message and return the response."""
self._refresh_provider_snapshot()
# System messages: parse origin from chat_id ("channel:chat_id")
if msg.channel == "system":
channel, chat_id = (
msg.chat_id.split(":", 1) if ":" in msg.chat_id else ("cli", msg.chat_id)
)
logger.info("Processing system message from {}", msg.sender_id)
# Honor session_key_override so subagent announces from threaded
# callers route to the originating thread session, not the
# channel-level session derived from chat_id.
key = msg.session_key_override or f"{channel}:{chat_id}"
session = self.sessions.get_or_create(key)
if self._restore_runtime_checkpoint(session):
self.sessions.save(session)
if self._restore_pending_user_turn(session):
self.sessions.save(session)
session, pending = self.auto_compact.prepare_session(session, key)
if pending:
logger.info("Memory compact triggered for session {}", key)
await self.consolidator.maybe_consolidate_by_tokens(
session,
session_summary=pending,
replay_max_messages=self._max_messages,
)
# Persist subagent follow-ups into durable history BEFORE prompt
# assembly. ContextBuilder merges adjacent same-role messages for
# provider compatibility, which previously caused the follow-up to
# disappear from session.messages while still being visible to the
# LLM via the merged prompt. See _persist_subagent_followup.
is_subagent = msg.sender_id == "subagent"
if is_subagent and self._persist_subagent_followup(session, msg):
logger.debug("Subagent result persisted for session {}", key)
self.sessions.save(session)
self._set_tool_context(
channel, chat_id, msg.metadata.get("message_id"),
msg.metadata, session_key=key,
)
_hist_kwargs: dict[str, Any] = {
"max_messages": self._max_messages,
"max_tokens": self._replay_token_budget(),
"include_timestamps": True,
}
history = session.get_history(**_hist_kwargs)
current_role = "assistant" if is_subagent else "user"
# Subagent content is already in `history` above; passing it again
# as current_message would double-project it into the prompt.
messages = self.context.build_messages(
history=history,
current_message="" if is_subagent else msg.content,
channel=channel,
chat_id=chat_id,
session_summary=pending,
current_role=current_role,
sender_id=msg.sender_id,
)
final_content, _, all_msgs, stop_reason, _ = await self._run_agent_loop(
messages, session=session, channel=channel, chat_id=chat_id,
message_id=msg.metadata.get("message_id"),
metadata=msg.metadata,
session_key=key,
return await self._process_system_message(
msg,
session_key=session_key,
on_progress=on_progress,
on_stream=on_stream,
on_stream_end=on_stream_end,
pending_queue=pending_queue,
)
self._save_turn(session, all_msgs, 1 + len(history))
session.enforce_file_cap(on_archive=self.context.memory.raw_archive)
self._clear_runtime_checkpoint(session)
self.sessions.save(session)
self._schedule_background(
self.consolidator.maybe_consolidate_by_tokens(
session,
replay_max_messages=self._max_messages,
)
)
options = ask_user_options_from_messages(all_msgs) if stop_reason == "ask_user" else []
content, buttons = ask_user_outbound(
final_content or "Background task completed.",
options,
channel,
)
# Reconstruct channel-specific metadata from session.key so the
# outbound reply lands in the originating thread (not the channel
# top-level). The announce InboundMessage carries only
# injected_event metadata; we recover thread_ts from the session
# key, which slack writes as "slack:<chat_id>:<thread_ts>".
outbound_metadata: dict[str, Any] = {}
if channel == "slack" and key.startswith("slack:") and key.count(":") >= 2:
outbound_metadata["slack"] = {"thread_ts": key.split(":", 2)[2]}
if origin_message_id := msg.metadata.get("origin_message_id"):
outbound_metadata["origin_message_id"] = origin_message_id
return OutboundMessage(
channel=channel,
chat_id=chat_id,
content=content,
buttons=buttons,
metadata=outbound_metadata,
)
# Extract document text from media at the processing boundary so all
# channels benefit without format-specific logic in ContextBuilder.
if msg.media:
new_content, image_only = extract_documents(msg.content, msg.media)
msg = dataclasses.replace(msg, content=new_content, media=image_only)
preview = msg.content[:80] + "..." if len(msg.content) > 80 else msg.content
logger.info("Processing message from {}:{}: {}", msg.channel, msg.sender_id, preview)
key = session_key or msg.session_key
session = self.sessions.get_or_create(key)
mark_webui_session(session, msg.metadata)
if self._restore_runtime_checkpoint(session):
self.sessions.save(session)
if self._restore_pending_user_turn(session):
self.sessions.save(session)
session, pending = self.auto_compact.prepare_session(session, key)
# Slash commands
raw = msg.content.strip()
ctx = CommandContext(msg=msg, session=session, key=key, raw=raw, loop=self)
if result := await self.commands.dispatch(ctx):
return result
await self.consolidator.maybe_consolidate_by_tokens(
session,
session_summary=pending,
replay_max_messages=self._max_messages,
)
self._set_tool_context(
msg.channel, msg.chat_id, msg.metadata.get("message_id"),
msg.metadata, session_key=key,
)
if message_tool := self.tools.get("message"):
if isinstance(message_tool, MessageTool):
message_tool.start_turn()
_hist_kwargs: dict[str, Any] = {
"max_messages": self._max_messages,
"max_tokens": self._replay_token_budget(),
"include_timestamps": True,
}
history = session.get_history(**_hist_kwargs)
pending_ask_id = pending_ask_user_id(history)
initial_messages = self._build_initial_messages(
msg, session, history, pending_ask_id, pending
)
_bus_progress = await self._build_bus_progress_callback(msg)
_on_retry_wait = await self._build_retry_wait_callback(msg)
# Persist the triggering user message up front so a mid-turn crash
# doesn't silently lose the prompt on recovery. ``media`` rides along
# as raw on-disk paths — sanitized image blocks are stripped from
# JSONL, and webui replay needs the paths to mint signed URLs.
user_persisted_early = self._persist_user_message_early(msg, session, pending_ask_id)
final_content, _, all_msgs, stop_reason, had_injections = await self._run_agent_loop(
initial_messages,
on_progress=on_progress or _bus_progress,
ctx = TurnContext(
msg=msg,
session=self.sessions.get_or_create(key),
session_key=key,
state=TurnState.RESTORE,
on_progress=on_progress,
on_stream=on_stream,
on_stream_end=on_stream_end,
on_retry_wait=_on_retry_wait,
session=session,
channel=msg.channel,
chat_id=msg.chat_id,
message_id=msg.metadata.get("message_id"),
metadata=msg.metadata,
session_key=key,
pending_queue=pending_queue,
)
if final_content is None or not final_content.strip():
final_content = EMPTY_FINAL_RESPONSE_MESSAGE
while ctx.state is not TurnState.DONE:
handler = getattr(self, f"_state_{ctx.state.name.lower()}")
ctx.state = await handler(ctx)
# Skip the already-persisted user message when saving the turn
save_skip = 1 + len(history) + (1 if user_persisted_early else 0)
generated_media = generated_image_paths_from_messages(all_msgs[save_skip:])
if generated_media and all_msgs and all_msgs[-1].get("role") == "assistant":
existing_media = all_msgs[-1].get("media")
media = existing_media if isinstance(existing_media, list) else []
all_msgs[-1]["media"] = list(dict.fromkeys([*media, *generated_media]))
self._save_turn(session, all_msgs, save_skip)
session.enforce_file_cap(on_archive=self.context.memory.raw_archive)
self._clear_pending_user_turn(session)
self._clear_runtime_checkpoint(session)
self.sessions.save(session)
self._schedule_background(
self.consolidator.maybe_consolidate_by_tokens(
session,
replay_max_messages=self._max_messages,
)
)
return self._assemble_outbound(
msg,
final_content,
all_msgs,
stop_reason,
had_injections,
generated_media,
on_stream,
)
return ctx.outbound
def _assemble_outbound(
self,
@ -1380,6 +1295,141 @@ class AgentLoop:
buttons=buttons,
)
async def _state_restore(self, ctx: TurnContext) -> TurnState:
"""Restore checkpoint / pending user turn; extract documents."""
msg = ctx.msg
if msg.media:
new_content, image_only = extract_documents(msg.content, msg.media)
ctx.msg = dataclasses.replace(msg, content=new_content, media=image_only)
msg = ctx.msg
preview = msg.content[:80] + "..." if len(msg.content) > 80 else msg.content
logger.info("Processing message from {}:{}: {}", msg.channel, msg.sender_id, preview)
ctx.session = self.sessions.get_or_create(ctx.session_key)
mark_webui_session(ctx.session, msg.metadata)
if self._restore_runtime_checkpoint(ctx.session):
self.sessions.save(ctx.session)
if self._restore_pending_user_turn(ctx.session):
self.sessions.save(ctx.session)
return TurnState.COMPACT
async def _state_compact(self, ctx: TurnContext) -> TurnState:
ctx.session, pending = self.auto_compact.prepare_session(ctx.session, ctx.session_key)
ctx.pending_summary = pending
return TurnState.COMMAND
async def _state_command(self, ctx: TurnContext) -> TurnState:
raw = ctx.msg.content.strip()
cmd_ctx = CommandContext(
msg=ctx.msg, session=ctx.session, key=ctx.session_key, raw=raw, loop=self
)
result = await self.commands.dispatch(cmd_ctx)
if result is not None:
ctx.outbound = result
return TurnState.DONE
return TurnState.BUILD
async def _state_build(self, ctx: TurnContext) -> TurnState:
await self.consolidator.maybe_consolidate_by_tokens(
ctx.session,
session_summary=ctx.pending_summary,
replay_max_messages=self._max_messages,
)
self._set_tool_context(
ctx.msg.channel,
ctx.msg.chat_id,
ctx.msg.metadata.get("message_id"),
ctx.msg.metadata,
session_key=ctx.session_key,
)
if message_tool := self.tools.get("message"):
if isinstance(message_tool, MessageTool):
message_tool.start_turn()
_hist_kwargs: dict[str, Any] = {
"max_messages": self._max_messages,
"max_tokens": self._replay_token_budget(),
"include_timestamps": True,
}
ctx.history = ctx.session.get_history(**_hist_kwargs)
pending_ask_id = pending_ask_user_id(ctx.history)
ctx.initial_messages = self._build_initial_messages(
ctx.msg, ctx.session, ctx.history, pending_ask_id, ctx.pending_summary
)
ctx.user_persisted_early = self._persist_user_message_early(
ctx.msg, ctx.session, pending_ask_id
)
if ctx.on_progress is None:
ctx.on_progress = await self._build_bus_progress_callback(ctx.msg)
if ctx.on_retry_wait is None:
ctx.on_retry_wait = await self._build_retry_wait_callback(ctx.msg)
return TurnState.RUN
async def _state_run(self, ctx: TurnContext) -> TurnState:
final_content, tools_used, all_msgs, stop_reason, had_injections = await self._run_agent_loop(
ctx.initial_messages,
on_progress=ctx.on_progress,
on_stream=ctx.on_stream,
on_stream_end=ctx.on_stream_end,
on_retry_wait=ctx.on_retry_wait,
session=ctx.session,
channel=ctx.msg.channel,
chat_id=ctx.msg.chat_id,
message_id=ctx.msg.metadata.get("message_id"),
metadata=ctx.msg.metadata,
session_key=ctx.session_key,
pending_queue=ctx.pending_queue,
)
ctx.final_content = final_content
ctx.tools_used = tools_used
ctx.all_messages = all_msgs
ctx.stop_reason = stop_reason
ctx.had_injections = had_injections
return TurnState.SAVE
async def _state_save(self, ctx: TurnContext) -> TurnState:
if ctx.final_content is None or not ctx.final_content.strip():
ctx.final_content = EMPTY_FINAL_RESPONSE_MESSAGE
ctx.save_skip = 1 + len(ctx.history) + (1 if ctx.user_persisted_early else 0)
ctx.generated_media = generated_image_paths_from_messages(ctx.all_messages[ctx.save_skip:])
if ctx.generated_media and ctx.all_messages and ctx.all_messages[-1].get("role") == "assistant":
existing_media = ctx.all_messages[-1].get("media")
media = existing_media if isinstance(existing_media, list) else []
ctx.all_messages[-1]["media"] = list(dict.fromkeys([*media, *ctx.generated_media]))
self._save_turn(ctx.session, ctx.all_messages, ctx.save_skip)
ctx.session.enforce_file_cap(on_archive=self.context.memory.raw_archive)
self._clear_pending_user_turn(ctx.session)
self._clear_runtime_checkpoint(ctx.session)
self.sessions.save(ctx.session)
self._schedule_background(
self.consolidator.maybe_consolidate_by_tokens(
ctx.session,
replay_max_messages=self._max_messages,
)
)
return TurnState.RESPOND
async def _state_respond(self, ctx: TurnContext) -> TurnState:
ctx.outbound = self._assemble_outbound(
ctx.msg,
ctx.final_content or "",
ctx.all_messages,
ctx.stop_reason,
ctx.had_injections,
ctx.generated_media,
ctx.on_stream,
)
return TurnState.DONE
def _sanitize_persisted_blocks(
self,
content: list[dict[str, Any]],