@@ -33,9 +33,9 @@ def __init__(
3333 vae_decode_only : bool = False ,
3434 vae_tiling : bool = False ,
3535 n_threads : int = - 1 ,
36- wtype : Union [str , GGMLType , int , float , None ] = "default" ,
37- rng_type : Union [str , RNGType , int , float , None ] = "cuda" ,
38- schedule : Union [str , Schedule , int , float , None ] = "default" ,
36+ wtype : Optional [ Union [str , GGMLType , int , float ] ] = "default" ,
37+ rng_type : Optional [ Union [str , RNGType , int , float ] ] = "cuda" ,
38+ schedule : Optional [ Union [str , Schedule , int , float ] ] = "default" ,
3939 keep_clip_on_cpu : bool = False ,
4040 keep_control_net_cpu : bool = False ,
4141 keep_vae_on_cpu : bool = False ,
@@ -189,7 +189,7 @@ def txt_to_img(
189189 guidance : float = 3.5 ,
190190 width : int = 512 ,
191191 height : int = 512 ,
192- sample_method : Union [str , SampleMethod , int , float , None ] = "euler_a" ,
192+ sample_method : Optional [ Union [str , SampleMethod , int , float ] ] = "euler_a" ,
193193 sample_steps : int = 20 ,
194194 seed : int = 42 ,
195195 batch_count : int = 1 ,
@@ -315,13 +315,14 @@ def img_to_img(
315315 self ,
316316 image : Union [Image .Image , str ],
317317 prompt : str ,
318+ mask_image : Optional [Union [Image .Image , str ]] = None ,
318319 negative_prompt : str = "" ,
319320 clip_skip : int = - 1 ,
320321 cfg_scale : float = 7.0 ,
321322 guidance : float = 3.5 ,
322323 width : int = 512 ,
323324 height : int = 512 ,
324- sample_method : Union [str , SampleMethod , int , float , None ] = "euler_a" ,
325+ sample_method : Optional [ Union [str , SampleMethod , int , float ] ] = "euler_a" ,
325326 sample_steps : int = 20 ,
326327 strength : float = 0.75 ,
327328 seed : int = 42 ,
@@ -344,6 +345,7 @@ def img_to_img(
344345 Args:
345346 image: The input image path or Pillow Image to direct the generation.
346347 prompt: The prompt to render.
348+ mask_image: The inpainting mask image path or Pillow Image.
347349 negative_prompt: The negative prompt.
348350 clip_skip: Ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer.
349351 cfg_scale: Unconditional guidance scale.
@@ -414,8 +416,25 @@ def sd_progress_callback(
414416 # Resize the input image
415417 image = self ._resize_image (image , width , height ) # Input image and generated image must have the same size
416418
417- # Convert the image to a byte array
419+ def _create_blank_mask_image (width : int , height : int ):
420+ """Create a blank white mask image in c_unit8 format."""
421+ mask_image_buffer = (ctypes .c_uint8 * (width * height ))(* [255 ] * (width * height ))
422+ return mask_image_buffer
423+
424+ # Convert the image and mask image to a byte array
418425 image_pointer = self ._image_to_sd_image_t_p (image )
426+ if mask_image :
427+ # Resize the mask image (however the mask should ideally already be the same size as the input image)
428+ mask_image = self ._resize_image (mask_image , width , height )
429+ mask_image_pointer = self ._image_to_sd_image_t_p (mask_image , channel = 1 )
430+ else :
431+ # Create a blank white mask image
432+ mask_image_pointer = self ._c_uint8_to_sd_image_t_p (
433+ image = _create_blank_mask_image (width , height ),
434+ width = width ,
435+ height = height ,
436+ channel = 1 ,
437+ )
419438
420439 # Convert skip_layers to a ctypes array
421440 skip_layers_array = (ctypes .c_int * len (skip_layers ))(* skip_layers )
@@ -426,6 +445,7 @@ def sd_progress_callback(
426445 c_images = sd_cpp .img2img (
427446 self .model ,
428447 image_pointer ,
448+ mask_image_pointer ,
429449 prompt .encode ("utf-8" ),
430450 negative_prompt .encode ("utf-8" ),
431451 clip_skip ,
@@ -466,7 +486,7 @@ def img_to_vid(
466486 augmentation_level : float = 0.0 ,
467487 min_cfg : float = 1.0 ,
468488 cfg_scale : float = 7.0 ,
469- sample_method : Union [str , SampleMethod , int , float , None ] = "euler_a" ,
489+ sample_method : Optional [ Union [str , SampleMethod , int , float ] ] = "euler_a" ,
470490 sample_steps : int = 20 ,
471491 strength : float = 0.75 ,
472492 seed : int = 42 ,
@@ -661,7 +681,6 @@ def sd_progress_callback(
661681 # ==================== Upscale images ====================
662682
663683 upscaled_images = []
664-
665684 for image in images :
666685
667686 # Convert the image to a byte array
@@ -698,19 +717,24 @@ def _resize_image(self, image: Union[Image.Image, str], width: int, height: int)
698717 def _format_image (
699718 self ,
700719 image : Union [Image .Image , str ],
720+ channel : int = 3 ,
701721 ) -> Image .Image :
702- """Convert an image path or Pillow Image to a Pillow Image of RGBA format."""
722+ """Convert an image path or Pillow Image to a Pillow Image of RGBA or grayscale (inpainting masks) format."""
703723 # Convert image path to image if str
704724 if isinstance (image , str ):
705725 image = Image .open (image )
706726
707- # Convert any non RGBA to RGBA
708- if image .format != "PNG" :
709- image = image .convert ("RGBA" )
727+ if channel == 1 :
728+ # Grayscale the image if channel is 1
729+ image = image .convert ("L" )
730+ else :
731+ # Convert any non RGBA to RGBA
732+ if image .format != "PNG" :
733+ image = image .convert ("RGBA" )
710734
711- # Ensure the image is in RGB mode
712- if image .mode != "RGB" :
713- image = image .convert ("RGB" )
735+ # Ensure the image is in RGB mode
736+ if image .mode != "RGB" :
737+ image = image .convert ("RGB" )
714738
715739 return image , image .width , image .height
716740
@@ -741,14 +765,12 @@ def _format_control_cond(
741765
742766 # ============= Image to C uint8 pointer =============
743767
744- def _cast_image (self , image : Union [Image .Image , str ]):
768+ def _cast_image (self , image : Union [Image .Image , str ], channel : int = 3 ):
745769 """Cast a PIL Image to a C uint8 pointer."""
746-
747- image , width , height = self ._format_image (image )
770+ image , width , height = self ._format_image (image , channel )
748771
749772 # Convert the PIL Image to a byte array
750773 image_bytes = image .tobytes ()
751-
752774 data = ctypes .cast (
753775 (ctypes .c_byte * len (image_bytes ))(* image_bytes ),
754776 ctypes .POINTER (ctypes .c_uint8 ),
@@ -757,8 +779,8 @@ def _cast_image(self, image: Union[Image.Image, str]):
757779
758780 # ============= Image to C sd_image_t =============
759781
760- def _c_uint8_to_sd_image_t_p (self , image : ctypes .c_uint8 , width , height , channel : int = 3 ):
761- # Create a new C sd_image_t
782+ def _c_uint8_to_sd_image_t_p (self , image : ctypes .c_uint8 , width : int , height : int , channel : int = 3 ) -> sd_cpp . sd_image_t :
783+ """Convert a C uint8 pointer to a C sd_image_t."""
762784 c_image = sd_cpp .sd_image_t (
763785 width = width ,
764786 height = height ,
@@ -767,21 +789,18 @@ def _c_uint8_to_sd_image_t_p(self, image: ctypes.c_uint8, width, height, channel
767789 )
768790 return c_image
769791
770- def _image_to_sd_image_t_p (self , image : Union [Image .Image , str ]) :
792+ def _image_to_sd_image_t_p (self , image : Union [Image .Image , str ], channel : int = 3 ) -> sd_cpp . sd_image_t :
771793 """Convert a PIL Image or image path to a C sd_image_t."""
772-
773- data , width , height = self ._cast_image (image )
774-
775- # Create a new C sd_image_t
776- c_image = self ._c_uint8_to_sd_image_t_p (data , width , height )
794+ data , width , height = self ._cast_image (image , channel )
795+ c_image = self ._c_uint8_to_sd_image_t_p (data , width , height , channel )
777796 return c_image
778797
779798 # ============= C sd_image_t to Image =============
780799
781- def _c_array_to_bytes (self , c_array , buffer_size : int ):
800+ def _c_array_to_bytes (self , c_array , buffer_size : int ) -> bytes :
782801 return bytearray (ctypes .cast (c_array , ctypes .POINTER (ctypes .c_byte * buffer_size )).contents )
783802
784- def _dereference_sd_image_t_p (self , c_image : sd_cpp .sd_image_t ):
803+ def _dereference_sd_image_t_p (self , c_image : sd_cpp .sd_image_t ) -> Dict :
785804 """Dereference a C sd_image_t pointer to a Python dictionary with height, width, channel and data (bytes)."""
786805
787806 # Calculate the size of the data buffer
@@ -795,7 +814,7 @@ def _dereference_sd_image_t_p(self, c_image: sd_cpp.sd_image_t):
795814 }
796815 return image
797816
798- def _image_slice (self , c_images : sd_cpp .sd_image_t , count : int , upscale_factor : int ):
817+ def _image_slice (self , c_images : sd_cpp .sd_image_t , count : int , upscale_factor : int ) -> List [ Dict ] :
799818 """Slice a C array of images."""
800819 image_array = ctypes .cast (c_images , ctypes .POINTER (sd_cpp .sd_image_t * count )).contents
801820
@@ -821,7 +840,7 @@ def _image_slice(self, c_images: sd_cpp.sd_image_t, count: int, upscale_factor:
821840 # Return the list of images
822841 return images
823842
824- def _sd_image_t_p_to_images (self , c_images : sd_cpp .sd_image_t , count : int , upscale_factor : int ):
843+ def _sd_image_t_p_to_images (self , c_images : sd_cpp .sd_image_t , count : int , upscale_factor : int ) -> List [ Image . Image ] :
825844 """Convert C sd_image_t_p images to a Python list of images."""
826845
827846 # Convert C array to Python list of images
@@ -836,20 +855,30 @@ def _sd_image_t_p_to_images(self, c_images: sd_cpp.sd_image_t, count: int, upsca
836855
837856 # ============= Bytes to Image =============
838857
839- def _bytes_to_image (self , byte_data : bytes , width : int , height : int ) :
858+ def _bytes_to_image (self , byte_data : bytes , width : int , height : int , channel : int = 3 ) -> Image . Image :
840859 """Convert a byte array to a PIL Image."""
860+ # Initialize the image with RGBA mode
841861 image = Image .new ("RGBA" , (width , height ))
842862
843863 for y in range (height ):
844864 for x in range (width ):
845- idx = (y * width + x ) * 3
846- image .putpixel (
847- (x , y ),
848- (byte_data [idx ], byte_data [idx + 1 ], byte_data [idx + 2 ], 255 ),
849- )
865+ idx = (y * width + x ) * channel
866+ # Dynamically create the color tuple
867+ color = tuple (byte_data [idx + i ] if idx + i < len (byte_data ) else 0 for i in range (channel ))
868+ if channel == 1 : # Grayscale
869+ color = (color [0 ],) * 3 + (255 ,) # Convert to (R, G, B, A)
870+ elif channel == 3 : # RGB
871+ color = color + (255 ,) # Add alpha channel
872+ elif channel == 4 : # RGBA
873+ pass # Use color as is
874+ else :
875+ raise ValueError (f"Unsupported channel value: { channel } " )
876+ # Set the pixel
877+ image .putpixel ((x , y ), color )
878+
850879 return image
851880
852- def __setstate__ (self , state ):
881+ def __setstate__ (self , state ) -> None :
853882 self .__init__ (** state )
854883
855884 def close (self ) -> None :
@@ -865,7 +894,7 @@ def __del__(self) -> None:
865894# ============================================
866895
867896
868- def validate_dimensions (dimension : int | float , attribute_name : str ) -> int :
897+ def validate_dimensions (dimension : Union [ int , float ] , attribute_name : str ) -> int :
869898 """Dimensions must be a multiple of 64 otherwise a GGML_ASSERT error is encountered."""
870899 dimension = int (dimension )
871900 if dimension <= 0 or dimension % 64 != 0 :
0 commit comments