@@ -127,7 +127,6 @@ def log_validation(cogview4_transformer, args, accelerator, weight_dtype, step,
127127 num_inference_steps = 50 ,
128128 guidance_scale = args .guidance_scale ,
129129 generator = generator ,
130- max_sequence_length = 512 ,
131130 height = args .resolution ,
132131 width = args .resolution ,
133132 ).images [0 ]
@@ -1075,7 +1074,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
10751074 mu = torch .sqrt (image_seq_lens / 256 )
10761075 mu = mu * 0.75 + 0.25
10771076 scale_factors = mu / (mu + (1 / sigmas - 1 ) ** 1.0 ).to (dtype = pixel_latents .dtype , device = pixel_latents .device )
1078- scale_factors = scale_factors .view (4 , 1 , 1 , 1 )
1077+ scale_factors = scale_factors .view (len ( batch [ "captions" ]) , 1 , 1 , 1 )
10791078 noisy_model_input = (1.0 - scale_factors ) * pixel_latents + scale_factors * noise
10801079 concatenated_noisy_model_input = torch .cat ([noisy_model_input , control_latents ], dim = 1 )
10811080 text_encoding_pipeline = text_encoding_pipeline .to ("cuda" )
@@ -1114,7 +1113,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
11141113 # flow-matching loss
11151114 target = noise - pixel_latents
11161115
1117- weighting = weighting .unsqueeze ( 1 ). unsqueeze ( 2 ). unsqueeze ( 3 ) # [4 , 1, 1, 1]
1116+ weighting = weighting .view ( len ( batch [ "captions" ]) , 1 , 1 , 1 )
11181117 loss = torch .mean ((weighting .float () * (model_pred .float () - target .float ()) ** 2 ).reshape (target .shape [0 ], - 1 ),1 )
11191118 loss = loss .mean ()
11201119 accelerator .backward (loss )
0 commit comments