@@ -367,6 +367,10 @@ def prepare_latents(
367367 width // self .vae_scale_factor_spatial ,
368368 )
369369
370+ # For CogVideoX1.5, the latent should add 1 for padding (Not use)
371+ if self .transformer .config .patch_size_t is not None :
372+ shape = shape [:1 ] + (shape [1 ] + shape [1 ] % self .transformer .config .patch_size_t ,) + shape [2 :]
373+
370374 image = image .unsqueeze (2 ) # [B, C, F, H, W]
371375
372376 if isinstance (generator , list ):
@@ -386,9 +390,15 @@ def prepare_latents(
386390 height // self .vae_scale_factor_spatial ,
387391 width // self .vae_scale_factor_spatial ,
388392 )
393+
389394 latent_padding = torch .zeros (padding_shape , device = device , dtype = dtype )
390395 image_latents = torch .cat ([image_latents , latent_padding ], dim = 1 )
391396
397+ # Select the first frame along the second dimension
398+ if self .transformer .config .patch_size_t is not None :
399+ first_frame = image_latents [:, : image_latents .size (1 ) % self .transformer .config .patch_size_t , ...]
400+ image_latents = torch .cat ([first_frame , image_latents ], dim = 1 )
401+
392402 if latents is None :
393403 latents = randn_tensor (shape , generator = generator , device = device , dtype = dtype )
394404 else :
@@ -460,14 +470,14 @@ def check_inputs(
460470 if height % 8 != 0 or width % 8 != 0 :
461471 raise ValueError (f"`height` and `width` have to be divisible by 8 but are { height } and { width } ." )
462472
463- latent_frames = (num_frames - 1 ) // self .vae_scale_factor_temporal + 1
464- if (
465- self .transformer .config .patch_size_t is not None
466- and latent_frames % self .transformer .config .patch_size_t != 0
467- ):
468- raise ValueError (
469- f"Number of latent frames must be divisible by `{ self .transformer .config .patch_size_t } ` but got { latent_frames = } ."
470- )
473+ # latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
474+ # if (
475+ # self.transformer.config.patch_size_t is not None
476+ # and latent_frames % self.transformer.config.patch_size_t != 0
477+ # ):
478+ # raise ValueError(
479+ # f"Number of latent frames must be divisible by `{self.transformer.config.patch_size_t}` but got {latent_frames=}."
480+ # )
471481
472482 if callback_on_step_end_tensor_inputs is not None and not all (
473483 k in self ._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
@@ -853,7 +863,13 @@ def adjust_resolution_to_divisible(image_height, image_width, tgt_height, tgt_wi
853863 progress_bar .update ()
854864
855865 if not output_type == "latent" :
856- video = self .decode_latents (latents )
866+ # Calculate the number of start frames based on the size of the second dimension of latents
867+ num_latent_frames = latents .size (1 ) # Get the size of the second dimension
868+ # (81 - 1) / 4 + 1 = 21 and latents is 22, so the first frames will be 22 - 1 = 1, and we will skip frames 0
869+ start_frames = num_latent_frames - ((num_frames - 1 ) // self .vae_scale_factor_temporal + 1 )
870+
871+ # Slice latents starting from start_frames
872+ video = self .decode_latents (latents [:, start_frames :])
857873 video = self .video_processor .postprocess_video (video = video , output_type = output_type )
858874 else :
859875 video = latents
0 commit comments