@@ -80,6 +80,8 @@ def replace_up_keys_inplace(key: str, state_dict: Dict[str, Any]):
8080 "post_attn1_layernorm" : "norm2.norm" ,
8181 "time_embed.0" : "time_embedding.linear_1" ,
8282 "time_embed.2" : "time_embedding.linear_2" ,
83+ "ofs_embed.0" : "ofs_embedding.linear_1" ,
84+ "ofs_embed.2" : "ofs_embedding.linear_2" ,
8385 "mixins.patch_embed" : "patch_embed" ,
8486 "mixins.final_layer.norm_final" : "norm_out.norm" ,
8587 "mixins.final_layer.linear" : "proj_out" ,
@@ -150,7 +152,8 @@ def convert_transformer(
150152 num_layers = num_layers ,
151153 num_attention_heads = num_attention_heads ,
152154 use_rotary_positional_embeddings = use_rotary_positional_embeddings ,
153- use_learned_positional_embeddings = i2v ,
155+ ofs_embed_dim = 512 if (i2v and init_kwargs ["patch_size_t" ] is not None ) else None , # CogVideoX1.5-5B-I2V
156+ use_learned_positional_embeddings = i2v and init_kwargs ["patch_size_t" ] is None , # CogVideoX-5B-I2V
154157 ** init_kwargs ,
155158 ).to (dtype = dtype )
156159
@@ -210,7 +213,7 @@ def get_init_kwargs(version: str):
210213 "patch_bias" : False ,
211214 "sample_height" : 768 // vae_scale_factor_spatial ,
212215 "sample_width" : 1360 // vae_scale_factor_spatial ,
213- "sample_frames" : 81 ,
216+ "sample_frames" : 81 , # TODO: Need Test with 161 for 10 seconds
214217 }
215218 else :
216219 raise ValueError ("Unsupported version of CogVideoX." )
0 commit comments