1717from  redis .asyncio .retry  import  Retry 
1818from  redis .auth .idp  import  IdentityProviderInterface 
1919from  redis .auth .token  import  JWToken 
20+ from  redis .auth .token_manager  import  RetryPolicy , TokenManagerConfig 
2021from  redis .backoff  import  NoBackoff 
2122from  redis .credentials  import  CredentialProvider 
22- from  redis_entraid .cred_provider  import  EntraIdCredentialsProvider , TokenAuthConfig 
23+ from  redis_entraid .cred_provider  import  (
24+     DEFAULT_DELAY_IN_MS ,
25+     DEFAULT_EXPIRATION_REFRESH_RATIO ,
26+     DEFAULT_LOWER_REFRESH_BOUND_MILLIS ,
27+     DEFAULT_MAX_ATTEMPTS ,
28+     EntraIdCredentialsProvider ,
29+ )
2330from  redis_entraid .identity_provider  import  (
2431    ManagedIdentityIdType ,
32+     ManagedIdentityProviderConfig ,
2533    ManagedIdentityType ,
26-     create_provider_from_managed_identity ,
27-     create_provider_from_service_principal ,
34+     ServicePrincipalIdentityProviderConfig ,
35+     _create_provider_from_managed_identity ,
36+     _create_provider_from_service_principal ,
2837)
2938from  tests .conftest  import  REDIS_INFO 
3039
@@ -255,41 +264,58 @@ def identity_provider(request) -> IdentityProviderInterface:
255264        return  mock_identity_provider ()
256265
257266    auth_type  =  kwargs .pop ("auth_type" , AuthType .SERVICE_PRINCIPAL )
267+     config  =  get_identity_provider_config (request = request )
258268
259269    if  auth_type  ==  "MANAGED_IDENTITY" :
260-         return  _get_managed_identity_provider (request )
270+         return  _create_provider_from_managed_identity (config )
271+ 
272+     return  _create_provider_from_service_principal (config )
273+ 
274+ 
275+ def  get_identity_provider_config (
276+     request ,
277+ ) ->  Union [ManagedIdentityProviderConfig , ServicePrincipalIdentityProviderConfig ]:
278+     if  hasattr (request , "param" ):
279+         kwargs  =  request .param .get ("idp_kwargs" , {})
280+     else :
281+         kwargs  =  {}
261282
262-     return  _get_service_principal_provider (request )
283+     auth_type  =  kwargs .pop ("auth_type" , AuthType .SERVICE_PRINCIPAL )
284+ 
285+     if  auth_type  ==  AuthType .MANAGED_IDENTITY :
286+         return  _get_managed_identity_provider_config (request )
263287
288+     return  _get_service_principal_provider_config (request )
264289
265- def   _get_managed_identity_provider ( request ): 
266-      authority   =   os . getenv ( "AZURE_AUTHORITY" ) 
290+ 
291+ def   _get_managed_identity_provider_config ( request )  ->   ManagedIdentityProviderConfig : 
267292    resource  =  os .getenv ("AZURE_RESOURCE" )
268-     id_value  =  os .getenv ("AZURE_ID_VALUE " , None )
293+     id_value  =  os .getenv ("AZURE_USER_ASSIGNED_MANAGED_ID " , None )
269294
270295    if  hasattr (request , "param" ):
271296        kwargs  =  request .param .get ("idp_kwargs" , {})
272297    else :
273298        kwargs  =  {}
274299
275300    identity_type  =  kwargs .pop ("identity_type" , ManagedIdentityType .SYSTEM_ASSIGNED )
276-     id_type  =  kwargs .pop ("id_type" , ManagedIdentityIdType .CLIENT_ID )
301+     id_type  =  kwargs .pop ("id_type" , ManagedIdentityIdType .OBJECT_ID )
277302
278-     return  create_provider_from_managed_identity (
303+     return  ManagedIdentityProviderConfig (
279304        identity_type = identity_type ,
280305        resource = resource ,
281306        id_type = id_type ,
282307        id_value = id_value ,
283-         authority = authority ,
284-         ** kwargs ,
308+         kwargs = kwargs ,
285309    )
286310
287311
288- def  _get_service_principal_provider (request ):
312+ def  _get_service_principal_provider_config (
313+     request ,
314+ ) ->  ServicePrincipalIdentityProviderConfig :
289315    client_id  =  os .getenv ("AZURE_CLIENT_ID" )
290316    client_credential  =  os .getenv ("AZURE_CLIENT_SECRET" )
291-     authority  =  os .getenv ("AZURE_AUTHORITY " )
292-     scopes  =  os .getenv ("AZURE_REDIS_SCOPES" , [] )
317+     tenant_id  =  os .getenv ("AZURE_TENANT_ID " )
318+     scopes  =  os .getenv ("AZURE_REDIS_SCOPES" , None )
293319
294320    if  hasattr (request , "param" ):
295321        kwargs  =  request .param .get ("idp_kwargs" , {})
@@ -303,14 +329,14 @@ def _get_service_principal_provider(request):
303329    if  isinstance (scopes , str ):
304330        scopes  =  scopes .split ("," )
305331
306-     return  create_provider_from_service_principal (
332+     return  ServicePrincipalIdentityProviderConfig (
307333        client_id = client_id ,
308334        client_credential = client_credential ,
309335        scopes = scopes ,
310336        timeout = timeout ,
311337        token_kwargs = token_kwargs ,
312-         authority = authority ,
313-         ** kwargs ,
338+         tenant_id = tenant_id ,
339+         app_kwargs = kwargs ,
314340    )
315341
316342
@@ -322,31 +348,29 @@ def get_credential_provider(request) -> CredentialProvider:
322348        return  cred_provider_class (** cred_provider_kwargs )
323349
324350    idp  =  identity_provider (request )
325-     initial_delay_in_ms  =  cred_provider_kwargs .get ("initial_delay_in_ms" , 0 )
326-     block_for_initial  =  cred_provider_kwargs .get ("block_for_initial" , False )
327351    expiration_refresh_ratio  =  cred_provider_kwargs .get (
328-         "expiration_refresh_ratio" , TokenAuthConfig . DEFAULT_EXPIRATION_REFRESH_RATIO 
352+         "expiration_refresh_ratio" , DEFAULT_EXPIRATION_REFRESH_RATIO 
329353    )
330354    lower_refresh_bound_millis  =  cred_provider_kwargs .get (
331-         "lower_refresh_bound_millis" , TokenAuthConfig .DEFAULT_LOWER_REFRESH_BOUND_MILLIS 
332-     )
333-     max_attempts  =  cred_provider_kwargs .get (
334-         "max_attempts" , TokenAuthConfig .DEFAULT_MAX_ATTEMPTS 
355+         "lower_refresh_bound_millis" , DEFAULT_LOWER_REFRESH_BOUND_MILLIS 
335356    )
336-     delay_in_ms  =  cred_provider_kwargs .get (
337-         "delay_in_ms" , TokenAuthConfig .DEFAULT_DELAY_IN_MS 
357+     max_attempts  =  cred_provider_kwargs .get ("max_attempts" , DEFAULT_MAX_ATTEMPTS )
358+     delay_in_ms  =  cred_provider_kwargs .get ("delay_in_ms" , DEFAULT_DELAY_IN_MS )
359+ 
360+     token_mgr_config  =  TokenManagerConfig (
361+         expiration_refresh_ratio = expiration_refresh_ratio ,
362+         lower_refresh_bound_millis = lower_refresh_bound_millis ,
363+         token_request_execution_timeout_in_ms = DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS ,  # noqa 
364+         retry_policy = RetryPolicy (
365+             max_attempts = max_attempts ,
366+             delay_in_ms = delay_in_ms ,
367+         ),
338368    )
339369
340-     auth_config  =  TokenAuthConfig (idp )
341-     auth_config .expiration_refresh_ratio  =  expiration_refresh_ratio 
342-     auth_config .lower_refresh_bound_millis  =  lower_refresh_bound_millis 
343-     auth_config .max_attempts  =  max_attempts 
344-     auth_config .delay_in_ms  =  delay_in_ms 
345- 
346370    return  EntraIdCredentialsProvider (
347-         config = auth_config ,
348-         initial_delay_in_ms = initial_delay_in_ms ,
349-         block_for_initial = block_for_initial ,
371+         identity_provider = idp ,
372+         token_manager_config = token_mgr_config ,
373+         initial_delay_in_ms = delay_in_ms ,
350374    )
351375
352376
0 commit comments