@@ -92,6 +92,8 @@ def replace_up_keys_inplace(key: str, state_dict: Dict[str, Any]):
9292 "post_attn1_layernorm" : "norm2.norm" ,
9393 "time_embed.0" : "time_embedding.linear_1" ,
9494 "time_embed.2" : "time_embedding.linear_2" ,
95+ "ofs_embed.0" : "ofs_embedding.linear_1" ,
96+ "ofs_embed.2" : "ofs_embedding.linear_2" ,
9597 "mixins.patch_embed" : "patch_embed" ,
9698 "mixins.final_layer.norm_final" : "norm_out.norm" ,
9799 "mixins.final_layer.linear" : "proj_out" ,
@@ -146,12 +148,13 @@ def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key:
146148
147149
148150def convert_transformer (
149- ckpt_path : str ,
150- num_layers : int ,
151- num_attention_heads : int ,
152- use_rotary_positional_embeddings : bool ,
153- i2v : bool ,
154- dtype : torch .dtype ,
151+ ckpt_path : str ,
152+ num_layers : int ,
153+ num_attention_heads : int ,
154+ use_rotary_positional_embeddings : bool ,
155+ i2v : bool ,
156+ dtype : torch .dtype ,
157+ init_kwargs : Dict [str , Any ],
155158):
156159 PREFIX_KEY = "model.diffusion_model."
157160
@@ -161,11 +164,13 @@ def convert_transformer(
161164 num_layers = num_layers ,
162165 num_attention_heads = num_attention_heads ,
163166 use_rotary_positional_embeddings = use_rotary_positional_embeddings ,
164- use_learned_positional_embeddings = i2v ,
167+ ofs_embed_dim = 512 if (i2v and init_kwargs ["patch_size_t" ] is not None ) else None , # CogVideoX1.5-5B-I2V
168+ use_learned_positional_embeddings = i2v and init_kwargs ["patch_size_t" ] is None , # CogVideoX-5B-I2V
169+ ** init_kwargs ,
165170 ).to (dtype = dtype )
166171
167172 for key in list (original_state_dict .keys ()):
168- new_key = key [len (PREFIX_KEY ):]
173+ new_key = key [len (PREFIX_KEY ) :]
169174 for replace_key , rename_key in TRANSFORMER_KEYS_RENAME_DICT .items ():
170175 new_key = new_key .replace (replace_key , rename_key )
171176 update_state_dict_inplace (original_state_dict , key , new_key )
@@ -175,13 +180,18 @@ def convert_transformer(
175180 if special_key not in key :
176181 continue
177182 handler_fn_inplace (key , original_state_dict )
183+
178184 transformer .load_state_dict (original_state_dict , strict = True )
179185 return transformer
180186
181187
182- def convert_vae (ckpt_path : str , scaling_factor : float , dtype : torch .dtype ):
188+ def convert_vae (ckpt_path : str , scaling_factor : float , version : str , dtype : torch .dtype ):
189+ init_kwargs = {"scaling_factor" : scaling_factor }
190+ if version == "1.5" :
191+ init_kwargs .update ({"invert_scale_latents" : True })
192+
183193 original_state_dict = get_state_dict (torch .load (ckpt_path , map_location = "cpu" , mmap = True ))
184- vae = AutoencoderKLCogVideoX (scaling_factor = scaling_factor ).to (dtype = dtype )
194+ vae = AutoencoderKLCogVideoX (** init_kwargs ).to (dtype = dtype )
185195
186196 for key in list (original_state_dict .keys ()):
187197 new_key = key [:]
@@ -199,6 +209,34 @@ def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype):
199209 return vae
200210
201211
212+ def get_transformer_init_kwargs (version : str ):
213+ if version == "1.0" :
214+ vae_scale_factor_spatial = 8
215+ init_kwargs = {
216+ "patch_size" : 2 ,
217+ "patch_size_t" : None ,
218+ "patch_bias" : True ,
219+ "sample_height" : 480 // vae_scale_factor_spatial ,
220+ "sample_width" : 720 // vae_scale_factor_spatial ,
221+ "sample_frames" : 49 ,
222+ }
223+
224+ elif version == "1.5" :
225+ vae_scale_factor_spatial = 8
226+ init_kwargs = {
227+ "patch_size" : 2 ,
228+ "patch_size_t" : 2 ,
229+ "patch_bias" : False ,
230+ "sample_height" : 768 // vae_scale_factor_spatial ,
231+ "sample_width" : 1360 // vae_scale_factor_spatial ,
232+ "sample_frames" : 81 ,
233+ }
234+ else :
235+ raise ValueError ("Unsupported version of CogVideoX." )
236+
237+ return init_kwargs
238+
239+
202240def get_args ():
203241 parser = argparse .ArgumentParser ()
204242 parser .add_argument (
@@ -214,6 +252,12 @@ def get_args():
214252 parser .add_argument (
215253 "--text_encoder_cache_dir" , type = str , default = None , help = "Path to text encoder cache directory"
216254 )
255+ parser .add_argument (
256+ "--typecast_text_encoder" ,
257+ action = "store_true" ,
258+ default = False ,
259+ help = "Whether or not to apply fp16/bf16 precision to text_encoder" ,
260+ )
217261 # For CogVideoX-2B, num_layers is 30. For 5B, it is 42
218262 parser .add_argument ("--num_layers" , type = int , default = 30 , help = "Number of transformer blocks" )
219263 # For CogVideoX-2B, num_attention_heads is 30. For 5B, it is 48
@@ -226,7 +270,18 @@ def get_args():
226270 parser .add_argument ("--scaling_factor" , type = float , default = 1.15258426 , help = "Scaling factor in the VAE" )
227271 # For CogVideoX-2B, snr_shift_scale is 3.0. For 5B, it is 1.0
228272 parser .add_argument ("--snr_shift_scale" , type = float , default = 3.0 , help = "Scaling factor in the VAE" )
229- parser .add_argument ("--i2v" , action = "store_true" , default = False , help = "Whether to save the model weights in fp16" )
273+ parser .add_argument (
274+ "--i2v" ,
275+ action = "store_true" ,
276+ default = False ,
277+ help = "Whether the model to be converted is the Image-to-Video version of CogVideoX." ,
278+ )
279+ parser .add_argument (
280+ "--version" ,
281+ choices = ["1.0" , "1.5" ],
282+ default = "1.0" ,
283+ help = "Which version of CogVideoX to use for initializing default modeling parameters." ,
284+ )
230285 return parser .parse_args ()
231286
232287
@@ -242,21 +297,27 @@ def get_args():
242297 dtype = torch .float16 if args .fp16 else torch .bfloat16 if args .bf16 else torch .float32
243298
244299 if args .transformer_ckpt_path is not None :
300+ init_kwargs = get_transformer_init_kwargs (args .version )
245301 transformer = convert_transformer (
246302 args .transformer_ckpt_path ,
247303 args .num_layers ,
248304 args .num_attention_heads ,
249305 args .use_rotary_positional_embeddings ,
250306 args .i2v ,
251307 dtype ,
308+ init_kwargs ,
252309 )
253310 if args .vae_ckpt_path is not None :
254- vae = convert_vae (args .vae_ckpt_path , args .scaling_factor , dtype )
311+ # Keep VAE in float32 for better quality
312+ vae = convert_vae (args .vae_ckpt_path , args .scaling_factor , args .version , torch .float32 )
255313
256- text_encoder_id = "/share/official_pretrains/hf_home /t5-v1_1-xxl"
314+ text_encoder_id = "google /t5-v1_1-xxl"
257315 tokenizer = T5Tokenizer .from_pretrained (text_encoder_id , model_max_length = TOKENIZER_MAX_LENGTH )
258316 text_encoder = T5EncoderModel .from_pretrained (text_encoder_id , cache_dir = args .text_encoder_cache_dir )
259317
318+ if args .typecast_text_encoder :
319+ text_encoder = text_encoder .to (dtype = dtype )
320+
260321 # Apparently, the conversion does not work anymore without this :shrug:
261322 for param in text_encoder .parameters ():
262323 param .data = param .data .contiguous ()
@@ -288,11 +349,6 @@ def get_args():
288349 scheduler = scheduler ,
289350 )
290351
291- if args .fp16 :
292- pipe = pipe .to (dtype = torch .float16 )
293- if args .bf16 :
294- pipe = pipe .to (dtype = torch .bfloat16 )
295-
296352 # We don't use variant here because the model must be run in fp16 (2B) or bf16 (5B). It would be weird
297353 # for users to specify variant when the default is not fp32 and they want to run with the correct default (which
298354 # is either fp16/bf16 here).
0 commit comments