11
11
from sklearn .multioutput import MultiOutputClassifier
12
12
from sklearn .pipeline import Pipeline
13
13
from sklearn .utils import check_random_state
14
- from sklearn .utils .validation import (_check_y , _num_samples , check_is_fitted ,
15
- indexable )
14
+ from sklearn .utils .validation import (
15
+ _check_y ,
16
+ _num_samples ,
17
+ check_is_fitted ,
18
+ indexable ,
19
+ )
16
20
17
21
from numpy .typing import ArrayLike , NDArray
18
22
from .control_risk .crc_rcps import find_lambda_star , get_r_hat_plus
@@ -127,7 +131,7 @@ class PrecisionRecallController(BaseEstimator, ClassifierMixin):
127
131
[1] Lihua Lei Jitendra Malik Stephen Bates, Anastasios Angelopoulos
128
132
and Michael I. Jordan. Distribution-free, risk-controlling prediction
129
133
sets. CoRR, abs/2101.02703, 2021.
130
- URL https://arxiv.org/abs/2101.02703.39
134
+ URL https://arxiv.org/abs/2101.02703
131
135
132
136
[2] Angelopoulos, Anastasios N., Stephen, Bates, Adam, Fisch, Lihua,
133
137
Lei, and Tal, Schuster. "Conformal Risk Control." (2022).
@@ -153,30 +157,25 @@ class PrecisionRecallController(BaseEstimator, ClassifierMixin):
153
157
[False True False]
154
158
[False True False]]
155
159
"""
156
- valid_methods_by_metric_ = {
157
- "precision" : ["ltt" ],
158
- "recall" : ["rcps" , "crc" ]
159
- }
160
+
161
+ valid_methods_by_metric_ = {"precision" : ["ltt" ], "recall" : ["rcps" , "crc" ]}
160
162
valid_methods = list (chain (* valid_methods_by_metric_ .values ()))
161
163
valid_metric_ = list (valid_methods_by_metric_ .keys ())
162
164
valid_bounds_ = ["hoeffding" , "bernstein" , "wsr" , None ]
163
165
lambdas = np .arange (0 , 1 , 0.01 )
164
166
n_lambdas = len (lambdas )
165
- fit_attributes = [
166
- "single_estimator_" ,
167
- "risks"
168
- ]
167
+ fit_attributes = ["single_estimator_" , "risks" ]
169
168
sigma_init = 0.25 # Value given in the paper [1]
170
- cal_size = .3
169
+ cal_size = 0 .3
171
170
172
171
def __init__ (
173
172
self ,
174
173
estimator : Optional [ClassifierMixin ] = None ,
175
- metric_control : Optional [str ] = ' recall' ,
174
+ metric_control : Optional [str ] = " recall" ,
176
175
method : Optional [str ] = None ,
177
176
n_jobs : Optional [int ] = None ,
178
177
random_state : Optional [Union [int , np .random .RandomState ]] = None ,
179
- verbose : int = 0
178
+ verbose : int = 0 ,
180
179
) -> None :
181
180
self .estimator = estimator
182
181
self .metric_control = metric_control
@@ -211,16 +210,18 @@ def _check_method(self) -> None:
211
210
self .method = cast (str , self .method )
212
211
self .metric_control = cast (str , self .metric_control )
213
212
214
- if self .method not in self .valid_methods_by_metric_ [
215
- self .metric_control
216
- ]:
213
+ if (
214
+ self .method
215
+ not in self .valid_methods_by_metric_ [self .metric_control ]
216
+ ):
217
217
raise ValueError (
218
218
"Invalid method for metric: "
219
- + "You are controlling " + self .metric_control
220
- + " and you are using invalid method: " + self .method
221
- + ". Use instead: " + "" .join (self .valid_methods_by_metric_ [
222
- self .metric_control ]
223
- )
219
+ + "You are controlling "
220
+ + self .metric_control
221
+ + " and you are using invalid method: "
222
+ + self .method
223
+ + ". Use instead: "
224
+ + "" .join (self .valid_methods_by_metric_ [self .metric_control ])
224
225
)
225
226
226
227
def _check_all_labelled (self , y : NDArray ) -> None :
@@ -241,9 +242,7 @@ def _check_all_labelled(self, y: NDArray) -> None:
241
242
"""
242
243
if not (y .sum (axis = 1 ) > 0 ).all ():
243
244
raise ValueError (
244
- "Invalid y. "
245
- "All observations should contain at "
246
- "least one label."
245
+ "Invalid y. All observations should contain at least one label."
247
246
)
248
247
249
248
def _check_delta (self , delta : Optional [float ]):
@@ -268,8 +267,7 @@ def _check_delta(self, delta: Optional[float]):
268
267
"""
269
268
if (not isinstance (delta , float )) and (delta is not None ):
270
269
raise ValueError (
271
- "Invalid delta. "
272
- f"delta must be a float, not a { type (delta )} "
270
+ f"Invalid delta. delta must be a float, not a { type (delta )} "
273
271
)
274
272
if (self .method == "rcps" ) or (self .method == "ltt" ):
275
273
if delta is None :
@@ -278,11 +276,8 @@ def _check_delta(self, delta: Optional[float]):
278
276
"delta cannot be ``None`` when controlling "
279
277
"Recall with RCPS or Precision with LTT"
280
278
)
281
- elif ((delta <= 0 ) or (delta >= 1 )):
282
- raise ValueError (
283
- "Invalid delta. "
284
- "delta must be in ]0, 1["
285
- )
279
+ elif (delta <= 0 ) or (delta >= 1 ):
280
+ raise ValueError ("Invalid delta. delta must be in ]0, 1[" )
286
281
if (self .method == "crc" ) and (delta is not None ):
287
282
warnings .warn (
288
283
"WARNING: you are using crc method, hence "
@@ -302,7 +297,8 @@ def _check_valid_index(self, alpha: NDArray):
302
297
if self .valid_index [i ] == []:
303
298
warnings .warn (
304
299
"Warning: LTT method has returned an empty sequence"
305
- + " for alpha=" + str (alpha [i ])
300
+ + " for alpha="
301
+ + str (alpha [i ])
306
302
)
307
303
308
304
def _check_estimator (
@@ -361,14 +357,12 @@ def _check_estimator(
361
357
"use partial_fit."
362
358
)
363
359
if (estimator is None ) and (_refit ):
364
- estimator = MultiOutputClassifier (
365
- LogisticRegression ()
366
- )
360
+ estimator = MultiOutputClassifier (LogisticRegression ())
367
361
X_train , X_conf , y_train , y_conf = train_test_split (
368
- X ,
369
- y ,
370
- test_size = self .conformalize_size ,
371
- random_state = self .random_state ,
362
+ X ,
363
+ y ,
364
+ test_size = self .conformalize_size ,
365
+ random_state = self .random_state ,
372
366
)
373
367
estimator .fit (X_train , y_train )
374
368
warnings .warn (
@@ -460,8 +454,7 @@ def _check_metric_control(self):
460
454
self .method = "ltt"
461
455
462
456
def _transform_pred_proba (
463
- self ,
464
- y_pred_proba : Union [Sequence [NDArray ], NDArray ]
457
+ self , y_pred_proba : Union [Sequence [NDArray ], NDArray ]
465
458
) -> NDArray :
466
459
"""If the output of the predict_proba is a list of arrays (output of
467
460
the ``predict_proba`` of ``MultiOutputClassifier``) we transform it
@@ -483,7 +476,7 @@ def _transform_pred_proba(
483
476
else :
484
477
y_pred_proba_stacked = np .stack (
485
478
y_pred_proba , # type: ignore
486
- axis = 0
479
+ axis = 0 ,
487
480
)[:, :, 1 ]
488
481
y_pred_proba_array = np .moveaxis (y_pred_proba_stacked , 0 , - 1 )
489
482
@@ -526,10 +519,7 @@ def partial_fit(
526
519
527
520
X , y = indexable (X , y )
528
521
_check_y (y , multi_output = True )
529
- estimator , X , y = self ._check_estimator (
530
- X , y , self .estimator ,
531
- _refit
532
- )
522
+ estimator , X , y = self ._check_estimator (X , y , self .estimator , _refit )
533
523
534
524
y = cast (NDArray , y )
535
525
X = cast (NDArray , X )
@@ -561,15 +551,11 @@ def partial_fit(
561
551
y_pred_proba_array = self ._transform_pred_proba (y_pred_proba )
562
552
if self .metric_control == "recall" :
563
553
partial_risk = compute_risk_recall (
564
- self .lambdas ,
565
- y_pred_proba_array ,
566
- y
554
+ self .lambdas , y_pred_proba_array , y
567
555
)
568
556
else : # self.metric_control == "precision"
569
557
partial_risk = compute_risk_precision (
570
- self .lambdas ,
571
- y_pred_proba_array ,
572
- y
558
+ self .lambdas , y_pred_proba_array , y
573
559
)
574
560
self .risks = np .concatenate ([self .risks , partial_risk ], axis = 0 )
575
561
@@ -579,7 +565,7 @@ def fit(
579
565
self ,
580
566
X : ArrayLike ,
581
567
y : ArrayLike ,
582
- conformalize_size : Optional [float ] = .3
568
+ conformalize_size : Optional [float ] = 0.3 ,
583
569
) -> PrecisionRecallController :
584
570
"""
585
571
Fit the base estimator or use the fitted base estimator.
@@ -611,7 +597,7 @@ def predict(
611
597
X : ArrayLike ,
612
598
alpha : Optional [Union [float , Iterable [float ]]] = None ,
613
599
delta : Optional [float ] = None ,
614
- bound : Optional [Union [str , None ]] = None
600
+ bound : Optional [Union [str , None ]] = None ,
615
601
) -> Union [NDArray , Tuple [NDArray , NDArray ]]:
616
602
"""
617
603
Prediction sets on new samples based on target confidence
@@ -674,35 +660,37 @@ def predict(
674
660
675
661
y_pred_proba_array = self ._transform_pred_proba (y_pred_proba )
676
662
y_pred_proba_array = np .repeat (
677
- y_pred_proba_array ,
678
- len (alpha_np ),
679
- axis = 2
663
+ y_pred_proba_array , len (alpha_np ), axis = 2
680
664
)
681
- if self .metric_control == ' precision' :
665
+ if self .metric_control == " precision" :
682
666
self .n_obs = len (self .risks )
683
667
self .r_hat = self .risks .mean (axis = 0 )
684
668
self .valid_index , self .p_values = ltt_procedure (
685
669
self .r_hat , alpha_np , delta , self .n_obs
686
670
)
687
671
self ._check_valid_index (alpha_np )
688
672
self .lambdas_star , self .r_star = find_lambda_control_star (
689
- self .r_hat , self .valid_index , self .lambdas
673
+ self .r_hat , self .valid_index , self .lambdas
690
674
)
691
675
y_pred_proba_array = (
692
- y_pred_proba_array >
693
- np .array (self .lambdas_star )[np .newaxis , np .newaxis , :]
676
+ y_pred_proba_array
677
+ > np .array (self .lambdas_star )[np .newaxis , np .newaxis , :]
694
678
)
695
679
696
680
else :
697
681
self .r_hat , self .r_hat_plus = get_r_hat_plus (
698
- self .risks , self .lambdas , self .method ,
699
- bound , delta , self .sigma_init
682
+ self .risks ,
683
+ self .lambdas ,
684
+ self .method ,
685
+ bound ,
686
+ delta ,
687
+ self .sigma_init ,
700
688
)
701
689
self .lambdas_star = find_lambda_star (
702
690
self .lambdas , self .r_hat_plus , alpha_np
703
691
)
704
692
y_pred_proba_array = (
705
- y_pred_proba_array >
706
- self .lambdas_star [np .newaxis , np .newaxis , :]
693
+ y_pred_proba_array
694
+ > self .lambdas_star [np .newaxis , np .newaxis , :]
707
695
)
708
696
return y_pred , y_pred_proba_array
0 commit comments