11import  math 
22import  os 
33from  abc  import  ABC , abstractmethod 
4- from  functools  import  lru_cache 
4+ from  functools  import  lru_cache ,  partial 
55from  itertools  import  product 
66from  typing  import  Callable 
77
@@ -242,7 +242,7 @@ def _helper_boxes_shape(self, func):
242242            boxes  =  torch .tensor ([[0 , 0 , 3 , 3 ]], dtype = a .dtype )
243243            func (a , boxes , output_size = (2 , 2 ))
244244
245-         # test boxes as List [Tensor[N, 4]] 
245+         # test boxes as list [Tensor[N, 4]] 
246246        with  pytest .raises (AssertionError ):
247247            a  =  torch .linspace (1 , 8  *  8 , 8  *  8 ).reshape (1 , 1 , 8 , 8 )
248248            boxes  =  torch .tensor ([[0 , 0 , 3 ]], dtype = a .dtype )
@@ -1446,34 +1446,60 @@ def test_bbox_convert_jit(self):
14461446
14471447
14481448class  TestBoxArea :
1449-     def  area_check (self , box , expected , atol = 1e-4 ):
1450-         out  =  ops .box_area (box )
1449+     def  area_check (self , box , expected , fmt = "xyxy" ,  atol = 1e-4 ):
1450+         out  =  ops .box_area (box ,  fmt = fmt )
14511451        torch .testing .assert_close (out , expected , rtol = 0.0 , check_dtype = False , atol = atol )
14521452
14531453    @pytest .mark .parametrize ("dtype" , [torch .int8 , torch .int16 , torch .int32 , torch .int64 ]) 
1454-     def  test_int_boxes (self , dtype ):
1455-         box_tensor  =  torch .tensor ([[0 , 0 , 100 , 100 ], [0 , 0 , 0 , 0 ]], dtype = dtype )
1454+     @pytest .mark .parametrize ("fmt" , ["xyxy" , "xywh" , "cxcywh" ]) 
1455+     def  test_int_boxes (self , dtype , fmt ):
1456+         box_tensor  =  ops .box_convert (
1457+             torch .tensor ([[0 , 0 , 100 , 100 ], [0 , 0 , 0 , 0 ]], dtype = dtype ), in_fmt = "xyxy" , out_fmt = fmt 
1458+         )
14561459        expected  =  torch .tensor ([10000 , 0 ], dtype = torch .int32 )
1457-         self .area_check (box_tensor , expected )
1460+         self .area_check (box_tensor , expected ,  fmt )
14581461
14591462    @pytest .mark .parametrize ("dtype" , [torch .float32 , torch .float64 ]) 
1460-     def  test_float_boxes (self , dtype ):
1461-         box_tensor  =  torch .tensor (FLOAT_BOXES , dtype = dtype )
1463+     @pytest .mark .parametrize ("fmt" , ["xyxy" , "xywh" , "cxcywh" ]) 
1464+     def  test_float_boxes (self , dtype , fmt ):
1465+         box_tensor  =  ops .box_convert (torch .tensor (FLOAT_BOXES , dtype = dtype ), in_fmt = "xyxy" , out_fmt = fmt )
14621466        expected  =  torch .tensor ([604723.0806 , 600965.4666 , 592761.0085 ], dtype = dtype )
1463-         self .area_check (box_tensor , expected )
1464- 
1465-     def  test_float16_box (self ):
1466-         box_tensor  =  torch .tensor (
1467-             [[2.825 , 1.8625 , 3.90 , 4.85 ], [2.825 , 4.875 , 19.20 , 5.10 ], [2.925 , 1.80 , 8.90 , 4.90 ]], dtype = torch .float16 
1467+         self .area_check (box_tensor , expected , fmt )
1468+ 
1469+     @pytest .mark .parametrize ("fmt" , ["xyxy" , "xywh" , "cxcywh" ]) 
1470+     def  test_float16_box (self , fmt ):
1471+         box_tensor  =  ops .box_convert (
1472+             torch .tensor (
1473+                 [[2.825 , 1.8625 , 3.90 , 4.85 ], [2.825 , 4.875 , 19.20 , 5.10 ], [2.925 , 1.80 , 8.90 , 4.90 ]],
1474+                 dtype = torch .float16 ,
1475+             ),
1476+             in_fmt = "xyxy" ,
1477+             out_fmt = fmt ,
14681478        )
14691479
14701480        expected  =  torch .tensor ([3.2170 , 3.7108 , 18.5071 ], dtype = torch .float16 )
1471-         self .area_check (box_tensor , expected , atol = 0.01 )
1481+         self .area_check (box_tensor , expected , fmt , atol = 0.01 )
1482+ 
1483+     @pytest .mark .parametrize ("fmt" , ["xyxy" , "xywh" , "cxcywh" ]) 
1484+     def  test_box_area_jit (self , fmt ):
1485+         box_tensor  =  ops .box_convert (
1486+             torch .tensor ([[0 , 0 , 100 , 100 ], [0 , 0 , 0 , 0 ]], dtype = torch .float ), in_fmt = "xyxy" , out_fmt = fmt 
1487+         )
1488+         expected  =  ops .box_area (box_tensor , fmt )
1489+ 
1490+         class  BoxArea (torch .nn .Module ):
1491+             # We are using this intermediate class 
1492+             # since torchscript does not support 
1493+             # neither partial nor lambda functions for this test. 
1494+             def  __init__ (self , fmt ):
1495+                 super ().__init__ ()
1496+                 self .area  =  ops .box_area 
1497+                 self .fmt  =  fmt 
14721498
1473-     def  test_box_area_jit (self ):
1474-         box_tensor   =   torch . tensor ([[ 0 ,  0 ,  100 ,  100 ], [ 0 ,  0 ,  0 ,  0 ]],  dtype = torch . float )
1475-          expected   =   ops . box_area ( box_tensor ) 
1476-         scripted_fn  =  torch .jit .script (ops . box_area )
1499+              def  forward (self ,  boxes ):
1500+                  return   self . area ( boxes ,  self . fmt )
1501+ 
1502+         scripted_fn  =  torch .jit .script (BoxArea ( fmt ) )
14771503        scripted_area  =  scripted_fn (box_tensor )
14781504        torch .testing .assert_close (scripted_area , expected )
14791505
@@ -1487,25 +1513,28 @@ def test_box_area_jit(self):
14871513]
14881514
14891515
1490- def  gen_box (size , dtype = torch .float ) ->  Tensor :
1516+ def  gen_box (size , dtype = torch .float ,  fmt = "xyxy" ) ->  Tensor :
14911517    xy1  =  torch .rand ((size , 2 ), dtype = dtype )
14921518    xy2  =  xy1  +  torch .rand ((size , 2 ), dtype = dtype )
1493-     return  torch .cat ([xy1 , xy2 ], axis = - 1 )
1519+     return  ops . box_convert ( torch .cat ([xy1 , xy2 ], axis = - 1 ),  in_fmt = "xyxy" ,  out_fmt = fmt )
14941520
14951521
14961522class  TestIouBase :
14971523    @staticmethod  
1498-     def  _run_test (target_fn : Callable , actual_box1 , actual_box2 , dtypes , atol , expected ):
1524+     def  _run_test (target_fn : Callable , actual_box1 , actual_box2 , dtypes , atol , expected ,  fmt = "xyxy" ):
14991525        for  dtype  in  dtypes :
1500-             actual_box1  =  torch .tensor (actual_box1 , dtype = dtype )
1501-             actual_box2  =  torch .tensor (actual_box2 , dtype = dtype )
1526+             _actual_box1  =  ops . box_convert ( torch .tensor (actual_box1 , dtype = dtype ),  in_fmt = "xyxy" ,  out_fmt = fmt )
1527+             _actual_box2  =  ops . box_convert ( torch .tensor (actual_box2 , dtype = dtype ),  in_fmt = "xyxy" ,  out_fmt = fmt )
15021528            expected_box  =  torch .tensor (expected )
1503-             out  =  target_fn (actual_box1 , actual_box2 )
1529+             out  =  target_fn (
1530+                 _actual_box1 ,
1531+                 _actual_box2 ,
1532+             )
15041533            torch .testing .assert_close (out , expected_box , rtol = 0.0 , check_dtype = False , atol = atol )
15051534
15061535    @staticmethod  
1507-     def  _run_jit_test (target_fn : Callable , actual_box : list ):
1508-         box_tensor  =  torch .tensor (actual_box , dtype = torch .float )
1536+     def  _run_jit_test (target_fn : Callable , actual_box : list ,  fmt = "xyxy" ):
1537+         box_tensor  =  ops . box_convert ( torch .tensor (actual_box , dtype = torch .float ),  in_fmt = "xyxy" ,  out_fmt = fmt )
15091538        expected  =  target_fn (box_tensor , box_tensor )
15101539        scripted_fn  =  torch .jit .script (target_fn )
15111540        scripted_out  =  scripted_fn (box_tensor , box_tensor )
@@ -1522,17 +1551,17 @@ def _cartesian_product(boxes1, boxes2, target_fn: Callable):
15221551        return  result 
15231552
15241553    @staticmethod  
1525-     def  _run_cartesian_test (target_fn : Callable ):
1526-         boxes1  =  gen_box (5 )
1527-         boxes2  =  gen_box (7 )
1554+     def  _run_cartesian_test (target_fn : Callable ,  fmt :  str   =   "xyxy" ):
1555+         boxes1  =  gen_box (5 ,  fmt = fmt )
1556+         boxes2  =  gen_box (7 ,  fmt = fmt )
15281557        a  =  TestIouBase ._cartesian_product (boxes1 , boxes2 , target_fn )
15291558        b  =  target_fn (boxes1 , boxes2 )
15301559        torch .testing .assert_close (a , b )
15311560
15321561    @staticmethod  
1533-     def  _run_batch_test (target_fn : Callable ):
1534-         boxes1  =  torch .stack ([gen_box (5 ) for  _  in  range (3 )], dim = 0 )
1535-         boxes2  =  torch .stack ([gen_box (5 ) for  _  in  range (3 )], dim = 0 )
1562+     def  _run_batch_test (target_fn : Callable ,  fmt :  str   =   "xyxy" ):
1563+         boxes1  =  torch .stack ([gen_box (5 ,  fmt = fmt ) for  _  in  range (3 )], dim = 0 )
1564+         boxes2  =  torch .stack ([gen_box (5 ,  fmt = fmt ) for  _  in  range (3 )], dim = 0 )
15361565        native : Tensor  =  target_fn (boxes1 , boxes2 )
15371566        iterative : Tensor  =  torch .stack ([target_fn (* pairs ) for  pairs  in  zip (boxes1 , boxes2 )], dim = 0 )
15381567        torch .testing .assert_close (native , iterative )
@@ -1550,17 +1579,33 @@ class TestBoxIou(TestIouBase):
15501579            pytest .param (FLOAT_BOXES , FLOAT_BOXES , [torch .float32 , torch .float64 ], 1e-3 , float_expected ), 
15511580        ], 
15521581    ) 
1553-     def  test_iou (self , actual_box1 , actual_box2 , dtypes , atol , expected ):
1554-         self ._run_test (ops .box_iou , actual_box1 , actual_box2 , dtypes , atol , expected )
1555- 
1556-     def  test_iou_jit (self ):
1557-         self ._run_jit_test (ops .box_iou , INT_BOXES )
1558- 
1559-     def  test_iou_cartesian (self ):
1560-         self ._run_cartesian_test (ops .box_iou )
1561- 
1562-     def  test_iou_batch (self ):
1563-         self ._run_batch_test (ops .box_iou )
1582+     @pytest .mark .parametrize ("fmt" , ["xyxy" , "xywh" , "cxcywh" ]) 
1583+     def  test_iou (self , actual_box1 , actual_box2 , dtypes , atol , expected , fmt ):
1584+         self ._run_test (partial (ops .box_iou , fmt = fmt ), actual_box1 , actual_box2 , dtypes , atol , expected , fmt )
1585+ 
1586+     @pytest .mark .parametrize ("fmt" , ["xyxy" , "xywh" , "cxcywh" ]) 
1587+     def  test_iou_jit (self , fmt ):
1588+         class  IoUJit (torch .nn .Module ):
1589+             # We are using this intermediate class 
1590+             # since torchscript does not support 
1591+             # neither partial nor lambda functions for this test. 
1592+             def  __init__ (self , fmt ):
1593+                 super ().__init__ ()
1594+                 self .iou  =  ops .box_iou 
1595+                 self .fmt  =  fmt 
1596+ 
1597+             def  forward (self , boxes1 , boxes2 ):
1598+                 return  self .iou (boxes1 , boxes2 , fmt = self .fmt )
1599+ 
1600+         self ._run_jit_test (IoUJit (fmt = fmt ), INT_BOXES , fmt )
1601+ 
1602+     @pytest .mark .parametrize ("fmt" , ["xyxy" , "xywh" , "cxcywh" ]) 
1603+     def  test_iou_cartesian (self , fmt ):
1604+         self ._run_cartesian_test (partial (ops .box_iou , fmt = fmt ))
1605+ 
1606+     @pytest .mark .parametrize ("fmt" , ["xyxy" , "xywh" , "cxcywh" ]) 
1607+     def  test_iou_batch (self , fmt ):
1608+         self ._run_batch_test (partial (ops .box_iou , fmt = fmt ))
15641609
15651610
15661611class  TestGeneralizedBoxIou (TestIouBase ):
0 commit comments