55import  time 
66from  datetime  import  datetime , timezone 
77from  enum  import  Enum 
8- from  typing  import  Callable , TypeVar 
8+ from  typing  import  Callable , TypeVar ,  Union 
99from  unittest  import  mock 
1010from  unittest .mock  import  Mock 
1111from  urllib .parse  import  urlparse 
1717from  redis  import  Sentinel 
1818from  redis .auth .idp  import  IdentityProviderInterface 
1919from  redis .auth .token  import  JWToken 
20+ from  redis .auth .token_manager  import  RetryPolicy , TokenManagerConfig 
2021from  redis .backoff  import  NoBackoff 
2122from  redis .cache  import  (
2223    CacheConfig ,
2930from  redis .credentials  import  CredentialProvider 
3031from  redis .exceptions  import  RedisClusterException 
3132from  redis .retry  import  Retry 
32- from  redis_entraid .cred_provider  import  EntraIdCredentialsProvider , TokenAuthConfig 
33+ from  redis_entraid .cred_provider  import  (
34+     DEFAULT_DELAY_IN_MS ,
35+     DEFAULT_EXPIRATION_REFRESH_RATIO ,
36+     DEFAULT_LOWER_REFRESH_BOUND_MILLIS ,
37+     DEFAULT_MAX_ATTEMPTS ,
38+     DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS ,
39+     EntraIdCredentialsProvider ,
40+ )
3341from  redis_entraid .identity_provider  import  (
3442    ManagedIdentityIdType ,
43+     ManagedIdentityProviderConfig ,
3544    ManagedIdentityType ,
36-     create_provider_from_managed_identity ,
37-     create_provider_from_service_principal ,
45+     ServicePrincipalIdentityProviderConfig ,
46+     _create_provider_from_managed_identity ,
47+     _create_provider_from_service_principal ,
3848)
3949from  tests .ssl_utils  import  get_tls_certificates 
4050
@@ -623,41 +633,58 @@ def identity_provider(request) -> IdentityProviderInterface:
623633        return  mock_identity_provider ()
624634
625635    auth_type  =  kwargs .pop ("auth_type" , AuthType .SERVICE_PRINCIPAL )
636+     config  =  get_identity_provider_config (request = request )
626637
627638    if  auth_type  ==  "MANAGED_IDENTITY" :
628-         return  _get_managed_identity_provider (request )
639+         return  _create_provider_from_managed_identity (config )
640+ 
641+     return  _create_provider_from_service_principal (config )
629642
630-     return  _get_service_principal_provider (request )
631643
644+ def  get_identity_provider_config (
645+     request ,
646+ ) ->  Union [ManagedIdentityProviderConfig , ServicePrincipalIdentityProviderConfig ]:
647+     if  hasattr (request , "param" ):
648+         kwargs  =  request .param .get ("idp_kwargs" , {})
649+     else :
650+         kwargs  =  {}
651+ 
652+     auth_type  =  kwargs .pop ("auth_type" , AuthType .SERVICE_PRINCIPAL )
632653
633- def  _get_managed_identity_provider (request ):
634-     authority  =  os .getenv ("AZURE_AUTHORITY" )
654+     if  auth_type  ==  AuthType .MANAGED_IDENTITY :
655+         return  _get_managed_identity_provider_config (request )
656+ 
657+     return  _get_service_principal_provider_config (request )
658+ 
659+ 
660+ def  _get_managed_identity_provider_config (request ) ->  ManagedIdentityProviderConfig :
635661    resource  =  os .getenv ("AZURE_RESOURCE" )
636-     id_value  =  os .getenv ("AZURE_ID_VALUE " , None )
662+     id_value  =  os .getenv ("AZURE_USER_ASSIGNED_MANAGED_ID " , None )
637663
638664    if  hasattr (request , "param" ):
639665        kwargs  =  request .param .get ("idp_kwargs" , {})
640666    else :
641667        kwargs  =  {}
642668
643669    identity_type  =  kwargs .pop ("identity_type" , ManagedIdentityType .SYSTEM_ASSIGNED )
644-     id_type  =  kwargs .pop ("id_type" , ManagedIdentityIdType .CLIENT_ID )
670+     id_type  =  kwargs .pop ("id_type" , ManagedIdentityIdType .OBJECT_ID )
645671
646-     return  create_provider_from_managed_identity (
672+     return  ManagedIdentityProviderConfig (
647673        identity_type = identity_type ,
648674        resource = resource ,
649675        id_type = id_type ,
650676        id_value = id_value ,
651-         authority = authority ,
652-         ** kwargs ,
677+         kwargs = kwargs ,
653678    )
654679
655680
656- def  _get_service_principal_provider (request ):
681+ def  _get_service_principal_provider_config (
682+     request ,
683+ ) ->  ServicePrincipalIdentityProviderConfig :
657684    client_id  =  os .getenv ("AZURE_CLIENT_ID" )
658685    client_credential  =  os .getenv ("AZURE_CLIENT_SECRET" )
659-     authority  =  os .getenv ("AZURE_AUTHORITY " )
660-     scopes  =  os .getenv ("AZURE_REDIS_SCOPES" , [] )
686+     tenant_id  =  os .getenv ("AZURE_TENANT_ID " )
687+     scopes  =  os .getenv ("AZURE_REDIS_SCOPES" , None )
661688
662689    if  hasattr (request , "param" ):
663690        kwargs  =  request .param .get ("idp_kwargs" , {})
@@ -671,14 +698,14 @@ def _get_service_principal_provider(request):
671698    if  isinstance (scopes , str ):
672699        scopes  =  scopes .split ("," )
673700
674-     return  create_provider_from_service_principal (
701+     return  ServicePrincipalIdentityProviderConfig (
675702        client_id = client_id ,
676703        client_credential = client_credential ,
677704        scopes = scopes ,
678705        timeout = timeout ,
679706        token_kwargs = token_kwargs ,
680-         authority = authority ,
681-         ** kwargs ,
707+         tenant_id = tenant_id ,
708+         app_kwargs = kwargs ,
682709    )
683710
684711
@@ -690,31 +717,29 @@ def get_credential_provider(request) -> CredentialProvider:
690717        return  cred_provider_class (** cred_provider_kwargs )
691718
692719    idp  =  identity_provider (request )
693-     initial_delay_in_ms  =  cred_provider_kwargs .get ("initial_delay_in_ms" , 0 )
694-     block_for_initial  =  cred_provider_kwargs .get ("block_for_initial" , False )
695720    expiration_refresh_ratio  =  cred_provider_kwargs .get (
696-         "expiration_refresh_ratio" , TokenAuthConfig . DEFAULT_EXPIRATION_REFRESH_RATIO 
721+         "expiration_refresh_ratio" , DEFAULT_EXPIRATION_REFRESH_RATIO 
697722    )
698723    lower_refresh_bound_millis  =  cred_provider_kwargs .get (
699-         "lower_refresh_bound_millis" , TokenAuthConfig .DEFAULT_LOWER_REFRESH_BOUND_MILLIS 
700-     )
701-     max_attempts  =  cred_provider_kwargs .get (
702-         "max_attempts" , TokenAuthConfig .DEFAULT_MAX_ATTEMPTS 
724+         "lower_refresh_bound_millis" , DEFAULT_LOWER_REFRESH_BOUND_MILLIS 
703725    )
704-     delay_in_ms  =  cred_provider_kwargs .get (
705-         "delay_in_ms" , TokenAuthConfig .DEFAULT_DELAY_IN_MS 
726+     max_attempts  =  cred_provider_kwargs .get ("max_attempts" , DEFAULT_MAX_ATTEMPTS )
727+     delay_in_ms  =  cred_provider_kwargs .get ("delay_in_ms" , DEFAULT_DELAY_IN_MS )
728+ 
729+     token_mgr_config  =  TokenManagerConfig (
730+         expiration_refresh_ratio = expiration_refresh_ratio ,
731+         lower_refresh_bound_millis = lower_refresh_bound_millis ,
732+         token_request_execution_timeout_in_ms = DEFAULT_TOKEN_REQUEST_EXECUTION_TIMEOUT_IN_MS ,
733+         retry_policy = RetryPolicy (
734+             max_attempts = max_attempts ,
735+             delay_in_ms = delay_in_ms ,
736+         ),
706737    )
707738
708-     auth_config  =  TokenAuthConfig (idp )
709-     auth_config .expiration_refresh_ratio  =  expiration_refresh_ratio 
710-     auth_config .lower_refresh_bound_millis  =  lower_refresh_bound_millis 
711-     auth_config .max_attempts  =  max_attempts 
712-     auth_config .delay_in_ms  =  delay_in_ms 
713- 
714739    return  EntraIdCredentialsProvider (
715-         config = auth_config ,
716-         initial_delay_in_ms = initial_delay_in_ms ,
717-         block_for_initial = block_for_initial ,
740+         identity_provider = idp ,
741+         token_manager_config = token_mgr_config ,
742+         initial_delay_in_ms = delay_in_ms ,
718743    )
719744
720745
0 commit comments