Skip to content

Commit 270d407

Browse files
committed
remove dynamic cfg
1 parent 52f97f6 commit 270d407

File tree

1 file changed

+12
-11
lines changed

1 file changed

+12
-11
lines changed

src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
# limitations under the License.
1515

1616
import inspect
17-
import math
1817
from typing import Callable, Dict, List, Optional, Tuple, Union
1918

2019
import torch
@@ -385,6 +384,13 @@ def check_inputs(
385384
def guidance_scale(self):
386385
return self._guidance_scale
387386

387+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
388+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
389+
# corresponds to doing no classifier free guidance.
390+
@property
391+
def do_classifier_free_guidance(self):
392+
return self._guidance_scale > 1
393+
388394
@property
389395
def num_timesteps(self):
390396
return self._num_timesteps
@@ -404,7 +410,6 @@ def __call__(
404410
num_inference_steps: int = 50,
405411
timesteps: Optional[List[int]] = None,
406412
guidance_scale: float = 5.0,
407-
use_dynamic_cfg: bool = False,
408413
num_images_per_prompt: int = 1,
409414
eta: float = 0.0,
410415
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
@@ -545,14 +550,14 @@ def __call__(
545550
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
546551
prompt,
547552
negative_prompt,
548-
do_classifier_free_guidance,
553+
self.do_classifier_free_guidance,
549554
num_images_per_prompt=num_images_per_prompt,
550555
prompt_embeds=prompt_embeds,
551556
negative_prompt_embeds=negative_prompt_embeds,
552557
max_sequence_length=max_sequence_length,
553558
device=device,
554559
)
555-
if do_classifier_free_guidance:
560+
if self.do_classifier_free_guidance:
556561
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
557562

558563
# 4. Prepare timesteps
@@ -580,7 +585,7 @@ def __call__(
580585
target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype)
581586
crops_coords_top_left = torch.tensor([crops_coords_top_left], dtype=prompt_embeds.dtype)
582587

583-
if do_classifier_free_guidance:
588+
if self.do_classifier_free_guidance:
584589
original_size = torch.cat([original_size, original_size])
585590
target_size = torch.cat([target_size, target_size])
586591
crops_coords_top_left = torch.cat([crops_coords_top_left, crops_coords_top_left])
@@ -599,7 +604,7 @@ def __call__(
599604
if self.interrupt:
600605
continue
601606

602-
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
607+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
603608
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
604609

605610
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
@@ -618,11 +623,7 @@ def __call__(
618623
noise_pred = noise_pred.float()
619624

620625
# perform guidance
621-
if use_dynamic_cfg:
622-
self._guidance_scale = 1 + guidance_scale * (
623-
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
624-
)
625-
if do_classifier_free_guidance:
626+
if self.do_classifier_free_guidance:
626627
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
627628
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
628629

0 commit comments

Comments
 (0)