@@ -71,6 +71,11 @@ def tensor_to_pil(tensor):
7171 return toPIL (tensor .permute (2 , 0 , 1 ))
7272
7373
74+ def tensor_to_pil_single (tensor ):
75+ # [H, W] to [ H, W]
76+ return toPIL (tensor )
77+
78+
7479def tensor_to_batch (tensor , h , w , c ):
7580 tensor = torch .cat (tensor )
7681 tensor = tensor .reshape (- 1 , h , w , c )
@@ -982,3 +987,37 @@ def INPUT_TYPES(cls):
982987
983988 def node_function (self , direction ):
984989 return (direction ,)
990+
991+
992+ class ImageRemoveAlphaNode :
993+ def __init__ (self ):
994+ pass
995+
996+ @classmethod
997+ def INPUT_TYPES (cls ):
998+ return {
999+ "required" : {
1000+ "images" : (IO .IMAGE , {"defaultInput" : True }),
1001+ "masks" : (IO .MASK , {"defaultInput" : True }),
1002+ "fill_color" : (IO .STRING , {"default" : "#FFFFFF" }),
1003+ }
1004+ }
1005+
1006+ FUNCTION = "node_function"
1007+ CATEGORY = "Fair/image"
1008+ RETURN_TYPES = (IO .IMAGE ,)
1009+ RETURN_NAMES = ("images" ,)
1010+
1011+ def node_function (self , images , masks , fill_color ):
1012+ out_images = []
1013+ for image , mask in zip (images , masks ):
1014+ pil = tensor_to_pil (image )
1015+ pil_mask = tensor_to_pil_single (mask )
1016+
1017+ new_pil = Image .new ("RGBA" , pil .size , fill_color )
1018+ new_pil .paste (pil , pil_mask )
1019+
1020+ image = pil_to_tensor (new_pil )
1021+ out_images .append (image )
1022+ out_images = torch .stack (out_images , dim = 0 )
1023+ return (out_images ,)
0 commit comments