Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7348,6 +7348,82 @@ def test_no_label(self):
assert isinstance(out_img, tv_tensors.Image)
assert isinstance(out_boxes, tv_tensors.BoundingBoxes)

def test_semantic_masks_passthrough(self):
# Test that semantic masks (2D) pass through unchanged
H, W = 256, 128
boxes = tv_tensors.BoundingBoxes(
[[0, 0, 50, 50], [60, 60, 100, 100]],
format=tv_tensors.BoundingBoxFormat.XYXY,
canvas_size=(H, W),
)

# Create semantic segmentation mask (H, W) - should NOT be sanitized
semantic_mask = tv_tensors.Mask(torch.randint(0, 10, size=(H, W)))

sample = {
"boxes": boxes,
"semantic_mask": semantic_mask,
}

out = transforms.SanitizeBoundingBoxes(labels_getter=None)(sample)

# Check that semantic mask passed through unchanged
assert isinstance(out["semantic_mask"], tv_tensors.Mask)
assert out["semantic_mask"].shape == (H, W)
assert_equal(out["semantic_mask"], semantic_mask)

def test_masks_with_mismatched_shape_passthrough(self):
# Test that masks with shapes that don't match the number of boxes are passed through
H, W = 256, 128
boxes = tv_tensors.BoundingBoxes(
[[0, 0, 10, 10], [20, 20, 30, 30], [50, 50, 60, 60]],
format=tv_tensors.BoundingBoxFormat.XYXY,
canvas_size=(H, W),
)

# Create masks with different number of instances than boxes
mismatched_masks = tv_tensors.Mask(torch.randint(0, 2, size=(5, H, W))) # 5 masks but 3 boxes

sample = {
"boxes": boxes,
"masks": mismatched_masks,
}

# Should not raise an error, masks should pass through unchanged
out = transforms.SanitizeBoundingBoxes(labels_getter=None)(sample)

assert isinstance(out["masks"], tv_tensors.Mask)
assert out["masks"].shape == (5, H, W)
assert_equal(out["masks"], mismatched_masks)

def test_per_instance_masks_sanitized(self):
# Test that per-instance masks (N, H, W) are correctly sanitized
H, W = 256, 128
boxes, expected_valid_mask = self._get_boxes_and_valid_mask(H=H, W=W, min_size=10, min_area=10)
valid_indices = [i for (i, is_valid) in enumerate(expected_valid_mask) if is_valid]
num_boxes = boxes.shape[0]

# Create per-instance masks matching the number of boxes
per_instance_masks = tv_tensors.Mask(torch.randint(0, 2, size=(num_boxes, H, W)))
labels = torch.arange(num_boxes)

sample = {
"boxes": boxes,
"masks": per_instance_masks,
"labels": labels,
}

out = transforms.SanitizeBoundingBoxes(min_size=10, min_area=10)(sample)

# Check that masks were sanitized correctly
assert isinstance(out["masks"], tv_tensors.Mask)
assert out["masks"].shape[0] == len(valid_indices)
assert out["masks"].shape[0] == out["boxes"].shape[0] == out["labels"].shape[0]

# Verify correct masks were kept
for i, valid_idx in enumerate(valid_indices):
assert_equal(out["masks"][i], per_instance_masks[valid_idx])

def test_errors_transform(self):
good_bbox = tv_tensors.BoundingBoxes(
[[0, 0, 10, 10]],
Expand Down
17 changes: 14 additions & 3 deletions torchvision/transforms/v2/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,12 @@ class SanitizeBoundingBoxes(Transform):
It can also sanitize other tensors like the "iscrowd" or "area" properties from COCO
(see ``labels_getter`` parameter).
.. note::
**Mask handling**: This transform automatically detects and sanitizes per-instance masks
(shape ``[N, H, W]`` where N matches the number of bounding boxes). Semantic segmentation masks
(shape ``[H, W]``) or masks with mismatched dimensions are passed through unchanged.
You do not need to add masks to ``labels_getter`` for them to be sanitized.
It is recommended to call it at the end of a pipeline, before passing the
input to the models. It is critical to call this transform if
:class:`~torchvision.transforms.v2.RandomIoUCrop` was called.
Expand Down Expand Up @@ -456,12 +462,17 @@ def forward(self, *inputs: Any) -> Any:

def transform(self, inpt: Any, params: dict[str, Any]) -> Any:
is_label = params["labels"] is not None and any(inpt is label for label in params["labels"])
is_bounding_boxes_or_mask = isinstance(inpt, (tv_tensors.BoundingBoxes, tv_tensors.Mask))
is_bounding_boxes = isinstance(inpt, tv_tensors.BoundingBoxes)
is_mask = isinstance(inpt, tv_tensors.Mask)

if not (is_label or is_bounding_boxes_or_mask):
if not (is_label or is_bounding_boxes or is_mask):
return inpt

output = inpt[params["valid"]]
try:
output = inpt[params["valid"]]
except (IndexError):
# If indexing fails (e.g., shape mismatch), pass through unchanged
return inpt

if is_label:
return output
Expand Down
Loading