fix: use structured error metadata for app-layer retry

This commit is contained in:
pikaxinge 2026-04-02 18:14:09 +00:00
parent 7113ad34f4
commit b951b37c97
5 changed files with 355 additions and 8 deletions

View File

@ -3,10 +3,12 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import email.utils
import os import os
import re import re
import secrets import secrets
import string import string
import time
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from typing import Any from typing import Any
@ -51,6 +53,73 @@ class AnthropicProvider(LLMProvider):
client_kw["default_headers"] = extra_headers client_kw["default_headers"] = extra_headers
self._client = AsyncAnthropic(**client_kw) self._client = AsyncAnthropic(**client_kw)
@staticmethod
def _parse_retry_after_headers(headers: Any) -> float | None:
if headers is None:
return None
try:
retry_ms = headers.get("retry-after-ms")
if retry_ms is not None:
value = float(retry_ms) / 1000.0
if value > 0:
return value
except (TypeError, ValueError):
pass
retry_after = headers.get("retry-after")
try:
if retry_after is not None:
value = float(retry_after)
if value > 0:
return value
except (TypeError, ValueError):
pass
if retry_after is None:
return None
retry_date_tuple = email.utils.parsedate_tz(retry_after)
if retry_date_tuple is None:
return None
retry_date = email.utils.mktime_tz(retry_date_tuple)
value = float(retry_date - time.time())
return value if value > 0 else None
@classmethod
def _error_response(cls, e: Exception) -> LLMResponse:
response = getattr(e, "response", None)
headers = getattr(response, "headers", None)
status_code = getattr(e, "status_code", None)
if status_code is None and response is not None:
status_code = getattr(response, "status_code", None)
should_retry: bool | None = None
if headers is not None:
raw = headers.get("x-should-retry")
if isinstance(raw, str):
lowered = raw.strip().lower()
if lowered == "true":
should_retry = True
elif lowered == "false":
should_retry = False
error_kind: str | None = None
error_name = e.__class__.__name__.lower()
if "timeout" in error_name:
error_kind = "timeout"
elif "connection" in error_name:
error_kind = "connection"
return LLMResponse(
content=f"Error calling LLM: {e}",
finish_reason="error",
error_status_code=int(status_code) if status_code is not None else None,
error_kind=error_kind,
error_retry_after_s=cls._parse_retry_after_headers(headers),
error_should_retry=should_retry,
)
@staticmethod @staticmethod
def _strip_prefix(model: str) -> str: def _strip_prefix(model: str) -> str:
if model.startswith("anthropic/"): if model.startswith("anthropic/"):
@ -419,7 +488,7 @@ class AnthropicProvider(LLMProvider):
response = await self._client.messages.create(**kwargs) response = await self._client.messages.create(**kwargs)
return self._parse_response(response) return self._parse_response(response)
except Exception as e: except Exception as e:
return LLMResponse(content=f"Error calling LLM: {e}", finish_reason="error") return self._error_response(e)
async def chat_stream( async def chat_stream(
self, self,
@ -462,9 +531,10 @@ class AnthropicProvider(LLMProvider):
f"{idle_timeout_s} seconds" f"{idle_timeout_s} seconds"
), ),
finish_reason="error", finish_reason="error",
error_kind="timeout",
) )
except Exception as e: except Exception as e:
return LLMResponse(content=f"Error calling LLM: {e}", finish_reason="error") return self._error_response(e)
def get_default_model(self) -> str: def get_default_model(self) -> str:
return self.default_model return self.default_model

View File

@ -51,6 +51,11 @@ class LLMResponse:
usage: dict[str, int] = field(default_factory=dict) usage: dict[str, int] = field(default_factory=dict)
reasoning_content: str | None = None # Kimi, DeepSeek-R1 etc. reasoning_content: str | None = None # Kimi, DeepSeek-R1 etc.
thinking_blocks: list[dict] | None = None # Anthropic extended thinking thinking_blocks: list[dict] | None = None # Anthropic extended thinking
# Structured error metadata used by retry policy when finish_reason == "error".
error_status_code: int | None = None
error_kind: str | None = None # e.g. "timeout", "connection"
error_retry_after_s: float | None = None
error_should_retry: bool | None = None
@property @property
def has_tool_calls(self) -> bool: def has_tool_calls(self) -> bool:
@ -88,6 +93,8 @@ class LLMProvider(ABC):
"server error", "server error",
"temporarily unavailable", "temporarily unavailable",
) )
_RETRYABLE_STATUS_CODES = frozenset({408, 409, 429})
_TRANSIENT_ERROR_KINDS = frozenset({"timeout", "connection"})
_SENTINEL = object() _SENTINEL = object()
@ -191,6 +198,23 @@ class LLMProvider(ABC):
err = (content or "").lower() err = (content or "").lower()
return any(marker in err for marker in cls._TRANSIENT_ERROR_MARKERS) return any(marker in err for marker in cls._TRANSIENT_ERROR_MARKERS)
@classmethod
def _is_transient_response(cls, response: LLMResponse) -> bool:
"""Prefer structured error metadata, fallback to text markers for legacy providers."""
if response.error_should_retry is not None:
return bool(response.error_should_retry)
if response.error_status_code is not None:
status = int(response.error_status_code)
if status in cls._RETRYABLE_STATUS_CODES or status >= 500:
return True
kind = (response.error_kind or "").strip().lower()
if kind in cls._TRANSIENT_ERROR_KINDS:
return True
return cls._is_transient_error(response.content)
@staticmethod @staticmethod
def _strip_image_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]] | None: def _strip_image_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]] | None:
"""Replace image_url blocks with text placeholder. Returns None if no images found.""" """Replace image_url blocks with text placeholder. Returns None if no images found."""
@ -345,6 +369,12 @@ class LLMProvider(ABC):
return value * 60.0 return value * 60.0
return value return value
@classmethod
def _extract_retry_after_from_response(cls, response: LLMResponse) -> float | None:
if response.error_retry_after_s is not None and response.error_retry_after_s > 0:
return response.error_retry_after_s
return cls._extract_retry_after(response.content)
async def _sleep_with_heartbeat( async def _sleep_with_heartbeat(
self, self,
delay: float, delay: float,
@ -393,7 +423,7 @@ class LLMProvider(ABC):
last_error_key = error_key last_error_key = error_key
identical_error_count = 1 if error_key else 0 identical_error_count = 1 if error_key else 0
if not self._is_transient_error(response.content): if not self._is_transient_response(response):
stripped = self._strip_image_content(original_messages) stripped = self._strip_image_content(original_messages)
if stripped is not None and stripped != kw["messages"]: if stripped is not None and stripped != kw["messages"]:
logger.warning( logger.warning(
@ -416,7 +446,7 @@ class LLMProvider(ABC):
break break
base_delay = delays[min(attempt - 1, len(delays) - 1)] base_delay = delays[min(attempt - 1, len(delays) - 1)]
delay = self._extract_retry_after(response.content) or base_delay delay = self._extract_retry_after_from_response(response) or base_delay
if persistent: if persistent:
delay = min(delay, self._PERSISTENT_MAX_DELAY) delay = min(delay, self._PERSISTENT_MAX_DELAY)

View File

@ -3,10 +3,12 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import email.utils
import hashlib import hashlib
import os import os
import secrets import secrets
import string import string
import time
import uuid import uuid
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
@ -569,11 +571,85 @@ class OpenAICompatProvider(LLMProvider):
usage=usage, usage=usage,
) )
@staticmethod
def _parse_retry_after_headers(headers: Any) -> float | None:
if headers is None:
return None
try:
retry_ms = headers.get("retry-after-ms")
if retry_ms is not None:
value = float(retry_ms) / 1000.0
if value > 0:
return value
except (TypeError, ValueError):
pass
retry_after = headers.get("retry-after")
try:
if retry_after is not None:
value = float(retry_after)
if value > 0:
return value
except (TypeError, ValueError):
pass
if retry_after is None:
return None
retry_date_tuple = email.utils.parsedate_tz(retry_after)
if retry_date_tuple is None:
return None
retry_date = email.utils.mktime_tz(retry_date_tuple)
value = float(retry_date - time.time())
return value if value > 0 else None
@classmethod
def _extract_error_metadata(cls, e: Exception) -> dict[str, Any]:
response = getattr(e, "response", None)
headers = getattr(response, "headers", None)
status_code = getattr(e, "status_code", None)
if status_code is None and response is not None:
status_code = getattr(response, "status_code", None)
should_retry: bool | None = None
if headers is not None:
raw = headers.get("x-should-retry")
if isinstance(raw, str):
lowered = raw.strip().lower()
if lowered == "true":
should_retry = True
elif lowered == "false":
should_retry = False
error_kind: str | None = None
error_name = e.__class__.__name__.lower()
if "timeout" in error_name:
error_kind = "timeout"
elif "connection" in error_name:
error_kind = "connection"
return {
"error_status_code": int(status_code) if status_code is not None else None,
"error_kind": error_kind,
"error_retry_after_s": cls._parse_retry_after_headers(headers),
"error_should_retry": should_retry,
}
@staticmethod @staticmethod
def _handle_error(e: Exception) -> LLMResponse: def _handle_error(e: Exception) -> LLMResponse:
body = getattr(e, "doc", None) or getattr(getattr(e, "response", None), "text", None) body = (
msg = f"Error: {body.strip()[:500]}" if body and body.strip() else f"Error calling LLM: {e}" getattr(e, "doc", None)
return LLMResponse(content=msg, finish_reason="error") or getattr(e, "body", None)
or getattr(getattr(e, "response", None), "text", None)
)
body_text = body if isinstance(body, str) else str(body) if body is not None else ""
msg = f"Error: {body_text.strip()[:500]}" if body_text.strip() else f"Error calling LLM: {e}"
return LLMResponse(
content=msg,
finish_reason="error",
**OpenAICompatProvider._extract_error_metadata(e),
)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Public API # Public API
@ -641,6 +717,7 @@ class OpenAICompatProvider(LLMProvider):
f"{idle_timeout_s} seconds" f"{idle_timeout_s} seconds"
), ),
finish_reason="error", finish_reason="error",
error_kind="timeout",
) )
except Exception as e: except Exception as e:
return self._handle_error(e) return self._handle_error(e)

View File

@ -0,0 +1,77 @@
from types import SimpleNamespace
from nanobot.providers.anthropic_provider import AnthropicProvider
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
def _fake_response(
*,
status_code: int,
headers: dict[str, str] | None = None,
text: str = "",
) -> SimpleNamespace:
return SimpleNamespace(
status_code=status_code,
headers=headers or {},
text=text,
)
def test_openai_handle_error_extracts_structured_metadata() -> None:
class FakeStatusError(Exception):
pass
err = FakeStatusError("boom")
err.status_code = 409
err.response = _fake_response(
status_code=409,
headers={"retry-after-ms": "250", "x-should-retry": "false"},
text='{"error":"conflict"}',
)
err.body = {"error": "conflict"}
response = OpenAICompatProvider._handle_error(err)
assert response.finish_reason == "error"
assert response.error_status_code == 409
assert response.error_retry_after_s == 0.25
assert response.error_should_retry is False
def test_openai_handle_error_marks_timeout_kind() -> None:
class FakeTimeoutError(Exception):
pass
response = OpenAICompatProvider._handle_error(FakeTimeoutError("timeout"))
assert response.finish_reason == "error"
assert response.error_kind == "timeout"
def test_anthropic_error_response_extracts_structured_metadata() -> None:
class FakeStatusError(Exception):
pass
err = FakeStatusError("boom")
err.status_code = 408
err.response = _fake_response(
status_code=408,
headers={"retry-after": "1.5", "x-should-retry": "true"},
)
response = AnthropicProvider._error_response(err)
assert response.finish_reason == "error"
assert response.error_status_code == 408
assert response.error_retry_after_s == 1.5
assert response.error_should_retry is True
def test_anthropic_error_response_marks_connection_kind() -> None:
class FakeConnectionError(Exception):
pass
response = AnthropicProvider._error_response(FakeConnectionError("connection"))
assert response.finish_reason == "error"
assert response.error_kind == "connection"

View File

@ -240,6 +240,100 @@ async def test_chat_with_retry_uses_retry_after_and_emits_wait_progress(monkeypa
assert progress and "7s" in progress[0] assert progress and "7s" in progress[0]
@pytest.mark.asyncio
async def test_chat_with_retry_retries_structured_status_code_without_keyword(monkeypatch) -> None:
provider = ScriptedProvider([
LLMResponse(
content="request failed",
finish_reason="error",
error_status_code=409,
),
LLMResponse(content="ok"),
])
delays: list[float] = []
async def _fake_sleep(delay: float) -> None:
delays.append(delay)
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
assert response.content == "ok"
assert provider.calls == 2
assert delays == [1]
@pytest.mark.asyncio
async def test_chat_with_retry_retries_structured_timeout_kind(monkeypatch) -> None:
provider = ScriptedProvider([
LLMResponse(
content="request failed",
finish_reason="error",
error_kind="timeout",
),
LLMResponse(content="ok"),
])
delays: list[float] = []
async def _fake_sleep(delay: float) -> None:
delays.append(delay)
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
assert response.content == "ok"
assert provider.calls == 2
assert delays == [1]
@pytest.mark.asyncio
async def test_chat_with_retry_structured_should_retry_false_disables_retry(monkeypatch) -> None:
provider = ScriptedProvider([
LLMResponse(
content="429 rate limit",
finish_reason="error",
error_should_retry=False,
),
])
delays: list[float] = []
async def _fake_sleep(delay: float) -> None:
delays.append(delay)
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
assert response.finish_reason == "error"
assert provider.calls == 1
assert delays == []
@pytest.mark.asyncio
async def test_chat_with_retry_prefers_structured_retry_after(monkeypatch) -> None:
provider = ScriptedProvider([
LLMResponse(
content="429 rate limit, retry after 99s",
finish_reason="error",
error_retry_after_s=0.2,
),
LLMResponse(content="ok"),
])
delays: list[float] = []
async def _fake_sleep(delay: float) -> None:
delays.append(delay)
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
assert response.content == "ok"
assert delays == [0.2]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_persistent_retry_aborts_after_ten_identical_transient_errors(monkeypatch) -> None: async def test_persistent_retry_aborts_after_ten_identical_transient_errors(monkeypatch) -> None:
provider = ScriptedProvider([ provider = ScriptedProvider([
@ -263,4 +357,3 @@ async def test_persistent_retry_aborts_after_ten_identical_transient_errors(monk
assert provider.calls == 10 assert provider.calls == 10
assert delays == [1, 2, 4, 4, 4, 4, 4, 4, 4] assert delays == [1, 2, 4, 4, 4, 4, 4, 4, 4]