@@ -5805,7 +5805,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
5805
5805
5806
5806
5807
5807
class TestSanitizeBoundingBoxes :
5808
- def _get_boxes_and_valid_mask (self , H = 256 , W = 128 , min_size = 10 ):
5808
+ def _get_boxes_and_valid_mask (self , H = 256 , W = 128 , min_size = 10 , min_area = 10 ):
5809
5809
boxes_and_validity = [
5810
5810
([0 , 1 , 10 , 1 ], False ), # Y1 == Y2
5811
5811
([0 , 1 , 0 , 20 ], False ), # X1 == X2
@@ -5816,17 +5816,16 @@ def _get_boxes_and_valid_mask(self, H=256, W=128, min_size=10):
5816
5816
([- 1 , 1 , 10 , 20 ], False ), # any < 0
5817
5817
([0 , 0 , - 1 , 20 ], False ), # any < 0
5818
5818
([0 , 0 , - 10 , - 1 ], False ), # any < 0
5819
- ([0 , 0 , min_size , 10 ], True ), # H < min_size
5820
- ([0 , 0 , 10 , min_size ], True ), # W < min_size
5821
- ([0 , 0 , W , H ], True ), # TODO: Is that actually OK?? Should it be -1?
5822
- ([1 , 1 , 30 , 20 ], True ),
5823
- ([0 , 0 , 10 , 10 ], True ),
5824
- ([1 , 1 , 30 , 20 ], True ),
5819
+ ([0 , 0 , min_size , 10 ], min_size * 10 >= min_area ), # H < min_size
5820
+ ([0 , 0 , 10 , min_size ], min_size * 10 >= min_area ), # W < min_size
5821
+ ([0 , 0 , W , H ], W * H >= min_area ),
5822
+ ([1 , 1 , 30 , 20 ], 29 * 19 >= min_area ),
5823
+ ([0 , 0 , 10 , 10 ], 9 * 9 >= min_area ),
5824
+ ([1 , 1 , 30 , 20 ], 29 * 19 >= min_area ),
5825
5825
]
5826
5826
5827
5827
random .shuffle (boxes_and_validity ) # For test robustness: mix order of wrong and correct cases
5828
5828
boxes , expected_valid_mask = zip (* boxes_and_validity )
5829
-
5830
5829
boxes = tv_tensors .BoundingBoxes (
5831
5830
boxes ,
5832
5831
format = tv_tensors .BoundingBoxFormat .XYXY ,
@@ -5835,7 +5834,7 @@ def _get_boxes_and_valid_mask(self, H=256, W=128, min_size=10):
5835
5834
5836
5835
return boxes , expected_valid_mask
5837
5836
5838
- @pytest .mark .parametrize ("min_size" , (1 , 10 ))
5837
+ @pytest .mark .parametrize ("min_size, min_area " , (( 1 , 1 ), ( 10 , 1 ), ( 10 , 101 ) ))
5839
5838
@pytest .mark .parametrize (
5840
5839
"labels_getter" ,
5841
5840
(
@@ -5848,15 +5847,15 @@ def _get_boxes_and_valid_mask(self, H=256, W=128, min_size=10):
5848
5847
),
5849
5848
)
5850
5849
@pytest .mark .parametrize ("sample_type" , (tuple , dict ))
5851
- def test_transform (self , min_size , labels_getter , sample_type ):
5850
+ def test_transform (self , min_size , min_area , labels_getter , sample_type ):
5852
5851
5853
5852
if sample_type is tuple and not isinstance (labels_getter , str ):
5854
5853
# The "lambda inputs: inputs["labels"]" labels_getter used in this test
5855
5854
# doesn't work if the input is a tuple.
5856
5855
return
5857
5856
5858
5857
H , W = 256 , 128
5859
- boxes , expected_valid_mask = self ._get_boxes_and_valid_mask (H = H , W = W , min_size = min_size )
5858
+ boxes , expected_valid_mask = self ._get_boxes_and_valid_mask (H = H , W = W , min_size = min_size , min_area = min_area )
5860
5859
valid_indices = [i for (i , is_valid ) in enumerate (expected_valid_mask ) if is_valid ]
5861
5860
5862
5861
labels = torch .arange (boxes .shape [0 ])
@@ -5880,7 +5879,9 @@ def test_transform(self, min_size, labels_getter, sample_type):
5880
5879
img = sample .pop ("image" )
5881
5880
sample = (img , sample )
5882
5881
5883
- out = transforms .SanitizeBoundingBoxes (min_size = min_size , labels_getter = labels_getter )(sample )
5882
+ out = transforms .SanitizeBoundingBoxes (min_size = min_size , min_area = min_area , labels_getter = labels_getter )(
5883
+ sample
5884
+ )
5884
5885
5885
5886
if sample_type is tuple :
5886
5887
out_image = out [0 ]
@@ -5977,6 +5978,8 @@ def test_errors_transform(self):
5977
5978
5978
5979
with pytest .raises (ValueError , match = "min_size must be >= 1" ):
5979
5980
transforms .SanitizeBoundingBoxes (min_size = 0 )
5981
+ with pytest .raises (ValueError , match = "min_area must be >= 1" ):
5982
+ transforms .SanitizeBoundingBoxes (min_area = 0 )
5980
5983
with pytest .raises (ValueError , match = "labels_getter should either be 'default'" ):
5981
5984
transforms .SanitizeBoundingBoxes (labels_getter = 12 )
5982
5985
0 commit comments