mirror of
https://github.com/HKUDS/nanobot.git
synced 2026-04-08 20:23:41 +00:00
fix: use structured error metadata for app-layer retry
This commit is contained in:
parent
7113ad34f4
commit
b951b37c97
@ -3,10 +3,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import email.utils
|
||||
import os
|
||||
import re
|
||||
import secrets
|
||||
import string
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
@ -51,6 +53,73 @@ class AnthropicProvider(LLMProvider):
|
||||
client_kw["default_headers"] = extra_headers
|
||||
self._client = AsyncAnthropic(**client_kw)
|
||||
|
||||
@staticmethod
|
||||
def _parse_retry_after_headers(headers: Any) -> float | None:
|
||||
if headers is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
retry_ms = headers.get("retry-after-ms")
|
||||
if retry_ms is not None:
|
||||
value = float(retry_ms) / 1000.0
|
||||
if value > 0:
|
||||
return value
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
|
||||
retry_after = headers.get("retry-after")
|
||||
try:
|
||||
if retry_after is not None:
|
||||
value = float(retry_after)
|
||||
if value > 0:
|
||||
return value
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
|
||||
if retry_after is None:
|
||||
return None
|
||||
retry_date_tuple = email.utils.parsedate_tz(retry_after)
|
||||
if retry_date_tuple is None:
|
||||
return None
|
||||
retry_date = email.utils.mktime_tz(retry_date_tuple)
|
||||
value = float(retry_date - time.time())
|
||||
return value if value > 0 else None
|
||||
|
||||
@classmethod
|
||||
def _error_response(cls, e: Exception) -> LLMResponse:
|
||||
response = getattr(e, "response", None)
|
||||
headers = getattr(response, "headers", None)
|
||||
|
||||
status_code = getattr(e, "status_code", None)
|
||||
if status_code is None and response is not None:
|
||||
status_code = getattr(response, "status_code", None)
|
||||
|
||||
should_retry: bool | None = None
|
||||
if headers is not None:
|
||||
raw = headers.get("x-should-retry")
|
||||
if isinstance(raw, str):
|
||||
lowered = raw.strip().lower()
|
||||
if lowered == "true":
|
||||
should_retry = True
|
||||
elif lowered == "false":
|
||||
should_retry = False
|
||||
|
||||
error_kind: str | None = None
|
||||
error_name = e.__class__.__name__.lower()
|
||||
if "timeout" in error_name:
|
||||
error_kind = "timeout"
|
||||
elif "connection" in error_name:
|
||||
error_kind = "connection"
|
||||
|
||||
return LLMResponse(
|
||||
content=f"Error calling LLM: {e}",
|
||||
finish_reason="error",
|
||||
error_status_code=int(status_code) if status_code is not None else None,
|
||||
error_kind=error_kind,
|
||||
error_retry_after_s=cls._parse_retry_after_headers(headers),
|
||||
error_should_retry=should_retry,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _strip_prefix(model: str) -> str:
|
||||
if model.startswith("anthropic/"):
|
||||
@ -419,7 +488,7 @@ class AnthropicProvider(LLMProvider):
|
||||
response = await self._client.messages.create(**kwargs)
|
||||
return self._parse_response(response)
|
||||
except Exception as e:
|
||||
return LLMResponse(content=f"Error calling LLM: {e}", finish_reason="error")
|
||||
return self._error_response(e)
|
||||
|
||||
async def chat_stream(
|
||||
self,
|
||||
@ -462,9 +531,10 @@ class AnthropicProvider(LLMProvider):
|
||||
f"{idle_timeout_s} seconds"
|
||||
),
|
||||
finish_reason="error",
|
||||
error_kind="timeout",
|
||||
)
|
||||
except Exception as e:
|
||||
return LLMResponse(content=f"Error calling LLM: {e}", finish_reason="error")
|
||||
return self._error_response(e)
|
||||
|
||||
def get_default_model(self) -> str:
|
||||
return self.default_model
|
||||
|
||||
@ -51,6 +51,11 @@ class LLMResponse:
|
||||
usage: dict[str, int] = field(default_factory=dict)
|
||||
reasoning_content: str | None = None # Kimi, DeepSeek-R1 etc.
|
||||
thinking_blocks: list[dict] | None = None # Anthropic extended thinking
|
||||
# Structured error metadata used by retry policy when finish_reason == "error".
|
||||
error_status_code: int | None = None
|
||||
error_kind: str | None = None # e.g. "timeout", "connection"
|
||||
error_retry_after_s: float | None = None
|
||||
error_should_retry: bool | None = None
|
||||
|
||||
@property
|
||||
def has_tool_calls(self) -> bool:
|
||||
@ -88,6 +93,8 @@ class LLMProvider(ABC):
|
||||
"server error",
|
||||
"temporarily unavailable",
|
||||
)
|
||||
_RETRYABLE_STATUS_CODES = frozenset({408, 409, 429})
|
||||
_TRANSIENT_ERROR_KINDS = frozenset({"timeout", "connection"})
|
||||
|
||||
_SENTINEL = object()
|
||||
|
||||
@ -191,6 +198,23 @@ class LLMProvider(ABC):
|
||||
err = (content or "").lower()
|
||||
return any(marker in err for marker in cls._TRANSIENT_ERROR_MARKERS)
|
||||
|
||||
@classmethod
|
||||
def _is_transient_response(cls, response: LLMResponse) -> bool:
|
||||
"""Prefer structured error metadata, fallback to text markers for legacy providers."""
|
||||
if response.error_should_retry is not None:
|
||||
return bool(response.error_should_retry)
|
||||
|
||||
if response.error_status_code is not None:
|
||||
status = int(response.error_status_code)
|
||||
if status in cls._RETRYABLE_STATUS_CODES or status >= 500:
|
||||
return True
|
||||
|
||||
kind = (response.error_kind or "").strip().lower()
|
||||
if kind in cls._TRANSIENT_ERROR_KINDS:
|
||||
return True
|
||||
|
||||
return cls._is_transient_error(response.content)
|
||||
|
||||
@staticmethod
|
||||
def _strip_image_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]] | None:
|
||||
"""Replace image_url blocks with text placeholder. Returns None if no images found."""
|
||||
@ -345,6 +369,12 @@ class LLMProvider(ABC):
|
||||
return value * 60.0
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def _extract_retry_after_from_response(cls, response: LLMResponse) -> float | None:
|
||||
if response.error_retry_after_s is not None and response.error_retry_after_s > 0:
|
||||
return response.error_retry_after_s
|
||||
return cls._extract_retry_after(response.content)
|
||||
|
||||
async def _sleep_with_heartbeat(
|
||||
self,
|
||||
delay: float,
|
||||
@ -393,7 +423,7 @@ class LLMProvider(ABC):
|
||||
last_error_key = error_key
|
||||
identical_error_count = 1 if error_key else 0
|
||||
|
||||
if not self._is_transient_error(response.content):
|
||||
if not self._is_transient_response(response):
|
||||
stripped = self._strip_image_content(original_messages)
|
||||
if stripped is not None and stripped != kw["messages"]:
|
||||
logger.warning(
|
||||
@ -416,7 +446,7 @@ class LLMProvider(ABC):
|
||||
break
|
||||
|
||||
base_delay = delays[min(attempt - 1, len(delays) - 1)]
|
||||
delay = self._extract_retry_after(response.content) or base_delay
|
||||
delay = self._extract_retry_after_from_response(response) or base_delay
|
||||
if persistent:
|
||||
delay = min(delay, self._PERSISTENT_MAX_DELAY)
|
||||
|
||||
|
||||
@ -3,10 +3,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import email.utils
|
||||
import hashlib
|
||||
import os
|
||||
import secrets
|
||||
import string
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import TYPE_CHECKING, Any
|
||||
@ -569,11 +571,85 @@ class OpenAICompatProvider(LLMProvider):
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _parse_retry_after_headers(headers: Any) -> float | None:
|
||||
if headers is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
retry_ms = headers.get("retry-after-ms")
|
||||
if retry_ms is not None:
|
||||
value = float(retry_ms) / 1000.0
|
||||
if value > 0:
|
||||
return value
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
|
||||
retry_after = headers.get("retry-after")
|
||||
try:
|
||||
if retry_after is not None:
|
||||
value = float(retry_after)
|
||||
if value > 0:
|
||||
return value
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
|
||||
if retry_after is None:
|
||||
return None
|
||||
retry_date_tuple = email.utils.parsedate_tz(retry_after)
|
||||
if retry_date_tuple is None:
|
||||
return None
|
||||
retry_date = email.utils.mktime_tz(retry_date_tuple)
|
||||
value = float(retry_date - time.time())
|
||||
return value if value > 0 else None
|
||||
|
||||
@classmethod
|
||||
def _extract_error_metadata(cls, e: Exception) -> dict[str, Any]:
|
||||
response = getattr(e, "response", None)
|
||||
headers = getattr(response, "headers", None)
|
||||
|
||||
status_code = getattr(e, "status_code", None)
|
||||
if status_code is None and response is not None:
|
||||
status_code = getattr(response, "status_code", None)
|
||||
|
||||
should_retry: bool | None = None
|
||||
if headers is not None:
|
||||
raw = headers.get("x-should-retry")
|
||||
if isinstance(raw, str):
|
||||
lowered = raw.strip().lower()
|
||||
if lowered == "true":
|
||||
should_retry = True
|
||||
elif lowered == "false":
|
||||
should_retry = False
|
||||
|
||||
error_kind: str | None = None
|
||||
error_name = e.__class__.__name__.lower()
|
||||
if "timeout" in error_name:
|
||||
error_kind = "timeout"
|
||||
elif "connection" in error_name:
|
||||
error_kind = "connection"
|
||||
|
||||
return {
|
||||
"error_status_code": int(status_code) if status_code is not None else None,
|
||||
"error_kind": error_kind,
|
||||
"error_retry_after_s": cls._parse_retry_after_headers(headers),
|
||||
"error_should_retry": should_retry,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _handle_error(e: Exception) -> LLMResponse:
|
||||
body = getattr(e, "doc", None) or getattr(getattr(e, "response", None), "text", None)
|
||||
msg = f"Error: {body.strip()[:500]}" if body and body.strip() else f"Error calling LLM: {e}"
|
||||
return LLMResponse(content=msg, finish_reason="error")
|
||||
body = (
|
||||
getattr(e, "doc", None)
|
||||
or getattr(e, "body", None)
|
||||
or getattr(getattr(e, "response", None), "text", None)
|
||||
)
|
||||
body_text = body if isinstance(body, str) else str(body) if body is not None else ""
|
||||
msg = f"Error: {body_text.strip()[:500]}" if body_text.strip() else f"Error calling LLM: {e}"
|
||||
return LLMResponse(
|
||||
content=msg,
|
||||
finish_reason="error",
|
||||
**OpenAICompatProvider._extract_error_metadata(e),
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API
|
||||
@ -641,6 +717,7 @@ class OpenAICompatProvider(LLMProvider):
|
||||
f"{idle_timeout_s} seconds"
|
||||
),
|
||||
finish_reason="error",
|
||||
error_kind="timeout",
|
||||
)
|
||||
except Exception as e:
|
||||
return self._handle_error(e)
|
||||
|
||||
77
tests/providers/test_provider_error_metadata.py
Normal file
77
tests/providers/test_provider_error_metadata.py
Normal file
@ -0,0 +1,77 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
from nanobot.providers.anthropic_provider import AnthropicProvider
|
||||
from nanobot.providers.openai_compat_provider import OpenAICompatProvider
|
||||
|
||||
|
||||
def _fake_response(
|
||||
*,
|
||||
status_code: int,
|
||||
headers: dict[str, str] | None = None,
|
||||
text: str = "",
|
||||
) -> SimpleNamespace:
|
||||
return SimpleNamespace(
|
||||
status_code=status_code,
|
||||
headers=headers or {},
|
||||
text=text,
|
||||
)
|
||||
|
||||
|
||||
def test_openai_handle_error_extracts_structured_metadata() -> None:
|
||||
class FakeStatusError(Exception):
|
||||
pass
|
||||
|
||||
err = FakeStatusError("boom")
|
||||
err.status_code = 409
|
||||
err.response = _fake_response(
|
||||
status_code=409,
|
||||
headers={"retry-after-ms": "250", "x-should-retry": "false"},
|
||||
text='{"error":"conflict"}',
|
||||
)
|
||||
err.body = {"error": "conflict"}
|
||||
|
||||
response = OpenAICompatProvider._handle_error(err)
|
||||
|
||||
assert response.finish_reason == "error"
|
||||
assert response.error_status_code == 409
|
||||
assert response.error_retry_after_s == 0.25
|
||||
assert response.error_should_retry is False
|
||||
|
||||
|
||||
def test_openai_handle_error_marks_timeout_kind() -> None:
|
||||
class FakeTimeoutError(Exception):
|
||||
pass
|
||||
|
||||
response = OpenAICompatProvider._handle_error(FakeTimeoutError("timeout"))
|
||||
|
||||
assert response.finish_reason == "error"
|
||||
assert response.error_kind == "timeout"
|
||||
|
||||
|
||||
def test_anthropic_error_response_extracts_structured_metadata() -> None:
|
||||
class FakeStatusError(Exception):
|
||||
pass
|
||||
|
||||
err = FakeStatusError("boom")
|
||||
err.status_code = 408
|
||||
err.response = _fake_response(
|
||||
status_code=408,
|
||||
headers={"retry-after": "1.5", "x-should-retry": "true"},
|
||||
)
|
||||
|
||||
response = AnthropicProvider._error_response(err)
|
||||
|
||||
assert response.finish_reason == "error"
|
||||
assert response.error_status_code == 408
|
||||
assert response.error_retry_after_s == 1.5
|
||||
assert response.error_should_retry is True
|
||||
|
||||
|
||||
def test_anthropic_error_response_marks_connection_kind() -> None:
|
||||
class FakeConnectionError(Exception):
|
||||
pass
|
||||
|
||||
response = AnthropicProvider._error_response(FakeConnectionError("connection"))
|
||||
|
||||
assert response.finish_reason == "error"
|
||||
assert response.error_kind == "connection"
|
||||
@ -240,6 +240,100 @@ async def test_chat_with_retry_uses_retry_after_and_emits_wait_progress(monkeypa
|
||||
assert progress and "7s" in progress[0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_with_retry_retries_structured_status_code_without_keyword(monkeypatch) -> None:
|
||||
provider = ScriptedProvider([
|
||||
LLMResponse(
|
||||
content="request failed",
|
||||
finish_reason="error",
|
||||
error_status_code=409,
|
||||
),
|
||||
LLMResponse(content="ok"),
|
||||
])
|
||||
delays: list[float] = []
|
||||
|
||||
async def _fake_sleep(delay: float) -> None:
|
||||
delays.append(delay)
|
||||
|
||||
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
|
||||
|
||||
response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
|
||||
|
||||
assert response.content == "ok"
|
||||
assert provider.calls == 2
|
||||
assert delays == [1]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_with_retry_retries_structured_timeout_kind(monkeypatch) -> None:
|
||||
provider = ScriptedProvider([
|
||||
LLMResponse(
|
||||
content="request failed",
|
||||
finish_reason="error",
|
||||
error_kind="timeout",
|
||||
),
|
||||
LLMResponse(content="ok"),
|
||||
])
|
||||
delays: list[float] = []
|
||||
|
||||
async def _fake_sleep(delay: float) -> None:
|
||||
delays.append(delay)
|
||||
|
||||
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
|
||||
|
||||
response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
|
||||
|
||||
assert response.content == "ok"
|
||||
assert provider.calls == 2
|
||||
assert delays == [1]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_with_retry_structured_should_retry_false_disables_retry(monkeypatch) -> None:
|
||||
provider = ScriptedProvider([
|
||||
LLMResponse(
|
||||
content="429 rate limit",
|
||||
finish_reason="error",
|
||||
error_should_retry=False,
|
||||
),
|
||||
])
|
||||
delays: list[float] = []
|
||||
|
||||
async def _fake_sleep(delay: float) -> None:
|
||||
delays.append(delay)
|
||||
|
||||
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
|
||||
|
||||
response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
|
||||
|
||||
assert response.finish_reason == "error"
|
||||
assert provider.calls == 1
|
||||
assert delays == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_with_retry_prefers_structured_retry_after(monkeypatch) -> None:
|
||||
provider = ScriptedProvider([
|
||||
LLMResponse(
|
||||
content="429 rate limit, retry after 99s",
|
||||
finish_reason="error",
|
||||
error_retry_after_s=0.2,
|
||||
),
|
||||
LLMResponse(content="ok"),
|
||||
])
|
||||
delays: list[float] = []
|
||||
|
||||
async def _fake_sleep(delay: float) -> None:
|
||||
delays.append(delay)
|
||||
|
||||
monkeypatch.setattr("nanobot.providers.base.asyncio.sleep", _fake_sleep)
|
||||
|
||||
response = await provider.chat_with_retry(messages=[{"role": "user", "content": "hello"}])
|
||||
|
||||
assert response.content == "ok"
|
||||
assert delays == [0.2]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_persistent_retry_aborts_after_ten_identical_transient_errors(monkeypatch) -> None:
|
||||
provider = ScriptedProvider([
|
||||
@ -263,4 +357,3 @@ async def test_persistent_retry_aborts_after_ten_identical_transient_errors(monk
|
||||
assert provider.calls == 10
|
||||
assert delays == [1, 2, 4, 4, 4, 4, 4, 4, 4]
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user