Skip to content

Commit 4c874fa

Browse files
author
gmartinonQM
committed
update tests
1 parent 53f8c4d commit 4c874fa

File tree

1 file changed

+16
-20
lines changed

1 file changed

+16
-20
lines changed

mapie/tests/test_estimators.py

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from sklearn.dummy import DummyRegressor
1212
from sklearn.linear_model import LinearRegression
1313
from sklearn.model_selection import LeaveOneOut, KFold, train_test_split
14-
from sklearn.pipeline import make_pipeline
14+
from sklearn.pipeline import Pipeline, make_pipeline
1515
from sklearn.exceptions import NotFittedError
1616
from sklearn.utils.estimator_checks import parametrize_with_checks
1717
from sklearn.utils.validation import check_is_fitted
@@ -159,25 +159,25 @@ def test_invalid_prefit_estimator(estimator: RegressorMixin) -> None:
159159
mapie.fit(X_toy, y_toy)
160160

161161

162-
def test_valid_prefit_raw_estimator() -> None:
163-
"""Test that fitted raw estimator with prefit cv raise no errors."""
164-
estimator = LinearRegression().fit(X_toy, y_toy)
162+
@pytest.mark.parametrize(
163+
"estimator", [
164+
LinearRegression(),
165+
make_pipeline(LinearRegression())
166+
]
167+
)
168+
def test_valid_prefit_estimator(estimator: RegressorMixin) -> None:
169+
"""Test that fitted estimators with prefit cv raise no errors."""
170+
estimator.fit(X_toy, y_toy)
165171
mapie = MapieRegressor(estimator=estimator, cv="prefit")
166172
mapie.fit(X_toy, y_toy)
167-
check_is_fitted(mapie.single_estimator_)
173+
if isinstance(estimator, Pipeline):
174+
check_is_fitted(mapie.single_estimator_[-1])
175+
else:
176+
check_is_fitted(mapie.single_estimator_)
168177
check_is_fitted(mapie, ["n_features_in_", "single_estimator_", "estimators_", "k_", "residuals_"])
169178
assert mapie.n_features_in_ == 1
170179

171180

172-
def test_valid_prefit_pipeline() -> None:
173-
"""Test that fitted pipeline with prefit cv raise no errors."""
174-
estimator = make_pipeline(LinearRegression()).fit(X_toy, y_toy)
175-
mapie = MapieRegressor(estimator=estimator, cv="prefit")
176-
mapie.fit(X_toy, y_toy)
177-
check_is_fitted(mapie.single_estimator_[-1])
178-
assert mapie.n_features_in_ == 1
179-
180-
181181
def test_invalid_prefit_estimator_shape() -> None:
182182
"""Test that estimators fitted with a wrong number of features raise errors."""
183183
estimator = LinearRegression().fit(X_reg, y_reg)
@@ -469,9 +469,7 @@ def test_results_prefit_ignore_method() -> None:
469469

470470

471471
def test_results_prefit_naive() -> None:
472-
"""
473-
Test that prefit, fit and predict on the same dataset is equivalent to the "naive" method.
474-
"""
472+
"""Test that prefit, fit and predict on the same dataset is equivalent to the "naive" method."""
475473
estimator = LinearRegression().fit(X_reg, y_reg)
476474
mapie = MapieRegressor(alpha=0.05, estimator=estimator, cv="prefit")
477475
mapie.fit(X_reg, y_reg)
@@ -483,9 +481,7 @@ def test_results_prefit_naive() -> None:
483481

484482

485483
def test_results_prefit() -> None:
486-
"""
487-
Test prefit results on a standard train/validation/test split.
488-
"""
484+
"""Test prefit results on a standard train/validation/test split."""
489485
X_train_val, X_test, y_train_val, y_test = train_test_split(X_reg, y_reg, test_size=1/10, random_state=1)
490486
X_train, X_val, y_train, y_val = train_test_split(X_train_val, y_train_val, test_size=1/9, random_state=1)
491487
estimator = LinearRegression().fit(X_train, y_train)

0 commit comments

Comments
 (0)