From 5d1ea43858f90a0ba9478af1116b8356a0208a40 Mon Sep 17 00:00:00 2001 From: pikaxinge <2392811793@qq.com> Date: Thu, 2 Apr 2026 18:39:24 +0000 Subject: [PATCH] fix: robust Retry-After extraction across provider backends --- nanobot/providers/anthropic_provider.py | 13 +++- nanobot/providers/azure_openai_provider.py | 13 ++-- nanobot/providers/base.py | 64 ++++++++++++++++--- nanobot/providers/openai_codex_provider.py | 16 ++++- nanobot/providers/openai_compat_provider.py | 13 ++-- tests/providers/test_provider_retry.py | 34 +++++++++- .../test_provider_retry_after_hints.py | 42 ++++++++++++ 7 files changed, 172 insertions(+), 23 deletions(-) create mode 100644 tests/providers/test_provider_retry_after_hints.py diff --git a/nanobot/providers/anthropic_provider.py b/nanobot/providers/anthropic_provider.py index eaec77789..0625d23b7 100644 --- a/nanobot/providers/anthropic_provider.py +++ b/nanobot/providers/anthropic_provider.py @@ -401,6 +401,15 @@ class AnthropicProvider(LLMProvider): # Public API # ------------------------------------------------------------------ + @staticmethod + def _handle_error(e: Exception) -> LLMResponse: + msg = f"Error calling LLM: {e}" + response = getattr(e, "response", None) + retry_after = LLMProvider._extract_retry_after_from_headers(getattr(response, "headers", None)) + if retry_after is None: + retry_after = LLMProvider._extract_retry_after(msg) + return LLMResponse(content=msg, finish_reason="error", retry_after=retry_after) + async def chat( self, messages: list[dict[str, Any]], @@ -419,7 +428,7 @@ class AnthropicProvider(LLMProvider): response = await self._client.messages.create(**kwargs) return self._parse_response(response) except Exception as e: - return LLMResponse(content=f"Error calling LLM: {e}", finish_reason="error") + return self._handle_error(e) async def chat_stream( self, @@ -464,7 +473,7 @@ class AnthropicProvider(LLMProvider): finish_reason="error", ) except Exception as e: - return LLMResponse(content=f"Error calling LLM: {e}", finish_reason="error") + return self._handle_error(e) def get_default_model(self) -> str: return self.default_model diff --git a/nanobot/providers/azure_openai_provider.py b/nanobot/providers/azure_openai_provider.py index 12c74be02..2c42be6b3 100644 --- a/nanobot/providers/azure_openai_provider.py +++ b/nanobot/providers/azure_openai_provider.py @@ -113,9 +113,14 @@ class AzureOpenAIProvider(LLMProvider): @staticmethod def _handle_error(e: Exception) -> LLMResponse: - body = getattr(e, "body", None) or getattr(getattr(e, "response", None), "text", None) - msg = f"Error: {str(body).strip()[:500]}" if body else f"Error calling Azure OpenAI: {e}" - return LLMResponse(content=msg, finish_reason="error") + response = getattr(e, "response", None) + body = getattr(e, "body", None) or getattr(response, "text", None) + body_text = str(body).strip() if body is not None else "" + msg = f"Error: {body_text[:500]}" if body_text else f"Error calling Azure OpenAI: {e}" + retry_after = LLMProvider._extract_retry_after_from_headers(getattr(response, "headers", None)) + if retry_after is None: + retry_after = LLMProvider._extract_retry_after(msg) + return LLMResponse(content=msg, finish_reason="error", retry_after=retry_after) # ------------------------------------------------------------------ # Public API @@ -174,4 +179,4 @@ class AzureOpenAIProvider(LLMProvider): return self._handle_error(e) def get_default_model(self) -> str: - return self.default_model \ No newline at end of file + return self.default_model diff --git a/nanobot/providers/base.py b/nanobot/providers/base.py index 852e9c973..9638d1d80 100644 --- a/nanobot/providers/base.py +++ b/nanobot/providers/base.py @@ -6,6 +6,8 @@ import re from abc import ABC, abstractmethod from collections.abc import Awaitable, Callable from dataclasses import dataclass, field +from datetime import datetime, timezone +from email.utils import parsedate_to_datetime from typing import Any from loguru import logger @@ -49,6 +51,7 @@ class LLMResponse: tool_calls: list[ToolCallRequest] = field(default_factory=list) finish_reason: str = "stop" usage: dict[str, int] = field(default_factory=dict) + retry_after: float | None = None # Provider supplied retry wait in seconds. reasoning_content: str | None = None # Kimi, DeepSeek-R1 etc. thinking_blocks: list[dict] | None = None # Anthropic extended thinking @@ -334,16 +337,57 @@ class LLMProvider(ABC): @classmethod def _extract_retry_after(cls, content: str | None) -> float | None: text = (content or "").lower() - match = re.search(r"retry after\s+(\d+(?:\.\d+)?)\s*(ms|milliseconds|s|sec|secs|seconds|m|min|minutes)?", text) - if not match: - return None - value = float(match.group(1)) - unit = (match.group(2) or "s").lower() - if unit in {"ms", "milliseconds"}: + patterns = ( + r"retry after\s+(\d+(?:\.\d+)?)\s*(ms|milliseconds|s|sec|secs|seconds|m|min|minutes)?", + r"try again in\s+(\d+(?:\.\d+)?)\s*(ms|milliseconds|s|sec|secs|seconds|m|min|minutes)", + r"wait\s+(\d+(?:\.\d+)?)\s*(ms|milliseconds|s|sec|secs|seconds|m|min|minutes)\s*before retry", + r"retry[_-]?after[\"'\s:=]+(\d+(?:\.\d+)?)", + ) + for idx, pattern in enumerate(patterns): + match = re.search(pattern, text) + if not match: + continue + value = float(match.group(1)) + unit = match.group(2) if idx < 3 else "s" + return cls._to_retry_seconds(value, unit) + return None + + @classmethod + def _to_retry_seconds(cls, value: float, unit: str | None = None) -> float: + normalized_unit = (unit or "s").lower() + if normalized_unit in {"ms", "milliseconds"}: return max(0.1, value / 1000.0) - if unit in {"m", "min", "minutes"}: - return value * 60.0 - return value + if normalized_unit in {"m", "min", "minutes"}: + return max(0.1, value * 60.0) + return max(0.1, value) + + @classmethod + def _extract_retry_after_from_headers(cls, headers: Any) -> float | None: + if not headers: + return None + retry_after: Any = None + if hasattr(headers, "get"): + retry_after = headers.get("retry-after") or headers.get("Retry-After") + if retry_after is None and isinstance(headers, dict): + for key, value in headers.items(): + if isinstance(key, str) and key.lower() == "retry-after": + retry_after = value + break + if retry_after is None: + return None + retry_after_text = str(retry_after).strip() + if not retry_after_text: + return None + if re.fullmatch(r"\d+(?:\.\d+)?", retry_after_text): + return cls._to_retry_seconds(float(retry_after_text), "s") + try: + retry_at = parsedate_to_datetime(retry_after_text) + except Exception: + return None + if retry_at.tzinfo is None: + retry_at = retry_at.replace(tzinfo=timezone.utc) + remaining = (retry_at - datetime.now(retry_at.tzinfo)).total_seconds() + return max(0.1, remaining) async def _sleep_with_heartbeat( self, @@ -416,7 +460,7 @@ class LLMProvider(ABC): break base_delay = delays[min(attempt - 1, len(delays) - 1)] - delay = self._extract_retry_after(response.content) or base_delay + delay = response.retry_after or self._extract_retry_after(response.content) or base_delay if persistent: delay = min(delay, self._PERSISTENT_MAX_DELAY) diff --git a/nanobot/providers/openai_codex_provider.py b/nanobot/providers/openai_codex_provider.py index 265b4b106..44cb24786 100644 --- a/nanobot/providers/openai_codex_provider.py +++ b/nanobot/providers/openai_codex_provider.py @@ -79,7 +79,9 @@ class OpenAICodexProvider(LLMProvider): ) return LLMResponse(content=content, tool_calls=tool_calls, finish_reason=finish_reason) except Exception as e: - return LLMResponse(content=f"Error calling Codex: {e}", finish_reason="error") + msg = f"Error calling Codex: {e}" + retry_after = getattr(e, "retry_after", None) or self._extract_retry_after(msg) + return LLMResponse(content=msg, finish_reason="error", retry_after=retry_after) async def chat( self, messages: list[dict[str, Any]], tools: list[dict[str, Any]] | None = None, @@ -120,6 +122,12 @@ def _build_headers(account_id: str, token: str) -> dict[str, str]: } +class _CodexHTTPError(RuntimeError): + def __init__(self, message: str, retry_after: float | None = None): + super().__init__(message) + self.retry_after = retry_after + + async def _request_codex( url: str, headers: dict[str, str], @@ -131,7 +139,11 @@ async def _request_codex( async with client.stream("POST", url, headers=headers, json=body) as response: if response.status_code != 200: text = await response.aread() - raise RuntimeError(_friendly_error(response.status_code, text.decode("utf-8", "ignore"))) + retry_after = LLMProvider._extract_retry_after_from_headers(response.headers) + raise _CodexHTTPError( + _friendly_error(response.status_code, text.decode("utf-8", "ignore")), + retry_after=retry_after, + ) return await consume_sse(response, on_content_delta) diff --git a/nanobot/providers/openai_compat_provider.py b/nanobot/providers/openai_compat_provider.py index 3e0a34fbf..db463773f 100644 --- a/nanobot/providers/openai_compat_provider.py +++ b/nanobot/providers/openai_compat_provider.py @@ -571,9 +571,14 @@ class OpenAICompatProvider(LLMProvider): @staticmethod def _handle_error(e: Exception) -> LLMResponse: - body = getattr(e, "doc", None) or getattr(getattr(e, "response", None), "text", None) - msg = f"Error: {body.strip()[:500]}" if body and body.strip() else f"Error calling LLM: {e}" - return LLMResponse(content=msg, finish_reason="error") + response = getattr(e, "response", None) + body = getattr(e, "doc", None) or getattr(response, "text", None) + body_text = str(body).strip() if body is not None else "" + msg = f"Error: {body_text[:500]}" if body_text else f"Error calling LLM: {e}" + retry_after = LLMProvider._extract_retry_after_from_headers(getattr(response, "headers", None)) + if retry_after is None: + retry_after = LLMProvider._extract_retry_after(msg) + return LLMResponse(content=msg, finish_reason="error", retry_after=retry_after) # ------------------------------------------------------------------ # Public API @@ -646,4 +651,4 @@ class OpenAICompatProvider(LLMProvider): return self._handle_error(e) def get_default_model(self) -> str: - return self.default_model \ No newline at end of file + return self.default_model diff --git a/tests/providers/test_provider_retry.py b/tests/providers/test_provider_retry.py index 1d8facf52..61e58e22a 100644 --- a/tests/providers/test_provider_retry.py +++ b/tests/providers/test_provider_retry.py @@ -240,6 +240,39 @@ async def test_chat_with_retry_uses_retry_after_and_emits_wait_progress(monkeypa assert progress and "7s" in progress[0] +def test_extract_retry_after_supports_common_provider_formats() -> None: + assert LLMProvider._extract_retry_after('{"error":{"retry_after":20}}') == 20.0 + assert LLMProvider._extract_retry_after("Rate limit reached, please try again in 20s") == 20.0 + assert LLMProvider._extract_retry_after("retry-after: 20") == 20.0 + + +def test_extract_retry_after_from_headers_supports_numeric_and_http_date() -> None: + assert LLMProvider._extract_retry_after_from_headers({"Retry-After": "20"}) == 20.0 + assert LLMProvider._extract_retry_after_from_headers({"retry-after": "20"}) == 20.0 + assert LLMProvider._extract_retry_after_from_headers( + {"Retry-After": "Wed, 21 Oct 2015 07:28:00 GMT"}, + ) == 0.1 + + +@pytest.mark.asyncio +async def test_chat_with_retry_prefers_structured_retry_after_when_present(monkeypatch) -> None: + provider = ScriptedProvider([ + LLMResponse(content="429 rate limit", finish_reason="error", retry_after=9.0), + LLMResponse(content="ok"), + ]) + delays: list[float] = [] + + async def _fake_sleep(delay: float) -> None: + delays.append(delay) + + monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep) + + response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}]) + + assert response.content == "ok" + assert delays == [9.0] + + @pytest.mark.asyncio async def test_persistent_retry_aborts_after_ten_identical_transient_errors(monkeypatch) -> None: provider = ScriptedProvider([ @@ -263,4 +296,3 @@ async def test_persistent_retry_aborts_after_ten_identical_transient_errors(monk assert provider.calls == 10 assert delays == [1, 2, 4, 4, 4, 4, 4, 4, 4] - diff --git a/tests/providers/test_provider_retry_after_hints.py b/tests/providers/test_provider_retry_after_hints.py new file mode 100644 index 000000000..b3bbdb0f3 --- /dev/null +++ b/tests/providers/test_provider_retry_after_hints.py @@ -0,0 +1,42 @@ +from types import SimpleNamespace + +from nanobot.providers.anthropic_provider import AnthropicProvider +from nanobot.providers.azure_openai_provider import AzureOpenAIProvider +from nanobot.providers.openai_compat_provider import OpenAICompatProvider + + +def test_openai_compat_error_captures_retry_after_from_headers() -> None: + err = Exception("boom") + err.doc = None + err.response = SimpleNamespace( + text='{"error":{"message":"Rate limit exceeded"}}', + headers={"Retry-After": "20"}, + ) + + response = OpenAICompatProvider._handle_error(err) + + assert response.retry_after == 20.0 + + +def test_azure_openai_error_captures_retry_after_from_headers() -> None: + err = Exception("boom") + err.body = {"message": "Rate limit exceeded"} + err.response = SimpleNamespace( + text='{"error":{"message":"Rate limit exceeded"}}', + headers={"Retry-After": "20"}, + ) + + response = AzureOpenAIProvider._handle_error(err) + + assert response.retry_after == 20.0 + + +def test_anthropic_error_captures_retry_after_from_headers() -> None: + err = Exception("boom") + err.response = SimpleNamespace( + headers={"Retry-After": "20"}, + ) + + response = AnthropicProvider._handle_error(err) + + assert response.retry_after == 20.0