Skip to content

Commit 9733773

Browse files
committed
image batch
1 parent 7797999 commit 9733773

File tree

3 files changed

+41
-16
lines changed

3 files changed

+41
-16
lines changed

__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from .nodes.image_utility import ImageToVideoNode
2323
from .nodes.image_utility import LoadImageFromURLNode
2424
from .nodes.image_utility import LoadImageFromDirectoryNode
25+
from .nodes.image_utility import LoadImageBatchFromDirectoryNode
2526
from .nodes.image_utility import FillAlphaNode
2627
from .nodes.image_utility import ImageToBase64Node
2728
from .nodes.image_utility import Base64ToImageNode
@@ -52,6 +53,7 @@
5253
NODE_CLASS_MAPPINGS = {
5354
"StringTranslateNode": StringTranslateNode,
5455
"LoadImageFromDirectoryNode": LoadImageFromDirectoryNode,
56+
"LoadImageBatchFromDirectoryNode": LoadImageBatchFromDirectoryNode,
5557
"FillAlphaNode": FillAlphaNode,
5658
"SaveStringToDirectoryNode": SaveStringToDirectoryNode,
5759
"LoadStringFromDirectoryNode": LoadStringFromDirectoryNode,
@@ -99,6 +101,7 @@
99101
NODE_DISPLAY_NAME_MAPPINGS = {
100102
"StringTranslateNode": "String Translate",
101103
"LoadImageFromDirectoryNode": "Load Image From Directory",
104+
"LoadImageBatchFromDirectoryNode": "Load Image Batch From Directory",
102105
"FillAlphaNode": "Fill Alpha",
103106
"SaveStringToDirectoryNode": "Save String To Directory",
104107
"LoadStringFromDirectoryNode": "Load String From Directory",

nodes/image_utility.py

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,10 @@ def batch2list(tensor_batch):
8585
return tensors
8686

8787

88+
def list2batch(tensor_list):
89+
return torch.cat(tensor_list)
90+
91+
8892
def rgba2rgb(pil):
8993
bg = Image.new("RGB", pil.size, (255, 255, 255))
9094
bg.paste(pil, pil)
@@ -442,6 +446,37 @@ def node_function(self, directory, recursive, channels):
442446
return (out_image, out_dir, out_name)
443447

444448

449+
class LoadImageBatchFromDirectoryNode:
450+
def __init__(self):
451+
pass
452+
453+
@classmethod
454+
def INPUT_TYPES(cls):
455+
return {
456+
"required": {
457+
"directory": (IO.STRING, {"default": "", "forceInput": False}),
458+
"recursive": (IO.BOOLEAN, {"default": False}),
459+
"channels": (["RGB", "RGBA"], {"default": "RGB"}),
460+
}
461+
}
462+
463+
FUNCTION = "node_function"
464+
CATEGORY = "Fair/image"
465+
466+
RETURN_TYPES = (IO.IMAGE, IO.STRING, IO.STRING)
467+
RETURN_NAMES = ("images", "directory", "name")
468+
OUTPUT_IS_LIST = (False, True, True)
469+
470+
def node_function(self, directory, recursive, channels):
471+
if not directory or not os.path.isdir(directory):
472+
raise Exception("folder_path is not valid: " + directory)
473+
474+
(out_image, out_dir, out_name) = load_image_to_tensor(directory, recursive, channels)
475+
out_image = list2batch(out_image)
476+
477+
return (out_image, out_dir, out_name)
478+
479+
445480
class FillAlphaNode:
446481
def __init__(self):
447482
pass
@@ -485,21 +520,8 @@ def fill_alpha(self, tensor, alpha_threshold, fill_color):
485520
return image_tensor
486521

487522
def node_function(self, image, alpha_threshold, r, g, b):
488-
height = image[0].shape[0]
489-
width = image[0].shape[1]
490-
491-
progress_bar = ProgressBar(image.shape[0])
492-
493-
image_tensors = []
494-
495-
for tensor in batch2list(image):
496-
tensor_filled = self.fill_alpha(tensor, alpha_threshold, (r, g, b))
497-
image_tensors.append(tensor_filled)
498-
progress_bar.update(1)
499-
500-
image_tensors = tensor2batch(image_tensors, height, width, 3)
501-
502-
return (image_tensors,)
523+
tensor_filled = self.fill_alpha(image, alpha_threshold, (r, g, b))
524+
return (tensor_filled,)
503525

504526

505527
def pil_to_base64(pli_image, pnginfo=None, header=False):

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ Python Script
6363
6464
Load LoRA Dual
6565
"""
66-
version = "1.0.65"
66+
version = "1.0.66"
6767
license = { file = "LICENSE" }
6868
dependencies = [
6969
"googletrans",

0 commit comments

Comments
 (0)