Skip to content

Commit 4339f65

Browse files
update
1 parent e26c3c4 commit 4339f65

File tree

2 files changed

+76
-24
lines changed

2 files changed

+76
-24
lines changed

finetune/train_cogvideox_image_to_video_lora.py

Lines changed: 75 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
# Copyright 2024 The HuggingFace Team.
2-
# All rights reserved.
1+
# Copyright 2024 The CogView team, Tsinghua University & ZhipuAI and The HuggingFace Team. All rights reserved.
32
#
43
# Licensed under the Apache License, Version 2.0 (the "License");
54
# you may not use this file except in compliance with the License.
@@ -45,10 +44,7 @@
4544
from diffusers.models.embeddings import get_3d_rotary_pos_embed
4645
from diffusers.optimization import get_scheduler
4746
from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid
48-
from diffusers.training_utils import (
49-
cast_training_params,
50-
clear_objs_and_retain_memory,
51-
)
47+
from diffusers.training_utils import cast_training_params, free_memory
5248
from diffusers.utils import (
5349
check_min_version,
5450
convert_unet_state_dict_to_peft,
@@ -58,6 +54,10 @@
5854
)
5955
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
6056
from diffusers.utils.torch_utils import is_compiled_module
57+
from torchvision.transforms.functional import center_crop, resize
58+
from torchvision.transforms import InterpolationMode
59+
import torchvision.transforms as TT
60+
import numpy as np
6161

6262

6363
if is_wandb_available():
@@ -236,6 +236,12 @@ def get_args():
236236
default=720,
237237
help="All input videos are resized to this width.",
238238
)
239+
parser.add_argument(
240+
"--video_reshape_mode",
241+
type=str,
242+
default="center",
243+
help="All input videos are reshaped to this mode. Choose between ['center', 'random', 'none']",
244+
)
239245
parser.add_argument("--fps", type=int, default=8, help="All input videos will be used at this FPS.")
240246
parser.add_argument(
241247
"--max_num_frames", type=int, default=49, help="All input videos will be truncated to these many frames."
@@ -442,6 +448,7 @@ def __init__(
442448
video_column: str = "video",
443449
height: int = 480,
444450
width: int = 720,
451+
video_reshape_mode: str = "center",
445452
fps: int = 8,
446453
max_num_frames: int = 49,
447454
skip_frames_start: int = 0,
@@ -450,19 +457,22 @@ def __init__(
450457
id_token: Optional[str] = None,
451458
) -> None:
452459
super().__init__()
460+
453461
self.instance_data_root = Path(instance_data_root) if instance_data_root is not None else None
454462
self.dataset_name = dataset_name
455463
self.dataset_config_name = dataset_config_name
456464
self.caption_column = caption_column
457465
self.video_column = video_column
458466
self.height = height
459467
self.width = width
468+
self.video_reshape_mode = video_reshape_mode
460469
self.fps = fps
461470
self.max_num_frames = max_num_frames
462471
self.skip_frames_start = skip_frames_start
463472
self.skip_frames_end = skip_frames_end
464473
self.cache_dir = cache_dir
465474
self.id_token = id_token or ""
475+
466476
if dataset_name is not None:
467477
self.instance_prompts, self.instance_video_paths = self._load_dataset_from_hub()
468478
else:
@@ -561,6 +571,38 @@ def _load_dataset_from_local_path(self):
561571

562572
return instance_prompts, instance_videos
563573

574+
def _resize_for_rectangle_crop(self, arr):
575+
image_size = self.height, self.width
576+
reshape_mode = self.video_reshape_mode
577+
if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
578+
arr = resize(
579+
arr,
580+
size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
581+
interpolation=InterpolationMode.BICUBIC,
582+
)
583+
else:
584+
arr = resize(
585+
arr,
586+
size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
587+
interpolation=InterpolationMode.BICUBIC,
588+
)
589+
590+
h, w = arr.shape[2], arr.shape[3]
591+
arr = arr.squeeze(0)
592+
593+
delta_h = h - image_size[0]
594+
delta_w = w - image_size[1]
595+
596+
if reshape_mode == "random" or reshape_mode == "none":
597+
top = np.random.randint(0, delta_h + 1)
598+
left = np.random.randint(0, delta_w + 1)
599+
elif reshape_mode == "center":
600+
top, left = delta_h // 2, delta_w // 2
601+
else:
602+
raise NotImplementedError
603+
arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
604+
return arr
605+
564606
def _preprocess_data(self):
565607
try:
566608
import decord
@@ -571,14 +613,15 @@ def _preprocess_data(self):
571613

572614
decord.bridge.set_bridge("torch")
573615

574-
videos = []
575-
train_transforms = transforms.Compose(
576-
[
577-
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0),
578-
]
616+
progress_dataset_bar = tqdm(
617+
range(0, len(self.instance_video_paths)),
618+
desc="Loading progress resize and crop videos",
579619
)
580620

621+
videos = []
622+
581623
for filename in self.instance_video_paths:
624+
progress_dataset_bar.update(1)
582625
video_reader = decord.VideoReader(uri=filename.as_posix(), width=self.width, height=self.height)
583626
video_num_frames = len(video_reader)
584627

@@ -605,9 +648,12 @@ def _preprocess_data(self):
605648
assert (selected_num_frames - 1) % 4 == 0
606649

607650
# Training transforms
608-
frames = frames.float()
609-
frames = torch.stack([train_transforms(frame) for frame in frames], dim=0)
610-
videos.append(frames.permute(0, 3, 1, 2).contiguous()) # [F, C, H, W]
651+
frames = (frames - 127.5) / 127.5
652+
frames = frames.permute(0, 3, 1, 2) # [F, C, H, W]
653+
frames = self._resize_for_rectangle_crop(frames)
654+
videos.append(frames.contiguous()) # [F, C, H, W]
655+
656+
progress_dataset_bar.close()
611657

612658
return videos
613659

@@ -727,7 +773,7 @@ def log_validation(
727773

728774
videos = []
729775
for _ in range(args.num_validation_videos):
730-
video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0]
776+
video = pipe(**pipeline_args, generator=generator, output_type="pil").frames[0]
731777
videos.append(video)
732778

733779
for tracker in accelerator.trackers:
@@ -756,7 +802,8 @@ def log_validation(
756802
}
757803
)
758804

759-
clear_objs_and_retain_memory([pipe])
805+
del pipe
806+
free_memory()
760807

761808
return videos
762809

@@ -1204,6 +1251,7 @@ def load_model_hook(models, input_dir):
12041251
video_column=args.video_column,
12051252
height=args.height,
12061253
width=args.width,
1254+
video_reshape_mode=args.video_reshape_mode,
12071255
fps=args.fps,
12081256
max_num_frames=args.max_num_frames,
12091257
skip_frames_start=args.skip_frames_start,
@@ -1212,7 +1260,8 @@ def load_model_hook(models, input_dir):
12121260
id_token=args.id_token,
12131261
)
12141262

1215-
def encode_video(video):
1263+
def encode_video(video, bar):
1264+
bar.update(1)
12161265
video = video.to(accelerator.device, dtype=vae.dtype).unsqueeze(0)
12171266
video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
12181267
image = video[:, :, :1].clone()
@@ -1238,7 +1287,13 @@ def encode_video(video):
12381287
)
12391288
for prompt in train_dataset.instance_prompts
12401289
]
1241-
train_dataset.instance_videos = [encode_video(video) for video in train_dataset.instance_videos]
1290+
1291+
progress_encode_bar = tqdm(
1292+
range(0, len(train_dataset.instance_videos)),
1293+
desc="Loading Encode videos",
1294+
)
1295+
train_dataset.instance_videos = [encode_video(video, progress_encode_bar) for video in train_dataset.instance_videos]
1296+
progress_encode_bar.close()
12421297

12431298
def collate_fn(examples):
12441299
videos = []
@@ -1378,9 +1433,6 @@ def collate_fn(examples):
13781433
)
13791434
vae_scale_factor_spatial = 2 ** (len(vae.config.block_out_channels) - 1)
13801435

1381-
# Delete VAE and Text Encoder to save memory
1382-
clear_objs_and_retain_memory([vae, text_encoder])
1383-
13841436
# For DeepSpeed training
13851437
model_config = transformer.module.config if hasattr(transformer, "module") else transformer.config
13861438

@@ -1550,7 +1602,8 @@ def collate_fn(examples):
15501602
)
15511603

15521604
# Cleanup trained models to save memory
1553-
clear_objs_and_retain_memory([transformer])
1605+
del transformer
1606+
free_memory()
15541607

15551608
# Final test inference
15561609
pipe = CogVideoXImageToVideoPipeline.from_pretrained(

finetune/train_cogvideox_lora.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
# Copyright 2024 The HuggingFace Team.
2-
# All rights reserved.
1+
# Copyright 2024 The CogView team, Tsinghua University & ZhipuAI and The HuggingFace Team. All rights reserved.
32
#
43
# Licensed under the Apache License, Version 2.0 (the "License");
54
# you may not use this file except in compliance with the License.

0 commit comments

Comments
 (0)