Skip to content

Commit 7ab4a3f

Browse files
fix the timestep init and sigma
1 parent ca000dd commit 7ab4a3f

File tree

3 files changed

+101
-46
lines changed

3 files changed

+101
-46
lines changed

scripts/convert_cogview4_to_diffusers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def main(args):
231231
"prediction_type": "v_prediction",
232232
"rescale_betas_zero_snr": True,
233233
"set_alpha_to_one": True,
234-
"timestep_spacing": "trailing",
234+
"timestep_spacing": "linspace",
235235
}
236236
)
237237

src/diffusers/pipelines/cogview4/pipeline_cogview4.py

Lines changed: 98 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import inspect
1717
from typing import Callable, Dict, List, Optional, Tuple, Union
1818

19+
import math
1920
import torch
2021
from transformers import GlmModel
2122

@@ -53,7 +54,19 @@
5354
"""
5455

5556

56-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
57+
def calculate_shift(
58+
image_seq_len, base_seq_len: int = 256, max_seq_len: int = 4096, base_shift: float = 0.5, max_shift: float = 1.15
59+
):
60+
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
61+
b = base_shift - m * base_seq_len
62+
mu = image_seq_len * m + b
63+
return mu
64+
65+
66+
def time_shift(mu: float, shift_sigma: float, sigmas: torch.Tensor):
67+
return math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1) ** shift_sigma)
68+
69+
5770
def retrieve_timesteps(
5871
scheduler,
5972
num_inference_steps: Optional[int] = None,
@@ -203,7 +216,7 @@ def _get_glm_embeds(
203216
)
204217
text_input_ids = torch.cat([pad_ids, text_input_ids], dim=1)
205218

206-
prompt_embeds = self.text_encoder.model.embed_tokens(text_input_ids)[0]
219+
prompt_embeds = self.text_encoder.model.embed_tokens(text_input_ids.to(self.text_encoder.model.device))[0]
207220
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
208221
seq_len, _ = prompt_embeds.shape
209222
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
@@ -573,6 +586,16 @@ def __call__(
573586

574587
# 4. Prepare timesteps
575588
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
589+
590+
image_seq_len = ((height // self.vae_scale_factor) * (width // self.vae_scale_factor)) // (
591+
self.transformer.config.patch_size ** 2
592+
)
593+
mu = calculate_shift(image_seq_len)
594+
sigmas = timesteps / self.scheduler.config.num_train_timesteps
595+
sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) # Append zero at the end
596+
597+
self.sigmas = time_shift(mu, 1.0, sigmas) # This is for noisy contr
598+
576599
self._num_timesteps = len(timesteps)
577600

578601
# 5. Prepare latents.
@@ -611,17 +634,81 @@ def __call__(
611634
with self.progress_bar(total=num_inference_steps) as progress_bar:
612635
# for DPM-solver++
613636
old_pred_original_sample = None
637+
# for i, t in enumerate(timesteps):
638+
# if self.interrupt:
639+
# continue
640+
#
641+
# latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
642+
# latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
643+
#
644+
# # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
645+
# timestep = t.expand(latent_model_input.shape[0])
646+
#
647+
# # predict noise model_output
648+
# noise_pred = self.transformer(
649+
# hidden_states=latent_model_input,
650+
# encoder_hidden_states=prompt_embeds,
651+
# timestep=timestep,
652+
# original_size=original_size,
653+
# target_size=target_size,
654+
# crop_coords=crops_coords_top_left,
655+
# return_dict=False,
656+
# )[0]
657+
# noise_pred = noise_pred.float()
658+
#
659+
# # perform guidance
660+
# if self.do_classifier_free_guidance:
661+
# noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
662+
# noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
663+
#
664+
# # compute the previous noisy sample x_t -> x_t-1
665+
# if not isinstance(self.scheduler, CogView4DDIMScheduler):
666+
# latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
667+
# else:
668+
# latents, old_pred_original_sample = self.scheduler.step(
669+
# model_output=noise_pred,
670+
# timestep=t,
671+
# sample=latents,
672+
# **extra_step_kwargs,
673+
# return_dict=False,
674+
# )
675+
# latents = latents.to(prompt_embeds.dtype)
676+
#
677+
# # call the callback, if provided
678+
# if callback_on_step_end is not None:
679+
# callback_kwargs = {}
680+
# for k in callback_on_step_end_tensor_inputs:
681+
# callback_kwargs[k] = locals()[k]
682+
# callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
683+
#
684+
# latents = callback_outputs.pop("latents", latents)
685+
# prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
686+
# negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
687+
#
688+
# if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
689+
# progress_bar.update()
690+
#
691+
# if XLA_AVAILABLE:
692+
# xm.mark_step()
693+
# 假设 sigmas 已经计算好了,和之前的步骤一样
614694
for i, t in enumerate(timesteps):
615695
if self.interrupt:
616696
continue
617697

618-
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
698+
# 获取当前的 sigma 和下一个时间步的 sigma
699+
sigma = sigmas[i]
700+
sigma_next = sigmas[i + 1] if i + 1 < len(sigmas) else sigma # 防止越界
701+
702+
# 根据 sigmas 修改 latent 模型输入
703+
latent_model_input = latents * sigma # 使用当前 sigma 调整 latents
704+
latent_model_input = torch.cat(
705+
[latent_model_input] * 2) if self.do_classifier_free_guidance else latent_model_input
619706
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
620707

621-
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
708+
# 广播到 batch 维度,以便与 ONNX/Core ML 兼容
622709
timestep = t.expand(latent_model_input.shape[0])
623710

624-
# predict noise model_output
711+
# 预测噪声
625712
noise_pred = self.transformer(
626713
hidden_states=latent_model_input,
627714
encoder_hidden_states=prompt_embeds,
@@ -633,25 +720,18 @@ def __call__(
633720
)[0]
634721
noise_pred = noise_pred.float()
635722

636-
# perform guidance
723+
# 执行引导
637724
if self.do_classifier_free_guidance:
638725
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
639726
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
640727

641-
# compute the previous noisy sample x_t -> x_t-1
642-
if not isinstance(self.scheduler, CogView4DDIMScheduler):
643-
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
644-
else:
645-
latents, old_pred_original_sample = self.scheduler.step(
646-
model_output=noise_pred,
647-
timestep=t,
648-
sample=latents,
649-
**extra_step_kwargs,
650-
return_dict=False,
651-
)
728+
# 根据预测的噪声和 sigmas 更新 latents
729+
latents = latents + (sigma_next - sigma) * noise_pred # 使用 sigmas 计算新的 latents
730+
731+
# 或者使用更新后的 latents 进行下一步计算
652732
latents = latents.to(prompt_embeds.dtype)
653733

654-
# call the callback, if provided
734+
# 如果有回调,执行回调
655735
if callback_on_step_end is not None:
656736
callback_kwargs = {}
657737
for k in callback_on_step_end_tensor_inputs:
@@ -667,7 +747,6 @@ def __call__(
667747

668748
if XLA_AVAILABLE:
669749
xm.mark_step()
670-
671750
if not output_type == "latent":
672751
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
673752
0

src/diffusers/schedulers/scheduling_ddim_cogview4.py

Lines changed: 2 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -318,10 +318,8 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
318318
# Generate timesteps according to the specified spacing method
319319
if self.config.timestep_spacing == "linspace":
320320
timesteps = (
321-
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps)
322-
.round()[::-1]
323-
.copy()
324-
.astype(np.int64)
321+
np.linspace(self.config.num_train_timesteps, 1, num_inference_steps)
322+
.astype(np.int64) # Only for CogView4
325323
)
326324
elif self.config.timestep_spacing == "leading":
327325
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
@@ -339,28 +337,6 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
339337
# Convert the numpy array of timesteps into a PyTorch tensor
340338
self.timesteps = torch.from_numpy(timesteps).to(device)
341339

342-
# ===== change for cogview4 ====
343-
# The new dynamic shifting code starts here.
344-
345-
# Convert integer timesteps to float for further manipulation
346-
times_float = self.timesteps.float() / float(self.config.num_train_timesteps)
347-
348-
# Apply the shift_scale factor
349-
times_float = self.config.shift_scale * times_float
350-
351-
# Convert the shifted floats back to integer indices for timesteps
352-
new_timesteps = (times_float * self.config.num_train_timesteps).round().long().clamp_min(0)
353-
354-
# Ensure the timesteps are in descending order and unique
355-
new_timesteps = new_timesteps.unique().flip(0)
356-
if len(new_timesteps) == 0:
357-
# If all values somehow got collapsed, fallback to a single timestep
358-
new_timesteps = torch.zeros(1, dtype=torch.long, device=device)
359-
360-
# Overwrite the original timesteps with our newly shifted timesteps
361-
self.timesteps = new_timesteps
362-
# =====
363-
364340
def step(
365341
self,
366342
model_output: torch.Tensor,

0 commit comments

Comments
 (0)