@@ -7401,24 +7401,22 @@ class TestSanitizeKeyPoints:
74017401    def  _make_keypoints_with_validity (
74027402        self ,
74037403        canvas_size = (100 , 100 ),
7404-         min_valid_edge_distance = 0 ,
7405-         min_invalid_points = 1 ,
74067404        shape = "2d" ,  # "2d", "3d", "4d" for different keypoint shapes 
74077405    ):
74087406        """Create keypoints with known validity for testing.""" 
74097407        canvas_h , canvas_w  =  canvas_size 
74107408
74117409        if  shape  ==  "2d" :  # [N_points, 2] 
74127410            keypoints_data  =  [
7413-                 ([5 , 5 ], min_valid_edge_distance   <=   5 ),  # Valid point inside image 
7414-                 ([canvas_w  -  6 , canvas_h  -  6 ], min_valid_edge_distance   <=   5 ),  # Valid point near corner 
7411+                 ([5 , 5 ], True ),  # Valid point inside image 
7412+                 ([canvas_w  -  6 , canvas_h  -  6 ], True ),  # Valid point near corner 
74157413                ([canvas_w  //  2 , canvas_h  //  2 ], True ),  # Valid point in center 
74167414                ([- 1 , canvas_h  //  2 ], False ),  # Invalid: x < 0 
74177415                ([canvas_w  //  2 , - 1 ], False ),  # Invalid: y < 0 
74187416                ([canvas_w , canvas_h  //  2 ], False ),  # Invalid: x >= canvas_w 
74197417                ([canvas_w  //  2 , canvas_h ], False ),  # Invalid: y >= canvas_h 
7420-                 ([0 , 0 ], min_valid_edge_distance   <=   0 ),  # Edge case: exactly on edge 
7421-                 ([canvas_w  -  1 , canvas_h  -  1 ], min_valid_edge_distance   <=   0 ),  # Edge case: exactly on edge 
7418+                 ([0 , 0 ], True ),  # Edge case: exactly on edge 
7419+                 ([canvas_w  -  1 , canvas_h  -  1 ], True ),  # Edge case: exactly on edge 
74227420            ]
74237421            points , validity  =  zip (* keypoints_data )
74247422            keypoints  =  torch .tensor (points , dtype = torch .float32 )
@@ -7429,11 +7427,11 @@ def _make_keypoints_with_validity(
74297427                # Group 1: All points valid 
74307428                ([[10 , 10 ], [20 , 20 ], [30 , 30 ]], True ),
74317429                # Group 2: One invalid point (should be removed if min_invalid_points=1) 
7432-                 ([[10 , 10 ], [20 , 20 ], [- 5 , 30 ]], min_invalid_points   >   1 ),
7430+                 ([[10 , 10 ], [20 , 20 ], [- 5 , 30 ]], False ),
74337431                # Group 3: All points invalid 
74347432                ([[- 1 , - 1 ], [- 2 , - 2 ], [- 3 , - 3 ]], False ),
74357433                # Group 4: Mix of valid and invalid (depends on min_invalid_points) 
7436-                 ([[10 , 10 ], [- 1 , 20 ], [- 2 , 30 ]], min_invalid_points   >   2 ),
7434+                 ([[10 , 10 ], [- 1 , 20 ], [- 2 , 30 ]], False ),
74377435            ]
74387436            groups , validity  =  zip (* keypoints_data )
74397437            keypoints  =  torch .tensor (groups , dtype = torch .float32 )
@@ -7444,7 +7442,7 @@ def _make_keypoints_with_validity(
74447442                # Object 1: All bones valid 
74457443                ([[[10 , 10 ], [15 , 15 ]], [[20 , 20 ], [25 , 25 ]]], True ),
74467444                # Object 2: One bone with invalid point 
7447-                 ([[[10 , 10 ], [15 , 15 ]], [[- 1 , 20 ], [25 , 25 ]]], min_invalid_points   >   1 ),
7445+                 ([[[10 , 10 ], [15 , 15 ]], [[- 1 , 20 ], [25 , 25 ]]], False ),
74487446                # Object 3: All bones invalid 
74497447                ([[[- 1 , - 1 ], [- 2 , - 2 ]], [[- 3 , - 3 ], [- 4 , - 4 ]]], False ),
74507448            ]
@@ -7457,26 +7455,14 @@ def _make_keypoints_with_validity(
74577455        return  keypoints , validity 
74587456
74597457    @pytest .mark .parametrize ("shape" , ["2d" , "3d" , "4d" ]) 
7460-     @pytest .mark .parametrize ("min_valid_edge_distance" , [0 , 1 , 5 , 6 ]) 
7461-     @pytest .mark .parametrize ("min_invalid_points" , [1 , 2 , 0.5 ]) 
74627458    @pytest .mark .parametrize ("input_type" , [torch .Tensor , tv_tensors .KeyPoints ]) 
7463-     def  test_functional (self , shape , min_valid_edge_distance ,  min_invalid_points ,  input_type ):
7459+     def  test_functional (self , shape , input_type ):
74647460        """Test the sanitize_keypoints functional interface.""" 
7465-         # Check for invalid configuration 
7466-         if  shape  ==  "2d"  and  min_invalid_points  >  1 :
7467-             pytest .xfail ("min_invalid_points > 1 does not make sense for 2D keypoints" )
74687461
74697462        # Create inputs 
74707463        canvas_size  =  (50 , 50 )
7471-         if  isinstance (min_invalid_points , float ):
7472-             num_groups  =  4  if  shape  ==  "4d"  else  3 
7473-             min_invalid_points_int  =  math .ceil (min_invalid_points  *  num_groups )
7474-         else :
7475-             min_invalid_points_int  =  min_invalid_points 
74767464        keypoints , expected_validity  =  self ._make_keypoints_with_validity (
74777465            canvas_size = canvas_size ,
7478-             min_valid_edge_distance = min_valid_edge_distance ,
7479-             min_invalid_points = min_invalid_points_int ,
74807466            shape = shape ,
74817467        )
74827468
@@ -7490,8 +7476,6 @@ def test_functional(self, shape, min_valid_edge_distance, min_invalid_points, in
74907476        result_keypoints , valid_mask  =  F .sanitize_keypoints (
74917477            keypoints ,
74927478            canvas_size = canvas_size_arg ,
7493-             min_valid_edge_distance = min_valid_edge_distance ,
7494-             min_invalid_points = min_invalid_points ,
74957479        )
74967480
74977481        # Check return types 
@@ -7522,8 +7506,6 @@ def test_kernel(self, shape):
75227506        )
75237507
75247508    @pytest .mark .parametrize ("shape" , ["2d" , "3d" , "4d" ]) 
7525-     @pytest .mark .parametrize ("min_valid_edge_distance" , [0 , 2 ]) 
7526-     @pytest .mark .parametrize ("min_invalid_points" , [1 , 0.3 ]) 
75277509    @pytest .mark .parametrize ( 
75287510        "labels_getter" , 
75297511        ( 
@@ -7536,26 +7518,15 @@ def test_kernel(self, shape):
75367518        ), 
75377519    ) 
75387520    @pytest .mark .parametrize ("sample_type" , (tuple , dict )) 
7539-     def  test_transform (self , shape , min_valid_edge_distance ,  min_invalid_points ,  labels_getter , sample_type ):
7521+     def  test_transform (self , shape , labels_getter , sample_type ):
75407522        """Test the SanitizeKeyPoints transform class.""" 
75417523        if  sample_type  is  tuple  and  not  isinstance (labels_getter , str ):
75427524            # Lambda-based labels_getter doesn't work with tuple input 
75437525            return 
75447526
7545-         # Check for invalid configuration 
7546-         if  shape  ==  "2d"  and  min_invalid_points  >  1 :
7547-             pytest .xfail ("min_invalid_points > 1 does not make sense for 2D keypoints" )
7548- 
75497527        canvas_size  =  (40 , 40 )
7550-         if  isinstance (min_invalid_points , float ):
7551-             num_groups  =  4  if  shape  ==  "4d"  else  3 
7552-             min_invalid_points_int  =  math .ceil (min_invalid_points  *  num_groups )
7553-         else :
7554-             min_invalid_points_int  =  min_invalid_points 
75557528        keypoints , expected_validity  =  self ._make_keypoints_with_validity (
75567529            canvas_size = canvas_size ,
7557-             min_valid_edge_distance = min_valid_edge_distance ,
7558-             min_invalid_points = min_invalid_points_int ,
75597530            shape = shape ,
75607531        )
75617532
@@ -7585,8 +7556,6 @@ def test_transform(self, shape, min_valid_edge_distance, min_invalid_points, lab
75857556
75867557        # Apply transform 
75877558        transform  =  transforms .SanitizeKeyPoints (
7588-             min_valid_edge_distance = min_valid_edge_distance ,
7589-             min_invalid_points = min_invalid_points ,
75907559            labels_getter = labels_getter ,
75917560        )
75927561        out  =  transform (sample )
@@ -7644,6 +7613,7 @@ def test_edge_cases(self):
76447613        # Test empty keypoints 
76457614        empty_keypoints  =  tv_tensors .KeyPoints (torch .empty (0 , 2 ), canvas_size = canvas_size )
76467615        result , valid_mask  =  F .sanitize_keypoints (empty_keypoints )
7616+         print (empty_keypoints , result , valid_mask )
76477617        assert  tuple (result .shape ) ==  (0 , 2 )
76487618        assert  valid_mask .shape [0 ] ==  0 
76497619
@@ -7659,43 +7629,6 @@ def test_edge_cases(self):
76597629        assert  tuple (result .shape ) ==  (0 , 2 )
76607630        assert  not  valid_mask .any ()
76617631
7662-     def  test_min_invalid_points_fraction (self ):
7663-         """Test min_invalid_points as a fraction.""" 
7664-         canvas_size  =  (20 , 20 )
7665- 
7666-         # Create 3D keypoints with 4 points per object 
7667-         keypoints  =  torch .tensor (
7668-             [
7669-                 # Object 1: 1 invalid point out of 4 (25% invalid) 
7670-                 [[5 , 5 ], [10 , 10 ], [15 , 15 ], [- 1 , - 1 ]],
7671-                 # Object 2: 2 invalid points out of 4 (50% invalid) 
7672-                 [[5 , 5 ], [10 , 10 ], [- 1 , - 1 ], [- 2 , - 2 ]],
7673-                 # Object 3: 3 invalid points out of 4 (75% invalid) 
7674-                 [[5 , 5 ], [- 1 , - 1 ], [- 2 , - 2 ], [- 3 , - 3 ]],
7675-             ],
7676-             dtype = torch .float32 ,
7677-         )
7678- 
7679-         keypoints  =  tv_tensors .KeyPoints (keypoints , canvas_size = canvas_size )
7680- 
7681-         # Test with 30% threshold - should keep object 1 
7682-         result , valid_mask  =  F .sanitize_keypoints (keypoints , min_invalid_points = 0.3 )
7683-         expected_valid  =  torch .tensor ([True , False , False ])
7684-         assert_equal (valid_mask , expected_valid )
7685-         assert  result .shape [0 ] ==  1 
7686- 
7687-         # Test with 60% threshold - should keep objects 1 and 2 
7688-         result , valid_mask  =  F .sanitize_keypoints (keypoints , min_invalid_points = 0.6 )
7689-         expected_valid  =  torch .tensor ([True , True , False ])
7690-         assert_equal (valid_mask , expected_valid )
7691-         assert  result .shape [0 ] ==  2 
7692- 
7693-         # Test with 100% threshold - should keep all objects 
7694-         result , valid_mask  =  F .sanitize_keypoints (keypoints , min_invalid_points = 1.0 )
7695-         expected_valid  =  torch .tensor ([True , True , True ])
7696-         assert_equal (valid_mask , expected_valid )
7697-         assert  result .shape [0 ] ==  3 
7698- 
76997632    def  test_errors_functional (self ):
77007633        """Test error conditions for the functional interface.""" 
77017634        good_keypoints  =  tv_tensors .KeyPoints ([[5 , 5 ]], canvas_size = (10 , 10 ))
@@ -7708,16 +7641,6 @@ def test_errors_functional(self):
77087641        with  pytest .raises (ValueError , match = "canvas_size must be None" ):
77097642            F .sanitize_keypoints (good_keypoints , canvas_size = (10 , 10 ))
77107643
7711-         # Test invalid min_invalid_points 
7712-         with  pytest .raises (ValueError , match = "min_invalid_points must be > 0" ):
7713-             F .sanitize_keypoints (good_keypoints , min_invalid_points = 0 )
7714- 
7715-         with  pytest .raises (ValueError , match = "min_invalid_points must be > 0" ):
7716-             F .sanitize_keypoints (good_keypoints , min_invalid_points = - 1 )
7717- 
7718-         with  pytest .raises (ValueError , match = "so min_invalid_points must be 1" ):
7719-             F .sanitize_keypoints (good_keypoints , min_invalid_points = 2 )
7720- 
77217644    def  test_errors_transform (self ):
77227645        """Test error conditions for the transform class.""" 
77237646        good_keypoints  =  tv_tensors .KeyPoints ([[5 , 5 ]], canvas_size = (10 , 10 ))
@@ -7726,10 +7649,6 @@ def test_errors_transform(self):
77267649        with  pytest .raises (ValueError , match = "labels_getter should either be" ):
77277650            transforms .SanitizeKeyPoints (labels_getter = "invalid_type" )  # type: ignore 
77287651
7729-         # Test invalid min_invalid_points 
7730-         with  pytest .raises (ValueError , match = "min_invalid_points must be > 0" ):
7731-             transforms .SanitizeKeyPoints (min_invalid_points = 0 )
7732- 
77337652        # Test missing labels key 
77347653        with  pytest .raises (ValueError , match = "Could not infer where the labels are" ):
77357654            bad_sample  =  {"keypoints" : good_keypoints , "BAD_KEY" : torch .tensor ([0 ])}
@@ -7745,10 +7664,6 @@ def test_errors_transform(self):
77457664            bad_sample  =  {"keypoints" : good_keypoints , "labels" : torch .tensor ([0 , 1 , 2 ])}
77467665            transforms .SanitizeKeyPoints (labels_getter = "default" )(bad_sample )
77477666
7748-         # Test min_invalid_points > 1 for 2D keypoints 
7749-         with  pytest .raises (ValueError , match = "so min_invalid_points must be 1" ):
7750-             transforms .SanitizeKeyPoints (min_invalid_points = 2 )(good_keypoints )
7751- 
77527667    def  test_no_label (self ):
77537668        """Test transform without labels.""" 
77547669        img  =  make_image ()
0 commit comments