11
11
from sklearn .multioutput import MultiOutputClassifier
12
12
from sklearn .pipeline import Pipeline
13
13
from sklearn .utils import check_random_state
14
- from sklearn .utils .validation import (
15
- _check_y ,
16
- _num_samples ,
17
- check_is_fitted ,
18
- indexable ,
19
- )
14
+ from sklearn .utils .validation import (_check_y , _num_samples , check_is_fitted ,
15
+ indexable )
20
16
21
17
from numpy .typing import ArrayLike , NDArray
22
18
from .control_risk .crc_rcps import find_lambda_star , get_r_hat_plus
@@ -157,25 +153,30 @@ class PrecisionRecallController(BaseEstimator, ClassifierMixin):
157
153
[False True False]
158
154
[False True False]]
159
155
"""
160
-
161
- valid_methods_by_metric_ = {"precision" : ["ltt" ], "recall" : ["rcps" , "crc" ]}
156
+ valid_methods_by_metric_ = {
157
+ "precision" : ["ltt" ],
158
+ "recall" : ["rcps" , "crc" ]
159
+ }
162
160
valid_methods = list (chain (* valid_methods_by_metric_ .values ()))
163
161
valid_metric_ = list (valid_methods_by_metric_ .keys ())
164
162
valid_bounds_ = ["hoeffding" , "bernstein" , "wsr" , None ]
165
163
lambdas = np .arange (0 , 1 , 0.01 )
166
164
n_lambdas = len (lambdas )
167
- fit_attributes = ["single_estimator_" , "risks" ]
165
+ fit_attributes = [
166
+ "single_estimator_" ,
167
+ "risks"
168
+ ]
168
169
sigma_init = 0.25 # Value given in the paper [1]
169
- cal_size = 0 .3
170
+ cal_size = .3
170
171
171
172
def __init__ (
172
173
self ,
173
174
estimator : Optional [ClassifierMixin ] = None ,
174
- metric_control : Optional [str ] = " recall" ,
175
+ metric_control : Optional [str ] = ' recall' ,
175
176
method : Optional [str ] = None ,
176
177
n_jobs : Optional [int ] = None ,
177
178
random_state : Optional [Union [int , np .random .RandomState ]] = None ,
178
- verbose : int = 0 ,
179
+ verbose : int = 0
179
180
) -> None :
180
181
self .estimator = estimator
181
182
self .metric_control = metric_control
@@ -210,18 +211,16 @@ def _check_method(self) -> None:
210
211
self .method = cast (str , self .method )
211
212
self .metric_control = cast (str , self .metric_control )
212
213
213
- if (
214
- self .method
215
- not in self .valid_methods_by_metric_ [self .metric_control ]
216
- ):
214
+ if self .method not in self .valid_methods_by_metric_ [
215
+ self .metric_control
216
+ ]:
217
217
raise ValueError (
218
218
"Invalid method for metric: "
219
- + "You are controlling "
220
- + self .metric_control
221
- + " and you are using invalid method: "
222
- + self .method
223
- + ". Use instead: "
224
- + "" .join (self .valid_methods_by_metric_ [self .metric_control ])
219
+ + "You are controlling " + self .metric_control
220
+ + " and you are using invalid method: " + self .method
221
+ + ". Use instead: " + "" .join (self .valid_methods_by_metric_ [
222
+ self .metric_control ]
223
+ )
225
224
)
226
225
227
226
def _check_all_labelled (self , y : NDArray ) -> None :
@@ -242,7 +241,9 @@ def _check_all_labelled(self, y: NDArray) -> None:
242
241
"""
243
242
if not (y .sum (axis = 1 ) > 0 ).all ():
244
243
raise ValueError (
245
- "Invalid y. All observations should contain at least one label."
244
+ "Invalid y. "
245
+ "All observations should contain at "
246
+ "least one label."
246
247
)
247
248
248
249
def _check_delta (self , delta : Optional [float ]):
@@ -267,7 +268,8 @@ def _check_delta(self, delta: Optional[float]):
267
268
"""
268
269
if (not isinstance (delta , float )) and (delta is not None ):
269
270
raise ValueError (
270
- f"Invalid delta. delta must be a float, not a { type (delta )} "
271
+ "Invalid delta. "
272
+ "delta must be a float, not a {type(delta)}"
271
273
)
272
274
if (self .method == "rcps" ) or (self .method == "ltt" ):
273
275
if delta is None :
@@ -276,8 +278,11 @@ def _check_delta(self, delta: Optional[float]):
276
278
"delta cannot be ``None`` when controlling "
277
279
"Recall with RCPS or Precision with LTT"
278
280
)
279
- elif (delta <= 0 ) or (delta >= 1 ):
280
- raise ValueError ("Invalid delta. delta must be in ]0, 1[" )
281
+ elif ((delta <= 0 ) or (delta >= 1 )):
282
+ raise ValueError (
283
+ "Invalid delta. "
284
+ "delta must be in ]0, 1["
285
+ )
281
286
if (self .method == "crc" ) and (delta is not None ):
282
287
warnings .warn (
283
288
"WARNING: you are using crc method, hence "
@@ -297,8 +302,7 @@ def _check_valid_index(self, alpha: NDArray):
297
302
if self .valid_index [i ] == []:
298
303
warnings .warn (
299
304
"Warning: LTT method has returned an empty sequence"
300
- + " for alpha="
301
- + str (alpha [i ])
305
+ + " for alpha=" + str (alpha [i ])
302
306
)
303
307
304
308
def _check_estimator (
@@ -357,12 +361,14 @@ def _check_estimator(
357
361
"use partial_fit."
358
362
)
359
363
if (estimator is None ) and (_refit ):
360
- estimator = MultiOutputClassifier (LogisticRegression ())
364
+ estimator = MultiOutputClassifier (
365
+ LogisticRegression ()
366
+ )
361
367
X_train , X_conf , y_train , y_conf = train_test_split (
362
- X ,
363
- y ,
364
- test_size = self .conformalize_size ,
365
- random_state = self .random_state ,
368
+ X ,
369
+ y ,
370
+ test_size = self .conformalize_size ,
371
+ random_state = self .random_state ,
366
372
)
367
373
estimator .fit (X_train , y_train )
368
374
warnings .warn (
@@ -454,7 +460,8 @@ def _check_metric_control(self):
454
460
self .method = "ltt"
455
461
456
462
def _transform_pred_proba (
457
- self , y_pred_proba : Union [Sequence [NDArray ], NDArray ]
463
+ self ,
464
+ y_pred_proba : Union [Sequence [NDArray ], NDArray ]
458
465
) -> NDArray :
459
466
"""If the output of the predict_proba is a list of arrays (output of
460
467
the ``predict_proba`` of ``MultiOutputClassifier``) we transform it
@@ -476,7 +483,7 @@ def _transform_pred_proba(
476
483
else :
477
484
y_pred_proba_stacked = np .stack (
478
485
y_pred_proba , # type: ignore
479
- axis = 0 ,
486
+ axis = 0
480
487
)[:, :, 1 ]
481
488
y_pred_proba_array = np .moveaxis (y_pred_proba_stacked , 0 , - 1 )
482
489
@@ -519,7 +526,10 @@ def partial_fit(
519
526
520
527
X , y = indexable (X , y )
521
528
_check_y (y , multi_output = True )
522
- estimator , X , y = self ._check_estimator (X , y , self .estimator , _refit )
529
+ estimator , X , y = self ._check_estimator (
530
+ X , y , self .estimator ,
531
+ _refit
532
+ )
523
533
524
534
y = cast (NDArray , y )
525
535
X = cast (NDArray , X )
@@ -551,11 +561,15 @@ def partial_fit(
551
561
y_pred_proba_array = self ._transform_pred_proba (y_pred_proba )
552
562
if self .metric_control == "recall" :
553
563
partial_risk = compute_risk_recall (
554
- self .lambdas , y_pred_proba_array , y
564
+ self .lambdas ,
565
+ y_pred_proba_array ,
566
+ y
555
567
)
556
568
else : # self.metric_control == "precision"
557
569
partial_risk = compute_risk_precision (
558
- self .lambdas , y_pred_proba_array , y
570
+ self .lambdas ,
571
+ y_pred_proba_array ,
572
+ y
559
573
)
560
574
self .risks = np .concatenate ([self .risks , partial_risk ], axis = 0 )
561
575
@@ -565,7 +579,7 @@ def fit(
565
579
self ,
566
580
X : ArrayLike ,
567
581
y : ArrayLike ,
568
- conformalize_size : Optional [float ] = 0 .3 ,
582
+ conformalize_size : Optional [float ] = .3 ,
569
583
) -> PrecisionRecallController :
570
584
"""
571
585
Fit the base estimator or use the fitted base estimator.
@@ -597,7 +611,7 @@ def predict(
597
611
X : ArrayLike ,
598
612
alpha : Optional [Union [float , Iterable [float ]]] = None ,
599
613
delta : Optional [float ] = None ,
600
- bound : Optional [Union [str , None ]] = None ,
614
+ bound : Optional [Union [str , None ]] = None
601
615
) -> Union [NDArray , Tuple [NDArray , NDArray ]]:
602
616
"""
603
617
Prediction sets on new samples based on target confidence
@@ -660,37 +674,35 @@ def predict(
660
674
661
675
y_pred_proba_array = self ._transform_pred_proba (y_pred_proba )
662
676
y_pred_proba_array = np .repeat (
663
- y_pred_proba_array , len (alpha_np ), axis = 2
677
+ y_pred_proba_array ,
678
+ len (alpha_np ),
679
+ axis = 2
664
680
)
665
- if self .metric_control == " precision" :
681
+ if self .metric_control == ' precision' :
666
682
self .n_obs = len (self .risks )
667
683
self .r_hat = self .risks .mean (axis = 0 )
668
684
self .valid_index , self .p_values = ltt_procedure (
669
685
self .r_hat , alpha_np , delta , self .n_obs
670
686
)
671
687
self ._check_valid_index (alpha_np )
672
688
self .lambdas_star , self .r_star = find_lambda_control_star (
673
- self .r_hat , self .valid_index , self .lambdas
689
+ self .r_hat , self .valid_index , self .lambdas
674
690
)
675
691
y_pred_proba_array = (
676
- y_pred_proba_array
677
- > np .array (self .lambdas_star )[np .newaxis , np .newaxis , :]
692
+ y_pred_proba_array >
693
+ np .array (self .lambdas_star )[np .newaxis , np .newaxis , :]
678
694
)
679
695
680
696
else :
681
697
self .r_hat , self .r_hat_plus = get_r_hat_plus (
682
- self .risks ,
683
- self .lambdas ,
684
- self .method ,
685
- bound ,
686
- delta ,
687
- self .sigma_init ,
698
+ self .risks , self .lambdas , self .method ,
699
+ bound , delta , self .sigma_init
688
700
)
689
701
self .lambdas_star = find_lambda_star (
690
702
self .lambdas , self .r_hat_plus , alpha_np
691
703
)
692
704
y_pred_proba_array = (
693
- y_pred_proba_array
694
- > self .lambdas_star [np .newaxis , np .newaxis , :]
705
+ y_pred_proba_array >
706
+ self .lambdas_star [np .newaxis , np .newaxis , :]
695
707
)
696
708
return y_pred , y_pred_proba_array
0 commit comments