|
55 | 55 | """ |
56 | 56 |
|
57 | 57 |
|
| 58 | +def time_shift(self, mu: float, shift_sigma: float, sigmas: torch.Tensor): |
| 59 | + return mu / (mu + (1 / sigmas - 1) ** shift_sigma) |
| 60 | + |
| 61 | + |
58 | 62 | def calculate_shift( |
59 | | - image_seq_len, base_seq_len: int = 256, max_seq_len: int = 4096, base_shift: float = 0.5, max_shift: float = 1.15 |
| 63 | + self, |
| 64 | + image_seq_len, |
| 65 | + base_seq_len: int = 256, |
60 | 66 | ): |
61 | | - m = (max_shift - base_shift) / (max_seq_len - base_seq_len) |
62 | | - b = base_shift - m * base_seq_len |
63 | | - mu = image_seq_len * m + b |
64 | | - return mu |
| 67 | + if isinstance(image_seq_len, int): |
| 68 | + mu = math.sqrt(image_seq_len / base_seq_len) |
| 69 | + elif isinstance(image_seq_len, torch.Tensor): |
| 70 | + mu = torch.sqrt(image_seq_len / base_seq_len) |
| 71 | + else: |
| 72 | + raise ValueError(f'Invalid type for image_seq_len: {type(image_seq_len)}') |
65 | 73 |
|
| 74 | + mu = mu * 0.75 + 0.25 |
66 | 75 |
|
67 | | -def time_shift(mu: float, shift_sigma: float, sigmas: torch.Tensor): |
68 | | - return math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1) ** shift_sigma) |
69 | | - |
| 76 | + return mu |
70 | 77 |
|
71 | 78 | # def retrieve_timesteps( |
72 | 79 | # scheduler, |
@@ -598,7 +605,8 @@ def __call__( |
598 | 605 | max_sequence_length=max_sequence_length, |
599 | 606 | device=device, |
600 | 607 | ) |
601 | | - |
| 608 | + torch.save(prompt_embeds, '/share/home/zyx/prompt_embeds.pt') |
| 609 | + torch.save(negative_prompt_embeds, '/share/home/zyx/negative_prompt_embeds.pt') |
602 | 610 | # 5. Prepare latents. |
603 | 611 | latent_channels = self.transformer.config.in_channels |
604 | 612 | latents = self.prepare_latents( |
|
0 commit comments