@@ -162,6 +162,8 @@ class LazyClassifier:
162162 When function is provided, models are evaluated based on the custom evaluation metric provided.
163163 prediction : bool, optional (default=False)
164164 When set to True, the predictions of all the models models are returned as dataframe.
165+ classifiers : list, optional (default="all")
166+ When function is provided, trains the chosen classifier(s).
165167
166168 Examples
167169 --------
@@ -217,7 +219,7 @@ def __init__(
217219 custom_metric = None ,
218220 predictions = False ,
219221 random_state = 42 ,
220- classifiers = CLASSIFIERS
222+ classifiers = "all"
221223 ):
222224 self .verbose = verbose
223225 self .ignore_warnings = ignore_warnings
@@ -282,20 +284,18 @@ def fit(self, X_train, X_test, y_train, y_test):
282284 ]
283285 )
284286
285- # Here is for issue #114, if you do not want to choose all classifiers
286- # you can call LazyClassifier as shown below:
287- # LazyClassifier(classifiers=["DecisionTreeClassifier", "RandomForestClassifier"])
288- temp_list = []
289- if self .classifiers is not CLASSIFIERS :
290- for name , model in all_estimators (): # an example of all_estimators() ↓↓↓↓
291- for classifier in self .classifiers : # ('SVC', <class 'sklearn.svm._classes.SVC'>)
292- if classifier is name :
293- full_name = (name , model )
294- temp_list .append (full_name )
295- self .classifiers = temp_list .copy ()
296- if not self .classifiers :
297- print ("Invalid Classifier(s)" )
298-
287+ if self .classifiers == "all" :
288+ self .classifiers = CLASSIFIERS
289+ else :
290+ try :
291+ temp_list = []
292+ for classifier in self .classifiers :
293+ full_name = (classifier .__class__ .__name__ , classifier )
294+ temp_list .append (full_name )
295+ self .classifiers = temp_list
296+ except Exception as exception :
297+ print (exception )
298+ print ("Invalid Classifier(s)" )
299299 for name , model in tqdm (self .classifiers ):
300300 start = time .time ()
301301 try :
@@ -445,6 +445,8 @@ class LazyRegressor:
445445 When function is provided, models are evaluated based on the custom evaluation metric provided.
446446 prediction : bool, optional (default=False)
447447 When set to True, the predictions of all the models models are returned as dataframe.
448+ regressors : list, optional (default="all")
449+ When function is provided, trains the chosen regressor(s).
448450
449451 Examples
450452 --------
@@ -518,15 +520,15 @@ def __init__(
518520 custom_metric = None ,
519521 predictions = False ,
520522 random_state = 42 ,
521- regressors = REGRESSORS ,
523+ regressors = "all" ,
522524 ):
523525 self .verbose = verbose
524526 self .ignore_warnings = ignore_warnings
525527 self .custom_metric = custom_metric
526528 self .predictions = predictions
527529 self .models = {}
528530 self .random_state = random_state
529- self .regressors = regressors
531+ self .regressors = regressors
530532
531533 def fit (self , X_train , X_test , y_train , y_test ):
532534 """Fit Regression algorithms to X_train and y_train, predict and score on X_test, y_test.
@@ -583,19 +585,18 @@ def fit(self, X_train, X_test, y_train, y_test):
583585 ]
584586 )
585587
586- # Here is for issue #114, if you do not want to choose all regressors
587- # you can call LazyRegressor as shown below:
588- # LazyRegressor(regressors=["DecisionTreeRegressor", "RandomForestRegressor", "Ridge"])
589- temp_list = []
590- if self .regressors is not REGRESSORS :
591- for name , model in all_estimators (): # an example of all_estimators() ↓↓↓↓
592- for regressor in self .regressors : # ('SVC', <class 'sklearn.svm._classes.SVC'>)
593- if regressor is name :
594- full_name = (name , model )
595- temp_list .append (full_name )
596- self .regressors = temp_list .copy ()
597- if not self .regressors :
598- print ("Invalid Regressor(s)" )
588+ if self .regressors == "all" :
589+ self .regressors = REGRESSORS
590+ else :
591+ try :
592+ temp_list = []
593+ for regressor in self .regressors :
594+ full_name = (regressor .__class__ .__name__ , regressor )
595+ temp_list .append (full_name )
596+ self .regressors = temp_list
597+ except Exception as exception :
598+ print (exception )
599+ print ("Invalid Regressor(s)" )
599600
600601 for name , model in tqdm (self .regressors ):
601602 start = time .time ()
0 commit comments