Remove check from the test

This commit is contained in:
Kunal Karmakar 2026-06-04 16:33:26 +00:00 committed by Xubin Ren
parent 9fdc6f892a
commit fa423dffbc

View File

@ -104,12 +104,11 @@ def test_init_missing_key_uses_aad_token_provider(monkeypatch):
@pytest.mark.asyncio
async def test_aad_token_provider_wires_into_sdk_auth_headers(monkeypatch):
"""End-to-end: SDK ``_refresh_api_key`` invokes our callable and the
resulting bearer token shows up in ``auth_headers``.
This is a regression guard against a future refactor that constructs
``_AzureTokenProvider`` but forgets to pass it to ``AsyncOpenAI`` (in
which case the outgoing request would carry no Authorization header).
"""Regression guard: the token provider must be wired into the
OpenAI SDK so ``_refresh_api_key`` pulls a fresh token and the
bearer ends up in ``auth_headers``. Without this, a refactor that
constructs ``_AzureTokenProvider`` but forgets to pass it to
``AsyncOpenAI`` would silently send unauthenticated requests.
"""
access_token = SimpleNamespace(token="token-A", expires_on=time.time() + 3600)
credential_instance = MagicMock()
@ -121,23 +120,18 @@ async def test_aad_token_provider_wires_into_sdk_auth_headers(monkeypatch):
api_key="", api_base="https://res.openai.azure.com",
)
# Before any refresh, the SDK has no key yet -> no auth header.
assert provider._client.api_key == ""
assert provider._client.auth_headers == {}
# Trigger the SDK's refresh path; it must call our async callable.
refreshed = await provider._client._refresh_api_key()
await provider._client._refresh_api_key()
assert refreshed == "token-A"
assert provider._client.api_key == "token-A"
assert provider._client.auth_headers == {"Authorization": "Bearer token-A"}
credential_instance.get_token.assert_awaited_with(
"https://cognitiveservices.azure.com/.default"
)
# A second refresh picks up a rotated token without re-instantiating
# the credential — proves we delegate per-request rather than caching
# the first value.
# Rotated token on second refresh proves per-request delegation (no client-side caching).
credential_instance.get_token = AsyncMock(
return_value=SimpleNamespace(token="token-B", expires_on=time.time() + 3600)
)