Skip to content

Commit f28708d

Browse files
Update train_cogvideox_image_to_video_lora.py
1 parent 4339f65 commit f28708d

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

finetune/train_cogvideox_image_to_video_lora.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
from torchvision.transforms import InterpolationMode
5959
import torchvision.transforms as TT
6060
import numpy as np
61+
from diffusers.image_processor import VaeImageProcessor
6162

6263

6364
if is_wandb_available():
@@ -773,8 +774,13 @@ def log_validation(
773774

774775
videos = []
775776
for _ in range(args.num_validation_videos):
776-
video = pipe(**pipeline_args, generator=generator, output_type="pil").frames[0]
777-
videos.append(video)
777+
pt_images = pipe(**pipeline_args, generator=generator, output_type="pt").frames[0]
778+
pt_images = torch.stack([pt_images[i] for i in range(pt_images.shape[0])])
779+
780+
image_np = VaeImageProcessor.pt_to_numpy(pt_images)
781+
image_pil = VaeImageProcessor.numpy_to_pil(image_np)
782+
783+
videos.append(image_pil)
778784

779785
for tracker in accelerator.trackers:
780786
phase_name = "test" if is_final_validation else "validation"

0 commit comments

Comments
 (0)