5858from  scipy  import  sparse  as  sp 
5959
6060if  sklearn_check_version ('1.2' ):
61-     from  sklearn .utils ._param_validation  import  Interval 
61+     from  sklearn .utils ._param_validation  import  Interval ,  StrOptions 
6262
6363
6464class  BaseRandomForest (ABC ):
@@ -193,7 +193,8 @@ class RandomForestClassifier(sklearn_RandomForestClassifier, BaseRandomForest):
193193        _parameter_constraints : dict  =  {
194194            ** sklearn_RandomForestClassifier ._parameter_constraints ,
195195            "max_bins" : [Interval (numbers .Integral , 2 , None , closed = "left" )],
196-             "min_bin_size" : [Interval (numbers .Integral , 1 , None , closed = "left" )]
196+             "min_bin_size" : [Interval (numbers .Integral , 1 , None , closed = "left" )],
197+             "splitter_mode" : [StrOptions ({"best" , "random" })]
197198        }
198199
199200    if  sklearn_check_version ('1.0' ):
@@ -218,7 +219,8 @@ def __init__(
218219                ccp_alpha = 0.0 ,
219220                max_samples = None ,
220221                max_bins = 256 ,
221-                 min_bin_size = 1 ):
222+                 min_bin_size = 1 ,
223+                 splitter_mode = 'best' ):
222224            super (RandomForestClassifier , self ).__init__ (
223225                n_estimators = n_estimators ,
224226                criterion = criterion ,
@@ -243,6 +245,7 @@ def __init__(
243245            self .max_bins  =  max_bins 
244246            self .min_bin_size  =  min_bin_size 
245247            self .min_impurity_split  =  None 
248+             self .splitter_mode  =  splitter_mode 
246249            # self._estimator = DecisionTreeClassifier() 
247250    else :
248251        def  __init__ (self ,
@@ -266,7 +269,8 @@ def __init__(self,
266269                     ccp_alpha = 0.0 ,
267270                     max_samples = None ,
268271                     max_bins = 256 ,
269-                      min_bin_size = 1 ):
272+                      min_bin_size = 1 ,
273+                      splitter_mode = 'best' ):
270274            super (RandomForestClassifier , self ).__init__ (
271275                n_estimators = n_estimators ,
272276                criterion = criterion ,
@@ -294,6 +298,7 @@ def __init__(self,
294298            self .max_bins  =  max_bins 
295299            self .min_bin_size  =  min_bin_size 
296300            self .min_impurity_split  =  None 
301+             self .splitter_mode  =  splitter_mode 
297302            # self._estimator = DecisionTreeClassifier() 
298303
299304    def  fit (self , X , y , sample_weight = None ):
@@ -529,6 +534,11 @@ def _estimators_(self):
529534    def  _onedal_cpu_supported (self , method_name , * data ):
530535        if  method_name  ==  'ensemble.RandomForestClassifier.fit' :
531536            ready , X , y , sample_weight  =  self ._onedal_ready (* data )
537+             if  self .splitter_mode  ==  'random' :
538+                 warnings .warn ("'random' splitter mode supports GPU devices only " 
539+                               "and requires oneDAL version >= 2023.1.1. " 
540+                               "Using 'best' mode instead." , RuntimeWarning )
541+                 self .splitter_mode  =  'best' 
532542            if  not  ready :
533543                return  False 
534544            elif  sp .issparse (X ):
@@ -570,6 +580,11 @@ def _onedal_cpu_supported(self, method_name, *data):
570580    def  _onedal_gpu_supported (self , method_name , * data ):
571581        if  method_name  ==  'ensemble.RandomForestClassifier.fit' :
572582            ready , X , y , sample_weight  =  self ._onedal_ready (* data )
583+             if  self .splitter_mode  ==  'random'  and  \
584+                     not  daal_check_version ((2023 , 'P' , 101 )):
585+                 warnings .warn ("'random' splitter mode requires OneDAL >= 2023.1.1. " 
586+                               "Using 'best' mode instead." , RuntimeWarning )
587+                 self .splitter_mode  =  'best' 
573588            if  not  ready :
574589                return  False 
575590            elif  sp .issparse (X ):
@@ -687,6 +702,8 @@ def _onedal_fit(self, X, y, sample_weight=None, queue=None):
687702            'min_bin_size' : self .min_bin_size ,
688703            'max_samples' : self .max_samples 
689704        }
705+         if  daal_check_version ((2023 , 'P' , 101 )):
706+             onedal_params ['splitter_mode' ] =  self .splitter_mode 
690707        self ._cached_estimators_  =  None 
691708
692709        # Compute 
@@ -729,7 +746,8 @@ class RandomForestRegressor(sklearn_RandomForestRegressor, BaseRandomForest):
729746        _parameter_constraints : dict  =  {
730747            ** sklearn_RandomForestRegressor ._parameter_constraints ,
731748            "max_bins" : [Interval (numbers .Integral , 2 , None , closed = "left" )],
732-             "min_bin_size" : [Interval (numbers .Integral , 1 , None , closed = "left" )]
749+             "min_bin_size" : [Interval (numbers .Integral , 1 , None , closed = "left" )],
750+             "splitter_mode" : [StrOptions ({"best" , "random" })]
733751        }
734752
735753    if  sklearn_check_version ('1.0' ):
@@ -754,7 +772,8 @@ def __init__(
754772                ccp_alpha = 0.0 ,
755773                max_samples = None ,
756774                max_bins = 256 ,
757-                 min_bin_size = 1 ):
775+                 min_bin_size = 1 ,
776+                 splitter_mode = 'best' ):
758777            super (RandomForestRegressor , self ).__init__ (
759778                n_estimators = n_estimators ,
760779                criterion = criterion ,
@@ -778,6 +797,7 @@ def __init__(
778797            self .max_bins  =  max_bins 
779798            self .min_bin_size  =  min_bin_size 
780799            self .min_impurity_split  =  None 
800+             self .splitter_mode  =  splitter_mode 
781801    else :
782802        def  __init__ (self ,
783803                     n_estimators = 100 , * ,
@@ -799,7 +819,8 @@ def __init__(self,
799819                     ccp_alpha = 0.0 ,
800820                     max_samples = None ,
801821                     max_bins = 256 ,
802-                      min_bin_size = 1 ):
822+                      min_bin_size = 1 ,
823+                      splitter_mode = 'best' ):
803824            super (RandomForestRegressor , self ).__init__ (
804825                n_estimators = n_estimators ,
805826                criterion = criterion ,
@@ -826,6 +847,7 @@ def __init__(self,
826847            self .max_bins  =  max_bins 
827848            self .min_bin_size  =  min_bin_size 
828849            self .min_impurity_split  =  None 
850+             self .splitter_mode  =  splitter_mode 
829851
830852    @property  
831853    def  _estimators_ (self ):
@@ -902,6 +924,11 @@ def _onedal_ready(self, X, y, sample_weight):
902924    def  _onedal_cpu_supported (self , method_name , * data ):
903925        if  method_name  ==  'ensemble.RandomForestRegressor.fit' :
904926            ready , X , y , sample_weight  =  self ._onedal_ready (* data )
927+             if  self .splitter_mode  ==  'random' :
928+                 warnings .warn ("'random' splitter mode supports GPU devices only " 
929+                               "and requires oneDAL version >= 2023.1.1. " 
930+                               "Using 'best' mode instead." , RuntimeWarning )
931+                 self .splitter_mode  =  'best' 
905932            if  not  ready :
906933                return  False 
907934            elif  not  (self .oob_score  and  daal_check_version (
@@ -947,6 +974,11 @@ def _onedal_cpu_supported(self, method_name, *data):
947974    def  _onedal_gpu_supported (self , method_name , * data ):
948975        if  method_name  ==  'ensemble.RandomForestRegressor.fit' :
949976            ready , X , y , sample_weight  =  self ._onedal_ready (* data )
977+             if  self .splitter_mode  ==  'random'  and  \
978+                     not  daal_check_version ((2023 , 'P' , 101 )):
979+                 warnings .warn ("'random' splitter mode requires OneDAL >= 2023.1.1. " 
980+                               "Using 'best' mode instead." , RuntimeWarning )
981+                 self .splitter_mode  =  'best' 
950982            if  not  ready :
951983                return  False 
952984            elif  not  (self .oob_score  and  daal_check_version (
@@ -1035,6 +1067,8 @@ def _onedal_fit(self, X, y, sample_weight=None, queue=None):
10351067            'variable_importance_mode' : 'mdi' ,
10361068            'max_samples' : self .max_samples 
10371069        }
1070+         if  daal_check_version ((2023 , 'P' , 101 )):
1071+             onedal_params ['splitter_mode' ] =  self .splitter_mode 
10381072        self ._cached_estimators_  =  None 
10391073        self ._onedal_estimator  =  self ._onedal_regressor (** onedal_params )
10401074        self ._onedal_estimator .fit (X , y , sample_weight , queue = queue )
0 commit comments