diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 487018b7c69..0f985ab9604 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -7348,6 +7348,82 @@ def test_no_label(self): assert isinstance(out_img, tv_tensors.Image) assert isinstance(out_boxes, tv_tensors.BoundingBoxes) + def test_semantic_masks_passthrough(self): + # Test that semantic masks (2D) pass through unchanged + H, W = 256, 128 + boxes = tv_tensors.BoundingBoxes( + [[0, 0, 50, 50], [60, 60, 100, 100]], + format=tv_tensors.BoundingBoxFormat.XYXY, + canvas_size=(H, W), + ) + + # Create semantic segmentation mask (H, W) - should NOT be sanitized + semantic_mask = tv_tensors.Mask(torch.randint(0, 10, size=(H, W))) + + sample = { + "boxes": boxes, + "semantic_mask": semantic_mask, + } + + out = transforms.SanitizeBoundingBoxes(labels_getter=None)(sample) + + # Check that semantic mask passed through unchanged + assert isinstance(out["semantic_mask"], tv_tensors.Mask) + assert out["semantic_mask"].shape == (H, W) + assert_equal(out["semantic_mask"], semantic_mask) + + def test_masks_with_mismatched_shape_passthrough(self): + # Test that masks with shapes that don't match the number of boxes are passed through + H, W = 256, 128 + boxes = tv_tensors.BoundingBoxes( + [[0, 0, 10, 10], [20, 20, 30, 30], [50, 50, 60, 60]], + format=tv_tensors.BoundingBoxFormat.XYXY, + canvas_size=(H, W), + ) + + # Create masks with different number of instances than boxes + mismatched_masks = tv_tensors.Mask(torch.randint(0, 2, size=(5, H, W))) # 5 masks but 3 boxes + + sample = { + "boxes": boxes, + "masks": mismatched_masks, + } + + # Should not raise an error, masks should pass through unchanged + out = transforms.SanitizeBoundingBoxes(labels_getter=None)(sample) + + assert isinstance(out["masks"], tv_tensors.Mask) + assert out["masks"].shape == (5, H, W) + assert_equal(out["masks"], mismatched_masks) + + def test_per_instance_masks_sanitized(self): + # Test that per-instance masks (N, H, W) are correctly sanitized + H, W = 256, 128 + boxes, expected_valid_mask = self._get_boxes_and_valid_mask(H=H, W=W, min_size=10, min_area=10) + valid_indices = [i for (i, is_valid) in enumerate(expected_valid_mask) if is_valid] + num_boxes = boxes.shape[0] + + # Create per-instance masks matching the number of boxes + per_instance_masks = tv_tensors.Mask(torch.randint(0, 2, size=(num_boxes, H, W))) + labels = torch.arange(num_boxes) + + sample = { + "boxes": boxes, + "masks": per_instance_masks, + "labels": labels, + } + + out = transforms.SanitizeBoundingBoxes(min_size=10, min_area=10)(sample) + + # Check that masks were sanitized correctly + assert isinstance(out["masks"], tv_tensors.Mask) + assert out["masks"].shape[0] == len(valid_indices) + assert out["masks"].shape[0] == out["boxes"].shape[0] == out["labels"].shape[0] + + # Verify correct masks were kept + for i, valid_idx in enumerate(valid_indices): + assert_equal(out["masks"][i], per_instance_masks[valid_idx]) + def test_errors_transform(self): good_bbox = tv_tensors.BoundingBoxes( [[0, 0, 10, 10]], diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py index df68dd3d243..305149c87b1 100644 --- a/torchvision/transforms/v2/_misc.py +++ b/torchvision/transforms/v2/_misc.py @@ -369,6 +369,12 @@ class SanitizeBoundingBoxes(Transform): It can also sanitize other tensors like the "iscrowd" or "area" properties from COCO (see ``labels_getter`` parameter). + .. note:: + **Mask handling**: This transform automatically detects and sanitizes per-instance masks + (shape ``[N, H, W]`` where N matches the number of bounding boxes). Semantic segmentation masks + (shape ``[H, W]``) or masks with mismatched dimensions are passed through unchanged. + You do not need to add masks to ``labels_getter`` for them to be sanitized. + It is recommended to call it at the end of a pipeline, before passing the input to the models. It is critical to call this transform if :class:`~torchvision.transforms.v2.RandomIoUCrop` was called. @@ -456,12 +462,17 @@ def forward(self, *inputs: Any) -> Any: def transform(self, inpt: Any, params: dict[str, Any]) -> Any: is_label = params["labels"] is not None and any(inpt is label for label in params["labels"]) - is_bounding_boxes_or_mask = isinstance(inpt, (tv_tensors.BoundingBoxes, tv_tensors.Mask)) + is_bounding_boxes = isinstance(inpt, tv_tensors.BoundingBoxes) + is_mask = isinstance(inpt, tv_tensors.Mask) - if not (is_label or is_bounding_boxes_or_mask): + if not (is_label or is_bounding_boxes or is_mask): return inpt - output = inpt[params["valid"]] + try: + output = inpt[params["valid"]] + except (IndexError): + # If indexing fails (e.g., shape mismatch), pass through unchanged + return inpt if is_label: return output