@@ -633,6 +633,7 @@ def affine_rotated_bounding_boxes(bounding_boxes):
633633
634634def  reference_affine_keypoints_helper (keypoints , * , affine_matrix , new_canvas_size = None , clamp = True ):
635635    canvas_size  =  new_canvas_size  or  keypoints .canvas_size 
636+     clamping_mode  =  keypoints .clamping_mode 
636637
637638    def  affine_keypoints (keypoints ):
638639        dtype  =  keypoints .dtype 
@@ -652,15 +653,15 @@ def affine_keypoints(keypoints):
652653        )
653654
654655        if  clamp :
655-             output  =  F .clamp_keypoints (output , canvas_size = canvas_size )
656+             output  =  F .clamp_keypoints (output , canvas_size = canvas_size ,  clamping_mode = clamping_mode )
656657        else :
657658            dtype  =  output .dtype 
658659
659660        return  output .to (dtype = dtype , device = device )
660661
661662    return  tv_tensors .KeyPoints (
662663        torch .cat ([affine_keypoints (k ) for  k  in  keypoints .reshape (- 1 , 2 ).unbind ()], dim = 0 ).reshape (keypoints .shape ),
663-         canvas_size = canvas_size ,
664+         canvas_size = canvas_size ,  clamping_mode = clamping_mode 
664665    )
665666
666667
@@ -3309,7 +3310,6 @@ def test_functional(self, make_input):
33093310            (F .elastic_image , tv_tensors .Image ), 
33103311            (F .elastic_mask , tv_tensors .Mask ), 
33113312            (F .elastic_video , tv_tensors .Video ), 
3312-             (F .elastic_keypoints , tv_tensors .KeyPoints ), 
33133313        ], 
33143314    ) 
33153315    def  test_functional_signature (self , kernel , input_type ):
@@ -5325,6 +5325,7 @@ def test_correctness_perspective_bounding_boxes(self, startpoints, endpoints, fo
53255325
53265326    def  _reference_perspective_keypoints (self , keypoints , * , startpoints , endpoints ):
53275327        canvas_size  =  keypoints .canvas_size 
5328+         clamping_mode  =  keypoints .clamping_mode 
53285329        dtype  =  keypoints .dtype 
53295330        device  =  keypoints .device 
53305331
@@ -5364,13 +5365,15 @@ def perspective_keypoints(keypoints):
53645365            return  F .clamp_keypoints (
53655366                output ,
53665367                canvas_size = canvas_size ,
5368+                 clamping_mode = clamping_mode 
53675369            ).to (dtype = dtype , device = device )
53685370
53695371        return  tv_tensors .KeyPoints (
53705372            torch .cat ([perspective_keypoints (k ) for  k  in  keypoints .reshape (- 1 , 2 ).unbind ()], dim = 0 ).reshape (
53715373                keypoints .shape 
53725374            ),
53735375            canvas_size = canvas_size ,
5376+             clamping_mode = clamping_mode ,
53745377        )
53755378
53765379    @pytest .mark .parametrize (("startpoints" , "endpoints" ), START_END_POINTS ) 
@@ -5733,32 +5736,80 @@ def test_error(self):
57335736
57345737
57355738class  TestClampKeyPoints :
5739+     @pytest .mark .parametrize ("clamping_mode" , ("soft" , "hard" , None )) 
57365740    @pytest .mark .parametrize ("dtype" , [torch .int64 , torch .float32 ]) 
57375741    @pytest .mark .parametrize ("device" , cpu_and_cuda ()) 
5738-     def  test_kernel (self , dtype , device ):
5739-         keypoints  =  make_keypoints (dtype = dtype , device = device )
5742+     def  test_kernel (self , clamping_mode ,  dtype , device ):
5743+         keypoints  =  make_keypoints (dtype = dtype , device = device ,  clamping_mode = clamping_mode )
57405744        check_kernel (
57415745            F .clamp_keypoints ,
57425746            keypoints ,
57435747            canvas_size = keypoints .canvas_size ,
5748+             clamping_mode = clamping_mode ,
57445749        )
57455750
5746-     def  test_functional (self ):
5747-         check_functional (F .clamp_keypoints , make_keypoints ())
5751+     @pytest .mark .parametrize ("clamping_mode" , ("soft" , "hard" , None )) 
5752+     def  test_functional (self , clamping_mode ):
5753+         check_functional (F .clamp_keypoints , make_keypoints (clamping_mode = clamping_mode ))
57485754
57495755    def  test_errors (self ):
57505756        input_tv_tensor  =  make_keypoints ()
57515757        input_pure_tensor  =  input_tv_tensor .as_subclass (torch .Tensor )
57525758
5753-         with  pytest .raises (ValueError , match = "`canvas_size` has  to be passed" ):
5759+         with  pytest .raises (ValueError , match = "`canvas_size` and `clamping_mode` have  to be passed. " ):
57545760            F .clamp_keypoints (input_pure_tensor , canvas_size = None )
57555761
57565762        with  pytest .raises (ValueError , match = "`canvas_size` must not be passed" ):
57575763            F .clamp_keypoints (input_tv_tensor , canvas_size = input_tv_tensor .canvas_size )
5764+         with  pytest .raises (ValueError , match = "clamping_mode must be soft," ):
5765+             F .clamp_keypoints (input_tv_tensor , clamping_mode = "bad" )
5766+         with  pytest .raises (ValueError , match = "clamping_mode must be soft," ):
5767+             transforms .ClampKeyPoints (clamping_mode = "bad" )(input_tv_tensor )
57585768
57595769    def  test_transform (self ):
57605770        check_transform (transforms .ClampKeyPoints (), make_keypoints ())
57615771
5772+     @pytest .mark .parametrize ("constructor_clamping_mode" , ("soft" , "hard" , None )) 
5773+     @pytest .mark .parametrize ("clamping_mode" , ("soft" , "hard" , None , "auto" )) 
5774+     @pytest .mark .parametrize ("pass_pure_tensor" , (True , False )) 
5775+     @pytest .mark .parametrize ("fn" , [F .clamp_keypoints , transform_cls_to_functional (transforms .ClampKeyPoints )]) 
5776+     def  test_clamping_mode (self , constructor_clamping_mode , clamping_mode , pass_pure_tensor , fn ):
5777+         # This test checks 2 things: 
5778+         # - That passing clamping_mode=None to the clamp_keypointss 
5779+         #   functional (or to the class) relies on the box's `.clamping_mode` 
5780+         #   attribute 
5781+         # - That clamping happens when it should, and only when it should, i.e. 
5782+         #   when the clamping mode is not None. It doesn't validate the 
5783+         #   numerical results, only that clamping happened. For that, we create 
5784+         #   a keypoints with large coordinates (100) inside of a small 10x10 image. 
5785+ 
5786+         if  pass_pure_tensor  and  fn  is  not F .clamp_keypoints :
5787+             # Only the functional supports pure tensors, not the class 
5788+             return 
5789+         if  pass_pure_tensor  and  clamping_mode  ==  "auto" :
5790+             # cannot leave clamping_mode="auto" when passing pure tensor 
5791+             return 
5792+ 
5793+         keypoints  =  tv_tensors .KeyPoints (
5794+             [[0 , 100 ], [0 , 100 ]],canvas_size = (10 , 10 ), clamping_mode = constructor_clamping_mode 
5795+         )
5796+         expected_clamped_output  =  torch .tensor ([[0 , 9 ], [0 , 9 ]]) if  clamping_mode  ==  "hard"  else  torch .tensor ([[0 , 100 ], [0 , 100 ]])
5797+ 
5798+         if  pass_pure_tensor :
5799+             out  =  fn (
5800+                 keypoints .as_subclass (torch .Tensor ),
5801+                 canvas_size = keypoints .canvas_size ,
5802+                 clamping_mode = clamping_mode ,
5803+             )
5804+         else :
5805+             out  =  fn (keypoints , clamping_mode = clamping_mode )
5806+ 
5807+         clamping_mode_prevailing  =  constructor_clamping_mode  if  clamping_mode  ==  "auto"  else  clamping_mode 
5808+         if  clamping_mode_prevailing  is  None :
5809+             assert_equal (keypoints , out )  # should be a pass-through 
5810+         else :
5811+             assert_equal (out , expected_clamped_output )
5812+ 
57625813
57635814class  TestInvert :
57645815    @pytest .mark .parametrize ("dtype" , [torch .uint8 , torch .int16 , torch .float32 ]) 
0 commit comments