@@ -301,7 +301,7 @@ def _check_cv(
301301
302302 def _check_alpha (
303303 self ,
304- alpha : Optional [ Union [float , Iterable [float ]]] = None
304+ alpha : Union [float , Iterable [float ]]
305305 ) -> np .ndarray :
306306 """
307307 Check alpha and prepare it as a np.ndarray
@@ -325,30 +325,27 @@ def _check_alpha(
325325 ValueError
326326 If alpha is not a float or an Iterable of floats between 0 and 1.
327327 """
328- if alpha is None :
329- alpha_np = None
328+ if isinstance (alpha , float ):
329+ alpha_np = np .array ([alpha ])
330+ elif isinstance (alpha , Iterable ):
331+ alpha_np = np .array (alpha )
330332 else :
331- if isinstance (alpha , float ):
332- alpha_np = np .array ([alpha ])
333- elif isinstance (alpha , Iterable ):
334- alpha_np = np .array (alpha )
335- else :
336- raise ValueError (
337- "Invalid alpha. Allowed values are float or Iterable."
338- )
339- if len (alpha_np .shape ) != 1 :
340- raise ValueError (
341- "Invalid alpha. "
342- "Please provide a one-dimensional list of values."
343- )
344- if alpha_np .dtype .type not in [np .float64 , np .float32 ]:
345- raise ValueError (
346- "Invalid alpha. Allowed values are Iterable of floats."
347- )
348- if np .any ((alpha_np <= 0 ) | (alpha_np >= 1 )):
349- raise ValueError (
350- "Invalid alpha. Allowed values are between 0 and 1."
351- )
333+ raise ValueError (
334+ "Invalid alpha. Allowed values are float or Iterable."
335+ )
336+ if len (alpha_np .shape ) != 1 :
337+ raise ValueError (
338+ "Invalid alpha. "
339+ "Please provide a one-dimensional list of values."
340+ )
341+ if alpha_np .dtype .type not in [np .float64 , np .float32 ]:
342+ raise ValueError (
343+ "Invalid alpha. Allowed values are Iterable of floats."
344+ )
345+ if np .any ((alpha_np <= 0 ) | (alpha_np >= 1 )):
346+ raise ValueError (
347+ "Invalid alpha. Allowed values are between 0 and 1."
348+ )
352349 return alpha_np
353350
354351 def _check_n_features_in (
@@ -582,7 +579,6 @@ def predict(
582579 ]
583580 )
584581 X = check_array (X , force_all_finite = False , dtype = ["float64" , "object" ])
585- alpha_ = self ._check_alpha (alpha )
586582 y_pred = self .single_estimator_ .predict (X )
587583
588584 if alpha is None :
@@ -593,6 +589,7 @@ def predict(
593589 # (n_alpha,) : alpha
594590 # (n_samples_test, n_alpha) : y_pred_low, y_pred_up
595591 # (n_samples_test, n_samples_train) : y_pred_multi, low/up_bounds
592+ alpha_ = self ._check_alpha (alpha )
596593 if self .method in ["naive" , "base" ] or self .cv == "prefit" :
597594 quantile = np .quantile (
598595 self .residuals_ , 1 - alpha_ , interpolation = "higher"
0 commit comments