Skip to content

Commit 4f51829

Browse files
committed
make pipeline work
1 parent 677a553 commit 4f51829

File tree

3 files changed

+47
-48
lines changed

3 files changed

+47
-48
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1089,8 +1089,9 @@ def _encode(self, x: torch.Tensor) -> torch.Tensor:
10891089
return self.tiled_encode(x)
10901090

10911091
frame_batch_size = self.num_sample_frames_batch_size
1092+
num_batches = num_frames // frame_batch_size if num_frames > 1 else 1
10921093
enc = []
1093-
for i in range((num_frames + frame_batch_size - 1) // frame_batch_size):
1094+
for i in range(num_batches):
10941095
remaining_frames = num_frames % frame_batch_size
10951096
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
10961097
end_frame = frame_batch_size * (i + 1) + remaining_frames
@@ -1141,7 +1142,7 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOut
11411142

11421143
frame_batch_size = self.num_latent_frames_batch_size
11431144
dec = []
1144-
for i in range((num_frames + frame_batch_size - 1) // frame_batch_size):
1145+
for i in range(num_frames // frame_batch_size):
11451146
remaining_frames = num_frames % frame_batch_size
11461147
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
11471148
end_frame = frame_batch_size * (i + 1) + remaining_frames
@@ -1233,8 +1234,9 @@ def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
12331234
for i in range(0, height, overlap_height):
12341235
row = []
12351236
for j in range(0, width, overlap_width):
1237+
num_batches = num_frames // frame_batch_size if num_frames > 1 else 1
12361238
time = []
1237-
for k in range(num_frames // frame_batch_size):
1239+
for k in range(num_batches):
12381240
remaining_frames = num_frames % frame_batch_size
12391241
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
12401242
end_frame = frame_batch_size * (k + 1) + remaining_frames

src/diffusers/models/transformers/cogvideox_transformer_3d.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -465,8 +465,11 @@ def custom_forward(*inputs):
465465
hidden_states = self.proj_out(hidden_states)
466466

467467
# 5. Unpatchify
468+
# Note: we use `-1` instead of `channels`:
469+
# - It is okay to use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels)
470+
# - However, for CogVideoX-5b-I2V, input image (number of input channels is twice the output channels)
468471
p = self.config.patch_size
469-
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p)
472+
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
470473
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
471474

472475
if not return_dict:

src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py

Lines changed: 38 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ def encode_prompt(
338338

339339
def prepare_latents(
340340
self,
341-
image: Optional[torch.Tensor] = None,
341+
image: torch.Tensor,
342342
batch_size: int = 1,
343343
num_channels_latents: int = 16,
344344
num_frames: int = 13,
@@ -363,47 +363,46 @@ def prepare_latents(
363363
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
364364
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
365365
)
366+
367+
assert image.ndim == 4
368+
image = image.unsqueeze(2) # [B, C, F, H, W]
369+
print(image.shape)
366370

367-
if latents is None:
368-
assert image.ndim == 4
369-
image = image.unsqueeze(2) # [B, C, F, H, W]
370-
print(image.shape)
371-
372-
if isinstance(generator, list):
373-
if len(generator) != batch_size:
374-
raise ValueError(
375-
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
376-
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
377-
)
371+
if isinstance(generator, list):
372+
if len(generator) != batch_size:
373+
raise ValueError(
374+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
375+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
376+
)
378377

379-
init_latents = [
380-
retrieve_latents(self.vae.encode(image[i].unsqueeze(0)), generator[i]) for i in range(batch_size)
381-
]
382-
else:
383-
init_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in image]
384-
385-
init_latents = torch.cat(init_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
386-
init_latents = self.vae.config.scaling_factor * init_latents
387-
388-
padding_shape = (
389-
batch_size,
390-
num_frames - 1,
391-
num_channels_latents,
392-
height // self.vae_scale_factor_spatial,
393-
width // self.vae_scale_factor_spatial,
394-
)
395-
latent_padding = torch.zeros(padding_shape, device=device, dtype=dtype)
396-
print(init_latents.shape, latent_padding.shape)
397-
init_latents = torch.cat([init_latents, latent_padding], dim=1)
378+
image_latents = [
379+
retrieve_latents(self.vae.encode(image[i].unsqueeze(0)), generator[i]) for i in range(batch_size)
380+
]
381+
else:
382+
image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in image]
398383

399-
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
400-
latents = torch.cat([noise, init_latents], dim=2)
384+
image_latents = torch.cat(image_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
385+
image_latents = self.vae.config.scaling_factor * image_latents
386+
387+
padding_shape = (
388+
batch_size,
389+
num_frames - 1,
390+
num_channels_latents,
391+
height // self.vae_scale_factor_spatial,
392+
width // self.vae_scale_factor_spatial,
393+
)
394+
latent_padding = torch.zeros(padding_shape, device=device, dtype=dtype)
395+
print(image_latents.shape, latent_padding.shape)
396+
image_latents = torch.cat([image_latents, latent_padding], dim=1)
397+
398+
if latents is None:
399+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
401400
else:
402401
latents = latents.to(device)
403402

404403
# scale the initial noise by the standard deviation required by the scheduler
405404
latents = latents * self.scheduler.init_noise_sigma
406-
return latents
405+
return latents, image_latents
407406

408407
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.decode_latents
409408
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
@@ -446,7 +445,6 @@ def check_inputs(
446445
prompt,
447446
height,
448447
width,
449-
strength,
450448
negative_prompt,
451449
callback_on_step_end_tensor_inputs,
452450
video=None,
@@ -457,9 +455,6 @@ def check_inputs(
457455
if height % 8 != 0 or width % 8 != 0:
458456
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
459457

460-
if strength < 0 or strength > 1:
461-
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
462-
463458
if callback_on_step_end_tensor_inputs is not None and not all(
464459
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
465460
):
@@ -567,12 +562,10 @@ def __call__(
567562
num_frames: int = 49,
568563
num_inference_steps: int = 50,
569564
timesteps: Optional[List[int]] = None,
570-
strength: float = 0.8,
571565
guidance_scale: float = 6,
572566
use_dynamic_cfg: bool = False,
573567
num_videos_per_prompt: int = 1,
574568
eta: float = 0.0,
575-
noise_aug_strength: float = 0.02,
576569
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
577570
latents: Optional[torch.FloatTensor] = None,
578571
prompt_embeds: Optional[torch.FloatTensor] = None,
@@ -614,8 +607,6 @@ def __call__(
614607
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
615608
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
616609
passed will be used. Must be in descending order.
617-
strength (`float`, *optional*, defaults to 0.8):
618-
Higher strength leads to more differences between original video and generated video.
619610
guidance_scale (`float`, *optional*, defaults to 7.0):
620611
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
621612
`guidance_scale` is defined as `w` of equation 2. of [Imagen
@@ -682,7 +673,6 @@ def __call__(
682673
prompt,
683674
height,
684675
width,
685-
strength,
686676
negative_prompt,
687677
callback_on_step_end_tensor_inputs,
688678
prompt_embeds,
@@ -730,7 +720,7 @@ def __call__(
730720
)
731721

732722
latent_channels = self.transformer.config.in_channels // 2
733-
latents = self.prepare_latents(
723+
latents, image_latents = self.prepare_latents(
734724
image,
735725
batch_size * num_videos_per_prompt,
736726
latent_channels,
@@ -765,6 +755,10 @@ def __call__(
765755

766756
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
767757
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
758+
759+
latent_image_input = torch.cat([image_latents] * 2) if do_classifier_free_guidance else image_latents
760+
latent_model_input = torch.cat([latent_model_input, latent_image_input], dim=2)
761+
print(latent_model_input.shape)
768762

769763
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
770764
timestep = t.expand(latent_model_input.shape[0])

0 commit comments

Comments
 (0)