diff --git a/anylabeling/configs/auto_labeling/models.yaml b/anylabeling/configs/auto_labeling/models.yaml index 7eadfd3..0941016 100644 --- a/anylabeling/configs/auto_labeling/models.yaml +++ b/anylabeling/configs/auto_labeling/models.yaml @@ -1,3 +1,14 @@ +- name: "sam2_1_coreml_large" + display_name: Segment Anything 2.1 (Large) CoreML + download_url: https://huggingface.co/apple/coreml-sam2.1-large + encoder_model_path: SAM2_1LargeImageEncoderFLOAT16.mlpackage + decoder_model_path: SAM2_1LargeMaskDecoderFLOAT16.mlpackage + image_encoder_model_path: SAM2_1LargeImageEncoderFLOAT16.mlpackage + prompt_encoder_model_path: sSAM2_1LargePromptEncoderFLOAT16.mlpackage + input_size: 1024 + max_height: 1024 + max_width: 1024 + type: segment_anything - name: "sam2_hiera_tiny_20240803" display_name: Segment Anything 2 (Hiera-Tiny) download_url: https://huggingface.co/vietanhdev/segment-anything-2-onnx-models/resolve/main/sam2_hiera_tiny.zip diff --git a/anylabeling/services/auto_labeling/model_manager.py b/anylabeling/services/auto_labeling/model_manager.py index 86f903e..41abb69 100644 --- a/anylabeling/services/auto_labeling/model_manager.py +++ b/anylabeling/services/auto_labeling/model_manager.py @@ -21,6 +21,7 @@ from anylabeling.config import get_config, save_config import ssl +from huggingface_hub import snapshot_download ssl._create_default_https_context = ( ssl._create_unverified_context @@ -267,22 +268,7 @@ def load_model(self, config_file): self.model_download_thread.started.connect(self.model_download_worker.run) self.model_download_thread.start() - def _download_and_extract_model(self, model_config): - """Download and extract a model from model config""" - config_file = model_config["config_file"] - # Check if model is already downloaded - if not os.path.exists(config_file): - raise ValueError(self.tr("Error in loading config file.")) - with open(config_file, "r") as f: - model_config = yaml.safe_load(f) - if model_config.get("has_downloaded", False): - return - - # Download model - download_url = model_config.get("download_url", None) - if not download_url: - raise ValueError(self.tr("Missing download_url in config file.")) - tmp_dir = tempfile.mkdtemp() + def download_zip(self, tmp_dir, download_url): zip_model_path = os.path.join(tmp_dir, "model.zip") # Download url @@ -307,14 +293,11 @@ def _progress(count, block_size, total_size): print(f"Could not download {download_url}: {e}") self.new_model_status.emit(f"Could not download {download_url}") return None - # Extract model tmp_extract_dir = os.path.join(tmp_dir, "extract") - extract_dir = os.path.dirname(config_file) with zipfile.ZipFile(zip_model_path, "r") as zip_ref: zip_ref.extractall(tmp_extract_dir) - - # Find model folder (containing config.yaml) + # Find model folder (containing config.yaml) model_folder = None for root, _, files in os.walk(tmp_extract_dir): if "config.yaml" in files: @@ -322,6 +305,43 @@ def _progress(count, block_size, total_size): break if model_folder is None: raise ValueError(self.tr("Could not find config.yaml in zip file.")) + return model_folder + + def download_hf(self, tmp_dir, download_url, model_config): + repo_id = download_url.split('https://huggingface.co/')[-1].strip('/') + tmp_extract_dir = os.path.join(tmp_dir, "extract") + local_dir = snapshot_download( + repo_id=repo_id, + local_dir=tmp_extract_dir # where to store everything + ) + with open(tmp_extract_dir + "/config.yaml", "w") as f: + model_config = yaml.dump(model_config, f, default_flow_style=False) + return tmp_extract_dir + + def _download_and_extract_model(self, model_config): + """Download and extract a model from model config""" + config_file = model_config["config_file"] + extract_dir = os.path.dirname(config_file) + # Check if model is already downloaded + if not os.path.exists(config_file): + raise ValueError(self.tr("Error in loading config file.")) + with open(config_file, "r") as f: + model_config = yaml.safe_load(f) + if model_config.get("has_downloaded", False): + return + + # Download model + download_url = model_config.get("download_url", None) + if not download_url: + raise ValueError(self.tr("Missing download_url in config file.")) + + tmp_dir = tempfile.mkdtemp() + if download_url.endswith('.zip'): + model_folder = self.download_zip(tmp_dir, download_url) + + if download_url.startswith('https://huggingface.co'): + model_folder = self.download_hf(tmp_dir, download_url, model_config) + # Move model folder to correct location shutil.rmtree(extract_dir) diff --git a/anylabeling/services/auto_labeling/sam2_coreml.py b/anylabeling/services/auto_labeling/sam2_coreml.py new file mode 100644 index 0000000..41ea0b9 --- /dev/null +++ b/anylabeling/services/auto_labeling/sam2_coreml.py @@ -0,0 +1,115 @@ +import os +import cv2 +import numpy as np +import coremltools as ct +from pathlib import Path +from PIL import Image + + +class SegmentAnything2CoreML: + def __init__(self, model_path: str) -> None: + print("using CoreML", model_path) + image_decoder_path = os.path.join( + model_path, "SAM2_1LargeImageEncoderFLOAT16.mlpackage" + ) + mask_decoder_path = os.path.join( + model_path, "SAM2_1LargeMaskDecoderFLOAT16.mlpackage" + ) + prompt_encoder_path = os.path.join( + model_path, "SAM2_1LargePromptEncoderFLOAT16.mlpackage" + ) + self.image_encoder = ct.models.MLModel(image_decoder_path) + self.mask_decoder = ct.models.MLModel(mask_decoder_path) + self.prompt_encoder = ct.models.MLModel(prompt_encoder_path) + self.input_size = (1024, 1024) + + def encode(self, cv_image: np.ndarray) -> dict: + """Encodes the input image using the image encoder.""" + # Convert OpenCV image to PIL Image + pil_image = Image.fromarray(cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB)) + + # Resize image to input_size + original_size = pil_image.size + resized_image = pil_image.resize(self.input_size, Image.Resampling.LANCZOS) + + # Predict image embeddings + embeddings = self.image_encoder.predict({"image": resized_image}) + + return { + "high_res_feats_0": embeddings["feats_s0"], + "high_res_feats_1": embeddings["feats_s1"], + "image_embedding": embeddings["image_embedding"], + "original_size": original_size, + } + + def predict_masks(self, embedding: dict, prompt: list) -> list[np.ndarray]: + """Predicts masks based on image embedding and prompt.""" + points = [] + labels = [] + for mark in prompt: + if mark["type"] == "point": + # Scale point coordinates to match the model's input size + x_scaled = mark["data"][0] * ( + self.input_size[0] / embedding["original_size"][0] + ) + y_scaled = mark["data"][1] * ( + self.input_size[1] / embedding["original_size"][1] + ) + points.append([x_scaled, y_scaled]) + labels.append(mark["label"]) + elif mark["type"] == "rectangle": + # Scale rectangle coordinates + x1_scaled = mark["data"][0] * ( + self.input_size[0] / embedding["original_size"][0] + ) + y1_scaled = mark["data"][1] * ( + self.input_size[1] / embedding["original_size"][1] + ) + x2_scaled = mark["data"][2] * ( + self.input_size[0] / embedding["original_size"][0] + ) + y2_scaled = mark["data"][3] * ( + self.input_size[1] / embedding["original_size"][1] + ) + points.append([x1_scaled, y1_scaled]) + points.append([x2_scaled, y2_scaled]) + labels.append(2) # Label for top-left of box + labels.append(3) # Label for bottom-right of box + + points_array = np.array(points, dtype=np.float32).reshape(1, len(points), 2) + labels_array = np.array(labels, dtype=np.int32).reshape(1, len(labels)) + + # Get prompt embeddings + prompt_embeddings = self.prompt_encoder.predict( + {"points": points_array, "labels": labels_array} + ) + + # Predict masks + mask_output = self.mask_decoder.predict( + { + "image_embedding": embedding["image_embedding"], + "sparse_embedding": prompt_embeddings["sparse_embeddings"], + "dense_embedding": prompt_embeddings["dense_embeddings"], + "feats_s0": embedding["high_res_feats_0"], + "feats_s1": embedding["high_res_feats_1"], + } + ) + + # The model returns low_res_masks, which need to be upscaled and thresholded + low_res_masks = mask_output["low_res_masks"] + + # Select the best mask based on score + scores = mask_output["scores"] + best_mask_idx = np.argmax(scores) + mask = low_res_masks[0, best_mask_idx] # Assuming batch size of 1 + + # Resize the mask back to the original image size + original_width, original_height = embedding["original_size"] + mask = cv2.resize( + mask, (original_width, original_height), interpolation=cv2.INTER_LINEAR + ) + + # Apply threshold to get a binary mask + mask = (mask > 0).astype(np.uint8) * 255 # Convert to 0 or 255 + + return np.array([mask]) # Return as a list for consistency diff --git a/anylabeling/services/auto_labeling/segment_anything.py b/anylabeling/services/auto_labeling/segment_anything.py index 4d9274a..f757361 100644 --- a/anylabeling/services/auto_labeling/segment_anything.py +++ b/anylabeling/services/auto_labeling/segment_anything.py @@ -18,7 +18,7 @@ from .types import AutoLabelingResult from .sam_onnx import SegmentAnythingONNX from .sam2_onnx import SegmentAnything2ONNX - +from .sam2_coreml import SegmentAnything2CoreML class SegmentAnything(Model): """Segmentation model using SegmentAnything""" @@ -57,7 +57,7 @@ def __init__(self, config_path, on_message) -> None: encoder_model_abs_path = self.get_model_abs_path( self.config, "encoder_model_path" ) - if not encoder_model_abs_path or not os.path.isfile(encoder_model_abs_path): + if not encoder_model_abs_path or not (os.path.isfile(encoder_model_abs_path) or os.path.isdir(encoder_model_abs_path)): raise FileNotFoundError( QCoreApplication.translate( "Model", @@ -67,7 +67,7 @@ def __init__(self, config_path, on_message) -> None: decoder_model_abs_path = self.get_model_abs_path( self.config, "decoder_model_path" ) - if not decoder_model_abs_path or not os.path.isfile(decoder_model_abs_path): + if not decoder_model_abs_path or not (os.path.isfile(decoder_model_abs_path) or os.path.isdir(decoder_model_abs_path)): raise FileNotFoundError( QCoreApplication.translate( "Model", @@ -76,7 +76,10 @@ def __init__(self, config_path, on_message) -> None: ) # Load models - if self.detect_model_variant(decoder_model_abs_path) == "sam2": + if "coreml" in decoder_model_abs_path: + config_folder = os.path.dirname(decoder_model_abs_path) + self.model = SegmentAnything2CoreML(config_folder) + elif self.detect_model_variant(decoder_model_abs_path) == "sam2": self.model = SegmentAnything2ONNX( encoder_model_abs_path, decoder_model_abs_path ) diff --git a/requirements-macos.txt b/requirements-macos.txt index 85db885..5d63f69 100644 --- a/requirements-macos.txt +++ b/requirements-macos.txt @@ -8,3 +8,4 @@ onnx==1.16.1 onnxruntime==1.18.1 qimage2ndarray==1.10.0 darkdetect==0.8.0 +coremltools==8.3.0 \ No newline at end of file diff --git a/setup.py b/setup.py index 2bb6e80..58c0d36 100644 --- a/setup.py +++ b/setup.py @@ -51,6 +51,7 @@ def get_install_requires(): "onnx==1.16.1", "qimage2ndarray==1.10.0", "darkdetect==0.8.0", + 'coremltools==8.3.0; platform_system == "Darwin"', ] # Add onnxruntime-gpu if GPU is preferred