Skip to content

Commit b677a61

Browse files
Zhitao Yufacebook-github-bot
authored andcommitted
Fix SanitizeBoundingBoxes Handling of Semantic Masks
Summary: Background Currently, torchvision.transforms.v2.SanitizeBoundingBoxes fails when used inside a v2.Compose that receives both bounding boxes and a semantic segmentation mask as inputs. The transform attempts to apply a per-box boolean validity mask to all tv_tensors.Mask objects, including semantic masks (shape [H, W]), resulting in a shape mismatch and a crash. Error Example: IndexError: The shape of the mask [3] at index 0 does not match the shape of the indexed tensor [1080, 1920] at index 0 Expected Behavior The transform should only sanitize masks that have a 1:1 mapping with bounding boxes (e.g., per-instance masks). Semantic masks (2D, shape [H, W]) should be passed through unchanged. Task Objectives Update SanitizeBoundingBoxes Logic: Detect whether a tv_tensors.Mask is a per-instance mask (shape [N, H, W] or [N, ...] where N == num_boxes) or a semantic mask (shape [H, W]). Only apply the per-box validity mask to per-instance masks. Pass through semantic masks unchanged. If a mask does not match the number of boxes, do not raise an error; instead, pass it through. Optionally, log a warning if a mask is skipped for sanitization due to shape mismatch. Clarify Documentation: Update the docstring for SanitizeBoundingBoxes to explicitly state: Only per-instance masks are sanitized. Semantic masks are passed through unchanged. The transform does not require users to pass masks to labels_getter for them to be sanitized. Add examples for both use cases (per-instance and semantic masks). Add/Update Unit Tests: Test with both per-instance masks and semantic masks in a v2.Compose. Ensure semantic masks are not sanitized and do not cause errors. Ensure per-instance masks are sanitized correctly. This can be added in TestSanitizeBoundingBoxes Backward Compatibility: Ensure that the change does not break existing datasets or user code that relies on current behavior. Finally submit a PR with the changes and link the issue in the description. Differential Revision: D85840801
1 parent 249326b commit b677a61

File tree

2 files changed

+134
-3
lines changed

2 files changed

+134
-3
lines changed

test/test_transforms_v2.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7347,6 +7347,82 @@ def test_no_label(self):
73477347
assert isinstance(out_img, tv_tensors.Image)
73487348
assert isinstance(out_boxes, tv_tensors.BoundingBoxes)
73497349

7350+
def test_semantic_masks_passthrough(self):
7351+
# Test that semantic masks (2D) pass through unchanged
7352+
H, W = 256, 128
7353+
boxes = tv_tensors.BoundingBoxes(
7354+
[[0, 0, 50, 50], [60, 60, 100, 100]],
7355+
format=tv_tensors.BoundingBoxFormat.XYXY,
7356+
canvas_size=(H, W),
7357+
)
7358+
7359+
# Create semantic segmentation mask (H, W) - should NOT be sanitized
7360+
semantic_mask = tv_tensors.Mask(torch.randint(0, 10, size=(H, W)))
7361+
7362+
sample = {
7363+
"boxes": boxes,
7364+
"semantic_mask": semantic_mask,
7365+
}
7366+
7367+
out = transforms.SanitizeBoundingBoxes(labels_getter=None)(sample)
7368+
7369+
# Check that semantic mask passed through unchanged
7370+
assert isinstance(out["semantic_mask"], tv_tensors.Mask)
7371+
assert out["semantic_mask"].shape == (H, W)
7372+
assert_equal(out["semantic_mask"], semantic_mask)
7373+
7374+
def test_masks_with_mismatched_shape_passthrough(self):
7375+
# Test that masks with shapes that don't match the number of boxes are passed through
7376+
H, W = 256, 128
7377+
boxes = tv_tensors.BoundingBoxes(
7378+
[[0, 0, 10, 10], [20, 20, 30, 30], [50, 50, 60, 60]],
7379+
format=tv_tensors.BoundingBoxFormat.XYXY,
7380+
canvas_size=(H, W),
7381+
)
7382+
7383+
# Create masks with different number of instances than boxes
7384+
mismatched_masks = tv_tensors.Mask(torch.randint(0, 2, size=(5, H, W))) # 5 masks but 3 boxes
7385+
7386+
sample = {
7387+
"boxes": boxes,
7388+
"masks": mismatched_masks,
7389+
}
7390+
7391+
# Should not raise an error, masks should pass through unchanged
7392+
out = transforms.SanitizeBoundingBoxes(labels_getter=None)(sample)
7393+
7394+
assert isinstance(out["masks"], tv_tensors.Mask)
7395+
assert out["masks"].shape == (5, H, W)
7396+
assert_equal(out["masks"], mismatched_masks)
7397+
7398+
def test_per_instance_masks_sanitized(self):
7399+
# Test that per-instance masks (N, H, W) are correctly sanitized
7400+
H, W = 256, 128
7401+
boxes, expected_valid_mask = self._get_boxes_and_valid_mask(H=H, W=W, min_size=10, min_area=10)
7402+
valid_indices = [i for (i, is_valid) in enumerate(expected_valid_mask) if is_valid]
7403+
num_boxes = boxes.shape[0]
7404+
7405+
# Create per-instance masks matching the number of boxes
7406+
per_instance_masks = tv_tensors.Mask(torch.randint(0, 2, size=(num_boxes, H, W)))
7407+
labels = torch.arange(num_boxes)
7408+
7409+
sample = {
7410+
"boxes": boxes,
7411+
"masks": per_instance_masks,
7412+
"labels": labels,
7413+
}
7414+
7415+
out = transforms.SanitizeBoundingBoxes(min_size=10, min_area=10)(sample)
7416+
7417+
# Check that masks were sanitized correctly
7418+
assert isinstance(out["masks"], tv_tensors.Mask)
7419+
assert out["masks"].shape[0] == len(valid_indices)
7420+
assert out["masks"].shape[0] == out["boxes"].shape[0] == out["labels"].shape[0]
7421+
7422+
# Verify correct masks were kept
7423+
for i, valid_idx in enumerate(valid_indices):
7424+
assert_equal(out["masks"][i], per_instance_masks[valid_idx])
7425+
73507426
def test_errors_transform(self):
73517427
good_bbox = tv_tensors.BoundingBoxes(
73527428
[[0, 0, 10, 10]],

torchvision/transforms/v2/_misc.py

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,12 @@ class SanitizeBoundingBoxes(Transform):
369369
It can also sanitize other tensors like the "iscrowd" or "area" properties from COCO
370370
(see ``labels_getter`` parameter).
371371
372+
.. note::
373+
**Mask handling**: This transform automatically detects and sanitizes per-instance masks
374+
(shape ``[N, H, W]`` where N matches the number of bounding boxes). Semantic segmentation masks
375+
(shape ``[H, W]``) or masks with mismatched dimensions are passed through unchanged.
376+
You do not need to add masks to ``labels_getter`` for them to be sanitized.
377+
372378
It is recommended to call it at the end of a pipeline, before passing the
373379
input to the models. It is critical to call this transform if
374380
:class:`~torchvision.transforms.v2.RandomIoUCrop` was called.
@@ -393,6 +399,36 @@ class SanitizeBoundingBoxes(Transform):
393399
from COCO.
394400
395401
If ``labels_getter`` is None then only bounding boxes are sanitized.
402+
403+
Example:
404+
>>> import torch
405+
>>> from torchvision import tv_tensors
406+
>>> from torchvision.transforms import v2
407+
>>>
408+
>>> # Per-instance masks (N boxes, each with H x W mask)
409+
>>> boxes = tv_tensors.BoundingBoxes(
410+
... [[0, 0, 10, 10], [5, 5, 15, 15], [0, 0, 5, 5]], # 3 boxes
411+
... format="XYXY", canvas_size=(20, 20)
412+
... )
413+
>>> masks = tv_tensors.Mask(torch.randint(0, 2, (3, 20, 20))) # 3 masks
414+
>>> labels = torch.tensor([1, 2, 3])
415+
>>>
416+
>>> # Both per-instance masks and labels will be sanitized
417+
>>> transform = v2.SanitizeBoundingBoxes(min_size=8)
418+
>>> sample = {"boxes": boxes, "masks": masks, "labels": labels}
419+
>>> output = transform(sample)
420+
>>> # Invalid boxes and their corresponding masks/labels are removed
421+
>>>
422+
>>> # Semantic mask (single 2D mask for entire image)
423+
>>> semantic_mask = tv_tensors.Mask(torch.randint(0, 10, (20, 20))) # H x W
424+
>>> sample_with_semantic = {
425+
... "boxes": boxes,
426+
... "masks": masks, # Per-instance masks - will be sanitized
427+
... "semantic_mask": semantic_mask, # Semantic mask - will NOT be sanitized
428+
... "labels": labels
429+
... }
430+
>>> output = transform(sample_with_semantic)
431+
>>> # semantic_mask passes through unchanged
396432
"""
397433

398434
def __init__(
@@ -456,12 +492,31 @@ def forward(self, *inputs: Any) -> Any:
456492

457493
def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
458494
is_label = params["labels"] is not None and any(inpt is label for label in params["labels"])
459-
is_bounding_boxes_or_mask = isinstance(inpt, (tv_tensors.BoundingBoxes, tv_tensors.Mask))
495+
is_bounding_boxes = isinstance(inpt, tv_tensors.BoundingBoxes)
496+
is_mask = isinstance(inpt, tv_tensors.Mask)
460497

461-
if not (is_label or is_bounding_boxes_or_mask):
498+
if not (is_label or is_bounding_boxes or is_mask):
462499
return inpt
463500

464-
output = inpt[params["valid"]]
501+
# Semantic masks (2D) should pass through unchanged
502+
if is_mask and inpt.ndim == 2:
503+
return inpt
504+
505+
# Try to apply the validity mask
506+
try:
507+
output = inpt[params["valid"]]
508+
except (IndexError) as e:
509+
# If indexing fails (e.g., shape mismatch for masks), pass through unchanged
510+
if is_mask:
511+
warnings.warn(
512+
f"Mask with shape {inpt.shape} could not be sanitized: {e}. "
513+
"The mask will be passed through unchanged. "
514+
"This transform only sanitizes per-instance masks with shape [N, H, W] where "
515+
"N matches the number of bounding boxes."
516+
)
517+
return inpt
518+
# For other types (labels, boxes), re-raise as this is unexpected
519+
raise
465520

466521
if is_label:
467522
return output

0 commit comments

Comments
 (0)