From 8940dc3f231145b81e82c0e25f29ee812a7ff917 Mon Sep 17 00:00:00 2001 From: Zhitao Yu Date: Wed, 5 Nov 2025 08:05:55 -0800 Subject: [PATCH] Fix SanitizeBoundingBoxes Handling of Semantic Masks (#9256) Summary: Background Currently, torchvision.transforms.v2.SanitizeBoundingBoxes fails when used inside a v2.Compose that receives both bounding boxes and a semantic segmentation mask as inputs. The transform attempts to apply a per-box boolean validity mask to all tv_tensors.Mask objects, including semantic masks (shape [H, W]), resulting in a shape mismatch and a crash. Error Example: IndexError: The shape of the mask [3] at index 0 does not match the shape of the indexed tensor [1080, 1920] at index 0 Expected Behavior The transform should only sanitize masks that have a 1:1 mapping with bounding boxes (e.g., per-instance masks). Semantic masks (2D, shape [H, W]) should be passed through unchanged. Task Objectives Update SanitizeBoundingBoxes Logic: Detect whether a tv_tensors.Mask is a per-instance mask (shape [N, H, W] or [N, ...] where N == num_boxes) or a semantic mask (shape [H, W]). Only apply the per-box validity mask to per-instance masks. Pass through semantic masks unchanged. If a mask does not match the number of boxes, do not raise an error; instead, pass it through. Optionally, log a warning if a mask is skipped for sanitization due to shape mismatch. Clarify Documentation: Update the docstring for SanitizeBoundingBoxes to explicitly state: Only per-instance masks are sanitized. Semantic masks are passed through unchanged. The transform does not require users to pass masks to labels_getter for them to be sanitized. Add/Update Unit Tests: Test with both per-instance masks and semantic masks in a v2.Compose. Ensure semantic masks are not sanitized and do not cause errors. Ensure per-instance masks are sanitized correctly. This can be added in TestSanitizeBoundingBoxes Backward Compatibility: Ensure that the change does not break existing datasets or user code that relies on current behavior. Finally submit a PR with the changes and link the issue in the description. Differential Revision: D85840801 --- test/test_transforms_v2.py | 76 ++++++++++++++++++++++++++++++ torchvision/transforms/v2/_misc.py | 17 +++++-- 2 files changed, 90 insertions(+), 3 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 487018b7c69..0f985ab9604 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -7348,6 +7348,82 @@ def test_no_label(self): assert isinstance(out_img, tv_tensors.Image) assert isinstance(out_boxes, tv_tensors.BoundingBoxes) + def test_semantic_masks_passthrough(self): + # Test that semantic masks (2D) pass through unchanged + H, W = 256, 128 + boxes = tv_tensors.BoundingBoxes( + [[0, 0, 50, 50], [60, 60, 100, 100]], + format=tv_tensors.BoundingBoxFormat.XYXY, + canvas_size=(H, W), + ) + + # Create semantic segmentation mask (H, W) - should NOT be sanitized + semantic_mask = tv_tensors.Mask(torch.randint(0, 10, size=(H, W))) + + sample = { + "boxes": boxes, + "semantic_mask": semantic_mask, + } + + out = transforms.SanitizeBoundingBoxes(labels_getter=None)(sample) + + # Check that semantic mask passed through unchanged + assert isinstance(out["semantic_mask"], tv_tensors.Mask) + assert out["semantic_mask"].shape == (H, W) + assert_equal(out["semantic_mask"], semantic_mask) + + def test_masks_with_mismatched_shape_passthrough(self): + # Test that masks with shapes that don't match the number of boxes are passed through + H, W = 256, 128 + boxes = tv_tensors.BoundingBoxes( + [[0, 0, 10, 10], [20, 20, 30, 30], [50, 50, 60, 60]], + format=tv_tensors.BoundingBoxFormat.XYXY, + canvas_size=(H, W), + ) + + # Create masks with different number of instances than boxes + mismatched_masks = tv_tensors.Mask(torch.randint(0, 2, size=(5, H, W))) # 5 masks but 3 boxes + + sample = { + "boxes": boxes, + "masks": mismatched_masks, + } + + # Should not raise an error, masks should pass through unchanged + out = transforms.SanitizeBoundingBoxes(labels_getter=None)(sample) + + assert isinstance(out["masks"], tv_tensors.Mask) + assert out["masks"].shape == (5, H, W) + assert_equal(out["masks"], mismatched_masks) + + def test_per_instance_masks_sanitized(self): + # Test that per-instance masks (N, H, W) are correctly sanitized + H, W = 256, 128 + boxes, expected_valid_mask = self._get_boxes_and_valid_mask(H=H, W=W, min_size=10, min_area=10) + valid_indices = [i for (i, is_valid) in enumerate(expected_valid_mask) if is_valid] + num_boxes = boxes.shape[0] + + # Create per-instance masks matching the number of boxes + per_instance_masks = tv_tensors.Mask(torch.randint(0, 2, size=(num_boxes, H, W))) + labels = torch.arange(num_boxes) + + sample = { + "boxes": boxes, + "masks": per_instance_masks, + "labels": labels, + } + + out = transforms.SanitizeBoundingBoxes(min_size=10, min_area=10)(sample) + + # Check that masks were sanitized correctly + assert isinstance(out["masks"], tv_tensors.Mask) + assert out["masks"].shape[0] == len(valid_indices) + assert out["masks"].shape[0] == out["boxes"].shape[0] == out["labels"].shape[0] + + # Verify correct masks were kept + for i, valid_idx in enumerate(valid_indices): + assert_equal(out["masks"][i], per_instance_masks[valid_idx]) + def test_errors_transform(self): good_bbox = tv_tensors.BoundingBoxes( [[0, 0, 10, 10]], diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py index df68dd3d243..305149c87b1 100644 --- a/torchvision/transforms/v2/_misc.py +++ b/torchvision/transforms/v2/_misc.py @@ -369,6 +369,12 @@ class SanitizeBoundingBoxes(Transform): It can also sanitize other tensors like the "iscrowd" or "area" properties from COCO (see ``labels_getter`` parameter). + .. note:: + **Mask handling**: This transform automatically detects and sanitizes per-instance masks + (shape ``[N, H, W]`` where N matches the number of bounding boxes). Semantic segmentation masks + (shape ``[H, W]``) or masks with mismatched dimensions are passed through unchanged. + You do not need to add masks to ``labels_getter`` for them to be sanitized. + It is recommended to call it at the end of a pipeline, before passing the input to the models. It is critical to call this transform if :class:`~torchvision.transforms.v2.RandomIoUCrop` was called. @@ -456,12 +462,17 @@ def forward(self, *inputs: Any) -> Any: def transform(self, inpt: Any, params: dict[str, Any]) -> Any: is_label = params["labels"] is not None and any(inpt is label for label in params["labels"]) - is_bounding_boxes_or_mask = isinstance(inpt, (tv_tensors.BoundingBoxes, tv_tensors.Mask)) + is_bounding_boxes = isinstance(inpt, tv_tensors.BoundingBoxes) + is_mask = isinstance(inpt, tv_tensors.Mask) - if not (is_label or is_bounding_boxes_or_mask): + if not (is_label or is_bounding_boxes or is_mask): return inpt - output = inpt[params["valid"]] + try: + output = inpt[params["valid"]] + except (IndexError): + # If indexing fails (e.g., shape mismatch), pass through unchanged + return inpt if is_label: return output