315315 agg_scores = "mean"
316316 )
317317 ),
318+ "raps" : (
319+ Params (
320+ method = "raps" ,
321+ cv = "prefit" ,
322+ test_size = None ,
323+ random_state = random_state
324+ ),
325+ ParamsPredict (
326+ include_last_label = True ,
327+ agg_scores = "mean"
328+ )
329+ ),
330+ "raps_split" : (
331+ Params (
332+ method = "raps" ,
333+ cv = StratifiedShuffleSplit (
334+ n_splits = 1 , train_size = 0.5 , random_state = random_state
335+ ),
336+ test_size = None ,
337+ random_state = random_state
338+ ),
339+ ParamsPredict (
340+ include_last_label = True ,
341+ agg_scores = "mean"
342+ )
343+ ),
344+ "raps_randomized" : (
345+ Params (
346+ method = "raps" ,
347+ cv = "prefit" ,
348+ test_size = None ,
349+ random_state = random_state
350+ ),
351+ ParamsPredict (
352+ include_last_label = "randomized" ,
353+ agg_scores = "mean"
354+ )
355+ ),
318356}
319357
320358STRATEGIES_BINARY = {
365403 include_last_label = False ,
366404 agg_scores = "crossval"
367405 )
368- )
406+ ),
369407}
370408
371409COVERAGES = {
389427 "naive_split" : 5 / 9 ,
390428 "top_k" : 1.0 ,
391429 "top_k_split" : 1.0 ,
392- "raps" : 6 / 9 ,
393- "raps_randomized" : 3 / 9
394430}
395431
396432COVERAGES_BINARY = {
710746 random_state = random_state ,
711747)
712748
713- LARGE_STRATEGIES = {
714- "lac" : (
715- Params (
716- method = "lac" ,
717- cv = "prefit" ,
718- test_size = None ,
719- random_state = random_state
720- ),
721- ParamsPredict (
722- include_last_label = False ,
723- agg_scores = "mean"
724- )
725- ),
726- "lac_split" : (
727- Params (
728- method = "lac" ,
729- cv = "split" ,
730- test_size = 0.5 ,
731- random_state = random_state
732- ),
733- ParamsPredict (
734- include_last_label = False ,
735- agg_scores = "mean"
736- )
737- ),
738- "aps" : (
739- Params (
740- method = "aps" ,
741- cv = "prefit" ,
742- test_size = None ,
743- random_state = random_state
744- ),
745- ParamsPredict (
746- include_last_label = True ,
747- agg_scores = "mean"
748- )
749- ),
750- "aps_split" : (
751- Params (
752- method = "aps" ,
753- cv = "split" ,
754- test_size = 0.5 ,
755- random_state = random_state
756- ),
757- ParamsPredict (
758- include_last_label = True ,
759- agg_scores = "mean"
760- )
761- ),
762- "aps_randomized" : (
763- Params (
764- method = "aps" ,
765- cv = "prefit" ,
766- test_size = None ,
767- random_state = random_state
768- ),
769- ParamsPredict (
770- include_last_label = "randomized" ,
771- agg_scores = "mean"
772- )
773- ),
774- "naive" : (
775- Params (
776- method = "naive" ,
777- cv = "prefit" ,
778- test_size = None ,
779- random_state = random_state
780- ),
781- ParamsPredict (
782- include_last_label = True ,
783- agg_scores = "mean"
784- )
785- ),
786- "naive_split" : (
787- Params (
788- method = "naive" ,
789- cv = "split" ,
790- test_size = 0.5 ,
791- random_state = random_state
792- ),
793- ParamsPredict (
794- include_last_label = True ,
795- agg_scores = "mean"
796- )
797- ),
798- "top_k" : (
799- Params (
800- method = "top_k" ,
801- cv = "prefit" ,
802- test_size = None ,
803- random_state = random_state
804- ),
805- ParamsPredict (
806- include_last_label = True ,
807- agg_scores = "mean"
808- )
809- ),
810- "top_k_split" : (
811- Params (
812- method = "top_k" ,
813- cv = "split" ,
814- test_size = 0.5 ,
815- random_state = random_state
816- ),
817- ParamsPredict (
818- include_last_label = True ,
819- agg_scores = "mean"
820- )
821- ),
822- "raps" : (
823- Params (
824- method = "raps" ,
825- cv = "prefit" ,
826- test_size = None ,
827- random_state = random_state
828- ),
829- ParamsPredict (
830- include_last_label = True ,
831- agg_scores = "mean"
832- )
833- ),
834- "raps_split" : (
835- Params (
836- method = "raps" ,
837- cv = StratifiedShuffleSplit (
838- n_splits = 1 , train_size = 0.5 , random_state = random_state
839- ),
840- test_size = None ,
841- random_state = random_state
842- ),
843- ParamsPredict (
844- include_last_label = True ,
845- agg_scores = "mean"
846- )
847- ),
848- "raps_randomized" : (
849- Params (
850- method = "raps" ,
851- cv = "prefit" ,
852- test_size = None ,
853- random_state = random_state
854- ),
855- ParamsPredict (
856- include_last_label = "randomized" ,
857- agg_scores = "mean"
858- )
859- ),
860- }
861-
862749LARGE_COVERAGES = {
863750 "lac" : 0.802 ,
864751 "lac_split" : 0.842 ,
865- "aps " : 0.928 ,
866- "aps_split " : 0.93 ,
752+ "aps_include " : 0.928 ,
753+ "aps_include_split " : 0.93 ,
867754 "aps_randomized" : 0.802 ,
868755 "naive" : 0.936 ,
869756 "naive_split" : 0.914 ,
@@ -1046,9 +933,9 @@ def test_binary_classif_same_result() -> None:
1046933@pytest .mark .parametrize ("strategy" , [* STRATEGIES ])
1047934def test_valid_estimator (strategy : str ) -> None :
1048935 """Test that valid estimators are not corrupted, for all strategies."""
1049- clf = LogisticRegression ().fit (X_toy , y_toy )
936+ clf = LogisticRegression ().fit (X , y )
1050937 mapie_clf = MapieClassifier (estimator = clf , ** STRATEGIES [strategy ][0 ])
1051- mapie_clf .fit (X_toy , y_toy )
938+ mapie_clf .fit (X , y )
1052939 assert (
1053940 isinstance (mapie_clf .estimator_ .single_estimator_ , LogisticRegression )
1054941 )
@@ -1500,11 +1387,9 @@ def test_valid_prediction(alpha: Any) -> None:
15001387 mapie_clf .predict (X_toy , alpha = alpha )
15011388
15021389
1503- @pytest .mark .parametrize ("strategy" , [* STRATEGIES ])
1390+ @pytest .mark .parametrize ("strategy" , [* COVERAGES ])
15041391def test_toy_dataset_predictions (strategy : str ) -> None :
15051392 """Test prediction sets estimated by MapieClassifier on a toy dataset"""
1506- if strategy == "aps_randomized_cv_crossval" :
1507- return
15081393 args_init , args_predict = STRATEGIES [strategy ]
15091394 if "split" not in strategy :
15101395 clf = LogisticRegression ().fit (X_toy , y_toy )
@@ -1525,10 +1410,10 @@ def test_toy_dataset_predictions(strategy: str) -> None:
15251410 )
15261411
15271412
1528- @pytest .mark .parametrize ("strategy" , [* LARGE_STRATEGIES ])
1413+ @pytest .mark .parametrize ("strategy" , [* LARGE_COVERAGES ])
15291414def test_large_dataset_predictions (strategy : str ) -> None :
15301415 """Test prediction sets estimated by MapieClassifier on a larger dataset"""
1531- args_init , args_predict = LARGE_STRATEGIES [strategy ]
1416+ args_init , args_predict = STRATEGIES [strategy ]
15321417 if "split" not in strategy :
15331418 clf = LogisticRegression ().fit (X , y )
15341419 else :
@@ -1748,13 +1633,15 @@ def test_pred_loof_isnan() -> None:
17481633@pytest .mark .parametrize ("strategy" , [* STRATEGIES ])
17491634def test_pipeline_compatibility (strategy : str ) -> None :
17501635 """Check that MAPIE works on pipeline based on pandas dataframes"""
1636+ X = np .random .randint (0 , 100 , size = 100 )
1637+ X_cat = np .random .choice (["A" , "B" , "C" ], size = X .shape [0 ])
17511638 X = pd .DataFrame (
17521639 {
1753- "x_cat" : [ "A" , "A" , "B" , "A" , "A" , "B" ] ,
1754- "x_num" : [ 0 , 1 , 1 , 4 , np . nan , 5 ] ,
1640+ "x_cat" : X_cat ,
1641+ "x_num" : X ,
17551642 }
17561643 )
1757- y = pd . Series ([ 0 , 1 , 2 , 0 , 1 , 0 ])
1644+ y = np . random . randint ( 0 , 4 , size = ( 100 , 1 )) # 3 classes
17581645 numeric_preprocessor = Pipeline (
17591646 [
17601647 ("imputer" , SimpleImputer (strategy = "mean" )),
@@ -1833,7 +1720,7 @@ def test_regularize_conf_scores_shape(k_lambda) -> None:
18331720 Test that the conformity scores have the correct shape.
18341721 """
18351722 lambda_ , k = k_lambda [0 ], k_lambda [1 ]
1836- args_init , _ = LARGE_STRATEGIES ["raps" ]
1723+ args_init , _ = STRATEGIES ["raps" ]
18371724 clf = LogisticRegression ().fit (X , y )
18381725 mapie_clf = MapieClassifier (estimator = clf , ** args_init )
18391726 conf_scores = np .random .rand (100 , 1 )
0 commit comments