refactor: consolidate _parse_retry_after_headers into base class

Merge the three retry-after header parsers (base, OpenAI, Anthropic)
into a single _extract_retry_after_from_headers on LLMProvider that
handles retry-after-ms, case-insensitive lookup, and HTTP date.

Remove the per-provider _parse_retry_after_headers duplicates and
their now-unused email.utils / time imports. Add test for retry-after-ms.

Made-with: Cursor
This commit is contained in:
Xubin Ren 2026-04-06 08:44:52 +00:00
parent aeba9a23e6
commit 35f53a721d
4 changed files with 32 additions and 89 deletions

View File

@ -3,12 +3,10 @@
from __future__ import annotations
import asyncio
import email.utils
import os
import re
import secrets
import string
import time
from collections.abc import Awaitable, Callable
from typing import Any
@ -54,49 +52,6 @@ class AnthropicProvider(LLMProvider):
client_kw["max_retries"] = 0
self._client = AsyncAnthropic(**client_kw)
@staticmethod
def _parse_retry_after_headers(headers: Any) -> float | None:
if headers is None:
return None
def _header_value(name: str) -> Any:
if hasattr(headers, "get"):
value = headers.get(name) or headers.get(name.title())
if value is not None:
return value
if isinstance(headers, dict):
for key, value in headers.items():
if isinstance(key, str) and key.lower() == name.lower():
return value
return None
try:
retry_ms = _header_value("retry-after-ms")
if retry_ms is not None:
value = float(retry_ms) / 1000.0
if value > 0:
return value
except (TypeError, ValueError):
pass
retry_after = _header_value("retry-after")
try:
if retry_after is not None:
value = float(retry_after)
if value > 0:
return value
except (TypeError, ValueError):
pass
if retry_after is None:
return None
retry_date_tuple = email.utils.parsedate_tz(retry_after)
if retry_date_tuple is None:
return None
retry_date = email.utils.mktime_tz(retry_date_tuple)
value = float(retry_date - time.time())
return value if value > 0 else None
@classmethod
def _handle_error(cls, e: Exception) -> LLMResponse:
response = getattr(e, "response", None)
@ -115,7 +70,7 @@ class AnthropicProvider(LLMProvider):
payload = None
payload_text = payload if isinstance(payload, str) else str(payload) if payload is not None else ""
msg = f"Error: {payload_text.strip()[:500]}" if payload_text.strip() else f"Error calling LLM: {e}"
retry_after = cls._parse_retry_after_headers(headers)
retry_after = cls._extract_retry_after_from_headers(headers)
if retry_after is None:
retry_after = LLMProvider._extract_retry_after(msg)

View File

@ -524,14 +524,28 @@ class LLMProvider(ABC):
def _extract_retry_after_from_headers(cls, headers: Any) -> float | None:
if not headers:
return None
retry_after: Any = None
if hasattr(headers, "get"):
retry_after = headers.get("retry-after") or headers.get("Retry-After")
if retry_after is None and isinstance(headers, dict):
for key, value in headers.items():
if isinstance(key, str) and key.lower() == "retry-after":
retry_after = value
break
def _header_value(name: str) -> Any:
if hasattr(headers, "get"):
value = headers.get(name) or headers.get(name.title())
if value is not None:
return value
if isinstance(headers, dict):
for key, value in headers.items():
if isinstance(key, str) and key.lower() == name.lower():
return value
return None
try:
retry_ms = _header_value("retry-after-ms")
if retry_ms is not None:
value = float(retry_ms) / 1000.0
if value > 0:
return value
except (TypeError, ValueError):
pass
retry_after = _header_value("retry-after")
if retry_after is None:
return None
retry_after_text = str(retry_after).strip()

View File

@ -3,13 +3,11 @@
from __future__ import annotations
import asyncio
import email.utils
import hashlib
import importlib.util
import os
import secrets
import string
import time
import uuid
from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING, Any
@ -636,38 +634,6 @@ class OpenAICompatProvider(LLMProvider):
reasoning_content="".join(reasoning_parts) or None,
)
@staticmethod
def _parse_retry_after_headers(headers: Any) -> float | None:
if headers is None:
return None
try:
retry_ms = headers.get("retry-after-ms")
if retry_ms is not None:
value = float(retry_ms) / 1000.0
if value > 0:
return value
except (TypeError, ValueError):
pass
retry_after = headers.get("retry-after")
try:
if retry_after is not None:
value = float(retry_after)
if value > 0:
return value
except (TypeError, ValueError):
pass
if retry_after is None:
return None
retry_date_tuple = email.utils.parsedate_tz(retry_after)
if retry_date_tuple is None:
return None
retry_date = email.utils.mktime_tz(retry_date_tuple)
value = float(retry_date - time.time())
return value if value > 0 else None
@classmethod
def _extract_error_metadata(cls, e: Exception) -> dict[str, Any]:
response = getattr(e, "response", None)
@ -712,7 +678,7 @@ class OpenAICompatProvider(LLMProvider):
"error_kind": error_kind,
"error_type": error_type,
"error_code": error_code,
"error_retry_after_s": cls._parse_retry_after_headers(headers),
"error_retry_after_s": cls._extract_retry_after_from_headers(headers),
"error_should_retry": should_retry,
}

View File

@ -254,6 +254,14 @@ def test_extract_retry_after_from_headers_supports_numeric_and_http_date() -> No
) == 0.1
def test_extract_retry_after_from_headers_supports_retry_after_ms() -> None:
assert LLMProvider._extract_retry_after_from_headers({"retry-after-ms": "250"}) == 0.25
assert LLMProvider._extract_retry_after_from_headers({"Retry-After-Ms": "1000"}) == 1.0
assert LLMProvider._extract_retry_after_from_headers(
{"retry-after-ms": "500", "retry-after": "10"},
) == 0.5
@pytest.mark.asyncio
async def test_chat_with_retry_prefers_structured_retry_after_when_present(monkeypatch) -> None:
provider = ScriptedProvider([