Merge PR #2761: fix: Retry-After was ignored, causing premature retries

fix: Retry-After was ignored, causing premature retries (now honors header/json hints)
This commit is contained in:
Xubin Ren 2026-04-04 03:10:14 +08:00 committed by GitHub
commit 6fbcecc880
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 172 additions and 23 deletions

View File

@ -401,6 +401,15 @@ class AnthropicProvider(LLMProvider):
# Public API
# ------------------------------------------------------------------
@staticmethod
def _handle_error(e: Exception) -> LLMResponse:
msg = f"Error calling LLM: {e}"
response = getattr(e, "response", None)
retry_after = LLMProvider._extract_retry_after_from_headers(getattr(response, "headers", None))
if retry_after is None:
retry_after = LLMProvider._extract_retry_after(msg)
return LLMResponse(content=msg, finish_reason="error", retry_after=retry_after)
async def chat(
self,
messages: list[dict[str, Any]],
@ -419,7 +428,7 @@ class AnthropicProvider(LLMProvider):
response = await self._client.messages.create(**kwargs)
return self._parse_response(response)
except Exception as e:
return LLMResponse(content=f"Error calling LLM: {e}", finish_reason="error")
return self._handle_error(e)
async def chat_stream(
self,
@ -464,7 +473,7 @@ class AnthropicProvider(LLMProvider):
finish_reason="error",
)
except Exception as e:
return LLMResponse(content=f"Error calling LLM: {e}", finish_reason="error")
return self._handle_error(e)
def get_default_model(self) -> str:
return self.default_model

View File

@ -113,9 +113,14 @@ class AzureOpenAIProvider(LLMProvider):
@staticmethod
def _handle_error(e: Exception) -> LLMResponse:
body = getattr(e, "body", None) or getattr(getattr(e, "response", None), "text", None)
msg = f"Error: {str(body).strip()[:500]}" if body else f"Error calling Azure OpenAI: {e}"
return LLMResponse(content=msg, finish_reason="error")
response = getattr(e, "response", None)
body = getattr(e, "body", None) or getattr(response, "text", None)
body_text = str(body).strip() if body is not None else ""
msg = f"Error: {body_text[:500]}" if body_text else f"Error calling Azure OpenAI: {e}"
retry_after = LLMProvider._extract_retry_after_from_headers(getattr(response, "headers", None))
if retry_after is None:
retry_after = LLMProvider._extract_retry_after(msg)
return LLMResponse(content=msg, finish_reason="error", retry_after=retry_after)
# ------------------------------------------------------------------
# Public API
@ -174,4 +179,4 @@ class AzureOpenAIProvider(LLMProvider):
return self._handle_error(e)
def get_default_model(self) -> str:
return self.default_model
return self.default_model

View File

@ -6,6 +6,8 @@ import re
from abc import ABC, abstractmethod
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field
from datetime import datetime, timezone
from email.utils import parsedate_to_datetime
from typing import Any
from loguru import logger
@ -49,6 +51,7 @@ class LLMResponse:
tool_calls: list[ToolCallRequest] = field(default_factory=list)
finish_reason: str = "stop"
usage: dict[str, int] = field(default_factory=dict)
retry_after: float | None = None # Provider supplied retry wait in seconds.
reasoning_content: str | None = None # Kimi, DeepSeek-R1, MiMo etc.
thinking_blocks: list[dict] | None = None # Anthropic extended thinking
@ -334,16 +337,57 @@ class LLMProvider(ABC):
@classmethod
def _extract_retry_after(cls, content: str | None) -> float | None:
text = (content or "").lower()
match = re.search(r"retry after\s+(\d+(?:\.\d+)?)\s*(ms|milliseconds|s|sec|secs|seconds|m|min|minutes)?", text)
if not match:
return None
value = float(match.group(1))
unit = (match.group(2) or "s").lower()
if unit in {"ms", "milliseconds"}:
patterns = (
r"retry after\s+(\d+(?:\.\d+)?)\s*(ms|milliseconds|s|sec|secs|seconds|m|min|minutes)?",
r"try again in\s+(\d+(?:\.\d+)?)\s*(ms|milliseconds|s|sec|secs|seconds|m|min|minutes)",
r"wait\s+(\d+(?:\.\d+)?)\s*(ms|milliseconds|s|sec|secs|seconds|m|min|minutes)\s*before retry",
r"retry[_-]?after[\"'\s:=]+(\d+(?:\.\d+)?)",
)
for idx, pattern in enumerate(patterns):
match = re.search(pattern, text)
if not match:
continue
value = float(match.group(1))
unit = match.group(2) if idx < 3 else "s"
return cls._to_retry_seconds(value, unit)
return None
@classmethod
def _to_retry_seconds(cls, value: float, unit: str | None = None) -> float:
normalized_unit = (unit or "s").lower()
if normalized_unit in {"ms", "milliseconds"}:
return max(0.1, value / 1000.0)
if unit in {"m", "min", "minutes"}:
return value * 60.0
return value
if normalized_unit in {"m", "min", "minutes"}:
return max(0.1, value * 60.0)
return max(0.1, value)
@classmethod
def _extract_retry_after_from_headers(cls, headers: Any) -> float | None:
if not headers:
return None
retry_after: Any = None
if hasattr(headers, "get"):
retry_after = headers.get("retry-after") or headers.get("Retry-After")
if retry_after is None and isinstance(headers, dict):
for key, value in headers.items():
if isinstance(key, str) and key.lower() == "retry-after":
retry_after = value
break
if retry_after is None:
return None
retry_after_text = str(retry_after).strip()
if not retry_after_text:
return None
if re.fullmatch(r"\d+(?:\.\d+)?", retry_after_text):
return cls._to_retry_seconds(float(retry_after_text), "s")
try:
retry_at = parsedate_to_datetime(retry_after_text)
except Exception:
return None
if retry_at.tzinfo is None:
retry_at = retry_at.replace(tzinfo=timezone.utc)
remaining = (retry_at - datetime.now(retry_at.tzinfo)).total_seconds()
return max(0.1, remaining)
async def _sleep_with_heartbeat(
self,
@ -416,7 +460,7 @@ class LLMProvider(ABC):
break
base_delay = delays[min(attempt - 1, len(delays) - 1)]
delay = self._extract_retry_after(response.content) or base_delay
delay = response.retry_after or self._extract_retry_after(response.content) or base_delay
if persistent:
delay = min(delay, self._PERSISTENT_MAX_DELAY)

View File

@ -79,7 +79,9 @@ class OpenAICodexProvider(LLMProvider):
)
return LLMResponse(content=content, tool_calls=tool_calls, finish_reason=finish_reason)
except Exception as e:
return LLMResponse(content=f"Error calling Codex: {e}", finish_reason="error")
msg = f"Error calling Codex: {e}"
retry_after = getattr(e, "retry_after", None) or self._extract_retry_after(msg)
return LLMResponse(content=msg, finish_reason="error", retry_after=retry_after)
async def chat(
self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None,
@ -120,6 +122,12 @@ def _build_headers(account_id: str, token: str) -> dict[str, str]:
}
class _CodexHTTPError(RuntimeError):
def __init__(self, message: str, retry_after: float | None = None):
super().__init__(message)
self.retry_after = retry_after
async def _request_codex(
url: str,
headers: dict[str, str],
@ -131,7 +139,11 @@ async def _request_codex(
async with client.stream("POST", url, headers=headers, json=body) as response:
if response.status_code != 200:
text = await response.aread()
raise RuntimeError(_friendly_error(response.status_code, text.decode("utf-8", "ignore")))
retry_after = LLMProvider._extract_retry_after_from_headers(response.headers)
raise _CodexHTTPError(
_friendly_error(response.status_code, text.decode("utf-8", "ignore")),
retry_after=retry_after,
)
return await consume_sse(response, on_content_delta)

View File

@ -584,9 +584,14 @@ class OpenAICompatProvider(LLMProvider):
@staticmethod
def _handle_error(e: Exception) -> LLMResponse:
body = getattr(e, "doc", None) or getattr(getattr(e, "response", None), "text", None)
msg = f"Error: {body.strip()[:500]}" if body and body.strip() else f"Error calling LLM: {e}"
return LLMResponse(content=msg, finish_reason="error")
response = getattr(e, "response", None)
body = getattr(e, "doc", None) or getattr(response, "text", None)
body_text = str(body).strip() if body is not None else ""
msg = f"Error: {body_text[:500]}" if body_text else f"Error calling LLM: {e}"
retry_after = LLMProvider._extract_retry_after_from_headers(getattr(response, "headers", None))
if retry_after is None:
retry_after = LLMProvider._extract_retry_after(msg)
return LLMResponse(content=msg, finish_reason="error", retry_after=retry_after)
# ------------------------------------------------------------------
# Public API
@ -662,4 +667,4 @@ class OpenAICompatProvider(LLMProvider):
return self._handle_error(e)
def get_default_model(self) -> str:
return self.default_model
return self.default_model

View File

@ -240,6 +240,39 @@ async def test_chat_with_retry_uses_retry_after_and_emits_wait_progress(monkeypa
assert progress and "7s" in progress[0]
def test_extract_retry_after_supports_common_provider_formats() -> None:
assert LLMProvider._extract_retry_after('{"error":{"retry_after":20}}') == 20.0
assert LLMProvider._extract_retry_after("Rate limit reached, please try again in 20s") == 20.0
assert LLMProvider._extract_retry_after("retry-after: 20") == 20.0
def test_extract_retry_after_from_headers_supports_numeric_and_http_date() -> None:
assert LLMProvider._extract_retry_after_from_headers({"Retry-After": "20"}) == 20.0
assert LLMProvider._extract_retry_after_from_headers({"retry-after": "20"}) == 20.0
assert LLMProvider._extract_retry_after_from_headers(
{"Retry-After": "Wed, 21 Oct 2015 07:28:00 GMT"},
) == 0.1
@pytest.mark.asyncio
async def test_chat_with_retry_prefers_structured_retry_after_when_present(monkeypatch) -> None:
provider = ScriptedProvider([
LLMResponse(content="429 rate limit", finish_reason="error", retry_after=9.0),
LLMResponse(content="ok"),
])
delays: list[float] = []
async def _fake_sleep(delay: float) -> None:
delays.append(delay)
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
assert response.content == "ok"
assert delays == [9.0]
@pytest.mark.asyncio
async def test_persistent_retry_aborts_after_ten_identical_transient_errors(monkeypatch) -> None:
provider = ScriptedProvider([
@ -263,4 +296,3 @@ async def test_persistent_retry_aborts_after_ten_identical_transient_errors(monk
assert provider.calls == 10
assert delays == [1, 2, 4, 4, 4, 4, 4, 4, 4]

View File

@ -0,0 +1,42 @@
from types import SimpleNamespace
from nanobot.providers.anthropic_provider import AnthropicProvider
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
def test_openai_compat_error_captures_retry_after_from_headers() -> None:
err = Exception("boom")
err.doc = None
err.response = SimpleNamespace(
text='{"error":{"message":"Rate limit exceeded"}}',
headers={"Retry-After": "20"},
)
response = OpenAICompatProvider._handle_error(err)
assert response.retry_after == 20.0
def test_azure_openai_error_captures_retry_after_from_headers() -> None:
err = Exception("boom")
err.body = {"message": "Rate limit exceeded"}
err.response = SimpleNamespace(
text='{"error":{"message":"Rate limit exceeded"}}',
headers={"Retry-After": "20"},
)
response = AzureOpenAIProvider._handle_error(err)
assert response.retry_after == 20.0
def test_anthropic_error_captures_retry_after_from_headers() -> None:
err = Exception("boom")
err.response = SimpleNamespace(
headers={"Retry-After": "20"},
)
response = AnthropicProvider._handle_error(err)
assert response.retry_after == 20.0