@@ -7740,22 +7740,21 @@ def test_errors_transform(self):
77407740        # Test missing labels key 
77417741        with  pytest .raises (ValueError , match = "Could not infer where the labels are" ):
77427742            bad_sample  =  {"keypoints" : good_keypoints , "BAD_KEY" : torch .tensor ([0 ])}
7743-             transforms .SanitizeKeyPoints ()(bad_sample )
7743+             transforms .SanitizeKeyPoints (labels_getter = "default" )(bad_sample )
77447744
77457745        # Test labels not a tensor 
77467746        with  pytest .raises (ValueError , match = "must be a tensor" ):
77477747            bad_sample  =  {"keypoints" : good_keypoints , "labels" : [0 ]}
7748-             transforms .SanitizeKeyPoints ()(bad_sample )
7748+             transforms .SanitizeKeyPoints (labels_getter = "default" )(bad_sample )
77497749
77507750        # Test mismatched sizes 
77517751        with  pytest .raises (ValueError , match = "Number of" ):
77527752            bad_sample  =  {"keypoints" : good_keypoints , "labels" : torch .tensor ([0 , 1 , 2 ])}
7753-             transforms .SanitizeKeyPoints ()(bad_sample )
7753+             transforms .SanitizeKeyPoints (labels_getter = "default" )(bad_sample )
77547754
77557755        # Test min_invalid_points > 1 for 2D keypoints 
77567756        with  pytest .raises (ValueError , match = "so min_invalid_points must be 1" ):
7757-             sample  =  {"keypoints" : good_keypoints , "labels" : torch .tensor ([0 ])}
7758-             transforms .SanitizeKeyPoints (min_invalid_points = 2 )(sample )
7757+             transforms .SanitizeKeyPoints (min_invalid_points = 2 )(good_keypoints )
77597758
77607759    def  test_no_label (self ):
77617760        """Test transform without labels.""" 
@@ -7764,7 +7763,7 @@ def test_no_label(self):
77647763
77657764        # Should raise error without labels_getter=None 
77667765        with  pytest .raises (ValueError , match = "or a two-tuple whose second item is a dict" ):
7767-             transforms .SanitizeKeyPoints ()(img , keypoints )
7766+             transforms .SanitizeKeyPoints (labels_getter = "default" )(img , keypoints )
77687767
77697768        # Should work with labels_getter=None 
77707769        out_img , out_keypoints  =  transforms .SanitizeKeyPoints (labels_getter = None )(img , keypoints )
0 commit comments