Skip to content

Commit 414b6bc

Browse files
FIX: extend tests to all methods
1 parent eb7ad23 commit 414b6bc

File tree

1 file changed

+52
-165
lines changed

1 file changed

+52
-165
lines changed

mapie/tests/test_classification.py

Lines changed: 52 additions & 165 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,44 @@
315315
agg_scores="mean"
316316
)
317317
),
318+
"raps": (
319+
Params(
320+
method="raps",
321+
cv="prefit",
322+
test_size=None,
323+
random_state=random_state
324+
),
325+
ParamsPredict(
326+
include_last_label=True,
327+
agg_scores="mean"
328+
)
329+
),
330+
"raps_split": (
331+
Params(
332+
method="raps",
333+
cv=StratifiedShuffleSplit(
334+
n_splits=1, train_size=0.5, random_state=random_state
335+
),
336+
test_size=None,
337+
random_state=random_state
338+
),
339+
ParamsPredict(
340+
include_last_label=True,
341+
agg_scores="mean"
342+
)
343+
),
344+
"raps_randomized": (
345+
Params(
346+
method="raps",
347+
cv="prefit",
348+
test_size=None,
349+
random_state=random_state
350+
),
351+
ParamsPredict(
352+
include_last_label="randomized",
353+
agg_scores="mean"
354+
)
355+
),
318356
}
319357

320358
STRATEGIES_BINARY = {
@@ -365,7 +403,7 @@
365403
include_last_label=False,
366404
agg_scores="crossval"
367405
)
368-
)
406+
),
369407
}
370408

371409
COVERAGES = {
@@ -389,8 +427,6 @@
389427
"naive_split": 5/9,
390428
"top_k": 1.0,
391429
"top_k_split": 1.0,
392-
"raps": 6/9,
393-
"raps_randomized": 3/9
394430
}
395431

396432
COVERAGES_BINARY = {
@@ -710,160 +746,11 @@
710746
random_state=random_state,
711747
)
712748

713-
LARGE_STRATEGIES = {
714-
"lac": (
715-
Params(
716-
method="lac",
717-
cv="prefit",
718-
test_size=None,
719-
random_state=random_state
720-
),
721-
ParamsPredict(
722-
include_last_label=False,
723-
agg_scores="mean"
724-
)
725-
),
726-
"lac_split": (
727-
Params(
728-
method="lac",
729-
cv="split",
730-
test_size=0.5,
731-
random_state=random_state
732-
),
733-
ParamsPredict(
734-
include_last_label=False,
735-
agg_scores="mean"
736-
)
737-
),
738-
"aps": (
739-
Params(
740-
method="aps",
741-
cv="prefit",
742-
test_size=None,
743-
random_state=random_state
744-
),
745-
ParamsPredict(
746-
include_last_label=True,
747-
agg_scores="mean"
748-
)
749-
),
750-
"aps_split": (
751-
Params(
752-
method="aps",
753-
cv="split",
754-
test_size=0.5,
755-
random_state=random_state
756-
),
757-
ParamsPredict(
758-
include_last_label=True,
759-
agg_scores="mean"
760-
)
761-
),
762-
"aps_randomized": (
763-
Params(
764-
method="aps",
765-
cv="prefit",
766-
test_size=None,
767-
random_state=random_state
768-
),
769-
ParamsPredict(
770-
include_last_label="randomized",
771-
agg_scores="mean"
772-
)
773-
),
774-
"naive": (
775-
Params(
776-
method="naive",
777-
cv="prefit",
778-
test_size=None,
779-
random_state=random_state
780-
),
781-
ParamsPredict(
782-
include_last_label=True,
783-
agg_scores="mean"
784-
)
785-
),
786-
"naive_split": (
787-
Params(
788-
method="naive",
789-
cv="split",
790-
test_size=0.5,
791-
random_state=random_state
792-
),
793-
ParamsPredict(
794-
include_last_label=True,
795-
agg_scores="mean"
796-
)
797-
),
798-
"top_k": (
799-
Params(
800-
method="top_k",
801-
cv="prefit",
802-
test_size=None,
803-
random_state=random_state
804-
),
805-
ParamsPredict(
806-
include_last_label=True,
807-
agg_scores="mean"
808-
)
809-
),
810-
"top_k_split": (
811-
Params(
812-
method="top_k",
813-
cv="split",
814-
test_size=0.5,
815-
random_state=random_state
816-
),
817-
ParamsPredict(
818-
include_last_label=True,
819-
agg_scores="mean"
820-
)
821-
),
822-
"raps": (
823-
Params(
824-
method="raps",
825-
cv="prefit",
826-
test_size=None,
827-
random_state=random_state
828-
),
829-
ParamsPredict(
830-
include_last_label=True,
831-
agg_scores="mean"
832-
)
833-
),
834-
"raps_split": (
835-
Params(
836-
method="raps",
837-
cv=StratifiedShuffleSplit(
838-
n_splits=1, train_size=0.5, random_state=random_state
839-
),
840-
test_size=None,
841-
random_state=random_state
842-
),
843-
ParamsPredict(
844-
include_last_label=True,
845-
agg_scores="mean"
846-
)
847-
),
848-
"raps_randomized": (
849-
Params(
850-
method="raps",
851-
cv="prefit",
852-
test_size=None,
853-
random_state=random_state
854-
),
855-
ParamsPredict(
856-
include_last_label="randomized",
857-
agg_scores="mean"
858-
)
859-
),
860-
}
861-
862749
LARGE_COVERAGES = {
863750
"lac": 0.802,
864751
"lac_split": 0.842,
865-
"aps": 0.928,
866-
"aps_split": 0.93,
752+
"aps_include": 0.928,
753+
"aps_include_split": 0.93,
867754
"aps_randomized": 0.802,
868755
"naive": 0.936,
869756
"naive_split": 0.914,
@@ -1046,9 +933,9 @@ def test_binary_classif_same_result() -> None:
1046933
@pytest.mark.parametrize("strategy", [*STRATEGIES])
1047934
def test_valid_estimator(strategy: str) -> None:
1048935
"""Test that valid estimators are not corrupted, for all strategies."""
1049-
clf = LogisticRegression().fit(X_toy, y_toy)
936+
clf = LogisticRegression().fit(X, y)
1050937
mapie_clf = MapieClassifier(estimator=clf, **STRATEGIES[strategy][0])
1051-
mapie_clf.fit(X_toy, y_toy)
938+
mapie_clf.fit(X, y)
1052939
assert (
1053940
isinstance(mapie_clf.estimator_.single_estimator_, LogisticRegression)
1054941
)
@@ -1500,11 +1387,9 @@ def test_valid_prediction(alpha: Any) -> None:
15001387
mapie_clf.predict(X_toy, alpha=alpha)
15011388

15021389

1503-
@pytest.mark.parametrize("strategy", [*STRATEGIES])
1390+
@pytest.mark.parametrize("strategy", [*COVERAGES])
15041391
def test_toy_dataset_predictions(strategy: str) -> None:
15051392
"""Test prediction sets estimated by MapieClassifier on a toy dataset"""
1506-
if strategy == "aps_randomized_cv_crossval":
1507-
return
15081393
args_init, args_predict = STRATEGIES[strategy]
15091394
if "split" not in strategy:
15101395
clf = LogisticRegression().fit(X_toy, y_toy)
@@ -1525,10 +1410,10 @@ def test_toy_dataset_predictions(strategy: str) -> None:
15251410
)
15261411

15271412

1528-
@pytest.mark.parametrize("strategy", [*LARGE_STRATEGIES])
1413+
@pytest.mark.parametrize("strategy", [*LARGE_COVERAGES])
15291414
def test_large_dataset_predictions(strategy: str) -> None:
15301415
"""Test prediction sets estimated by MapieClassifier on a larger dataset"""
1531-
args_init, args_predict = LARGE_STRATEGIES[strategy]
1416+
args_init, args_predict = STRATEGIES[strategy]
15321417
if "split" not in strategy:
15331418
clf = LogisticRegression().fit(X, y)
15341419
else:
@@ -1748,13 +1633,15 @@ def test_pred_loof_isnan() -> None:
17481633
@pytest.mark.parametrize("strategy", [*STRATEGIES])
17491634
def test_pipeline_compatibility(strategy: str) -> None:
17501635
"""Check that MAPIE works on pipeline based on pandas dataframes"""
1636+
X = np.random.randint(0, 100, size=100)
1637+
X_cat = np.random.choice(["A", "B", "C"], size=X.shape[0])
17511638
X = pd.DataFrame(
17521639
{
1753-
"x_cat": ["A", "A", "B", "A", "A", "B"],
1754-
"x_num": [0, 1, 1, 4, np.nan, 5],
1640+
"x_cat": X_cat,
1641+
"x_num": X,
17551642
}
17561643
)
1757-
y = pd.Series([0, 1, 2, 0, 1, 0])
1644+
y = np.random.randint(0, 4, size=(100, 1)) # 3 classes
17581645
numeric_preprocessor = Pipeline(
17591646
[
17601647
("imputer", SimpleImputer(strategy="mean")),
@@ -1833,7 +1720,7 @@ def test_regularize_conf_scores_shape(k_lambda) -> None:
18331720
Test that the conformity scores have the correct shape.
18341721
"""
18351722
lambda_, k = k_lambda[0], k_lambda[1]
1836-
args_init, _ = LARGE_STRATEGIES["raps"]
1723+
args_init, _ = STRATEGIES["raps"]
18371724
clf = LogisticRegression().fit(X, y)
18381725
mapie_clf = MapieClassifier(estimator=clf, **args_init)
18391726
conf_scores = np.random.rand(100, 1)

0 commit comments

Comments
 (0)