1717from redis .asyncio .retry import Retry
1818from redis .auth .idp import IdentityProviderInterface
1919from redis .auth .token import JWToken
20+ from redis .auth .token_manager import RetryPolicy , TokenManagerConfig
2021from redis .backoff import NoBackoff
2122from redis .credentials import CredentialProvider
22- from redis_entraid .cred_provider import EntraIdCredentialsProvider , TokenAuthConfig
23+ from redis_entraid .cred_provider import (
24+ DEFAULT_DELAY_IN_MS ,
25+ DEFAULT_EXPIRATION_REFRESH_RATIO ,
26+ DEFAULT_LOWER_REFRESH_BOUND_MILLIS ,
27+ DEFAULT_MAX_ATTEMPTS ,
28+ EntraIdCredentialsProvider ,
29+ )
2330from redis_entraid .identity_provider import (
2431 ManagedIdentityIdType ,
32+ ManagedIdentityProviderConfig ,
2533 ManagedIdentityType ,
26- create_provider_from_managed_identity ,
27- create_provider_from_service_principal ,
34+ ServicePrincipalIdentityProviderConfig ,
35+ _create_provider_from_managed_identity ,
36+ _create_provider_from_service_principal ,
2837)
2938from tests .conftest import REDIS_INFO
3039
@@ -255,41 +264,58 @@ def identity_provider(request) -> IdentityProviderInterface:
255264 return mock_identity_provider ()
256265
257266 auth_type = kwargs .pop ("auth_type" , AuthType .SERVICE_PRINCIPAL )
267+ config = get_identity_provider_config (request = request )
258268
259269 if auth_type == "MANAGED_IDENTITY" :
260- return _get_managed_identity_provider (request )
270+ return _create_provider_from_managed_identity (config )
271+
272+ return _create_provider_from_service_principal (config )
273+
274+
275+ def get_identity_provider_config (
276+ request ,
277+ ) -> Union [ManagedIdentityProviderConfig , ServicePrincipalIdentityProviderConfig ]:
278+ if hasattr (request , "param" ):
279+ kwargs = request .param .get ("idp_kwargs" , {})
280+ else :
281+ kwargs = {}
261282
262- return _get_service_principal_provider (request )
283+ auth_type = kwargs .pop ("auth_type" , AuthType .SERVICE_PRINCIPAL )
284+
285+ if auth_type == AuthType .MANAGED_IDENTITY :
286+ return _get_managed_identity_provider_config (request )
263287
288+ return _get_service_principal_provider_config (request )
264289
265- def _get_managed_identity_provider ( request ):
266- authority = os . getenv ( "AZURE_AUTHORITY" )
290+
291+ def _get_managed_identity_provider_config ( request ) -> ManagedIdentityProviderConfig :
267292 resource = os .getenv ("AZURE_RESOURCE" )
268- id_value = os .getenv ("AZURE_ID_VALUE " , None )
293+ id_value = os .getenv ("AZURE_USER_ASSIGNED_MANAGED_ID " , None )
269294
270295 if hasattr (request , "param" ):
271296 kwargs = request .param .get ("idp_kwargs" , {})
272297 else :
273298 kwargs = {}
274299
275300 identity_type = kwargs .pop ("identity_type" , ManagedIdentityType .SYSTEM_ASSIGNED )
276- id_type = kwargs .pop ("id_type" , ManagedIdentityIdType .CLIENT_ID )
301+ id_type = kwargs .pop ("id_type" , ManagedIdentityIdType .OBJECT_ID )
277302
278- return create_provider_from_managed_identity (
303+ return ManagedIdentityProviderConfig (
279304 identity_type = identity_type ,
280305 resource = resource ,
281306 id_type = id_type ,
282307 id_value = id_value ,
283- authority = authority ,
284- ** kwargs ,
308+ kwargs = kwargs ,
285309 )
286310
287311
288- def _get_service_principal_provider (request ):
312+ def _get_service_principal_provider_config (
313+ request ,
314+ ) -> ServicePrincipalIdentityProviderConfig :
289315 client_id = os .getenv ("AZURE_CLIENT_ID" )
290316 client_credential = os .getenv ("AZURE_CLIENT_SECRET" )
291- authority = os .getenv ("AZURE_AUTHORITY " )
292- scopes = os .getenv ("AZURE_REDIS_SCOPES" , [] )
317+ tenant_id = os .getenv ("AZURE_TENANT_ID " )
318+ scopes = os .getenv ("AZURE_REDIS_SCOPES" , None )
293319
294320 if hasattr (request , "param" ):
295321 kwargs = request .param .get ("idp_kwargs" , {})
@@ -303,14 +329,14 @@ def _get_service_principal_provider(request):
303329 if isinstance (scopes , str ):
304330 scopes = scopes .split ("," )
305331
306- return create_provider_from_service_principal (
332+ return ServicePrincipalIdentityProviderConfig (
307333 client_id = client_id ,
308334 client_credential = client_credential ,
309335 scopes = scopes ,
310336 timeout = timeout ,
311337 token_kwargs = token_kwargs ,
312- authority = authority ,
313- ** kwargs ,
338+ tenant_id = tenant_id ,
339+ app_kwargs = kwargs ,
314340 )
315341
316342
@@ -322,31 +348,29 @@ def get_credential_provider(request) -> CredentialProvider:
322348 return cred_provider_class (** cred_provider_kwargs )
323349
324350 idp = identity_provider (request )
325- initial_delay_in_ms = cred_provider_kwargs .get ("initial_delay_in_ms" , 0 )
326- block_for_initial = cred_provider_kwargs .get ("block_for_initial" , False )
327351 expiration_refresh_ratio = cred_provider_kwargs .get (
328- "expiration_refresh_ratio" , TokenAuthConfig . DEFAULT_EXPIRATION_REFRESH_RATIO
352+ "expiration_refresh_ratio" , DEFAULT_EXPIRATION_REFRESH_RATIO
329353 )
330354 lower_refresh_bound_millis = cred_provider_kwargs .get (
331- "lower_refresh_bound_millis" , TokenAuthConfig .DEFAULT_LOWER_REFRESH_BOUND_MILLIS
332- )
333- max_attempts = cred_provider_kwargs .get (
334- "max_attempts" , TokenAuthConfig .DEFAULT_MAX_ATTEMPTS
355+ "lower_refresh_bound_millis" , DEFAULT_LOWER_REFRESH_BOUND_MILLIS
335356 )
336- delay_in_ms = cred_provider_kwargs .get (
337- "delay_in_ms" , TokenAuthConfig .DEFAULT_DELAY_IN_MS
357+ max_attempts = cred_provider_kwargs .get ("max_attempts" , DEFAULT_MAX_ATTEMPTS )
358+ delay_in_ms = cred_provider_kwargs .get ("delay_in_ms" , DEFAULT_DELAY_IN_MS )
359+
360+ token_mgr_config = TokenManagerConfig (
361+ expiration_refresh_ratio = expiration_refresh_ratio ,
362+ lower_refresh_bound_millis = lower_refresh_bound_millis ,
363+ token_request_execution_timeout_in_ms = DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS , # noqa
364+ retry_policy = RetryPolicy (
365+ max_attempts = max_attempts ,
366+ delay_in_ms = delay_in_ms ,
367+ ),
338368 )
339369
340- auth_config = TokenAuthConfig (idp )
341- auth_config .expiration_refresh_ratio = expiration_refresh_ratio
342- auth_config .lower_refresh_bound_millis = lower_refresh_bound_millis
343- auth_config .max_attempts = max_attempts
344- auth_config .delay_in_ms = delay_in_ms
345-
346370 return EntraIdCredentialsProvider (
347- config = auth_config ,
348- initial_delay_in_ms = initial_delay_in_ms ,
349- block_for_initial = block_for_initial ,
371+ identity_provider = idp ,
372+ token_manager_config = token_mgr_config ,
373+ initial_delay_in_ms = delay_in_ms ,
350374 )
351375
352376
0 commit comments