Skip to content

Commit b2d03b1

Browse files
committed
Add : new raise value error and linked unit test
1 parent 20a881e commit b2d03b1

File tree

2 files changed

+36
-9
lines changed

2 files changed

+36
-9
lines changed

mapie/tests/test_regression.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -985,6 +985,27 @@ def test_using_one_predict_parameter_into_predict_but_not_in_fit() -> None:
985985
mapie_fitted.predict(X_test, **predict_params)
986986

987987

988+
def test_using_one_predict_parameter_into_fit_but_not_in_predict() -> None:
989+
"""Test that using predict parameters in the fit method
990+
without using one predict_parameter in
991+
the predict method raises an error"""
992+
custom_gbr = CustomGradientBoostingRegressor(random_state=random_state)
993+
X_train, X_test, y_train, y_test = (
994+
train_test_split(X, y, test_size=0.2, random_state=random_state)
995+
)
996+
mapie = MapieRegressor(estimator=custom_gbr)
997+
predict_params = {'check_predict_params': True}
998+
mapie_fitted = mapie.fit(X_train, y_train, predict_params=predict_params)
999+
1000+
with pytest.raises(ValueError, match=(
1001+
r"Using one 'predict_param' in the fit method "
1002+
r"without using one 'predict_param' in the predict method. "
1003+
r"Please ensure one 'predict_param' "
1004+
r"is used in the predict method before calling it."
1005+
)):
1006+
mapie_fitted.predict(X_test)
1007+
1008+
9881009
def test_predict_infinite_intervals() -> None:
9891010
"""Test that MapieRegressor produces infinite bounds with alpha=0"""
9901011
mapie_reg = MapieRegressor().fit(X, y)

mapie/utils.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1398,12 +1398,18 @@ def check_predict_params(
13981398
If any predict_params are used in the predict method but none
13991399
are used in the fit method.
14001400
"""
1401-
if (len(predict_params) > 0 and
1402-
predict_params_used_in_fit is False and
1403-
cv != "prefit"):
1404-
raise ValueError(
1405-
f"Using 'predict_param' '{predict_params}' "
1406-
f"without using one 'predict_param' in the fit method. "
1407-
f"Please ensure one 'predict_param' "
1408-
f"is used in the fit method before calling predict."
1409-
)
1401+
if cv != "prefit":
1402+
if len(predict_params) > 0 and predict_params_used_in_fit is False:
1403+
raise ValueError(
1404+
f"Using 'predict_param' '{predict_params}' "
1405+
f"without using one 'predict_param' in the fit method. "
1406+
f"Please ensure one 'predict_param' "
1407+
f"is used in the fit method before calling predict."
1408+
)
1409+
if len(predict_params) == 0 and predict_params_used_in_fit is True:
1410+
raise ValueError(
1411+
"Using one 'predict_param' in the fit method "
1412+
"without using one 'predict_param' in the predict method. "
1413+
"Please ensure one 'predict_param' "
1414+
"is used in the predict method before calling it."
1415+
)

0 commit comments

Comments
 (0)