Skip to content

Commit 3587317

Browse files
committed
invert_scale_latents
1 parent 8b28232 commit 3587317

File tree

3 files changed

+17
-6
lines changed

3 files changed

+17
-6
lines changed

scripts/convert_cogvideox_to_diffusers.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -173,9 +173,13 @@ def convert_transformer(
173173
return transformer
174174

175175

176-
def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype):
176+
def convert_vae(ckpt_path: str, scaling_factor: float, version: str, dtype: torch.dtype):
177+
init_kwargs = {"scaling_factor": scaling_factor}
178+
if args.version == "1.5":
179+
init_kwargs.update({"invert_scale_latents": True})
180+
177181
original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True))
178-
vae = AutoencoderKLCogVideoX(scaling_factor=scaling_factor).to(dtype=dtype)
182+
vae = AutoencoderKLCogVideoX(**init_kwargs).to(dtype=dtype)
179183

180184
for key in list(original_state_dict.keys()):
181185
new_key = key[:]
@@ -193,7 +197,7 @@ def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype):
193197
return vae
194198

195199

196-
def get_init_kwargs(version: str):
200+
def get_transformer_init_kwargs(version: str):
197201
if version == "1.0":
198202
vae_scale_factor_spatial = 8
199203
init_kwargs = {
@@ -281,7 +285,7 @@ def get_args():
281285
dtype = torch.float16 if args.fp16 else torch.bfloat16 if args.bf16 else torch.float32
282286

283287
if args.transformer_ckpt_path is not None:
284-
init_kwargs = get_init_kwargs(args.version)
288+
init_kwargs = get_transformer_init_kwargs(args.version)
285289
transformer = convert_transformer(
286290
args.transformer_ckpt_path,
287291
args.num_layers,
@@ -293,7 +297,7 @@ def get_args():
293297
)
294298
if args.vae_ckpt_path is not None:
295299
# Keep VAE in float32 for better quality
296-
vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, torch.float32)
300+
vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, args.version, torch.float32)
297301

298302
text_encoder_id = "google/t5-v1_1-xxl"
299303
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)

src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1057,6 +1057,7 @@ def __init__(
10571057
force_upcast: float = True,
10581058
use_quant_conv: bool = False,
10591059
use_post_quant_conv: bool = False,
1060+
invert_scale_latents: bool = False,
10601061
):
10611062
super().__init__()
10621063

src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,13 @@ def prepare_latents(
381381
image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in image]
382382

383383
image_latents = torch.cat(image_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
384-
image_latents = self.vae_scaling_factor_image * image_latents
384+
385+
if not self.vae.config.invert_scale_latents:
386+
image_latents = self.vae_scaling_factor_image * image_latents
387+
else:
388+
# This is awkward but required because the CogVideoX team forgot to multiply the
389+
# scaling factor during training :)
390+
image_latents = 1 / self.vae_scaling_factor_image * image_latents
385391

386392
padding_shape = (
387393
batch_size,

0 commit comments

Comments
 (0)