Skip to content

Commit aa933b9

Browse files
author
Vianney Taquet
committed
Fix Optional alpha
1 parent 9bed02c commit aa933b9

File tree

1 file changed

+24
-24
lines changed

1 file changed

+24
-24
lines changed

mapie/estimators.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ def _check_cv(
302302
def _check_alpha(
303303
self,
304304
alpha: Optional[Union[float, Iterable[float]]] = None
305-
) -> np.ndarray:
305+
) -> Optional[np.ndarray]:
306306
"""
307307
Check alpha and prepare it as a np.ndarray
308308
@@ -326,30 +326,29 @@ def _check_alpha(
326326
If alpha is not a float or an Iterable of floats between 0 and 1.
327327
"""
328328
if alpha is None:
329-
return np.zeros(1)
329+
return alpha
330+
if isinstance(alpha, float):
331+
alpha_np = np.array([alpha])
332+
elif isinstance(alpha, Iterable):
333+
alpha_np = np.array(alpha)
330334
else:
331-
if isinstance(alpha, float):
332-
alpha_np = np.array([alpha])
333-
elif isinstance(alpha, Iterable):
334-
alpha_np = np.array(alpha)
335-
else:
336-
raise ValueError(
337-
"Invalid alpha. Allowed values are float or Iterable."
338-
)
339-
if len(alpha_np.shape) != 1:
340-
raise ValueError(
341-
"Invalid alpha. "
342-
"Please provide a one-dimensional list of values."
343-
)
344-
if alpha_np.dtype.type not in [np.float64, np.float32]:
345-
raise ValueError(
346-
"Invalid alpha. Allowed values are Iterable of floats."
347-
)
348-
if np.any((alpha_np <= 0) | (alpha_np >= 1)):
349-
raise ValueError(
350-
"Invalid alpha. Allowed values are between 0 and 1."
351-
)
352-
return alpha_np
335+
raise ValueError(
336+
"Invalid alpha. Allowed values are float or Iterable."
337+
)
338+
if len(alpha_np.shape) != 1:
339+
raise ValueError(
340+
"Invalid alpha. "
341+
"Please provide a one-dimensional list of values."
342+
)
343+
if alpha_np.dtype.type not in [np.float64, np.float32]:
344+
raise ValueError(
345+
"Invalid alpha. Allowed values are Iterable of floats."
346+
)
347+
if np.any((alpha_np <= 0) | (alpha_np >= 1)):
348+
raise ValueError(
349+
"Invalid alpha. Allowed values are between 0 and 1."
350+
)
351+
return alpha_np
353352

354353
def _check_n_features_in(
355354
self,
@@ -593,6 +592,7 @@ def predict(
593592
# (n_alpha,) : alpha
594593
# (n_samples_test, n_alpha) : y_pred_low, y_pred_up
595594
# (n_samples_test, n_samples_train) : y_pred_multi, low/up_bounds
595+
alpha_ = cast(np.ndarray, alpha_)
596596
if self.method in ["naive", "base"] or self.cv == "prefit":
597597
quantile = np.quantile(
598598
self.residuals_, 1 - alpha_, interpolation="higher"

0 commit comments

Comments
 (0)