@@ -639,6 +639,15 @@ def parse_args(input_args=None):
639639 action = "store_true" ,
640640 help = "Enable model cpu offload and save memory." ,
641641 )
642+ parser .add_argument (
643+ "--image_interpolation_mode" ,
644+ type = str ,
645+ default = "lanczos" ,
646+ choices = [
647+ f .lower () for f in dir (transforms .InterpolationMode ) if not f .startswith ("__" ) and not f .endswith ("__" )
648+ ],
649+ help = "The image interpolation method to use for resizing images." ,
650+ )
642651
643652 if input_args is not None :
644653 args = parser .parse_args (input_args )
@@ -736,9 +745,13 @@ def get_train_dataset(args, accelerator):
736745
737746
738747def prepare_train_dataset (dataset , accelerator ):
748+ interpolation = getattr (transforms .InterpolationMode , args .image_interpolation_mode .upper (), None )
749+ if interpolation is None :
750+ raise ValueError (f"Unsupported interpolation mode { interpolation = } ." )
751+
739752 image_transforms = transforms .Compose (
740753 [
741- transforms .Resize (args .resolution , interpolation = transforms . InterpolationMode . BILINEAR ),
754+ transforms .Resize (args .resolution , interpolation = interpolation ),
742755 transforms .CenterCrop (args .resolution ),
743756 transforms .ToTensor (),
744757 transforms .Normalize ([0.5 ], [0.5 ]),
@@ -747,7 +760,7 @@ def prepare_train_dataset(dataset, accelerator):
747760
748761 conditioning_image_transforms = transforms .Compose (
749762 [
750- transforms .Resize (args .resolution , interpolation = transforms . InterpolationMode . BILINEAR ),
763+ transforms .Resize (args .resolution , interpolation = interpolation ),
751764 transforms .CenterCrop (args .resolution ),
752765 transforms .ToTensor (),
753766 transforms .Normalize ([0.5 ], [0.5 ]),
0 commit comments