@@ -4661,7 +4661,36 @@ def test_bounding_boxes_correctness(self, padding, format, dtype, device, fn):
46614661
46624662 assert_equal (actual , expected )
46634663
4664- # TODOKP need keypoint correctness tests
4664+ def _reference_pad_keypoints (self , keypoints , * , padding ):
4665+ if isinstance (padding , int ):
4666+ padding = [padding ]
4667+ left , top , right , bottom = padding * (4 // len (padding ))
4668+
4669+ affine_matrix = np .array (
4670+ [
4671+ [1 , 0 , left ],
4672+ [0 , 1 , top ],
4673+ ],
4674+ )
4675+
4676+ height = keypoints .canvas_size [0 ] + top + bottom
4677+ width = keypoints .canvas_size [1 ] + left + right
4678+
4679+ return reference_affine_keypoints_helper (
4680+ keypoints , affine_matrix = affine_matrix , new_canvas_size = (height , width )
4681+ )
4682+
4683+ @pytest .mark .parametrize ("padding" , CORRECTNESS_PADDINGS )
4684+ @pytest .mark .parametrize ("dtype" , [torch .int64 , torch .float32 ])
4685+ @pytest .mark .parametrize ("device" , cpu_and_cuda ())
4686+ @pytest .mark .parametrize ("fn" , [F .pad , transform_cls_to_functional (transforms .Pad )])
4687+ def test_keypoints_correctness (self , padding , dtype , device , fn ):
4688+ keypoints = make_keypoints (dtype = dtype , device = device )
4689+
4690+ actual = fn (keypoints , padding = padding )
4691+ expected = self ._reference_pad_keypoints (keypoints , padding = padding )
4692+
4693+ assert_equal (actual , expected )
46654694
46664695
46674696class TestCenterCrop :
0 commit comments