55
66import numpy as np
77from sklearn .base import BaseEstimator , ClassifierMixin
8- from sklearn .model_selection import BaseCrossValidator , ShuffleSplit
8+ from sklearn .model_selection import (BaseCrossValidator , BaseShuffleSplit ,
9+ StratifiedShuffleSplit )
910from sklearn .preprocessing import LabelEncoder , label_binarize
1011from sklearn .utils import _safe_indexing , check_random_state
1112from sklearn .utils .multiclass import (check_classification_targets ,
@@ -301,9 +302,9 @@ def _check_raps(self):
301302 ValueError
302303 If ``method`` is ``"raps"`` and ``cv`` is not ``"prefit"``.
303304 """
304- if (self .method == "raps" ) and (
305- (self .cv not in self .raps_valid_cv_ )
306- or isinstance (self .cv , ShuffleSplit )
305+ if (self .method == "raps" ) and not (
306+ (self .cv in self .raps_valid_cv_ )
307+ or isinstance (self .cv , BaseShuffleSplit )
307308 ):
308309 raise ValueError (
309310 "RAPS method can only be used "
@@ -926,7 +927,7 @@ def _check_fit_parameter(
926927 y: ArrayLike
927928 Target values.
928929
929- sample_weight: Optional[NDArray ] of shape (n_samples,)
930+ sample_weight: Optional[ArrayLike ] of shape (n_samples,)
930931 Non-null sample weights.
931932
932933 groups: Optional[ArrayLike] of shape (n_samples,)
@@ -940,8 +941,8 @@ def _check_fit_parameter(
940941 Optional[Union[int, str, BaseCrossValidator]],
941942 ArrayLike, NDArray, NDArray, Optional[NDArray],
942943 Optional[NDArray], ArrayLike]
943-
944944 Parameters checked
945+
945946 Raises
946947 ------
947948 ValueError
@@ -952,7 +953,6 @@ def _check_fit_parameter(
952953 If ``cv`` is `"prefit"`` or ``"split"`` and ``method`` is not
953954 ``"base"``.
954955 """
955-
956956 self ._check_parameters ()
957957 cv = check_cv (
958958 self .cv , test_size = self .test_size , random_state = self .random_state
@@ -979,15 +979,15 @@ def _check_fit_parameter(
979979 self .label_encoder_ = enc
980980 self ._check_target (y )
981981
982- return ( estimator , cv , X , y , y_enc , sample_weight , groups , n_samples )
982+ return estimator , cv , X , y , y_enc , sample_weight , groups , n_samples
983983
984984 def _split_data (
985985 self ,
986- X ,
987- y_enc ,
988- sample_weight ,
989- groups ,
990- size_raps
986+ X : ArrayLike ,
987+ y_enc : ArrayLike ,
988+ sample_weight : Optional [ ArrayLike ] = None ,
989+ groups : Optional [ ArrayLike ] = None ,
990+ size_raps : Optional [ float ] = None ,
991991 ):
992992 """Split data for raps method
993993
@@ -999,7 +999,7 @@ def _split_data(
999999 y_enc: ArrayLike
10001000 Target values as normalized encodings.
10011001
1002- sample_weight: Optional[NDArray ] of shape (n_samples,)
1002+ sample_weight: Optional[ArrayLike ] of shape (n_samples,)
10031003 Non-null sample weights.
10041004
10051005 groups: Optional[ArrayLike] of shape (n_samples,)
@@ -1015,34 +1015,38 @@ def _split_data(
10151015 -------
10161016 Tuple[NDArray, NDArray, NDArray, NDArray, Optional[NDArray],
10171017 Optional[NDArray]]
1018-
10191018 - NDArray of shape (n_samples, n_features)
10201019 - NDArray of shape (n_samples,)
10211020 - NDArray of shape (n_samples,)
10221021 - NDArray of shape (n_samples,)
10231022 - NDArray of shape (n_samples,)
10241023 - NDArray of shape (n_samples,)
10251024 """
1026- raps_split = ShuffleSplit (
1027- 1 , test_size = size_raps , random_state = self .random_state
1025+ # Split data for raps method
1026+ raps_split = StratifiedShuffleSplit (
1027+ n_splits = 1 , test_size = size_raps , random_state = self .random_state
10281028 )
1029- train_raps_index , val_raps_index = next (raps_split .split (X ))
1029+ train_raps_index , val_raps_index = next (raps_split .split (X , y_enc ))
10301030 X , self .X_raps , y_enc , self .y_raps = (
10311031 _safe_indexing (X , train_raps_index ),
10321032 _safe_indexing (X , val_raps_index ),
10331033 _safe_indexing (y_enc , train_raps_index ),
10341034 _safe_indexing (y_enc , val_raps_index ),
10351035 )
1036+
1037+ # Decode y_raps for use in the RAPS method
10361038 self .y_raps_no_enc = self .label_encoder_ .inverse_transform (self .y_raps )
10371039 y = self .label_encoder_ .inverse_transform (y_enc )
1040+
1041+ # Cast to NDArray for type checking
10381042 y_enc = cast (NDArray , y_enc )
10391043 n_samples = _num_samples (y_enc )
10401044 if sample_weight is not None :
1041- sample_weight = sample_weight [train_raps_index ]
10421045 sample_weight = cast (NDArray , sample_weight )
1046+ sample_weight = sample_weight [train_raps_index ]
10431047 if groups is not None :
1044- groups = groups [train_raps_index ]
10451048 groups = cast (NDArray , groups )
1049+ groups = groups [train_raps_index ]
10461050
10471051 return X , y_enc , y , n_samples , sample_weight , groups
10481052
@@ -1126,12 +1130,13 @@ def fit(
11261130 self .test_size ,
11271131 self .verbose ,
11281132 )
1129-
1133+ # Fit the prediction function
11301134 self .estimator_ = self .estimator_ .fit (
11311135 X , y , y_enc = y_enc , sample_weight = sample_weight , groups = groups ,
11321136 ** fit_params
11331137 )
11341138
1139+ # Predict on calibration data
11351140 y_pred_proba , y , y_enc = self .estimator_ .predict_proba_calib (
11361141 X , y , y_enc , groups
11371142 )
@@ -1176,10 +1181,6 @@ def fit(
11761181 "Invalid method. " f"Allowed values are { self .valid_methods_ } ."
11771182 )
11781183
1179- # In split-CP, we keep only the model fitted on train dataset
1180- if isinstance (cv , ShuffleSplit ):
1181- self .estimator_ .single_estimator_ = self .estimator_ .estimators_ [0 ]
1182-
11831184 return self
11841185
11851186 def predict (
@@ -1278,9 +1279,12 @@ def predict(
12781279 alpha_np = cast (NDArray , alpha )
12791280 check_alpha_and_n_samples (alpha_np , n )
12801281
1281- y_pred_proba = self .estimator_ .predict (X , alpha_np , agg_scores )
1282- # Check that sum of probas is equal to 1
1282+ y_pred_proba = self .estimator_ .predict (X , agg_scores )
12831283 y_pred_proba = self ._check_proba_normalized (y_pred_proba , axis = 1 )
1284+ if agg_scores != "crossval" :
1285+ y_pred_proba = np .repeat (
1286+ y_pred_proba [:, :, np .newaxis ], len (alpha_np ), axis = 2
1287+ )
12841288
12851289 # Choice of the quantile
12861290 if self .method == "naive" :
0 commit comments