From 02443ca208c141220f21884a3ccb25da674e1340 Mon Sep 17 00:00:00 2001 From: chengyongru Date: Sat, 9 May 2026 15:40:11 +0800 Subject: [PATCH] refactor(loop): convert _process_message to functional state machine - Extract TurnState enum and TurnContext dataclass - Extract state handlers: _state_restore, _state_compact, _state_command, _state_build, _state_run, _state_save, _state_respond - Extract _process_system_message for system message short-circuit - Driver loop uses getattr dispatch over explicit state transitions - Preserve all existing behavior (794 tests passing) --- nanobot/agent/loop.py | 436 +++++++++++++++++++++++------------------- 1 file changed, 243 insertions(+), 193 deletions(-) diff --git a/nanobot/agent/loop.py b/nanobot/agent/loop.py index b42058a14..800f526af 100644 --- a/nanobot/agent/loop.py +++ b/nanobot/agent/loop.py @@ -1127,6 +1127,97 @@ class AgentLoop: self._running = False logger.info("Agent loop stopping") + async def _process_system_message( + self, + msg: InboundMessage, + session_key: str | None = None, + on_progress: Callable[..., Awaitable[None]] | None = None, + on_stream: Callable[[str], Awaitable[None]] | None = None, + on_stream_end: Callable[..., Awaitable[None]] | None = None, + pending_queue: asyncio.Queue | None = None, + ) -> OutboundMessage | None: + """Process a system inbound message (e.g. subagent announce).""" + channel, chat_id = ( + msg.chat_id.split(":", 1) if ":" in msg.chat_id else ("cli", msg.chat_id) + ) + logger.info("Processing system message from {}", msg.sender_id) + key = msg.session_key_override or f"{channel}:{chat_id}" + session = self.sessions.get_or_create(key) + if self._restore_runtime_checkpoint(session): + self.sessions.save(session) + if self._restore_pending_user_turn(session): + self.sessions.save(session) + + session, pending = self.auto_compact.prepare_session(session, key) + if pending: + logger.info("Memory compact triggered for session {}", key) + + await self.consolidator.maybe_consolidate_by_tokens( + session, + session_summary=pending, + replay_max_messages=self._max_messages, + ) + is_subagent = msg.sender_id == "subagent" + if is_subagent and self._persist_subagent_followup(session, msg): + logger.debug("Subagent result persisted for session {}", key) + self.sessions.save(session) + self._set_tool_context( + channel, chat_id, msg.metadata.get("message_id"), + msg.metadata, session_key=key, + ) + _hist_kwargs: dict[str, Any] = { + "max_messages": self._max_messages, + "max_tokens": self._replay_token_budget(), + "include_timestamps": True, + } + history = session.get_history(**_hist_kwargs) + current_role = "assistant" if is_subagent else "user" + + messages = self.context.build_messages( + history=history, + current_message="" if is_subagent else msg.content, + channel=channel, + chat_id=chat_id, + session_summary=pending, + current_role=current_role, + sender_id=msg.sender_id, + ) + final_content, _, all_msgs, stop_reason, _ = await self._run_agent_loop( + messages, session=session, channel=channel, chat_id=chat_id, + message_id=msg.metadata.get("message_id"), + metadata=msg.metadata, + session_key=key, + pending_queue=pending_queue, + ) + self._save_turn(session, all_msgs, 1 + len(history)) + session.enforce_file_cap(on_archive=self.context.memory.raw_archive) + self._clear_runtime_checkpoint(session) + self.sessions.save(session) + self._schedule_background( + self.consolidator.maybe_consolidate_by_tokens( + session, + replay_max_messages=self._max_messages, + ) + ) + options = ask_user_options_from_messages(all_msgs) if stop_reason == "ask_user" else [] + content, buttons = ask_user_outbound( + final_content or "Background task completed.", + options, + channel, + ) + outbound_metadata: dict[str, Any] = {} + if channel == "slack" and key.startswith("slack:") and key.count(":") >= 2: + outbound_metadata["slack"] = {"thread_ts": key.split(":", 2)[2]} + if origin_message_id := msg.metadata.get("origin_message_id"): + outbound_metadata["origin_message_id"] = origin_message_id + return OutboundMessage( + channel=channel, + chat_id=chat_id, + content=content, + buttons=buttons, + metadata=outbound_metadata, + ) + async def _process_message( self, msg: InboundMessage, @@ -1138,210 +1229,34 @@ class AgentLoop: ) -> OutboundMessage | None: """Process a single inbound message and return the response.""" self._refresh_provider_snapshot() - # System messages: parse origin from chat_id ("channel:chat_id") + if msg.channel == "system": - channel, chat_id = ( - msg.chat_id.split(":", 1) if ":" in msg.chat_id else ("cli", msg.chat_id) - ) - logger.info("Processing system message from {}", msg.sender_id) - # Honor session_key_override so subagent announces from threaded - # callers route to the originating thread session, not the - # channel-level session derived from chat_id. - key = msg.session_key_override or f"{channel}:{chat_id}" - session = self.sessions.get_or_create(key) - if self._restore_runtime_checkpoint(session): - self.sessions.save(session) - if self._restore_pending_user_turn(session): - self.sessions.save(session) - - session, pending = self.auto_compact.prepare_session(session, key) - if pending: - logger.info("Memory compact triggered for session {}", key) - - await self.consolidator.maybe_consolidate_by_tokens( - session, - session_summary=pending, - replay_max_messages=self._max_messages, - ) - # Persist subagent follow-ups into durable history BEFORE prompt - # assembly. ContextBuilder merges adjacent same-role messages for - # provider compatibility, which previously caused the follow-up to - # disappear from session.messages while still being visible to the - # LLM via the merged prompt. See _persist_subagent_followup. - is_subagent = msg.sender_id == "subagent" - if is_subagent and self._persist_subagent_followup(session, msg): - logger.debug("Subagent result persisted for session {}", key) - self.sessions.save(session) - self._set_tool_context( - channel, chat_id, msg.metadata.get("message_id"), - msg.metadata, session_key=key, - ) - _hist_kwargs: dict[str, Any] = { - "max_messages": self._max_messages, - "max_tokens": self._replay_token_budget(), - "include_timestamps": True, - } - history = session.get_history(**_hist_kwargs) - current_role = "assistant" if is_subagent else "user" - - # Subagent content is already in `history` above; passing it again - # as current_message would double-project it into the prompt. - messages = self.context.build_messages( - history=history, - current_message="" if is_subagent else msg.content, - channel=channel, - chat_id=chat_id, - session_summary=pending, - current_role=current_role, - sender_id=msg.sender_id, - ) - final_content, _, all_msgs, stop_reason, _ = await self._run_agent_loop( - messages, session=session, channel=channel, chat_id=chat_id, - message_id=msg.metadata.get("message_id"), - metadata=msg.metadata, - session_key=key, + return await self._process_system_message( + msg, + session_key=session_key, + on_progress=on_progress, + on_stream=on_stream, + on_stream_end=on_stream_end, pending_queue=pending_queue, ) - self._save_turn(session, all_msgs, 1 + len(history)) - session.enforce_file_cap(on_archive=self.context.memory.raw_archive) - self._clear_runtime_checkpoint(session) - self.sessions.save(session) - self._schedule_background( - self.consolidator.maybe_consolidate_by_tokens( - session, - replay_max_messages=self._max_messages, - ) - ) - options = ask_user_options_from_messages(all_msgs) if stop_reason == "ask_user" else [] - content, buttons = ask_user_outbound( - final_content or "Background task completed.", - options, - channel, - ) - # Reconstruct channel-specific metadata from session.key so the - # outbound reply lands in the originating thread (not the channel - # top-level). The announce InboundMessage carries only - # injected_event metadata; we recover thread_ts from the session - # key, which slack writes as "slack::". - outbound_metadata: dict[str, Any] = {} - if channel == "slack" and key.startswith("slack:") and key.count(":") >= 2: - outbound_metadata["slack"] = {"thread_ts": key.split(":", 2)[2]} - if origin_message_id := msg.metadata.get("origin_message_id"): - outbound_metadata["origin_message_id"] = origin_message_id - return OutboundMessage( - channel=channel, - chat_id=chat_id, - content=content, - buttons=buttons, - metadata=outbound_metadata, - ) - - # Extract document text from media at the processing boundary so all - # channels benefit without format-specific logic in ContextBuilder. - if msg.media: - new_content, image_only = extract_documents(msg.content, msg.media) - msg = dataclasses.replace(msg, content=new_content, media=image_only) - - preview = msg.content[:80] + "..." if len(msg.content) > 80 else msg.content - logger.info("Processing message from {}:{}: {}", msg.channel, msg.sender_id, preview) key = session_key or msg.session_key - session = self.sessions.get_or_create(key) - mark_webui_session(session, msg.metadata) - if self._restore_runtime_checkpoint(session): - self.sessions.save(session) - if self._restore_pending_user_turn(session): - self.sessions.save(session) - - session, pending = self.auto_compact.prepare_session(session, key) - - # Slash commands - raw = msg.content.strip() - ctx = CommandContext(msg=msg, session=session, key=key, raw=raw, loop=self) - if result := await self.commands.dispatch(ctx): - return result - - await self.consolidator.maybe_consolidate_by_tokens( - session, - session_summary=pending, - replay_max_messages=self._max_messages, - ) - - self._set_tool_context( - msg.channel, msg.chat_id, msg.metadata.get("message_id"), - msg.metadata, session_key=key, - ) - if message_tool := self.tools.get("message"): - if isinstance(message_tool, MessageTool): - message_tool.start_turn() - - _hist_kwargs: dict[str, Any] = { - "max_messages": self._max_messages, - "max_tokens": self._replay_token_budget(), - "include_timestamps": True, - } - history = session.get_history(**_hist_kwargs) - - pending_ask_id = pending_ask_user_id(history) - initial_messages = self._build_initial_messages( - msg, session, history, pending_ask_id, pending - ) - - _bus_progress = await self._build_bus_progress_callback(msg) - _on_retry_wait = await self._build_retry_wait_callback(msg) - - # Persist the triggering user message up front so a mid-turn crash - # doesn't silently lose the prompt on recovery. ``media`` rides along - # as raw on-disk paths — sanitized image blocks are stripped from - # JSONL, and webui replay needs the paths to mint signed URLs. - user_persisted_early = self._persist_user_message_early(msg, session, pending_ask_id) - - final_content, _, all_msgs, stop_reason, had_injections = await self._run_agent_loop( - initial_messages, - on_progress=on_progress or _bus_progress, + ctx = TurnContext( + msg=msg, + session=self.sessions.get_or_create(key), + session_key=key, + state=TurnState.RESTORE, + on_progress=on_progress, on_stream=on_stream, on_stream_end=on_stream_end, - on_retry_wait=_on_retry_wait, - session=session, - channel=msg.channel, - chat_id=msg.chat_id, - message_id=msg.metadata.get("message_id"), - metadata=msg.metadata, - session_key=key, pending_queue=pending_queue, ) - if final_content is None or not final_content.strip(): - final_content = EMPTY_FINAL_RESPONSE_MESSAGE + while ctx.state is not TurnState.DONE: + handler = getattr(self, f"_state_{ctx.state.name.lower()}") + ctx.state = await handler(ctx) - # Skip the already-persisted user message when saving the turn - save_skip = 1 + len(history) + (1 if user_persisted_early else 0) - generated_media = generated_image_paths_from_messages(all_msgs[save_skip:]) - if generated_media and all_msgs and all_msgs[-1].get("role") == "assistant": - existing_media = all_msgs[-1].get("media") - media = existing_media if isinstance(existing_media, list) else [] - all_msgs[-1]["media"] = list(dict.fromkeys([*media, *generated_media])) - self._save_turn(session, all_msgs, save_skip) - session.enforce_file_cap(on_archive=self.context.memory.raw_archive) - self._clear_pending_user_turn(session) - self._clear_runtime_checkpoint(session) - self.sessions.save(session) - self._schedule_background( - self.consolidator.maybe_consolidate_by_tokens( - session, - replay_max_messages=self._max_messages, - ) - ) - - return self._assemble_outbound( - msg, - final_content, - all_msgs, - stop_reason, - had_injections, - generated_media, - on_stream, - ) + return ctx.outbound def _assemble_outbound( self, @@ -1380,6 +1295,141 @@ class AgentLoop: buttons=buttons, ) + async def _state_restore(self, ctx: TurnContext) -> TurnState: + """Restore checkpoint / pending user turn; extract documents.""" + msg = ctx.msg + + if msg.media: + new_content, image_only = extract_documents(msg.content, msg.media) + ctx.msg = dataclasses.replace(msg, content=new_content, media=image_only) + msg = ctx.msg + + preview = msg.content[:80] + "..." if len(msg.content) > 80 else msg.content + logger.info("Processing message from {}:{}: {}", msg.channel, msg.sender_id, preview) + + ctx.session = self.sessions.get_or_create(ctx.session_key) + mark_webui_session(ctx.session, msg.metadata) + + if self._restore_runtime_checkpoint(ctx.session): + self.sessions.save(ctx.session) + if self._restore_pending_user_turn(ctx.session): + self.sessions.save(ctx.session) + + return TurnState.COMPACT + + async def _state_compact(self, ctx: TurnContext) -> TurnState: + ctx.session, pending = self.auto_compact.prepare_session(ctx.session, ctx.session_key) + ctx.pending_summary = pending + return TurnState.COMMAND + + async def _state_command(self, ctx: TurnContext) -> TurnState: + raw = ctx.msg.content.strip() + cmd_ctx = CommandContext( + msg=ctx.msg, session=ctx.session, key=ctx.session_key, raw=raw, loop=self + ) + result = await self.commands.dispatch(cmd_ctx) + if result is not None: + ctx.outbound = result + return TurnState.DONE + return TurnState.BUILD + + async def _state_build(self, ctx: TurnContext) -> TurnState: + await self.consolidator.maybe_consolidate_by_tokens( + ctx.session, + session_summary=ctx.pending_summary, + replay_max_messages=self._max_messages, + ) + self._set_tool_context( + ctx.msg.channel, + ctx.msg.chat_id, + ctx.msg.metadata.get("message_id"), + ctx.msg.metadata, + session_key=ctx.session_key, + ) + if message_tool := self.tools.get("message"): + if isinstance(message_tool, MessageTool): + message_tool.start_turn() + + _hist_kwargs: dict[str, Any] = { + "max_messages": self._max_messages, + "max_tokens": self._replay_token_budget(), + "include_timestamps": True, + } + ctx.history = ctx.session.get_history(**_hist_kwargs) + + pending_ask_id = pending_ask_user_id(ctx.history) + ctx.initial_messages = self._build_initial_messages( + ctx.msg, ctx.session, ctx.history, pending_ask_id, ctx.pending_summary + ) + ctx.user_persisted_early = self._persist_user_message_early( + ctx.msg, ctx.session, pending_ask_id + ) + + if ctx.on_progress is None: + ctx.on_progress = await self._build_bus_progress_callback(ctx.msg) + if ctx.on_retry_wait is None: + ctx.on_retry_wait = await self._build_retry_wait_callback(ctx.msg) + + return TurnState.RUN + + async def _state_run(self, ctx: TurnContext) -> TurnState: + final_content, tools_used, all_msgs, stop_reason, had_injections = await self._run_agent_loop( + ctx.initial_messages, + on_progress=ctx.on_progress, + on_stream=ctx.on_stream, + on_stream_end=ctx.on_stream_end, + on_retry_wait=ctx.on_retry_wait, + session=ctx.session, + channel=ctx.msg.channel, + chat_id=ctx.msg.chat_id, + message_id=ctx.msg.metadata.get("message_id"), + metadata=ctx.msg.metadata, + session_key=ctx.session_key, + pending_queue=ctx.pending_queue, + ) + ctx.final_content = final_content + ctx.tools_used = tools_used + ctx.all_messages = all_msgs + ctx.stop_reason = stop_reason + ctx.had_injections = had_injections + return TurnState.SAVE + + async def _state_save(self, ctx: TurnContext) -> TurnState: + if ctx.final_content is None or not ctx.final_content.strip(): + ctx.final_content = EMPTY_FINAL_RESPONSE_MESSAGE + + ctx.save_skip = 1 + len(ctx.history) + (1 if ctx.user_persisted_early else 0) + ctx.generated_media = generated_image_paths_from_messages(ctx.all_messages[ctx.save_skip:]) + if ctx.generated_media and ctx.all_messages and ctx.all_messages[-1].get("role") == "assistant": + existing_media = ctx.all_messages[-1].get("media") + media = existing_media if isinstance(existing_media, list) else [] + ctx.all_messages[-1]["media"] = list(dict.fromkeys([*media, *ctx.generated_media])) + + self._save_turn(ctx.session, ctx.all_messages, ctx.save_skip) + ctx.session.enforce_file_cap(on_archive=self.context.memory.raw_archive) + self._clear_pending_user_turn(ctx.session) + self._clear_runtime_checkpoint(ctx.session) + self.sessions.save(ctx.session) + self._schedule_background( + self.consolidator.maybe_consolidate_by_tokens( + ctx.session, + replay_max_messages=self._max_messages, + ) + ) + return TurnState.RESPOND + + async def _state_respond(self, ctx: TurnContext) -> TurnState: + ctx.outbound = self._assemble_outbound( + ctx.msg, + ctx.final_content or "", + ctx.all_messages, + ctx.stop_reason, + ctx.had_injections, + ctx.generated_media, + ctx.on_stream, + ) + return TurnState.DONE + def _sanitize_persisted_blocks( self, content: list[dict[str, Any]],