Skip to content

Commit ea56788

Browse files
committed
update
1 parent 9edddc1 commit ea56788

File tree

2 files changed

+12
-6
lines changed

2 files changed

+12
-6
lines changed

scripts/convert_cogvideox_to_diffusers.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,12 @@ def get_args():
233233
parser.add_argument(
234234
"--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory"
235235
)
236+
parser.add_argument(
237+
"--typecast_text_encoder",
238+
action="store_true",
239+
default=False,
240+
help="Whether or not to apply fp16/bf16 precision to text_encoder",
241+
)
236242
# For CogVideoX-2B, num_layers is 30. For 5B, it is 42
237243
parser.add_argument("--num_layers", type=int, default=30, help="Number of transformer blocks")
238244
# For CogVideoX-2B, num_attention_heads is 30. For 5B, it is 48
@@ -283,12 +289,16 @@ def get_args():
283289
init_kwargs,
284290
)
285291
if args.vae_ckpt_path is not None:
286-
vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, dtype)
292+
# Keep VAE in float32 for better quality
293+
vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, torch.float32)
287294

288295
text_encoder_id = "google/t5-v1_1-xxl"
289296
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
290297
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
291298

299+
if args.typecast_text_encoder:
300+
text_encoder = text_encoder.to(dtype=dtype)
301+
292302
# Apparently, the conversion does not work anymore without this :shrug:
293303
for param in text_encoder.parameters():
294304
param.data = param.data.contiguous()
@@ -320,11 +330,6 @@ def get_args():
320330
scheduler=scheduler,
321331
)
322332

323-
if args.fp16:
324-
pipe = pipe.to(dtype=torch.float16)
325-
if args.bf16:
326-
pipe = pipe.to(dtype=torch.bfloat16)
327-
328333
# We don't use variant here because the model must be run in fp16 (2B) or bf16 (5B). It would be weird
329334
# for users to specify variant when the default is not fp32 and they want to run with the correct default (which
330335
# is either fp16/bf16 here).

src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -683,6 +683,7 @@ def __call__(
683683

684684
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
685685
timestep = t.expand(latent_model_input.shape[0])
686+
686687
# predict noise model_output
687688
noise_pred = self.transformer(
688689
hidden_states=latent_model_input,

0 commit comments

Comments
 (0)