Skip to content

Commit 3c36fff

Browse files
MTN: minor refactoring of JCAB including unit testing to achieve 100% coverage
1 parent 37277e0 commit 3c36fff

File tree

2 files changed

+43
-8
lines changed

2 files changed

+43
-8
lines changed

mapie/regression/regression.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -720,14 +720,7 @@ def __init__(
720720
JackknifeAfterBootstrapRegressor._VALID_AGGREGATION_METHODS
721721
)
722722

723-
if isinstance(resampling, int):
724-
cv = Subsample(n_resamplings=resampling)
725-
elif isinstance(resampling, Subsample):
726-
cv = resampling
727-
else:
728-
raise ValueError(
729-
"resampling must be an integer or a Subsample instance"
730-
)
723+
cv = self._check_and_convert_resampling_to_cv(resampling)
731724

732725
self._mapie_regressor = MapieRegressor(
733726
estimator=estimator,
@@ -901,6 +894,20 @@ def predict(
901894
)
902895
return cast_point_predictions_to_ndarray(predictions)
903896

897+
@staticmethod
898+
def _check_and_convert_resampling_to_cv(
899+
resampling: Union[int, Subsample]
900+
) -> Subsample:
901+
if isinstance(resampling, int):
902+
cv = Subsample(n_resamplings=resampling)
903+
elif isinstance(resampling, Subsample):
904+
cv = resampling
905+
else:
906+
raise ValueError(
907+
"resampling must be an integer or a Subsample instance"
908+
)
909+
return cv
910+
904911

905912
class MapieRegressor(RegressorMixin, BaseEstimator):
906913
"""

tests_v1/test_unit/test_regression.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import pytest
2+
from mapie.subsample import Subsample
3+
from mapie.regression import JackknifeAfterBootstrapRegressor
4+
5+
6+
class TestCheckAndConvertResamplingToCv:
7+
def test_with_integer(self):
8+
regressor = JackknifeAfterBootstrapRegressor()
9+
cv = regressor._check_and_convert_resampling_to_cv(50)
10+
11+
assert isinstance(cv, Subsample)
12+
assert cv.n_resamplings == 50
13+
14+
def test_with_subsample(self):
15+
custom_subsample = Subsample(n_resamplings=25, random_state=42)
16+
regressor = JackknifeAfterBootstrapRegressor()
17+
cv = regressor._check_and_convert_resampling_to_cv(custom_subsample)
18+
19+
assert cv is custom_subsample
20+
21+
def test_with_invalid_input(self):
22+
regressor = JackknifeAfterBootstrapRegressor()
23+
24+
with pytest.raises(
25+
ValueError,
26+
match="resampling must be an integer or a Subsample instance"
27+
):
28+
regressor._check_and_convert_resampling_to_cv("invalid_input")

0 commit comments

Comments
 (0)