mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-05-19 16:12:30 +00:00
refactor(loop): convert _process_message to functional state machine
- Extract TurnState enum and TurnContext dataclass - Extract state handlers: _state_restore, _state_compact, _state_command, _state_build, _state_run, _state_save, _state_respond - Extract _process_system_message for system message short-circuit - Driver loop uses getattr dispatch over explicit state transitions - Preserve all existing behavior (794 tests passing)
This commit is contained in:
parent
9fb9f53147
commit
02443ca208
@ -1127,6 +1127,97 @@ class AgentLoop:
|
||||
self._running = False
|
||||
logger.info("Agent loop stopping")
|
||||
|
||||
async def _process_system_message(
|
||||
self,
|
||||
msg: InboundMessage,
|
||||
session_key: str | None = None,
|
||||
on_progress: Callable[..., Awaitable[None]] | None = None,
|
||||
on_stream: Callable[[str], Awaitable[None]] | None = None,
|
||||
on_stream_end: Callable[..., Awaitable[None]] | None = None,
|
||||
pending_queue: asyncio.Queue | None = None,
|
||||
) -> OutboundMessage | None:
|
||||
"""Process a system inbound message (e.g. subagent announce)."""
|
||||
channel, chat_id = (
|
||||
msg.chat_id.split(":", 1) if ":" in msg.chat_id else ("cli", msg.chat_id)
|
||||
)
|
||||
logger.info("Processing system message from {}", msg.sender_id)
|
||||
key = msg.session_key_override or f"{channel}:{chat_id}"
|
||||
session = self.sessions.get_or_create(key)
|
||||
if self._restore_runtime_checkpoint(session):
|
||||
self.sessions.save(session)
|
||||
if self._restore_pending_user_turn(session):
|
||||
self.sessions.save(session)
|
||||
|
||||
session, pending = self.auto_compact.prepare_session(session, key)
|
||||
if pending:
|
||||
logger.info("Memory compact triggered for session {}", key)
|
||||
|
||||
await self.consolidator.maybe_consolidate_by_tokens(
|
||||
session,
|
||||
session_summary=pending,
|
||||
replay_max_messages=self._max_messages,
|
||||
)
|
||||
is_subagent = msg.sender_id == "subagent"
|
||||
if is_subagent and self._persist_subagent_followup(session, msg):
|
||||
logger.debug("Subagent result persisted for session {}", key)
|
||||
self.sessions.save(session)
|
||||
self._set_tool_context(
|
||||
channel, chat_id, msg.metadata.get("message_id"),
|
||||
msg.metadata, session_key=key,
|
||||
)
|
||||
_hist_kwargs: dict[str, Any] = {
|
||||
"max_messages": self._max_messages,
|
||||
"max_tokens": self._replay_token_budget(),
|
||||
"include_timestamps": True,
|
||||
}
|
||||
history = session.get_history(**_hist_kwargs)
|
||||
current_role = "assistant" if is_subagent else "user"
|
||||
|
||||
messages = self.context.build_messages(
|
||||
history=history,
|
||||
current_message="" if is_subagent else msg.content,
|
||||
channel=channel,
|
||||
chat_id=chat_id,
|
||||
session_summary=pending,
|
||||
current_role=current_role,
|
||||
sender_id=msg.sender_id,
|
||||
)
|
||||
final_content, _, all_msgs, stop_reason, _ = await self._run_agent_loop(
|
||||
messages, session=session, channel=channel, chat_id=chat_id,
|
||||
message_id=msg.metadata.get("message_id"),
|
||||
metadata=msg.metadata,
|
||||
session_key=key,
|
||||
pending_queue=pending_queue,
|
||||
)
|
||||
self._save_turn(session, all_msgs, 1 + len(history))
|
||||
session.enforce_file_cap(on_archive=self.context.memory.raw_archive)
|
||||
self._clear_runtime_checkpoint(session)
|
||||
self.sessions.save(session)
|
||||
self._schedule_background(
|
||||
self.consolidator.maybe_consolidate_by_tokens(
|
||||
session,
|
||||
replay_max_messages=self._max_messages,
|
||||
)
|
||||
)
|
||||
options = ask_user_options_from_messages(all_msgs) if stop_reason == "ask_user" else []
|
||||
content, buttons = ask_user_outbound(
|
||||
final_content or "Background task completed.",
|
||||
options,
|
||||
channel,
|
||||
)
|
||||
outbound_metadata: dict[str, Any] = {}
|
||||
if channel == "slack" and key.startswith("slack:") and key.count(":") >= 2:
|
||||
outbound_metadata["slack"] = {"thread_ts": key.split(":", 2)[2]}
|
||||
if origin_message_id := msg.metadata.get("origin_message_id"):
|
||||
outbound_metadata["origin_message_id"] = origin_message_id
|
||||
return OutboundMessage(
|
||||
channel=channel,
|
||||
chat_id=chat_id,
|
||||
content=content,
|
||||
buttons=buttons,
|
||||
metadata=outbound_metadata,
|
||||
)
|
||||
|
||||
async def _process_message(
|
||||
self,
|
||||
msg: InboundMessage,
|
||||
@ -1138,210 +1229,34 @@ class AgentLoop:
|
||||
) -> OutboundMessage | None:
|
||||
"""Process a single inbound message and return the response."""
|
||||
self._refresh_provider_snapshot()
|
||||
# System messages: parse origin from chat_id ("channel:chat_id")
|
||||
|
||||
if msg.channel == "system":
|
||||
channel, chat_id = (
|
||||
msg.chat_id.split(":", 1) if ":" in msg.chat_id else ("cli", msg.chat_id)
|
||||
)
|
||||
logger.info("Processing system message from {}", msg.sender_id)
|
||||
# Honor session_key_override so subagent announces from threaded
|
||||
# callers route to the originating thread session, not the
|
||||
# channel-level session derived from chat_id.
|
||||
key = msg.session_key_override or f"{channel}:{chat_id}"
|
||||
session = self.sessions.get_or_create(key)
|
||||
if self._restore_runtime_checkpoint(session):
|
||||
self.sessions.save(session)
|
||||
if self._restore_pending_user_turn(session):
|
||||
self.sessions.save(session)
|
||||
|
||||
session, pending = self.auto_compact.prepare_session(session, key)
|
||||
if pending:
|
||||
logger.info("Memory compact triggered for session {}", key)
|
||||
|
||||
await self.consolidator.maybe_consolidate_by_tokens(
|
||||
session,
|
||||
session_summary=pending,
|
||||
replay_max_messages=self._max_messages,
|
||||
)
|
||||
# Persist subagent follow-ups into durable history BEFORE prompt
|
||||
# assembly. ContextBuilder merges adjacent same-role messages for
|
||||
# provider compatibility, which previously caused the follow-up to
|
||||
# disappear from session.messages while still being visible to the
|
||||
# LLM via the merged prompt. See _persist_subagent_followup.
|
||||
is_subagent = msg.sender_id == "subagent"
|
||||
if is_subagent and self._persist_subagent_followup(session, msg):
|
||||
logger.debug("Subagent result persisted for session {}", key)
|
||||
self.sessions.save(session)
|
||||
self._set_tool_context(
|
||||
channel, chat_id, msg.metadata.get("message_id"),
|
||||
msg.metadata, session_key=key,
|
||||
)
|
||||
_hist_kwargs: dict[str, Any] = {
|
||||
"max_messages": self._max_messages,
|
||||
"max_tokens": self._replay_token_budget(),
|
||||
"include_timestamps": True,
|
||||
}
|
||||
history = session.get_history(**_hist_kwargs)
|
||||
current_role = "assistant" if is_subagent else "user"
|
||||
|
||||
# Subagent content is already in `history` above; passing it again
|
||||
# as current_message would double-project it into the prompt.
|
||||
messages = self.context.build_messages(
|
||||
history=history,
|
||||
current_message="" if is_subagent else msg.content,
|
||||
channel=channel,
|
||||
chat_id=chat_id,
|
||||
session_summary=pending,
|
||||
current_role=current_role,
|
||||
sender_id=msg.sender_id,
|
||||
)
|
||||
final_content, _, all_msgs, stop_reason, _ = await self._run_agent_loop(
|
||||
messages, session=session, channel=channel, chat_id=chat_id,
|
||||
message_id=msg.metadata.get("message_id"),
|
||||
metadata=msg.metadata,
|
||||
session_key=key,
|
||||
return await self._process_system_message(
|
||||
msg,
|
||||
session_key=session_key,
|
||||
on_progress=on_progress,
|
||||
on_stream=on_stream,
|
||||
on_stream_end=on_stream_end,
|
||||
pending_queue=pending_queue,
|
||||
)
|
||||
self._save_turn(session, all_msgs, 1 + len(history))
|
||||
session.enforce_file_cap(on_archive=self.context.memory.raw_archive)
|
||||
self._clear_runtime_checkpoint(session)
|
||||
self.sessions.save(session)
|
||||
self._schedule_background(
|
||||
self.consolidator.maybe_consolidate_by_tokens(
|
||||
session,
|
||||
replay_max_messages=self._max_messages,
|
||||
)
|
||||
)
|
||||
options = ask_user_options_from_messages(all_msgs) if stop_reason == "ask_user" else []
|
||||
content, buttons = ask_user_outbound(
|
||||
final_content or "Background task completed.",
|
||||
options,
|
||||
channel,
|
||||
)
|
||||
# Reconstruct channel-specific metadata from session.key so the
|
||||
# outbound reply lands in the originating thread (not the channel
|
||||
# top-level). The announce InboundMessage carries only
|
||||
# injected_event metadata; we recover thread_ts from the session
|
||||
# key, which slack writes as "slack:<chat_id>:<thread_ts>".
|
||||
outbound_metadata: dict[str, Any] = {}
|
||||
if channel == "slack" and key.startswith("slack:") and key.count(":") >= 2:
|
||||
outbound_metadata["slack"] = {"thread_ts": key.split(":", 2)[2]}
|
||||
if origin_message_id := msg.metadata.get("origin_message_id"):
|
||||
outbound_metadata["origin_message_id"] = origin_message_id
|
||||
return OutboundMessage(
|
||||
channel=channel,
|
||||
chat_id=chat_id,
|
||||
content=content,
|
||||
buttons=buttons,
|
||||
metadata=outbound_metadata,
|
||||
)
|
||||
|
||||
# Extract document text from media at the processing boundary so all
|
||||
# channels benefit without format-specific logic in ContextBuilder.
|
||||
if msg.media:
|
||||
new_content, image_only = extract_documents(msg.content, msg.media)
|
||||
msg = dataclasses.replace(msg, content=new_content, media=image_only)
|
||||
|
||||
preview = msg.content[:80] + "..." if len(msg.content) > 80 else msg.content
|
||||
logger.info("Processing message from {}:{}: {}", msg.channel, msg.sender_id, preview)
|
||||
|
||||
key = session_key or msg.session_key
|
||||
session = self.sessions.get_or_create(key)
|
||||
mark_webui_session(session, msg.metadata)
|
||||
if self._restore_runtime_checkpoint(session):
|
||||
self.sessions.save(session)
|
||||
if self._restore_pending_user_turn(session):
|
||||
self.sessions.save(session)
|
||||
|
||||
session, pending = self.auto_compact.prepare_session(session, key)
|
||||
|
||||
# Slash commands
|
||||
raw = msg.content.strip()
|
||||
ctx = CommandContext(msg=msg, session=session, key=key, raw=raw, loop=self)
|
||||
if result := await self.commands.dispatch(ctx):
|
||||
return result
|
||||
|
||||
await self.consolidator.maybe_consolidate_by_tokens(
|
||||
session,
|
||||
session_summary=pending,
|
||||
replay_max_messages=self._max_messages,
|
||||
)
|
||||
|
||||
self._set_tool_context(
|
||||
msg.channel, msg.chat_id, msg.metadata.get("message_id"),
|
||||
msg.metadata, session_key=key,
|
||||
)
|
||||
if message_tool := self.tools.get("message"):
|
||||
if isinstance(message_tool, MessageTool):
|
||||
message_tool.start_turn()
|
||||
|
||||
_hist_kwargs: dict[str, Any] = {
|
||||
"max_messages": self._max_messages,
|
||||
"max_tokens": self._replay_token_budget(),
|
||||
"include_timestamps": True,
|
||||
}
|
||||
history = session.get_history(**_hist_kwargs)
|
||||
|
||||
pending_ask_id = pending_ask_user_id(history)
|
||||
initial_messages = self._build_initial_messages(
|
||||
msg, session, history, pending_ask_id, pending
|
||||
)
|
||||
|
||||
_bus_progress = await self._build_bus_progress_callback(msg)
|
||||
_on_retry_wait = await self._build_retry_wait_callback(msg)
|
||||
|
||||
# Persist the triggering user message up front so a mid-turn crash
|
||||
# doesn't silently lose the prompt on recovery. ``media`` rides along
|
||||
# as raw on-disk paths — sanitized image blocks are stripped from
|
||||
# JSONL, and webui replay needs the paths to mint signed URLs.
|
||||
user_persisted_early = self._persist_user_message_early(msg, session, pending_ask_id)
|
||||
|
||||
final_content, _, all_msgs, stop_reason, had_injections = await self._run_agent_loop(
|
||||
initial_messages,
|
||||
on_progress=on_progress or _bus_progress,
|
||||
ctx = TurnContext(
|
||||
msg=msg,
|
||||
session=self.sessions.get_or_create(key),
|
||||
session_key=key,
|
||||
state=TurnState.RESTORE,
|
||||
on_progress=on_progress,
|
||||
on_stream=on_stream,
|
||||
on_stream_end=on_stream_end,
|
||||
on_retry_wait=_on_retry_wait,
|
||||
session=session,
|
||||
channel=msg.channel,
|
||||
chat_id=msg.chat_id,
|
||||
message_id=msg.metadata.get("message_id"),
|
||||
metadata=msg.metadata,
|
||||
session_key=key,
|
||||
pending_queue=pending_queue,
|
||||
)
|
||||
|
||||
if final_content is None or not final_content.strip():
|
||||
final_content = EMPTY_FINAL_RESPONSE_MESSAGE
|
||||
while ctx.state is not TurnState.DONE:
|
||||
handler = getattr(self, f"_state_{ctx.state.name.lower()}")
|
||||
ctx.state = await handler(ctx)
|
||||
|
||||
# Skip the already-persisted user message when saving the turn
|
||||
save_skip = 1 + len(history) + (1 if user_persisted_early else 0)
|
||||
generated_media = generated_image_paths_from_messages(all_msgs[save_skip:])
|
||||
if generated_media and all_msgs and all_msgs[-1].get("role") == "assistant":
|
||||
existing_media = all_msgs[-1].get("media")
|
||||
media = existing_media if isinstance(existing_media, list) else []
|
||||
all_msgs[-1]["media"] = list(dict.fromkeys([*media, *generated_media]))
|
||||
self._save_turn(session, all_msgs, save_skip)
|
||||
session.enforce_file_cap(on_archive=self.context.memory.raw_archive)
|
||||
self._clear_pending_user_turn(session)
|
||||
self._clear_runtime_checkpoint(session)
|
||||
self.sessions.save(session)
|
||||
self._schedule_background(
|
||||
self.consolidator.maybe_consolidate_by_tokens(
|
||||
session,
|
||||
replay_max_messages=self._max_messages,
|
||||
)
|
||||
)
|
||||
|
||||
return self._assemble_outbound(
|
||||
msg,
|
||||
final_content,
|
||||
all_msgs,
|
||||
stop_reason,
|
||||
had_injections,
|
||||
generated_media,
|
||||
on_stream,
|
||||
)
|
||||
return ctx.outbound
|
||||
|
||||
def _assemble_outbound(
|
||||
self,
|
||||
@ -1380,6 +1295,141 @@ class AgentLoop:
|
||||
buttons=buttons,
|
||||
)
|
||||
|
||||
async def _state_restore(self, ctx: TurnContext) -> TurnState:
|
||||
"""Restore checkpoint / pending user turn; extract documents."""
|
||||
msg = ctx.msg
|
||||
|
||||
if msg.media:
|
||||
new_content, image_only = extract_documents(msg.content, msg.media)
|
||||
ctx.msg = dataclasses.replace(msg, content=new_content, media=image_only)
|
||||
msg = ctx.msg
|
||||
|
||||
preview = msg.content[:80] + "..." if len(msg.content) > 80 else msg.content
|
||||
logger.info("Processing message from {}:{}: {}", msg.channel, msg.sender_id, preview)
|
||||
|
||||
ctx.session = self.sessions.get_or_create(ctx.session_key)
|
||||
mark_webui_session(ctx.session, msg.metadata)
|
||||
|
||||
if self._restore_runtime_checkpoint(ctx.session):
|
||||
self.sessions.save(ctx.session)
|
||||
if self._restore_pending_user_turn(ctx.session):
|
||||
self.sessions.save(ctx.session)
|
||||
|
||||
return TurnState.COMPACT
|
||||
|
||||
async def _state_compact(self, ctx: TurnContext) -> TurnState:
|
||||
ctx.session, pending = self.auto_compact.prepare_session(ctx.session, ctx.session_key)
|
||||
ctx.pending_summary = pending
|
||||
return TurnState.COMMAND
|
||||
|
||||
async def _state_command(self, ctx: TurnContext) -> TurnState:
|
||||
raw = ctx.msg.content.strip()
|
||||
cmd_ctx = CommandContext(
|
||||
msg=ctx.msg, session=ctx.session, key=ctx.session_key, raw=raw, loop=self
|
||||
)
|
||||
result = await self.commands.dispatch(cmd_ctx)
|
||||
if result is not None:
|
||||
ctx.outbound = result
|
||||
return TurnState.DONE
|
||||
return TurnState.BUILD
|
||||
|
||||
async def _state_build(self, ctx: TurnContext) -> TurnState:
|
||||
await self.consolidator.maybe_consolidate_by_tokens(
|
||||
ctx.session,
|
||||
session_summary=ctx.pending_summary,
|
||||
replay_max_messages=self._max_messages,
|
||||
)
|
||||
self._set_tool_context(
|
||||
ctx.msg.channel,
|
||||
ctx.msg.chat_id,
|
||||
ctx.msg.metadata.get("message_id"),
|
||||
ctx.msg.metadata,
|
||||
session_key=ctx.session_key,
|
||||
)
|
||||
if message_tool := self.tools.get("message"):
|
||||
if isinstance(message_tool, MessageTool):
|
||||
message_tool.start_turn()
|
||||
|
||||
_hist_kwargs: dict[str, Any] = {
|
||||
"max_messages": self._max_messages,
|
||||
"max_tokens": self._replay_token_budget(),
|
||||
"include_timestamps": True,
|
||||
}
|
||||
ctx.history = ctx.session.get_history(**_hist_kwargs)
|
||||
|
||||
pending_ask_id = pending_ask_user_id(ctx.history)
|
||||
ctx.initial_messages = self._build_initial_messages(
|
||||
ctx.msg, ctx.session, ctx.history, pending_ask_id, ctx.pending_summary
|
||||
)
|
||||
ctx.user_persisted_early = self._persist_user_message_early(
|
||||
ctx.msg, ctx.session, pending_ask_id
|
||||
)
|
||||
|
||||
if ctx.on_progress is None:
|
||||
ctx.on_progress = await self._build_bus_progress_callback(ctx.msg)
|
||||
if ctx.on_retry_wait is None:
|
||||
ctx.on_retry_wait = await self._build_retry_wait_callback(ctx.msg)
|
||||
|
||||
return TurnState.RUN
|
||||
|
||||
async def _state_run(self, ctx: TurnContext) -> TurnState:
|
||||
final_content, tools_used, all_msgs, stop_reason, had_injections = await self._run_agent_loop(
|
||||
ctx.initial_messages,
|
||||
on_progress=ctx.on_progress,
|
||||
on_stream=ctx.on_stream,
|
||||
on_stream_end=ctx.on_stream_end,
|
||||
on_retry_wait=ctx.on_retry_wait,
|
||||
session=ctx.session,
|
||||
channel=ctx.msg.channel,
|
||||
chat_id=ctx.msg.chat_id,
|
||||
message_id=ctx.msg.metadata.get("message_id"),
|
||||
metadata=ctx.msg.metadata,
|
||||
session_key=ctx.session_key,
|
||||
pending_queue=ctx.pending_queue,
|
||||
)
|
||||
ctx.final_content = final_content
|
||||
ctx.tools_used = tools_used
|
||||
ctx.all_messages = all_msgs
|
||||
ctx.stop_reason = stop_reason
|
||||
ctx.had_injections = had_injections
|
||||
return TurnState.SAVE
|
||||
|
||||
async def _state_save(self, ctx: TurnContext) -> TurnState:
|
||||
if ctx.final_content is None or not ctx.final_content.strip():
|
||||
ctx.final_content = EMPTY_FINAL_RESPONSE_MESSAGE
|
||||
|
||||
ctx.save_skip = 1 + len(ctx.history) + (1 if ctx.user_persisted_early else 0)
|
||||
ctx.generated_media = generated_image_paths_from_messages(ctx.all_messages[ctx.save_skip:])
|
||||
if ctx.generated_media and ctx.all_messages and ctx.all_messages[-1].get("role") == "assistant":
|
||||
existing_media = ctx.all_messages[-1].get("media")
|
||||
media = existing_media if isinstance(existing_media, list) else []
|
||||
ctx.all_messages[-1]["media"] = list(dict.fromkeys([*media, *ctx.generated_media]))
|
||||
|
||||
self._save_turn(ctx.session, ctx.all_messages, ctx.save_skip)
|
||||
ctx.session.enforce_file_cap(on_archive=self.context.memory.raw_archive)
|
||||
self._clear_pending_user_turn(ctx.session)
|
||||
self._clear_runtime_checkpoint(ctx.session)
|
||||
self.sessions.save(ctx.session)
|
||||
self._schedule_background(
|
||||
self.consolidator.maybe_consolidate_by_tokens(
|
||||
ctx.session,
|
||||
replay_max_messages=self._max_messages,
|
||||
)
|
||||
)
|
||||
return TurnState.RESPOND
|
||||
|
||||
async def _state_respond(self, ctx: TurnContext) -> TurnState:
|
||||
ctx.outbound = self._assemble_outbound(
|
||||
ctx.msg,
|
||||
ctx.final_content or "",
|
||||
ctx.all_messages,
|
||||
ctx.stop_reason,
|
||||
ctx.had_injections,
|
||||
ctx.generated_media,
|
||||
ctx.on_stream,
|
||||
)
|
||||
return TurnState.DONE
|
||||
|
||||
def _sanitize_persisted_blocks(
|
||||
self,
|
||||
content: list[dict[str, Any]],
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user