diff --git a/Cushy_Nodes.py b/Cushy_Nodes.py index ec3a9dc..d7a8750 100644 --- a/Cushy_Nodes.py +++ b/Cushy_Nodes.py @@ -323,6 +323,7 @@ def INPUT_TYPES(s): return { "required": { "image": ("IMAGE",), + "vit_model": (["default", "vit_h", "vit_l", "vit_b"],), "sam_model_name": (get_files_in_directory(sam_models_directory), { "default": "sam_vit_h_4b8939.pth" }), }, } @@ -334,12 +335,11 @@ def INPUT_TYPES(s): CATEGORY = "CushyNodes" - def execute_segmentation(self, image, sam_model_name): + def execute_segmentation(self, image, vit_model, sam_model_name): # Load the SAM model sam_checkpoint = os.path.join(sam_models_directory, sam_model_name) - model_type = "vit_h" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) + sam = sam_model_registry[vit_model](checkpoint=sam_checkpoint) sam.to(device=device) # Use the input image