33import random
44from contextlib import asynccontextmanager as _asynccontextmanager
55from datetime import datetime , timezone
6+ from enum import Enum
67from typing import Union
78
89import jwt
910import pytest
1011import pytest_asyncio
1112from redis_entraid .cred_provider import EntraIdCredentialsProvider , TokenAuthConfig
1213from redis_entraid .identity_provider import ManagedIdentityType , create_provider_from_managed_identity , \
13- create_provider_from_service_principal
14+ create_provider_from_service_principal , ManagedIdentityIdType
1415from mock .mock import Mock
1516from redis .credentials import CredentialProvider
1617
2728
2829from .compat import mock
2930
31+ class AuthType (Enum ):
32+ MANAGED_IDENTITY = "managed_identity"
33+ SERVICE_PRINCIPAL = "service_principal"
34+
3035
3136async def _get_info (redis_url ):
3237 client = redis .Redis .from_url (redis_url )
@@ -248,24 +253,34 @@ def mock_identity_provider() -> IdentityProviderInterface:
248253
249254
250255def identity_provider (request ) -> IdentityProviderInterface :
251- auth_type = os .getenv ("IDP_AUTH_TYPE" )
256+ if hasattr (request , "param" ):
257+ kwargs = request .param .get ("idp_kwargs" , {})
258+ else :
259+ kwargs = {}
252260
253261 if request .param .get ("mock_idp" , None ) is not None :
254262 return mock_identity_provider ()
255263
264+ auth_type = kwargs .pop ("auth_type" , AuthType .SERVICE_PRINCIPAL )
265+
256266 if auth_type == "MANAGED_IDENTITY" :
257267 return _get_managed_identity_provider (request )
258268
259269 return _get_service_principal_provider (request )
260270
261271
262272def _get_managed_identity_provider (request ):
263- authority = os .getenv ("IDP_AUTHORITY" )
264- identity_type = ManagedIdentityType (os .getenv ("IDP_IDENTITY_TYPE" ))
265- resource = os .getenv ("IDP_RESOURCE" )
266- id_type = os .getenv ("IDP_ID_TYPE" , None )
267- id_value = os .getenv ("IDP_ID_VALUE" , None )
268- kwargs = request .param .get ("idp_kwargs" , {})
273+ authority = os .getenv ("AZURE_AUTHORITY" )
274+ resource = os .getenv ("AZURE_RESOURCE" )
275+ id_value = os .getenv ("AZURE_ID_VALUE" , None )
276+
277+ if hasattr (request , "param" ):
278+ kwargs = request .param .get ("idp_kwargs" , {})
279+ else :
280+ kwargs = {}
281+
282+ identity_type = kwargs .pop ("identity_type" , ManagedIdentityType .SYSTEM_ASSIGNED )
283+ id_type = kwargs .pop ("id_type" , ManagedIdentityIdType .CLIENT_ID )
269284
270285 return create_provider_from_managed_identity (
271286 identity_type = identity_type ,
@@ -278,18 +293,23 @@ def _get_managed_identity_provider(request):
278293
279294
280295def _get_service_principal_provider (request ):
281- client_id = os .getenv ("IDP_CLIENT_ID" )
282- client_credential = os .getenv ("IDP_CLIENT_CREDENTIAL" )
283- authority = os .getenv ("IDP_AUTHORITY" )
284- scopes = os .getenv ("IDP_SCOPES" , [])
285- kwargs = request .param .get ("idp_kwargs" , {})
296+ client_id = os .getenv ("AZURE_CLIENT_ID" )
297+ client_credential = os .getenv ("AZURE_CLIENT_SECRET" )
298+ authority = os .getenv ("AZURE_AUTHORITY" )
299+ scopes = os .getenv ("AZURE_REDIS_SCOPES" , [])
300+
301+ if hasattr (request , "param" ):
302+ kwargs = request .param .get ("idp_kwargs" , {})
303+ token_kwargs = request .param .get ("token_kwargs" , {})
304+ timeout = request .param .get ("timeout" , None )
305+ else :
306+ kwargs = {}
307+ token_kwargs = {}
308+ timeout = None
286309
287310 if isinstance (scopes , str ):
288311 scopes = scopes .split (',' )
289312
290- token_kwargs = request .param .get ("token_kwargs" , {})
291- timeout = request .param .get ("timeout" , None )
292-
293313 return create_provider_from_service_principal (
294314 client_id = client_id ,
295315 client_credential = client_credential ,
@@ -325,10 +345,10 @@ def get_credential_provider(request) -> CredentialProvider:
325345 )
326346
327347 auth_config = TokenAuthConfig (idp )
328- auth_config .expiration_refresh_ratio ( expiration_refresh_ratio )
329- auth_config .lower_refresh_bound_millis ( lower_refresh_bound_millis )
330- auth_config .max_attempts ( max_attempts )
331- auth_config .delay_in_ms ( delay_in_ms )
348+ auth_config .expiration_refresh_ratio = expiration_refresh_ratio
349+ auth_config .lower_refresh_bound_millis = lower_refresh_bound_millis
350+ auth_config .max_attempts = max_attempts
351+ auth_config .delay_in_ms = delay_in_ms
332352
333353 return EntraIdCredentialsProvider (
334354 config = auth_config ,
0 commit comments