Skip to content

Commit 6feddf5

Browse files
Applied recommended changes
1 parent 9ed2708 commit 6feddf5

File tree

1 file changed

+31
-30
lines changed

1 file changed

+31
-30
lines changed

lazypredict/Supervised.py

Lines changed: 31 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,8 @@ class LazyClassifier:
162162
When function is provided, models are evaluated based on the custom evaluation metric provided.
163163
prediction : bool, optional (default=False)
164164
When set to True, the predictions of all the models models are returned as dataframe.
165+
classifiers : list, optional (default="all")
166+
When function is provided, trains the chosen classifier(s).
165167
166168
Examples
167169
--------
@@ -217,7 +219,7 @@ def __init__(
217219
custom_metric=None,
218220
predictions=False,
219221
random_state=42,
220-
classifiers = CLASSIFIERS
222+
classifiers = "all"
221223
):
222224
self.verbose = verbose
223225
self.ignore_warnings = ignore_warnings
@@ -282,20 +284,18 @@ def fit(self, X_train, X_test, y_train, y_test):
282284
]
283285
)
284286

285-
# Here is for issue #114, if you do not want to choose all classifiers
286-
# you can call LazyClassifier as shown below:
287-
# LazyClassifier(classifiers=["DecisionTreeClassifier", "RandomForestClassifier"])
288-
temp_list = []
289-
if self.classifiers is not CLASSIFIERS:
290-
for name, model in all_estimators(): # an example of all_estimators() ↓↓↓↓
291-
for classifier in self.classifiers: # ('SVC', <class 'sklearn.svm._classes.SVC'>)
292-
if classifier is name:
293-
full_name = (name, model)
294-
temp_list.append(full_name)
295-
self.classifiers = temp_list.copy()
296-
if not self.classifiers:
297-
print("Invalid Classifier(s)")
298-
287+
if self.classifiers == "all":
288+
self.classifiers = CLASSIFIERS
289+
else:
290+
try:
291+
temp_list=[]
292+
for classifier in self.classifiers:
293+
full_name = (classifier.__class__.__name__, classifier)
294+
temp_list.append(full_name)
295+
self.classifiers = temp_list
296+
except Exception as exception:
297+
print(exception)
298+
print("Invalid Classifier(s)")
299299
for name, model in tqdm(self.classifiers):
300300
start = time.time()
301301
try:
@@ -445,6 +445,8 @@ class LazyRegressor:
445445
When function is provided, models are evaluated based on the custom evaluation metric provided.
446446
prediction : bool, optional (default=False)
447447
When set to True, the predictions of all the models models are returned as dataframe.
448+
regressors : list, optional (default="all")
449+
When function is provided, trains the chosen regressor(s).
448450
449451
Examples
450452
--------
@@ -518,15 +520,15 @@ def __init__(
518520
custom_metric=None,
519521
predictions=False,
520522
random_state=42,
521-
regressors=REGRESSORS,
523+
regressors="all",
522524
):
523525
self.verbose = verbose
524526
self.ignore_warnings = ignore_warnings
525527
self.custom_metric = custom_metric
526528
self.predictions = predictions
527529
self.models = {}
528530
self.random_state = random_state
529-
self.regressors = regressors
531+
self.regressors = regressors
530532

531533
def fit(self, X_train, X_test, y_train, y_test):
532534
"""Fit Regression algorithms to X_train and y_train, predict and score on X_test, y_test.
@@ -583,19 +585,18 @@ def fit(self, X_train, X_test, y_train, y_test):
583585
]
584586
)
585587

586-
# Here is for issue #114, if you do not want to choose all regressors
587-
# you can call LazyRegressor as shown below:
588-
# LazyRegressor(regressors=["DecisionTreeRegressor", "RandomForestRegressor", "Ridge"])
589-
temp_list = []
590-
if self.regressors is not REGRESSORS:
591-
for name, model in all_estimators(): # an example of all_estimators() ↓↓↓↓
592-
for regressor in self.regressors: # ('SVC', <class 'sklearn.svm._classes.SVC'>)
593-
if regressor is name:
594-
full_name = (name, model)
595-
temp_list.append(full_name)
596-
self.regressors = temp_list.copy()
597-
if not self.regressors:
598-
print("Invalid Regressor(s)")
588+
if self.regressors == "all":
589+
self.regressors = REGRESSORS
590+
else:
591+
try:
592+
temp_list = []
593+
for regressor in self.regressors:
594+
full_name = (regressor.__class__.__name__, regressor)
595+
temp_list.append(full_name)
596+
self.regressors = temp_list
597+
except Exception as exception:
598+
print(exception)
599+
print("Invalid Regressor(s)")
599600

600601
for name, model in tqdm(self.regressors):
601602
start = time.time()

0 commit comments

Comments
 (0)