Skip to content

Commit 7a1b579

Browse files
set patch_size_t as None by default
1 parent 0c98aad commit 7a1b579

File tree

2 files changed

+3
-14
lines changed

2 files changed

+3
-14
lines changed

src/diffusers/models/transformers/cogvideox_transformer_3d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def __init__(
230230
sample_height: int = 60,
231231
sample_frames: int = 49,
232232
patch_size: int = 2,
233-
patch_size_t: int = 2,
233+
patch_size_t: Optional[int] = None,
234234
temporal_compression_ratio: int = 4,
235235
max_text_seq_length: int = 226,
236236
activation_fn: str = "gelu-approximate",

src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -368,12 +368,12 @@ def prepare_extra_step_kwargs(self, generator, eta):
368368
extra_step_kwargs["generator"] = generator
369369
return extra_step_kwargs
370370

371+
# Copied from diffusers.pipelines.latte.pipeline_latte.LattePipeline.check_inputs
371372
def check_inputs(
372373
self,
373374
prompt,
374375
height,
375376
width,
376-
num_frames,
377377
negative_prompt,
378378
callback_on_step_end_tensor_inputs,
379379
prompt_embeds=None,
@@ -382,15 +382,6 @@ def check_inputs(
382382
if height % 8 != 0 or width % 8 != 0:
383383
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
384384

385-
latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
386-
if (
387-
self.transformer.config.patch_size_t is not None
388-
and latent_frames % self.transformer.config.patch_size_t != 0
389-
):
390-
raise ValueError(
391-
f"Number of latent frames must be divisible by `{self.transformer.config.patch_size_t}` but got {latent_frames=}."
392-
)
393-
394385
if callback_on_step_end_tensor_inputs is not None and not all(
395386
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
396387
):
@@ -611,7 +602,6 @@ def __call__(
611602
prompt,
612603
height,
613604
width,
614-
num_frames,
615605
negative_prompt,
616606
callback_on_step_end_tensor_inputs,
617607
prompt_embeds,
@@ -744,8 +734,7 @@ def __call__(
744734
progress_bar.update()
745735

746736
if not output_type == "latent":
747-
breakpoint()
748-
video = self.decode_latents(latents)
737+
video = self.decode_latents(latents[:,1:])
749738
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
750739
else:
751740
video = latents

0 commit comments

Comments
 (0)