@@ -7397,6 +7397,326 @@ def test_errors_functional(self):
73977397            F .sanitize_bounding_boxes (good_bbox .tolist ())
73987398
73997399
7400+ class  TestSanitizeKeyPoints :
7401+     def  _make_keypoints_with_validity (
7402+         self ,
7403+         canvas_size = (100 , 100 ),
7404+         shape = "2d" ,  # "2d", "3d", "4d" for different keypoint shapes 
7405+     ):
7406+         """Create keypoints with known validity for testing.""" 
7407+         canvas_h , canvas_w  =  canvas_size 
7408+ 
7409+         if  shape  ==  "2d" :  # [N_points, 2] 
7410+             keypoints_data  =  [
7411+                 ([5 , 5 ], True ),  # Valid point inside image 
7412+                 ([canvas_w  -  6 , canvas_h  -  6 ], True ),  # Valid point near corner 
7413+                 ([canvas_w  //  2 , canvas_h  //  2 ], True ),  # Valid point in center 
7414+                 ([- 1 , canvas_h  //  2 ], False ),  # Invalid: x < 0 
7415+                 ([canvas_w  //  2 , - 1 ], False ),  # Invalid: y < 0 
7416+                 ([canvas_w , canvas_h  //  2 ], False ),  # Invalid: x >= canvas_w 
7417+                 ([canvas_w  //  2 , canvas_h ], False ),  # Invalid: y >= canvas_h 
7418+                 ([0 , 0 ], True ),  # Edge case: exactly on edge 
7419+                 ([canvas_w  -  1 , canvas_h  -  1 ], True ),  # Edge case: exactly on edge 
7420+             ]
7421+             points , validity  =  zip (* keypoints_data )
7422+             keypoints  =  torch .tensor (points , dtype = torch .float32 )
7423+ 
7424+         elif  shape  ==  "3d" :  # [N_objects, N_points, 2] 
7425+             # Create groups of keypoints with different validity patterns 
7426+             keypoints_data  =  [
7427+                 # Group 1: All points valid 
7428+                 ([[10 , 10 ], [20 , 20 ], [30 , 30 ]], True ),
7429+                 # Group 2: One invalid point (should be removed if min_invalid_points=1) 
7430+                 ([[10 , 10 ], [20 , 20 ], [- 5 , 30 ]], False ),
7431+                 # Group 3: All points invalid 
7432+                 ([[- 1 , - 1 ], [- 2 , - 2 ], [- 3 , - 3 ]], False ),
7433+                 # Group 4: Mix of valid and invalid (depends on min_invalid_points) 
7434+                 ([[10 , 10 ], [- 1 , 20 ], [- 2 , 30 ]], False ),
7435+             ]
7436+             groups , validity  =  zip (* keypoints_data )
7437+             keypoints  =  torch .tensor (groups , dtype = torch .float32 )
7438+ 
7439+         elif  shape  ==  "4d" :  # [N_objects, N_bones, 2, 2] 
7440+             # Create bone-like structures (pairs of points) 
7441+             keypoints_data  =  [
7442+                 # Object 1: All bones valid 
7443+                 ([[[10 , 10 ], [15 , 15 ]], [[20 , 20 ], [25 , 25 ]]], True ),
7444+                 # Object 2: One bone with invalid point 
7445+                 ([[[10 , 10 ], [15 , 15 ]], [[- 1 , 20 ], [25 , 25 ]]], False ),
7446+                 # Object 3: All bones invalid 
7447+                 ([[[- 1 , - 1 ], [- 2 , - 2 ]], [[- 3 , - 3 ], [- 4 , - 4 ]]], False ),
7448+             ]
7449+             objects , validity  =  zip (* keypoints_data )
7450+             keypoints  =  torch .tensor (objects , dtype = torch .float32 )
7451+ 
7452+         else :
7453+             raise  ValueError (f"Unsupported shape: { shape }  )
7454+ 
7455+         return  keypoints , validity 
7456+ 
7457+     @pytest .mark .parametrize ("shape" , ["2d" , "3d" , "4d" ]) 
7458+     @pytest .mark .parametrize ("input_type" , [torch .Tensor , tv_tensors .KeyPoints ]) 
7459+     def  test_functional (self , shape , input_type ):
7460+         """Test the sanitize_keypoints functional interface.""" 
7461+ 
7462+         # Create inputs 
7463+         canvas_size  =  (50 , 50 )
7464+         keypoints , expected_validity  =  self ._make_keypoints_with_validity (
7465+             canvas_size = canvas_size ,
7466+             shape = shape ,
7467+         )
7468+ 
7469+         if  input_type  is  tv_tensors .KeyPoints :
7470+             keypoints  =  tv_tensors .KeyPoints (keypoints , canvas_size = canvas_size )
7471+             canvas_size_arg  =  None 
7472+         else :
7473+             canvas_size_arg  =  canvas_size 
7474+ 
7475+         # Apply function to be tested 
7476+         result_keypoints , valid_mask  =  F .sanitize_keypoints (
7477+             keypoints ,
7478+             canvas_size = canvas_size_arg ,
7479+         )
7480+ 
7481+         # Check return types 
7482+         assert  isinstance (result_keypoints , input_type )
7483+         assert  isinstance (valid_mask , torch .Tensor )
7484+         assert  valid_mask .dtype  ==  torch .bool 
7485+ 
7486+         # Check that valid mask matches expected validity 
7487+         assert_equal (valid_mask , torch .tensor (expected_validity ))
7488+ 
7489+         # Check that result has correct number of valid keypoints 
7490+         assert  result_keypoints .shape [0 ] ==  valid_mask .sum ().item ()
7491+ 
7492+         # Check that remaining keypoints shape is preserved 
7493+         assert  result_keypoints .shape [1 :] ==  keypoints .shape [1 :]
7494+ 
7495+     @pytest .mark .parametrize ("shape" , ["2d" , "3d" , "4d" ]) 
7496+     def  test_kernel (self , shape ):
7497+         """Test kernel functionality.""" 
7498+         canvas_size  =  (30 , 30 )
7499+         keypoints , _  =  self ._make_keypoints_with_validity (canvas_size = canvas_size , shape = shape )
7500+ 
7501+         check_kernel (
7502+             F .sanitize_keypoints ,
7503+             input = keypoints ,
7504+             canvas_size = canvas_size ,
7505+             check_batched_vs_unbatched = False ,  # This function doesn't support batching 
7506+         )
7507+ 
7508+     @pytest .mark .parametrize ("shape" , ["2d" , "3d" , "4d" ]) 
7509+     @pytest .mark .parametrize ( 
7510+         "labels_getter" , 
7511+         ( 
7512+             "default" , 
7513+             lambda  inputs : inputs ["labels" ], 
7514+             lambda  inputs : (inputs ["labels" ], inputs ["other_labels" ]), 
7515+             lambda  inputs : [inputs ["labels" ], inputs ["other_labels" ]], 
7516+             None , 
7517+             lambda  inputs : None , 
7518+         ), 
7519+     ) 
7520+     @pytest .mark .parametrize ("sample_type" , (tuple , dict )) 
7521+     def  test_transform (self , shape , labels_getter , sample_type ):
7522+         """Test the SanitizeKeyPoints transform class.""" 
7523+         if  sample_type  is  tuple  and  not  isinstance (labels_getter , str ):
7524+             # Lambda-based labels_getter doesn't work with tuple input 
7525+             return 
7526+ 
7527+         canvas_size  =  (40 , 40 )
7528+         keypoints , expected_validity  =  self ._make_keypoints_with_validity (
7529+             canvas_size = canvas_size ,
7530+             shape = shape ,
7531+         )
7532+ 
7533+         keypoints  =  tv_tensors .KeyPoints (keypoints , canvas_size = canvas_size )
7534+         num_keypoints  =  keypoints .shape [0 ]
7535+ 
7536+         # Create associated labels and other data 
7537+         labels  =  torch .arange (num_keypoints )
7538+         other_labels  =  torch .arange (num_keypoints ) *  2 
7539+         masks  =  tv_tensors .Mask (torch .randint (0 , 2 , size = (num_keypoints , * canvas_size )))
7540+         whatever  =  torch .rand (10 )
7541+         input_img  =  torch .randint (0 , 256 , size = (1 , 3 , * canvas_size ), dtype = torch .uint8 )
7542+ 
7543+         sample  =  {
7544+             "image" : input_img ,
7545+             "labels" : labels ,
7546+             "keypoints" : keypoints ,
7547+             "other_labels" : other_labels ,
7548+             "whatever" : whatever ,
7549+             "None" : None ,
7550+             "masks" : masks ,
7551+         }
7552+ 
7553+         if  sample_type  is  tuple :
7554+             img  =  sample .pop ("image" )
7555+             sample  =  (img , sample )
7556+ 
7557+         # Apply transform 
7558+         transform  =  transforms .SanitizeKeyPoints (
7559+             labels_getter = labels_getter ,
7560+         )
7561+         out  =  transform (sample )
7562+ 
7563+         # Extract outputs 
7564+         if  sample_type  is  tuple :
7565+             out_image  =  out [0 ]
7566+             out_labels  =  out [1 ]["labels" ]
7567+             out_other_labels  =  out [1 ]["other_labels" ]
7568+             out_keypoints  =  out [1 ]["keypoints" ]
7569+             out_masks  =  out [1 ]["masks" ]
7570+             out_whatever  =  out [1 ]["whatever" ]
7571+         else :
7572+             out_image  =  out ["image" ]
7573+             out_labels  =  out ["labels" ]
7574+             out_other_labels  =  out ["other_labels" ]
7575+             out_keypoints  =  out ["keypoints" ]
7576+             out_masks  =  out ["masks" ]
7577+             out_whatever  =  out ["whatever" ]
7578+ 
7579+         # Verify unchanged elements 
7580+         assert_equal (out_image , input_img )
7581+         assert_equal (out_whatever , whatever )
7582+         assert_equal (out_masks , masks )
7583+ 
7584+         # Verify types 
7585+         assert  isinstance (out_keypoints , tv_tensors .KeyPoints )
7586+         assert  isinstance (out_masks , tv_tensors .Mask )
7587+ 
7588+         # Calculate expected valid indices 
7589+         valid_indices  =  [i  for  i , is_valid  in  enumerate (expected_validity ) if  is_valid ]
7590+ 
7591+         # Test label handling 
7592+         if  labels_getter  is  None  or  (callable (labels_getter ) and  labels_getter (sample ) is  None ):
7593+             # Labels should be unchanged 
7594+             assert  out_labels  is  labels 
7595+             assert  out_other_labels  is  other_labels 
7596+         else :
7597+             # Labels should be filtered 
7598+             assert  isinstance (out_labels , torch .Tensor )
7599+             assert  out_keypoints .shape [0 ] ==  out_labels .shape [0 ]
7600+             assert  out_labels .tolist () ==  valid_indices 
7601+ 
7602+             if  callable (labels_getter ) and  isinstance (labels_getter (sample ), (tuple , list )):
7603+                 # other_labels should also be filtered 
7604+                 assert_equal (out_other_labels , out_labels  *  2 )  # Since other_labels = labels * 2 
7605+             else :
7606+                 # other_labels and masks should be unchanged 
7607+                 assert_equal (out_other_labels , other_labels )
7608+ 
7609+     def  test_edge_cases (self ):
7610+         """Test edge cases and boundary conditions.""" 
7611+         canvas_size  =  (10 , 10 )
7612+ 
7613+         # Test empty keypoints 
7614+         empty_keypoints  =  tv_tensors .KeyPoints (torch .empty (0 , 2 ), canvas_size = canvas_size )
7615+         result , valid_mask  =  F .sanitize_keypoints (empty_keypoints )
7616+         print (empty_keypoints , result , valid_mask )
7617+         assert  tuple (result .shape ) ==  (0 , 2 )
7618+         assert  valid_mask .shape [0 ] ==  0 
7619+ 
7620+         # Test single valid keypoint 
7621+         single_valid  =  tv_tensors .KeyPoints ([[5 , 5 ]], canvas_size = canvas_size )
7622+         result , valid_mask  =  F .sanitize_keypoints (single_valid )
7623+         assert  tuple (result .shape ) ==  (1 , 2 )
7624+         assert  valid_mask .all ()
7625+ 
7626+         # Test single invalid keypoint 
7627+         single_invalid  =  tv_tensors .KeyPoints ([[- 1 , - 1 ]], canvas_size = canvas_size )
7628+         result , valid_mask  =  F .sanitize_keypoints (single_invalid )
7629+         assert  tuple (result .shape ) ==  (0 , 2 )
7630+         assert  not  valid_mask .any ()
7631+ 
7632+     def  test_errors_functional (self ):
7633+         """Test error conditions for the functional interface.""" 
7634+         good_keypoints  =  tv_tensors .KeyPoints ([[5 , 5 ]], canvas_size = (10 , 10 ))
7635+ 
7636+         # Test missing canvas_size for pure tensor 
7637+         with  pytest .raises (ValueError , match = "canvas_size cannot be None" ):
7638+             F .sanitize_keypoints (good_keypoints .as_subclass (torch .Tensor ), canvas_size = None )
7639+ 
7640+         # Test canvas_size provided for tv_tensor 
7641+         with  pytest .raises (ValueError , match = "canvas_size must be None" ):
7642+             F .sanitize_keypoints (good_keypoints , canvas_size = (10 , 10 ))
7643+ 
7644+     def  test_errors_transform (self ):
7645+         """Test error conditions for the transform class.""" 
7646+         good_keypoints  =  tv_tensors .KeyPoints ([[5 , 5 ]], canvas_size = (10 , 10 ))
7647+ 
7648+         # Test invalid labels_getter 
7649+         with  pytest .raises (ValueError , match = "labels_getter should either be" ):
7650+             transforms .SanitizeKeyPoints (labels_getter = "invalid_type" )  # type: ignore 
7651+ 
7652+         # Test missing labels key 
7653+         with  pytest .raises (ValueError , match = "Could not infer where the labels are" ):
7654+             bad_sample  =  {"keypoints" : good_keypoints , "BAD_KEY" : torch .tensor ([0 ])}
7655+             transforms .SanitizeKeyPoints (labels_getter = "default" )(bad_sample )
7656+ 
7657+         # Test labels not a tensor 
7658+         with  pytest .raises (ValueError , match = "must be a tensor" ):
7659+             bad_sample  =  {"keypoints" : good_keypoints , "labels" : [0 ]}
7660+             transforms .SanitizeKeyPoints (labels_getter = "default" )(bad_sample )
7661+ 
7662+         # Test mismatched sizes 
7663+         with  pytest .raises (ValueError , match = "Number of" ):
7664+             bad_sample  =  {"keypoints" : good_keypoints , "labels" : torch .tensor ([0 , 1 , 2 ])}
7665+             transforms .SanitizeKeyPoints (labels_getter = "default" )(bad_sample )
7666+ 
7667+     def  test_no_label (self ):
7668+         """Test transform without labels.""" 
7669+         img  =  make_image ()
7670+         keypoints  =  make_keypoints ()
7671+ 
7672+         # Should raise error without labels_getter=None 
7673+         with  pytest .raises (ValueError , match = "or a two-tuple whose second item is a dict" ):
7674+             transforms .SanitizeKeyPoints (labels_getter = "default" )(img , keypoints )
7675+ 
7676+         # Should work with labels_getter=None 
7677+         out_img , out_keypoints  =  transforms .SanitizeKeyPoints (labels_getter = None )(img , keypoints )
7678+         assert  isinstance (out_img , tv_tensors .Image )
7679+         assert  isinstance (out_keypoints , tv_tensors .KeyPoints )
7680+ 
7681+     @pytest .mark .parametrize ("device" , cpu_and_cuda ()) 
7682+     def  test_device_and_dtype_consistency (self , device ):
7683+         """Test that device and dtype are preserved.""" 
7684+         canvas_size  =  (20 , 20 )
7685+         keypoints  =  torch .tensor ([[5 , 5 ], [15 , 15 ], [- 1 , - 1 ]], dtype = torch .float32 , device = device )
7686+         keypoints  =  tv_tensors .KeyPoints (keypoints , canvas_size = canvas_size )
7687+ 
7688+         result , valid_mask  =  F .sanitize_keypoints (keypoints )
7689+ 
7690+         assert  result .device  ==  keypoints .device 
7691+         assert  result .dtype  ==  keypoints .dtype 
7692+         assert  valid_mask .device  ==  keypoints .device 
7693+ 
7694+     def  test_keypoint_shapes_consistency (self ):
7695+         """Test that different keypoint shapes are handled correctly.""" 
7696+         canvas_size  =  (50 , 50 )
7697+ 
7698+         # Test 2D shape [N_points, 2] 
7699+         kp_2d  =  torch .tensor ([[10 , 10 ], [20 , 20 ], [- 1 , - 1 ]], dtype = torch .float32 )
7700+         kp_2d  =  tv_tensors .KeyPoints (kp_2d , canvas_size = canvas_size )
7701+         result_2d , valid_2d  =  F .sanitize_keypoints (kp_2d )
7702+         assert  result_2d .ndim  ==  2 
7703+         assert  result_2d .shape [1 :] ==  kp_2d .shape [1 :]
7704+ 
7705+         # Test 3D shape [N_objects, N_points, 2] 
7706+         kp_3d  =  torch .tensor ([[[10 , 10 ], [20 , 20 ]], [[- 1 , - 1 ], [30 , 30 ]]], dtype = torch .float32 )
7707+         kp_3d  =  tv_tensors .KeyPoints (kp_3d , canvas_size = canvas_size )
7708+         result_3d , valid_3d  =  F .sanitize_keypoints (kp_3d )
7709+         assert  result_3d .ndim  ==  3 
7710+         assert  result_3d .shape [1 :] ==  kp_3d .shape [1 :]
7711+ 
7712+         # Test 4D shape [N_objects, N_bones, 2, 2] 
7713+         kp_4d  =  torch .tensor ([[[[10 , 10 ], [20 , 20 ]]], [[[- 1 , - 1 ], [30 , 30 ]]]], dtype = torch .float32 )
7714+         kp_4d  =  tv_tensors .KeyPoints (kp_4d , canvas_size = canvas_size )
7715+         result_4d , valid_4d  =  F .sanitize_keypoints (kp_4d )
7716+         assert  result_4d .ndim  ==  4 
7717+         assert  result_4d .shape [1 :] ==  kp_4d .shape [1 :]
7718+ 
7719+ 
74007720class  TestJPEG :
74017721    @pytest .mark .parametrize ("quality" , [5 , 75 ]) 
74027722    @pytest .mark .parametrize ("color_space" , ["RGB" , "GRAY" ]) 
0 commit comments