@@ -302,7 +302,7 @@ def _check_cv(
302302 def _check_alpha (
303303 self ,
304304 alpha : Optional [Union [float , Iterable [float ]]] = None
305- ) -> np .ndarray :
305+ ) -> Optional [ np .ndarray ] :
306306 """
307307 Check alpha and prepare it as a np.ndarray
308308
@@ -326,30 +326,29 @@ def _check_alpha(
326326 If alpha is not a float or an Iterable of floats between 0 and 1.
327327 """
328328 if alpha is None :
329- return np .zeros (1 )
329+ return alpha
330+ if isinstance (alpha , float ):
331+ alpha_np = np .array ([alpha ])
332+ elif isinstance (alpha , Iterable ):
333+ alpha_np = np .array (alpha )
330334 else :
331- if isinstance (alpha , float ):
332- alpha_np = np .array ([alpha ])
333- elif isinstance (alpha , Iterable ):
334- alpha_np = np .array (alpha )
335- else :
336- raise ValueError (
337- "Invalid alpha. Allowed values are float or Iterable."
338- )
339- if len (alpha_np .shape ) != 1 :
340- raise ValueError (
341- "Invalid alpha. "
342- "Please provide a one-dimensional list of values."
343- )
344- if alpha_np .dtype .type not in [np .float64 , np .float32 ]:
345- raise ValueError (
346- "Invalid alpha. Allowed values are Iterable of floats."
347- )
348- if np .any ((alpha_np <= 0 ) | (alpha_np >= 1 )):
349- raise ValueError (
350- "Invalid alpha. Allowed values are between 0 and 1."
351- )
352- return alpha_np
335+ raise ValueError (
336+ "Invalid alpha. Allowed values are float or Iterable."
337+ )
338+ if len (alpha_np .shape ) != 1 :
339+ raise ValueError (
340+ "Invalid alpha. "
341+ "Please provide a one-dimensional list of values."
342+ )
343+ if alpha_np .dtype .type not in [np .float64 , np .float32 ]:
344+ raise ValueError (
345+ "Invalid alpha. Allowed values are Iterable of floats."
346+ )
347+ if np .any ((alpha_np <= 0 ) | (alpha_np >= 1 )):
348+ raise ValueError (
349+ "Invalid alpha. Allowed values are between 0 and 1."
350+ )
351+ return alpha_np
353352
354353 def _check_n_features_in (
355354 self ,
@@ -593,6 +592,7 @@ def predict(
593592 # (n_alpha,) : alpha
594593 # (n_samples_test, n_alpha) : y_pred_low, y_pred_up
595594 # (n_samples_test, n_samples_train) : y_pred_multi, low/up_bounds
595+ alpha_ = cast (np .ndarray , alpha_ )
596596 if self .method in ["naive" , "base" ] or self .cv == "prefit" :
597597 quantile = np .quantile (
598598 self .residuals_ , 1 - alpha_ , interpolation = "higher"
0 commit comments