33
44import torch
55from transformers import T5EncoderModel , T5Tokenizer
6+ from diffusers import AutoencoderKL , CogVideoXDDIMScheduler
7+ from diffusers .loaders .single_file_utils import convert_ldm_vae_checkpoint
68
79from diffusers import (
810 CogView3PlusTransformer2DModel ,
911 CogView3PlusPipeline ,
1012)
1113
14+
15+ def reassign_query_key_value_inplace (key : str , state_dict : Dict [str , Any ]):
16+ to_q_key = key .replace ("query_key_value" , "to_q" )
17+ to_k_key = key .replace ("query_key_value" , "to_k" )
18+ to_v_key = key .replace ("query_key_value" , "to_v" )
19+ to_q , to_k , to_v = torch .chunk (state_dict [key ], chunks = 3 , dim = 0 )
20+ state_dict [to_q_key ] = to_q
21+ state_dict [to_k_key ] = to_k
22+ state_dict [to_v_key ] = to_v
23+ state_dict .pop (key )
24+
25+
26+ def reassign_query_key_layernorm_inplace (key : str , state_dict : Dict [str , Any ]):
27+ layer_id , weight_or_bias = key .split ("." )[- 2 :]
28+
29+ if "query" in key :
30+ new_key = f"transformer_blocks.{ layer_id } .attn.norm_q.{ weight_or_bias } "
31+ elif "key" in key :
32+ new_key = f"transformer_blocks.{ layer_id } .attn.norm_k.{ weight_or_bias } "
33+
34+ state_dict [new_key ] = state_dict .pop (key )
35+
36+
37+ def reassign_adaln_norm_inplace (key : str , state_dict : Dict [str , Any ]):
38+ layer_id , _ , weight_or_bias = key .split ("." )[- 3 :]
39+
40+ weights_or_biases = state_dict [key ].chunk (12 , dim = 0 )
41+ norm1_weights_or_biases = torch .cat (weights_or_biases [0 :3 ] + weights_or_biases [6 :9 ])
42+ norm2_weights_or_biases = torch .cat (weights_or_biases [3 :6 ] + weights_or_biases [9 :12 ])
43+
44+ norm1_key = f"transformer_blocks.{ layer_id } .norm1.linear.{ weight_or_bias } "
45+ state_dict [norm1_key ] = norm1_weights_or_biases
46+
47+ norm2_key = f"transformer_blocks.{ layer_id } .norm2.linear.{ weight_or_bias } "
48+ state_dict [norm2_key ] = norm2_weights_or_biases
49+
50+ state_dict .pop (key )
51+
52+
53+ def remove_keys_inplace (key : str , state_dict : Dict [str , Any ]):
54+ state_dict .pop (key )
55+
56+
57+ def get_state_dict (saved_dict : Dict [str , Any ]) -> Dict [str , Any ]:
58+ state_dict = saved_dict
59+ if "model" in saved_dict .keys ():
60+ state_dict = state_dict ["model" ]
61+ if "module" in saved_dict .keys ():
62+ state_dict = state_dict ["module" ]
63+ if "state_dict" in saved_dict .keys ():
64+ state_dict = state_dict ["state_dict" ]
65+ return state_dict
66+
67+
1268TRANSFORMER_KEYS_RENAME_DICT = {
1369 "transformer" : "transformer_blocks" ,
14- "attention" : "attn1 " ,
15- "mlp" : "ff .net" ,
70+ "attention" : "attn " ,
71+ "mlp" : "mlp .net" ,
1672 "dense_h_to_4h" : "0.proj" ,
1773 "dense_4h_to_h" : "2" ,
1874 ".layers" : "" ,
1975 "dense" : "to_out.0" ,
20- "patch_embed" : "norm1.norm" ,
21- "post_attn1_layernorm" : "norm2.norm" ,
22- "mixins.patch_embed" : "patch_embed" ,
23- "mixins.final_layer.adaln" : "norm_out" ,
76+ "mixins.patch_embed" : "image_patch_embed" ,
77+ "mixins.adaln.adaln_modules" : "adaln_module" ,
78+ "time_embed" : "time_embed" ,
79+ "label_emb" : "label_embed" ,
80+ "mixins.final_layer.adaln" : "final_layer.adaln" ,
2481 "mixins.final_layer.linear" : "proj_out" ,
25- }
82+ }
83+
84+ TRANSFORMER_SPECIAL_KEYS_REMAP = {
85+ "query_key_value" : reassign_query_key_value_inplace ,
86+ }
87+
88+ TOKENIZER_MAX_LENGTH = 224
89+
90+
91+ # VAE of CogView3Plus can be converted to diffusers without any changes
92+ def convert_vae (ckpt_path : str , scaling_factor : float , dtype : torch .dtype ):
93+ original_state_dict = torch .load (ckpt_path , map_location = 'cpu' )["state_dict" ]
94+
95+ vae = AutoencoderKL (
96+ in_channels = 3 ,
97+ out_channels = 3 ,
98+ down_block_types = ("DownEncoderBlock2D" ,) * 4 ,
99+ up_block_types = ("UpDecoderBlock2D" ,) * 4 ,
100+ block_out_channels = (128 , 512 , 1024 , 1024 ),
101+ layers_per_block = 3 ,
102+ act_fn = "silu" ,
103+ latent_channels = 16 ,
104+ norm_num_groups = 32 ,
105+ sample_size = 1024 ,
106+ scaling_factor = scaling_factor ,
107+ force_upcast = True ,
108+ use_quant_conv = False ,
109+ use_post_quant_conv = False ,
110+ mid_block_add_attention = False ,
111+ ).to (dtype = dtype )
112+
113+ # Convert the state dict to a format compatible with diffusers
114+ converted_state_dict = convert_ldm_vae_checkpoint (original_state_dict , vae .config )
115+
116+ # Load the converted state dict into the VAE model
117+ vae .load_state_dict (converted_state_dict , strict = False )
118+
119+ return vae
120+
121+
122+ def update_state_dict_inplace (state_dict : Dict [str , Any ], old_key : str , new_key : str ) -> Dict [str , Any ]:
123+ state_dict [new_key ] = state_dict .pop (old_key )
124+
125+
126+ def convert_transformer (
127+ ckpt_path : str ,
128+ num_layers : int ,
129+ num_attention_heads : int ,
130+ dtype : torch .dtype ,
131+ ):
132+ PREFIX_KEY = "model.diffusion_model."
133+
134+ original_state_dict = get_state_dict (torch .load (ckpt_path , map_location = "cpu" , mmap = True ))
135+ transformer = CogView3PlusTransformer2DModel (
136+ in_channels = 16 ,
137+ num_layers = num_layers ,
138+ num_attention_heads = num_attention_heads ,
139+ ).to (dtype = dtype )
140+
141+ for key in list (original_state_dict .keys ()):
142+ new_key = key [len (PREFIX_KEY ):]
143+ for replace_key , rename_key in TRANSFORMER_KEYS_RENAME_DICT .items ():
144+ new_key = new_key .replace (replace_key , rename_key )
145+ update_state_dict_inplace (original_state_dict , key , new_key )
146+
147+
148+ for key in list (original_state_dict .keys ()):
149+ for special_key , handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP .items ():
150+ if special_key not in key :
151+ continue
152+ handler_fn_inplace (key , original_state_dict )
153+ transformer .load_state_dict (original_state_dict , strict = True )
154+
155+ return transformer
156+
157+
158+ def get_args ():
159+ parser = argparse .ArgumentParser ()
160+ parser .add_argument (
161+ "--transformer_ckpt_path" , type = str , default = None , help = "Path to original transformer checkpoint"
162+ )
163+ parser .add_argument ("--vae_ckpt_path" , type = str , default = None , help = "Path to original vae checkpoint" )
164+ parser .add_argument ("--output_path" , type = str , required = True , help = "Path where converted model should be saved" )
165+ parser .add_argument ("--fp16" , action = "store_true" , default = False , help = "Whether to save the model weights in fp16" )
166+ parser .add_argument ("--bf16" , action = "store_true" , default = False , help = "Whether to save the model weights in bf16" )
167+ parser .add_argument (
168+ "--push_to_hub" , action = "store_true" , default = False , help = "Whether to push to HF Hub after saving"
169+ )
170+ parser .add_argument (
171+ "--text_encoder_cache_dir" , type = str , default = None , help = "Path to text encoder cache directory"
172+ )
173+ parser .add_argument ("--num_layers" , type = int , default = 30 , help = "Number of transformer blocks" )
174+ parser .add_argument ("--num_attention_heads" , type = int , default = 64 , help = "Number of transformer blocks" )
175+ parser .add_argument ("--scaling_factor" , type = float , default = 0.18215 , help = "Scaling factor in the VAE" )
176+ return parser .parse_args ()
177+
178+
179+ if __name__ == "__main__" :
180+ args = get_args ()
181+
182+ transformer = None
183+ vae = None
184+
185+ if args .fp16 and args .bf16 :
186+ raise ValueError ("You cannot pass both --fp16 and --bf16 at the same time." )
187+
188+ dtype = torch .float16 if args .fp16 else torch .bfloat16 if args .bf16 else torch .float32
189+ if args .transformer_ckpt_path is not None :
190+ transformer = convert_transformer (
191+ args .transformer_ckpt_path ,
192+ args .num_layers ,
193+ args .num_attention_heads ,
194+ dtype
195+ )
196+
197+ if args .vae_ckpt_path is not None :
198+ vae = convert_vae (args .vae_ckpt_path , args .scaling_factor , dtype )
199+
200+ text_encoder_id = "/share/official_pretrains/hf_home/t5-v1_1-xxl"
201+ tokenizer = T5Tokenizer .from_pretrained (text_encoder_id , model_max_length = TOKENIZER_MAX_LENGTH )
202+ text_encoder = T5EncoderModel .from_pretrained (text_encoder_id , cache_dir = args .text_encoder_cache_dir )
203+
204+ scheduler = CogVideoXDDIMScheduler .from_config (
205+ {
206+ "beta_end" : 0.012 ,
207+ "beta_schedule" : "scaled_linear" ,
208+ "beta_start" : 0.00085 ,
209+ "clip_sample" : False ,
210+ "num_train_timesteps" : 1000 ,
211+ "prediction_type" : "v_prediction" ,
212+ "rescale_betas_zero_snr" : True ,
213+ "set_alpha_to_one" : True ,
214+ "timestep_spacing" : "trailing" ,
215+ }
216+ )
217+
218+ pipe = CogView3PlusPipeline (
219+ tokenizer = tokenizer ,
220+ vae = vae ,
221+ text_encoder = text_encoder ,
222+ transformer = transformer ,
223+ scheduler = scheduler ,
224+ )
225+
226+ if args .fp16 :
227+ pipe = pipe .to (dtype = torch .float16 )
228+ if args .bf16 :
229+ pipe = pipe .to (dtype = torch .bfloat16 )
230+
231+ # We don't use variant here because the model must be run in fp16 (2B) or bf16 (5B). It would be weird
232+ # for users to specify variant when the default is not fp32 and they want to run with the correct default (which
233+ # is either fp16/bf16 here).
234+
235+ # This is necessary This is necessary for users with insufficient memory,
236+ # such as those using Colab and notebooks, as it can save some memory used for model loading.
237+ pipe .save_pretrained (args .output_path , safe_serialization = True , max_shard_size = "5GB" , push_to_hub = args .push_to_hub )
0 commit comments