1- # Copyright 2024 The HuggingFace Team.
2- # All rights reserved.
1+ # Copyright 2024 The CogView team, Tsinghua University & ZhipuAI and The HuggingFace Team. All rights reserved.
32#
43# Licensed under the Apache License, Version 2.0 (the "License");
54# you may not use this file except in compliance with the License.
4544from diffusers .models .embeddings import get_3d_rotary_pos_embed
4645from diffusers .optimization import get_scheduler
4746from diffusers .pipelines .cogvideo .pipeline_cogvideox import get_resize_crop_region_for_grid
48- from diffusers .training_utils import (
49- cast_training_params ,
50- clear_objs_and_retain_memory ,
51- )
47+ from diffusers .training_utils import cast_training_params , free_memory
5248from diffusers .utils import (
5349 check_min_version ,
5450 convert_unet_state_dict_to_peft ,
5854)
5955from diffusers .utils .hub_utils import load_or_create_model_card , populate_model_card
6056from diffusers .utils .torch_utils import is_compiled_module
57+ from torchvision .transforms .functional import center_crop , resize
58+ from torchvision .transforms import InterpolationMode
59+ import torchvision .transforms as TT
60+ import numpy as np
6161
6262
6363if is_wandb_available ():
@@ -236,6 +236,12 @@ def get_args():
236236 default = 720 ,
237237 help = "All input videos are resized to this width." ,
238238 )
239+ parser .add_argument (
240+ "--video_reshape_mode" ,
241+ type = str ,
242+ default = "center" ,
243+ help = "All input videos are reshaped to this mode. Choose between ['center', 'random', 'none']" ,
244+ )
239245 parser .add_argument ("--fps" , type = int , default = 8 , help = "All input videos will be used at this FPS." )
240246 parser .add_argument (
241247 "--max_num_frames" , type = int , default = 49 , help = "All input videos will be truncated to these many frames."
@@ -442,6 +448,7 @@ def __init__(
442448 video_column : str = "video" ,
443449 height : int = 480 ,
444450 width : int = 720 ,
451+ video_reshape_mode : str = "center" ,
445452 fps : int = 8 ,
446453 max_num_frames : int = 49 ,
447454 skip_frames_start : int = 0 ,
@@ -450,19 +457,22 @@ def __init__(
450457 id_token : Optional [str ] = None ,
451458 ) -> None :
452459 super ().__init__ ()
460+
453461 self .instance_data_root = Path (instance_data_root ) if instance_data_root is not None else None
454462 self .dataset_name = dataset_name
455463 self .dataset_config_name = dataset_config_name
456464 self .caption_column = caption_column
457465 self .video_column = video_column
458466 self .height = height
459467 self .width = width
468+ self .video_reshape_mode = video_reshape_mode
460469 self .fps = fps
461470 self .max_num_frames = max_num_frames
462471 self .skip_frames_start = skip_frames_start
463472 self .skip_frames_end = skip_frames_end
464473 self .cache_dir = cache_dir
465474 self .id_token = id_token or ""
475+
466476 if dataset_name is not None :
467477 self .instance_prompts , self .instance_video_paths = self ._load_dataset_from_hub ()
468478 else :
@@ -561,6 +571,38 @@ def _load_dataset_from_local_path(self):
561571
562572 return instance_prompts , instance_videos
563573
574+ def _resize_for_rectangle_crop (self , arr ):
575+ image_size = self .height , self .width
576+ reshape_mode = self .video_reshape_mode
577+ if arr .shape [3 ] / arr .shape [2 ] > image_size [1 ] / image_size [0 ]:
578+ arr = resize (
579+ arr ,
580+ size = [image_size [0 ], int (arr .shape [3 ] * image_size [0 ] / arr .shape [2 ])],
581+ interpolation = InterpolationMode .BICUBIC ,
582+ )
583+ else :
584+ arr = resize (
585+ arr ,
586+ size = [int (arr .shape [2 ] * image_size [1 ] / arr .shape [3 ]), image_size [1 ]],
587+ interpolation = InterpolationMode .BICUBIC ,
588+ )
589+
590+ h , w = arr .shape [2 ], arr .shape [3 ]
591+ arr = arr .squeeze (0 )
592+
593+ delta_h = h - image_size [0 ]
594+ delta_w = w - image_size [1 ]
595+
596+ if reshape_mode == "random" or reshape_mode == "none" :
597+ top = np .random .randint (0 , delta_h + 1 )
598+ left = np .random .randint (0 , delta_w + 1 )
599+ elif reshape_mode == "center" :
600+ top , left = delta_h // 2 , delta_w // 2
601+ else :
602+ raise NotImplementedError
603+ arr = TT .functional .crop (arr , top = top , left = left , height = image_size [0 ], width = image_size [1 ])
604+ return arr
605+
564606 def _preprocess_data (self ):
565607 try :
566608 import decord
@@ -571,14 +613,15 @@ def _preprocess_data(self):
571613
572614 decord .bridge .set_bridge ("torch" )
573615
574- videos = []
575- train_transforms = transforms .Compose (
576- [
577- transforms .Lambda (lambda x : x / 255.0 * 2.0 - 1.0 ),
578- ]
616+ progress_dataset_bar = tqdm (
617+ range (0 , len (self .instance_video_paths )),
618+ desc = "Loading progress resize and crop videos" ,
579619 )
580620
621+ videos = []
622+
581623 for filename in self .instance_video_paths :
624+ progress_dataset_bar .update (1 )
582625 video_reader = decord .VideoReader (uri = filename .as_posix (), width = self .width , height = self .height )
583626 video_num_frames = len (video_reader )
584627
@@ -605,9 +648,12 @@ def _preprocess_data(self):
605648 assert (selected_num_frames - 1 ) % 4 == 0
606649
607650 # Training transforms
608- frames = frames .float ()
609- frames = torch .stack ([train_transforms (frame ) for frame in frames ], dim = 0 )
610- videos .append (frames .permute (0 , 3 , 1 , 2 ).contiguous ()) # [F, C, H, W]
651+ frames = (frames - 127.5 ) / 127.5
652+ frames = frames .permute (0 , 3 , 1 , 2 ) # [F, C, H, W]
653+ frames = self ._resize_for_rectangle_crop (frames )
654+ videos .append (frames .contiguous ()) # [F, C, H, W]
655+
656+ progress_dataset_bar .close ()
611657
612658 return videos
613659
@@ -727,7 +773,7 @@ def log_validation(
727773
728774 videos = []
729775 for _ in range (args .num_validation_videos ):
730- video = pipe (** pipeline_args , generator = generator , output_type = "np " ).frames [0 ]
776+ video = pipe (** pipeline_args , generator = generator , output_type = "pil " ).frames [0 ]
731777 videos .append (video )
732778
733779 for tracker in accelerator .trackers :
@@ -756,7 +802,8 @@ def log_validation(
756802 }
757803 )
758804
759- clear_objs_and_retain_memory ([pipe ])
805+ del pipe
806+ free_memory ()
760807
761808 return videos
762809
@@ -1204,6 +1251,7 @@ def load_model_hook(models, input_dir):
12041251 video_column = args .video_column ,
12051252 height = args .height ,
12061253 width = args .width ,
1254+ video_reshape_mode = args .video_reshape_mode ,
12071255 fps = args .fps ,
12081256 max_num_frames = args .max_num_frames ,
12091257 skip_frames_start = args .skip_frames_start ,
@@ -1212,7 +1260,8 @@ def load_model_hook(models, input_dir):
12121260 id_token = args .id_token ,
12131261 )
12141262
1215- def encode_video (video ):
1263+ def encode_video (video , bar ):
1264+ bar .update (1 )
12161265 video = video .to (accelerator .device , dtype = vae .dtype ).unsqueeze (0 )
12171266 video = video .permute (0 , 2 , 1 , 3 , 4 ) # [B, C, F, H, W]
12181267 image = video [:, :, :1 ].clone ()
@@ -1238,7 +1287,13 @@ def encode_video(video):
12381287 )
12391288 for prompt in train_dataset .instance_prompts
12401289 ]
1241- train_dataset .instance_videos = [encode_video (video ) for video in train_dataset .instance_videos ]
1290+
1291+ progress_encode_bar = tqdm (
1292+ range (0 , len (train_dataset .instance_videos )),
1293+ desc = "Loading Encode videos" ,
1294+ )
1295+ train_dataset .instance_videos = [encode_video (video , progress_encode_bar ) for video in train_dataset .instance_videos ]
1296+ progress_encode_bar .close ()
12421297
12431298 def collate_fn (examples ):
12441299 videos = []
@@ -1378,9 +1433,6 @@ def collate_fn(examples):
13781433 )
13791434 vae_scale_factor_spatial = 2 ** (len (vae .config .block_out_channels ) - 1 )
13801435
1381- # Delete VAE and Text Encoder to save memory
1382- clear_objs_and_retain_memory ([vae , text_encoder ])
1383-
13841436 # For DeepSpeed training
13851437 model_config = transformer .module .config if hasattr (transformer , "module" ) else transformer .config
13861438
@@ -1550,7 +1602,8 @@ def collate_fn(examples):
15501602 )
15511603
15521604 # Cleanup trained models to save memory
1553- clear_objs_and_retain_memory ([transformer ])
1605+ del transformer
1606+ free_memory ()
15541607
15551608 # Final test inference
15561609 pipe = CogVideoXImageToVideoPipeline .from_pretrained (
0 commit comments