Skip to content

Commit 13da57f

Browse files
committed
chore: Refactor train-test split in plot_cqr_tutorial.py
1 parent 044ae69 commit 13da57f

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

examples/regression/4-tutorials/plot_cqr_tutorial.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -101,11 +101,6 @@ class :class:`~mapie.subsample.Subsample` (note that the `alpha` parameter is
101101
y['MedHouseVal'],
102102
random_state=random_state
103103
)
104-
X_train, X_calib, y_train, y_calib = train_test_split(
105-
X_train,
106-
y_train,
107-
random_state=random_state
108-
)
109104

110105

111106
##############################################################################
@@ -267,13 +262,19 @@ def plot_prediction_intervals(
267262
if strategy == "cqr":
268263
mapie = MapieQuantileRegressor(estimator, **params)
269264
mapie.fit(
270-
X_train, y_train,
271-
X_calib=X_calib, y_calib=y_calib,
265+
X_train,
266+
y_train,
267+
calib_size=0.3,
272268
random_state=random_state
273269
)
274270
y_pred[strategy], y_pis[strategy] = mapie.predict(X_test)
275271
else:
276-
mapie = MapieRegressor(estimator, **params, random_state=random_state)
272+
mapie = MapieRegressor(
273+
estimator,
274+
test_size=0.3,
275+
random_state=random_state,
276+
**params
277+
)
277278
mapie.fit(X_train, y_train)
278279
y_pred[strategy], y_pis[strategy] = mapie.predict(X_test, alpha=0.2)
279280
(

0 commit comments

Comments
 (0)