diff --git a/docker/dockerfiles/Dockerfile.onnx.gpu b/docker/dockerfiles/Dockerfile.onnx.gpu index af4a8a2799..480459792e 100644 --- a/docker/dockerfiles/Dockerfile.onnx.gpu +++ b/docker/dockerfiles/Dockerfile.onnx.gpu @@ -115,4 +115,4 @@ ENV ENABLE_STREAM_API=True ENV ENABLE_PROMETHEUS=True ENV STREAM_API_PRELOADED_PROCESSES=2 -ENTRYPOINT uvicorn gpu_http:app --workers $NUM_WORKERS --host $HOST --port $PORT +ENTRYPOINT uvicorn gpu_http:app --workers $NUM_WORKERS --host $HOST --port $PORT \ No newline at end of file diff --git a/docker/dockerfiles/Dockerfile.onnx.gpu.3d b/docker/dockerfiles/Dockerfile.onnx.gpu.3d new file mode 100644 index 0000000000..38ff240699 --- /dev/null +++ b/docker/dockerfiles/Dockerfile.onnx.gpu.3d @@ -0,0 +1,141 @@ +FROM nvcr.io/nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04 as builder + +WORKDIR /app + +RUN rm -rf /var/lib/apt/lists/* && apt-get clean && apt-get update -y && DEBIAN_FRONTEND=noninteractive apt-get install -y \ + libxext6 \ + libopencv-dev \ + uvicorn \ + python3-pip \ + git \ + libgdal-dev \ + libvips-dev \ + wget \ + rustc \ + cargo \ + && rm -rf /var/lib/apt/lists/* + +COPY requirements/requirements.sam.txt \ + requirements/requirements.sam3.txt \ + requirements/requirements.clip.txt \ + requirements/requirements.http.txt \ + requirements/requirements.gpu.txt \ + requirements/requirements.gaze.txt \ + requirements/requirements.doctr.txt \ + requirements/requirements.groundingdino.txt \ + requirements/requirements.yolo_world.txt \ + requirements/_requirements.txt \ + requirements/requirements.transformers.txt \ + requirements/requirements.pali.flash_attn.txt \ + requirements/requirements.easyocr.txt \ + requirements/requirements.modal.txt \ + requirements/requirements.sam3_3d.txt \ + ./ + +RUN python3 -m pip install -U pip uv +RUN uv pip install --system \ + -r _requirements.txt \ + -r requirements.doctr.txt \ + -r requirements.sam.txt \ + -r requirements.sam3.txt \ + -r requirements.clip.txt \ + -r requirements.http.txt \ + -r requirements.gpu.txt \ + -r requirements.gaze.txt \ + -r requirements.groundingdino.txt \ + -r requirements.yolo_world.txt \ + -r requirements.transformers.txt \ + -r requirements.easyocr.txt \ + -r requirements.modal.txt \ + jupyterlab \ + "setuptools<=75.5.0" \ + --upgrade \ + && rm -rf ~/.cache/pip + +# Install setup.py requirements for flash_attn +RUN python3 -m pip install packaging==24.1 && rm -rf ~/.cache/pip + +# Install flash_attn required for Paligemma and Florence2 +RUN python3 -m pip install -r requirements.pali.flash_attn.txt --no-dependencies --no-build-isolation && rm -rf ~/.cache/pip + +ENV TORCH_CUDA_ARCH_LIST="7.5;8.0;8.6;8.9;9.0" +RUN python3 -m pip install --no-cache-dir --no-build-isolation -r requirements.sam3_3d.txt && rm -rf ~/.cache/pip +# Start runtime stage +FROM nvcr.io/nvidia/cuda:12.4.1-cudnn-runtime-ubuntu22.04 as runtime + +WORKDIR /app + +# Copy Python and installed packages from builder +COPY --from=builder /usr/local/lib/python3.10 /usr/local/lib/python3.10 +COPY --from=builder /usr/local/bin /usr/local/bin + +# Install runtime dependencies +ADD https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb /tmp/cuda-keyring.deb +RUN set -eux; \ + rm -rf /var/lib/apt/lists/*; apt-get clean; \ + dpkg -i /tmp/cuda-keyring.deb || true; \ + rm -f /tmp/cuda-keyring.deb; \ + apt-get update -y; \ + DEBIAN_FRONTEND=noninteractive apt-get install -y \ + libxext6 \ + libopencv-dev \ + uvicorn \ + python3-pip \ + git \ + libgdal-dev \ + libvips-dev \ + wget \ + rustc \ + cargo \ + libgl1 \ + libegl1 \ + libgles2 \ + libglvnd0 \ + libglx0 \ + cuda-nvcc-12-4 \ + cuda-cudart-dev-12-4 \ + libcusparse-dev-12-4 \ + libcublas-dev-12-4 \ + libcusolver-dev-12-4 \ + libcurand-dev-12-4 \ + libcufft-dev-12-4; \ + rm -rf /var/lib/apt/lists/* + +WORKDIR /build +COPY . . +RUN ln -s /usr/bin/python3 /usr/bin/python +RUN /bin/make create_wheels_for_gpu_notebook +RUN pip3 install --no-cache-dir dist/inference_cli*.whl dist/inference_core*.whl dist/inference_gpu*.whl dist/inference_sdk*.whl "setuptools<=75.5.0" + + +WORKDIR /notebooks +COPY examples/notebooks . + +WORKDIR /app/ +COPY inference inference +COPY docker/config/gpu_http.py gpu_http.py + +ENV VERSION_CHECK_MODE=continuous +ENV PROJECT=roboflow-platform +ENV NUM_WORKERS=1 +ENV HOST=0.0.0.0 +ENV PORT=9001 +ENV WORKFLOWS_STEP_EXECUTION_MODE=local +ENV WORKFLOWS_MAX_CONCURRENT_STEPS=4 +ENV API_LOGGING_ENABLED=True +ENV LMM_ENABLED=True +ENV CORE_MODEL_SAM2_ENABLED=True +ENV CORE_MODEL_SAM3_ENABLED=True +ENV CORE_MODEL_OWLV2_ENABLED=True +ENV ENABLE_STREAM_API=True +ENV ENABLE_PROMETHEUS=True +ENV STREAM_API_PRELOADED_PROCESSES=2 +ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility,graphics +ENV NVIDIA_VISIBLE_DEVICES=all +ENV TORCH_EXTENSIONS_DIR=/tmp/torch_extensions +ENV SPARSE_ATTN_BACKEND="flash_attn" +ENV ATTN_BACKEND="flash_attn" +ENV MODEL_LOCK_ACQUIRE_TIMEOUT="300" +ENV SAM3_3D_OBJECTS_ENABLED=True + +ENTRYPOINT uvicorn gpu_http:app --workers $NUM_WORKERS --host $HOST --port $PORT diff --git a/inference/core/entities/requests/sam3_3d.py b/inference/core/entities/requests/sam3_3d.py new file mode 100644 index 0000000000..d279823204 --- /dev/null +++ b/inference/core/entities/requests/sam3_3d.py @@ -0,0 +1,37 @@ +from typing import Any, Dict, List, Optional, Union + +from pydantic import Field, validator + +from inference.core.entities.requests.inference import ( + BaseRequest, + InferenceRequestImage, +) + + +class Sam3_3D_Objects_InferenceRequest(BaseRequest): + """SAM3D inference request for 3D object generation. + + Attributes: + api_key (Optional[str]): Roboflow API Key. + image (InferenceRequestImage): The input image to be used for 3D generation. + mask_input: Mask(s) in any supported format - polygon, binary mask, or RLE. + """ + + image: InferenceRequestImage = Field( + description="The input image to be used for 3D generation.", + ) + + mask_input: Any = Field( + description="Mask input in any supported format: " + "polygon [x1,y1,x2,y2,...], binary mask (base64), RLE dict, or list of these.", + ) + + model_id: Optional[str] = Field( + default="sam3-3d-objects", description="The model ID for SAM3_3D." + ) + + @validator("model_id", always=True) + def validate_model_id(cls, value): + if value is not None: + return value + return "sam3-3d-objects" diff --git a/inference/core/entities/responses/sam3_3d.py b/inference/core/entities/responses/sam3_3d.py new file mode 100644 index 0000000000..9f20d89f1b --- /dev/null +++ b/inference/core/entities/responses/sam3_3d.py @@ -0,0 +1,54 @@ +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, Field + + +class Sam3_3D_Objects_Metadata(BaseModel): + rotation: Optional[List[float]] = Field( + default=None, + description="Rotation transformation parameters (quaternion, 4 floats)", + ) + translation: Optional[List[float]] = Field( + default=None, description="Translation transformation parameters (x, y, z)" + ) + scale: Optional[List[float]] = Field( + default=None, description="Scale transformation parameters (x, y, z)" + ) + + +class Sam3_3D_Object_Item(BaseModel): + """Individual 3D object output with mesh, gaussian, and transformation metadata.""" + + mesh_glb: Optional[bytes] = Field( + default=None, description="The 3D mesh in GLB format (binary)" + ) + gaussian_ply: Optional[bytes] = Field( + default=None, description="The Gaussian splatting in PLY format (binary)" + ) + metadata: Sam3_3D_Objects_Metadata = Field( + default_factory=Sam3_3D_Objects_Metadata, + description="3D transformation metadata (rotation, translation, scale)", + ) + + class Config: + arbitrary_types_allowed = True + + +class Sam3_3D_Objects_Response(BaseModel): + mesh_glb: Optional[bytes] = Field( + default=None, description="The 3D scene mesh in GLB format (binary)" + ) + gaussian_ply: Optional[bytes] = Field( + default=None, + description="The combined Gaussian splatting in PLY format (binary)", + ) + objects: List[Sam3_3D_Object_Item] = Field( + default=[], + description="List of individual 3D objects with their meshes, gaussians, and metadata", + ) + time: float = Field( + description="The time in seconds it took to produce the 3D outputs including preprocessing" + ) + + class Config: + arbitrary_types_allowed = True diff --git a/inference/core/env.py b/inference/core/env.py index 26358764fa..c665879077 100644 --- a/inference/core/env.py +++ b/inference/core/env.py @@ -201,6 +201,8 @@ FLORENCE2_ENABLED = str2bool(os.getenv("FLORENCE2_ENABLED", True)) +SAM3_3D_OBJECTS_ENABLED = str2bool(os.getenv("SAM3_3D_OBJECTS_ENABLED", False)) + # Flag to enable YOLO-World core model, default is True CORE_MODEL_YOLO_WORLD_ENABLED = str2bool( os.getenv("CORE_MODEL_YOLO_WORLD_ENABLED", True) diff --git a/inference/core/registries/roboflow.py b/inference/core/registries/roboflow.py index 37806b33ab..c105c508ef 100644 --- a/inference/core/registries/roboflow.py +++ b/inference/core/registries/roboflow.py @@ -49,6 +49,7 @@ "sam2": ("embed", "sam2"), "sam3": ("embed", "sam3"), "sam3/sam3_interactive": ("interactive-segmentation", "sam3"), + "sam3-3d-objects": ("3d-reconstruction", "sam3-3d-objects"), "gaze": ("gaze", "l2cs"), "doctr": ("ocr", "doctr"), "easy_ocr": ("ocr", "easy_ocr"), @@ -158,8 +159,12 @@ def get_model_type( MissingDefaultModelError: If default model is not configured and API does not provide this info MalformedRoboflowAPIResponseError: Roboflow API responds in invalid format. """ + model_id = resolve_roboflow_model_alias(model_id=model_id) dataset_id, version_id = get_model_id_chunks(model_id=model_id) + print( + f"Resolved model_id: {model_id}, dataset_id: {dataset_id}, version_id: {version_id}" + ) # first check if the model id as a whole is in the GENERIC_MODELS dictionary if model_id in GENERIC_MODELS: diff --git a/inference/core/utils/roboflow.py b/inference/core/utils/roboflow.py index 9fd4f8baea..f16c9db48d 100644 --- a/inference/core/utils/roboflow.py +++ b/inference/core/utils/roboflow.py @@ -34,6 +34,7 @@ def get_model_id_chunks( "moondream2", "depth-anything-v2", "perception_encoder", + "sam3-3d-objects", }: return dataset_id, version_id diff --git a/inference/core/workflows/core_steps/loader.py b/inference/core/workflows/core_steps/loader.py index 80b3fa17c6..9f2f2b8269 100644 --- a/inference/core/workflows/core_steps/loader.py +++ b/inference/core/workflows/core_steps/loader.py @@ -5,6 +5,7 @@ ALLOW_WORKFLOW_BLOCKS_ACCESSING_ENVIRONMENTAL_VARIABLES, ALLOW_WORKFLOW_BLOCKS_ACCESSING_LOCAL_STORAGE, API_KEY, + SAM3_3D_OBJECTS_ENABLED, WORKFLOW_BLOCKS_WRITE_DIRECTORY, WORKFLOWS_STEP_EXECUTION_MODE, ) @@ -242,6 +243,9 @@ from inference.core.workflows.core_steps.models.foundation.segment_anything3.v2 import ( SegmentAnything3BlockV2, ) +from inference.core.workflows.core_steps.models.foundation.segment_anything3_3d.v1 import ( + SegmentAnything3_3D_ObjectsBlockV1, +) from inference.core.workflows.core_steps.models.foundation.smolvlm.v1 import ( SmolVLM2BlockV1, ) @@ -670,6 +674,7 @@ def load_blocks() -> List[Type[WorkflowBlock]]: SegmentAnything2BlockV1, SegmentAnything3BlockV1, SegmentAnything3BlockV2, + SegmentAnything3_3D_ObjectsBlockV1, SegPreviewBlockV1, StabilityAIInpaintingBlockV1, StabilityAIImageGenBlockV1, diff --git a/inference/core/workflows/core_steps/models/foundation/segment_anything3_3d/__init__.py b/inference/core/workflows/core_steps/models/foundation/segment_anything3_3d/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/inference/core/workflows/core_steps/models/foundation/segment_anything3_3d/v1.py b/inference/core/workflows/core_steps/models/foundation/segment_anything3_3d/v1.py new file mode 100644 index 0000000000..a2339e3d39 --- /dev/null +++ b/inference/core/workflows/core_steps/models/foundation/segment_anything3_3d/v1.py @@ -0,0 +1,205 @@ +import base64 +from typing import Any, List, Literal, Optional, Type, Union + +import numpy as np +import supervision as sv +from pydantic import ConfigDict, Field + +from inference.core.entities.requests.sam3_3d import Sam3_3D_Objects_InferenceRequest +from inference.core.entities.responses.sam3_3d import Sam3_3D_Objects_Response +from inference.core.managers.base import ModelManager +from inference.core.workflows.core_steps.common.entities import StepExecutionMode +from inference.core.workflows.execution_engine.entities.base import ( + Batch, + OutputDefinition, + WorkflowImageData, +) +from inference.core.workflows.execution_engine.entities.types import ( + FLOAT_KIND, + IMAGE_KIND, + INSTANCE_SEGMENTATION_PREDICTION_KIND, + LIST_OF_VALUES_KIND, + STRING_KIND, + ImageInputField, + Selector, +) +from inference.core.workflows.prototypes.block import ( + BlockResult, + WorkflowBlock, + WorkflowBlockManifest, +) + +LONG_DESCRIPTION = """ +Generate 3D meshes and Gaussian splatting from 2D images with mask prompts. + +Accepts masks as: sv.Detections (from SAM2 etc), polygon lists, binary masks, or RLE dicts. +""" + + +class BlockManifest(WorkflowBlockManifest): + model_config = ConfigDict( + json_schema_extra={ + "name": "SAM3D", + "version": "v1", + "short_description": "Generate 3D meshes and Gaussian splatting from 2D images with mask prompts.", + "long_description": LONG_DESCRIPTION, + "license": "Apache-2.0", + "block_type": "model", + "search_keywords": ["SAM3_3D", "3D", "mesh", "gaussian splatting"], + "ui_manifest": { + "section": "model", + "icon": "far fa-cube", + "blockPriority": 9.0, + "needsGPU": True, + "inference": True, + }, + }, + protected_namespaces=(), + ) + + type: Literal["roboflow_core/segment_anything3_3d_objects@v1"] + images: Selector(kind=[IMAGE_KIND]) = ImageInputField + mask_input: Selector( + kind=[LIST_OF_VALUES_KIND, INSTANCE_SEGMENTATION_PREDICTION_KIND] + ) = Field( + description="Mask input - either instance segmentation predictions (e.g., from SAM2) or a flat list of polygon coordinates in COCO format [x1, y1, x2, y2, x3, y3, ...]", + examples=["$steps.sam2.predictions", "$steps.detections.mask_polygon"], + ) + + @classmethod + def get_parameters_accepting_batches(cls) -> List[str]: + return ["images", "mask_input"] + + @classmethod + def describe_outputs(cls) -> List[OutputDefinition]: + return [ + OutputDefinition( + name="mesh_glb", + kind=[STRING_KIND], + description="Scene mesh in GLB format (base64 encoded)", + ), + OutputDefinition( + name="gaussian_ply", + kind=[STRING_KIND], + description="Combined Gaussian splatting in PLY format (base64 encoded)", + ), + OutputDefinition( + name="objects", + kind=[LIST_OF_VALUES_KIND], + description="List of individual objects, each with mesh_glb, gaussian_ply, and metadata (rotation, translation, scale)", + ), + OutputDefinition( + name="inference_time", + kind=[FLOAT_KIND], + ), + ] + + @classmethod + def get_execution_engine_compatibility(cls) -> Optional[str]: + return ">=1.3.0,<2.0.0" + + +class SegmentAnything3_3D_ObjectsBlockV1(WorkflowBlock): + + def __init__( + self, + model_manager: ModelManager, + api_key: Optional[str], + step_execution_mode: StepExecutionMode, + ): + self._model_manager = model_manager + self._api_key = api_key + self._step_execution_mode = step_execution_mode + + @classmethod + def get_init_parameters(cls) -> List[str]: + return ["model_manager", "api_key", "step_execution_mode"] + + @classmethod + def get_manifest(cls) -> Type[WorkflowBlockManifest]: + return BlockManifest + + def run( + self, + images: Batch[WorkflowImageData], + mask_input: Batch[Union[sv.Detections, List[float]]], + ) -> BlockResult: + if self._step_execution_mode is StepExecutionMode.LOCAL: + return self.run_locally( + images=images, + mask_input=mask_input, + ) + elif self._step_execution_mode is StepExecutionMode.REMOTE: + raise NotImplementedError( + "Remote execution is not supported for Segment Anything 3_3D. Run a local or dedicated inference server to use this block (GPU strongly recommended)." + ) + else: + raise ValueError( + f"Unknown step execution mode: {self._step_execution_mode}" + ) + + def run_locally( + self, + images: Batch[WorkflowImageData], + mask_input: Batch[Union[sv.Detections, List[float]]], + ) -> BlockResult: + results = [] + model_id = "sam3-3d-objects" + + self._model_manager.add_model(model_id=model_id, api_key=self._api_key) + + for single_image, single_mask_input in zip(images, mask_input): + converted_mask = extract_masks_from_input(single_mask_input) + + inference_request = Sam3_3D_Objects_InferenceRequest( + image=single_image.to_inference_format(numpy_preferred=True), + mask_input=converted_mask, + api_key=self._api_key, + model_id=model_id, + ) + + response: Sam3_3D_Objects_Response = ( + self._model_manager.infer_from_request_sync(model_id, inference_request) + ) + + results.append(_format_response(response)) + + return results + + +def extract_masks_from_input(mask_input: Any) -> Any: + """Extract binary masks from sv.Detections, pass through other formats.""" + if isinstance(mask_input, sv.Detections): + if len(mask_input) == 0: + raise ValueError("sv.Detections contains no detections.") + if mask_input.mask is not None and len(mask_input.mask) > 0: + return list(mask_input.mask) + raise ValueError("sv.Detections has no mask data.") + return mask_input + + +def _format_response(response: Sam3_3D_Objects_Response) -> dict: + """Format response with base64 encoded outputs.""" + + def encode(data): + return base64.b64encode(data).decode("utf-8") if data else None + + objects_list = [ + { + "mesh_glb": encode(obj.mesh_glb), + "gaussian_ply": encode(obj.gaussian_ply), + "metadata": { + "rotation": obj.metadata.rotation, + "translation": obj.metadata.translation, + "scale": obj.metadata.scale, + }, + } + for obj in response.objects + ] + + return { + "mesh_glb": encode(response.mesh_glb), + "gaussian_ply": encode(response.gaussian_ply), + "objects": objects_list, + "inference_time": response.time, + } diff --git a/inference/models/README.md b/inference/models/README.md index 6c2847e05f..e0167b849a 100644 --- a/inference/models/README.md +++ b/inference/models/README.md @@ -34,7 +34,7 @@ The models supported by Roboflow Inference have their own licenses. View the lic | `inference/models/moondream2` | [Apache 2.0](https://github.com/vikhyat/moondream/blob/main/LICENSE) | 👍 | | `inference/models/perception_encoder` | [Apache 2.0](https://github.com/facebookresearch/perception_models/blob/main/LICENSE.PE) | 👍 | | `inference/models/sam3` | [SAM License](https://github.com/facebookresearch/sam3/blob/main/LICENSE) | 👍 | - +| `inference/models/sam3_3d` | [SAM License](https://github.com/facebookresearch/sam-3d-objects/blob/main/LICENSE) | 👍 | ## Commercial Licenses Models listed with a 👍 above are permissively licensed for commercial use by default. Typically no additional license is needed. diff --git a/inference/models/__init__.py b/inference/models/__init__.py index 2cffc106ce..9365bd19a1 100644 --- a/inference/models/__init__.py +++ b/inference/models/__init__.py @@ -1,5 +1,17 @@ import importlib from typing import Any +#Preinit nvdiffrast for SAM3D as it breaks if any flash attn model is loaded in first +try: + import torch + if torch.cuda.is_available(): + import utils3d.torch + _nvdiffrast_ctx = utils3d.torch.RastContext(backend='cuda') + _dummy_verts = torch.zeros(1, 3, 3, device='cuda') + _dummy_faces = torch.tensor([[0, 1, 2]], dtype=torch.int32, device='cuda') + _ = utils3d.torch.rasterize_triangle_faces(_nvdiffrast_ctx, _dummy_verts, _dummy_faces, 64, 64) + del _dummy_verts, _dummy_faces, _ +except: + pass from inference.core.env import ( CORE_MODEL_CLIP_ENABLED, @@ -13,6 +25,7 @@ CORE_MODEL_YOLO_WORLD_ENABLED, CORE_MODELS_ENABLED, DEPTH_ESTIMATION_ENABLED, + SAM3_3D_OBJECTS_ENABLED, ) _MODEL_REGISTRY: dict[str, Any] = {} @@ -23,6 +36,7 @@ "SegmentAnything": ("inference.models.sam", CORE_MODEL_SAM_ENABLED), "SegmentAnything2": ("inference.models.sam2", CORE_MODEL_SAM2_ENABLED), "SegmentAnything3": ("inference.models.sam3", CORE_MODEL_SAM3_ENABLED), + "SegmentAnything3_3D_Objects": ("inference.models.sam3_3d", SAM3_3D_OBJECTS_ENABLED), "Sam3ForInteractiveImageSegmentation": ( "inference.models.sam3", CORE_MODEL_SAM3_ENABLED, diff --git a/inference/models/aliases.py b/inference/models/aliases.py index 2c0cd3dea1..0c5ddba9e3 100644 --- a/inference/models/aliases.py +++ b/inference/models/aliases.py @@ -56,6 +56,10 @@ "smolvlm2": "smolvlm-2.2b-instruct", } +SAM3_3D_ALIASES = { + "sam3-3d-objects": "sam3-3d-weights-vc6vz/1", +} + RFDETR_ALIASES = { "rfdetr-base": "coco/36", "rfdetr-large": "coco/37", @@ -105,6 +109,7 @@ **YOLOV11_ALIASES, **QWEN_ALIASES, **RFDETR_ALIASES, + **SAM3_3D_ALIASES, } diff --git a/inference/models/sam3_3d/LICENSE.txt b/inference/models/sam3_3d/LICENSE.txt new file mode 100644 index 0000000000..1d31b947c2 --- /dev/null +++ b/inference/models/sam3_3d/LICENSE.txt @@ -0,0 +1,52 @@ +SAM License +Last Updated: November 19, 2025 + +“Agreement” means the terms and conditions for use, reproduction, distribution and modification of the SAM Materials set forth herein. + +“SAM Materials” means, collectively, Documentation and the models, software and algorithms, including machine-learning model code, trained model weights, inference-enabling code, training-enabling code, fine-tuning enabling code, and other elements of the foregoing distributed by Meta and made available under this Agreement. + +“Documentation” means the specifications, manuals and documentation accompanying +SAM Materials distributed by Meta. + +“Licensee” or “you” means you, or your employer or any other person or entity (if you are entering into this Agreement on such person or entity’s behalf), of the age required under applicable laws, rules or regulations to provide legal consent and that has legal authority to bind your employer or such other person or entity if you are entering in this Agreement on their behalf. + +“Meta” or “we” means Meta Platforms Ireland Limited (if you are located in or, if you are an entity, your principal place of business is in the EEA or Switzerland) or Meta Platforms, Inc. (if you are located outside of the EEA or Switzerland). + +“Sanctions” means any economic or trade sanctions or restrictions administered or enforced by the United States (including the Office of Foreign Assets Control of the U.S. Department of the Treasury (“OFAC”), the U.S. Department of State and the U.S. Department of Commerce), the United Nations, the European Union, or the United Kingdom. + +“Trade Controls” means any of the following: Sanctions and applicable export and import controls. + +By using or distributing any portion or element of the SAM Materials, you agree to be bound by this Agreement. + +1. License Rights and Redistribution. + +a. Grant of Rights. You are granted a non-exclusive, worldwide, non-transferable and royalty-free limited license under Meta’s intellectual property or other rights owned by Meta embodied in the SAM Materials to use, reproduce, distribute, copy, create derivative works of, and make modifications to the SAM Materials. + +i. Grant of Patent License. Subject to the terms and conditions of this License, you are granted a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by Meta that are necessarily infringed alone or by combination of their contribution(s) with the SAM 3 Materials. If you institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the SAM 3 Materials incorporated within the work constitutes direct or contributory patent infringement, then any patent licenses granted to you under this License for that work shall terminate as of the date such litigation is filed. + +b. Redistribution and Use. + +i. Distribution of SAM Materials, and any derivative works thereof, are subject to the terms of this Agreement. If you distribute or make the SAM Materials, or any derivative works thereof, available to a third party, you may only do so under the terms of this Agreement and you shall provide a copy of this Agreement with any such SAM Materials. + +ii. If you submit for publication the results of research you perform on, using, or otherwise in connection with SAM Materials, you must acknowledge the use of SAM Materials in your publication. + +iii. Your use of the SAM Materials must comply with applicable laws and regulations, including Trade Control Laws and applicable privacy and data protection laws. +iv. Your use of the SAM Materials will not involve or encourage others to reverse engineer, decompile or discover the underlying components of the SAM Materials. +v. You are not the target of Trade Controls and your use of SAM Materials must comply with Trade Controls. You agree not to use, or permit others to use, SAM Materials for any activities subject to the International Traffic in Arms Regulations (ITAR) or end uses prohibited by Trade Controls, including those related to military or warfare purposes, nuclear industries or applications, espionage, or the development or use of guns or illegal weapons. +2. User Support. Your use of the SAM Materials is done at your own discretion; Meta does not process any information nor provide any service in relation to such use. Meta is under no obligation to provide any support services for the SAM Materials. Any support provided is “as is”, “with all faults”, and without warranty of any kind. + +3. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE SAM MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN “AS IS” BASIS, WITHOUT WARRANTIES OF ANY KIND, AND META DISCLAIMS ALL WARRANTIES OF ANY KIND, BOTH EXPRESS AND IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE SAM MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE SAM MATERIALS AND ANY OUTPUT AND RESULTS. + +4. Limitation of Liability. IN NO EVENT WILL META OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY DIRECT OR INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING. + +5. Intellectual Property. + +a. Subject to Meta’s ownership of SAM Materials and derivatives made by or for Meta, with respect to any derivative works and modifications of the SAM Materials that are made by you, as between you and Meta, you are and will be the owner of such derivative works and modifications. + +b. If you institute litigation or other proceedings against Meta or any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the SAM Materials, outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. You will indemnify and hold harmless Meta from and against any claim by any third party arising out of or related to your use or distribution of the SAM Materials. + +6. Term and Termination. The term of this Agreement will commence upon your acceptance of this Agreement or access to the SAM Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. Meta may terminate this Agreement if you are in breach of any term or condition of this Agreement. Upon termination of this Agreement, you shall delete and cease use of the SAM Materials. Sections 3, 4 and 7 shall survive the termination of this Agreement. + +7. Governing Law and Jurisdiction. This Agreement will be governed and construed under the laws of the State of California without regard to choice of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement. The courts of California shall have exclusive jurisdiction of any dispute arising out of this Agreement. + +8. Modifications and Amendments. Meta may modify this Agreement from time to time; provided that they are similar in spirit to the current version of the Agreement, but may differ in detail to address new problems or concerns. All such changes will be effective immediately. Your continued use of the SAM Materials after any modification to this Agreement constitutes your agreement to such modification. Except as provided in this Agreement, no modification or addition to any provision of this Agreement will be binding unless it is in writing and signed by an authorized representative of both you and Meta. \ No newline at end of file diff --git a/inference/models/sam3_3d/README.md b/inference/models/sam3_3d/README.md new file mode 100644 index 0000000000..7b70dd0bb4 --- /dev/null +++ b/inference/models/sam3_3d/README.md @@ -0,0 +1,78 @@ +# SAM3-3D Model + +3D object generation model that converts 2D images with masks into 3D assets (meshes and Gaussian splats). + +## Installation + +```bash +pip install --no-cache-dir --no-build-isolation -r requirements/requirements.sam3_3d.txt +``` + +## Docker + +``` +docker/dockerfiles/Dockerfile.onnx.gpu.3d +``` + +## Input + +- **image**: RGB image (PIL Image, numpy array, or inference request format) +- **mask_input**: Mask(s) defining object regions. Supports multiple formats: + +| Format | Single | Multiple | +|--------|--------|----------| +| Binary mask | `np.ndarray` (H, W) | `np.ndarray` (N, H, W) or `List[np.ndarray]` | +| Polygon (COCO flat) | `[x1, y1, x2, y2, ...]` | `[[x1, y1, ...], [x1, y1, ...]]` | +| Polygon (points) | `[[x1, y1], [x2, y2], ...]` | N/A | +| RLE | `{"counts": "...", "size": [H, W]}` | `[{"counts": ...}, ...]` | +| sv.Detections | From SAM2 or other segmentation models | Extracts all masks | + +## Output + +- **mesh_glb**: Combined 3D scene mesh (GLB binary) +- **gaussian_ply**: Combined Gaussian splatting (PLY binary) +- **objects**: List of individual objects: + - `mesh_glb`: Object mesh (GLB binary) + - `gaussian_ply`: Object Gaussian splat (PLY binary) + - `metadata`: `{rotation, translation, scale}` +- **time**: Inference time in seconds + +## Example + +```python +import os +os.environ["SAM3_3D_OBJECTS_ENABLED"] = "true" + +from inference import get_model +from inference.core.entities.requests.sam3_3d import Sam3_3D_Objects_InferenceRequest +import numpy as np + +model = get_model("sam3-3d-objects", api_key="YOUR_API_KEY") +image = {"type": "file", "value": "path/to/image.jpg"} + +# Option 1: Polygon (COCO flat format) +mask = [100.0, 100.0, 200.0, 100.0, 200.0, 200.0, 100.0, 200.0] + +# Option 2: Binary mask +mask = np.zeros((480, 640), dtype=np.uint8) +mask[100:200, 100:200] = 255 + +# Option 3: Multiple masks +mask = [ + [100.0, 100.0, 200.0, 100.0, 200.0, 200.0, 100.0, 200.0], + [300.0, 300.0, 400.0, 300.0, 400.0, 400.0, 300.0, 400.0], +] + +request = Sam3_3D_Objects_InferenceRequest(image=image, mask_input=mask) +response = model.infer_from_request(request) + +# Save outputs +if response.mesh_glb: + with open("scene.glb", "wb") as f: + f.write(response.mesh_glb) + +for i, obj in enumerate(response.objects): + if obj.mesh_glb: + with open(f"object_{i}.glb", "wb") as f: + f.write(obj.mesh_glb) +``` diff --git a/inference/models/sam3_3d/__init__.py b/inference/models/sam3_3d/__init__.py new file mode 100644 index 0000000000..e0e37c420c --- /dev/null +++ b/inference/models/sam3_3d/__init__.py @@ -0,0 +1 @@ +from inference.models.sam3_3d.segment_anything_3d import SegmentAnything3_3D_Objects #, SegmentAnything3_3D_Body \ No newline at end of file diff --git a/inference/models/sam3_3d/segment_anything_3d.py b/inference/models/sam3_3d/segment_anything_3d.py new file mode 100644 index 0000000000..76050b6d19 --- /dev/null +++ b/inference/models/sam3_3d/segment_anything_3d.py @@ -0,0 +1,481 @@ +import os +import sys +import weakref +from io import BytesIO +from pathlib import Path +from threading import Lock +from time import perf_counter +from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union + +import cv2 +import numpy as np +import torch +from filelock import FileLock +from hydra.utils import instantiate +from omegaconf import OmegaConf +from PIL import Image, ImageDraw + +from inference.core.cache.model_artifacts import get_cache_dir +from inference.core.entities.requests.inference import InferenceRequestImage +from inference.core.entities.requests.sam3_3d import Sam3_3D_Objects_InferenceRequest +from inference.core.entities.responses.sam3_3d import ( + Sam3_3D_Object_Item, + Sam3_3D_Objects_Metadata, + Sam3_3D_Objects_Response, +) +from inference.core.env import DEVICE, MODEL_CACHE_DIR +from inference.core.exceptions import ModelArtefactError +from inference.core.models.roboflow import RoboflowCoreModel +from inference.core.roboflow_api import ( + ModelEndpointType, + get_roboflow_model_data, + stream_url_to_cache, +) +from inference.core.utils.image_utils import load_image_rgb + +try: + import pycocotools.mask as mask_utils + + PYCOCOTOOLS_AVAILABLE = True +except ImportError: + PYCOCOTOOLS_AVAILABLE = False + + +def convert_mask_to_binary(mask_input: Any, image_shape: Tuple[int, int]) -> np.ndarray: + """Convert polygon, RLE, or binary mask to binary mask (H, W) with values 0/255.""" + height, width = image_shape + + if isinstance(mask_input, np.ndarray): + return _normalize_binary_mask(mask_input, image_shape) + + if isinstance(mask_input, Image.Image): + return _normalize_binary_mask(np.array(mask_input.convert("L")), image_shape) + + if isinstance(mask_input, dict) and "counts" in mask_input: + if not PYCOCOTOOLS_AVAILABLE: + raise ImportError( + "pycocotools required for RLE. Install: pip install pycocotools" + ) + rle = dict(mask_input) + if isinstance(rle.get("counts"), str): + rle["counts"] = rle["counts"].encode("utf-8") + return _normalize_binary_mask(mask_utils.decode(rle), image_shape) + + if isinstance(mask_input, list): + points = _parse_polygon_to_points(mask_input) + if not points or len(points) < 3: + return np.zeros((height, width), dtype=np.uint8) + mask = Image.new("L", (width, height), 0) + ImageDraw.Draw(mask).polygon(points, outline=255, fill=255) + return np.array(mask, dtype=np.uint8) + + raise TypeError(f"Unsupported mask type: {type(mask_input)}") + + +def _normalize_binary_mask( + mask: np.ndarray, image_shape: Tuple[int, int] +) -> np.ndarray: + """Normalize mask to uint8 with values 0/255. Returns input unchanged if already correct.""" + if mask.ndim == 3: + mask = mask[:, :, 0] + + h, w = image_shape + needs_resize = mask.shape[0] != h or mask.shape[1] != w + + # Check if already in correct format (uint8, 0/255 range, correct size) + if mask.dtype == np.uint8 and mask.max() > 1 and not needs_resize: + return mask + + # Convert to uint8 0/255 + if mask.dtype == np.bool_: + mask = mask.astype(np.uint8) * 255 + elif mask.dtype != np.uint8: + mask = ((mask > 0).astype(np.uint8)) * 255 + elif mask.max() <= 1: + mask = mask * 255 + + if needs_resize: + mask = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST) + + return mask + + +def _parse_polygon_to_points(polygon: List) -> List[Tuple[float, float]]: + if polygon is None or (isinstance(polygon, list) and len(polygon) == 0): + return [] + if isinstance(polygon, np.ndarray): + if polygon.size == 0: + return [] + if polygon.ndim == 2 and polygon.shape[1] == 2: + return [(float(p[0]), float(p[1])) for p in polygon] + return [ + (float(polygon[i]), float(polygon[i + 1])) + for i in range(0, len(polygon), 2) + ] + if isinstance(polygon[0], (int, float)): + return [ + (float(polygon[i]), float(polygon[i + 1])) + for i in range(0, len(polygon), 2) + ] + if isinstance(polygon[0], (list, tuple, np.ndarray)): + return [(float(p[0]), float(p[1])) for p in polygon] + return [] + + +def _is_single_mask_input(mask_input: Any) -> bool: + """Check if input is single mask vs list of masks.""" + if mask_input is None or ( + isinstance(mask_input, (list, np.ndarray)) and len(mask_input) == 0 + ): + return True + if isinstance(mask_input, np.ndarray): + return mask_input.ndim == 2 + if isinstance(mask_input, dict) and "counts" in mask_input: + return True + if isinstance(mask_input, list): + first = mask_input[0] + # Flat polygon: [x1, y1, x2, y2, ...] + if isinstance(first, (int, float)): + return True + # List of RLE dicts + if isinstance(first, dict) and "counts" in first: + return False + # List of 2D numpy arrays (binary masks) -> multiple masks + if isinstance(first, np.ndarray) and first.ndim == 2: + return False + # Check list/tuple elements + if isinstance(first, (list, tuple)): + # [[x1, y1], [x2, y2], ...] -> single polygon as points + if len(first) == 2 and isinstance(first[0], (int, float)): + return True + # [[x1, y1, x2, ...], [x1, y1, x2, ...]] -> multiple flat polygons + if len(first) > 2 and isinstance(first[0], (int, float)): + return False + return True + + +if torch.cuda.is_available(): + device_count = torch.cuda.device_count() + os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(i) for i in range(device_count)) +else: + os.environ["CUDA_VISIBLE_DEVICES"] = "" + +from importlib.resources import files + +import tdfy.sam3d_v1 + +if DEVICE is None: + DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" + +import trimesh +from pytorch3d.transforms import quaternion_to_matrix +from tdfy.sam3d_v1.inference_utils import make_scene, ready_gaussian_for_video_rendering + + +class Sam3_3D_ObjectsPipelineSingleton: + """Singleton to cache the heavy 3D pipeline initialization.""" + + _instances = weakref.WeakValueDictionary() + _lock = Lock() + + def __new__(cls, config_key: str): + with cls._lock: + if config_key not in cls._instances: + instance = super().__new__(cls) + instance.config_key = config_key + cls._instances[config_key] = instance + return cls._instances[config_key] + + +class SegmentAnything3_3D_Objects(RoboflowCoreModel): + def __init__( + self, + *args, + model_id: str = "sam3-3d-objects", + torch_compile: bool = False, + compile_res: int = 518, + **kwargs, + ): + super().__init__(model_id=model_id, **kwargs) + + self.cache_dir = Path(get_cache_dir(model_id=self.endpoint)) + + tdfy_dir = files(tdfy.sam3d_v1) + pipeline_config_path = tdfy_dir / "checkpoints_configs" / "pipeline.yaml" + moge_checkpoint_path = self.cache_dir / "moge-vitl.pth" + ss_generator_checkpoint_path = self.cache_dir / "ss_generator.ckpt" + slat_generator_checkpoint_path = self.cache_dir / "slat_generator.ckpt" + ss_decoder_checkpoint_path = self.cache_dir / "ss_decoder.ckpt" + slat_decoder_checkpoint_path = self.cache_dir / "slat_decoder_gs.ckpt" + slat_decodergs4_checkpoint_path = self.cache_dir / "slat_decoder_gs_4.ckpt" + slat_decoder_mesh_checkpoint_path = self.cache_dir / "slat_decoder_mesh.pt" + dinov2_ckpt_path = self.cache_dir / "dinov2_vitl14_reg4_pretrain.pth" + + config_key = f"{DEVICE}_{pipeline_config_path}" + singleton = Sam3_3D_ObjectsPipelineSingleton(config_key) + + if not hasattr(singleton, "pipeline"): + self.pipeline_config = OmegaConf.load(str(pipeline_config_path)) + self.pipeline_config["device"] = DEVICE + self.pipeline_config["workspace_dir"] = str(tdfy_dir) + self.pipeline_config["compile_model"] = torch_compile + self.pipeline_config["compile_res"] = compile_res + self.pipeline_config["depth_model"]["model"][ + "pretrained_model_name_or_path" + ] = str(moge_checkpoint_path) + self.pipeline_config["ss_generator_ckpt_path"] = str( + ss_generator_checkpoint_path + ) + self.pipeline_config["slat_generator_ckpt_path"] = str( + slat_generator_checkpoint_path + ) + self.pipeline_config["ss_decoder_ckpt_path"] = str( + ss_decoder_checkpoint_path + ) + self.pipeline_config["slat_decoder_gs_ckpt_path"] = str( + slat_decoder_checkpoint_path + ) + self.pipeline_config["slat_decoder_gs_4_ckpt_path"] = str( + slat_decodergs4_checkpoint_path + ) + self.pipeline_config["slat_decoder_mesh_ckpt_path"] = str( + slat_decoder_mesh_checkpoint_path + ) + self.pipeline_config["dinov2_ckpt_path"] = str(dinov2_ckpt_path) + singleton.pipeline = instantiate(self.pipeline_config) + + # Reference the singleton's pipeline + self.pipeline = singleton.pipeline + self._state_lock = Lock() + + def get_infer_bucket_file_list(self) -> list: + """Get the list of required files for inference. + + Returns: + list: A list of required files for inference, e.g., ["environment.json"]. + """ + return [ + "moge-vitl.pth", + "ss_generator.ckpt", + "slat_generator.ckpt", + "ss_decoder.ckpt", + "slat_decoder_gs.ckpt", + "slat_decoder_gs_4.ckpt", + "slat_decoder_mesh.pt", + ] + + def download_model_from_roboflow_api(self) -> None: + """Override parent method to use streaming downloads for large SAM3_3D model files.""" + lock_dir = MODEL_CACHE_DIR + "/_file_locks" + os.makedirs(lock_dir, exist_ok=True) + lock_file = os.path.join(lock_dir, f"{os.path.basename(self.cache_dir)}.lock") + lock = FileLock(lock_file, timeout=120) + with lock: + api_data = get_roboflow_model_data( + api_key=self.api_key, + model_id="sam3-3d-weights-vc6vz/1", + endpoint_type=ModelEndpointType.ORT, + device_id=self.device_id, + )["ort"] + if "weights" not in api_data: + raise ModelArtefactError( + f"`weights` key not available in Roboflow API response while downloading model weights." + ) + for weights_url_key in api_data["weights"]: + weights_url = api_data["weights"][weights_url_key] + filename = weights_url.split("?")[0].split("/")[-1] + stream_url_to_cache( + url=weights_url, + filename=filename, + model_id=self.endpoint, + ) + + def infer_from_request( + self, request: Sam3_3D_Objects_InferenceRequest + ) -> Sam3_3D_Objects_Response: + with self._state_lock: + t1 = perf_counter() + raw_result = self.create_3d(**request.dict()) + inference_time = perf_counter() - t1 + return convert_3d_objects_result_to_api_response( + raw_result=raw_result, + inference_time=inference_time, + ) + + def create_3d( + self, + image: Optional[InferenceRequestImage], + mask_input: Optional[Any] = None, + **kwargs, + ): + """ + Generate 3D from image and mask(s). + + Args: + image: Input image + mask_input: Mask in any supported format: + - np.ndarray (H,W) or (N,H,W): Binary mask(s) + - List[float]: COCO polygon [x1,y1,x2,y2,...] + - List[List[float]]: Multiple polygons + - Dict with 'counts'/'size': RLE mask + - List[Dict]: Multiple RLE masks + """ + with torch.inference_mode(): + if image is None or mask_input is None: + raise ValueError("Must provide image and mask!") + + image_np = load_image_rgb(image) + if image_np.dtype != np.uint8: + if image_np.max() <= 1: + image_np = (image_np * 255).astype(np.uint8) + else: + image_np = image_np.astype(np.uint8) + image_shape = (image_np.shape[0], image_np.shape[1]) + + # Convert to list of binary masks + if _is_single_mask_input(mask_input): + masks = [convert_mask_to_binary(mask_input, image_shape)] + elif isinstance(mask_input, np.ndarray) and mask_input.ndim == 3: + masks = [convert_mask_to_binary(m, image_shape) for m in mask_input] + else: + masks = [convert_mask_to_binary(m, image_shape) for m in mask_input] + + outputs = [] + for mask in masks: + result = self.pipeline.run(image=image_np, mask=mask) + outputs.append(result) + + if len(outputs) == 1: + result = outputs[0] + scene_gs = ready_gaussian_for_video_rendering(result["gs"]) + return { + "gs": scene_gs, + "glb": result["glb"], + "objects": outputs, + } + else: + scene_gs = make_scene(*outputs) + scene_gs = ready_gaussian_for_video_rendering(scene_gs) + scene_glb = make_scene_glb(*outputs) + return { + "gs": scene_gs, + "glb": scene_glb, + "objects": outputs, + } + + +def convert_tensor_to_list(tensor_data: torch.Tensor) -> Optional[List[float]]: + if tensor_data is None: + return None + if isinstance(tensor_data, torch.Tensor): + return tensor_data.cpu().flatten().tolist() + return tensor_data + + +def convert_3d_objects_result_to_api_response( + raw_result: Dict[str, Any], + inference_time: float, +) -> Sam3_3D_Objects_Response: + + mesh_glb_bytes = None + glb = raw_result.pop("glb", None) + if glb is not None: + glb_buffer = BytesIO() + glb.export(glb_buffer, "glb") + glb_buffer.seek(0) + mesh_glb_bytes = glb_buffer.getvalue() + + gaussian_ply_bytes = None + gaussian = raw_result.pop("gs", None) + if gaussian is not None: + gaussian_buffer = BytesIO() + gaussian.save_ply(gaussian_buffer) + gaussian_buffer.seek(0) + gaussian_ply_bytes = gaussian_buffer.getvalue() + + objects = [] + outputs_list = raw_result.pop("objects", []) + for output in outputs_list: + obj_glb_bytes = None + obj_glb = output.get("glb") + if obj_glb is not None: + obj_glb_buffer = BytesIO() + obj_glb.export(obj_glb_buffer, "glb") + obj_glb_buffer.seek(0) + obj_glb_bytes = obj_glb_buffer.getvalue() + + obj_ply_bytes = None + obj_gs = output.get("gs") + if obj_gs is not None: + obj_ply_buffer = BytesIO() + obj_gs.save_ply(obj_ply_buffer) + obj_ply_buffer.seek(0) + obj_ply_bytes = obj_ply_buffer.getvalue() + + obj_metadata = Sam3_3D_Objects_Metadata( + rotation=convert_tensor_to_list(output.get("rotation")), + translation=convert_tensor_to_list(output.get("translation")), + scale=convert_tensor_to_list(output.get("scale")), + ) + + objects.append( + Sam3_3D_Object_Item( + mesh_glb=obj_glb_bytes, + gaussian_ply=obj_ply_bytes, + metadata=obj_metadata, + ) + ) + + return Sam3_3D_Objects_Response( + mesh_glb=mesh_glb_bytes, + gaussian_ply=gaussian_ply_bytes, + objects=objects, + time=inference_time, + ) + + +def transform_glb_to_world(glb_mesh, rotation, translation, scale): + """ + Transform a GLB mesh from local to world coordinates. + + Note: to_glb() already applies z-up to y-up rotation, so we just need + to apply the pose transform (rotation, translation, scale). + + Based on export_transformed_mesh_glb from layout_post_optimization_utils.py + """ + quat = rotation.squeeze() + quat_normalized = quat / quat.norm() + R = quaternion_to_matrix(quat_normalized).cpu().numpy() + t = translation.squeeze().cpu().numpy() + s = scale.squeeze().cpu().numpy()[0] + + verts = torch.from_numpy(glb_mesh.vertices.copy()).float() + + center = verts.mean(dim=0) + + verts = verts - center + verts = verts * s + verts = verts @ torch.from_numpy(R.T).float() + verts = verts + center + verts = verts + torch.from_numpy(t).float() + + glb_mesh.vertices = verts.numpy() + return glb_mesh + + +def make_scene_glb(*outputs): + scene = trimesh.Scene() + + for i, output in enumerate(outputs): + glb = output["glb"] + glb = glb.copy() + + glb = transform_glb_to_world( + glb, + output["rotation"], + output["translation"], + output["scale"], + ) + scene.add_geometry(glb, node_name=f"object_{i}") + + return scene diff --git a/inference/models/utils.py b/inference/models/utils.py index 23ae4c65d1..7247e4d620 100644 --- a/inference/models/utils.py +++ b/inference/models/utils.py @@ -19,6 +19,7 @@ MOONDREAM2_ENABLED, PALIGEMMA_ENABLED, QWEN_2_5_ENABLED, + SAM3_3D_OBJECTS_ENABLED, SMOLVLM2_ENABLED, USE_INFERENCE_EXP_MODELS, ) @@ -498,6 +499,23 @@ category=ModelDependencyMissing, ) +try: + if SAM3_3D_OBJECTS_ENABLED: + from inference.models.sam3_3d.segment_anything_3d import ( + SegmentAnything3_3D_Objects, + ) + + ROBOFLOW_MODEL_TYPES[("3d-reconstruction", "sam3-3d-objects")] = ( + SegmentAnything3_3D_Objects + ) +except: + warnings.warn( + "Your `inference` configuration does not support SAM3_3D_Objects model. " + "Use pip install 'inference[sam3_3d]' to install missing requirements." + "To suppress this warning, set SAM3_3D_OBJECTS_ENABLED to False.", + category=ModelDependencyMissing, + ) + try: if CORE_MODEL_DOCTR_ENABLED: from inference.models import DocTR diff --git a/inference_experimental/inference_exp/utils/download.py b/inference_experimental/inference_exp/utils/download.py index 6b74a16d4f..86064049b5 100644 --- a/inference_experimental/inference_exp/utils/download.py +++ b/inference_experimental/inference_exp/utils/download.py @@ -604,4 +604,4 @@ def _handle_stream_download( if content_storage is not None: content_storage.update(chunk) if on_chunk_downloaded: - on_chunk_downloaded(len(chunk)) + on_chunk_downloaded(len(chunk)) \ No newline at end of file diff --git a/requirements/requirements.sam3_3d.txt b/requirements/requirements.sam3_3d.txt new file mode 100644 index 0000000000..f1a3c48e4c --- /dev/null +++ b/requirements/requirements.sam3_3d.txt @@ -0,0 +1,24 @@ +pytorch3d @ git+https://github.com/facebookresearch/pytorch3d.git@75ebeeaea0908c5527e7b1e305fbc7681382db47 +utils3d @ git+https://github.com/EasternJournalist/utils3d.git@3913c65d81e05e47b9f367250cf8c0f7462a0900 +optree>=0.14.1 +astor>=0.8.1 +OpenEXR>=3.3.1 +imath>=0.0.2 +open3d>=0.18.0 +trimesh>=4.6.10 +xatlas>=0.0.9 +pyvista>=0.45.2 +pymeshfix>=0.17.0 +igraph>=0.11.8 +easydict>=1.13 +nvdiffrast @ git+https://github.com/NVlabs/nvdiffrast.git@729261dc64c4241ea36efda84fbf532cc8b425b8 +moge @ git+https://github.com/microsoft/MoGe.git@a8c37341bc0325ca99b9d57981cc3bb2bd3e255b +spconv-cu121>=2.3.8 +omegaconf>=2.3.0 +hydra-core>=1.3.2 +loguru>=0.7.2 +lightning>=2.2.1 +flash_attn==2.7.4.post1 +diff_gaussian_rasterization @ git+https://github.com/autonomousvision/mip-splatting.git@dda02ab5ecf45d6edb8c540d9bb65c7e451345a9#subdirectory=submodules/diff-gaussian-rasterization +gsplat @ git+https://github.com/nerfstudio-project/gsplat.git@2323de5905d5e90e035f792fe65bad0fedd413e7 +tdfy @ git+https://github.com/roboflow/tdfy.git@a9753d31d8359dff372879b9f2697279a6b29d21 \ No newline at end of file