1616import inspect
1717from typing import Callable , Dict , List , Optional , Tuple , Union
1818
19+ import math
1920import torch
2021from transformers import GlmModel
2122
5354"""
5455
5556
56- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
57+ def calculate_shift (
58+ image_seq_len , base_seq_len : int = 256 , max_seq_len : int = 4096 , base_shift : float = 0.5 , max_shift : float = 1.15
59+ ):
60+ m = (max_shift - base_shift ) / (max_seq_len - base_seq_len )
61+ b = base_shift - m * base_seq_len
62+ mu = image_seq_len * m + b
63+ return mu
64+
65+
66+ def time_shift (mu : float , shift_sigma : float , sigmas : torch .Tensor ):
67+ return math .exp (mu ) / (math .exp (mu ) + (1 / sigmas - 1 ) ** shift_sigma )
68+
69+
5770def retrieve_timesteps (
5871 scheduler ,
5972 num_inference_steps : Optional [int ] = None ,
@@ -203,7 +216,7 @@ def _get_glm_embeds(
203216 )
204217 text_input_ids = torch .cat ([pad_ids , text_input_ids ], dim = 1 )
205218
206- prompt_embeds = self .text_encoder .model .embed_tokens (text_input_ids )[0 ]
219+ prompt_embeds = self .text_encoder .model .embed_tokens (text_input_ids . to ( self . text_encoder . model . device ) )[0 ]
207220 prompt_embeds = prompt_embeds .to (dtype = dtype , device = device )
208221 seq_len , _ = prompt_embeds .shape
209222 prompt_embeds = prompt_embeds .repeat (1 , num_images_per_prompt , 1 )
@@ -573,6 +586,16 @@ def __call__(
573586
574587 # 4. Prepare timesteps
575588 timesteps , num_inference_steps = retrieve_timesteps (self .scheduler , num_inference_steps , device , timesteps )
589+
590+ image_seq_len = ((height // self .vae_scale_factor ) * (width // self .vae_scale_factor )) // (
591+ self .transformer .config .patch_size ** 2
592+ )
593+ mu = calculate_shift (image_seq_len )
594+ sigmas = timesteps / self .scheduler .config .num_train_timesteps
595+ sigmas = torch .cat ([sigmas , torch .zeros (1 , device = sigmas .device )]) # Append zero at the end
596+
597+ self .sigmas = time_shift (mu , 1.0 , sigmas ) # This is for noisy contr
598+
576599 self ._num_timesteps = len (timesteps )
577600
578601 # 5. Prepare latents.
@@ -611,17 +634,81 @@ def __call__(
611634 with self .progress_bar (total = num_inference_steps ) as progress_bar :
612635 # for DPM-solver++
613636 old_pred_original_sample = None
637+ # for i, t in enumerate(timesteps):
638+ # if self.interrupt:
639+ # continue
640+ #
641+ # latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
642+ # latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
643+ #
644+ # # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
645+ # timestep = t.expand(latent_model_input.shape[0])
646+ #
647+ # # predict noise model_output
648+ # noise_pred = self.transformer(
649+ # hidden_states=latent_model_input,
650+ # encoder_hidden_states=prompt_embeds,
651+ # timestep=timestep,
652+ # original_size=original_size,
653+ # target_size=target_size,
654+ # crop_coords=crops_coords_top_left,
655+ # return_dict=False,
656+ # )[0]
657+ # noise_pred = noise_pred.float()
658+ #
659+ # # perform guidance
660+ # if self.do_classifier_free_guidance:
661+ # noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
662+ # noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
663+ #
664+ # # compute the previous noisy sample x_t -> x_t-1
665+ # if not isinstance(self.scheduler, CogView4DDIMScheduler):
666+ # latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
667+ # else:
668+ # latents, old_pred_original_sample = self.scheduler.step(
669+ # model_output=noise_pred,
670+ # timestep=t,
671+ # sample=latents,
672+ # **extra_step_kwargs,
673+ # return_dict=False,
674+ # )
675+ # latents = latents.to(prompt_embeds.dtype)
676+ #
677+ # # call the callback, if provided
678+ # if callback_on_step_end is not None:
679+ # callback_kwargs = {}
680+ # for k in callback_on_step_end_tensor_inputs:
681+ # callback_kwargs[k] = locals()[k]
682+ # callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
683+ #
684+ # latents = callback_outputs.pop("latents", latents)
685+ # prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
686+ # negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
687+ #
688+ # if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
689+ # progress_bar.update()
690+ #
691+ # if XLA_AVAILABLE:
692+ # xm.mark_step()
693+ # 假设 sigmas 已经计算好了,和之前的步骤一样
614694 for i , t in enumerate (timesteps ):
615695 if self .interrupt :
616696 continue
617697
618- latent_model_input = torch .cat ([latents ] * 2 ) if self .do_classifier_free_guidance else latents
698+ # 获取当前的 sigma 和下一个时间步的 sigma
699+ sigma = sigmas [i ]
700+ sigma_next = sigmas [i + 1 ] if i + 1 < len (sigmas ) else sigma # 防止越界
701+
702+ # 根据 sigmas 修改 latent 模型输入
703+ latent_model_input = latents * sigma # 使用当前 sigma 调整 latents
704+ latent_model_input = torch .cat (
705+ [latent_model_input ] * 2 ) if self .do_classifier_free_guidance else latent_model_input
619706 latent_model_input = self .scheduler .scale_model_input (latent_model_input , t )
620707
621- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
708+ # 广播到 batch 维度,以便与 ONNX/Core ML 兼容
622709 timestep = t .expand (latent_model_input .shape [0 ])
623710
624- # predict noise model_output
711+ # 预测噪声
625712 noise_pred = self .transformer (
626713 hidden_states = latent_model_input ,
627714 encoder_hidden_states = prompt_embeds ,
@@ -633,25 +720,18 @@ def __call__(
633720 )[0 ]
634721 noise_pred = noise_pred .float ()
635722
636- # perform guidance
723+ # 执行引导
637724 if self .do_classifier_free_guidance :
638725 noise_pred_uncond , noise_pred_text = noise_pred .chunk (2 )
639726 noise_pred = noise_pred_uncond + self .guidance_scale * (noise_pred_text - noise_pred_uncond )
640727
641- # compute the previous noisy sample x_t -> x_t-1
642- if not isinstance (self .scheduler , CogView4DDIMScheduler ):
643- latents = self .scheduler .step (noise_pred , t , latents , ** extra_step_kwargs , return_dict = False )[0 ]
644- else :
645- latents , old_pred_original_sample = self .scheduler .step (
646- model_output = noise_pred ,
647- timestep = t ,
648- sample = latents ,
649- ** extra_step_kwargs ,
650- return_dict = False ,
651- )
728+ # 根据预测的噪声和 sigmas 更新 latents
729+ latents = latents + (sigma_next - sigma ) * noise_pred # 使用 sigmas 计算新的 latents
730+
731+ # 或者使用更新后的 latents 进行下一步计算
652732 latents = latents .to (prompt_embeds .dtype )
653733
654- # call the callback, if provided
734+ # 如果有回调,执行回调
655735 if callback_on_step_end is not None :
656736 callback_kwargs = {}
657737 for k in callback_on_step_end_tensor_inputs :
@@ -667,7 +747,6 @@ def __call__(
667747
668748 if XLA_AVAILABLE :
669749 xm .mark_step ()
670-
671750 if not output_type == "latent" :
672751 image = self .vae .decode (latents / self .vae .config .scaling_factor , return_dict = False , generator = generator )[
673752 0
0 commit comments