@@ -233,6 +233,12 @@ def get_args():
233233 parser .add_argument (
234234 "--text_encoder_cache_dir" , type = str , default = None , help = "Path to text encoder cache directory"
235235 )
236+ parser .add_argument (
237+ "--typecast_text_encoder" ,
238+ action = "store_true" ,
239+ default = False ,
240+ help = "Whether or not to apply fp16/bf16 precision to text_encoder" ,
241+ )
236242 # For CogVideoX-2B, num_layers is 30. For 5B, it is 42
237243 parser .add_argument ("--num_layers" , type = int , default = 30 , help = "Number of transformer blocks" )
238244 # For CogVideoX-2B, num_attention_heads is 30. For 5B, it is 48
@@ -283,12 +289,16 @@ def get_args():
283289 init_kwargs ,
284290 )
285291 if args .vae_ckpt_path is not None :
286- vae = convert_vae (args .vae_ckpt_path , args .scaling_factor , dtype )
292+ # Keep VAE in float32 for better quality
293+ vae = convert_vae (args .vae_ckpt_path , args .scaling_factor , torch .float32 )
287294
288295 text_encoder_id = "google/t5-v1_1-xxl"
289296 tokenizer = T5Tokenizer .from_pretrained (text_encoder_id , model_max_length = TOKENIZER_MAX_LENGTH )
290297 text_encoder = T5EncoderModel .from_pretrained (text_encoder_id , cache_dir = args .text_encoder_cache_dir )
291298
299+ if args .typecast_text_encoder :
300+ text_encoder = text_encoder .to (dtype = dtype )
301+
292302 # Apparently, the conversion does not work anymore without this :shrug:
293303 for param in text_encoder .parameters ():
294304 param .data = param .data .contiguous ()
@@ -320,11 +330,6 @@ def get_args():
320330 scheduler = scheduler ,
321331 )
322332
323- if args .fp16 :
324- pipe = pipe .to (dtype = torch .float16 )
325- if args .bf16 :
326- pipe = pipe .to (dtype = torch .bfloat16 )
327-
328333 # We don't use variant here because the model must be run in fp16 (2B) or bf16 (5B). It would be weird
329334 # for users to specify variant when the default is not fp32 and they want to run with the correct default (which
330335 # is either fp16/bf16 here).
0 commit comments