Skip to content
37 changes: 37 additions & 0 deletions supervision/detection/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from_florence_2,
from_google_gemini_2_0,
from_google_gemini_2_5,
from_kosmos,
from_moondream,
from_paligemma,
from_qwen_2_5_vl,
Expand Down Expand Up @@ -1452,6 +1453,36 @@ def from_vlm(cls, vlm: VLM | str, result: str | dict, **kwargs: Any) -> Detectio
# [1908.01, 1346.67, 2585.99, 2024.11]])
```

!!! example "Kosmos-2"
```python

import supervision as sv

kosmos_result = (
'An image of a small statue of a cat, with a gramophone and a man walking past in the background.',
[
('a small statue of a cat', (12, 35), [(0.265625, 0.015625, 0.703125, 0.984375)]),
('a gramophone', (42, 54), [(0.234375, 0.015625, 0.703125, 0.515625)]),
('a man', (59, 64), [(0.015625, 0.390625, 0.171875, 0.984375)])
]
)
Comment on lines +1612 to +1619
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to take all this as input? most of that stuff is unnecessary for us

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to, but this is the data structure that KOSMOS-2's post processor returns. I personally think asking the users to enter a different data structure would introduce a layer of friction.

detections = sv.Detections.from_vlm(
sv.VLM.KOSMOS,
kosmos_result,
resolution_wh=image.size,
)
detections.xyxy
# array([[310.78125, 11.625 , 822.65625, 732.375 ],
# [274.21875, 11.625 , 822.65625, 383.625 ],
# [ 18.28125, 290.625 , 201.09375, 732.375 ]])

detections.class_id
# array([0, 1, 2])

detections.data
# {'class_name': array(['a small statue of a cat', 'a gramophone', 'a man'])}
```

""" # noqa: E501

vlm = validate_vlm_parameters(vlm, result, kwargs)
Expand Down Expand Up @@ -1501,6 +1532,12 @@ def from_vlm(cls, vlm: VLM | str, result: str | dict, **kwargs: Any) -> Detectio
data=data,
)

if vlm == VLM.KOSMOS:
xyxy, class_id, class_name = from_kosmos(result, **kwargs)
return cls(
xyxy=xyxy, class_id=class_id, data={CLASS_NAME_DATA_FIELD: class_name}
)

return cls.empty()

@classmethod
Expand Down
64 changes: 61 additions & 3 deletions supervision/detection/vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import json
import re
from enum import Enum
from typing import Any
from typing import Any, get_origin

import numpy as np
from PIL import Image
Expand Down Expand Up @@ -39,6 +39,7 @@ class LMM(Enum):
GOOGLE_GEMINI_2_0 = "gemini_2_0"
GOOGLE_GEMINI_2_5 = "gemini_2_5"
MOONDREAM = "moondream"
KOSMOS = "kosmos"

@classmethod
def list(cls):
Expand Down Expand Up @@ -79,6 +80,7 @@ class VLM(Enum):
GOOGLE_GEMINI_2_0 = "gemini_2_0"
GOOGLE_GEMINI_2_5 = "gemini_2_5"
MOONDREAM = "moondream"
KOSMOS = "kosmos"

@classmethod
def list(cls):
Expand Down Expand Up @@ -107,6 +109,9 @@ def from_value(cls, value: VLM | str) -> VLM:
VLM.GOOGLE_GEMINI_2_0: str,
VLM.GOOGLE_GEMINI_2_5: str,
VLM.MOONDREAM: dict,
VLM.KOSMOS: tuple[
str, list[tuple[str, tuple[int, int], list[tuple[int, int, int, int]]]]
],
}

REQUIRED_ARGUMENTS: dict[VLM, list[str]] = {
Expand All @@ -116,6 +121,7 @@ def from_value(cls, value: VLM | str) -> VLM:
VLM.GOOGLE_GEMINI_2_0: ["resolution_wh"],
VLM.GOOGLE_GEMINI_2_5: ["resolution_wh"],
VLM.MOONDREAM: ["resolution_wh"],
VLM.KOSMOS: ["resolution_wh"],
}

ALLOWED_ARGUMENTS: dict[VLM, list[str]] = {
Expand All @@ -125,6 +131,7 @@ def from_value(cls, value: VLM | str) -> VLM:
VLM.GOOGLE_GEMINI_2_0: ["resolution_wh", "classes"],
VLM.GOOGLE_GEMINI_2_5: ["resolution_wh", "classes"],
VLM.MOONDREAM: ["resolution_wh"],
VLM.KOSMOS: ["resolution_wh"],
}

SUPPORTED_TASKS_FLORENCE_2 = [
Expand Down Expand Up @@ -164,9 +171,11 @@ def validate_vlm_parameters(vlm: VLM | str, result: Any, kwargs: dict[str, Any])
f"Invalid vlm value: {vlm}. Must be one of {[e.value for e in VLM]}"
)

if not isinstance(result, RESULT_TYPES[vlm]):
expected_type = RESULT_TYPES[vlm]
origin_type = get_origin(expected_type) or expected_type
if not isinstance(result, origin_type):
raise ValueError(
f"Invalid VLM result type: {type(result)}. Must be {RESULT_TYPES[vlm]}"
f"Invalid VLM result type: {type(result)}. Must be {expected_type}"
)

required_args = REQUIRED_ARGUMENTS.get(vlm, [])
Expand Down Expand Up @@ -717,3 +726,52 @@ def from_moondream(
return np.empty((0, 4))

return np.array(denormalize_xyxy, dtype=float)


def from_kosmos(
result: tuple[
str, list[tuple[str, tuple[int, int], list[tuple[int, int, int, int]]]]
],
resolution_wh: tuple[int, int],
) -> tuple[np.ndarray]:
"""
Parse and scale bounding boxes from kosmos-2 result.

The result is a tuple of a string and a list of tuples.
The first element of the tuple is the caption.
The second element of the tuple is a list of tuples containing the class name,
the start and end index of the class name in the caption,
and the bounding box coordinates normalized to the range [0, 1].

The result is supposed to be in the following format:
```python
result = (
'An image of a small statue of a cat, with a gramophone and a man walking past in the background.',
[
('a small statue of a cat', (12, 35), [(0.265625, 0.015625, 0.703125, 0.984375)]),
('a gramophone', (42, 54), [(0.234375, 0.015625, 0.703125, 0.515625)]),
('a man', (59, 64), [(0.015625, 0.390625, 0.171875, 0.984375)])
]
)
```

Args:
result: The result from the kosmos-2 model.
resolution_wh: (output_width, output_height) to which we rescale the boxes.

Returns:
xyxy (np.ndarray): An array of shape `(n, 4)` containing
the bounding boxes coordinates in format `[x1, y1, x2, y2]`
class_id (np.ndarray): An array of shape `(n,)` containing
the class indices for each bounding box
class_name (np.ndarray): An array of shape `(n,)` containing
the class labels for each bounding box
""" # noqa: E501
_, entity_locations = result
xyxy, class_names = [], []
for item in entity_locations:
class_name = item[0]
bbox = item[2][0]
xyxy.append(denormalize_boxes(np.array(bbox), resolution_wh=resolution_wh))
class_names.append(class_name)
return np.array(xyxy).reshape(-1, 4), np.array(range(len(xyxy))), class_names
Loading