Skip to content

Commit 312e348

Browse files
Merge pull request #313 from brendalf/feat/ajusted_rsquared
Implement adjusted r-square metric for LazyRegressor
2 parents c6389ef + ba14fc8 commit 312e348

File tree

1 file changed

+39
-38
lines changed

1 file changed

+39
-38
lines changed

lazypredict/Supervised.py

Lines changed: 39 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,6 @@ def get_card_split(df, cols, n=11):
148148

149149
# Helper class for performing classification
150150

151-
152151
class LazyClassifier:
153152
"""
154153
This module helps in fitting to all the classification algorithms that are available in Scikit-learn
@@ -405,10 +404,14 @@ def provide_models(self, X_train, X_test, y_train, y_test):
405404
"""
406405
if len(self.models.keys()) == 0:
407406
self.fit(X_train,X_test,y_train,y_test)
408-
407+
409408
return self.models
410409

411410

411+
def adjusted_rsquared(r2, n, p):
412+
return 1 - (1-r2) * ((n-1) / (n-p-1))
413+
414+
412415
# Helper class for performing classification
413416

414417

@@ -522,13 +525,14 @@ def fit(self, X_train, X_test, y_train, y_test):
522525
Returns predictions of all the models in a Pandas DataFrame.
523526
"""
524527
R2 = []
528+
ADJR2 = []
525529
RMSE = []
526530
# WIN = []
527531
names = []
528532
TIME = []
529533
predictions = {}
530534

531-
if self.custom_metric is not None:
535+
if self.custom_metric:
532536
CUSTOM_METRIC = []
533537

534538
if isinstance(X_train, np.ndarray):
@@ -566,61 +570,58 @@ def fit(self, X_train, X_test, y_train, y_test):
566570
pipe = Pipeline(
567571
steps=[("preprocessor", preprocessor), ("regressor", model())]
568572
)
573+
569574
pipe.fit(X_train, y_train)
570575
self.models[name] = pipe
571576
y_pred = pipe.predict(X_test)
577+
572578
r_squared = r2_score(y_test, y_pred)
579+
adj_rsquared = adjusted_rsquared(r_squared, X_test.shape[0], X_test.shape[1])
573580
rmse = np.sqrt(mean_squared_error(y_test, y_pred))
581+
574582
names.append(name)
575583
R2.append(r_squared)
584+
ADJR2.append(adj_rsquared)
576585
RMSE.append(rmse)
577586
TIME.append(time.time() - start)
578-
if self.custom_metric is not None:
587+
588+
if self.custom_metric:
579589
custom_metric = self.custom_metric(y_test, y_pred)
580590
CUSTOM_METRIC.append(custom_metric)
581591

582592
if self.verbose > 0:
583-
if self.custom_metric is not None:
584-
print(
585-
{
586-
"Model": name,
587-
"R-Squared": r_squared,
588-
"RMSE": rmse,
589-
self.custom_metric.__name__: custom_metric,
590-
"Time taken": time.time() - start,
591-
}
592-
)
593-
else:
594-
print(
595-
{
596-
"Model": name,
597-
"R-Squared": r_squared,
598-
"RMSE": rmse,
599-
"Time taken": time.time() - start,
600-
}
601-
)
593+
scores_verbose = {
594+
"Model": name,
595+
"R-Squared": r_squared,
596+
"Adjusted R-Squared": adj_rsquared,
597+
"RMSE": rmse,
598+
"Time taken": time.time() - start,
599+
}
600+
601+
if self.custom_metric:
602+
scores_verbose[self.custom_metric.__name__] = custom_metric
603+
604+
print(scores_verbose)
602605
if self.predictions:
603606
predictions[name] = y_pred
604607
except Exception as exception:
605608
if self.ignore_warnings is False:
606609
print(name + " model failed to execute")
607610
print(exception)
608611

609-
if self.custom_metric is None:
610-
scores = pd.DataFrame(
611-
{"Model": names, "R-Squared": R2, "RMSE": RMSE, "Time Taken": TIME}
612-
)
613-
else:
614-
scores = pd.DataFrame(
615-
{
616-
"Model": names,
617-
"R-Squared": R2,
618-
"RMSE": RMSE,
619-
self.custom_metric.__name__: CUSTOM_METRIC,
620-
"Time Taken": TIME,
621-
}
622-
)
623-
scores = scores.sort_values(by="R-Squared", ascending=False).set_index("Model")
612+
scores = {
613+
"Model": names,
614+
"Adjusted R-Squared": ADJR2,
615+
"R-Squared": R2,
616+
"RMSE": RMSE,
617+
"Time Taken": TIME
618+
}
619+
620+
if self.custom_metric:
621+
scores[self.custom_metric.__name__] = CUSTOM_METRIC
622+
623+
scores = pd.DataFrame(scores)
624+
scores = scores.sort_values(by="Adjusted R-Squared", ascending=False).set_index("Model")
624625

625626
if self.predictions:
626627
predictions_df = pd.DataFrame.from_dict(predictions)

0 commit comments

Comments
 (0)