Skip to content

Commit 462207a

Browse files
authored
Merge pull request #148 from scikit-learn-contrib/Remove-the-check_image_is_input
ENH : remove image_input argument add modify tests accordingly
2 parents 1cbeade + 7221815 commit 462207a

File tree

3 files changed

+5
-96
lines changed

3 files changed

+5
-96
lines changed

mapie/classification.py

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@
2727
check_alpha_and_n_samples,
2828
check_n_jobs,
2929
check_verbose,
30-
check_input_is_image,
31-
fit_estimator,
30+
fit_estimator
3231
)
3332
from ._compatibility import np_quantile
3433

@@ -252,13 +251,8 @@ def _check_estimator(
252251
If the estimator is not fitted and ``cv`` attribute is "prefit".
253252
"""
254253
if estimator is None:
255-
if not self.image_input:
256-
return LogisticRegression(multi_class="multinomial").fit(X, y)
257-
else:
258-
raise ValueError(
259-
"Default LogisticRegression's input can't be an image."
260-
"Please provide a proper model."
261-
)
254+
return LogisticRegression(multi_class="multinomial").fit(X, y)
255+
262256
if isinstance(estimator, Pipeline):
263257
est = estimator[-1]
264258
else:
@@ -635,7 +629,6 @@ def fit(
635629
self,
636630
X: ArrayLike,
637631
y: ArrayLike,
638-
image_input: Optional[bool] = False,
639632
sample_weight: Optional[ArrayLike] = None,
640633
) -> MapieClassifier:
641634
"""
@@ -649,13 +642,6 @@ def fit(
649642
y : ArrayLike of shape (n_samples,)
650643
Training labels.
651644
652-
image_input: Optional[bool] = False
653-
Whether or not the X input is an image. If True, you must provide
654-
a model that accepts image as input (e.g., a Neural Network). All
655-
Scikit-learn classifiers only accept two-dimensional inputs.
656-
657-
By default False.
658-
659645
sample_weight : Optional[ArrayLike] of shape (n_samples,)
660646
Sample weights for fitting the out-of-fold models.
661647
If None, then samples are equally weighted.
@@ -671,12 +657,10 @@ def fit(
671657
The model itself.
672658
"""
673659
# Checks
674-
self.image_input = image_input
675660
self._check_parameters()
676661
cv = check_cv(self.cv)
677662
estimator = self._check_estimator(X, y, self.estimator)
678-
if self.image_input:
679-
check_input_is_image(X)
663+
680664
X, y = indexable(X, y)
681665
y = _check_y(y)
682666
assert type_of_target(y) == "multiclass"
@@ -849,8 +833,6 @@ def predict(
849833
include_last_label = self._check_include_last_label(include_last_label)
850834
alpha = cast(Optional[NDArray], check_alpha(alpha))
851835
check_is_fitted(self, self.fit_attributes)
852-
if self.image_input:
853-
check_input_is_image(X)
854836

855837
# Estimate prediction sets
856838
y_pred = self.single_estimator_.predict(X)

mapie/tests/test_classification.py

Lines changed: 1 addition & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -415,10 +415,6 @@
415415
}
416416
]
417417

418-
X_WRONG_IMAGE = [
419-
np.zeros((3, 1024, 1024, 3, 1)),
420-
np.zeros((3, 512))
421-
]
422418
X_good_image = np.zeros((3, 1024, 1024, 3))
423419
y_toy_image = np.array([0, 0, 1])
424420

@@ -827,7 +823,7 @@ def test_image_cumulated_scores(X: Dict[str, ArrayLike]) -> None:
827823
cv="prefit",
828824
random_state=42
829825
)
830-
mapie.fit(cumclf.X_calib, cumclf.y_calib, image_input=True)
826+
mapie.fit(cumclf.X_calib, cumclf.y_calib)
831827
np.testing.assert_allclose(mapie.conformity_scores_, cumclf.y_calib_scores)
832828
# predict
833829
_, y_ps = mapie.predict(
@@ -894,51 +890,6 @@ def test_classifier_without_classes_attribute(
894890
mapie.fit(X_toy, y_toy)
895891

896892

897-
@pytest.mark.parametrize("X_wrong_image", X_WRONG_IMAGE)
898-
def test_wrong_image_shape_fit(X_wrong_image: ArrayLike) -> None:
899-
"""
900-
Test that ValueError is raised if image has not 3 or 4 dimensions in fit.
901-
"""
902-
cumclf = ImageClassifier(X_wrong_image, y_toy_image)
903-
cumclf.fit(cumclf.X_calib, cumclf.y_calib)
904-
mapie = MapieClassifier(
905-
cumclf,
906-
method="cumulated_score",
907-
cv="prefit",
908-
random_state=42
909-
)
910-
with pytest.raises(ValueError, match=r"Invalid X.*"):
911-
mapie.fit(cumclf.X_calib, cumclf.y_calib, image_input=True)
912-
913-
914-
@pytest.mark.parametrize("X_wrong_image", X_WRONG_IMAGE)
915-
def test_wrong_image_shape_predict(X_wrong_image: ArrayLike) -> None:
916-
"""
917-
Test that ValueError is raised if image has not
918-
3 or 4 dimensions in predict.
919-
"""
920-
cumclf = ImageClassifier(X_good_image, y_toy_image)
921-
cumclf.fit(cumclf.X_calib, cumclf.y_calib)
922-
mapie = MapieClassifier(
923-
cumclf,
924-
method="cumulated_score",
925-
cv="prefit",
926-
random_state=42
927-
)
928-
mapie.fit(cumclf.X_calib, cumclf.y_calib, image_input=True,)
929-
with pytest.raises(ValueError, match=r"Invalid X.*"):
930-
mapie.predict(X_wrong_image)
931-
932-
933-
def test_undefined_model() -> None:
934-
"""
935-
Test ValueError is raised if no model is specified with image input.
936-
"""
937-
mapie = MapieClassifier()
938-
with pytest.raises(ValueError, match=r"LogisticRegression's input.*"):
939-
mapie.fit(X_good_image, y_toy_image, image_input=True,)
940-
941-
942893
@pytest.mark.parametrize("method", WRONG_METHODS)
943894
def test_method_error_in_fit(monkeypatch: Any, method: str) -> None:
944895
"""Test else condition for the method in .fit"""

mapie/utils.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -404,27 +404,3 @@ def check_nan_in_aposteriori_prediction(X: ArrayLike) -> None:
404404
+ "belongs to every resamplings.\n"
405405
"Increase the number of resamplings"
406406
)
407-
408-
409-
def check_input_is_image(X: ArrayLike) -> None:
410-
"""
411-
Check if the image has 3 or 4 dimensions
412-
413-
Parameters
414-
----------
415-
X: Union[
416-
ArrayLike[n_samples, width, height],
417-
ArrayLike[n_samples, width, height, n_channels]
418-
]
419-
Image input
420-
421-
Raises
422-
------
423-
ValueError
424-
"""
425-
if len(np.array(X).shape) not in [3, 4]:
426-
raise ValueError(
427-
"Invalid X."
428-
"When X is an image, the number of dimensions"
429-
"must be equal to 3 or 4."
430-
)

0 commit comments

Comments
 (0)