@@ -758,7 +758,6 @@ def main(args):
758758 revision = args .revision ,
759759 variant = args .variant ,
760760 )
761- vae_scale_factor = 2 ** (len (vae .config .block_out_channels ) - 1 )
762761 cogview4_transformer = CogView4Transformer2DModel .from_pretrained (
763762 args .pretrained_model_name_or_path ,
764763 subfolder = "transformer" ,
@@ -1081,9 +1080,8 @@ def load_model_hook(models, input_dir):
10811080 #TODO: Should a parameter be set here for passing? This is not present in Flux.
10821081 crops_coords_top_left = torch .tensor ([(0 , 0 )], dtype = prompt_embeds .dtype , device = prompt_embeds .device )
10831082 crops_coords_top_left = crops_coords_top_left .repeat (len (batch ["captions" ]), 1 )
1084-
10851083 # Predict.
1086- model_pred = cogview4_transformer (
1084+ noise_pred_cond = cogview4_transformer (
10871085 hidden_states = concatenated_noisy_model_input ,
10881086 encoder_hidden_states = prompt_embeds ,
10891087 timestep = timesteps ,
@@ -1093,6 +1091,16 @@ def load_model_hook(models, input_dir):
10931091 return_dict = False ,
10941092 )[0 ]
10951093
1094+ noise_pred_uncond = cogview4_transformer (
1095+ hidden_states = concatenated_noisy_model_input ,
1096+ encoder_hidden_states = pooled_prompt_embeds ,
1097+ timestep = timesteps ,
1098+ original_size = original_size ,
1099+ target_size = target_size ,
1100+ crop_coords = crops_coords_top_left ,
1101+ return_dict = False ,
1102+ )[0 ]
1103+ model_pred = noise_pred_uncond + args .guidance_scale * (noise_pred_cond - noise_pred_uncond )
10961104 # these weighting schemes use a uniform timestep sampling
10971105 # and instead post-weight the loss
10981106 weighting = compute_loss_weighting_for_sd3 (weighting_scheme = args .weighting_scheme , sigmas = sigmas )
0 commit comments