Skip to content

Commit 27441fc

Browse files
#skip frames 0
1 parent 7a1b579 commit 27441fc

File tree

2 files changed

+36
-10
lines changed

2 files changed

+36
-10
lines changed

src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,10 @@ def prepare_latents(
334334
width // self.vae_scale_factor_spatial,
335335
)
336336

337+
# For CogVideoX1.5, the latent should add 1 for padding (Not use)
338+
if self.transformer.config.patch_size_t is not None:
339+
shape = shape[:1] + (shape[1] + shape[1] % self.transformer.config.patch_size_t,) + shape[2:]
340+
337341
if latents is None:
338342
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
339343
else:
@@ -734,7 +738,13 @@ def __call__(
734738
progress_bar.update()
735739

736740
if not output_type == "latent":
737-
video = self.decode_latents(latents[:,1:])
741+
# Calculate the number of start frames based on the size of the second dimension of latents
742+
num_latent_frames = latents.size(1) # Get the size of the second dimension
743+
# (81 - 1) / 4 + 1 = 21 and latents is 22, so the first frames will be 22 - 1 = 1, and we will skip frames 0
744+
start_frames = num_latent_frames - ((num_frames - 1) // self.vae_scale_factor_temporal + 1)
745+
746+
# Slice latents starting from start_frames
747+
video = self.decode_latents(latents[:, start_frames:])
738748
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
739749
else:
740750
video = latents

src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,10 @@ def prepare_latents(
367367
width // self.vae_scale_factor_spatial,
368368
)
369369

370+
# For CogVideoX1.5, the latent should add 1 for padding (Not use)
371+
if self.transformer.config.patch_size_t is not None:
372+
shape = shape[:1] + (shape[1] + shape[1] % self.transformer.config.patch_size_t,) + shape[2:]
373+
370374
image = image.unsqueeze(2) # [B, C, F, H, W]
371375

372376
if isinstance(generator, list):
@@ -386,9 +390,15 @@ def prepare_latents(
386390
height // self.vae_scale_factor_spatial,
387391
width // self.vae_scale_factor_spatial,
388392
)
393+
389394
latent_padding = torch.zeros(padding_shape, device=device, dtype=dtype)
390395
image_latents = torch.cat([image_latents, latent_padding], dim=1)
391396

397+
# Select the first frame along the second dimension
398+
if self.transformer.config.patch_size_t is not None:
399+
first_frame = image_latents[:, : image_latents.size(1) % self.transformer.config.patch_size_t, ...]
400+
image_latents = torch.cat([first_frame, image_latents], dim=1)
401+
392402
if latents is None:
393403
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
394404
else:
@@ -460,14 +470,14 @@ def check_inputs(
460470
if height % 8 != 0 or width % 8 != 0:
461471
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
462472

463-
latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
464-
if (
465-
self.transformer.config.patch_size_t is not None
466-
and latent_frames % self.transformer.config.patch_size_t != 0
467-
):
468-
raise ValueError(
469-
f"Number of latent frames must be divisible by `{self.transformer.config.patch_size_t}` but got {latent_frames=}."
470-
)
473+
# latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
474+
# if (
475+
# self.transformer.config.patch_size_t is not None
476+
# and latent_frames % self.transformer.config.patch_size_t != 0
477+
# ):
478+
# raise ValueError(
479+
# f"Number of latent frames must be divisible by `{self.transformer.config.patch_size_t}` but got {latent_frames=}."
480+
# )
471481

472482
if callback_on_step_end_tensor_inputs is not None and not all(
473483
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
@@ -853,7 +863,13 @@ def adjust_resolution_to_divisible(image_height, image_width, tgt_height, tgt_wi
853863
progress_bar.update()
854864

855865
if not output_type == "latent":
856-
video = self.decode_latents(latents)
866+
# Calculate the number of start frames based on the size of the second dimension of latents
867+
num_latent_frames = latents.size(1) # Get the size of the second dimension
868+
# (81 - 1) / 4 + 1 = 21 and latents is 22, so the first frames will be 22 - 1 = 1, and we will skip frames 0
869+
start_frames = num_latent_frames - ((num_frames - 1) // self.vae_scale_factor_temporal + 1)
870+
871+
# Slice latents starting from start_frames
872+
video = self.decode_latents(latents[:, start_frames:])
857873
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
858874
else:
859875
video = latents

0 commit comments

Comments
 (0)