@@ -418,6 +418,15 @@ def parse_args():
418418 default = 4 ,
419419 help = ("The dimension of the LoRA update matrices." ),
420420 )
421+ parser .add_argument (
422+ "--image_interpolation_mode" ,
423+ type = str ,
424+ default = "lanczos" ,
425+ choices = [
426+ f .lower () for f in dir (transforms .InterpolationMode ) if not f .startswith ("__" ) and not f .endswith ("__" )
427+ ],
428+ help = "The image interpolation method to use for resizing images." ,
429+ )
421430
422431 args = parser .parse_args ()
423432 env_local_rank = int (os .environ .get ("LOCAL_RANK" , - 1 ))
@@ -649,10 +658,17 @@ def tokenize_captions(examples, is_train=True):
649658 )
650659 return inputs .input_ids
651660
652- # Preprocessing the datasets.
661+ # Get the specified interpolation method from the args
662+ interpolation = getattr (transforms .InterpolationMode , args .image_interpolation_mode .upper (), None )
663+
664+ # Raise an error if the interpolation method is invalid
665+ if interpolation is None :
666+ raise ValueError (f"Unsupported interpolation mode { args .image_interpolation_mode } ." )
667+
668+ # Data preprocessing transformations
653669 train_transforms = transforms .Compose (
654670 [
655- transforms .Resize (args .resolution , interpolation = transforms . InterpolationMode . BILINEAR ),
671+ transforms .Resize (args .resolution , interpolation = interpolation ), # Use dynamic interpolation method
656672 transforms .CenterCrop (args .resolution ) if args .center_crop else transforms .RandomCrop (args .resolution ),
657673 transforms .RandomHorizontalFlip () if args .random_flip else transforms .Lambda (lambda x : x ),
658674 transforms .ToTensor (),
0 commit comments