Skip to content

Commit 65b5b53

Browse files
Zhitao Yufacebook-github-bot
authored andcommitted
Fix SanitizeBoundingBoxes Handling of Semantic Masks (pytorch#9256)
Summary: Background Currently, torchvision.transforms.v2.SanitizeBoundingBoxes fails when used inside a v2.Compose that receives both bounding boxes and a semantic segmentation mask as inputs. The transform attempts to apply a per-box boolean validity mask to all tv_tensors.Mask objects, including semantic masks (shape [H, W]), resulting in a shape mismatch and a crash. Error Example: IndexError: The shape of the mask [3] at index 0 does not match the shape of the indexed tensor [1080, 1920] at index 0 Expected Behavior The transform should only sanitize masks that have a 1:1 mapping with bounding boxes (e.g., per-instance masks). Semantic masks (2D, shape [H, W]) should be passed through unchanged. Task Objectives Update SanitizeBoundingBoxes Logic: Detect whether a tv_tensors.Mask is a per-instance mask (shape [N, H, W] or [N, ...] where N == num_boxes) or a semantic mask (shape [H, W]). Only apply the per-box validity mask to per-instance masks. Pass through semantic masks unchanged. If a mask does not match the number of boxes, do not raise an error; instead, pass it through. Optionally, log a warning if a mask is skipped for sanitization due to shape mismatch. Clarify Documentation: Update the docstring for SanitizeBoundingBoxes to explicitly state: Only per-instance masks are sanitized. Semantic masks are passed through unchanged. The transform does not require users to pass masks to labels_getter for them to be sanitized. Add/Update Unit Tests: Test with both per-instance masks and semantic masks in a v2.Compose. Ensure semantic masks are not sanitized and do not cause errors. Ensure per-instance masks are sanitized correctly. This can be added in TestSanitizeBoundingBoxes Backward Compatibility: Ensure that the change does not break existing datasets or user code that relies on current behavior. Finally submit a PR with the changes and link the issue in the description. Differential Revision: D85840801
1 parent cfbc5c2 commit 65b5b53

File tree

2 files changed

+90
-3
lines changed

2 files changed

+90
-3
lines changed

test/test_transforms_v2.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7347,6 +7347,82 @@ def test_no_label(self):
73477347
assert isinstance(out_img, tv_tensors.Image)
73487348
assert isinstance(out_boxes, tv_tensors.BoundingBoxes)
73497349

7350+
def test_semantic_masks_passthrough(self):
7351+
# Test that semantic masks (2D) pass through unchanged
7352+
H, W = 256, 128
7353+
boxes = tv_tensors.BoundingBoxes(
7354+
[[0, 0, 50, 50], [60, 60, 100, 100]],
7355+
format=tv_tensors.BoundingBoxFormat.XYXY,
7356+
canvas_size=(H, W),
7357+
)
7358+
7359+
# Create semantic segmentation mask (H, W) - should NOT be sanitized
7360+
semantic_mask = tv_tensors.Mask(torch.randint(0, 10, size=(H, W)))
7361+
7362+
sample = {
7363+
"boxes": boxes,
7364+
"semantic_mask": semantic_mask,
7365+
}
7366+
7367+
out = transforms.SanitizeBoundingBoxes(labels_getter=None)(sample)
7368+
7369+
# Check that semantic mask passed through unchanged
7370+
assert isinstance(out["semantic_mask"], tv_tensors.Mask)
7371+
assert out["semantic_mask"].shape == (H, W)
7372+
assert_equal(out["semantic_mask"], semantic_mask)
7373+
7374+
def test_masks_with_mismatched_shape_passthrough(self):
7375+
# Test that masks with shapes that don't match the number of boxes are passed through
7376+
H, W = 256, 128
7377+
boxes = tv_tensors.BoundingBoxes(
7378+
[[0, 0, 10, 10], [20, 20, 30, 30], [50, 50, 60, 60]],
7379+
format=tv_tensors.BoundingBoxFormat.XYXY,
7380+
canvas_size=(H, W),
7381+
)
7382+
7383+
# Create masks with different number of instances than boxes
7384+
mismatched_masks = tv_tensors.Mask(torch.randint(0, 2, size=(5, H, W))) # 5 masks but 3 boxes
7385+
7386+
sample = {
7387+
"boxes": boxes,
7388+
"masks": mismatched_masks,
7389+
}
7390+
7391+
# Should not raise an error, masks should pass through unchanged
7392+
out = transforms.SanitizeBoundingBoxes(labels_getter=None)(sample)
7393+
7394+
assert isinstance(out["masks"], tv_tensors.Mask)
7395+
assert out["masks"].shape == (5, H, W)
7396+
assert_equal(out["masks"], mismatched_masks)
7397+
7398+
def test_per_instance_masks_sanitized(self):
7399+
# Test that per-instance masks (N, H, W) are correctly sanitized
7400+
H, W = 256, 128
7401+
boxes, expected_valid_mask = self._get_boxes_and_valid_mask(H=H, W=W, min_size=10, min_area=10)
7402+
valid_indices = [i for (i, is_valid) in enumerate(expected_valid_mask) if is_valid]
7403+
num_boxes = boxes.shape[0]
7404+
7405+
# Create per-instance masks matching the number of boxes
7406+
per_instance_masks = tv_tensors.Mask(torch.randint(0, 2, size=(num_boxes, H, W)))
7407+
labels = torch.arange(num_boxes)
7408+
7409+
sample = {
7410+
"boxes": boxes,
7411+
"masks": per_instance_masks,
7412+
"labels": labels,
7413+
}
7414+
7415+
out = transforms.SanitizeBoundingBoxes(min_size=10, min_area=10)(sample)
7416+
7417+
# Check that masks were sanitized correctly
7418+
assert isinstance(out["masks"], tv_tensors.Mask)
7419+
assert out["masks"].shape[0] == len(valid_indices)
7420+
assert out["masks"].shape[0] == out["boxes"].shape[0] == out["labels"].shape[0]
7421+
7422+
# Verify correct masks were kept
7423+
for i, valid_idx in enumerate(valid_indices):
7424+
assert_equal(out["masks"][i], per_instance_masks[valid_idx])
7425+
73507426
def test_errors_transform(self):
73517427
good_bbox = tv_tensors.BoundingBoxes(
73527428
[[0, 0, 10, 10]],

torchvision/transforms/v2/_misc.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,12 @@ class SanitizeBoundingBoxes(Transform):
369369
It can also sanitize other tensors like the "iscrowd" or "area" properties from COCO
370370
(see ``labels_getter`` parameter).
371371
372+
.. note::
373+
**Mask handling**: This transform automatically detects and sanitizes per-instance masks
374+
(shape ``[N, H, W]`` where N matches the number of bounding boxes). Semantic segmentation masks
375+
(shape ``[H, W]``) or masks with mismatched dimensions are passed through unchanged.
376+
You do not need to add masks to ``labels_getter`` for them to be sanitized.
377+
372378
It is recommended to call it at the end of a pipeline, before passing the
373379
input to the models. It is critical to call this transform if
374380
:class:`~torchvision.transforms.v2.RandomIoUCrop` was called.
@@ -456,12 +462,17 @@ def forward(self, *inputs: Any) -> Any:
456462

457463
def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
458464
is_label = params["labels"] is not None and any(inpt is label for label in params["labels"])
459-
is_bounding_boxes_or_mask = isinstance(inpt, (tv_tensors.BoundingBoxes, tv_tensors.Mask))
465+
is_bounding_boxes = isinstance(inpt, tv_tensors.BoundingBoxes)
466+
is_mask = isinstance(inpt, tv_tensors.Mask)
460467

461-
if not (is_label or is_bounding_boxes_or_mask):
468+
if not (is_label or is_bounding_boxes or is_mask):
462469
return inpt
463470

464-
output = inpt[params["valid"]]
471+
try:
472+
output = inpt[params["valid"]]
473+
except (IndexError):
474+
# If indexing fails (e.g., shape mismatch), pass through unchanged
475+
return inpt
465476

466477
if is_label:
467478
return output

0 commit comments

Comments
 (0)