@@ -1485,8 +1485,9 @@ def area_check(self, box, expected, atol=1e-4):
14851485
14861486 @pytest .mark .parametrize ("dtype" , [torch .int8 , torch .int16 , torch .int32 , torch .int64 ])
14871487 def test_int_boxes (self , dtype ):
1488- box_tensor = ops .box_convert (torch .tensor ([[0 , 0 , 100 , 100 ], [0 , 0 , 0 , 0 ]], dtype = dtype ),
1489- in_fmt = "xyxy" , out_fmt = "cxcywh" )
1488+ box_tensor = ops .box_convert (
1489+ torch .tensor ([[0 , 0 , 100 , 100 ], [0 , 0 , 0 , 0 ]], dtype = dtype ), in_fmt = "xyxy" , out_fmt = "cxcywh"
1490+ )
14901491 expected = torch .tensor ([10000 , 0 ], dtype = torch .int32 )
14911492 self .area_check (box_tensor , expected )
14921493
@@ -1497,16 +1498,22 @@ def test_float_boxes(self, dtype):
14971498 self .area_check (box_tensor , expected )
14981499
14991500 def test_float16_box (self ):
1500- box_tensor = ops .box_convert (torch .tensor (
1501- [[2.825 , 1.8625 , 3.90 , 4.85 ], [2.825 , 4.875 , 19.20 , 5.10 ], [2.925 , 1.80 , 8.90 , 4.90 ]], dtype = torch .float16
1502- ), in_fmt = "xyxy" , out_fmt = "cxcywh" )
1501+ box_tensor = ops .box_convert (
1502+ torch .tensor (
1503+ [[2.825 , 1.8625 , 3.90 , 4.85 ], [2.825 , 4.875 , 19.20 , 5.10 ], [2.925 , 1.80 , 8.90 , 4.90 ]],
1504+ dtype = torch .float16 ,
1505+ ),
1506+ in_fmt = "xyxy" ,
1507+ out_fmt = "cxcywh" ,
1508+ )
15031509
15041510 expected = torch .tensor ([3.2170 , 3.7108 , 18.5071 ], dtype = torch .float16 )
15051511 self .area_check (box_tensor , expected , atol = 0.01 )
15061512
15071513 def test_box_area_jit (self ):
1508- box_tensor = ops .box_convert (torch .tensor ([[0 , 0 , 100 , 100 ], [0 , 0 , 0 , 0 ]], dtype = torch .float ),
1509- in_fmt = "xyxy" , out_fmt = "cxcywh" )
1514+ box_tensor = ops .box_convert (
1515+ torch .tensor ([[0 , 0 , 100 , 100 ], [0 , 0 , 0 , 0 ]], dtype = torch .float ), in_fmt = "xyxy" , out_fmt = "cxcywh"
1516+ )
15101517 expected = ops .box_area (box_tensor , fmt = "cxcywh" )
15111518 scripted_fn = torch .jit .script (ops .box_area )
15121519 scripted_area = scripted_fn (box_tensor , fmt = "cxcywh" )
@@ -1526,7 +1533,7 @@ def test_box_area_jit(self):
15261533FLOAT_BOXES_CXCYWH = [
15271534 [739.4324 , 518.5154 , 908.1572 , 665.8793 ],
15281535 [738.8228 , 519.9021 , 907.3512 , 662.3295 ],
1529- [734.3593 , 523.5916 , 910.2306 , 651.2207 ]
1536+ [734.3593 , 523.5916 , 910.2306 , 651.2207 ],
15301537]
15311538
15321539
@@ -1650,7 +1657,9 @@ class TestBoxIouCXCYWH(TestIouCXCYWHBase):
16501657 @pytest .mark .parametrize (
16511658 "actual_box1, actual_box2, dtypes, atol, expected" ,
16521659 [
1653- pytest .param (INT_BOXES_CXCYWH , INT_BOXES2_CXCYWH , [torch .int16 , torch .int32 , torch .int64 ], 1e-4 , int_expected ),
1660+ pytest .param (
1661+ INT_BOXES_CXCYWH , INT_BOXES2_CXCYWH , [torch .int16 , torch .int32 , torch .int64 ], 1e-4 , int_expected
1662+ ),
16541663 pytest .param (FLOAT_BOXES_CXCYWH , FLOAT_BOXES_CXCYWH , [torch .float16 ], 0.002 , float_expected ),
16551664 pytest .param (FLOAT_BOXES_CXCYWH , FLOAT_BOXES_CXCYWH , [torch .float32 , torch .float64 ], 1e-3 , float_expected ),
16561665 ],
@@ -1664,6 +1673,7 @@ def test_iou_jit(self):
16641673 def test_iou_cartesian (self ):
16651674 self ._run_cartesian_test (ops .box_iou )
16661675
1676+
16671677class TestIouBase :
16681678 @staticmethod
16691679 def _run_test (target_fn : Callable , actual_box1 , actual_box2 , dtypes , atol , expected ):
0 commit comments