1818import logging
1919import math
2020import os
21+ import random
2122import shutil
2223from contextlib import nullcontext
2324from pathlib import Path
@@ -1094,6 +1095,14 @@ def load_model_hook(models, input_dir):
10941095 # TODO: Should a parameter be set here for passing? This is not present in Flux.
10951096 crops_coords_top_left = torch .tensor ([(0 , 0 )], dtype = prompt_embeds .dtype , device = prompt_embeds .device )
10961097 crops_coords_top_left = crops_coords_top_left .repeat (len (batch ["captions" ]), 1 )
1098+
1099+ # this could be optimized by not having to do any text encoding and just
1100+ # doing zeros on specified shapes for `prompt_embeds` and `pooled_prompt_embeds`
1101+ if args .proportion_empty_prompts and random .random () < args .proportion_empty_prompts :
1102+ # 这里,直接将 pooled_prompt_embeds 16个 pad token 提供给 prompt_embeds
1103+ prompt_embeds = pooled_prompt_embeds
1104+ if args .offload :
1105+ text_encoding_pipeline = text_encoding_pipeline .to ("cpu" )
10971106 # Predict.
10981107 noise_pred_cond = cogview4_transformer (
10991108 hidden_states = concatenated_noisy_model_input ,
@@ -1104,17 +1113,6 @@ def load_model_hook(models, input_dir):
11041113 crop_coords = crops_coords_top_left ,
11051114 return_dict = False ,
11061115 )[0 ]
1107-
1108- noise_pred_uncond = cogview4_transformer (
1109- hidden_states = concatenated_noisy_model_input ,
1110- encoder_hidden_states = pooled_prompt_embeds ,
1111- timestep = timesteps ,
1112- original_size = original_size ,
1113- target_size = target_size ,
1114- crop_coords = crops_coords_top_left ,
1115- return_dict = False ,
1116- )[0 ]
1117- model_pred = noise_pred_uncond + (noise_pred_cond - noise_pred_uncond )
11181116 # these weighting schemes use a uniform timestep sampling
11191117 # and instead post-weight the loss
11201118 weighting = compute_loss_weighting_for_sd3 (weighting_scheme = args .weighting_scheme , sigmas = sigmas )
@@ -1123,7 +1121,7 @@ def load_model_hook(models, input_dir):
11231121
11241122 weighting = weighting .view (len (batch ["captions" ]), 1 , 1 , 1 )
11251123 loss = torch .mean (
1126- (weighting .float () * (model_pred .float () - target .float ()) ** 2 ).reshape (target .shape [0 ], - 1 ), 1
1124+ (weighting .float () * (noise_pred_cond .float () - target .float ()) ** 2 ).reshape (target .shape [0 ], - 1 ), 1
11271125 )
11281126 loss = loss .mean ()
11291127 accelerator .backward (loss )
0 commit comments