From ad8a13a0a4e080b324bc46ced333124b07c7649f Mon Sep 17 00:00:00 2001 From: WAS Date: Thu, 11 May 2023 14:19:28 -0700 Subject: [PATCH 1/3] Update Cushy_Nodes.py Add VIT selection --- Cushy_Nodes.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Cushy_Nodes.py b/Cushy_Nodes.py index ec3a9dc..d7a8750 100644 --- a/Cushy_Nodes.py +++ b/Cushy_Nodes.py @@ -323,6 +323,7 @@ def INPUT_TYPES(s): return { "required": { "image": ("IMAGE",), + "vit_model": (["default", "vit_h", "vit_l", "vit_b"],), "sam_model_name": (get_files_in_directory(sam_models_directory), { "default": "sam_vit_h_4b8939.pth" }), }, } @@ -334,12 +335,11 @@ def INPUT_TYPES(s): CATEGORY = "CushyNodes" - def execute_segmentation(self, image, sam_model_name): + def execute_segmentation(self, image, vit_model, sam_model_name): # Load the SAM model sam_checkpoint = os.path.join(sam_models_directory, sam_model_name) - model_type = "vit_h" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) + sam = sam_model_registry[vit_model](checkpoint=sam_checkpoint) sam.to(device=device) # Use the input image From 3af3e969b3c702f3a375af94ce3d72aa61b069c4 Mon Sep 17 00:00:00 2001 From: WAS Date: Thu, 11 May 2023 14:23:55 -0700 Subject: [PATCH 2/3] Update Cushy_Nodes.py --- Cushy_Nodes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cushy_Nodes.py b/Cushy_Nodes.py index d7a8750..f9cfe4d 100644 --- a/Cushy_Nodes.py +++ b/Cushy_Nodes.py @@ -323,7 +323,7 @@ def INPUT_TYPES(s): return { "required": { "image": ("IMAGE",), - "vit_model": (["default", "vit_h", "vit_l", "vit_b"],), + "vit_model": (sam_model_registry.keys(),), "sam_model_name": (get_files_in_directory(sam_models_directory), { "default": "sam_vit_h_4b8939.pth" }), }, } From 12c4849e4fc97ecb7cd19d8ba36679b189506011 Mon Sep 17 00:00:00 2001 From: WAS Date: Thu, 11 May 2023 14:39:17 -0700 Subject: [PATCH 3/3] Update Cushy_Nodes.py Revert back to list --- Cushy_Nodes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cushy_Nodes.py b/Cushy_Nodes.py index f9cfe4d..d7a8750 100644 --- a/Cushy_Nodes.py +++ b/Cushy_Nodes.py @@ -323,7 +323,7 @@ def INPUT_TYPES(s): return { "required": { "image": ("IMAGE",), - "vit_model": (sam_model_registry.keys(),), + "vit_model": (["default", "vit_h", "vit_l", "vit_b"],), "sam_model_name": (get_files_in_directory(sam_models_directory), { "default": "sam_vit_h_4b8939.pth" }), }, }