refactor(agent): surgical extraction from AgentRunner.run()

Extract 3 focused helpers while keeping the iteration control flow
inline and readable:

- _prepare_context(ctx) – context governance pipeline (orphan cleanup,
  backfill, microcompact, budget, snip)
- _try_drain_injections(ctx, ...) – simplified signature using RunContext
  + bool return (was tuple[bool, int])
- _terminate_with_error(ctx, ...) – deduplicates 3 error-path blocks

Keep RunContext as a minimal mutable state bag; remove LoopAction enum
and the over-extracted response dispatchers (_respond, _handle_tool_calls,
_handle_empty, _handle_length, _handle_final, _handle_max_iterations).
This commit is contained in:
chengyongru 2026-05-18 16:40:32 +08:00
parent d4ade8f680
commit e977a43445

View File

@ -108,6 +108,49 @@ class AgentRunResult:
had_injections: bool = False had_injections: bool = False
@dataclass(slots=True)
class RunContext:
"""Mutable accumulator for state that lives across loop iterations."""
spec: AgentRunSpec
hook: AgentHook
messages: list[dict[str, Any]] # OpenAI-format chat messages (mutated in place)
iteration: int = 0
# Cached by _prepare_context for downstream retry access.
# Stale after _prepare_context returns — do not use elsewhere.
messages_for_model: list[dict[str, Any]] | None = None
final_content: str | None = None
tools_used: list[str] = field(default_factory=list)
usage: dict[str, int] = field(
default_factory=lambda: {"prompt_tokens": 0, "completion_tokens": 0},
)
error: str | None = None
stop_reason: str = "completed"
tool_events: list[dict[str, str]] = field(default_factory=list)
# Per-run throttle counters (keyed by tool name)
external_lookup_counts: dict[str, int] = field(default_factory=dict)
workspace_violation_counts: dict[str, int] = field(default_factory=dict)
empty_content_retries: int = 0
length_recovery_count: int = 0
had_injections: bool = False
injection_cycles: int = 0
def to_result(self) -> AgentRunResult:
return AgentRunResult(
final_content=self.final_content,
messages=self.messages,
tools_used=self.tools_used,
usage=self.usage,
stop_reason=self.stop_reason,
error=self.error,
tool_events=self.tool_events,
had_injections=self.had_injections,
)
class AgentRunner: class AgentRunner:
"""Run a tool-capable LLM loop without product-layer concerns.""" """Run a tool-capable LLM loop without product-layer concerns."""
@ -155,47 +198,42 @@ class AgentRunner:
async def _try_drain_injections( async def _try_drain_injections(
self, self,
spec: AgentRunSpec, ctx: RunContext,
messages: list[dict[str, Any]],
assistant_message: dict[str, Any] | None,
injection_cycles: int,
*, *,
assistant_message: dict[str, Any] | None = None,
phase: str = "after error", phase: str = "after error",
iteration: int | None = None, ) -> bool:
) -> tuple[bool, int]: """Drain pending injections into ctx.messages.
"""Drain pending injections. Returns (should_continue, updated_cycles).
If injections are found and we haven't exceeded _MAX_INJECTION_CYCLES, Returns True if injections were appended (caller should continue).
append them to *messages* (and emit a checkpoint if *assistant_message* Mutates ctx.injection_cycles and ctx.had_injections directly.
and *iteration* are both provided) and return (True, cycles+1) so the
caller continues the iteration loop. Otherwise return (False, cycles).
""" """
if injection_cycles >= _MAX_INJECTION_CYCLES: if ctx.injection_cycles >= _MAX_INJECTION_CYCLES:
return False, injection_cycles return False
injections = await self._drain_injections(spec) injections = await self._drain_injections(ctx.spec)
if not injections: if not injections:
return False, injection_cycles return False
injection_cycles += 1 ctx.injection_cycles += 1
if assistant_message is not None: if assistant_message is not None:
messages.append(assistant_message) ctx.messages.append(assistant_message)
if iteration is not None: await self._emit_checkpoint(
await self._emit_checkpoint( ctx.spec,
spec, {
{ "phase": "final_response",
"phase": "final_response", "iteration": ctx.iteration,
"iteration": iteration, "model": ctx.spec.model,
"model": spec.model, "assistant_message": assistant_message,
"assistant_message": assistant_message, "completed_tool_results": [],
"completed_tool_results": [], "pending_tool_calls": [],
"pending_tool_calls": [], },
}, )
) self._append_injected_messages(ctx.messages, injections)
self._append_injected_messages(messages, injections) ctx.had_injections = True
logger.info( logger.info(
"Injected {} follow-up message(s) {} ({}/{})", "Injected {} follow-up message(s) {} ({}/{})",
len(injections), phase, injection_cycles, _MAX_INJECTION_CYCLES, len(injections), phase, ctx.injection_cycles, _MAX_INJECTION_CYCLES,
) )
return True, injection_cycles return True
async def _drain_injections(self, spec: AgentRunSpec) -> list[dict[str, Any]]: async def _drain_injections(self, spec: AgentRunSpec) -> list[dict[str, Any]]:
"""Drain pending user messages via the injection callback. """Drain pending user messages via the injection callback.
@ -243,55 +281,23 @@ class AgentRunner:
return injected_messages return injected_messages
async def run(self, spec: AgentRunSpec) -> AgentRunResult: async def run(self, spec: AgentRunSpec) -> AgentRunResult:
hook = spec.hook or AgentHook() ctx = RunContext(
messages = list(spec.initial_messages) spec=spec,
final_content: str | None = None hook=spec.hook or AgentHook(),
tools_used: list[str] = [] messages=list(spec.initial_messages),
usage: dict[str, int] = {"prompt_tokens": 0, "completion_tokens": 0} )
error: str | None = None for ctx.iteration in range(spec.max_iterations):
stop_reason = "completed" messages_for_model = await self._prepare_context(ctx)
tool_events: list[dict[str, str]] = []
external_lookup_counts: dict[str, int] = {} hook_context = AgentHookContext(iteration=ctx.iteration, messages=ctx.messages)
# Per-turn throttle for repeated attempts against the same outside target. await ctx.hook.before_iteration(hook_context)
workspace_violation_counts: dict[str, int] = {} response = await self._request_model(ctx.spec, messages_for_model, ctx.hook, hook_context)
empty_content_retries = 0
length_recovery_count = 0
had_injections = False
injection_cycles = 0
for iteration in range(spec.max_iterations):
try:
# Keep the persisted conversation untouched. Context governance
# may repair or compact historical messages for the model, but
# those synthetic edits must not shift the append boundary used
# later when the caller saves only the new turn.
messages_for_model = self._drop_orphan_tool_results(messages)
messages_for_model = self._backfill_missing_tool_results(messages_for_model)
messages_for_model = self._microcompact(messages_for_model)
messages_for_model = self._apply_tool_result_budget(spec, messages_for_model)
messages_for_model = self._snip_history(spec, messages_for_model)
# Snipping may have created new orphans; clean them up.
messages_for_model = self._drop_orphan_tool_results(messages_for_model)
messages_for_model = self._backfill_missing_tool_results(messages_for_model)
except Exception:
logger.exception(
"Context governance failed on turn {} for {}; applying minimal repair",
iteration,
spec.session_key or "default",
)
try:
messages_for_model = self._drop_orphan_tool_results(messages)
messages_for_model = self._backfill_missing_tool_results(messages_for_model)
except Exception:
messages_for_model = messages
context = AgentHookContext(iteration=iteration, messages=messages)
await hook.before_iteration(context)
response = await self._request_model(spec, messages_for_model, hook, context)
raw_usage = self._usage_dict(response.usage) raw_usage = self._usage_dict(response.usage)
context.response = response hook_context.response = response
context.usage = dict(raw_usage) hook_context.usage = dict(raw_usage)
context.tool_calls = list(response.tool_calls) hook_context.tool_calls = list(response.tool_calls)
self._accumulate_usage(usage, raw_usage) self._accumulate_usage(ctx.usage, raw_usage)
reasoning_text, cleaned_content = extract_reasoning( reasoning_text, cleaned_content = extract_reasoning(
response.reasoning_content, response.reasoning_content,
@ -299,47 +305,49 @@ class AgentRunner:
response.content, response.content,
) )
response.content = cleaned_content response.content = cleaned_content
if reasoning_text and not context.streamed_reasoning: if reasoning_text and not hook_context.streamed_reasoning:
await hook.emit_reasoning(reasoning_text) await ctx.hook.emit_reasoning(reasoning_text)
await hook.emit_reasoning_end() await ctx.hook.emit_reasoning_end()
context.streamed_reasoning = True hook_context.streamed_reasoning = True
if response.should_execute_tools: if response.should_execute_tools:
context.tool_calls = list(response.tool_calls) hook_context.tool_calls = list(response.tool_calls)
if hook.wants_streaming(): if ctx.hook.wants_streaming():
await hook.on_stream_end(context, resuming=True) await ctx.hook.on_stream_end(hook_context, resuming=True)
openai_tool_calls = [tc.to_openai_tool_call() for tc in response.tool_calls]
assistant_message = build_assistant_message( assistant_message = build_assistant_message(
response.content or "", response.content or "",
tool_calls=[tc.to_openai_tool_call() for tc in response.tool_calls], tool_calls=openai_tool_calls,
reasoning_content=response.reasoning_content, reasoning_content=response.reasoning_content,
thinking_blocks=response.thinking_blocks, thinking_blocks=response.thinking_blocks,
) )
messages.append(assistant_message) ctx.messages.append(assistant_message)
tools_used.extend(tc.name for tc in response.tool_calls) ctx.tools_used.extend(tc.name for tc in response.tool_calls)
await self._emit_checkpoint( await self._emit_checkpoint(
spec, ctx.spec,
{ {
"phase": "awaiting_tools", "phase": "awaiting_tools",
"iteration": iteration, "iteration": ctx.iteration,
"model": spec.model, "model": ctx.spec.model,
"assistant_message": assistant_message, "assistant_message": assistant_message,
"completed_tool_results": [], "completed_tool_results": [],
"pending_tool_calls": [tc.to_openai_tool_call() for tc in response.tool_calls], "pending_tool_calls": openai_tool_calls,
}, },
) )
await hook.before_execute_tools(context) await ctx.hook.before_execute_tools(hook_context)
results, new_events, fatal_error = await self._execute_tools( results, new_events, fatal_error = await self._execute_tools(
spec, ctx.spec,
response.tool_calls, response.tool_calls,
external_lookup_counts, ctx.external_lookup_counts,
workspace_violation_counts, ctx.workspace_violation_counts,
) )
tool_events.extend(new_events) ctx.tool_events.extend(new_events)
context.tool_results = list(results) hook_context.tool_results = list(results)
context.tool_events = list(new_events) hook_context.tool_events = list(new_events)
completed_tool_results: list[dict[str, Any]] = [] completed_tool_results: list[dict[str, Any]] = []
for tool_call, result in zip(response.tool_calls, results): for tool_call, result in zip(response.tool_calls, results):
tool_message = { tool_message = {
@ -347,112 +355,101 @@ class AgentRunner:
"tool_call_id": tool_call.id, "tool_call_id": tool_call.id,
"name": tool_call.name, "name": tool_call.name,
"content": self._normalize_tool_result( "content": self._normalize_tool_result(
spec, ctx.spec,
tool_call.id, tool_call.id,
tool_call.name, tool_call.name,
result, result,
), ),
} }
messages.append(tool_message) ctx.messages.append(tool_message)
completed_tool_results.append(tool_message) completed_tool_results.append(tool_message)
if fatal_error is not None: if fatal_error is not None:
error = f"Error: {type(fatal_error).__name__}: {fatal_error}" if await self._terminate_with_error(
final_content = error ctx, hook_context,
stop_reason = "tool_error" error=f"Error: {type(fatal_error).__name__}: {fatal_error}",
self._append_final_message(messages, final_content) stop_reason="tool_error",
context.final_content = final_content append_fn="final",
context.error = error
context.stop_reason = stop_reason
await hook.after_iteration(context)
should_continue, injection_cycles = await self._try_drain_injections(
spec, messages, None, injection_cycles,
phase="after tool error", phase="after tool error",
) ):
if should_continue:
had_injections = True
continue continue
break break
await self._emit_checkpoint( await self._emit_checkpoint(
spec, ctx.spec,
{ {
"phase": "tools_completed", "phase": "tools_completed",
"iteration": iteration, "iteration": ctx.iteration,
"model": spec.model, "model": ctx.spec.model,
"assistant_message": assistant_message, "assistant_message": assistant_message,
"completed_tool_results": completed_tool_results, "completed_tool_results": completed_tool_results,
"pending_tool_calls": [], "pending_tool_calls": [],
}, },
) )
empty_content_retries = 0 ctx.empty_content_retries = 0
length_recovery_count = 0 ctx.length_recovery_count = 0
# Checkpoint 1: drain injections after tools, before next LLM call await self._try_drain_injections(ctx, phase="after tool execution")
_drained, injection_cycles = await self._try_drain_injections( await ctx.hook.after_iteration(hook_context)
spec, messages, None, injection_cycles,
phase="after tool execution",
)
if _drained:
had_injections = True
await hook.after_iteration(context)
continue continue
if response.has_tool_calls: if response.has_tool_calls:
logger.warning( logger.warning(
"Ignoring tool calls under finish_reason='{}' for {}", "Ignoring tool calls under finish_reason='{}' for {}",
response.finish_reason, response.finish_reason,
spec.session_key or "default", ctx.spec.session_key or "default",
) )
clean = hook.finalize_content(context, response.content) clean = ctx.hook.finalize_content(hook_context, response.content)
if response.finish_reason != "error" and is_blank_text(clean): if response.finish_reason != "error" and is_blank_text(clean):
empty_content_retries += 1 ctx.empty_content_retries += 1
if empty_content_retries < _MAX_EMPTY_RETRIES: if ctx.empty_content_retries < _MAX_EMPTY_RETRIES:
logger.warning( logger.warning(
"Empty response on turn {} for {} ({}/{}); retrying", "Empty response on turn {} for {} ({}/{}); retrying",
iteration, ctx.iteration,
spec.session_key or "default", ctx.spec.session_key or "default",
empty_content_retries, ctx.empty_content_retries,
_MAX_EMPTY_RETRIES, _MAX_EMPTY_RETRIES,
) )
if hook.wants_streaming(): if ctx.hook.wants_streaming():
await hook.on_stream_end(context, resuming=False) await ctx.hook.on_stream_end(hook_context, resuming=False)
await hook.after_iteration(context) await ctx.hook.after_iteration(hook_context)
continue continue
logger.warning( logger.warning(
"Empty response on turn {} for {} after {} retries; attempting finalization", "Empty response on turn {} for {} after {} retries; attempting finalization",
iteration, ctx.iteration,
spec.session_key or "default", ctx.spec.session_key or "default",
empty_content_retries, ctx.empty_content_retries,
) )
if hook.wants_streaming(): if ctx.hook.wants_streaming():
await hook.on_stream_end(context, resuming=False) await ctx.hook.on_stream_end(hook_context, resuming=False)
response = await self._request_finalization_retry(spec, messages_for_model) response = await self._request_finalization_retry(ctx.spec, messages_for_model)
retry_usage = self._usage_dict(response.usage) retry_usage = self._usage_dict(response.usage)
self._accumulate_usage(usage, retry_usage) self._accumulate_usage(ctx.usage, retry_usage)
raw_usage = self._merge_usage(raw_usage, retry_usage) hook_context.response = response
context.response = response hook_context.usage = dict(self._merge_usage(hook_context.usage, retry_usage))
context.usage = dict(raw_usage) hook_context.tool_calls = list(response.tool_calls)
context.tool_calls = list(response.tool_calls) clean = ctx.hook.finalize_content(hook_context, response.content)
clean = hook.finalize_content(context, response.content)
if response.finish_reason == "length" and not is_blank_text(clean): if response.finish_reason == "length" and not is_blank_text(clean):
length_recovery_count += 1 ctx.length_recovery_count += 1
if length_recovery_count <= _MAX_LENGTH_RECOVERIES: if ctx.length_recovery_count <= _MAX_LENGTH_RECOVERIES:
logger.info( logger.info(
"Output truncated on turn {} for {} ({}/{}); continuing", "Output truncated on turn {} for {} ({}/{}); continuing",
iteration, ctx.iteration,
spec.session_key or "default", ctx.spec.session_key or "default",
length_recovery_count, ctx.length_recovery_count,
_MAX_LENGTH_RECOVERIES, _MAX_LENGTH_RECOVERIES,
) )
if hook.wants_streaming(): if ctx.hook.wants_streaming():
await hook.on_stream_end(context, resuming=True) await ctx.hook.on_stream_end(hook_context, resuming=True)
messages.append(build_assistant_message( ctx.messages.append(build_assistant_message(
clean, clean,
reasoning_content=response.reasoning_content, reasoning_content=response.reasoning_content,
thinking_blocks=response.thinking_blocks, thinking_blocks=response.thinking_blocks,
)) ))
messages.append(build_length_recovery_message()) ctx.messages.append(build_length_recovery_message())
await hook.after_iteration(context) await ctx.hook.after_iteration(hook_context)
continue continue
assistant_message: dict[str, Any] | None = None assistant_message: dict[str, Any] | None = None
@ -463,115 +460,131 @@ class AgentRunner:
thinking_blocks=response.thinking_blocks, thinking_blocks=response.thinking_blocks,
) )
# Check for mid-turn injections BEFORE signaling stream end. should_continue = await self._try_drain_injections(
# If injections are found we keep the stream alive (resuming=True) ctx,
# so streaming channels don't prematurely finalize the card. assistant_message=assistant_message,
should_continue, injection_cycles = await self._try_drain_injections(
spec, messages, assistant_message, injection_cycles,
phase="after final response", phase="after final response",
iteration=iteration,
) )
if should_continue:
had_injections = True
if hook.wants_streaming(): if ctx.hook.wants_streaming():
await hook.on_stream_end(context, resuming=should_continue) await ctx.hook.on_stream_end(hook_context, resuming=should_continue)
if should_continue: if should_continue:
await hook.after_iteration(context) await ctx.hook.after_iteration(hook_context)
continue continue
if response.finish_reason == "error": if response.finish_reason == "error":
final_content = clean or spec.error_message or _DEFAULT_ERROR_MESSAGE if await self._terminate_with_error(
stop_reason = "error" ctx, hook_context,
error = final_content error=clean or ctx.spec.error_message or _DEFAULT_ERROR_MESSAGE,
self._append_model_error_placeholder(messages) stop_reason="error",
context.final_content = final_content append_fn="model_error",
context.error = error
context.stop_reason = stop_reason
await hook.after_iteration(context)
should_continue, injection_cycles = await self._try_drain_injections(
spec, messages, None, injection_cycles,
phase="after LLM error", phase="after LLM error",
) ):
if should_continue:
had_injections = True
continue
break
if is_blank_text(clean):
final_content = EMPTY_FINAL_RESPONSE_MESSAGE
stop_reason = "empty_final_response"
error = final_content
self._append_final_message(messages, final_content)
context.final_content = final_content
context.error = error
context.stop_reason = stop_reason
await hook.after_iteration(context)
should_continue, injection_cycles = await self._try_drain_injections(
spec, messages, None, injection_cycles,
phase="after empty response",
)
if should_continue:
had_injections = True
continue continue
break break
messages.append(assistant_message or build_assistant_message( if is_blank_text(clean):
if await self._terminate_with_error(
ctx, hook_context,
error=EMPTY_FINAL_RESPONSE_MESSAGE,
stop_reason="empty_final_response",
append_fn="final",
phase="after empty response",
):
continue
break
ctx.messages.append(assistant_message or build_assistant_message(
clean, clean,
reasoning_content=response.reasoning_content, reasoning_content=response.reasoning_content,
thinking_blocks=response.thinking_blocks, thinking_blocks=response.thinking_blocks,
)) ))
await self._emit_checkpoint( await self._emit_checkpoint(
spec, ctx.spec,
{ {
"phase": "final_response", "phase": "final_response",
"iteration": iteration, "iteration": ctx.iteration,
"model": spec.model, "model": ctx.spec.model,
"assistant_message": messages[-1], "assistant_message": ctx.messages[-1],
"completed_tool_results": [], "completed_tool_results": [],
"pending_tool_calls": [], "pending_tool_calls": [],
}, },
) )
final_content = clean ctx.final_content = clean
context.final_content = final_content hook_context.final_content = ctx.final_content
context.stop_reason = stop_reason hook_context.stop_reason = ctx.stop_reason
await hook.after_iteration(context) await ctx.hook.after_iteration(hook_context)
break break
else: else:
stop_reason = "max_iterations" ctx.stop_reason = "max_iterations"
if spec.max_iterations_message: if ctx.spec.max_iterations_message:
final_content = spec.max_iterations_message.format( ctx.final_content = ctx.spec.max_iterations_message.format(
max_iterations=spec.max_iterations, max_iterations=ctx.spec.max_iterations,
) )
else: else:
final_content = render_template( ctx.final_content = render_template(
"agent/max_iterations_message.md", "agent/max_iterations_message.md",
strip=True, strip=True,
max_iterations=spec.max_iterations, max_iterations=ctx.spec.max_iterations,
) )
self._append_final_message(messages, final_content) self._append_final_message(ctx.messages, ctx.final_content)
# Drain any remaining injections so they are appended to the await self._try_drain_injections(ctx, phase="after max_iterations")
# conversation history instead of being re-published as
# independent inbound messages by _dispatch's finally block.
# We ignore should_continue here because the for-loop has already
# exhausted all iterations.
drained_after_max_iterations, injection_cycles = await self._try_drain_injections(
spec, messages, None, injection_cycles,
phase="after max_iterations",
)
if drained_after_max_iterations:
had_injections = True
return AgentRunResult( return ctx.to_result()
final_content=final_content,
messages=messages, async def _prepare_context(self, ctx: RunContext) -> list[dict[str, Any]]:
tools_used=tools_used, """Build a derived message list safe for the LLM provider.
usage=usage,
stop_reason=stop_reason, Applies orphan cleanup, backfill, microcompaction, budget truncation,
error=error, and history snipping. Never mutates ctx.messages.
tool_events=tool_events, Stores the result in ctx.messages_for_model for downstream retry access.
had_injections=had_injections, """
) try:
msgs = self._drop_orphan_tool_results(ctx.messages)
msgs = self._backfill_missing_tool_results(msgs)
msgs = self._microcompact(msgs)
msgs = self._apply_tool_result_budget(ctx.spec, msgs)
msgs = self._snip_history(ctx.spec, msgs)
msgs = self._drop_orphan_tool_results(msgs)
msgs = self._backfill_missing_tool_results(msgs)
except Exception:
logger.exception(
"Context governance failed on turn {} for {}; applying minimal repair",
ctx.iteration,
ctx.spec.session_key or "default",
)
try:
msgs = self._drop_orphan_tool_results(ctx.messages)
msgs = self._backfill_missing_tool_results(msgs)
except Exception:
msgs = ctx.messages
ctx.messages_for_model = msgs
return msgs
async def _terminate_with_error(
self,
ctx: RunContext,
hook_context: AgentHookContext,
*,
error: str,
stop_reason: str,
append_fn: str,
phase: str,
) -> bool:
"""Set error state, append a message, and optionally continue via injections."""
ctx.final_content = error
ctx.error = error
ctx.stop_reason = stop_reason
if append_fn == "model_error":
self._append_model_error_placeholder(ctx.messages)
else:
self._append_final_message(ctx.messages, error)
hook_context.final_content = ctx.final_content
hook_context.error = ctx.error
hook_context.stop_reason = ctx.stop_reason
await ctx.hook.after_iteration(hook_context)
return await self._try_drain_injections(ctx, phase=phase)
def _build_request_kwargs( def _build_request_kwargs(
self, self,