@@ -410,7 +410,7 @@ def _add_random_tie_breaking(
410410 self ,
411411 prediction_sets : ArrayLike ,
412412 y_pred_index_last : ArrayLike ,
413- y_pred_proba : ArrayLike ,
413+ y_pred_proba_cumsum : ArrayLike ,
414414 y_pred_proba_last : ArrayLike
415415 ) -> ArrayLike :
416416 """
@@ -424,8 +424,8 @@ def _add_random_tie_breaking(
424424 Prediction set for each observation and each alpha.
425425 y_pred_index_last : ArrayLike of shape (n_samples, n_alpha)
426426 Index of the last included label.
427- y_pred_proba : ArrayLike of shape (n_samples, n_classes)
428- Probability output of the model.
427+ y_pred_proba_cumsum : ArrayLike of shape (n_samples, n_classes)
428+ Cumsumed probability of the model in the original order .
429429 y_pred_proba_last : ArrayLike of shape (n_samples, n_alpha)
430430 Last included probability.
431431
@@ -436,25 +436,31 @@ def _add_random_tie_breaking(
436436 labels.
437437 """
438438 # filter sorting probabilities with kept labels
439- y_proba_filtered = np .stack ([
440- y_pred_proba * prediction_sets [:, :, iq ]
441- for iq , _ in enumerate (self .quantiles_ )
442- ], axis = 2 )
439+ y_proba_last_cumsumed = np .stack (
440+ [
441+ np .squeeze (
442+ np .take_along_axis (
443+ y_pred_proba_cumsum ,
444+ y_pred_index_last [:, iq ].reshape (- 1 , 1 ),
445+ axis = 1
446+ )
447+ )
448+ for iq , _ in enumerate (self .quantiles_ )
449+ ], axis = 1
450+ )
443451 # compute V parameter from Romano+(2020)
444452 vs = np .stack (
445453 [
446454 (
447- np .sum (
448- y_proba_filtered [:, :, iq ], axis = 1
449- )
455+ y_proba_last_cumsumed [:, iq ]
450456 - quantile
451457 ) / y_pred_proba_last [:, iq ]
452458 for iq , quantile in enumerate (self .quantiles_ )
453459 ], axis = 1 ,
454460 )
455461 # get random numbers for each observation and alpha value
456462 random_state = check_random_state (self .random_state )
457- us = random_state .uniform (size = y_pred_proba .shape [0 ])
463+ us = random_state .uniform (size = prediction_sets .shape [0 ])
458464 # remove last label from comparison between uniform number and V
459465 vs_less_than_us = vs < us [:, np .newaxis ]
460466 np .put_along_axis (
@@ -683,7 +689,7 @@ def predict(
683689 prediction_sets = self ._add_random_tie_breaking (
684690 prediction_sets ,
685691 y_pred_index_last ,
686- y_pred_proba ,
692+ y_pred_proba_cumsum ,
687693 y_pred_proba_last
688694 )
689695 else :
0 commit comments