Skip to content

Commit c24ab17

Browse files
committed
Fixed fixtures for async
1 parent b697e27 commit c24ab17

File tree

1 file changed

+40
-20
lines changed

1 file changed

+40
-20
lines changed

tests/test_asyncio/conftest.py

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33
import random
44
from contextlib import asynccontextmanager as _asynccontextmanager
55
from datetime import datetime, timezone
6+
from enum import Enum
67
from typing import Union
78

89
import jwt
910
import pytest
1011
import pytest_asyncio
1112
from redis_entraid.cred_provider import EntraIdCredentialsProvider, TokenAuthConfig
1213
from redis_entraid.identity_provider import ManagedIdentityType, create_provider_from_managed_identity, \
13-
create_provider_from_service_principal
14+
create_provider_from_service_principal, ManagedIdentityIdType
1415
from mock.mock import Mock
1516
from redis.credentials import CredentialProvider
1617

@@ -27,6 +28,10 @@
2728

2829
from .compat import mock
2930

31+
class AuthType(Enum):
32+
MANAGED_IDENTITY = "managed_identity"
33+
SERVICE_PRINCIPAL = "service_principal"
34+
3035

3136
async def _get_info(redis_url):
3237
client = redis.Redis.from_url(redis_url)
@@ -248,24 +253,34 @@ def mock_identity_provider() -> IdentityProviderInterface:
248253

249254

250255
def identity_provider(request) -> IdentityProviderInterface:
251-
auth_type = os.getenv("IDP_AUTH_TYPE")
256+
if hasattr(request, "param"):
257+
kwargs = request.param.get("idp_kwargs", {})
258+
else:
259+
kwargs = {}
252260

253261
if request.param.get("mock_idp", None) is not None:
254262
return mock_identity_provider()
255263

264+
auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL)
265+
256266
if auth_type == "MANAGED_IDENTITY":
257267
return _get_managed_identity_provider(request)
258268

259269
return _get_service_principal_provider(request)
260270

261271

262272
def _get_managed_identity_provider(request):
263-
authority = os.getenv("IDP_AUTHORITY")
264-
identity_type = ManagedIdentityType(os.getenv("IDP_IDENTITY_TYPE"))
265-
resource = os.getenv("IDP_RESOURCE")
266-
id_type = os.getenv("IDP_ID_TYPE", None)
267-
id_value = os.getenv("IDP_ID_VALUE", None)
268-
kwargs = request.param.get("idp_kwargs", {})
273+
authority = os.getenv("AZURE_AUTHORITY")
274+
resource = os.getenv("AZURE_RESOURCE")
275+
id_value = os.getenv("AZURE_ID_VALUE", None)
276+
277+
if hasattr(request, "param"):
278+
kwargs = request.param.get("idp_kwargs", {})
279+
else:
280+
kwargs = {}
281+
282+
identity_type = kwargs.pop("identity_type", ManagedIdentityType.SYSTEM_ASSIGNED)
283+
id_type = kwargs.pop("id_type", ManagedIdentityIdType.CLIENT_ID)
269284

270285
return create_provider_from_managed_identity(
271286
identity_type=identity_type,
@@ -278,18 +293,23 @@ def _get_managed_identity_provider(request):
278293

279294

280295
def _get_service_principal_provider(request):
281-
client_id = os.getenv("IDP_CLIENT_ID")
282-
client_credential = os.getenv("IDP_CLIENT_CREDENTIAL")
283-
authority = os.getenv("IDP_AUTHORITY")
284-
scopes = os.getenv("IDP_SCOPES", [])
285-
kwargs = request.param.get("idp_kwargs", {})
296+
client_id = os.getenv("AZURE_CLIENT_ID")
297+
client_credential = os.getenv("AZURE_CLIENT_SECRET")
298+
authority = os.getenv("AZURE_AUTHORITY")
299+
scopes = os.getenv("AZURE_REDIS_SCOPES", [])
300+
301+
if hasattr(request, "param"):
302+
kwargs = request.param.get("idp_kwargs", {})
303+
token_kwargs = request.param.get("token_kwargs", {})
304+
timeout = request.param.get("timeout", None)
305+
else:
306+
kwargs = {}
307+
token_kwargs = {}
308+
timeout = None
286309

287310
if isinstance(scopes, str):
288311
scopes = scopes.split(',')
289312

290-
token_kwargs = request.param.get("token_kwargs", {})
291-
timeout = request.param.get("timeout", None)
292-
293313
return create_provider_from_service_principal(
294314
client_id=client_id,
295315
client_credential=client_credential,
@@ -325,10 +345,10 @@ def get_credential_provider(request) -> CredentialProvider:
325345
)
326346

327347
auth_config = TokenAuthConfig(idp)
328-
auth_config.expiration_refresh_ratio(expiration_refresh_ratio)
329-
auth_config.lower_refresh_bound_millis(lower_refresh_bound_millis)
330-
auth_config.max_attempts(max_attempts)
331-
auth_config.delay_in_ms(delay_in_ms)
348+
auth_config.expiration_refresh_ratio = expiration_refresh_ratio
349+
auth_config.lower_refresh_bound_millis = lower_refresh_bound_millis
350+
auth_config.max_attempts = max_attempts
351+
auth_config.delay_in_ms = delay_in_ms
332352

333353
return EntraIdCredentialsProvider(
334354
config=auth_config,

0 commit comments

Comments
 (0)