@@ -5506,20 +5506,23 @@ def test_correctness_image(self, mean, std, dtype, fn):
55065506
55075507class TestClampBoundingBoxes :
55085508 @pytest .mark .parametrize ("format" , list (tv_tensors .BoundingBoxFormat ))
5509+ @pytest .mark .parametrize ("clamping_mode" , ("hard" , "none" )) # TODOBB add soft
55095510 @pytest .mark .parametrize ("dtype" , [torch .int64 , torch .float32 ])
55105511 @pytest .mark .parametrize ("device" , cpu_and_cuda ())
5511- def test_kernel (self , format , dtype , device ):
5512- bounding_boxes = make_bounding_boxes (format = format , dtype = dtype , device = device )
5512+ def test_kernel (self , format , clamping_mode , dtype , device ):
5513+ bounding_boxes = make_bounding_boxes (format = format , clamping_mode = clamping_mode , dtype = dtype , device = device )
55135514 check_kernel (
55145515 F .clamp_bounding_boxes ,
55155516 bounding_boxes ,
55165517 format = bounding_boxes .format ,
55175518 canvas_size = bounding_boxes .canvas_size ,
5519+ clamping_mode = clamping_mode ,
55185520 )
55195521
55205522 @pytest .mark .parametrize ("format" , list (tv_tensors .BoundingBoxFormat ))
5521- def test_functional (self , format ):
5522- check_functional (F .clamp_bounding_boxes , make_bounding_boxes (format = format ))
5523+ @pytest .mark .parametrize ("clamping_mode" , ("hard" , "none" )) # TODOBB add soft
5524+ def test_functional (self , format , clamping_mode ):
5525+ check_functional (F .clamp_bounding_boxes , make_bounding_boxes (format = format , clamping_mode = clamping_mode ))
55235526
55245527 def test_errors (self ):
55255528 input_tv_tensor = make_bounding_boxes ()
@@ -5540,6 +5543,47 @@ def test_errors(self):
55405543
55415544 def test_transform (self ):
55425545 check_transform (transforms .ClampBoundingBoxes (), make_bounding_boxes ())
5546+
5547+ @pytest .mark .parametrize ("rotated" , (True , False ))
5548+ @pytest .mark .parametrize ("constructor_clamping_mode" , ("hard" , "none" ))
5549+ @pytest .mark .parametrize ("clamping_mode" , ("hard" , "none" , None )) # TODOBB add soft here.
5550+ @pytest .mark .parametrize ("pass_pure_tensor" , (True , False ))
5551+ @pytest .mark .parametrize ("fn" , [F .clamp_bounding_boxes , transform_cls_to_functional (transforms .ClampBoundingBoxes )])
5552+ def test_clamping_mode (self , rotated , constructor_clamping_mode , clamping_mode , pass_pure_tensor , fn ):
5553+ # This test checks 2 things:
5554+ # - That passing clamping_mode=None to the clamp_bounding_boxes
5555+ # functional (or to the class) relies on the box's `.clamping_mode`
5556+ # attribute
5557+ # - That clamping happens when it should, and only when it should, i.e.
5558+ # when the clamping mode is not "none". It doesn't validate the
5559+ # nunmerical results, only that clamping happened. For that, we create
5560+ # a large 100x100 box inside of a small 10x10 image.
5561+
5562+ if pass_pure_tensor and fn is not F .clamp_bounding_boxes :
5563+ # Only the functional supports pure tensors, not the class
5564+ return
5565+ if pass_pure_tensor and clamping_mode is None :
5566+ # cannot leave clamping_mode=None when passing pure tensor
5567+ return
5568+
5569+ if rotated :
5570+ boxes = tv_tensors .BoundingBoxes ([0 , 0 , 100 , 100 , 0 ], format = "XYWHR" , canvas_size = (10 , 10 ), clamping_mode = constructor_clamping_mode )
5571+ expected_clamped_output = torch .tensor ([[0 , 0 , 10 , 10 , 0 ]])
5572+ else :
5573+ boxes = tv_tensors .BoundingBoxes ([0 , 100 , 0 , 100 ], format = "XYXY" , canvas_size = (10 , 10 ), clamping_mode = constructor_clamping_mode )
5574+ expected_clamped_output = torch .tensor ([[0 , 10 , 0 , 10 ]])
5575+
5576+ if pass_pure_tensor :
5577+ out = fn (boxes .as_subclass (torch .Tensor ), format = boxes .format , canvas_size = boxes .canvas_size , clamping_mode = clamping_mode )
5578+ else :
5579+ out = fn (boxes , clamping_mode = clamping_mode )
5580+
5581+ clamping_mode_prevailing = constructor_clamping_mode if clamping_mode is None else clamping_mode
5582+ if clamping_mode_prevailing == "none" :
5583+ assert_equal (boxes , out ) # should be a pass-through
5584+ else :
5585+ assert_equal (out , expected_clamped_output )
5586+
55435587
55445588
55455589class TestClampKeyPoints :
0 commit comments