Skip to content

Commit 4426376

Browse files
Merge pull request #457 from scikit-learn-contrib/447-mapieregressor-sets-method-to-base-from-aci
Solve conflict between ACI and base methods in MapieTimeSeriesRegressor
2 parents f98bafc + aa587c7 commit 4426376

File tree

5 files changed

+41
-1
lines changed

5 files changed

+41
-1
lines changed

AUTHORS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,5 @@ Contributors
4040
* Pierre de Fréminville <pidefrem>
4141
* Ambros Marzetta <ambrosm>
4242
* Carl McBride Ellis <Carl-McBride-Ellis>
43+
* Baptiste Calot <[email protected]>
4344
To be continued ...

HISTORY.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ History
66
------------------
77

88
* Fix the quantile formula to ensure valid coverage for any number of calibration data in `ConformityScore`.
9+
* Fix overloading of the value of the `method` attribute when using `MapieRegressor` and `MapieTimeSeriesRegressor`.
910
* Fix conda versionning.
1011
* Reduce precision for test in `MapieCalibrator`.
1112
* Fix invalid certificate when downloading data.

mapie/regression/regression.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -425,7 +425,8 @@ def _check_fit_parameters(
425425
cv = check_cv(
426426
self.cv, test_size=self.test_size, random_state=self.random_state
427427
)
428-
if self.cv in ["split", "prefit"] and self.method != "base":
428+
if self.cv in ["split", "prefit"] and \
429+
self.method in ["naive", "plus", "minmax"]:
429430
self.method = "base"
430431
estimator = self._check_estimator(self.estimator)
431432
agg_function = self._check_agg_function(self.agg_function)

mapie/tests/test_regression.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -793,3 +793,19 @@ def test_predict_infinite_intervals() -> None:
793793
_, y_pis = mapie_reg.predict(X, alpha=0., allow_infinite_bounds=True)
794794
np.testing.assert_allclose(y_pis[:, 0, 0], -np.inf)
795795
np.testing.assert_allclose(y_pis[:, 1, 0], np.inf)
796+
797+
798+
@pytest.mark.parametrize("method", ["minmax", "naive", "plus", "base"])
799+
@pytest.mark.parametrize("cv", ["split", "prefit"])
800+
def test_check_change_method_to_base(method: str, cv: str) -> None:
801+
"""Test of the overloading of method attribute to `base` method in fit"""
802+
803+
X_train, X_val, y_train, y_val = train_test_split(
804+
X, y, test_size=0.5, random_state=random_state
805+
)
806+
estimator = LinearRegression().fit(X_train, y_train)
807+
mapie_reg = MapieRegressor(
808+
cv=cv, method=method, estimator=estimator
809+
)
810+
mapie_reg.fit(X_val, y_val)
811+
assert mapie_reg.method == "base"

mapie/tests/test_time_series_regression.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -518,3 +518,24 @@ def test_method_error_in_update(monkeypatch: Any, method: str) -> None:
518518
with pytest.raises(ValueError, match=r".*Invalid method.*"):
519519
mapie_ts_reg.fit(X_toy, y_toy)
520520
mapie_ts_reg.update(X_toy, y_toy)
521+
522+
523+
@pytest.mark.parametrize("method", ["enbpi", "aci"])
524+
@pytest.mark.parametrize("cv", ["split", "prefit"])
525+
def test_methods_preservation_in_fit(method: str, cv: str) -> None:
526+
"""Test of enbpi and aci method preservation in the fit MapieRegressor"""
527+
528+
X_train_val, X_test, y_train_val, y_test = train_test_split(
529+
X, y, test_size=0.33, random_state=random_state
530+
)
531+
X_train, X_val, y_train, y_val = train_test_split(
532+
X_train_val, y_train_val, test_size=0.5, random_state=random_state
533+
)
534+
estimator = LinearRegression().fit(X_train, y_train)
535+
mapie_ts_reg = MapieTimeSeriesRegressor(
536+
estimator=estimator,
537+
cv=cv, method=method
538+
)
539+
mapie_ts_reg.fit(X_val, y_val)
540+
mapie_ts_reg.update(X_test, y_test, gamma=0.1, alpha=0.1)
541+
assert mapie_ts_reg.method == method

0 commit comments

Comments
 (0)