@@ -523,6 +523,7 @@ def resize(
523523 size = (height , width ),
524524 )
525525 image = self .pt_to_numpy (image )
526+
526527 return image
527528
528529 def binarize (self , image : PIL .Image .Image ) -> PIL .Image .Image :
@@ -838,6 +839,137 @@ def apply_overlay(
838839 return image
839840
840841
842+ class InpaintProcessor (ConfigMixin ):
843+ """
844+ Image processor for inpainting image and mask.
845+ """
846+
847+ config_name = CONFIG_NAME
848+
849+ @register_to_config
850+ def __init__ (
851+ self ,
852+ do_resize : bool = True ,
853+ vae_scale_factor : int = 8 ,
854+ vae_latent_channels : int = 4 ,
855+ resample : str = "lanczos" ,
856+ reducing_gap : int = None ,
857+ do_normalize : bool = True ,
858+ do_binarize : bool = False ,
859+ do_convert_grayscale : bool = False ,
860+ mask_do_normalize : bool = False ,
861+ mask_do_binarize : bool = True ,
862+ mask_do_convert_grayscale : bool = True ,
863+ ):
864+ super ().__init__ ()
865+
866+ self ._image_processor = VaeImageProcessor (
867+ do_resize = do_resize ,
868+ vae_scale_factor = vae_scale_factor ,
869+ vae_latent_channels = vae_latent_channels ,
870+ resample = resample ,
871+ reducing_gap = reducing_gap ,
872+ do_normalize = do_normalize ,
873+ do_binarize = do_binarize ,
874+ do_convert_grayscale = do_convert_grayscale ,
875+ )
876+ self ._mask_processor = VaeImageProcessor (
877+ do_resize = do_resize ,
878+ vae_scale_factor = vae_scale_factor ,
879+ vae_latent_channels = vae_latent_channels ,
880+ resample = resample ,
881+ reducing_gap = reducing_gap ,
882+ do_normalize = mask_do_normalize ,
883+ do_binarize = mask_do_binarize ,
884+ do_convert_grayscale = mask_do_convert_grayscale ,
885+ )
886+
887+ def preprocess (
888+ self ,
889+ image : PIL .Image .Image ,
890+ mask : PIL .Image .Image = None ,
891+ height : int = None ,
892+ width : int = None ,
893+ padding_mask_crop : Optional [int ] = None ,
894+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
895+ """
896+ Preprocess the image and mask.
897+ """
898+ if mask is None and padding_mask_crop is not None :
899+ raise ValueError ("mask must be provided if padding_mask_crop is provided" )
900+
901+ # if mask is None, same behavior as regular image processor
902+ if mask is None :
903+ return self ._image_processor .preprocess (image , height = height , width = width )
904+
905+ if padding_mask_crop is not None :
906+ crops_coords = self ._image_processor .get_crop_region (mask , width , height , pad = padding_mask_crop )
907+ resize_mode = "fill"
908+ else :
909+ crops_coords = None
910+ resize_mode = "default"
911+
912+ processed_image = self ._image_processor .preprocess (
913+ image ,
914+ height = height ,
915+ width = width ,
916+ crops_coords = crops_coords ,
917+ resize_mode = resize_mode ,
918+ )
919+
920+ processed_mask = self ._mask_processor .preprocess (
921+ mask ,
922+ height = height ,
923+ width = width ,
924+ resize_mode = resize_mode ,
925+ crops_coords = crops_coords ,
926+ )
927+
928+ if crops_coords is not None :
929+ postprocessing_kwargs = {
930+ "crops_coords" : crops_coords ,
931+ "original_image" : image ,
932+ "original_mask" : mask ,
933+ }
934+ else :
935+ postprocessing_kwargs = {
936+ "crops_coords" : None ,
937+ "original_image" : None ,
938+ "original_mask" : None ,
939+ }
940+
941+ return processed_image , processed_mask , postprocessing_kwargs
942+
943+ def postprocess (
944+ self ,
945+ image : torch .Tensor ,
946+ output_type : str = "pil" ,
947+ original_image : Optional [PIL .Image .Image ] = None ,
948+ original_mask : Optional [PIL .Image .Image ] = None ,
949+ crops_coords : Optional [Tuple [int , int , int , int ]] = None ,
950+ ) -> Tuple [PIL .Image .Image , PIL .Image .Image ]:
951+ """
952+ Postprocess the image, optionally apply mask overlay
953+ """
954+ image = self ._image_processor .postprocess (
955+ image ,
956+ output_type = output_type ,
957+ )
958+ # optionally apply the mask overlay
959+ if crops_coords is not None and (original_image is None or original_mask is None ):
960+ raise ValueError ("original_image and original_mask must be provided if crops_coords is provided" )
961+
962+ elif crops_coords is not None and output_type != "pil" :
963+ raise ValueError ("output_type must be 'pil' if crops_coords is provided" )
964+
965+ elif crops_coords is not None :
966+ image = [
967+ self ._image_processor .apply_overlay (original_mask , original_image , i , crops_coords ) for i in image
968+ ]
969+
970+ return image
971+
972+
841973class VaeImageProcessorLDM3D (VaeImageProcessor ):
842974 """
843975 Image processor for VAE LDM3D.
0 commit comments