Skip to content

Commit 74ce283

Browse files
author
Vincent Blot
committed
ENH: directly add proba cumsed in random_tie_breaking function
1 parent 3b638e6 commit 74ce283

File tree

1 file changed

+18
-12
lines changed

1 file changed

+18
-12
lines changed

mapie/classification.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ def _add_random_tie_breaking(
410410
self,
411411
prediction_sets: ArrayLike,
412412
y_pred_index_last: ArrayLike,
413-
y_pred_proba: ArrayLike,
413+
y_pred_proba_cumsum: ArrayLike,
414414
y_pred_proba_last: ArrayLike
415415
) -> ArrayLike:
416416
"""
@@ -424,8 +424,8 @@ def _add_random_tie_breaking(
424424
Prediction set for each observation and each alpha.
425425
y_pred_index_last : ArrayLike of shape (n_samples, n_alpha)
426426
Index of the last included label.
427-
y_pred_proba : ArrayLike of shape (n_samples, n_classes)
428-
Probability output of the model.
427+
y_pred_proba_cumsum : ArrayLike of shape (n_samples, n_classes)
428+
Cumsumed probability of the model in the original order.
429429
y_pred_proba_last : ArrayLike of shape (n_samples, n_alpha)
430430
Last included probability.
431431
@@ -436,25 +436,31 @@ def _add_random_tie_breaking(
436436
labels.
437437
"""
438438
# filter sorting probabilities with kept labels
439-
y_proba_filtered = np.stack([
440-
y_pred_proba * prediction_sets[:, :, iq]
441-
for iq, _ in enumerate(self.quantiles_)
442-
], axis=2)
439+
y_proba_last_cumsumed = np.stack(
440+
[
441+
np.squeeze(
442+
np.take_along_axis(
443+
y_pred_proba_cumsum,
444+
y_pred_index_last[:, iq].reshape(-1, 1),
445+
axis=1
446+
)
447+
)
448+
for iq, _ in enumerate(self.quantiles_)
449+
], axis=1
450+
)
443451
# compute V parameter from Romano+(2020)
444452
vs = np.stack(
445453
[
446454
(
447-
np.sum(
448-
y_proba_filtered[:, :, iq], axis=1
449-
)
455+
y_proba_last_cumsumed[:, iq]
450456
- quantile
451457
) / y_pred_proba_last[:, iq]
452458
for iq, quantile in enumerate(self.quantiles_)
453459
], axis=1,
454460
)
455461
# get random numbers for each observation and alpha value
456462
random_state = check_random_state(self.random_state)
457-
us = random_state.uniform(size=y_pred_proba.shape[0])
463+
us = random_state.uniform(size=prediction_sets.shape[0])
458464
# remove last label from comparison between uniform number and V
459465
vs_less_than_us = vs < us[:, np.newaxis]
460466
np.put_along_axis(
@@ -683,7 +689,7 @@ def predict(
683689
prediction_sets = self._add_random_tie_breaking(
684690
prediction_sets,
685691
y_pred_index_last,
686-
y_pred_proba,
692+
y_pred_proba_cumsum,
687693
y_pred_proba_last
688694
)
689695
else:

0 commit comments

Comments
 (0)