@@ -642,6 +642,10 @@ def prepare_latents(
642642    def  guidance_scale (self ):
643643        return  self ._guidance_scale 
644644
645+     @property  
646+     def  skip_guidance_layers (self ):
647+         return  self ._skip_guidance_layers 
648+ 
645649    @property  
646650    def  clip_skip (self ):
647651        return  self ._clip_skip 
@@ -694,6 +698,10 @@ def __call__(
694698        callback_on_step_end : Optional [Callable [[int , int , Dict ], None ]] =  None ,
695699        callback_on_step_end_tensor_inputs : List [str ] =  ["latents" ],
696700        max_sequence_length : int  =  256 ,
701+         skip_guidance_layers : List [int ] =  None ,
702+         skip_layer_guidance_scale : int  =  2.8 ,
703+         skip_layer_guidance_stop : int  =  0.2 ,
704+         skip_layer_guidance_start : int  =  0.01 ,
697705    ):
698706        r""" 
699707        Function invoked when calling the pipeline for generation. 
@@ -778,6 +786,22 @@ def __call__(
778786                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the 
779787                `._callback_tensor_inputs` attribute of your pipeline class. 
780788            max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. 
789+             skip_guidance_layers (`List[int]`, *optional*): 
790+                 A list of integers that specify layers to skip during guidance. If not provided, all layers will be 
791+                 used for guidance. If provided, the guidance will only be applied to the layers specified in the list. 
792+                 Recommended value by StabiltyAI for Stable Diffusion 3.5 Medium is [7, 8, 9]. 
793+             skip_layer_guidance_scale (`int`, *optional*): The scale of the guidance for the layers specified in 
794+                 `skip_guidance_layers`. The guidance will be applied to the layers specified in `skip_guidance_layers` 
795+                 with a scale of `skip_layer_guidance_scale`. The guidance will be applied to the rest of the layers 
796+                 with a scale of `1`. 
797+             skip_layer_guidance_stop (`int`, *optional*): The step at which the guidance for the layers specified in 
798+                 `skip_guidance_layers` will stop. The guidance will be applied to the layers specified in 
799+                 `skip_guidance_layers` until the fraction specified in `skip_layer_guidance_stop`. Recommended value by 
800+                 StabiltyAI for Stable Diffusion 3.5 Medium is 0.2. 
801+             skip_layer_guidance_start (`int`, *optional*): The step at which the guidance for the layers specified in 
802+                 `skip_guidance_layers` will start. The guidance will be applied to the layers specified in 
803+                 `skip_guidance_layers` from the fraction specified in `skip_layer_guidance_start`. Recommended value by 
804+                 StabiltyAI for Stable Diffusion 3.5 Medium is 0.01. 
781805
782806        Examples: 
783807
@@ -809,6 +833,7 @@ def __call__(
809833        )
810834
811835        self ._guidance_scale  =  guidance_scale 
836+         self ._skip_layer_guidance_scale  =  skip_layer_guidance_scale 
812837        self ._clip_skip  =  clip_skip 
813838        self ._joint_attention_kwargs  =  joint_attention_kwargs 
814839        self ._interrupt  =  False 
@@ -851,6 +876,9 @@ def __call__(
851876        )
852877
853878        if  self .do_classifier_free_guidance :
879+             if  skip_guidance_layers  is  not None :
880+                 original_prompt_embeds  =  prompt_embeds 
881+                 original_pooled_prompt_embeds  =  pooled_prompt_embeds 
854882            prompt_embeds  =  torch .cat ([negative_prompt_embeds , prompt_embeds ], dim = 0 )
855883            pooled_prompt_embeds  =  torch .cat ([negative_pooled_prompt_embeds , pooled_prompt_embeds ], dim = 0 )
856884
@@ -879,7 +907,11 @@ def __call__(
879907                    continue 
880908
881909                # expand the latents if we are doing classifier free guidance 
882-                 latent_model_input  =  torch .cat ([latents ] *  2 ) if  self .do_classifier_free_guidance  else  latents 
910+                 latent_model_input  =  (
911+                     torch .cat ([latents ] *  2 )
912+                     if  self .do_classifier_free_guidance  and  skip_guidance_layers  is  None 
913+                     else  latents 
914+                 )
883915                # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 
884916                timestep  =  t .expand (latent_model_input .shape [0 ])
885917
@@ -896,6 +928,25 @@ def __call__(
896928                if  self .do_classifier_free_guidance :
897929                    noise_pred_uncond , noise_pred_text  =  noise_pred .chunk (2 )
898930                    noise_pred  =  noise_pred_uncond  +  self .guidance_scale  *  (noise_pred_text  -  noise_pred_uncond )
931+                     should_skip_layers  =  (
932+                         True 
933+                         if  i  >  num_inference_steps  *  skip_layer_guidance_start 
934+                         and  i  <  num_inference_steps  *  skip_layer_guidance_stop 
935+                         else  False 
936+                     )
937+                     if  skip_guidance_layers  is  not None  and  should_skip_layers :
938+                         noise_pred_skip_layers  =  self .transformer (
939+                             hidden_states = latent_model_input ,
940+                             timestep = timestep ,
941+                             encoder_hidden_states = original_prompt_embeds ,
942+                             pooled_projections = original_pooled_prompt_embeds ,
943+                             joint_attention_kwargs = self .joint_attention_kwargs ,
944+                             return_dict = False ,
945+                             skip_layers = skip_guidance_layers ,
946+                         )[0 ]
947+                         noise_pred  =  (
948+                             noise_pred  +  (noise_pred_text  -  noise_pred_skip_layers ) *  self ._skip_layer_guidance_scale 
949+                         )
899950
900951                # compute the previous noisy sample x_t -> x_t-1 
901952                latents_dtype  =  latents .dtype 
0 commit comments