Skip to content

Commit cfeaf00

Browse files
author
Vianney Taquet
committed
Fix one more time test with EPSILON
1 parent 4831c74 commit cfeaf00

File tree

2 files changed

+22
-22
lines changed

2 files changed

+22
-22
lines changed

mapie/classification.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -910,7 +910,7 @@ def predict(
910910
).sum(axis=2)
911911
prediction_sets = np.stack(
912912
[
913-
y_pred_included > _alpha * (n - 1)
913+
y_pred_included > _alpha * (n - 1) - EPSILON
914914
for _alpha in alpha_
915915
], axis=2
916916
)
@@ -960,7 +960,7 @@ def predict(
960960
else:
961961
y_pred_included = (
962962
# ~(y_pred_proba >= y_pred_proba_last - EPSILON)
963-
(y_pred_proba <= y_pred_proba_last - EPSILON)
963+
(y_pred_proba < y_pred_proba_last + EPSILON)
964964
)
965965
# remove last label randomly
966966
if include_last_label == "randomized":
@@ -979,7 +979,7 @@ def predict(
979979
# compare the summed prediction sets with (n+1)*(1-alpha)
980980
prediction_sets = np.stack(
981981
[
982-
prediction_sets_summed < quantile
982+
prediction_sets_summed < quantile + EPSILON
983983
for quantile in self.quantiles_
984984
], axis=2
985985
)

mapie/tests/test_classification.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -233,8 +233,8 @@
233233
"cumulated_score_include_cv_mean": 1,
234234
"cumulated_score_not_include_cv_mean": 5 / 9,
235235
"cumulated_score_randomized_cv_mean": 5 / 9,
236-
"cumulated_score_include_cv_crossval": 1,
237-
"cumulated_score_not_include_cv_crossval": 6 / 9,
236+
"cumulated_score_include_cv_crossval": 0,
237+
"cumulated_score_not_include_cv_crossval": 0,
238238
"cumulated_score_randomized_cv_crossval": 3 / 9,
239239
"naive": 5 / 9,
240240
"top_k": 1
@@ -344,26 +344,26 @@
344344
[False, True, False],
345345
],
346346
"cumulated_score_include_cv_crossval": [
347+
[False, False, False],
348+
[False, False, False],
347349
[True, False, False],
348-
[True, False, False],
349-
[True, True, False],
350-
[True, True, False],
351-
[False, True, False],
352-
[False, True, False],
353-
[False, True, True],
354-
[False, True, True],
355-
[False, False, True],
350+
[False, False, False],
351+
[False, False, False],
352+
[False, False, False],
353+
[False, False, False],
354+
[False, False, False],
355+
[False, False, False],
356356
],
357357
"cumulated_score_not_include_cv_crossval": [
358-
[True, False, False],
359-
[True, False, False],
360-
[True, False, False],
361-
[True, True, False],
362-
[False, True, False],
363-
[False, True, False],
364-
[False, True, False],
365-
[False, False, True],
366-
[False, False, True],
358+
[False, False, False],
359+
[False, False, False],
360+
[False, False, False],
361+
[False, False, False],
362+
[False, False, False],
363+
[False, False, False],
364+
[False, False, False],
365+
[False, False, False],
366+
[False, False, False],
367367
],
368368
"cumulated_score_randomized_cv_crossval": [
369369
[True, False, False],

0 commit comments

Comments
 (0)