@@ -523,6 +523,7 @@ def resize(
523523                size = (height , width ),
524524            )
525525            image  =  self .pt_to_numpy (image )
526+ 
526527        return  image 
527528
528529    def  binarize (self , image : PIL .Image .Image ) ->  PIL .Image .Image :
@@ -838,6 +839,137 @@ def apply_overlay(
838839        return  image 
839840
840841
842+ class  InpaintProcessor (ConfigMixin ):
843+     """ 
844+     Image processor for inpainting image and mask. 
845+     """ 
846+ 
847+     config_name  =  CONFIG_NAME 
848+ 
849+     @register_to_config  
850+     def  __init__ (
851+         self ,
852+         do_resize : bool  =  True ,
853+         vae_scale_factor : int  =  8 ,
854+         vae_latent_channels : int  =  4 ,
855+         resample : str  =  "lanczos" ,
856+         reducing_gap : int  =  None ,
857+         do_normalize : bool  =  True ,
858+         do_binarize : bool  =  False ,
859+         do_convert_grayscale : bool  =  False ,
860+         mask_do_normalize : bool  =  False ,
861+         mask_do_binarize : bool  =  True ,
862+         mask_do_convert_grayscale : bool  =  True ,
863+     ):
864+         super ().__init__ ()
865+ 
866+         self ._image_processor  =  VaeImageProcessor (
867+             do_resize = do_resize ,
868+             vae_scale_factor = vae_scale_factor ,
869+             vae_latent_channels = vae_latent_channels ,
870+             resample = resample ,
871+             reducing_gap = reducing_gap ,
872+             do_normalize = do_normalize ,
873+             do_binarize = do_binarize ,
874+             do_convert_grayscale = do_convert_grayscale ,
875+         )
876+         self ._mask_processor  =  VaeImageProcessor (
877+             do_resize = do_resize ,
878+             vae_scale_factor = vae_scale_factor ,
879+             vae_latent_channels = vae_latent_channels ,
880+             resample = resample ,
881+             reducing_gap = reducing_gap ,
882+             do_normalize = mask_do_normalize ,
883+             do_binarize = mask_do_binarize ,
884+             do_convert_grayscale = mask_do_convert_grayscale ,
885+         )
886+ 
887+     def  preprocess (
888+         self ,
889+         image : PIL .Image .Image ,
890+         mask : PIL .Image .Image  =  None ,
891+         height : int  =  None ,
892+         width : int  =  None ,
893+         padding_mask_crop : Optional [int ] =  None ,
894+     ) ->  Tuple [torch .Tensor , torch .Tensor ]:
895+         """ 
896+         Preprocess the image and mask. 
897+         """ 
898+         if  mask  is  None  and  padding_mask_crop  is  not   None :
899+             raise  ValueError ("mask must be provided if padding_mask_crop is provided" )
900+ 
901+         # if mask is None, same behavior as regular image processor 
902+         if  mask  is  None :
903+             return  self ._image_processor .preprocess (image , height = height , width = width )
904+ 
905+         if  padding_mask_crop  is  not   None :
906+             crops_coords  =  self ._image_processor .get_crop_region (mask , width , height , pad = padding_mask_crop )
907+             resize_mode  =  "fill" 
908+         else :
909+             crops_coords  =  None 
910+             resize_mode  =  "default" 
911+ 
912+         processed_image  =  self ._image_processor .preprocess (
913+             image ,
914+             height = height ,
915+             width = width ,
916+             crops_coords = crops_coords ,
917+             resize_mode = resize_mode ,
918+         )
919+ 
920+         processed_mask  =  self ._mask_processor .preprocess (
921+             mask ,
922+             height = height ,
923+             width = width ,
924+             resize_mode = resize_mode ,
925+             crops_coords = crops_coords ,
926+         )
927+ 
928+         if  crops_coords  is  not   None :
929+             postprocessing_kwargs  =  {
930+                 "crops_coords" : crops_coords ,
931+                 "original_image" : image ,
932+                 "original_mask" : mask ,
933+             }
934+         else :
935+             postprocessing_kwargs  =  {
936+                 "crops_coords" : None ,
937+                 "original_image" : None ,
938+                 "original_mask" : None ,
939+             }
940+ 
941+         return  processed_image , processed_mask , postprocessing_kwargs 
942+ 
943+     def  postprocess (
944+         self ,
945+         image : torch .Tensor ,
946+         output_type : str  =  "pil" ,
947+         original_image : Optional [PIL .Image .Image ] =  None ,
948+         original_mask : Optional [PIL .Image .Image ] =  None ,
949+         crops_coords : Optional [Tuple [int , int , int , int ]] =  None ,
950+     ) ->  Tuple [PIL .Image .Image , PIL .Image .Image ]:
951+         """ 
952+         Postprocess the image, optionally apply mask overlay 
953+         """ 
954+         image  =  self ._image_processor .postprocess (
955+             image ,
956+             output_type = output_type ,
957+         )
958+         # optionally apply the mask overlay 
959+         if  crops_coords  is  not   None  and  (original_image  is  None  or  original_mask  is  None ):
960+             raise  ValueError ("original_image and original_mask must be provided if crops_coords is provided" )
961+ 
962+         elif  crops_coords  is  not   None  and  output_type  !=  "pil" :
963+             raise  ValueError ("output_type must be 'pil' if crops_coords is provided" )
964+ 
965+         elif  crops_coords  is  not   None :
966+             image  =  [
967+                 self ._image_processor .apply_overlay (original_mask , original_image , i , crops_coords ) for  i  in  image 
968+             ]
969+ 
970+         return  image 
971+ 
972+ 
841973class  VaeImageProcessorLDM3D (VaeImageProcessor ):
842974    """ 
843975    Image processor for VAE LDM3D. 
0 commit comments