Skip to content

Commit f2ec8df

Browse files
committed
coverage 100%
1 parent d4ae198 commit f2ec8df

File tree

4 files changed

+150
-20
lines changed

4 files changed

+150
-20
lines changed

examples/regression/1-quickstart/plot_prefit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def f(x: NDArray) -> NDArray:
8383
estimator = LGBMRegressor(
8484
objective='quantile',
8585
alpha=0.5,
86-
) # Note that this is the same model as used for QR
86+
) # Note that this is the same model as used for QR
8787
estimator.fit(X_train.reshape(-1, 1), y_train)
8888
list_estimators.append(estimator)
8989

mapie/quantile_regression.py

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
fit_estimator,
2323
check_lower_upper_bounds,
2424
check_defined_variables_predict_cqr,
25+
check_estimator_fit_predict,
2526
)
2627
from ._compatibility import np_quantile
2728
from .regression import MapieRegressor
@@ -190,6 +191,12 @@ def _check_alpha(
190191
ValueError
191192
If the value of alpha is not between 0 and 1.0.
192193
"""
194+
if self.cv == "prefit":
195+
warnings.warn(
196+
"WARNING: The alpha that is set needs to be the same"
197+
+ " as the alpha of your prefitted model in the following"
198+
" order [alpha/2, 1 - alpha/2, 0.5]"
199+
)
193200
if isinstance(alpha, float):
194201
if np.any(np.logical_or(alpha <= 0, alpha >= 1.0)):
195202
raise ValueError(
@@ -251,11 +258,7 @@ def _check_estimator(
251258
solver="highs-ds",
252259
alpha=0.0,
253260
)
254-
if not (hasattr(estimator, "fit") and hasattr(estimator, "predict")):
255-
raise ValueError(
256-
"Invalid estimator. "
257-
"Please provide a regressor with fit and predict methods."
258-
)
261+
check_estimator_fit_predict(estimator)
259262
if isinstance(estimator, Pipeline):
260263
self._check_estimator(estimator[-1])
261264
return estimator
@@ -408,10 +411,6 @@ def _check_calib_set(
408411
def _check_prefit_params(
409412
self,
410413
estimator: List[Union[RegressorMixin, Pipeline]],
411-
X: ArrayLike,
412-
y: ArrayLike,
413-
X_calib: Optional[ArrayLike] = None,
414-
y_calib: Optional[ArrayLike] = None,
415414
) -> None:
416415
"""
417416
Check the parameters set for the specific case of prefit
@@ -451,19 +450,14 @@ def _check_prefit_params(
451450
)
452451
if len(estimator) == 3:
453452
for est in estimator:
454-
self._check_estimator(est)
453+
check_estimator_fit_predict(est)
454+
check_is_fitted(est)
455455
else:
456456
raise ValueError(
457457
"You need to have provided 3 different estimators, they"
458458
" need to be preset with alpha values in the following"
459459
" order [alpha/2, 1 - alpha/2, 0.5]."
460460
)
461-
if self.alpha is not None:
462-
warnings.warn(
463-
"WARNING: The alpha that is set needs to be the same"
464-
+ " as the alpha of your prefitted model in the following"
465-
" order [alpha/2, 1 - alpha/2, 0.5]"
466-
)
467461

468462
def fit(
469463
self,
@@ -536,12 +530,11 @@ def fit(
536530
self.estimators_: List[RegressorMixin] = []
537531
if self.cv == "prefit":
538532
estimator = cast(List, self.estimator)
539-
self._check_prefit_params(estimator, X, y)
533+
alpha = self._check_alpha(self.alpha)
534+
self._check_prefit_params(estimator)
540535
X_calib, y_calib = indexable(X, y)
541536

542537
self.n_calib_samples = _num_samples(y_calib)
543-
check_alpha_and_n_samples(self.alpha, self.n_calib_samples)
544-
545538
y_calib_preds = np.full(
546539
shape=(3, self.n_calib_samples),
547540
fill_value=np.nan

mapie/tests/test_quantile_regression.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,120 @@ def test_linear_regression_results(strategy: str) -> None:
519519
np.testing.assert_allclose(coverage, COVERAGES[strategy], rtol=1e-2)
520520

521521

522+
def test_quantile_prefit_non_list() -> None:
523+
"""
524+
Test that there is a list of estimators provided when cv='prefit'
525+
is called for MapieQuantileRegressor.
526+
"""
527+
with pytest.raises(
528+
ValueError,
529+
match=r".*Estimator for prefit must be an iterable object.*",
530+
):
531+
not_an_iterable = 10
532+
mapie_reg = MapieQuantileRegressor(
533+
estimator=not_an_iterable,
534+
cv="prefit"
535+
)
536+
mapie_reg.fit(
537+
X_calib_toy,
538+
y_calib_toy
539+
)
540+
541+
542+
def test_quantile_prefit_three_estimators() -> None:
543+
"""
544+
Test that there is a list of estimators three estimators provided for
545+
cv="prefit".
546+
"""
547+
with pytest.raises(
548+
ValueError,
549+
match=r".*You need to have provided 3 different estimators, th*",
550+
):
551+
gb_trained1, gb_trained2 = clone(gb), clone(gb)
552+
gb_trained1.fit(X_train, y_train)
553+
gb_trained2.fit(X_train, y_train)
554+
list_estimators = [gb_trained1, gb_trained2]
555+
mapie_reg = MapieQuantileRegressor(
556+
estimator=list_estimators,
557+
cv="prefit"
558+
)
559+
mapie_reg.fit(
560+
X_calib,
561+
y_calib
562+
)
563+
564+
565+
def test_prefit_no_fit_predict() -> None:
566+
"""
567+
Check that the user is warned that the alphas need to be correctly set.
568+
"""
569+
with pytest.raises(
570+
ValueError,
571+
match=r"Invalid estimator. Please provide a regressor with fit and*",
572+
):
573+
gb_trained1, gb_trained2 = clone(gb), clone(gb)
574+
gb_trained1.fit(X_train, y_train)
575+
gb_trained2.fit(X_train, y_train)
576+
gb_trained3 = 3
577+
list_estimators = [gb_trained1, gb_trained2, gb_trained3]
578+
mapie_reg = MapieQuantileRegressor(
579+
estimator=list_estimators,
580+
cv="prefit",
581+
alpha=0.3
582+
)
583+
mapie_reg.fit(
584+
X_calib,
585+
y_calib
586+
)
587+
588+
589+
def test_non_trained_estimator() -> None:
590+
"""
591+
Check that the user is warned that the alphas need to be correctly set.
592+
"""
593+
with pytest.raises(
594+
ValueError,
595+
match=r".*instance is not fitted yet. Call 'fit' with appropriate*",
596+
):
597+
gb_trained1, gb_trained2, gb_trained3 = clone(gb), clone(gb), clone(gb)
598+
gb_trained1.fit(X_train, y_train)
599+
gb_trained2.fit(X_train, y_train)
600+
list_estimators = [gb_trained1, gb_trained2, gb_trained3]
601+
mapie_reg = MapieQuantileRegressor(
602+
estimator=list_estimators,
603+
cv="prefit",
604+
alpha=0.3
605+
)
606+
mapie_reg.fit(
607+
X_calib,
608+
y_calib
609+
)
610+
611+
612+
def test_warning_alpha_prefit() -> None:
613+
"""
614+
Check that the user is warned that the alphas need to be correctly set.
615+
"""
616+
with pytest.warns(
617+
UserWarning,
618+
match=r".*WARNING: The alpha that is set needs to be the same*"
619+
):
620+
gb_trained1, gb_trained2, gb_trained3 = clone(gb), clone(gb), clone(gb)
621+
gb_trained1.fit(X_train, y_train)
622+
gb_trained2.fit(X_train, y_train)
623+
gb_trained3.fit(X_train, y_train)
624+
list_estimators = [gb_trained1, gb_trained2, gb_trained3]
625+
mapie_reg = MapieQuantileRegressor(
626+
estimator=list_estimators,
627+
cv="prefit",
628+
alpha=0.3
629+
)
630+
mapie_reg.fit(
631+
X_calib,
632+
y_calib
633+
)
634+
635+
522636
@pytest.mark.parametrize("estimator", ESTIMATOR)
523637
def test_pipeline_compatibility(estimator: RegressorMixin) -> None:
524638
"""Check that MAPIE works on pipeline based on pandas dataframes"""

mapie/utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -563,3 +563,26 @@ def check_defined_variables_predict_cqr(
563563
"WARNING: Alpha should not be specified in the prediction method\n"
564564
+ "with conformalized quantile regression."
565565
)
566+
567+
568+
def check_estimator_fit_predict(
569+
estimator: Union[RegressorMixin, ClassifierMixin]
570+
) -> None:
571+
"""
572+
Check that the estimator has a fit and precict method.
573+
574+
Parameters
575+
----------
576+
estimator : Union[RegressorMixin, ClassifierMixin]
577+
Estimator to train.
578+
579+
Raises
580+
------
581+
ValueError
582+
If the estimator does not have a fit or predict attribute.
583+
"""
584+
if not (hasattr(estimator, "fit") and hasattr(estimator, "predict")):
585+
raise ValueError(
586+
"Invalid estimator. "
587+
"Please provide a regressor with fit and predict methods."
588+
)

0 commit comments

Comments
 (0)