Skip to content

Commit 1023987

Browse files
Add min_area to SanitizeBoundingBox (#7735)
Co-authored-by: Nicolas Hug <[email protected]>
1 parent f7d9e75 commit 1023987

File tree

3 files changed

+36
-18
lines changed

3 files changed

+36
-18
lines changed

test/test_transforms_v2.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5805,7 +5805,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
58055805

58065806

58075807
class TestSanitizeBoundingBoxes:
5808-
def _get_boxes_and_valid_mask(self, H=256, W=128, min_size=10):
5808+
def _get_boxes_and_valid_mask(self, H=256, W=128, min_size=10, min_area=10):
58095809
boxes_and_validity = [
58105810
([0, 1, 10, 1], False), # Y1 == Y2
58115811
([0, 1, 0, 20], False), # X1 == X2
@@ -5816,17 +5816,16 @@ def _get_boxes_and_valid_mask(self, H=256, W=128, min_size=10):
58165816
([-1, 1, 10, 20], False), # any < 0
58175817
([0, 0, -1, 20], False), # any < 0
58185818
([0, 0, -10, -1], False), # any < 0
5819-
([0, 0, min_size, 10], True), # H < min_size
5820-
([0, 0, 10, min_size], True), # W < min_size
5821-
([0, 0, W, H], True), # TODO: Is that actually OK?? Should it be -1?
5822-
([1, 1, 30, 20], True),
5823-
([0, 0, 10, 10], True),
5824-
([1, 1, 30, 20], True),
5819+
([0, 0, min_size, 10], min_size * 10 >= min_area), # H < min_size
5820+
([0, 0, 10, min_size], min_size * 10 >= min_area), # W < min_size
5821+
([0, 0, W, H], W * H >= min_area),
5822+
([1, 1, 30, 20], 29 * 19 >= min_area),
5823+
([0, 0, 10, 10], 9 * 9 >= min_area),
5824+
([1, 1, 30, 20], 29 * 19 >= min_area),
58255825
]
58265826

58275827
random.shuffle(boxes_and_validity) # For test robustness: mix order of wrong and correct cases
58285828
boxes, expected_valid_mask = zip(*boxes_and_validity)
5829-
58305829
boxes = tv_tensors.BoundingBoxes(
58315830
boxes,
58325831
format=tv_tensors.BoundingBoxFormat.XYXY,
@@ -5835,7 +5834,7 @@ def _get_boxes_and_valid_mask(self, H=256, W=128, min_size=10):
58355834

58365835
return boxes, expected_valid_mask
58375836

5838-
@pytest.mark.parametrize("min_size", (1, 10))
5837+
@pytest.mark.parametrize("min_size, min_area", ((1, 1), (10, 1), (10, 101)))
58395838
@pytest.mark.parametrize(
58405839
"labels_getter",
58415840
(
@@ -5848,15 +5847,15 @@ def _get_boxes_and_valid_mask(self, H=256, W=128, min_size=10):
58485847
),
58495848
)
58505849
@pytest.mark.parametrize("sample_type", (tuple, dict))
5851-
def test_transform(self, min_size, labels_getter, sample_type):
5850+
def test_transform(self, min_size, min_area, labels_getter, sample_type):
58525851

58535852
if sample_type is tuple and not isinstance(labels_getter, str):
58545853
# The "lambda inputs: inputs["labels"]" labels_getter used in this test
58555854
# doesn't work if the input is a tuple.
58565855
return
58575856

58585857
H, W = 256, 128
5859-
boxes, expected_valid_mask = self._get_boxes_and_valid_mask(H=H, W=W, min_size=min_size)
5858+
boxes, expected_valid_mask = self._get_boxes_and_valid_mask(H=H, W=W, min_size=min_size, min_area=min_area)
58605859
valid_indices = [i for (i, is_valid) in enumerate(expected_valid_mask) if is_valid]
58615860

58625861
labels = torch.arange(boxes.shape[0])
@@ -5880,7 +5879,9 @@ def test_transform(self, min_size, labels_getter, sample_type):
58805879
img = sample.pop("image")
58815880
sample = (img, sample)
58825881

5883-
out = transforms.SanitizeBoundingBoxes(min_size=min_size, labels_getter=labels_getter)(sample)
5882+
out = transforms.SanitizeBoundingBoxes(min_size=min_size, min_area=min_area, labels_getter=labels_getter)(
5883+
sample
5884+
)
58845885

58855886
if sample_type is tuple:
58865887
out_image = out[0]
@@ -5977,6 +5978,8 @@ def test_errors_transform(self):
59775978

59785979
with pytest.raises(ValueError, match="min_size must be >= 1"):
59795980
transforms.SanitizeBoundingBoxes(min_size=0)
5981+
with pytest.raises(ValueError, match="min_area must be >= 1"):
5982+
transforms.SanitizeBoundingBoxes(min_area=0)
59805983
with pytest.raises(ValueError, match="labels_getter should either be 'default'"):
59815984
transforms.SanitizeBoundingBoxes(labels_getter=12)
59825985

torchvision/transforms/v2/_misc.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ class SanitizeBoundingBoxes(Transform):
344344
345345
This transform removes bounding boxes and their associated labels/masks that:
346346
347-
- are below a given ``min_size``: by default this also removes degenerate boxes that have e.g. X2 <= X1.
347+
- are below a given ``min_size`` or ``min_area``: by default this also removes degenerate boxes that have e.g. X2 <= X1.
348348
- have any coordinate outside of their corresponding image. You may want to
349349
call :class:`~torchvision.transforms.v2.ClampBoundingBoxes` first to avoid undesired removals.
350350
@@ -359,7 +359,8 @@ class SanitizeBoundingBoxes(Transform):
359359
cases.
360360
361361
Args:
362-
min_size (float, optional) The size below which bounding boxes are removed. Default is 1.
362+
min_size (float, optional): The size below which bounding boxes are removed. Default is 1.
363+
min_area (float, optional): The area below which bounding boxes are removed. Default is 1.
363364
labels_getter (callable or str or None, optional): indicates how to identify the labels in the input
364365
(or anything else that needs to be sanitized along with the bounding boxes).
365366
By default, this will try to find a "labels" key in the input (case-insensitive), if
@@ -379,6 +380,7 @@ class SanitizeBoundingBoxes(Transform):
379380
def __init__(
380381
self,
381382
min_size: float = 1.0,
383+
min_area: float = 1.0,
382384
labels_getter: Union[Callable[[Any], Any], str, None] = "default",
383385
) -> None:
384386
super().__init__()
@@ -387,6 +389,10 @@ def __init__(
387389
raise ValueError(f"min_size must be >= 1, got {min_size}.")
388390
self.min_size = min_size
389391

392+
if min_area < 1:
393+
raise ValueError(f"min_area must be >= 1, got {min_area}.")
394+
self.min_area = min_area
395+
390396
self.labels_getter = labels_getter
391397
self._labels_getter = _parse_labels_getter(labels_getter)
392398

@@ -422,7 +428,9 @@ def forward(self, *inputs: Any) -> Any:
422428
format=boxes.format,
423429
canvas_size=boxes.canvas_size,
424430
min_size=self.min_size,
431+
min_area=self.min_area,
425432
)
433+
426434
params = dict(valid=valid, labels=labels)
427435
flat_outputs = [self._transform(inpt, params) for inpt in flat_inputs]
428436

torchvision/transforms/v2/functional/_misc.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -322,12 +322,13 @@ def sanitize_bounding_boxes(
322322
format: Optional[tv_tensors.BoundingBoxFormat] = None,
323323
canvas_size: Optional[Tuple[int, int]] = None,
324324
min_size: float = 1.0,
325+
min_area: float = 1.0,
325326
) -> Tuple[torch.Tensor, torch.Tensor]:
326327
"""Remove degenerate/invalid bounding boxes and return the corresponding indexing mask.
327328
328329
This removes bounding boxes that:
329330
330-
- are below a given ``min_size``: by default this also removes degenerate boxes that have e.g. X2 <= X1.
331+
- are below a given ``min_size`` or ``min_area``: by default this also removes degenerate boxes that have e.g. X2 <= X1.
331332
- have any coordinate outside of their corresponding image. You may want to
332333
call :func:`~torchvision.transforms.v2.functional.clamp_bounding_boxes` first to avoid undesired removals.
333334
@@ -346,6 +347,7 @@ def sanitize_bounding_boxes(
346347
(size of the corresponding image/video).
347348
Must be left to none if ``bounding_boxes`` is a :class:`~torchvision.tv_tensors.BoundingBoxes` object.
348349
min_size (float, optional) The size below which bounding boxes are removed. Default is 1.
350+
min_area (float, optional) The area below which bounding boxes are removed. Default is 1.
349351
350352
Returns:
351353
out (tuple of Tensors): The subset of valid bounding boxes, and the corresponding indexing mask.
@@ -361,7 +363,7 @@ def sanitize_bounding_boxes(
361363
if isinstance(format, str):
362364
format = tv_tensors.BoundingBoxFormat[format.upper()]
363365
valid = _get_sanitize_bounding_boxes_mask(
364-
bounding_boxes, format=format, canvas_size=canvas_size, min_size=min_size
366+
bounding_boxes, format=format, canvas_size=canvas_size, min_size=min_size, min_area=min_area
365367
)
366368
bounding_boxes = bounding_boxes[valid]
367369
else:
@@ -374,7 +376,11 @@ def sanitize_bounding_boxes(
374376
"Leave those to None or pass bounding_boxes as a pure tensor."
375377
)
376378
valid = _get_sanitize_bounding_boxes_mask(
377-
bounding_boxes, format=bounding_boxes.format, canvas_size=bounding_boxes.canvas_size, min_size=min_size
379+
bounding_boxes,
380+
format=bounding_boxes.format,
381+
canvas_size=bounding_boxes.canvas_size,
382+
min_size=min_size,
383+
min_area=min_area,
378384
)
379385
bounding_boxes = tv_tensors.wrap(bounding_boxes[valid], like=bounding_boxes)
380386

@@ -386,6 +392,7 @@ def _get_sanitize_bounding_boxes_mask(
386392
format: tv_tensors.BoundingBoxFormat,
387393
canvas_size: Tuple[int, int],
388394
min_size: float = 1.0,
395+
min_area: float = 1.0,
389396
) -> torch.Tensor:
390397

391398
bounding_boxes = _convert_bounding_box_format(
@@ -394,7 +401,7 @@ def _get_sanitize_bounding_boxes_mask(
394401

395402
image_h, image_w = canvas_size
396403
ws, hs = bounding_boxes[:, 2] - bounding_boxes[:, 0], bounding_boxes[:, 3] - bounding_boxes[:, 1]
397-
valid = (ws >= min_size) & (hs >= min_size) & (bounding_boxes >= 0).all(dim=-1)
404+
valid = (ws >= min_size) & (hs >= min_size) & (bounding_boxes >= 0).all(dim=-1) & (ws * hs >= min_area)
398405
# TODO: Do we really need to check for out of bounds here? All
399406
# transforms should be clamping anyway, so this should never happen?
400407
image_h, image_w = canvas_size

0 commit comments

Comments
 (0)