22
33import numpy as np
44from sklearn .calibration import LabelEncoder
5- from sklearn .model_selection import (BaseCrossValidator , BaseShuffleSplit ,
6- StratifiedShuffleSplit )
5+ from sklearn .model_selection import StratifiedShuffleSplit
76from sklearn .utils import _safe_indexing
87from sklearn .utils .validation import _num_samples
98
@@ -49,9 +48,6 @@ class RAPSConformityScore(APSConformityScore):
4948 quantiles_: ArrayLike of shape (n_alpha)
5049 The quantiles estimated from ``get_sets`` method.
5150
52- cv: Union[int, str, BaseCrossValidator]
53- The cross-validation strategy for computing scores.
54-
5551 label_encoder: LabelEncoder
5652 The label encoder used to encode the labels.
5753
@@ -60,8 +56,6 @@ class RAPSConformityScore(APSConformityScore):
6056 k_star for the RAPS method.
6157 """
6258
63- valid_cv_ = ["prefit" , "split" ]
64-
6559 def __init__ (
6660 self ,
6761 size_raps : Optional [float ] = 0.2
@@ -72,7 +66,6 @@ def __init__(
7266 def set_external_attributes (
7367 self ,
7468 * ,
75- cv : Optional [Union [str , BaseCrossValidator , BaseShuffleSplit ]] = None ,
7669 label_encoder : Optional [LabelEncoder ] = None ,
7770 size_raps : Optional [float ] = None ,
7871 ** kwargs
@@ -82,11 +75,6 @@ def set_external_attributes(
8275
8376 Parameters
8477 ----------
85- cv: Optional[Union[int, str, BaseCrossValidator]]
86- The cross-validation strategy for computing scores.
87-
88- By default ``None``.
89-
9078 label_encoder: Optional[LabelEncoder]
9179 The label encoder used to encode the labels.
9280
@@ -99,28 +87,9 @@ def set_external_attributes(
9987 By default ``None``.
10088 """
10189 super ().set_external_attributes (** kwargs )
102- self .cv = cast (Union [str , BaseCrossValidator , BaseShuffleSplit ], cv )
10390 self .label_encoder_ = cast (LabelEncoder , label_encoder )
10491 self .size_raps = size_raps
10592
106- def _check_cv (self ):
107- """
108- Check that if the method used is ``"raps"``, then
109- the cross validation strategy is ``"prefit"``.
110-
111- Raises
112- ------
113- ValueError
114- If ``method`` is ``"raps"`` and ``cv`` is not ``"prefit"``.
115- """
116- if not (
117- self .cv in self .valid_cv_ or isinstance (self .cv , BaseShuffleSplit )
118- ):
119- raise ValueError (
120- "RAPS method can only be used "
121- f"with cv in { self .valid_cv_ } ."
122- )
123-
12493 def split_data (
12594 self ,
12695 X : NDArray ,
@@ -162,9 +131,6 @@ def split_data(
162131 - NDArray of shape (n_samples,)
163132 - NDArray of shape (n_samples,)
164133 """
165- # Checks
166- self ._check_cv ()
167-
168134 # Split data for raps method
169135 raps_split = StratifiedShuffleSplit (
170136 n_splits = 1 ,
0 commit comments