@@ -1418,9 +1418,9 @@ def test_bbox_convert_jit(self):
14181418 torch .testing .assert_close (scripted_cxcywh , box_cxcywh )
14191419
14201420
1421- class TestBoxArea :
1421+ class TestBoxAreaXYXY :
14221422 def area_check (self , box , expected , atol = 1e-4 ):
1423- out = ops .box_area (box )
1423+ out = ops .box_area (box , fmt = "xyxy" )
14241424 torch .testing .assert_close (out , expected , rtol = 0.0 , check_dtype = False , atol = atol )
14251425
14261426 @pytest .mark .parametrize ("dtype" , [torch .int8 , torch .int16 , torch .int32 , torch .int64 ])
@@ -1445,15 +1445,15 @@ def test_float16_box(self):
14451445
14461446 def test_box_area_jit (self ):
14471447 box_tensor = torch .tensor ([[0 , 0 , 100 , 100 ], [0 , 0 , 0 , 0 ]], dtype = torch .float )
1448- expected = ops .box_area (box_tensor )
1448+ expected = ops .box_area (box_tensor , fmt = "xyxy" )
14491449 scripted_fn = torch .jit .script (ops .box_area )
14501450 scripted_area = scripted_fn (box_tensor )
14511451 torch .testing .assert_close (scripted_area , expected )
14521452
14531453
1454- class TestBoxAreaCenter :
1454+ class TestBoxAreaCXCYWH :
14551455 def area_check (self , box , expected , atol = 1e-4 ):
1456- out = ops .box_area_center (box )
1456+ out = ops .box_area (box , fmt = "cxcywh" )
14571457 torch .testing .assert_close (out , expected , rtol = 0.0 , check_dtype = False , atol = atol )
14581458
14591459 @pytest .mark .parametrize ("dtype" , [torch .int8 , torch .int16 , torch .int32 , torch .int64 ])
@@ -1480,9 +1480,9 @@ def test_float16_box(self):
14801480 def test_box_area_jit (self ):
14811481 box_tensor = ops .box_convert (torch .tensor ([[0 , 0 , 100 , 100 ], [0 , 0 , 0 , 0 ]], dtype = torch .float ),
14821482 in_fmt = "xyxy" , out_fmt = "cxcywh" )
1483- expected = ops .box_area_center (box_tensor )
1484- scripted_fn = torch .jit .script (ops .box_area_center )
1485- scripted_area = scripted_fn (box_tensor )
1483+ expected = ops .box_area (box_tensor , fmt = "cxcywh" )
1484+ scripted_fn = torch .jit .script (ops .box_area )
1485+ scripted_area = scripted_fn (box_tensor , fmt = "cxcywh" )
14861486 torch .testing .assert_close (scripted_area , expected )
14871487
14881488
@@ -1509,22 +1509,22 @@ def gen_box(size, dtype=torch.float):
15091509 return torch .cat ([xy1 , xy2 ], axis = - 1 )
15101510
15111511
1512- class TestIouBase :
1512+ class TestIouXYXYBase :
15131513 @staticmethod
15141514 def _run_test (target_fn : Callable , actual_box1 , actual_box2 , dtypes , atol , expected ):
15151515 for dtype in dtypes :
15161516 actual_box1 = torch .tensor (actual_box1 , dtype = dtype )
15171517 actual_box2 = torch .tensor (actual_box2 , dtype = dtype )
15181518 expected_box = torch .tensor (expected )
1519- out = target_fn (actual_box1 , actual_box2 )
1519+ out = target_fn (actual_box1 , actual_box2 , fmt = "xyxy" )
15201520 torch .testing .assert_close (out , expected_box , rtol = 0.0 , check_dtype = False , atol = atol )
15211521
15221522 @staticmethod
15231523 def _run_jit_test (target_fn : Callable , actual_box : List ):
15241524 box_tensor = torch .tensor (actual_box , dtype = torch .float )
1525- expected = target_fn (box_tensor , box_tensor )
1525+ expected = target_fn (box_tensor , box_tensor , fmt = "xyxy" )
15261526 scripted_fn = torch .jit .script (target_fn )
1527- scripted_out = scripted_fn (box_tensor , box_tensor )
1527+ scripted_out = scripted_fn (box_tensor , box_tensor , fmt = "xyxy" )
15281528 torch .testing .assert_close (scripted_out , expected )
15291529
15301530 @staticmethod
@@ -1534,19 +1534,19 @@ def _cartesian_product(boxes1, boxes2, target_fn: Callable):
15341534 result = torch .zeros ((N , M ))
15351535 for i in range (N ):
15361536 for j in range (M ):
1537- result [i , j ] = target_fn (boxes1 [i ].unsqueeze (0 ), boxes2 [j ].unsqueeze (0 ))
1537+ result [i , j ] = target_fn (boxes1 [i ].unsqueeze (0 ), boxes2 [j ].unsqueeze (0 ), fmt = "xyxy" )
15381538 return result
15391539
15401540 @staticmethod
15411541 def _run_cartesian_test (target_fn : Callable ):
15421542 boxes1 = gen_box (5 )
15431543 boxes2 = gen_box (7 )
1544- a = TestIouBase ._cartesian_product (boxes1 , boxes2 , target_fn )
1545- b = target_fn (boxes1 , boxes2 )
1544+ a = TestIouXYXYBase ._cartesian_product (boxes1 , boxes2 , target_fn )
1545+ b = target_fn (boxes1 , boxes2 , fmt = "xyxy" )
15461546 torch .testing .assert_close (a , b )
15471547
15481548
1549- class TestBoxIou ( TestIouBase ):
1549+ class TestBoxIouXYXY ( TestIouXYXYBase ):
15501550 int_expected = [[1.0 , 0.25 , 0.0 ], [0.25 , 1.0 , 0.0 ], [0.0 , 0.0 , 1.0 ], [0.0625 , 0.25 , 0.0 ]]
15511551 float_expected = [[1.0 , 0.9933 , 0.9673 ], [0.9933 , 1.0 , 0.9737 ], [0.9673 , 0.9737 , 1.0 ]]
15521552
@@ -1568,22 +1568,22 @@ def test_iou_cartesian(self):
15681568 self ._run_cartesian_test (ops .box_iou )
15691569
15701570
1571- class TestIouCenterBase :
1571+ class TestIouCXCYWHBase :
15721572 @staticmethod
15731573 def _run_test (target_fn : Callable , actual_box1 , actual_box2 , dtypes , atol , expected ):
15741574 for dtype in dtypes :
15751575 actual_box1 = torch .tensor (actual_box1 , dtype = dtype )
15761576 actual_box2 = torch .tensor (actual_box2 , dtype = dtype )
15771577 expected_box = torch .tensor (expected )
1578- out = target_fn (actual_box1 , actual_box2 )
1578+ out = target_fn (actual_box1 , actual_box2 , fmt = "cxcywh" )
15791579 torch .testing .assert_close (out , expected_box , rtol = 0.0 , check_dtype = False , atol = atol )
15801580
15811581 @staticmethod
15821582 def _run_jit_test (target_fn : Callable , actual_box : List ):
15831583 box_tensor = torch .tensor (actual_box , dtype = torch .float )
1584- expected = target_fn (box_tensor , box_tensor )
1584+ expected = target_fn (box_tensor , box_tensor , fmt = "cxcywh" )
15851585 scripted_fn = torch .jit .script (target_fn )
1586- scripted_out = scripted_fn (box_tensor , box_tensor )
1586+ scripted_out = scripted_fn (box_tensor , box_tensor , fmt = "cxcywh" )
15871587 torch .testing .assert_close (scripted_out , expected )
15881588
15891589 @staticmethod
@@ -1593,19 +1593,19 @@ def _cartesian_product(boxes1, boxes2, target_fn: Callable):
15931593 result = torch .zeros ((N , M ))
15941594 for i in range (N ):
15951595 for j in range (M ):
1596- result [i , j ] = target_fn (boxes1 [i ].unsqueeze (0 ), boxes2 [j ].unsqueeze (0 ))
1596+ result [i , j ] = target_fn (boxes1 [i ].unsqueeze (0 ), boxes2 [j ].unsqueeze (0 ), fmt = "cxcywh" )
15971597 return result
15981598
15991599 @staticmethod
16001600 def _run_cartesian_test (target_fn : Callable ):
16011601 boxes1 = ops .box_convert (gen_box (5 ), in_fmt = "xyxy" , out_fmt = "cxcywh" )
16021602 boxes2 = ops .box_convert (gen_box (7 ), in_fmt = "xyxy" , out_fmt = "cxcywh" )
1603- a = TestIouCenterBase ._cartesian_product (boxes1 , boxes2 , target_fn )
1604- b = target_fn (boxes1 , boxes2 )
1603+ a = TestIouCXCYWHBase ._cartesian_product (boxes1 , boxes2 , target_fn )
1604+ b = target_fn (boxes1 , boxes2 , fmt = "cxcywh" )
16051605 torch .testing .assert_close (a , b )
16061606
16071607
1608- class TestBoxIouCenter ( TestIouBase ):
1608+ class TestBoxIouCXCYWH ( TestIouCXCYWHBase ):
16091609 int_expected = [[1.0 , 0.25 , 0.0 ], [0.25 , 1.0 , 0.0 ], [0.0 , 0.0 , 1.0 ], [0.04 , 0.16 , 0.0 ]]
16101610 float_expected = [[1.0 , 0.9933 , 0.9673 ], [0.9933 , 1.0 , 0.9737 ], [0.9673 , 0.9737 , 1.0 ]]
16111611
@@ -1618,13 +1618,49 @@ class TestBoxIouCenter(TestIouBase):
16181618 ],
16191619 )
16201620 def test_iou (self , actual_box1 , actual_box2 , dtypes , atol , expected ):
1621- self ._run_test (ops .box_iou_center , actual_box1 , actual_box2 , dtypes , atol , expected )
1621+ self ._run_test (ops .box_iou , actual_box1 , actual_box2 , dtypes , atol , expected )
16221622
16231623 def test_iou_jit (self ):
1624- self ._run_jit_test (ops .box_iou_center , INT_BOXES_CXCYWH )
1624+ self ._run_jit_test (ops .box_iou , INT_BOXES_CXCYWH )
16251625
16261626 def test_iou_cartesian (self ):
1627- self ._run_cartesian_test (ops .box_iou_center )
1627+ self ._run_cartesian_test (ops .box_iou )
1628+
1629+ class TestIouBase :
1630+ @staticmethod
1631+ def _run_test (target_fn : Callable , actual_box1 , actual_box2 , dtypes , atol , expected ):
1632+ for dtype in dtypes :
1633+ actual_box1 = torch .tensor (actual_box1 , dtype = dtype )
1634+ actual_box2 = torch .tensor (actual_box2 , dtype = dtype )
1635+ expected_box = torch .tensor (expected )
1636+ out = target_fn (actual_box1 , actual_box2 )
1637+ torch .testing .assert_close (out , expected_box , rtol = 0.0 , check_dtype = False , atol = atol )
1638+
1639+ @staticmethod
1640+ def _run_jit_test (target_fn : Callable , actual_box : List ):
1641+ box_tensor = torch .tensor (actual_box , dtype = torch .float )
1642+ expected = target_fn (box_tensor , box_tensor )
1643+ scripted_fn = torch .jit .script (target_fn )
1644+ scripted_out = scripted_fn (box_tensor , box_tensor )
1645+ torch .testing .assert_close (scripted_out , expected )
1646+
1647+ @staticmethod
1648+ def _cartesian_product (boxes1 , boxes2 , target_fn : Callable ):
1649+ N = boxes1 .size (0 )
1650+ M = boxes2 .size (0 )
1651+ result = torch .zeros ((N , M ))
1652+ for i in range (N ):
1653+ for j in range (M ):
1654+ result [i , j ] = target_fn (boxes1 [i ].unsqueeze (0 ), boxes2 [j ].unsqueeze (0 ))
1655+ return result
1656+
1657+ @staticmethod
1658+ def _run_cartesian_test (target_fn : Callable ):
1659+ boxes1 = gen_box (5 )
1660+ boxes2 = gen_box (7 )
1661+ a = TestIouBase ._cartesian_product (boxes1 , boxes2 , target_fn )
1662+ b = target_fn (boxes1 , boxes2 )
1663+ torch .testing .assert_close (a , b )
16281664
16291665
16301666class TestGeneralizedBoxIou (TestIouBase ):
0 commit comments