Skip to content

Commit f84b175

Browse files
Solve #114 (Allowing users to choose the models they want)
1 parent 1526a91 commit f84b175

File tree

1 file changed

+32
-2
lines changed

1 file changed

+32
-2
lines changed

lazypredict/Supervised.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,21 @@ def fit(self, X_train, X_test, y_train, y_test):
280280
]
281281
)
282282

283-
for name, model in tqdm(CLASSIFIERS):
283+
# Here is for issue #114, if you do not want to choose all classifiers
284+
# you can call LazyClassifier as shown below:
285+
# LazyClassifier(classifiers=["DecisionTreeClassifier", "RandomForestClassifier"])
286+
temp_list = []
287+
if self.classifiers is not CLASSIFIERS:
288+
for name, model in all_estimators(): # an example of all_estimators() ↓↓↓↓
289+
for classifier in self.classifiers: # ('SVC', <class 'sklearn.svm._classes.SVC'>)
290+
if classifier is name:
291+
full_name = (name, model)
292+
temp_list.append(full_name)
293+
self.classifiers = temp_list.copy()
294+
if not self.classifiers:
295+
print("Invalid Classifier(s)")
296+
297+
for name, model in tqdm(self.classifiers):
284298
start = time.time()
285299
try:
286300
if "random_state" in model().get_params().keys():
@@ -502,13 +516,15 @@ def __init__(
502516
custom_metric=None,
503517
predictions=False,
504518
random_state=42,
519+
regressors=REGRESSORS,
505520
):
506521
self.verbose = verbose
507522
self.ignore_warnings = ignore_warnings
508523
self.custom_metric = custom_metric
509524
self.predictions = predictions
510525
self.models = {}
511526
self.random_state = random_state
527+
self.regressors = REGRESSORS
512528

513529
def fit(self, X_train, X_test, y_train, y_test):
514530
"""Fit Regression algorithms to X_train and y_train, predict and score on X_test, y_test.
@@ -565,7 +581,21 @@ def fit(self, X_train, X_test, y_train, y_test):
565581
]
566582
)
567583

568-
for name, model in tqdm(REGRESSORS):
584+
# Here is for issue #114, if you do not want to choose all regressors
585+
# you can call LazyRegressor as shown below:
586+
# LazyRegressor(regressors=["DecisionTreeRegressor", "RandomForestRegressor", "Ridge"])
587+
temp_list = []
588+
if self.regressors is not REGRESSORS:
589+
for name, model in all_estimators(): # an example of all_estimators() ↓↓↓↓
590+
for regressor in self.regressors: # ('SVC', <class 'sklearn.svm._classes.SVC'>)
591+
if regressor is name:
592+
full_name = (name, model)
593+
temp_list.append(full_name)
594+
self.regressors = temp_list.copy()
595+
if not self.regressors:
596+
print("Invalid Regressor(s)")
597+
598+
for name, model in tqdm(self.regressors):
569599
start = time.time()
570600
try:
571601
if "random_state" in model().get_params().keys():

0 commit comments

Comments
 (0)