Skip to content

Commit e169e7b

Browse files
Update train_cogvideox_image_to_video_lora.py
1 parent f28708d commit e169e7b

File tree

1 file changed

+32
-29
lines changed

1 file changed

+32
-29
lines changed

finetune/train_cogvideox_image_to_video_lora.py

Lines changed: 32 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -573,36 +573,36 @@ def _load_dataset_from_local_path(self):
573573
return instance_prompts, instance_videos
574574

575575
def _resize_for_rectangle_crop(self, arr):
576-
image_size = self.height, self.width
577-
reshape_mode = self.video_reshape_mode
578-
if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
579-
arr = resize(
580-
arr,
581-
size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
582-
interpolation=InterpolationMode.BICUBIC,
583-
)
584-
else:
585-
arr = resize(
586-
arr,
587-
size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
588-
interpolation=InterpolationMode.BICUBIC,
589-
)
576+
image_size = self.height, self.width
577+
reshape_mode = self.video_reshape_mode
578+
if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
579+
arr = resize(
580+
arr,
581+
size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
582+
interpolation=InterpolationMode.BICUBIC,
583+
)
584+
else:
585+
arr = resize(
586+
arr,
587+
size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
588+
interpolation=InterpolationMode.BICUBIC,
589+
)
590590

591-
h, w = arr.shape[2], arr.shape[3]
592-
arr = arr.squeeze(0)
591+
h, w = arr.shape[2], arr.shape[3]
592+
arr = arr.squeeze(0)
593593

594-
delta_h = h - image_size[0]
595-
delta_w = w - image_size[1]
594+
delta_h = h - image_size[0]
595+
delta_w = w - image_size[1]
596596

597-
if reshape_mode == "random" or reshape_mode == "none":
598-
top = np.random.randint(0, delta_h + 1)
599-
left = np.random.randint(0, delta_w + 1)
600-
elif reshape_mode == "center":
601-
top, left = delta_h // 2, delta_w // 2
602-
else:
603-
raise NotImplementedError
604-
arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
605-
return arr
597+
if reshape_mode == "random" or reshape_mode == "none":
598+
top = np.random.randint(0, delta_h + 1)
599+
left = np.random.randint(0, delta_w + 1)
600+
elif reshape_mode == "center":
601+
top, left = delta_h // 2, delta_w // 2
602+
else:
603+
raise NotImplementedError
604+
arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
605+
return arr
606606

607607
def _preprocess_data(self):
608608
try:
@@ -622,8 +622,7 @@ def _preprocess_data(self):
622622
videos = []
623623

624624
for filename in self.instance_video_paths:
625-
progress_dataset_bar.update(1)
626-
video_reader = decord.VideoReader(uri=filename.as_posix(), width=self.width, height=self.height)
625+
video_reader = decord.VideoReader(uri=filename.as_posix())
627626
video_num_frames = len(video_reader)
628627

629628
start_frame = min(self.skip_frames_start, video_num_frames)
@@ -651,8 +650,12 @@ def _preprocess_data(self):
651650
# Training transforms
652651
frames = (frames - 127.5) / 127.5
653652
frames = frames.permute(0, 3, 1, 2) # [F, C, H, W]
653+
progress_dataset_bar.set_description(
654+
f"Loading progress Resizing video from {frames.shape[2]}x{frames.shape[3]} to {self.height}x{self.width}"
655+
)
654656
frames = self._resize_for_rectangle_crop(frames)
655657
videos.append(frames.contiguous()) # [F, C, H, W]
658+
progress_dataset_bar.update(1)
656659

657660
progress_dataset_bar.close()
658661

0 commit comments

Comments
 (0)