Skip to content

Commit 839d117

Browse files
Erol444pre-commit-ci[bot]BordaCopilot
authored
Added support for creating Detections instances from SAM3 (#2103)
* Added support for creating Detections instances from SAM3 output - both from `inference` and from RF hosted server (dict) * added tests, addressed pr comments * fix(pre_commit): 🎨 auto format pre-commit hooks * Apply suggestions from code review --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 9955f01 commit 839d117

File tree

2 files changed

+317
-14
lines changed

2 files changed

+317
-14
lines changed

src/supervision/detection/core.py

Lines changed: 139 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,11 @@
1717
process_transformers_v4_segmentation_result,
1818
process_transformers_v5_segmentation_result,
1919
)
20-
from supervision.detection.utils.converters import mask_to_xyxy, xywh_to_xyxy
20+
from supervision.detection.utils.converters import (
21+
mask_to_xyxy,
22+
polygon_to_mask,
23+
xywh_to_xyxy,
24+
)
2125
from supervision.detection.utils.internal import (
2226
extract_ultralytics_masks,
2327
get_data_item,
@@ -52,7 +56,7 @@
5256
)
5357
from supervision.geometry.core import Position
5458
from supervision.utils.internal import deprecated, get_instance_variables
55-
from supervision.validators import validate_detections_fields
59+
from supervision.validators import validate_detections_fields, validate_resolution
5660

5761

5862
@dataclass
@@ -280,9 +284,11 @@ def from_ultralytics(cls, ultralytics_results) -> Detections:
280284
xyxy=ultralytics_results.obb.xyxy.cpu().numpy(),
281285
confidence=ultralytics_results.obb.conf.cpu().numpy(),
282286
class_id=class_id,
283-
tracker_id=ultralytics_results.obb.id.int().cpu().numpy()
284-
if ultralytics_results.obb.id is not None
285-
else None,
287+
tracker_id=(
288+
ultralytics_results.obb.id.int().cpu().numpy()
289+
if ultralytics_results.obb.id is not None
290+
else None
291+
),
286292
data={
287293
ORIENTED_BOX_COORDINATES: oriented_box_coordinates,
288294
CLASS_NAME_DATA_FIELD: class_names,
@@ -308,9 +314,11 @@ def from_ultralytics(cls, ultralytics_results) -> Detections:
308314
confidence=ultralytics_results.boxes.conf.cpu().numpy(),
309315
class_id=class_id,
310316
mask=extract_ultralytics_masks(ultralytics_results),
311-
tracker_id=ultralytics_results.boxes.id.int().cpu().numpy()
312-
if ultralytics_results.boxes.id is not None
313-
else None,
317+
tracker_id=(
318+
ultralytics_results.boxes.id.int().cpu().numpy()
319+
if ultralytics_results.boxes.id is not None
320+
else None
321+
),
314322
data={CLASS_NAME_DATA_FIELD: class_names},
315323
)
316324

@@ -464,9 +472,11 @@ def from_mmdetection(cls, mmdet_results) -> Detections:
464472
xyxy=mmdet_results.pred_instances.bboxes.cpu().numpy(),
465473
confidence=mmdet_results.pred_instances.scores.cpu().numpy(),
466474
class_id=mmdet_results.pred_instances.labels.cpu().numpy().astype(int),
467-
mask=mmdet_results.pred_instances.masks.cpu().numpy()
468-
if "masks" in mmdet_results.pred_instances
469-
else None,
475+
mask=(
476+
mmdet_results.pred_instances.masks.cpu().numpy()
477+
if "masks" in mmdet_results.pred_instances
478+
else None
479+
),
470480
)
471481

472482
@classmethod
@@ -584,9 +594,11 @@ class IDs, and confidences of the predictions.
584594
return cls(
585595
xyxy=detectron2_results["instances"].pred_boxes.tensor.cpu().numpy(),
586596
confidence=detectron2_results["instances"].scores.cpu().numpy(),
587-
mask=detectron2_results["instances"].pred_masks.cpu().numpy()
588-
if hasattr(detectron2_results["instances"], "pred_masks")
589-
else None,
597+
mask=(
598+
detectron2_results["instances"].pred_masks.cpu().numpy()
599+
if hasattr(detectron2_results["instances"], "pred_masks")
600+
else None
601+
),
590602
class_id=detectron2_results["instances"]
591603
.pred_classes.cpu()
592604
.numpy()
@@ -687,6 +699,119 @@ def from_sam(cls, sam_result: list[dict]) -> Detections:
687699
xyxy = xywh_to_xyxy(xywh=xywh)
688700
return cls(xyxy=xyxy, mask=mask)
689701

702+
@classmethod
703+
def from_sam3(
704+
cls, sam3_result: dict | Any, resolution_wh: tuple[int, int]
705+
) -> Detections:
706+
"""
707+
Creates a Detections instance from
708+
[SAM 3](https://github.com/facebookresearch/sam3) inference result.
709+
710+
Args:
711+
sam3_result (dict | Any): The output result from SAM 3 inference,
712+
either Sam3PromptResult from inference package or dict containing
713+
prompt_results with polygon predictions.
714+
resolution_wh (Tuple[int, int]): The width and height of the image
715+
used for mask generation.
716+
717+
Returns:
718+
Detections: A new Detections object.
719+
The `class_id` field contains the prompt index for each polygon.
720+
721+
Example:
722+
```python
723+
import cv2
724+
import supervision as sv
725+
from inference.models.sam3 import SegmentAnything3
726+
from inference.core.entities.requests.sam3 import Sam3Prompt
727+
728+
image = cv2.imread("<SOURCE_IMAGE_PATH>")
729+
model = SegmentAnything3(
730+
model_id="sam3/sam3_final",
731+
api_key="<ROBOFLOW_API_KEY>"
732+
)
733+
734+
prompts = [
735+
Sam3Prompt(type="text", text="car"),
736+
Sam3Prompt(type="text", text="tire"),
737+
]
738+
739+
result = model.segment_image(
740+
image=image,
741+
prompts=prompts,
742+
output_prob_thresh=0.5,
743+
format="polygon"
744+
)
745+
746+
height, width = image.shape[:2]
747+
detections = sv.Detections.from_sam3(
748+
sam3_result=result,
749+
resolution_wh=(width, height)
750+
)
751+
```
752+
"""
753+
width, height = validate_resolution(resolution_wh)
754+
755+
masks = []
756+
confidences = []
757+
class_ids = []
758+
759+
if isinstance(sam3_result, dict):
760+
prompt_results = sam3_result.get("prompt_results", [])
761+
else:
762+
prompt_results = getattr(sam3_result, "prompt_results", [])
763+
764+
for i, prompt_result in enumerate(prompt_results):
765+
if isinstance(prompt_result, dict):
766+
predictions = prompt_result.get("predictions", [])
767+
prompt_index = prompt_result.get("prompt_index", i)
768+
else:
769+
predictions = getattr(prompt_result, "predictions", [])
770+
prompt_index = getattr(prompt_result, "prompt_index", i)
771+
772+
for prediction in predictions:
773+
if isinstance(prediction, dict):
774+
prediction_format = prediction.get("format")
775+
if prediction_format and prediction_format != "polygon":
776+
continue
777+
pred_masks = prediction.get("masks", [])
778+
confidence = prediction.get("confidence", 1.0)
779+
else:
780+
prediction_format = getattr(prediction, "format", None)
781+
if prediction_format and prediction_format != "polygon":
782+
continue
783+
pred_masks = getattr(prediction, "masks", [])
784+
confidence = getattr(prediction, "confidence", 1.0)
785+
786+
if not pred_masks:
787+
continue
788+
789+
full_mask = np.zeros((height, width), dtype=bool)
790+
for poly in pred_masks:
791+
polygon = np.array(poly, dtype=np.int32)
792+
mask = polygon_to_mask(
793+
polygon=polygon, resolution_wh=(width, height)
794+
)
795+
mask = mask.astype(bool, copy=False)
796+
np.logical_or(full_mask, mask, out=full_mask)
797+
798+
masks.append(full_mask)
799+
confidences.append(confidence)
800+
class_ids.append(prompt_index)
801+
802+
if not masks:
803+
return cls.empty()
804+
805+
masks_np = np.stack(masks, axis=0)
806+
xyxy = mask_to_xyxy(masks_np)
807+
808+
return cls(
809+
xyxy=xyxy.astype(np.float32),
810+
mask=masks_np,
811+
confidence=np.array(confidences, dtype=np.float32),
812+
class_id=np.array(class_ids, dtype=int),
813+
)
814+
690815
@classmethod
691816
def from_azure_analyze_image(
692817
cls, azure_result: dict, class_map: dict[int, str] | None = None

tests/detection/test_from_sam.py

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
from __future__ import annotations
2+
3+
import numpy as np
4+
import pytest
5+
6+
from supervision.detection.core import Detections
7+
8+
SERVERLESS_SAM3_DICT = {
9+
"prompt_results": [
10+
{
11+
"prompt_index": 0,
12+
"echo": {
13+
"prompt_index": 0,
14+
"type": "text",
15+
"text": "person",
16+
"num_boxes": 0,
17+
},
18+
"predictions": [
19+
{
20+
"masks": [[[295, 675], [294, 676]], [[496, 617], [495, 618]]],
21+
"confidence": 0.94921875,
22+
"format": "polygon",
23+
}
24+
],
25+
},
26+
{
27+
"prompt_index": 1,
28+
"echo": {"prompt_index": 1, "type": "text", "text": "dog", "num_boxes": 0},
29+
"predictions": [
30+
{
31+
"masks": [[[316, 561], [316, 562]], [[345, 251], [344, 252]]],
32+
"confidence": 0.89453125,
33+
"format": "polygon",
34+
}
35+
],
36+
},
37+
],
38+
"time": 0.14756996370851994,
39+
}
40+
HOSTED_SAM3_DICT = {
41+
"prompt_results": [
42+
{
43+
"prompt_index": 0,
44+
"echo": {
45+
"prompt_index": 0,
46+
"type": "text",
47+
"text": "bottle",
48+
"num_boxes": 0,
49+
},
50+
"predictions": [
51+
{
52+
"masks": [[[1364, 200], [1365, 201]]],
53+
"confidence": 0.8984375,
54+
"format": "polygon",
55+
},
56+
{
57+
"masks": [[[1140, 171], [1139, 170]]],
58+
"confidence": 0.94140625,
59+
"format": "polygon",
60+
},
61+
],
62+
}
63+
],
64+
"time": 0.7277156260097399,
65+
}
66+
67+
68+
@pytest.mark.parametrize(
69+
("sam_result", "expected_xyxy", "expected_mask_shape"),
70+
[
71+
(
72+
[
73+
{
74+
"segmentation": np.ones((10, 10), dtype=bool),
75+
"bbox": [0, 0, 10, 10],
76+
"area": 100,
77+
}
78+
],
79+
np.array([[0, 0, 10, 10]], dtype=np.float32),
80+
(1, 10, 10),
81+
),
82+
([], np.empty((0, 4), dtype=np.float32), None),
83+
],
84+
)
85+
def test_from_sam(
86+
sam_result: list[dict],
87+
expected_xyxy: np.ndarray,
88+
expected_mask_shape: tuple[int, ...] | None,
89+
) -> None:
90+
detections = Detections.from_sam(sam_result=sam_result)
91+
92+
assert np.array_equal(detections.xyxy, expected_xyxy)
93+
if expected_mask_shape is not None:
94+
assert detections.mask.shape == expected_mask_shape
95+
else:
96+
assert detections.mask is None
97+
98+
99+
@pytest.mark.parametrize(
100+
(
101+
"sam3_result",
102+
"resolution_wh",
103+
"expected_xyxy",
104+
"expected_confidence",
105+
"expected_class_id",
106+
),
107+
[
108+
(
109+
{
110+
"prompt_results": [
111+
{
112+
"predictions": [
113+
{
114+
"format": "polygon",
115+
"masks": [[[0, 0], [10, 0], [10, 10], [0, 10]]],
116+
"confidence": 0.9,
117+
}
118+
],
119+
"prompt_index": 0,
120+
}
121+
]
122+
},
123+
(100, 100),
124+
np.array([[0, 0, 10, 10]], dtype=np.float32),
125+
np.array([0.9], dtype=np.float32),
126+
np.array([0], dtype=int),
127+
),
128+
(
129+
{"prompt_results": []},
130+
(100, 100),
131+
np.empty((0, 4), dtype=np.float32),
132+
np.empty((0,), dtype=np.float32),
133+
np.empty((0,), dtype=int),
134+
),
135+
(
136+
SERVERLESS_SAM3_DICT,
137+
(1000, 1000),
138+
np.array(
139+
[[294.0, 617.0, 496.0, 676.0], [316.0, 251.0, 345.0, 562.0]],
140+
dtype=np.float32,
141+
),
142+
np.array([0.94921875, 0.89453125], dtype=np.float32),
143+
np.array([0, 1], dtype=int),
144+
),
145+
(
146+
HOSTED_SAM3_DICT,
147+
(2000, 2000),
148+
np.array(
149+
[[1364.0, 200.0, 1365.0, 201.0], [1139.0, 170.0, 1140.0, 171.0]],
150+
dtype=np.float32,
151+
),
152+
np.array([0.898438, 0.941406], dtype=np.float32),
153+
np.array([0, 0], dtype=int),
154+
),
155+
],
156+
)
157+
def test_from_sam3(
158+
sam3_result: dict,
159+
resolution_wh: tuple[int, int],
160+
expected_xyxy: np.ndarray,
161+
expected_confidence: np.ndarray,
162+
expected_class_id: np.ndarray,
163+
) -> None:
164+
detections = Detections.from_sam3(
165+
sam3_result=sam3_result, resolution_wh=resolution_wh
166+
)
167+
168+
np.testing.assert_allclose(detections.xyxy, expected_xyxy, atol=1e-5)
169+
np.testing.assert_allclose(detections.confidence, expected_confidence, atol=1e-5)
170+
np.testing.assert_array_equal(detections.class_id, expected_class_id)
171+
172+
173+
def test_from_sam3_invalid_resolution() -> None:
174+
sam3_result = {"prompt_results": []}
175+
with pytest.raises(
176+
ValueError, match=r"Both dimensions in resolution must be positive\."
177+
):
178+
Detections.from_sam3(sam3_result=sam3_result, resolution_wh=(-100, 100))

0 commit comments

Comments
 (0)