@@ -77,6 +77,33 @@ def test_bbox_format(format, is_rotated_expected, scripted):
7777 fn = torch .jit .script (fn )
7878 assert fn (format ) == is_rotated_expected
7979
80+ @pytest .mark .parametrize (
81+ "format, support_integer_dtype" ,
82+ [
83+ ("XYXY" , True ),
84+ ("XYWH" , True ),
85+ ("CXCYWH" , True ),
86+ ("XYXYXYXY" , False ),
87+ ("XYWHR" , False ),
88+ ("CXCYWHR" , False ),
89+ (tv_tensors .BoundingBoxFormat .XYXY , True ),
90+ (tv_tensors .BoundingBoxFormat .XYWH , True ),
91+ (tv_tensors .BoundingBoxFormat .CXCYWH , True ),
92+ (tv_tensors .BoundingBoxFormat .XYXYXYXY , False ),
93+ (tv_tensors .BoundingBoxFormat .XYWHR , False ),
94+ (tv_tensors .BoundingBoxFormat .CXCYWHR , False ),
95+ ],
96+ )
97+ @pytest .mark .parametrize ("input_dtype" , [torch .float32 , torch .float64 , torch .uint8 ])
98+ def test_bbox_format_dtype (format , support_integer_dtype , input_dtype ):
99+ print (format , support_integer_dtype , input_dtype )
100+ if not input_dtype .is_floating_point and not support_integer_dtype :
101+ pytest .xfail ("Rotated bounding boxes should be floating point tensors" )
102+ bboxes = tv_tensors .BoundingBoxes (torch .randint (0 , 32 , size = (5 , 2 )), format = format , canvas_size = (32 , 32 ))
103+ else :
104+ bboxes = tv_tensors .BoundingBoxes (torch .rand (size = (5 , 2 )), format = format , canvas_size = (32 , 32 ))
105+ assert isinstance (bboxes , torch .Tensor )
106+
80107
81108def test_bbox_dim_error ():
82109 data_3d = [[[1 , 2 , 3 , 4 ]]]
0 commit comments