@@ -214,7 +214,7 @@ def get_quantile(
214214 conformity_scores : NDArray ,
215215 alpha_np : NDArray ,
216216 axis : int ,
217- method : str
217+ reversed : bool = False
218218 ) -> NDArray :
219219 """
220220 Compute the alpha quantile of the conformity scores or the conformity
@@ -235,28 +235,29 @@ def get_quantile(
235235 axis: int
236236 The axis from which to compute the quantile.
237237
238- method: str
239- ``"higher"`` or ``"lower"`` the method to compute the quantile.
238+ reversed: bool
239+ Boolean specifying whether we take the upper or lower quantile,
240+ if False, the alpha quantile, otherwise the (1-alpha) quantile.
240241
241242 Returns
242243 -------
243244 NDArray of shape (1, n_alpha) or (n_samples, n_alpha)
244245 The quantile of the conformity scores.
245246 """
246- n_ref = conformity_scores .shape [- 1 ]
247- # TODO: assume that each group has same n_calib when using plus method
248- n_calib = np .min (np .sum (~ np .isnan (conformity_scores ), axis = 0 ))
249- quantile = np .column_stack ([
247+ n_ref = conformity_scores .shape [1 - axis ]
248+ n_calib = np .min (np .sum (~ np .isnan (conformity_scores ), axis = axis ))
249+ signed = 1 - 2 * reversed
250+ alpha_ref = (1 - 2 * alpha_np )* reversed + alpha_np
251+
252+ quantile = signed * np .column_stack ([
250253 np_nanquantile (
251- conformity_scores .astype (float ),
252- np .ceil (_alpha * (n_calib + 1 ))/ n_calib ,
254+ signed * conformity_scores .astype (float ),
255+ np .ceil (_alpha * (n_calib + 1 ))/ n_calib ,
253256 axis = axis ,
254- method = method
255- ) if n_calib and 0 < np .ceil (_alpha * (n_calib + 1 ))/ n_calib < 1
256- else np .nan * np .ones (n_ref ) if not n_calib
257- else np .inf * np .ones (n_ref ) if method == "higher"
258- else - np .inf * np .ones (n_ref )
259- for _alpha in alpha_np
257+ method = "lower"
258+ ) if 0 < np .ceil (_alpha * (n_calib + 1 ))/ n_calib < 1
259+ else np .inf * np .ones (n_ref )
260+ for _alpha in alpha_ref
260261 ])
261262 return quantile
262263
@@ -284,7 +285,7 @@ def _beta_optimize(
284285 -------
285286 NDArray
286287 Array of betas minimizing the differences
287- ``(1-alpa +beta)-quantile - beta-quantile``.
288+ ``(1-alpha +beta)-quantile - beta-quantile``.
288289 """
289290 beta_np = np .full (
290291 shape = (len (lower_bounds ), len (alpha_np )),
@@ -408,26 +409,34 @@ def get_bounds(
408409 X , y_pred_up , conformity_scores
409410 )
410411 bound_low = self .get_quantile (
411- conformity_scores_low , alpha_low , axis = 1 , method = "lower"
412+ conformity_scores_low , alpha_low , axis = 1 , reversed = True
412413 )
413414 bound_up = self .get_quantile (
414- conformity_scores_up , alpha_up , axis = 1 , method = "higher"
415+ conformity_scores_up , alpha_up , axis = 1
415416 )
417+
416418 else :
417- quantile_search = "higher" if self .sym else "lower"
418- alpha_low = 1 - alpha_np if self .sym else beta_np
419- alpha_up = 1 - alpha_np if self .sym else 1 - alpha_np + beta_np
419+ if self .sym :
420+ alpha_ref = 1 - alpha_np
421+ quantile_ref = self .get_quantile (
422+ conformity_scores [..., np .newaxis ], alpha_ref , axis = 0
423+ )
424+ quantile_low , quantile_up = - quantile_ref , quantile_ref
425+
426+ else :
427+ alpha_low , alpha_up = beta_np , 1 - alpha_np + beta_np
428+
429+ quantile_low = self .get_quantile (
430+ conformity_scores [..., np .newaxis ],
431+ alpha_low , axis = 0 , reversed = True
432+ )
433+ quantile_up = self .get_quantile (
434+ conformity_scores [..., np .newaxis ],
435+ alpha_up , axis = 0
436+ )
420437
421- quantile_low = self .get_quantile (
422- conformity_scores [..., np .newaxis ],
423- alpha_low , axis = 0 , method = quantile_search
424- )
425- quantile_up = self .get_quantile (
426- conformity_scores [..., np .newaxis ],
427- alpha_up , axis = 0 , method = "higher"
428- )
429438 bound_low = self .get_estimation_distribution (
430- X , y_pred_low , signed * quantile_low
439+ X , y_pred_low , quantile_low
431440 )
432441 bound_up = self .get_estimation_distribution (
433442 X , y_pred_up , quantile_up
0 commit comments