Skip to content

Commit 56ceaa6

Browse files
update latent
1 parent 7ab4a3f commit 56ceaa6

File tree

1 file changed

+27
-81
lines changed

1 file changed

+27
-81
lines changed

src/diffusers/pipelines/cogview4/pipeline_cogview4.py

Lines changed: 27 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -582,19 +582,18 @@ def __call__(
582582
device=device,
583583
)
584584
if self.do_classifier_free_guidance:
585-
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
585+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=1)
586586

587587
# 4. Prepare timesteps
588588
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
589-
590589
image_seq_len = ((height // self.vae_scale_factor) * (width // self.vae_scale_factor)) // (
591-
self.transformer.config.patch_size ** 2
590+
self.transformer.config.patch_size**2
592591
)
593592
mu = calculate_shift(image_seq_len)
594593
sigmas = timesteps / self.scheduler.config.num_train_timesteps
595594
sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) # Append zero at the end
596595

597-
self.sigmas = time_shift(mu, 1.0, sigmas) # This is for noisy contr
596+
self.sigmas = time_shift(mu, 1.0, sigmas).to(torch.long).to("cpu") # This is for noisy control of cogview4
598597

599598
self._num_timesteps = len(timesteps)
600599

@@ -630,113 +629,59 @@ def __call__(
630629

631630
# 8. Denoising loop
632631
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
633-
634632
with self.progress_bar(total=num_inference_steps) as progress_bar:
635633
# for DPM-solver++
636634
old_pred_original_sample = None
637-
# for i, t in enumerate(timesteps):
638-
# if self.interrupt:
639-
# continue
640-
#
641-
# latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
642-
# latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
643-
#
644-
# # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
645-
# timestep = t.expand(latent_model_input.shape[0])
646-
#
647-
# # predict noise model_output
648-
# noise_pred = self.transformer(
649-
# hidden_states=latent_model_input,
650-
# encoder_hidden_states=prompt_embeds,
651-
# timestep=timestep,
652-
# original_size=original_size,
653-
# target_size=target_size,
654-
# crop_coords=crops_coords_top_left,
655-
# return_dict=False,
656-
# )[0]
657-
# noise_pred = noise_pred.float()
658-
#
659-
# # perform guidance
660-
# if self.do_classifier_free_guidance:
661-
# noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
662-
# noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
663-
#
664-
# # compute the previous noisy sample x_t -> x_t-1
665-
# if not isinstance(self.scheduler, CogView4DDIMScheduler):
666-
# latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
667-
# else:
668-
# latents, old_pred_original_sample = self.scheduler.step(
669-
# model_output=noise_pred,
670-
# timestep=t,
671-
# sample=latents,
672-
# **extra_step_kwargs,
673-
# return_dict=False,
674-
# )
675-
# latents = latents.to(prompt_embeds.dtype)
676-
#
677-
# # call the callback, if provided
678-
# if callback_on_step_end is not None:
679-
# callback_kwargs = {}
680-
# for k in callback_on_step_end_tensor_inputs:
681-
# callback_kwargs[k] = locals()[k]
682-
# callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
683-
#
684-
# latents = callback_outputs.pop("latents", latents)
685-
# prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
686-
# negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
687-
#
688-
# if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
689-
# progress_bar.update()
690-
#
691-
# if XLA_AVAILABLE:
692-
# xm.mark_step()
693-
# 假设 sigmas 已经计算好了,和之前的步骤一样
694635
for i, t in enumerate(timesteps):
695636
if self.interrupt:
696637
continue
697638

698-
# 获取当前的 sigma 和下一个时间步的 sigma
699-
sigma = sigmas[i]
700-
sigma_next = sigmas[i + 1] if i + 1 < len(sigmas) else sigma # 防止越界
701-
702-
# 根据 sigmas 修改 latent 模型输入
703-
latent_model_input = latents * sigma # 使用当前 sigma 调整 latents
704-
latent_model_input = torch.cat(
705-
[latent_model_input] * 2) if self.do_classifier_free_guidance else latent_model_input
639+
# latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
640+
latent_model_input = latents # For CogView4 concat the text embed and only use prompt
706641
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
707642

708-
# 广播到 batch 维度,以便与 ONNX/Core ML 兼容
709-
timestep = t.expand(latent_model_input.shape[0])
643+
# Use sigma instead of timestep directly
644+
sigma = self.sigmas[i] # Get the corresponding sigma value
645+
timestep = sigma.expand(latent_model_input.shape[0]).to(device) # Use sigma to scale the timestep
710646

711-
# 预测噪声
647+
# predict noise model_output using sigma
712648
noise_pred = self.transformer(
713649
hidden_states=latent_model_input,
714650
encoder_hidden_states=prompt_embeds,
715-
timestep=timestep,
651+
timestep=timestep, # Pass sigma as timestep for noise prediction
716652
original_size=original_size,
717653
target_size=target_size,
718654
crop_coords=crops_coords_top_left,
719655
return_dict=False,
720656
)[0]
721657
noise_pred = noise_pred.float()
722658

723-
# 执行引导
659+
# perform guidance
724660
if self.do_classifier_free_guidance:
725661
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
726662
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
727663

728-
# 根据预测的噪声和 sigmas 更新 latents
729-
latents = latents + (sigma_next - sigma) * noise_pred # 使用 sigmas 计算新的 latents
730-
731-
# 或者使用更新后的 latents 进行下一步计算
664+
# compute the previous noisy sample x_t -> x_t-1 using sigma (not timestep)
665+
if not isinstance(self.scheduler, CogView4DDIMScheduler):
666+
latents = self.scheduler.step(noise_pred, sigma, latents, **extra_step_kwargs, return_dict=False)[
667+
0
668+
]
669+
else:
670+
latents, old_pred_original_sample = self.scheduler.step(
671+
model_output=noise_pred,
672+
timestep=sigma, # Use sigma here as timestep
673+
sample=latents,
674+
**extra_step_kwargs,
675+
return_dict=False,
676+
)
732677
latents = latents.to(prompt_embeds.dtype)
733678

734-
# 如果有回调,执行回调
679+
# call the callback, if provided
735680
if callback_on_step_end is not None:
736681
callback_kwargs = {}
737682
for k in callback_on_step_end_tensor_inputs:
738683
callback_kwargs[k] = locals()[k]
739-
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
684+
callback_outputs = callback_on_step_end(self, i, sigma, callback_kwargs)
740685

741686
latents = callback_outputs.pop("latents", latents)
742687
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
@@ -747,6 +692,7 @@ def __call__(
747692

748693
if XLA_AVAILABLE:
749694
xm.mark_step()
695+
750696
if not output_type == "latent":
751697
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
752698
0

0 commit comments

Comments
 (0)