Skip to content

Commit 677a553

Browse files
committed
fix minor bugs
1 parent 3df95b2 commit 677a553

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1090,7 +1090,7 @@ def _encode(self, x: torch.Tensor) -> torch.Tensor:
10901090

10911091
frame_batch_size = self.num_sample_frames_batch_size
10921092
enc = []
1093-
for i in range(num_frames // frame_batch_size):
1093+
for i in range((num_frames + frame_batch_size - 1) // frame_batch_size):
10941094
remaining_frames = num_frames % frame_batch_size
10951095
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
10961096
end_frame = frame_batch_size * (i + 1) + remaining_frames
@@ -1141,7 +1141,7 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOut
11411141

11421142
frame_batch_size = self.num_latent_frames_batch_size
11431143
dec = []
1144-
for i in range(num_frames // frame_batch_size):
1144+
for i in range((num_frames + frame_batch_size - 1) // frame_batch_size):
11451145
remaining_frames = num_frames % frame_batch_size
11461146
start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
11471147
end_frame = frame_batch_size * (i + 1) + remaining_frames

src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,7 @@ def prepare_latents(
367367
if latents is None:
368368
assert image.ndim == 4
369369
image = image.unsqueeze(2) # [B, C, F, H, W]
370+
print(image.shape)
370371

371372
if isinstance(generator, list):
372373
if len(generator) != batch_size:
@@ -392,6 +393,7 @@ def prepare_latents(
392393
width // self.vae_scale_factor_spatial,
393394
)
394395
latent_padding = torch.zeros(padding_shape, device=device, dtype=dtype)
396+
print(init_latents.shape, latent_padding.shape)
395397
init_latents = torch.cat([init_latents, latent_padding], dim=1)
396398

397399
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
@@ -723,10 +725,11 @@ def __call__(
723725
self._num_timesteps = len(timesteps)
724726

725727
# 5. Prepare latents
726-
image = self.video_processor.preprocess(image, height=height, width=width).to(device)
727-
image = image.unsqueeze(2) # [B, C, F, H, W]
728+
image = self.video_processor.preprocess(image, height=height, width=width).to(
729+
device, dtype=prompt_embeds.dtype
730+
)
728731

729-
latent_channels = self.transformer.config.in_channels
732+
latent_channels = self.transformer.config.in_channels // 2
730733
latents = self.prepare_latents(
731734
image,
732735
batch_size * num_videos_per_prompt,

0 commit comments

Comments
 (0)