@@ -582,19 +582,18 @@ def __call__(
582582            device = device ,
583583        )
584584        if  self .do_classifier_free_guidance :
585-             prompt_embeds  =  torch .cat ([negative_prompt_embeds , prompt_embeds ], dim = 0 )
585+             prompt_embeds  =  torch .cat ([negative_prompt_embeds , prompt_embeds ], dim = 1 )
586586
587587        # 4. Prepare timesteps 
588588        timesteps , num_inference_steps  =  retrieve_timesteps (self .scheduler , num_inference_steps , device , timesteps )
589- 
590589        image_seq_len  =  ((height  //  self .vae_scale_factor ) *  (width  //  self .vae_scale_factor )) //  (
591-                  self .transformer .config .patch_size   **   2 
590+             self .transformer .config .patch_size ** 2 
592591        )
593592        mu  =  calculate_shift (image_seq_len )
594593        sigmas  =  timesteps  /  self .scheduler .config .num_train_timesteps 
595594        sigmas  =  torch .cat ([sigmas , torch .zeros (1 , device = sigmas .device )])  # Append zero at the end 
596595
597-         self .sigmas  =  time_shift (mu , 1.0 , sigmas )  # This is for noisy contr  
596+         self .sigmas  =  time_shift (mu , 1.0 , sigmas ). to ( torch . long ). to ( "cpu" )   # This is for noisy control of cogview4  
598597
599598        self ._num_timesteps  =  len (timesteps )
600599
@@ -630,113 +629,59 @@ def __call__(
630629
631630        # 8. Denoising loop 
632631        num_warmup_steps  =  max (len (timesteps ) -  num_inference_steps  *  self .scheduler .order , 0 )
633- 
634632        with  self .progress_bar (total = num_inference_steps ) as  progress_bar :
635633            # for DPM-solver++ 
636634            old_pred_original_sample  =  None 
637-             # for i, t in enumerate(timesteps): 
638-             #     if self.interrupt: 
639-             #         continue 
640-             # 
641-             #     latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents 
642-             #     latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 
643-             # 
644-             #     # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 
645-             #     timestep = t.expand(latent_model_input.shape[0]) 
646-             # 
647-             #     # predict noise model_output 
648-             #     noise_pred = self.transformer( 
649-             #         hidden_states=latent_model_input, 
650-             #         encoder_hidden_states=prompt_embeds, 
651-             #         timestep=timestep, 
652-             #         original_size=original_size, 
653-             #         target_size=target_size, 
654-             #         crop_coords=crops_coords_top_left, 
655-             #         return_dict=False, 
656-             #     )[0] 
657-             #     noise_pred = noise_pred.float() 
658-             # 
659-             #     # perform guidance 
660-             #     if self.do_classifier_free_guidance: 
661-             #         noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 
662-             #         noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) 
663-             # 
664-             #     # compute the previous noisy sample x_t -> x_t-1 
665-             #     if not isinstance(self.scheduler, CogView4DDIMScheduler): 
666-             #         latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] 
667-             #     else: 
668-             #         latents, old_pred_original_sample = self.scheduler.step( 
669-             #             model_output=noise_pred, 
670-             #             timestep=t, 
671-             #             sample=latents, 
672-             #             **extra_step_kwargs, 
673-             #             return_dict=False, 
674-             #         ) 
675-             #     latents = latents.to(prompt_embeds.dtype) 
676-             # 
677-             #     # call the callback, if provided 
678-             #     if callback_on_step_end is not None: 
679-             #         callback_kwargs = {} 
680-             #         for k in callback_on_step_end_tensor_inputs: 
681-             #             callback_kwargs[k] = locals()[k] 
682-             #         callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) 
683-             # 
684-             #         latents = callback_outputs.pop("latents", latents) 
685-             #         prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) 
686-             #         negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) 
687-             # 
688-             #     if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 
689-             #         progress_bar.update() 
690-             # 
691-             #     if XLA_AVAILABLE: 
692-             #         xm.mark_step() 
693-             # 假设 sigmas 已经计算好了,和之前的步骤一样 
694635            for  i , t  in  enumerate (timesteps ):
695636                if  self .interrupt :
696637                    continue 
697638
698-                 # 获取当前的 sigma 和下一个时间步的 sigma 
699-                 sigma  =  sigmas [i ]
700-                 sigma_next  =  sigmas [i  +  1 ] if  i  +  1  <  len (sigmas ) else  sigma   # 防止越界 
701- 
702-                 # 根据 sigmas 修改 latent 模型输入 
703-                 latent_model_input  =  latents  *  sigma   # 使用当前 sigma 调整 latents 
704-                 latent_model_input  =  torch .cat (
705-                     [latent_model_input ] *  2 ) if  self .do_classifier_free_guidance  else  latent_model_input 
639+                 # latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents 
640+                 latent_model_input  =  latents   # For CogView4 concat the text embed and only use prompt 
706641                latent_model_input  =  self .scheduler .scale_model_input (latent_model_input , t )
707642
708-                 # 广播到 batch 维度,以便与 ONNX/Core ML 兼容 
709-                 timestep  =  t .expand (latent_model_input .shape [0 ])
643+                 # Use sigma instead of timestep directly 
644+                 sigma  =  self .sigmas [i ]  # Get the corresponding sigma value 
645+                 timestep  =  sigma .expand (latent_model_input .shape [0 ]).to (device )  # Use sigma to scale the timestep 
710646
711-                 # 预测噪声  
647+                 # predict noise model_output using sigma  
712648                noise_pred  =  self .transformer (
713649                    hidden_states = latent_model_input ,
714650                    encoder_hidden_states = prompt_embeds ,
715-                     timestep = timestep ,
651+                     timestep = timestep ,   # Pass sigma as timestep for noise prediction 
716652                    original_size = original_size ,
717653                    target_size = target_size ,
718654                    crop_coords = crops_coords_top_left ,
719655                    return_dict = False ,
720656                )[0 ]
721657                noise_pred  =  noise_pred .float ()
722658
723-                 # 执行引导  
659+                 # perform guidance  
724660                if  self .do_classifier_free_guidance :
725661                    noise_pred_uncond , noise_pred_text  =  noise_pred .chunk (2 )
726662                    noise_pred  =  noise_pred_uncond  +  self .guidance_scale  *  (noise_pred_text  -  noise_pred_uncond )
727663
728-                 # 根据预测的噪声和 sigmas 更新 latents 
729-                 latents  =  latents  +  (sigma_next  -  sigma ) *  noise_pred   # 使用 sigmas 计算新的 latents 
730- 
731-                 # 或者使用更新后的 latents 进行下一步计算 
664+                 # compute the previous noisy sample x_t -> x_t-1 using sigma (not timestep) 
665+                 if  not  isinstance (self .scheduler , CogView4DDIMScheduler ):
666+                     latents  =  self .scheduler .step (noise_pred , sigma , latents , ** extra_step_kwargs , return_dict = False )[
667+                         0 
668+                     ]
669+                 else :
670+                     latents , old_pred_original_sample  =  self .scheduler .step (
671+                         model_output = noise_pred ,
672+                         timestep = sigma ,  # Use sigma here as timestep 
673+                         sample = latents ,
674+                         ** extra_step_kwargs ,
675+                         return_dict = False ,
676+                     )
732677                latents  =  latents .to (prompt_embeds .dtype )
733678
734-                 # 如果有回调,执行回调  
679+                 # call the callback, if provided  
735680                if  callback_on_step_end  is  not None :
736681                    callback_kwargs  =  {}
737682                    for  k  in  callback_on_step_end_tensor_inputs :
738683                        callback_kwargs [k ] =  locals ()[k ]
739-                     callback_outputs  =  callback_on_step_end (self , i , t , callback_kwargs )
684+                     callback_outputs  =  callback_on_step_end (self , i , sigma , callback_kwargs )
740685
741686                    latents  =  callback_outputs .pop ("latents" , latents )
742687                    prompt_embeds  =  callback_outputs .pop ("prompt_embeds" , prompt_embeds )
@@ -747,6 +692,7 @@ def __call__(
747692
748693                if  XLA_AVAILABLE :
749694                    xm .mark_step ()
695+ 
750696        if  not  output_type  ==  "latent" :
751697            image  =  self .vae .decode (latents  /  self .vae .config .scaling_factor , return_dict = False , generator = generator )[
752698                0 
0 commit comments