2121import  torch 
2222
2323from  ..configuration_utils  import  ConfigMixin , register_to_config 
24- from  ..utils  import  deprecate 
24+ from  ..utils  import  deprecate ,  is_scipy_available 
2525from  ..utils .torch_utils  import  randn_tensor 
2626from  .scheduling_utils  import  KarrasDiffusionSchedulers , SchedulerMixin , SchedulerOutput 
2727
2828
29+ if  is_scipy_available ():
30+     import  scipy .stats 
31+ 
32+ 
2933# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar 
3034def  betas_for_alpha_bar (
3135    num_diffusion_timesteps ,
@@ -163,6 +167,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
163167            the sigmas are determined according to a sequence of noise levels {σi}. 
164168        use_exponential_sigmas (`bool`, *optional*, defaults to `False`): 
165169            Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. 
170+         use_beta_sigmas (`bool`, *optional*, defaults to `False`): 
171+             Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to [Beta 
172+             Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information. 
166173        use_lu_lambdas (`bool`, *optional*, defaults to `False`): 
167174            Whether to use the uniform-logSNR for step sizes proposed by Lu's DPM-Solver in the noise schedule during 
168175            the sampling process. If `True`, the sigmas and time steps are determined according to a sequence of 
@@ -209,6 +216,7 @@ def __init__(
209216        euler_at_final : bool  =  False ,
210217        use_karras_sigmas : Optional [bool ] =  False ,
211218        use_exponential_sigmas : Optional [bool ] =  False ,
219+         use_beta_sigmas : Optional [bool ] =  False ,
212220        use_lu_lambdas : Optional [bool ] =  False ,
213221        final_sigmas_type : Optional [str ] =  "zero" ,  # "zero", "sigma_min" 
214222        lambda_min_clipped : float  =  - float ("inf" ),
@@ -217,8 +225,12 @@ def __init__(
217225        steps_offset : int  =  0 ,
218226        rescale_betas_zero_snr : bool  =  False ,
219227    ):
220-         if  sum ([self .config .use_exponential_sigmas , self .config .use_karras_sigmas ]) >  1 :
221-             raise  ValueError ("Only one of `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used." )
228+         if  self .config .use_beta_sigmas  and  not  is_scipy_available ():
229+             raise  ImportError ("Make sure to install scipy if you want to use beta sigmas." )
230+         if  sum ([self .config .use_beta_sigmas , self .config .use_exponential_sigmas , self .config .use_karras_sigmas ]) >  1 :
231+             raise  ValueError (
232+                 "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used." 
233+             )
222234        if  algorithm_type  in  ["dpmsolver" , "sde-dpmsolver" ]:
223235            deprecation_message  =  f"algorithm_type { algorithm_type }  
224236            deprecate ("algorithm_types dpmsolver and sde-dpmsolver" , "1.0.0" , deprecation_message )
@@ -337,6 +349,8 @@ def set_timesteps(
337349            raise  ValueError ("Cannot use `timesteps` with `config.use_lu_lambdas = True`" )
338350        if  timesteps  is  not None  and  self .config .use_exponential_sigmas :
339351            raise  ValueError ("Cannot set `timesteps` with `config.use_exponential_sigmas = True`." )
352+         if  timesteps  is  not None  and  self .config .use_beta_sigmas :
353+             raise  ValueError ("Cannot set `timesteps` with `config.use_beta_sigmas = True`." )
340354
341355        if  timesteps  is  not None :
342356            timesteps  =  np .array (timesteps ).astype (np .int64 )
@@ -388,6 +402,9 @@ def set_timesteps(
388402        elif  self .config .use_exponential_sigmas :
389403            sigmas  =  self ._convert_to_exponential (in_sigmas = sigmas , num_inference_steps = self .num_inference_steps )
390404            timesteps  =  np .array ([self ._sigma_to_t (sigma , log_sigmas ) for  sigma  in  sigmas ])
405+         elif  self .config .use_beta_sigmas :
406+             sigmas  =  self ._convert_to_beta (in_sigmas = sigmas , num_inference_steps = self .num_inference_steps )
407+             timesteps  =  np .array ([self ._sigma_to_t (sigma , log_sigmas ) for  sigma  in  sigmas ])
391408        else :
392409            sigmas  =  np .interp (timesteps , np .arange (0 , len (sigmas )), sigmas )
393410
@@ -542,6 +559,38 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps:
542559        sigmas  =  torch .linspace (math .log (sigma_max ), math .log (sigma_min ), num_inference_steps ).exp ()
543560        return  sigmas 
544561
562+     # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_beta 
563+     def  _convert_to_beta (
564+         self , in_sigmas : torch .Tensor , num_inference_steps : int , alpha : float  =  0.6 , beta : float  =  0.6 
565+     ) ->  torch .Tensor :
566+         """From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024)""" 
567+ 
568+         # Hack to make sure that other schedulers which copy this function don't break 
569+         # TODO: Add this logic to the other schedulers 
570+         if  hasattr (self .config , "sigma_min" ):
571+             sigma_min  =  self .config .sigma_min 
572+         else :
573+             sigma_min  =  None 
574+ 
575+         if  hasattr (self .config , "sigma_max" ):
576+             sigma_max  =  self .config .sigma_max 
577+         else :
578+             sigma_max  =  None 
579+ 
580+         sigma_min  =  sigma_min  if  sigma_min  is  not None  else  in_sigmas [- 1 ].item ()
581+         sigma_max  =  sigma_max  if  sigma_max  is  not None  else  in_sigmas [0 ].item ()
582+ 
583+         sigmas  =  torch .Tensor (
584+             [
585+                 sigma_min  +  (ppf  *  (sigma_max  -  sigma_min ))
586+                 for  ppf  in  [
587+                     scipy .stats .beta .ppf (timestep , alpha , beta )
588+                     for  timestep  in  1  -  np .linspace (0 , 1 , num_inference_steps )
589+                 ]
590+             ]
591+         )
592+         return  sigmas 
593+ 
545594    def  convert_model_output (
546595        self ,
547596        model_output : torch .Tensor ,
0 commit comments