Merge PR #2761: fix: Retry-After was ignored, causing premature retries

fix: Retry-After was ignored, causing premature retries (now honors header/json hints)
This commit is contained in:
Xubin Ren 2026-04-04 03:10:14 +08:00 committed by GitHub
commit 6fbcecc880
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 172 additions and 23 deletions

View File

@ -401,6 +401,15 @@ class AnthropicProvider(LLMProvider):
# Public API # Public API
# ------------------------------------------------------------------ # ------------------------------------------------------------------
@staticmethod
def _handle_error(e: Exception) -> LLMResponse:
msg = f"Error calling LLM: {e}"
response = getattr(e, "response", None)
retry_after = LLMProvider._extract_retry_after_from_headers(getattr(response, "headers", None))
if retry_after is None:
retry_after = LLMProvider._extract_retry_after(msg)
return LLMResponse(content=msg, finish_reason="error", retry_after=retry_after)
async def chat( async def chat(
self, self,
messages: list[dict[str, Any]], messages: list[dict[str, Any]],
@ -419,7 +428,7 @@ class AnthropicProvider(LLMProvider):
response = await self._client.messages.create(**kwargs) response = await self._client.messages.create(**kwargs)
return self._parse_response(response) return self._parse_response(response)
except Exception as e: except Exception as e:
return LLMResponse(content=f"Error calling LLM: {e}", finish_reason="error") return self._handle_error(e)
async def chat_stream( async def chat_stream(
self, self,
@ -464,7 +473,7 @@ class AnthropicProvider(LLMProvider):
finish_reason="error", finish_reason="error",
) )
except Exception as e: except Exception as e:
return LLMResponse(content=f"Error calling LLM: {e}", finish_reason="error") return self._handle_error(e)
def get_default_model(self) -> str: def get_default_model(self) -> str:
return self.default_model return self.default_model

View File

@ -113,9 +113,14 @@ class AzureOpenAIProvider(LLMProvider):
@staticmethod @staticmethod
def _handle_error(e: Exception) -> LLMResponse: def _handle_error(e: Exception) -> LLMResponse:
body = getattr(e, "body", None) or getattr(getattr(e, "response", None), "text", None) response = getattr(e, "response", None)
msg = f"Error: {str(body).strip()[:500]}" if body else f"Error calling Azure OpenAI: {e}" body = getattr(e, "body", None) or getattr(response, "text", None)
return LLMResponse(content=msg, finish_reason="error") body_text = str(body).strip() if body is not None else ""
msg = f"Error: {body_text[:500]}" if body_text else f"Error calling Azure OpenAI: {e}"
retry_after = LLMProvider._extract_retry_after_from_headers(getattr(response, "headers", None))
if retry_after is None:
retry_after = LLMProvider._extract_retry_after(msg)
return LLMResponse(content=msg, finish_reason="error", retry_after=retry_after)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Public API # Public API
@ -174,4 +179,4 @@ class AzureOpenAIProvider(LLMProvider):
return self._handle_error(e) return self._handle_error(e)
def get_default_model(self) -> str: def get_default_model(self) -> str:
return self.default_model return self.default_model

View File

@ -6,6 +6,8 @@ import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime, timezone
from email.utils import parsedate_to_datetime
from typing import Any from typing import Any
from loguru import logger from loguru import logger
@ -49,6 +51,7 @@ class LLMResponse:
tool_calls: list[ToolCallRequest] = field(default_factory=list) tool_calls: list[ToolCallRequest] = field(default_factory=list)
finish_reason: str = "stop" finish_reason: str = "stop"
usage: dict[str, int] = field(default_factory=dict) usage: dict[str, int] = field(default_factory=dict)
retry_after: float | None = None # Provider supplied retry wait in seconds.
reasoning_content: str | None = None # Kimi, DeepSeek-R1, MiMo etc. reasoning_content: str | None = None # Kimi, DeepSeek-R1, MiMo etc.
thinking_blocks: list[dict] | None = None # Anthropic extended thinking thinking_blocks: list[dict] | None = None # Anthropic extended thinking
@ -334,16 +337,57 @@ class LLMProvider(ABC):
@classmethod @classmethod
def _extract_retry_after(cls, content: str | None) -> float | None: def _extract_retry_after(cls, content: str | None) -> float | None:
text = (content or "").lower() text = (content or "").lower()
match = re.search(r"retry after\s+(\d+(?:\.\d+)?)\s*(ms|milliseconds|s|sec|secs|seconds|m|min|minutes)?", text) patterns = (
if not match: r"retry after\s+(\d+(?:\.\d+)?)\s*(ms|milliseconds|s|sec|secs|seconds|m|min|minutes)?",
return None r"try again in\s+(\d+(?:\.\d+)?)\s*(ms|milliseconds|s|sec|secs|seconds|m|min|minutes)",
value = float(match.group(1)) r"wait\s+(\d+(?:\.\d+)?)\s*(ms|milliseconds|s|sec|secs|seconds|m|min|minutes)\s*before retry",
unit = (match.group(2) or "s").lower() r"retry[_-]?after[\"'\s:=]+(\d+(?:\.\d+)?)",
if unit in {"ms", "milliseconds"}: )
for idx, pattern in enumerate(patterns):
match = re.search(pattern, text)
if not match:
continue
value = float(match.group(1))
unit = match.group(2) if idx < 3 else "s"
return cls._to_retry_seconds(value, unit)
return None
@classmethod
def _to_retry_seconds(cls, value: float, unit: str | None = None) -> float:
normalized_unit = (unit or "s").lower()
if normalized_unit in {"ms", "milliseconds"}:
return max(0.1, value / 1000.0) return max(0.1, value / 1000.0)
if unit in {"m", "min", "minutes"}: if normalized_unit in {"m", "min", "minutes"}:
return value * 60.0 return max(0.1, value * 60.0)
return value return max(0.1, value)
@classmethod
def _extract_retry_after_from_headers(cls, headers: Any) -> float | None:
if not headers:
return None
retry_after: Any = None
if hasattr(headers, "get"):
retry_after = headers.get("retry-after") or headers.get("Retry-After")
if retry_after is None and isinstance(headers, dict):
for key, value in headers.items():
if isinstance(key, str) and key.lower() == "retry-after":
retry_after = value
break
if retry_after is None:
return None
retry_after_text = str(retry_after).strip()
if not retry_after_text:
return None
if re.fullmatch(r"\d+(?:\.\d+)?", retry_after_text):
return cls._to_retry_seconds(float(retry_after_text), "s")
try:
retry_at = parsedate_to_datetime(retry_after_text)
except Exception:
return None
if retry_at.tzinfo is None:
retry_at = retry_at.replace(tzinfo=timezone.utc)
remaining = (retry_at - datetime.now(retry_at.tzinfo)).total_seconds()
return max(0.1, remaining)
async def _sleep_with_heartbeat( async def _sleep_with_heartbeat(
self, self,
@ -416,7 +460,7 @@ class LLMProvider(ABC):
break break
base_delay = delays[min(attempt - 1, len(delays) - 1)] base_delay = delays[min(attempt - 1, len(delays) - 1)]
delay = self._extract_retry_after(response.content) or base_delay delay = response.retry_after or self._extract_retry_after(response.content) or base_delay
if persistent: if persistent:
delay = min(delay, self._PERSISTENT_MAX_DELAY) delay = min(delay, self._PERSISTENT_MAX_DELAY)

View File

@ -79,7 +79,9 @@ class OpenAICodexProvider(LLMProvider):
) )
return LLMResponse(content=content, tool_calls=tool_calls, finish_reason=finish_reason) return LLMResponse(content=content, tool_calls=tool_calls, finish_reason=finish_reason)
except Exception as e: except Exception as e:
return LLMResponse(content=f"Error calling Codex: {e}", finish_reason="error") msg = f"Error calling Codex: {e}"
retry_after = getattr(e, "retry_after", None) or self._extract_retry_after(msg)
return LLMResponse(content=msg, finish_reason="error", retry_after=retry_after)
async def chat( async def chat(
self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None, self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None,
@ -120,6 +122,12 @@ def _build_headers(account_id: str, token: str) -> dict[str, str]:
} }
class _CodexHTTPError(RuntimeError):
def __init__(self, message: str, retry_after: float | None = None):
super().__init__(message)
self.retry_after = retry_after
async def _request_codex( async def _request_codex(
url: str, url: str,
headers: dict[str, str], headers: dict[str, str],
@ -131,7 +139,11 @@ async def _request_codex(
async with client.stream("POST", url, headers=headers, json=body) as response: async with client.stream("POST", url, headers=headers, json=body) as response:
if response.status_code != 200: if response.status_code != 200:
text = await response.aread() text = await response.aread()
raise RuntimeError(_friendly_error(response.status_code, text.decode("utf-8", "ignore"))) retry_after = LLMProvider._extract_retry_after_from_headers(response.headers)
raise _CodexHTTPError(
_friendly_error(response.status_code, text.decode("utf-8", "ignore")),
retry_after=retry_after,
)
return await consume_sse(response, on_content_delta) return await consume_sse(response, on_content_delta)

View File

@ -584,9 +584,14 @@ class OpenAICompatProvider(LLMProvider):
@staticmethod @staticmethod
def _handle_error(e: Exception) -> LLMResponse: def _handle_error(e: Exception) -> LLMResponse:
body = getattr(e, "doc", None) or getattr(getattr(e, "response", None), "text", None) response = getattr(e, "response", None)
msg = f"Error: {body.strip()[:500]}" if body and body.strip() else f"Error calling LLM: {e}" body = getattr(e, "doc", None) or getattr(response, "text", None)
return LLMResponse(content=msg, finish_reason="error") body_text = str(body).strip() if body is not None else ""
msg = f"Error: {body_text[:500]}" if body_text else f"Error calling LLM: {e}"
retry_after = LLMProvider._extract_retry_after_from_headers(getattr(response, "headers", None))
if retry_after is None:
retry_after = LLMProvider._extract_retry_after(msg)
return LLMResponse(content=msg, finish_reason="error", retry_after=retry_after)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Public API # Public API
@ -662,4 +667,4 @@ class OpenAICompatProvider(LLMProvider):
return self._handle_error(e) return self._handle_error(e)
def get_default_model(self) -> str: def get_default_model(self) -> str:
return self.default_model return self.default_model

View File

@ -240,6 +240,39 @@ async def test_chat_with_retry_uses_retry_after_and_emits_wait_progress(monkeypa
assert progress and "7s" in progress[0] assert progress and "7s" in progress[0]
def test_extract_retry_after_supports_common_provider_formats() -> None:
assert LLMProvider._extract_retry_after('{"error":{"retry_after":20}}') == 20.0
assert LLMProvider._extract_retry_after("Rate limit reached, please try again in 20s") == 20.0
assert LLMProvider._extract_retry_after("retry-after: 20") == 20.0
def test_extract_retry_after_from_headers_supports_numeric_and_http_date() -> None:
assert LLMProvider._extract_retry_after_from_headers({"Retry-After": "20"}) == 20.0
assert LLMProvider._extract_retry_after_from_headers({"retry-after": "20"}) == 20.0
assert LLMProvider._extract_retry_after_from_headers(
{"Retry-After": "Wed, 21 Oct 2015 07:28:00 GMT"},
) == 0.1
@pytest.mark.asyncio
async def test_chat_with_retry_prefers_structured_retry_after_when_present(monkeypatch) -> None:
provider = ScriptedProvider([
LLMResponse(content="429 rate limit", finish_reason="error", retry_after=9.0),
LLMResponse(content="ok"),
])
delays: list[float] = []
async def _fake_sleep(delay: float) -> None:
delays.append(delay)
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
assert response.content == "ok"
assert delays == [9.0]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_persistent_retry_aborts_after_ten_identical_transient_errors(monkeypatch) -> None: async def test_persistent_retry_aborts_after_ten_identical_transient_errors(monkeypatch) -> None:
provider = ScriptedProvider([ provider = ScriptedProvider([
@ -263,4 +296,3 @@ async def test_persistent_retry_aborts_after_ten_identical_transient_errors(monk
assert provider.calls == 10 assert provider.calls == 10
assert delays == [1, 2, 4, 4, 4, 4, 4, 4, 4] assert delays == [1, 2, 4, 4, 4, 4, 4, 4, 4]

View File

@ -0,0 +1,42 @@
from types import SimpleNamespace
from nanobot.providers.anthropic_provider import AnthropicProvider
from nanobot.providers.azure_openai_provider import AzureOpenAIProvider
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
def test_openai_compat_error_captures_retry_after_from_headers() -> None:
err = Exception("boom")
err.doc = None
err.response = SimpleNamespace(
text='{"error":{"message":"Rate limit exceeded"}}',
headers={"Retry-After": "20"},
)
response = OpenAICompatProvider._handle_error(err)
assert response.retry_after == 20.0
def test_azure_openai_error_captures_retry_after_from_headers() -> None:
err = Exception("boom")
err.body = {"message": "Rate limit exceeded"}
err.response = SimpleNamespace(
text='{"error":{"message":"Rate limit exceeded"}}',
headers={"Retry-After": "20"},
)
response = AzureOpenAIProvider._handle_error(err)
assert response.retry_after == 20.0
def test_anthropic_error_captures_retry_after_from_headers() -> None:
err = Exception("boom")
err.response = SimpleNamespace(
headers={"Retry-After": "20"},
)
response = AnthropicProvider._handle_error(err)
assert response.retry_after == 20.0