refactor(agent): surgical extraction from AgentRunner.run()

Extract 3 focused helpers while keeping the iteration control flow
inline and readable:

- _prepare_context(ctx) – context governance pipeline (orphan cleanup,
  backfill, microcompact, budget, snip)
- _try_drain_injections(ctx, ...) – simplified signature using RunContext
  + bool return (was tuple[bool, int])
- _terminate_with_error(ctx, ...) – deduplicates 3 error-path blocks

Keep RunContext as a minimal mutable state bag; remove LoopAction enum
and the over-extracted response dispatchers (_respond, _handle_tool_calls,
_handle_empty, _handle_length, _handle_final, _handle_max_iterations).
This commit is contained in:
chengyongru 2026-05-18 16:40:32 +08:00
parent d4ade8f680
commit e977a43445

View File

@ -108,6 +108,49 @@ class AgentRunResult:
had_injections: bool = False
@dataclass(slots=True)
class RunContext:
"""Mutable accumulator for state that lives across loop iterations."""
spec: AgentRunSpec
hook: AgentHook
messages: list[dict[str, Any]] # OpenAI-format chat messages (mutated in place)
iteration: int = 0
# Cached by _prepare_context for downstream retry access.
# Stale after _prepare_context returns — do not use elsewhere.
messages_for_model: list[dict[str, Any]] | None = None
final_content: str | None = None
tools_used: list[str] = field(default_factory=list)
usage: dict[str, int] = field(
default_factory=lambda: {"prompt_tokens": 0, "completion_tokens": 0},
)
error: str | None = None
stop_reason: str = "completed"
tool_events: list[dict[str, str]] = field(default_factory=list)
# Per-run throttle counters (keyed by tool name)
external_lookup_counts: dict[str, int] = field(default_factory=dict)
workspace_violation_counts: dict[str, int] = field(default_factory=dict)
empty_content_retries: int = 0
length_recovery_count: int = 0
had_injections: bool = False
injection_cycles: int = 0
def to_result(self) -> AgentRunResult:
return AgentRunResult(
final_content=self.final_content,
messages=self.messages,
tools_used=self.tools_used,
usage=self.usage,
stop_reason=self.stop_reason,
error=self.error,
tool_events=self.tool_events,
had_injections=self.had_injections,
)
class AgentRunner:
"""Run a tool-capable LLM loop without product-layer concerns."""
@ -155,47 +198,42 @@ class AgentRunner:
async def _try_drain_injections(
self,
spec: AgentRunSpec,
messages: list[dict[str, Any]],
assistant_message: dict[str, Any] | None,
injection_cycles: int,
ctx: RunContext,
*,
assistant_message: dict[str, Any] | None = None,
phase: str = "after error",
iteration: int | None = None,
) -> tuple[bool, int]:
"""Drain pending injections. Returns (should_continue, updated_cycles).
) -> bool:
"""Drain pending injections into ctx.messages.
If injections are found and we haven't exceeded _MAX_INJECTION_CYCLES,
append them to *messages* (and emit a checkpoint if *assistant_message*
and *iteration* are both provided) and return (True, cycles+1) so the
caller continues the iteration loop. Otherwise return (False, cycles).
Returns True if injections were appended (caller should continue).
Mutates ctx.injection_cycles and ctx.had_injections directly.
"""
if injection_cycles >= _MAX_INJECTION_CYCLES:
return False, injection_cycles
injections = await self._drain_injections(spec)
if ctx.injection_cycles >= _MAX_INJECTION_CYCLES:
return False
injections = await self._drain_injections(ctx.spec)
if not injections:
return False, injection_cycles
injection_cycles += 1
return False
ctx.injection_cycles += 1
if assistant_message is not None:
messages.append(assistant_message)
if iteration is not None:
await self._emit_checkpoint(
spec,
{
"phase": "final_response",
"iteration": iteration,
"model": spec.model,
"assistant_message": assistant_message,
"completed_tool_results": [],
"pending_tool_calls": [],
},
)
self._append_injected_messages(messages, injections)
ctx.messages.append(assistant_message)
await self._emit_checkpoint(
ctx.spec,
{
"phase": "final_response",
"iteration": ctx.iteration,
"model": ctx.spec.model,
"assistant_message": assistant_message,
"completed_tool_results": [],
"pending_tool_calls": [],
},
)
self._append_injected_messages(ctx.messages, injections)
ctx.had_injections = True
logger.info(
"Injected {} follow-up message(s) {} ({}/{})",
len(injections), phase, injection_cycles, _MAX_INJECTION_CYCLES,
len(injections), phase, ctx.injection_cycles, _MAX_INJECTION_CYCLES,
)
return True, injection_cycles
return True
async def _drain_injections(self, spec: AgentRunSpec) -> list[dict[str, Any]]:
"""Drain pending user messages via the injection callback.
@ -243,55 +281,23 @@ class AgentRunner:
return injected_messages
async def run(self, spec: AgentRunSpec) -> AgentRunResult:
hook = spec.hook or AgentHook()
messages = list(spec.initial_messages)
final_content: str | None = None
tools_used: list[str] = []
usage: dict[str, int] = {"prompt_tokens": 0, "completion_tokens": 0}
error: str | None = None
stop_reason = "completed"
tool_events: list[dict[str, str]] = []
external_lookup_counts: dict[str, int] = {}
# Per-turn throttle for repeated attempts against the same outside target.
workspace_violation_counts: dict[str, int] = {}
empty_content_retries = 0
length_recovery_count = 0
had_injections = False
injection_cycles = 0
ctx = RunContext(
spec=spec,
hook=spec.hook or AgentHook(),
messages=list(spec.initial_messages),
)
for ctx.iteration in range(spec.max_iterations):
messages_for_model = await self._prepare_context(ctx)
hook_context = AgentHookContext(iteration=ctx.iteration, messages=ctx.messages)
await ctx.hook.before_iteration(hook_context)
response = await self._request_model(ctx.spec, messages_for_model, ctx.hook, hook_context)
for iteration in range(spec.max_iterations):
try:
# Keep the persisted conversation untouched. Context governance
# may repair or compact historical messages for the model, but
# those synthetic edits must not shift the append boundary used
# later when the caller saves only the new turn.
messages_for_model = self._drop_orphan_tool_results(messages)
messages_for_model = self._backfill_missing_tool_results(messages_for_model)
messages_for_model = self._microcompact(messages_for_model)
messages_for_model = self._apply_tool_result_budget(spec, messages_for_model)
messages_for_model = self._snip_history(spec, messages_for_model)
# Snipping may have created new orphans; clean them up.
messages_for_model = self._drop_orphan_tool_results(messages_for_model)
messages_for_model = self._backfill_missing_tool_results(messages_for_model)
except Exception:
logger.exception(
"Context governance failed on turn {} for {}; applying minimal repair",
iteration,
spec.session_key or "default",
)
try:
messages_for_model = self._drop_orphan_tool_results(messages)
messages_for_model = self._backfill_missing_tool_results(messages_for_model)
except Exception:
messages_for_model = messages
context = AgentHookContext(iteration=iteration, messages=messages)
await hook.before_iteration(context)
response = await self._request_model(spec, messages_for_model, hook, context)
raw_usage = self._usage_dict(response.usage)
context.response = response
context.usage = dict(raw_usage)
context.tool_calls = list(response.tool_calls)
self._accumulate_usage(usage, raw_usage)
hook_context.response = response
hook_context.usage = dict(raw_usage)
hook_context.tool_calls = list(response.tool_calls)
self._accumulate_usage(ctx.usage, raw_usage)
reasoning_text, cleaned_content = extract_reasoning(
response.reasoning_content,
@ -299,47 +305,49 @@ class AgentRunner:
response.content,
)
response.content = cleaned_content
if reasoning_text and not context.streamed_reasoning:
await hook.emit_reasoning(reasoning_text)
await hook.emit_reasoning_end()
context.streamed_reasoning = True
if reasoning_text and not hook_context.streamed_reasoning:
await ctx.hook.emit_reasoning(reasoning_text)
await ctx.hook.emit_reasoning_end()
hook_context.streamed_reasoning = True
if response.should_execute_tools:
context.tool_calls = list(response.tool_calls)
if hook.wants_streaming():
await hook.on_stream_end(context, resuming=True)
hook_context.tool_calls = list(response.tool_calls)
if ctx.hook.wants_streaming():
await ctx.hook.on_stream_end(hook_context, resuming=True)
openai_tool_calls = [tc.to_openai_tool_call() for tc in response.tool_calls]
assistant_message = build_assistant_message(
response.content or "",
tool_calls=[tc.to_openai_tool_call() for tc in response.tool_calls],
tool_calls=openai_tool_calls,
reasoning_content=response.reasoning_content,
thinking_blocks=response.thinking_blocks,
)
messages.append(assistant_message)
tools_used.extend(tc.name for tc in response.tool_calls)
ctx.messages.append(assistant_message)
ctx.tools_used.extend(tc.name for tc in response.tool_calls)
await self._emit_checkpoint(
spec,
ctx.spec,
{
"phase": "awaiting_tools",
"iteration": iteration,
"model": spec.model,
"iteration": ctx.iteration,
"model": ctx.spec.model,
"assistant_message": assistant_message,
"completed_tool_results": [],
"pending_tool_calls": [tc.to_openai_tool_call() for tc in response.tool_calls],
"pending_tool_calls": openai_tool_calls,
},
)
await hook.before_execute_tools(context)
await ctx.hook.before_execute_tools(hook_context)
results, new_events, fatal_error = await self._execute_tools(
spec,
ctx.spec,
response.tool_calls,
external_lookup_counts,
workspace_violation_counts,
ctx.external_lookup_counts,
ctx.workspace_violation_counts,
)
tool_events.extend(new_events)
context.tool_results = list(results)
context.tool_events = list(new_events)
ctx.tool_events.extend(new_events)
hook_context.tool_results = list(results)
hook_context.tool_events = list(new_events)
completed_tool_results: list[dict[str, Any]] = []
for tool_call, result in zip(response.tool_calls, results):
tool_message = {
@ -347,112 +355,101 @@ class AgentRunner:
"tool_call_id": tool_call.id,
"name": tool_call.name,
"content": self._normalize_tool_result(
spec,
ctx.spec,
tool_call.id,
tool_call.name,
result,
),
}
messages.append(tool_message)
ctx.messages.append(tool_message)
completed_tool_results.append(tool_message)
if fatal_error is not None:
error = f"Error: {type(fatal_error).__name__}: {fatal_error}"
final_content = error
stop_reason = "tool_error"
self._append_final_message(messages, final_content)
context.final_content = final_content
context.error = error
context.stop_reason = stop_reason
await hook.after_iteration(context)
should_continue, injection_cycles = await self._try_drain_injections(
spec, messages, None, injection_cycles,
if await self._terminate_with_error(
ctx, hook_context,
error=f"Error: {type(fatal_error).__name__}: {fatal_error}",
stop_reason="tool_error",
append_fn="final",
phase="after tool error",
)
if should_continue:
had_injections = True
):
continue
break
await self._emit_checkpoint(
spec,
ctx.spec,
{
"phase": "tools_completed",
"iteration": iteration,
"model": spec.model,
"iteration": ctx.iteration,
"model": ctx.spec.model,
"assistant_message": assistant_message,
"completed_tool_results": completed_tool_results,
"pending_tool_calls": [],
},
)
empty_content_retries = 0
length_recovery_count = 0
# Checkpoint 1: drain injections after tools, before next LLM call
_drained, injection_cycles = await self._try_drain_injections(
spec, messages, None, injection_cycles,
phase="after tool execution",
)
if _drained:
had_injections = True
await hook.after_iteration(context)
ctx.empty_content_retries = 0
ctx.length_recovery_count = 0
await self._try_drain_injections(ctx, phase="after tool execution")
await ctx.hook.after_iteration(hook_context)
continue
if response.has_tool_calls:
logger.warning(
"Ignoring tool calls under finish_reason='{}' for {}",
response.finish_reason,
spec.session_key or "default",
ctx.spec.session_key or "default",
)
clean = hook.finalize_content(context, response.content)
clean = ctx.hook.finalize_content(hook_context, response.content)
if response.finish_reason != "error" and is_blank_text(clean):
empty_content_retries += 1
if empty_content_retries < _MAX_EMPTY_RETRIES:
ctx.empty_content_retries += 1
if ctx.empty_content_retries < _MAX_EMPTY_RETRIES:
logger.warning(
"Empty response on turn {} for {} ({}/{}); retrying",
iteration,
spec.session_key or "default",
empty_content_retries,
ctx.iteration,
ctx.spec.session_key or "default",
ctx.empty_content_retries,
_MAX_EMPTY_RETRIES,
)
if hook.wants_streaming():
await hook.on_stream_end(context, resuming=False)
await hook.after_iteration(context)
if ctx.hook.wants_streaming():
await ctx.hook.on_stream_end(hook_context, resuming=False)
await ctx.hook.after_iteration(hook_context)
continue
logger.warning(
"Empty response on turn {} for {} after {} retries; attempting finalization",
iteration,
spec.session_key or "default",
empty_content_retries,
ctx.iteration,
ctx.spec.session_key or "default",
ctx.empty_content_retries,
)
if hook.wants_streaming():
await hook.on_stream_end(context, resuming=False)
response = await self._request_finalization_retry(spec, messages_for_model)
if ctx.hook.wants_streaming():
await ctx.hook.on_stream_end(hook_context, resuming=False)
response = await self._request_finalization_retry(ctx.spec, messages_for_model)
retry_usage = self._usage_dict(response.usage)
self._accumulate_usage(usage, retry_usage)
raw_usage = self._merge_usage(raw_usage, retry_usage)
context.response = response
context.usage = dict(raw_usage)
context.tool_calls = list(response.tool_calls)
clean = hook.finalize_content(context, response.content)
self._accumulate_usage(ctx.usage, retry_usage)
hook_context.response = response
hook_context.usage = dict(self._merge_usage(hook_context.usage, retry_usage))
hook_context.tool_calls = list(response.tool_calls)
clean = ctx.hook.finalize_content(hook_context, response.content)
if response.finish_reason == "length" and not is_blank_text(clean):
length_recovery_count += 1
if length_recovery_count <= _MAX_LENGTH_RECOVERIES:
ctx.length_recovery_count += 1
if ctx.length_recovery_count <= _MAX_LENGTH_RECOVERIES:
logger.info(
"Output truncated on turn {} for {} ({}/{}); continuing",
iteration,
spec.session_key or "default",
length_recovery_count,
ctx.iteration,
ctx.spec.session_key or "default",
ctx.length_recovery_count,
_MAX_LENGTH_RECOVERIES,
)
if hook.wants_streaming():
await hook.on_stream_end(context, resuming=True)
messages.append(build_assistant_message(
if ctx.hook.wants_streaming():
await ctx.hook.on_stream_end(hook_context, resuming=True)
ctx.messages.append(build_assistant_message(
clean,
reasoning_content=response.reasoning_content,
thinking_blocks=response.thinking_blocks,
))
messages.append(build_length_recovery_message())
await hook.after_iteration(context)
ctx.messages.append(build_length_recovery_message())
await ctx.hook.after_iteration(hook_context)
continue
assistant_message: dict[str, Any] | None = None
@ -463,115 +460,131 @@ class AgentRunner:
thinking_blocks=response.thinking_blocks,
)
# Check for mid-turn injections BEFORE signaling stream end.
# If injections are found we keep the stream alive (resuming=True)
# so streaming channels don't prematurely finalize the card.
should_continue, injection_cycles = await self._try_drain_injections(
spec, messages, assistant_message, injection_cycles,
should_continue = await self._try_drain_injections(
ctx,
assistant_message=assistant_message,
phase="after final response",
iteration=iteration,
)
if should_continue:
had_injections = True
if hook.wants_streaming():
await hook.on_stream_end(context, resuming=should_continue)
if ctx.hook.wants_streaming():
await ctx.hook.on_stream_end(hook_context, resuming=should_continue)
if should_continue:
await hook.after_iteration(context)
await ctx.hook.after_iteration(hook_context)
continue
if response.finish_reason == "error":
final_content = clean or spec.error_message or _DEFAULT_ERROR_MESSAGE
stop_reason = "error"
error = final_content
self._append_model_error_placeholder(messages)
context.final_content = final_content
context.error = error
context.stop_reason = stop_reason
await hook.after_iteration(context)
should_continue, injection_cycles = await self._try_drain_injections(
spec, messages, None, injection_cycles,
if await self._terminate_with_error(
ctx, hook_context,
error=clean or ctx.spec.error_message or _DEFAULT_ERROR_MESSAGE,
stop_reason="error",
append_fn="model_error",
phase="after LLM error",
)
if should_continue:
had_injections = True
continue
break
if is_blank_text(clean):
final_content = EMPTY_FINAL_RESPONSE_MESSAGE
stop_reason = "empty_final_response"
error = final_content
self._append_final_message(messages, final_content)
context.final_content = final_content
context.error = error
context.stop_reason = stop_reason
await hook.after_iteration(context)
should_continue, injection_cycles = await self._try_drain_injections(
spec, messages, None, injection_cycles,
phase="after empty response",
)
if should_continue:
had_injections = True
):
continue
break
messages.append(assistant_message or build_assistant_message(
if is_blank_text(clean):
if await self._terminate_with_error(
ctx, hook_context,
error=EMPTY_FINAL_RESPONSE_MESSAGE,
stop_reason="empty_final_response",
append_fn="final",
phase="after empty response",
):
continue
break
ctx.messages.append(assistant_message or build_assistant_message(
clean,
reasoning_content=response.reasoning_content,
thinking_blocks=response.thinking_blocks,
))
await self._emit_checkpoint(
spec,
ctx.spec,
{
"phase": "final_response",
"iteration": iteration,
"model": spec.model,
"assistant_message": messages[-1],
"iteration": ctx.iteration,
"model": ctx.spec.model,
"assistant_message": ctx.messages[-1],
"completed_tool_results": [],
"pending_tool_calls": [],
},
)
final_content = clean
context.final_content = final_content
context.stop_reason = stop_reason
await hook.after_iteration(context)
ctx.final_content = clean
hook_context.final_content = ctx.final_content
hook_context.stop_reason = ctx.stop_reason
await ctx.hook.after_iteration(hook_context)
break
else:
stop_reason = "max_iterations"
if spec.max_iterations_message:
final_content = spec.max_iterations_message.format(
max_iterations=spec.max_iterations,
ctx.stop_reason = "max_iterations"
if ctx.spec.max_iterations_message:
ctx.final_content = ctx.spec.max_iterations_message.format(
max_iterations=ctx.spec.max_iterations,
)
else:
final_content = render_template(
ctx.final_content = render_template(
"agent/max_iterations_message.md",
strip=True,
max_iterations=spec.max_iterations,
max_iterations=ctx.spec.max_iterations,
)
self._append_final_message(messages, final_content)
# Drain any remaining injections so they are appended to the
# conversation history instead of being re-published as
# independent inbound messages by _dispatch's finally block.
# We ignore should_continue here because the for-loop has already
# exhausted all iterations.
drained_after_max_iterations, injection_cycles = await self._try_drain_injections(
spec, messages, None, injection_cycles,
phase="after max_iterations",
)
if drained_after_max_iterations:
had_injections = True
self._append_final_message(ctx.messages, ctx.final_content)
await self._try_drain_injections(ctx, phase="after max_iterations")
return AgentRunResult(
final_content=final_content,
messages=messages,
tools_used=tools_used,
usage=usage,
stop_reason=stop_reason,
error=error,
tool_events=tool_events,
had_injections=had_injections,
)
return ctx.to_result()
async def _prepare_context(self, ctx: RunContext) -> list[dict[str, Any]]:
"""Build a derived message list safe for the LLM provider.
Applies orphan cleanup, backfill, microcompaction, budget truncation,
and history snipping. Never mutates ctx.messages.
Stores the result in ctx.messages_for_model for downstream retry access.
"""
try:
msgs = self._drop_orphan_tool_results(ctx.messages)
msgs = self._backfill_missing_tool_results(msgs)
msgs = self._microcompact(msgs)
msgs = self._apply_tool_result_budget(ctx.spec, msgs)
msgs = self._snip_history(ctx.spec, msgs)
msgs = self._drop_orphan_tool_results(msgs)
msgs = self._backfill_missing_tool_results(msgs)
except Exception:
logger.exception(
"Context governance failed on turn {} for {}; applying minimal repair",
ctx.iteration,
ctx.spec.session_key or "default",
)
try:
msgs = self._drop_orphan_tool_results(ctx.messages)
msgs = self._backfill_missing_tool_results(msgs)
except Exception:
msgs = ctx.messages
ctx.messages_for_model = msgs
return msgs
async def _terminate_with_error(
self,
ctx: RunContext,
hook_context: AgentHookContext,
*,
error: str,
stop_reason: str,
append_fn: str,
phase: str,
) -> bool:
"""Set error state, append a message, and optionally continue via injections."""
ctx.final_content = error
ctx.error = error
ctx.stop_reason = stop_reason
if append_fn == "model_error":
self._append_model_error_placeholder(ctx.messages)
else:
self._append_final_message(ctx.messages, error)
hook_context.final_content = ctx.final_content
hook_context.error = ctx.error
hook_context.stop_reason = ctx.stop_reason
await ctx.hook.after_iteration(hook_context)
return await self._try_drain_injections(ctx, phase=phase)
def _build_request_kwargs(
self,