1414# limitations under the License.
1515
1616import inspect
17- import math
1817from typing import Callable , Dict , List , Optional , Tuple , Union
1918
2019import torch
@@ -385,6 +384,13 @@ def check_inputs(
385384 def guidance_scale (self ):
386385 return self ._guidance_scale
387386
387+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
388+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
389+ # corresponds to doing no classifier free guidance.
390+ @property
391+ def do_classifier_free_guidance (self ):
392+ return self ._guidance_scale > 1
393+
388394 @property
389395 def num_timesteps (self ):
390396 return self ._num_timesteps
@@ -404,7 +410,6 @@ def __call__(
404410 num_inference_steps : int = 50 ,
405411 timesteps : Optional [List [int ]] = None ,
406412 guidance_scale : float = 5.0 ,
407- use_dynamic_cfg : bool = False ,
408413 num_images_per_prompt : int = 1 ,
409414 eta : float = 0.0 ,
410415 generator : Optional [Union [torch .Generator , List [torch .Generator ]]] = None ,
@@ -545,14 +550,14 @@ def __call__(
545550 prompt_embeds , negative_prompt_embeds = self .encode_prompt (
546551 prompt ,
547552 negative_prompt ,
548- do_classifier_free_guidance ,
553+ self . do_classifier_free_guidance ,
549554 num_images_per_prompt = num_images_per_prompt ,
550555 prompt_embeds = prompt_embeds ,
551556 negative_prompt_embeds = negative_prompt_embeds ,
552557 max_sequence_length = max_sequence_length ,
553558 device = device ,
554559 )
555- if do_classifier_free_guidance :
560+ if self . do_classifier_free_guidance :
556561 prompt_embeds = torch .cat ([negative_prompt_embeds , prompt_embeds ], dim = 0 )
557562
558563 # 4. Prepare timesteps
@@ -580,7 +585,7 @@ def __call__(
580585 target_size = torch .tensor ([target_size ], dtype = prompt_embeds .dtype )
581586 crops_coords_top_left = torch .tensor ([crops_coords_top_left ], dtype = prompt_embeds .dtype )
582587
583- if do_classifier_free_guidance :
588+ if self . do_classifier_free_guidance :
584589 original_size = torch .cat ([original_size , original_size ])
585590 target_size = torch .cat ([target_size , target_size ])
586591 crops_coords_top_left = torch .cat ([crops_coords_top_left , crops_coords_top_left ])
@@ -599,7 +604,7 @@ def __call__(
599604 if self .interrupt :
600605 continue
601606
602- latent_model_input = torch .cat ([latents ] * 2 ) if do_classifier_free_guidance else latents
607+ latent_model_input = torch .cat ([latents ] * 2 ) if self . do_classifier_free_guidance else latents
603608 latent_model_input = self .scheduler .scale_model_input (latent_model_input , t )
604609
605610 # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
@@ -618,11 +623,7 @@ def __call__(
618623 noise_pred = noise_pred .float ()
619624
620625 # perform guidance
621- if use_dynamic_cfg :
622- self ._guidance_scale = 1 + guidance_scale * (
623- (1 - math .cos (math .pi * ((num_inference_steps - t .item ()) / num_inference_steps ) ** 5.0 )) / 2
624- )
625- if do_classifier_free_guidance :
626+ if self .do_classifier_free_guidance :
626627 noise_pred_uncond , noise_pred_text = noise_pred .chunk (2 )
627628 noise_pred = noise_pred_uncond + self .guidance_scale * (noise_pred_text - noise_pred_uncond )
628629
0 commit comments