@@ -173,9 +173,13 @@ def convert_transformer(
173173 return transformer
174174
175175
176- def convert_vae (ckpt_path : str , scaling_factor : float , dtype : torch .dtype ):
176+ def convert_vae (ckpt_path : str , scaling_factor : float , version : str , dtype : torch .dtype ):
177+ init_kwargs = {"scaling_factor" : scaling_factor }
178+ if args .version == "1.5" :
179+ init_kwargs .update ({"invert_scale_latents" : True })
180+
177181 original_state_dict = get_state_dict (torch .load (ckpt_path , map_location = "cpu" , mmap = True ))
178- vae = AutoencoderKLCogVideoX (scaling_factor = scaling_factor ).to (dtype = dtype )
182+ vae = AutoencoderKLCogVideoX (** init_kwargs ).to (dtype = dtype )
179183
180184 for key in list (original_state_dict .keys ()):
181185 new_key = key [:]
@@ -193,7 +197,7 @@ def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype):
193197 return vae
194198
195199
196- def get_init_kwargs (version : str ):
200+ def get_transformer_init_kwargs (version : str ):
197201 if version == "1.0" :
198202 vae_scale_factor_spatial = 8
199203 init_kwargs = {
@@ -281,7 +285,7 @@ def get_args():
281285 dtype = torch .float16 if args .fp16 else torch .bfloat16 if args .bf16 else torch .float32
282286
283287 if args .transformer_ckpt_path is not None :
284- init_kwargs = get_init_kwargs (args .version )
288+ init_kwargs = get_transformer_init_kwargs (args .version )
285289 transformer = convert_transformer (
286290 args .transformer_ckpt_path ,
287291 args .num_layers ,
@@ -293,7 +297,7 @@ def get_args():
293297 )
294298 if args .vae_ckpt_path is not None :
295299 # Keep VAE in float32 for better quality
296- vae = convert_vae (args .vae_ckpt_path , args .scaling_factor , torch .float32 )
300+ vae = convert_vae (args .vae_ckpt_path , args .scaling_factor , args . version , torch .float32 )
297301
298302 text_encoder_id = "google/t5-v1_1-xxl"
299303 tokenizer = T5Tokenizer .from_pretrained (text_encoder_id , model_max_length = TOKENIZER_MAX_LENGTH )
0 commit comments