|
1 | 1 | import os |
2 | 2 | import io |
3 | | -from random import random |
4 | | -from turtle import color |
5 | | -import black |
6 | 3 | import requests |
7 | 4 | import json |
8 | 5 | import numpy as np |
@@ -73,13 +70,12 @@ def pil_to_tensor(pil): |
73 | 70 |
|
74 | 71 |
|
75 | 72 | def tensor_to_pil(tensor): |
76 | | - # [H, W, C] to [C, H, W] |
77 | | - return toPIL(tensor.permute(2, 0, 1)) |
78 | | - |
79 | | - |
80 | | -def tensor_to_pil_single(tensor): |
81 | | - # [H, W] to [ H, W] |
82 | | - return toPIL(tensor) |
| 73 | + if len(tensor.shape) == 2: |
| 74 | + # [H, W] to [H, W] |
| 75 | + return toPIL(tensor) |
| 76 | + else: |
| 77 | + # [H, W, C] to [C, H, W] |
| 78 | + return toPIL(tensor.permute(2, 0, 1)) |
83 | 79 |
|
84 | 80 |
|
85 | 81 | def tensor_to_batch(tensor, h, w, c): |
@@ -1019,7 +1015,7 @@ def node_function(self, images, masks, fill_color): |
1019 | 1015 | out_images = [] |
1020 | 1016 | for image, mask in zip(images, masks): |
1021 | 1017 | pil = tensor_to_pil(image) |
1022 | | - pil_mask = tensor_to_pil_single(1 - mask) |
| 1018 | + pil_mask = tensor_to_pil(1 - mask) |
1023 | 1019 |
|
1024 | 1020 | new_pil = Image.new("RGBA", pil.size, fill_color) |
1025 | 1021 | new_pil.paste(pil, pil_mask) |
@@ -1056,25 +1052,25 @@ def node_function(self, metallic, ambient_occlusion, detail_mask, smoothness): |
1056 | 1052 | mask_maps = [] |
1057 | 1053 | for metallic_tensor, ambient_occlusion_tensor, detail_mask_tensor, smoothness_tensor in zip(metallic, ambient_occlusion, detail_mask, smoothness): |
1058 | 1054 | if len(metallic_tensor.shape) == 2: |
1059 | | - metallic_pil_single = tensor_to_pil_single(metallic_tensor) |
| 1055 | + metallic_pil_single = tensor_to_pil(metallic_tensor) |
1060 | 1056 | metallic_pil = metallic_pil_single.convert("RGB") |
1061 | 1057 | else: |
1062 | 1058 | metallic_pil = tensor_to_pil(metallic_tensor) |
1063 | 1059 |
|
1064 | 1060 | if len(ambient_occlusion_tensor.shape) == 2: |
1065 | | - ambient_occlusion_pil_single = tensor_to_pil_single(ambient_occlusion_tensor) |
| 1061 | + ambient_occlusion_pil_single = tensor_to_pil(ambient_occlusion_tensor) |
1066 | 1062 | ambient_occlusion_pil = ambient_occlusion_pil_single.convert("RGB") |
1067 | 1063 | else: |
1068 | 1064 | ambient_occlusion_pil = tensor_to_pil(ambient_occlusion_tensor) |
1069 | 1065 |
|
1070 | 1066 | if len(detail_mask_tensor.shape) == 2: |
1071 | | - detail_mask_pil_single = tensor_to_pil_single(detail_mask_tensor) |
| 1067 | + detail_mask_pil_single = tensor_to_pil(detail_mask_tensor) |
1072 | 1068 | detail_mask_pil = detail_mask_pil_single.convert("RGB") |
1073 | 1069 | else: |
1074 | 1070 | detail_mask_pil = tensor_to_pil(detail_mask_tensor) |
1075 | 1071 |
|
1076 | 1072 | if len(smoothness_tensor.shape) == 2: |
1077 | | - smoothness_pil_single = tensor_to_pil_single(smoothness_tensor) |
| 1073 | + smoothness_pil_single = tensor_to_pil(smoothness_tensor) |
1078 | 1074 | smoothness_pil = smoothness_pil_single.convert("RGB") |
1079 | 1075 | else: |
1080 | 1076 | smoothness_pil = tensor_to_pil(smoothness_tensor) |
@@ -1114,19 +1110,19 @@ def node_function(self, albedo, normal, smoothness): |
1114 | 1110 | detail_maps = [] |
1115 | 1111 | for albedo_tensor, normal_tensor, smoothness_tensor in zip(albedo, normal, smoothness): |
1116 | 1112 | if len(albedo_tensor.shape) == 2: |
1117 | | - albedo_pil_single = tensor_to_pil_single(albedo_tensor) |
| 1113 | + albedo_pil_single = tensor_to_pil(albedo_tensor) |
1118 | 1114 | albedo_pil = albedo_pil_single.convert("RGB") |
1119 | 1115 | else: |
1120 | 1116 | albedo_pil = tensor_to_pil(albedo_tensor) |
1121 | 1117 |
|
1122 | 1118 | if len(normal_tensor.shape) == 2: |
1123 | | - normal_pil_single = tensor_to_pil_single(normal_tensor) |
| 1119 | + normal_pil_single = tensor_to_pil(normal_tensor) |
1124 | 1120 | normal_pil = normal_pil_single.convert("RGB") |
1125 | 1121 | else: |
1126 | 1122 | normal_pil = tensor_to_pil(normal_tensor) |
1127 | 1123 |
|
1128 | 1124 | if len(smoothness_tensor.shape) == 2: |
1129 | | - smoothness_pil_single = tensor_to_pil_single(smoothness_tensor) |
| 1125 | + smoothness_pil_single = tensor_to_pil(smoothness_tensor) |
1130 | 1126 | smoothness_pil = smoothness_pil_single.convert("RGB") |
1131 | 1127 | else: |
1132 | 1128 | smoothness_pil = tensor_to_pil(smoothness_tensor) |
@@ -1222,7 +1218,7 @@ def INPUT_TYPES(cls): |
1222 | 1218 | def node_function(self, images, folder_path, filename_prefix): |
1223 | 1219 | for index, image in enumerate(images): |
1224 | 1220 | if len(image.shape) == 2: |
1225 | | - pil_single = tensor_to_pil_single(image) |
| 1221 | + pil_single = tensor_to_pil(image) |
1226 | 1222 | pil = pil_single.convert("RGB") |
1227 | 1223 | else: |
1228 | 1224 | pil = tensor_to_pil(image) |
|
0 commit comments