@@ -5444,6 +5444,34 @@ def test_errors(self):
54445444 def test_transform (self ):
54455445 check_transform (transforms .ClampBoundingBoxes (), make_bounding_boxes ())
54465446
5447+ class TestClampKeyPoints :
5448+ @pytest .mark .parametrize ("dtype" , [torch .int64 , torch .float32 ])
5449+ @pytest .mark .parametrize ("device" , cpu_and_cuda ())
5450+ def test_kernel (self , dtype , device ):
5451+ keypoints = make_keypoints (dtype = dtype , device = device )
5452+ check_kernel (
5453+ F .clamp_keypoints ,
5454+ keypoints ,
5455+ canvas_size = keypoints .canvas_size ,
5456+ )
5457+
5458+ def test_functional (self ):
5459+ check_functional (F .clamp_keypoints , make_keypoints ())
5460+
5461+ def test_errors (self ):
5462+ input_tv_tensor = make_keypoints ()
5463+ input_pure_tensor = input_tv_tensor .as_subclass (torch .Tensor )
5464+
5465+ with pytest .raises (ValueError , match = "`canvas_size` has to be passed" ):
5466+ F .clamp_keypoints (input_pure_tensor , canvas_size = None )
5467+
5468+ with pytest .raises (ValueError , match = "`canvas_size` must not be passed" ):
5469+ F .clamp_keypoints (input_tv_tensor , canvas_size = input_tv_tensor .canvas_size )
5470+
5471+ def test_transform (self ):
5472+ check_transform (transforms .ClampKeyPoints (), make_keypoints ())
5473+
5474+
54475475
54485476class TestInvert :
54495477 @pytest .mark .parametrize ("dtype" , [torch .uint8 , torch .int16 , torch .float32 ])
0 commit comments