@@ -280,7 +280,21 @@ def fit(self, X_train, X_test, y_train, y_test):
280280 ]
281281 )
282282
283- for name , model in tqdm (CLASSIFIERS ):
283+ # Here is for issue #114, if you do not want to choose all classifiers
284+ # you can call LazyClassifier as shown below:
285+ # LazyClassifier(classifiers=["DecisionTreeClassifier", "RandomForestClassifier"])
286+ temp_list = []
287+ if self .classifiers is not CLASSIFIERS :
288+ for name , model in all_estimators (): # an example of all_estimators() ↓↓↓↓
289+ for classifier in self .classifiers : # ('SVC', <class 'sklearn.svm._classes.SVC'>)
290+ if classifier is name :
291+ full_name = (name , model )
292+ temp_list .append (full_name )
293+ self .classifiers = temp_list .copy ()
294+ if not self .classifiers :
295+ print ("Invalid Classifier(s)" )
296+
297+ for name , model in tqdm (self .classifiers ):
284298 start = time .time ()
285299 try :
286300 if "random_state" in model ().get_params ().keys ():
@@ -502,13 +516,15 @@ def __init__(
502516 custom_metric = None ,
503517 predictions = False ,
504518 random_state = 42 ,
519+ regressors = REGRESSORS ,
505520 ):
506521 self .verbose = verbose
507522 self .ignore_warnings = ignore_warnings
508523 self .custom_metric = custom_metric
509524 self .predictions = predictions
510525 self .models = {}
511526 self .random_state = random_state
527+ self .regressors = REGRESSORS
512528
513529 def fit (self , X_train , X_test , y_train , y_test ):
514530 """Fit Regression algorithms to X_train and y_train, predict and score on X_test, y_test.
@@ -565,7 +581,21 @@ def fit(self, X_train, X_test, y_train, y_test):
565581 ]
566582 )
567583
568- for name , model in tqdm (REGRESSORS ):
584+ # Here is for issue #114, if you do not want to choose all regressors
585+ # you can call LazyRegressor as shown below:
586+ # LazyRegressor(regressors=["DecisionTreeRegressor", "RandomForestRegressor", "Ridge"])
587+ temp_list = []
588+ if self .regressors is not REGRESSORS :
589+ for name , model in all_estimators (): # an example of all_estimators() ↓↓↓↓
590+ for regressor in self .regressors : # ('SVC', <class 'sklearn.svm._classes.SVC'>)
591+ if regressor is name :
592+ full_name = (name , model )
593+ temp_list .append (full_name )
594+ self .regressors = temp_list .copy ()
595+ if not self .regressors :
596+ print ("Invalid Regressor(s)" )
597+
598+ for name , model in tqdm (self .regressors ):
569599 start = time .time ()
570600 try :
571601 if "random_state" in model ().get_params ().keys ():
0 commit comments