Skip to content

Commit b820aad

Browse files
committed
Change default labels_getter of SanitizeKeypoints to None.
1 parent 125ab39 commit b820aad

File tree

2 files changed

+8
-9
lines changed

2 files changed

+8
-9
lines changed

test/test_transforms_v2.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7740,22 +7740,21 @@ def test_errors_transform(self):
77407740
# Test missing labels key
77417741
with pytest.raises(ValueError, match="Could not infer where the labels are"):
77427742
bad_sample = {"keypoints": good_keypoints, "BAD_KEY": torch.tensor([0])}
7743-
transforms.SanitizeKeyPoints()(bad_sample)
7743+
transforms.SanitizeKeyPoints(labels_getter="default")(bad_sample)
77447744

77457745
# Test labels not a tensor
77467746
with pytest.raises(ValueError, match="must be a tensor"):
77477747
bad_sample = {"keypoints": good_keypoints, "labels": [0]}
7748-
transforms.SanitizeKeyPoints()(bad_sample)
7748+
transforms.SanitizeKeyPoints(labels_getter="default")(bad_sample)
77497749

77507750
# Test mismatched sizes
77517751
with pytest.raises(ValueError, match="Number of"):
77527752
bad_sample = {"keypoints": good_keypoints, "labels": torch.tensor([0, 1, 2])}
7753-
transforms.SanitizeKeyPoints()(bad_sample)
7753+
transforms.SanitizeKeyPoints(labels_getter="default")(bad_sample)
77547754

77557755
# Test min_invalid_points > 1 for 2D keypoints
77567756
with pytest.raises(ValueError, match="so min_invalid_points must be 1"):
7757-
sample = {"keypoints": good_keypoints, "labels": torch.tensor([0])}
7758-
transforms.SanitizeKeyPoints(min_invalid_points=2)(sample)
7757+
transforms.SanitizeKeyPoints(min_invalid_points=2)(good_keypoints)
77597758

77607759
def test_no_label(self):
77617760
"""Test transform without labels."""
@@ -7764,7 +7763,7 @@ def test_no_label(self):
77647763

77657764
# Should raise error without labels_getter=None
77667765
with pytest.raises(ValueError, match="or a two-tuple whose second item is a dict"):
7767-
transforms.SanitizeKeyPoints()(img, keypoints)
7766+
transforms.SanitizeKeyPoints(labels_getter="default")(img, keypoints)
77687767

77697768
# Should work with labels_getter=None
77707769
out_img, out_keypoints = transforms.SanitizeKeyPoints(labels_getter=None)(img, keypoints)

torchvision/transforms/v2/_misc.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -503,22 +503,22 @@ class SanitizeKeyPoints(Transform):
503503
Default is 1.
504504
labels_getter (callable or str or None, optional): indicates how to identify the labels in the input
505505
(or anything else that needs to be sanitized along with the keypoints).
506-
By default, this will try to find a "labels" key in the input (case-insensitive), if
506+
If set to the string ``"default"``, this will try to find a "labels" key in the input (case-insensitive), if
507507
the input is a dict or it is a tuple whose second element is a dict.
508508
509509
It can also be a callable that takes the same input as the transform, and returns either:
510510
511511
- A single tensor (the labels)
512512
- A tuple/list of tensors, each of which will be subject to the same sanitization as the keypoints.
513513
514-
If ``labels_getter`` is None then only keypoints are sanitized.
514+
If ``labels_getter`` is None (the default), then only keypoints are sanitized.
515515
"""
516516

517517
def __init__(
518518
self,
519519
min_valid_edge_distance: int = 0,
520520
min_invalid_points: int | float = 1,
521-
labels_getter: Union[Callable[[Any], Any], str, None] = "default",
521+
labels_getter: Union[Callable[[Any], Any], str, None] = None,
522522
) -> None:
523523
super().__init__()
524524
self.min_valid_edge_distance = min_valid_edge_distance

0 commit comments

Comments
 (0)