Skip to content

Commit 030a467

Browse files
change time_shift
1 parent f1ccdd2 commit 030a467

File tree

1 file changed

+17
-9
lines changed

1 file changed

+17
-9
lines changed

src/diffusers/pipelines/cogview4/pipeline_cogview4.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -55,18 +55,25 @@
5555
"""
5656

5757

58+
def time_shift(self, mu: float, shift_sigma: float, sigmas: torch.Tensor):
59+
return mu / (mu + (1 / sigmas - 1) ** shift_sigma)
60+
61+
5862
def calculate_shift(
59-
image_seq_len, base_seq_len: int = 256, max_seq_len: int = 4096, base_shift: float = 0.5, max_shift: float = 1.15
63+
self,
64+
image_seq_len,
65+
base_seq_len: int = 256,
6066
):
61-
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
62-
b = base_shift - m * base_seq_len
63-
mu = image_seq_len * m + b
64-
return mu
67+
if isinstance(image_seq_len, int):
68+
mu = math.sqrt(image_seq_len / base_seq_len)
69+
elif isinstance(image_seq_len, torch.Tensor):
70+
mu = torch.sqrt(image_seq_len / base_seq_len)
71+
else:
72+
raise ValueError(f'Invalid type for image_seq_len: {type(image_seq_len)}')
6573

74+
mu = mu * 0.75 + 0.25
6675

67-
def time_shift(mu: float, shift_sigma: float, sigmas: torch.Tensor):
68-
return math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1) ** shift_sigma)
69-
76+
return mu
7077

7178
# def retrieve_timesteps(
7279
# scheduler,
@@ -598,7 +605,8 @@ def __call__(
598605
max_sequence_length=max_sequence_length,
599606
device=device,
600607
)
601-
608+
torch.save(prompt_embeds, '/share/home/zyx/prompt_embeds.pt')
609+
torch.save(negative_prompt_embeds, '/share/home/zyx/negative_prompt_embeds.pt')
602610
# 5. Prepare latents.
603611
latent_channels = self.transformer.config.in_channels
604612
latents = self.prepare_latents(

0 commit comments

Comments
 (0)