Skip to content

Commit ba14fc8

Browse files
committed
feat:: implement adjusted r-square metric for lazyregressor
1 parent 330461b commit ba14fc8

File tree

1 file changed

+39
-38
lines changed

1 file changed

+39
-38
lines changed

lazypredict/Supervised.py

Lines changed: 39 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,6 @@ def get_card_split(df, cols, n=11):
147147

148148
# Helper class for performing classification
149149

150-
151150
class LazyClassifier:
152151
"""
153152
This module helps in fitting to all the classification algorithms that are available in Scikit-learn
@@ -404,10 +403,14 @@ def provide_models(self, X_train, X_test, y_train, y_test):
404403
"""
405404
if len(self.models.keys()) == 0:
406405
self.fit(X_train,X_test,y_train,y_test)
407-
406+
408407
return self.models
409408

410409

410+
def adjusted_rsquared(r2, n, p):
411+
return 1 - (1-r2) * ((n-1) / (n-p-1))
412+
413+
411414
# Helper class for performing classification
412415

413416

@@ -521,13 +524,14 @@ def fit(self, X_train, X_test, y_train, y_test):
521524
Returns predictions of all the models in a Pandas DataFrame.
522525
"""
523526
R2 = []
527+
ADJR2 = []
524528
RMSE = []
525529
# WIN = []
526530
names = []
527531
TIME = []
528532
predictions = {}
529533

530-
if self.custom_metric is not None:
534+
if self.custom_metric:
531535
CUSTOM_METRIC = []
532536

533537
if isinstance(X_train, np.ndarray):
@@ -565,61 +569,58 @@ def fit(self, X_train, X_test, y_train, y_test):
565569
pipe = Pipeline(
566570
steps=[("preprocessor", preprocessor), ("regressor", model())]
567571
)
572+
568573
pipe.fit(X_train, y_train)
569574
self.models[name] = pipe
570575
y_pred = pipe.predict(X_test)
576+
571577
r_squared = r2_score(y_test, y_pred)
578+
adj_rsquared = adjusted_rsquared(r_squared, X_test.shape[0], X_test.shape[1])
572579
rmse = np.sqrt(mean_squared_error(y_test, y_pred))
580+
573581
names.append(name)
574582
R2.append(r_squared)
583+
ADJR2.append(adj_rsquared)
575584
RMSE.append(rmse)
576585
TIME.append(time.time() - start)
577-
if self.custom_metric is not None:
586+
587+
if self.custom_metric:
578588
custom_metric = self.custom_metric(y_test, y_pred)
579589
CUSTOM_METRIC.append(custom_metric)
580590

581591
if self.verbose > 0:
582-
if self.custom_metric is not None:
583-
print(
584-
{
585-
"Model": name,
586-
"R-Squared": r_squared,
587-
"RMSE": rmse,
588-
self.custom_metric.__name__: custom_metric,
589-
"Time taken": time.time() - start,
590-
}
591-
)
592-
else:
593-
print(
594-
{
595-
"Model": name,
596-
"R-Squared": r_squared,
597-
"RMSE": rmse,
598-
"Time taken": time.time() - start,
599-
}
600-
)
592+
scores_verbose = {
593+
"Model": name,
594+
"R-Squared": r_squared,
595+
"Adjusted R-Squared": adj_rsquared,
596+
"RMSE": rmse,
597+
"Time taken": time.time() - start,
598+
}
599+
600+
if self.custom_metric:
601+
scores_verbose[self.custom_metric.__name__] = custom_metric
602+
603+
print(scores_verbose)
601604
if self.predictions:
602605
predictions[name] = y_pred
603606
except Exception as exception:
604607
if self.ignore_warnings is False:
605608
print(name + " model failed to execute")
606609
print(exception)
607610

608-
if self.custom_metric is None:
609-
scores = pd.DataFrame(
610-
{"Model": names, "R-Squared": R2, "RMSE": RMSE, "Time Taken": TIME}
611-
)
612-
else:
613-
scores = pd.DataFrame(
614-
{
615-
"Model": names,
616-
"R-Squared": R2,
617-
"RMSE": RMSE,
618-
self.custom_metric.__name__: CUSTOM_METRIC,
619-
"Time Taken": TIME,
620-
}
621-
)
622-
scores = scores.sort_values(by="R-Squared", ascending=False).set_index("Model")
611+
scores = {
612+
"Model": names,
613+
"Adjusted R-Squared": ADJR2,
614+
"R-Squared": R2,
615+
"RMSE": RMSE,
616+
"Time Taken": TIME
617+
}
618+
619+
if self.custom_metric:
620+
scores[self.custom_metric.__name__] = CUSTOM_METRIC
621+
622+
scores = pd.DataFrame(scores)
623+
scores = scores.sort_values(by="Adjusted R-Squared", ascending=False).set_index("Model")
623624

624625
if self.predictions:
625626
predictions_df = pd.DataFrame.from_dict(predictions)

0 commit comments

Comments
 (0)