Skip to content

Commit ebf107a

Browse files
author
Thibault Cordier
committed
UPD: move check cv - cs function
1 parent b724c35 commit ebf107a

File tree

2 files changed

+16
-38
lines changed

2 files changed

+16
-38
lines changed

mapie/classification.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@
55

66
import numpy as np
77
from sklearn.base import BaseEstimator, ClassifierMixin
8-
from sklearn.model_selection import BaseCrossValidator
8+
from sklearn.model_selection import BaseCrossValidator, BaseShuffleSplit
99
from sklearn.preprocessing import LabelEncoder
1010
from sklearn.utils import check_random_state
1111
from sklearn.utils.validation import (_check_y, check_is_fitted, indexable)
1212

1313
from mapie._typing import ArrayLike, NDArray
1414
from mapie.conformity_scores import BaseClassificationScore
15+
from mapie.conformity_scores.sets.raps import RAPSConformityScore
1516
from mapie.conformity_scores.utils import (
1617
check_depreciated_size_raps, check_classification_conformity_score,
1718
check_target
@@ -39,6 +40,7 @@ class MapieClassifier(BaseEstimator, ClassifierMixin):
3940
If ``None``, estimator defaults to a ``LogisticRegression`` instance.
4041
4142
method: Optional[str]
43+
[DEPRECIATED see instead conformity_score]
4244
Method to choose for prediction interval estimates.
4345
Choose among:
4446
@@ -119,7 +121,7 @@ class MapieClassifier(BaseEstimator, ClassifierMixin):
119121
120122
By default ``None``.
121123
122-
conformity_score_function_: BaseClassificationScore
124+
conformity_score: BaseClassificationScore
123125
Score function that handle all that is related to conformity scores.
124126
125127
In any case, the `conformity_score` parameter takes precedence over the
@@ -378,12 +380,22 @@ def _check_fit_parameter(
378380
)
379381
check_depreciated_size_raps(size_raps)
380382
cs_estimator.set_external_attributes(
381-
cv=self.cv,
382383
classes=self.classes_,
383384
label_encoder=self.label_encoder_,
384385
size_raps=size_raps,
385386
random_state=self.random_state
386387
)
388+
if (
389+
isinstance(cs_estimator, RAPSConformityScore) and
390+
not (
391+
self.cv in ["split", "prefit"] or
392+
isinstance(self.cv, BaseShuffleSplit)
393+
)
394+
):
395+
raise ValueError(
396+
"RAPS method can only be used "
397+
"with ``cv='split'`` and ``cv='prefit'``."
398+
)
387399

388400
# Cast
389401
X, y_enc, y = cast(NDArray, X), cast(NDArray, y_enc), cast(NDArray, y)

mapie/conformity_scores/sets/raps.py

Lines changed: 1 addition & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22

33
import numpy as np
44
from sklearn.calibration import LabelEncoder
5-
from sklearn.model_selection import (BaseCrossValidator, BaseShuffleSplit,
6-
StratifiedShuffleSplit)
5+
from sklearn.model_selection import StratifiedShuffleSplit
76
from sklearn.utils import _safe_indexing
87
from sklearn.utils.validation import _num_samples
98

@@ -49,9 +48,6 @@ class RAPSConformityScore(APSConformityScore):
4948
quantiles_: ArrayLike of shape (n_alpha)
5049
The quantiles estimated from ``get_sets`` method.
5150
52-
cv: Union[int, str, BaseCrossValidator]
53-
The cross-validation strategy for computing scores.
54-
5551
label_encoder: LabelEncoder
5652
The label encoder used to encode the labels.
5753
@@ -60,8 +56,6 @@ class RAPSConformityScore(APSConformityScore):
6056
k_star for the RAPS method.
6157
"""
6258

63-
valid_cv_ = ["prefit", "split"]
64-
6559
def __init__(
6660
self,
6761
size_raps: Optional[float] = 0.2
@@ -72,7 +66,6 @@ def __init__(
7266
def set_external_attributes(
7367
self,
7468
*,
75-
cv: Optional[Union[str, BaseCrossValidator, BaseShuffleSplit]] = None,
7669
label_encoder: Optional[LabelEncoder] = None,
7770
size_raps: Optional[float] = None,
7871
**kwargs
@@ -82,11 +75,6 @@ def set_external_attributes(
8275
8376
Parameters
8477
----------
85-
cv: Optional[Union[int, str, BaseCrossValidator]]
86-
The cross-validation strategy for computing scores.
87-
88-
By default ``None``.
89-
9078
label_encoder: Optional[LabelEncoder]
9179
The label encoder used to encode the labels.
9280
@@ -99,28 +87,9 @@ def set_external_attributes(
9987
By default ``None``.
10088
"""
10189
super().set_external_attributes(**kwargs)
102-
self.cv = cast(Union[str, BaseCrossValidator, BaseShuffleSplit], cv)
10390
self.label_encoder_ = cast(LabelEncoder, label_encoder)
10491
self.size_raps = size_raps
10592

106-
def _check_cv(self):
107-
"""
108-
Check that if the method used is ``"raps"``, then
109-
the cross validation strategy is ``"prefit"``.
110-
111-
Raises
112-
------
113-
ValueError
114-
If ``method`` is ``"raps"`` and ``cv`` is not ``"prefit"``.
115-
"""
116-
if not (
117-
self.cv in self.valid_cv_ or isinstance(self.cv, BaseShuffleSplit)
118-
):
119-
raise ValueError(
120-
"RAPS method can only be used "
121-
f"with cv in {self.valid_cv_}."
122-
)
123-
12493
def split_data(
12594
self,
12695
X: NDArray,
@@ -162,9 +131,6 @@ def split_data(
162131
- NDArray of shape (n_samples,)
163132
- NDArray of shape (n_samples,)
164133
"""
165-
# Checks
166-
self._check_cv()
167-
168134
# Split data for raps method
169135
raps_split = StratifiedShuffleSplit(
170136
n_splits=1,

0 commit comments

Comments
 (0)