7272        >>> image = pipe( 
7373        ...     prompt, 
7474        ...     control_image=control_image, 
75-         ...     controlnet_conditioning_scale=0.6, 
75+         ...     control_guidance_start=0.2, 
76+         ...     control_guidance_end=0.8, 
77+         ...     controlnet_conditioning_scale=1.0, 
7678        ...     num_inference_steps=28, 
7779        ...     guidance_scale=3.5, 
7880        ... ).images[0] 
@@ -572,6 +574,8 @@ def __call__(
572574        num_inference_steps : int  =  28 ,
573575        timesteps : List [int ] =  None ,
574576        guidance_scale : float  =  7.0 ,
577+         control_guidance_start : Union [float , List [float ]] =  0.0 ,
578+         control_guidance_end : Union [float , List [float ]] =  1.0 ,
575579        control_image : PipelineImageInput  =  None ,
576580        control_mode : Optional [Union [int , List [int ]]] =  None ,
577581        controlnet_conditioning_scale : Union [float , List [float ]] =  1.0 ,
@@ -614,6 +618,10 @@ def __call__(
614618                Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 
615619                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 
616620                usually at the expense of lower image quality. 
621+             control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): 
622+                 The percentage of total steps at which the ControlNet starts applying. 
623+             control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): 
624+                 The percentage of total steps at which the ControlNet stops applying. 
617625            control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: 
618626                    `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): 
619627                The ControlNet input condition to provide guidance to the `unet` for generation. If the type is 
@@ -674,6 +682,17 @@ def __call__(
674682        height  =  height  or  self .default_sample_size  *  self .vae_scale_factor 
675683        width  =  width  or  self .default_sample_size  *  self .vae_scale_factor 
676684
685+         if  not  isinstance (control_guidance_start , list ) and  isinstance (control_guidance_end , list ):
686+             control_guidance_start  =  len (control_guidance_end ) *  [control_guidance_start ]
687+         elif  not  isinstance (control_guidance_end , list ) and  isinstance (control_guidance_start , list ):
688+             control_guidance_end  =  len (control_guidance_start ) *  [control_guidance_end ]
689+         elif  not  isinstance (control_guidance_start , list ) and  not  isinstance (control_guidance_end , list ):
690+             mult  =  len (self .controlnet .nets ) if  isinstance (self .controlnet , FluxMultiControlNetModel ) else  1 
691+             control_guidance_start , control_guidance_end  =  (
692+                 mult  *  [control_guidance_start ],
693+                 mult  *  [control_guidance_end ],
694+             )
695+ 
677696        # 1. Check inputs. Raise error if not correct 
678697        self .check_inputs (
679698            prompt ,
@@ -839,7 +858,16 @@ def __call__(
839858        num_warmup_steps  =  max (len (timesteps ) -  num_inference_steps  *  self .scheduler .order , 0 )
840859        self ._num_timesteps  =  len (timesteps )
841860
842-         # 6. Denoising loop 
861+         # 6. Create tensor stating which controlnets to keep 
862+         controlnet_keep  =  []
863+         for  i  in  range (len (timesteps )):
864+             keeps  =  [
865+                 1.0  -  float (i  /  len (timesteps ) <  s  or  (i  +  1 ) /  len (timesteps ) >  e )
866+                 for  s , e  in  zip (control_guidance_start , control_guidance_end )
867+             ]
868+             controlnet_keep .append (keeps [0 ] if  isinstance (self .controlnet , FluxControlNetModel ) else  keeps )
869+ 
870+         # 7. Denoising loop 
843871        with  self .progress_bar (total = num_inference_steps ) as  progress_bar :
844872            for  i , t  in  enumerate (timesteps ):
845873                if  self .interrupt :
@@ -856,12 +884,20 @@ def __call__(
856884                guidance  =  torch .tensor ([guidance_scale ], device = device ) if  use_guidance  else  None 
857885                guidance  =  guidance .expand (latents .shape [0 ]) if  guidance  is  not   None  else  None 
858886
887+                 if  isinstance (controlnet_keep [i ], list ):
888+                     cond_scale  =  [c  *  s  for  c , s  in  zip (controlnet_conditioning_scale , controlnet_keep [i ])]
889+                 else :
890+                     controlnet_cond_scale  =  controlnet_conditioning_scale 
891+                     if  isinstance (controlnet_cond_scale , list ):
892+                         controlnet_cond_scale  =  controlnet_cond_scale [0 ]
893+                     cond_scale  =  controlnet_cond_scale  *  controlnet_keep [i ]
894+ 
859895                # controlnet 
860896                controlnet_block_samples , controlnet_single_block_samples  =  self .controlnet (
861897                    hidden_states = latents ,
862898                    controlnet_cond = control_image ,
863899                    controlnet_mode = control_mode ,
864-                     conditioning_scale = controlnet_conditioning_scale ,
900+                     conditioning_scale = cond_scale ,
865901                    timestep = timestep  /  1000 ,
866902                    guidance = guidance ,
867903                    pooled_projections = pooled_prompt_embeds ,
0 commit comments