@@ -499,6 +499,15 @@ def parse_args():
499499 " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
500500 ),
501501 )
502+ parser .add_argument (
503+ "--image_interpolation_mode" ,
504+ type = str ,
505+ default = "lanczos" ,
506+ choices = [
507+ f .lower () for f in dir (transforms .InterpolationMode ) if not f .startswith ("__" ) and not f .endswith ("__" )
508+ ],
509+ help = "The image interpolation method to use for resizing images." ,
510+ )
502511
503512 args = parser .parse_args ()
504513 env_local_rank = int (os .environ .get ("LOCAL_RANK" , - 1 ))
@@ -787,10 +796,17 @@ def tokenize_captions(examples, is_train=True):
787796 )
788797 return inputs .input_ids
789798
790- # Preprocessing the datasets.
799+ # Get the specified interpolation method from the args
800+ interpolation = getattr (transforms .InterpolationMode , args .image_interpolation_mode .upper (), None )
801+
802+ # Raise an error if the interpolation method is invalid
803+ if interpolation is None :
804+ raise ValueError (f"Unsupported interpolation mode { args .image_interpolation_mode } ." )
805+
806+ # Data preprocessing transformations
791807 train_transforms = transforms .Compose (
792808 [
793- transforms .Resize (args .resolution , interpolation = transforms . InterpolationMode . BILINEAR ),
809+ transforms .Resize (args .resolution , interpolation = interpolation ), # Use dynamic interpolation method
794810 transforms .CenterCrop (args .resolution ) if args .center_crop else transforms .RandomCrop (args .resolution ),
795811 transforms .RandomHorizontalFlip () if args .random_flip else transforms .Lambda (lambda x : x ),
796812 transforms .ToTensor (),
0 commit comments