@@ -573,36 +573,36 @@ def _load_dataset_from_local_path(self):
573573 return instance_prompts , instance_videos
574574
575575 def _resize_for_rectangle_crop (self , arr ):
576- image_size = self .height , self .width
577- reshape_mode = self .video_reshape_mode
578- if arr .shape [3 ] / arr .shape [2 ] > image_size [1 ] / image_size [0 ]:
579- arr = resize (
580- arr ,
581- size = [image_size [0 ], int (arr .shape [3 ] * image_size [0 ] / arr .shape [2 ])],
582- interpolation = InterpolationMode .BICUBIC ,
583- )
584- else :
585- arr = resize (
586- arr ,
587- size = [int (arr .shape [2 ] * image_size [1 ] / arr .shape [3 ]), image_size [1 ]],
588- interpolation = InterpolationMode .BICUBIC ,
589- )
576+ image_size = self .height , self .width
577+ reshape_mode = self .video_reshape_mode
578+ if arr .shape [3 ] / arr .shape [2 ] > image_size [1 ] / image_size [0 ]:
579+ arr = resize (
580+ arr ,
581+ size = [image_size [0 ], int (arr .shape [3 ] * image_size [0 ] / arr .shape [2 ])],
582+ interpolation = InterpolationMode .BICUBIC ,
583+ )
584+ else :
585+ arr = resize (
586+ arr ,
587+ size = [int (arr .shape [2 ] * image_size [1 ] / arr .shape [3 ]), image_size [1 ]],
588+ interpolation = InterpolationMode .BICUBIC ,
589+ )
590590
591- h , w = arr .shape [2 ], arr .shape [3 ]
592- arr = arr .squeeze (0 )
591+ h , w = arr .shape [2 ], arr .shape [3 ]
592+ arr = arr .squeeze (0 )
593593
594- delta_h = h - image_size [0 ]
595- delta_w = w - image_size [1 ]
594+ delta_h = h - image_size [0 ]
595+ delta_w = w - image_size [1 ]
596596
597- if reshape_mode == "random" or reshape_mode == "none" :
598- top = np .random .randint (0 , delta_h + 1 )
599- left = np .random .randint (0 , delta_w + 1 )
600- elif reshape_mode == "center" :
601- top , left = delta_h // 2 , delta_w // 2
602- else :
603- raise NotImplementedError
604- arr = TT .functional .crop (arr , top = top , left = left , height = image_size [0 ], width = image_size [1 ])
605- return arr
597+ if reshape_mode == "random" or reshape_mode == "none" :
598+ top = np .random .randint (0 , delta_h + 1 )
599+ left = np .random .randint (0 , delta_w + 1 )
600+ elif reshape_mode == "center" :
601+ top , left = delta_h // 2 , delta_w // 2
602+ else :
603+ raise NotImplementedError
604+ arr = TT .functional .crop (arr , top = top , left = left , height = image_size [0 ], width = image_size [1 ])
605+ return arr
606606
607607 def _preprocess_data (self ):
608608 try :
@@ -622,8 +622,7 @@ def _preprocess_data(self):
622622 videos = []
623623
624624 for filename in self .instance_video_paths :
625- progress_dataset_bar .update (1 )
626- video_reader = decord .VideoReader (uri = filename .as_posix (), width = self .width , height = self .height )
625+ video_reader = decord .VideoReader (uri = filename .as_posix ())
627626 video_num_frames = len (video_reader )
628627
629628 start_frame = min (self .skip_frames_start , video_num_frames )
@@ -651,8 +650,12 @@ def _preprocess_data(self):
651650 # Training transforms
652651 frames = (frames - 127.5 ) / 127.5
653652 frames = frames .permute (0 , 3 , 1 , 2 ) # [F, C, H, W]
653+ progress_dataset_bar .set_description (
654+ f"Loading progress Resizing video from { frames .shape [2 ]} x{ frames .shape [3 ]} to { self .height } x{ self .width } "
655+ )
654656 frames = self ._resize_for_rectangle_crop (frames )
655657 videos .append (frames .contiguous ()) # [F, C, H, W]
658+ progress_dataset_bar .update (1 )
656659
657660 progress_dataset_bar .close ()
658661
0 commit comments