refactor(loop): address code review nits

- Fix _assemble_outbound on_stream type annotation (Callable[[str], Awaitable[None]] | None)
- Use last_msg consistently in _state_save instead of re-indexing
- Remove dead  fallback in _state_respond (guaranteed non-None by _state_save)
- Change pending_summary type from Any to str | None
- Make session optional in TurnContext to avoid redundant fetch
- Add defensive dispatch with RuntimeError for missing handlers
This commit is contained in:
chengyongru 2026-05-09 16:29:55 +08:00
parent 1d9e74a358
commit 7669fccfad

View File

@ -198,9 +198,9 @@ class TurnState(Enum):
@dataclass @dataclass
class TurnContext: class TurnContext:
msg: InboundMessage msg: InboundMessage
session: Session
session_key: str session_key: str
state: TurnState state: TurnState
session: Session | None = None
history: list[dict[str, Any]] = field(default_factory=list) history: list[dict[str, Any]] = field(default_factory=list)
initial_messages: list[dict[str, Any]] = field(default_factory=list) initial_messages: list[dict[str, Any]] = field(default_factory=list)
@ -223,7 +223,7 @@ class TurnContext:
on_retry_wait: Callable[[str], Awaitable[None]] | None = None on_retry_wait: Callable[[str], Awaitable[None]] | None = None
pending_queue: asyncio.Queue | None = None pending_queue: asyncio.Queue | None = None
pending_summary: Any = None pending_summary: str | None = None
class AgentLoop: class AgentLoop:
@ -1198,7 +1198,10 @@ class AgentLoop:
) )
while ctx.state is not TurnState.DONE: while ctx.state is not TurnState.DONE:
handler = getattr(self, f"_state_{ctx.state.name.lower()}") handler_name = f"_state_{ctx.state.name.lower()}"
handler = getattr(self, handler_name, None)
if handler is None:
raise RuntimeError(f"Missing state handler for {ctx.state}")
ctx.state = await handler(ctx) ctx.state = await handler(ctx)
return ctx.outbound return ctx.outbound
@ -1211,7 +1214,7 @@ class AgentLoop:
stop_reason: str, stop_reason: str,
had_injections: bool, had_injections: bool,
generated_media: list[str], generated_media: list[str],
on_stream: Callable | None, on_stream: Callable[[str], Awaitable[None]] | None,
) -> OutboundMessage | None: ) -> OutboundMessage | None:
"""Assemble the final outbound message from turn results.""" """Assemble the final outbound message from turn results."""
# MessageTool suppression # MessageTool suppression
@ -1252,6 +1255,9 @@ class AgentLoop:
preview = msg.content[:80] + "..." if len(msg.content) > 80 else msg.content preview = msg.content[:80] + "..." if len(msg.content) > 80 else msg.content
logger.info("Processing message from {}:{}: {}", msg.channel, msg.sender_id, preview) logger.info("Processing message from {}:{}: {}", msg.channel, msg.sender_id, preview)
# Session is already fetched by the caller (_process_message) but
# ensure it exists in case this handler is invoked independently.
if ctx.session is None:
ctx.session = self.sessions.get_or_create(ctx.session_key) ctx.session = self.sessions.get_or_create(ctx.session_key)
mark_webui_session(ctx.session, msg.metadata) mark_webui_session(ctx.session, msg.metadata)
@ -1349,9 +1355,9 @@ class AgentLoop:
ctx.generated_media = generated_image_paths_from_messages(skip_msgs) ctx.generated_media = generated_image_paths_from_messages(skip_msgs)
last_msg = ctx.all_messages[-1] if ctx.all_messages else None last_msg = ctx.all_messages[-1] if ctx.all_messages else None
if ctx.generated_media and last_msg and last_msg.get("role") == "assistant": if ctx.generated_media and last_msg and last_msg.get("role") == "assistant":
existing_media = ctx.all_messages[-1].get("media") existing_media = last_msg.get("media")
media = existing_media if isinstance(existing_media, list) else [] media = existing_media if isinstance(existing_media, list) else []
ctx.all_messages[-1]["media"] = list(dict.fromkeys([*media, *ctx.generated_media])) last_msg["media"] = list(dict.fromkeys([*media, *ctx.generated_media]))
self._save_turn(ctx.session, ctx.all_messages, ctx.save_skip) self._save_turn(ctx.session, ctx.all_messages, ctx.save_skip)
ctx.session.enforce_file_cap(on_archive=self.context.memory.raw_archive) ctx.session.enforce_file_cap(on_archive=self.context.memory.raw_archive)
@ -1369,7 +1375,7 @@ class AgentLoop:
async def _state_respond(self, ctx: TurnContext) -> TurnState: async def _state_respond(self, ctx: TurnContext) -> TurnState:
ctx.outbound = self._assemble_outbound( ctx.outbound = self._assemble_outbound(
ctx.msg, ctx.msg,
ctx.final_content or "", ctx.final_content,
ctx.all_messages, ctx.all_messages,
ctx.stop_reason, ctx.stop_reason,
ctx.had_injections, ctx.had_injections,