11import math
22import os
33from abc import ABC , abstractmethod
4- from functools import lru_cache
4+ from functools import lru_cache , partial
55from itertools import product
66from typing import Callable
77
@@ -242,7 +242,7 @@ def _helper_boxes_shape(self, func):
242242 boxes = torch .tensor ([[0 , 0 , 3 , 3 ]], dtype = a .dtype )
243243 func (a , boxes , output_size = (2 , 2 ))
244244
245- # test boxes as List [Tensor[N, 4]]
245+ # test boxes as list [Tensor[N, 4]]
246246 with pytest .raises (AssertionError ):
247247 a = torch .linspace (1 , 8 * 8 , 8 * 8 ).reshape (1 , 1 , 8 , 8 )
248248 boxes = torch .tensor ([[0 , 0 , 3 ]], dtype = a .dtype )
@@ -1073,15 +1073,15 @@ def test_forward(self, device, contiguous, batch_sz, dtype=None):
10731073 expected = self .expected_fn (x , weight , offset , mask , bias , stride = stride , padding = padding , dilation = dilation )
10741074
10751075 torch .testing .assert_close (
1076- res .to (expected ), expected , rtol = tol , atol = tol , msg = f"\n res:\n { res } \n expected:\n { expected } "
1076+ res .to (expected ), expected , rtol = tol , atol = tol , msg = f"\n res: \n { res } \n expected: \n { expected } "
10771077 )
10781078
10791079 # no modulation test
10801080 res = layer (x , offset )
10811081 expected = self .expected_fn (x , weight , offset , None , bias , stride = stride , padding = padding , dilation = dilation )
10821082
10831083 torch .testing .assert_close (
1084- res .to (expected ), expected , rtol = tol , atol = tol , msg = f"\n res:\n { res } \n expected:\n { expected } "
1084+ res .to (expected ), expected , rtol = tol , atol = tol , msg = f"\n res: \n { res } \n expected: \n { expected } "
10851085 )
10861086
10871087 def test_wrong_sizes (self ):
@@ -1446,34 +1446,60 @@ def test_bbox_convert_jit(self):
14461446
14471447
14481448class TestBoxArea :
1449- def area_check (self , box , expected , atol = 1e-4 ):
1450- out = ops .box_area (box )
1449+ def area_check (self , box , expected , fmt = "xyxy" , atol = 1e-4 ):
1450+ out = ops .box_area (box , fmt = fmt )
14511451 torch .testing .assert_close (out , expected , rtol = 0.0 , check_dtype = False , atol = atol )
14521452
14531453 @pytest .mark .parametrize ("dtype" , [torch .int8 , torch .int16 , torch .int32 , torch .int64 ])
1454- def test_int_boxes (self , dtype ):
1455- box_tensor = torch .tensor ([[0 , 0 , 100 , 100 ], [0 , 0 , 0 , 0 ]], dtype = dtype )
1454+ @pytest .mark .parametrize ("fmt" , ["xyxy" , "xywh" , "cxcywh" ])
1455+ def test_int_boxes (self , dtype , fmt ):
1456+ box_tensor = ops .box_convert (
1457+ torch .tensor ([[0 , 0 , 100 , 100 ], [0 , 0 , 0 , 0 ]], dtype = dtype ), in_fmt = "xyxy" , out_fmt = fmt
1458+ )
14561459 expected = torch .tensor ([10000 , 0 ], dtype = torch .int32 )
1457- self .area_check (box_tensor , expected )
1460+ self .area_check (box_tensor , expected , fmt )
14581461
14591462 @pytest .mark .parametrize ("dtype" , [torch .float32 , torch .float64 ])
1460- def test_float_boxes (self , dtype ):
1461- box_tensor = torch .tensor (FLOAT_BOXES , dtype = dtype )
1463+ @pytest .mark .parametrize ("fmt" , ["xyxy" , "xywh" , "cxcywh" ])
1464+ def test_float_boxes (self , dtype , fmt ):
1465+ box_tensor = ops .box_convert (torch .tensor (FLOAT_BOXES , dtype = dtype ), in_fmt = "xyxy" , out_fmt = fmt )
14621466 expected = torch .tensor ([604723.0806 , 600965.4666 , 592761.0085 ], dtype = dtype )
1463- self .area_check (box_tensor , expected )
1464-
1465- def test_float16_box (self ):
1466- box_tensor = torch .tensor (
1467- [[2.825 , 1.8625 , 3.90 , 4.85 ], [2.825 , 4.875 , 19.20 , 5.10 ], [2.925 , 1.80 , 8.90 , 4.90 ]], dtype = torch .float16
1467+ self .area_check (box_tensor , expected , fmt )
1468+
1469+ @pytest .mark .parametrize ("fmt" , ["xyxy" , "xywh" , "cxcywh" ])
1470+ def test_float16_box (self , fmt ):
1471+ box_tensor = ops .box_convert (
1472+ torch .tensor (
1473+ [[2.825 , 1.8625 , 3.90 , 4.85 ], [2.825 , 4.875 , 19.20 , 5.10 ], [2.925 , 1.80 , 8.90 , 4.90 ]],
1474+ dtype = torch .float16 ,
1475+ ),
1476+ in_fmt = "xyxy" ,
1477+ out_fmt = fmt ,
14681478 )
14691479
14701480 expected = torch .tensor ([3.2170 , 3.7108 , 18.5071 ], dtype = torch .float16 )
1471- self .area_check (box_tensor , expected , atol = 0.01 )
1481+ self .area_check (box_tensor , expected , fmt , atol = 0.01 )
1482+
1483+ @pytest .mark .parametrize ("fmt" , ["xyxy" , "xywh" , "cxcywh" ])
1484+ def test_box_area_jit (self , fmt ):
1485+ box_tensor = ops .box_convert (
1486+ torch .tensor ([[0 , 0 , 100 , 100 ], [0 , 0 , 0 , 0 ]], dtype = torch .float ), in_fmt = "xyxy" , out_fmt = fmt
1487+ )
1488+ expected = ops .box_area (box_tensor , fmt )
14721489
1473- def test_box_area_jit (self ):
1474- box_tensor = torch .tensor ([[0 , 0 , 100 , 100 ], [0 , 0 , 0 , 0 ]], dtype = torch .float )
1475- expected = ops .box_area (box_tensor )
1476- scripted_fn = torch .jit .script (ops .box_area )
1490+ class BoxArea (torch .nn .Module ):
1491+ # We are using this intermediate class
1492+ # since torchscript does not support
1493+ # neither partial nor lambda functions for this test.
1494+ def __init__ (self , fmt ):
1495+ super ().__init__ ()
1496+ self .area = ops .box_area
1497+ self .fmt = fmt
1498+
1499+ def forward (self , boxes ):
1500+ return self .area (boxes , self .fmt )
1501+
1502+ scripted_fn = torch .jit .script (BoxArea (fmt ))
14771503 scripted_area = scripted_fn (box_tensor )
14781504 torch .testing .assert_close (scripted_area , expected )
14791505
@@ -1487,25 +1513,28 @@ def test_box_area_jit(self):
14871513]
14881514
14891515
1490- def gen_box (size , dtype = torch .float ) :
1516+ def gen_box (size , dtype = torch .float , fmt = "xyxy" ) -> Tensor :
14911517 xy1 = torch .rand ((size , 2 ), dtype = dtype )
14921518 xy2 = xy1 + torch .rand ((size , 2 ), dtype = dtype )
1493- return torch .cat ([xy1 , xy2 ], axis = - 1 )
1519+ return ops . box_convert ( torch .cat ([xy1 , xy2 ], axis = - 1 ), in_fmt = "xyxy" , out_fmt = fmt )
14941520
14951521
14961522class TestIouBase :
14971523 @staticmethod
1498- def _run_test (target_fn : Callable , actual_box1 , actual_box2 , dtypes , atol , expected ):
1524+ def _run_test (target_fn : Callable , actual_box1 , actual_box2 , dtypes , atol , expected , fmt = "xyxy" ):
14991525 for dtype in dtypes :
1500- actual_box1 = torch .tensor (actual_box1 , dtype = dtype )
1501- actual_box2 = torch .tensor (actual_box2 , dtype = dtype )
1526+ _actual_box1 = ops . box_convert ( torch .tensor (actual_box1 , dtype = dtype ), in_fmt = "xyxy" , out_fmt = fmt )
1527+ _actual_box2 = ops . box_convert ( torch .tensor (actual_box2 , dtype = dtype ), in_fmt = "xyxy" , out_fmt = fmt )
15021528 expected_box = torch .tensor (expected )
1503- out = target_fn (actual_box1 , actual_box2 )
1529+ out = target_fn (
1530+ _actual_box1 ,
1531+ _actual_box2 ,
1532+ )
15041533 torch .testing .assert_close (out , expected_box , rtol = 0.0 , check_dtype = False , atol = atol )
15051534
15061535 @staticmethod
1507- def _run_jit_test (target_fn : Callable , actual_box : list ):
1508- box_tensor = torch .tensor (actual_box , dtype = torch .float )
1536+ def _run_jit_test (target_fn : Callable , actual_box : list , fmt = "xyxy" ):
1537+ box_tensor = ops . box_convert ( torch .tensor (actual_box , dtype = torch .float ), in_fmt = "xyxy" , out_fmt = fmt )
15091538 expected = target_fn (box_tensor , box_tensor )
15101539 scripted_fn = torch .jit .script (target_fn )
15111540 scripted_out = scripted_fn (box_tensor , box_tensor )
@@ -1522,13 +1551,21 @@ def _cartesian_product(boxes1, boxes2, target_fn: Callable):
15221551 return result
15231552
15241553 @staticmethod
1525- def _run_cartesian_test (target_fn : Callable ):
1526- boxes1 = gen_box (5 )
1527- boxes2 = gen_box (7 )
1554+ def _run_cartesian_test (target_fn : Callable , fmt : str = "xyxy" ):
1555+ boxes1 = gen_box (5 , fmt = fmt )
1556+ boxes2 = gen_box (7 , fmt = fmt )
15281557 a = TestIouBase ._cartesian_product (boxes1 , boxes2 , target_fn )
15291558 b = target_fn (boxes1 , boxes2 )
15301559 torch .testing .assert_close (a , b )
15311560
1561+ @staticmethod
1562+ def _run_batch_test (target_fn : Callable , fmt : str = "xyxy" ):
1563+ boxes1 = torch .stack ([gen_box (5 , fmt = fmt ) for _ in range (3 )], dim = 0 )
1564+ boxes2 = torch .stack ([gen_box (5 , fmt = fmt ) for _ in range (3 )], dim = 0 )
1565+ native : Tensor = target_fn (boxes1 , boxes2 )
1566+ iterative : Tensor = torch .stack ([target_fn (* pairs ) for pairs in zip (boxes1 , boxes2 )], dim = 0 )
1567+ torch .testing .assert_close (native , iterative )
1568+
15321569
15331570class TestBoxIou (TestIouBase ):
15341571 int_expected = [[1.0 , 0.25 , 0.0 ], [0.25 , 1.0 , 0.0 ], [0.0 , 0.0 , 1.0 ], [0.0625 , 0.25 , 0.0 ]]
@@ -1542,14 +1579,33 @@ class TestBoxIou(TestIouBase):
15421579 pytest .param (FLOAT_BOXES , FLOAT_BOXES , [torch .float32 , torch .float64 ], 1e-3 , float_expected ),
15431580 ],
15441581 )
1545- def test_iou (self , actual_box1 , actual_box2 , dtypes , atol , expected ):
1546- self ._run_test (ops .box_iou , actual_box1 , actual_box2 , dtypes , atol , expected )
1582+ @pytest .mark .parametrize ("fmt" , ["xyxy" , "xywh" , "cxcywh" ])
1583+ def test_iou (self , actual_box1 , actual_box2 , dtypes , atol , expected , fmt ):
1584+ self ._run_test (partial (ops .box_iou , fmt = fmt ), actual_box1 , actual_box2 , dtypes , atol , expected , fmt )
15471585
1548- def test_iou_jit (self ):
1549- self ._run_jit_test (ops .box_iou , INT_BOXES )
1586+ @pytest .mark .parametrize ("fmt" , ["xyxy" , "xywh" , "cxcywh" ])
1587+ def test_iou_jit (self , fmt ):
1588+ class IoUJit (torch .nn .Module ):
1589+ # We are using this intermediate class
1590+ # since torchscript does not support
1591+ # neither partial nor lambda functions for this test.
1592+ def __init__ (self , fmt ):
1593+ super ().__init__ ()
1594+ self .iou = ops .box_iou
1595+ self .fmt = fmt
15501596
1551- def test_iou_cartesian (self ):
1552- self ._run_cartesian_test (ops .box_iou )
1597+ def forward (self , boxes1 , boxes2 ):
1598+ return self .iou (boxes1 , boxes2 , fmt = self .fmt )
1599+
1600+ self ._run_jit_test (IoUJit (fmt = fmt ), INT_BOXES , fmt )
1601+
1602+ @pytest .mark .parametrize ("fmt" , ["xyxy" , "xywh" , "cxcywh" ])
1603+ def test_iou_cartesian (self , fmt ):
1604+ self ._run_cartesian_test (partial (ops .box_iou , fmt = fmt ))
1605+
1606+ @pytest .mark .parametrize ("fmt" , ["xyxy" , "xywh" , "cxcywh" ])
1607+ def test_iou_batch (self , fmt ):
1608+ self ._run_batch_test (partial (ops .box_iou , fmt = fmt ))
15531609
15541610
15551611class TestGeneralizedBoxIou (TestIouBase ):
@@ -1573,6 +1629,9 @@ def test_iou_jit(self):
15731629 def test_iou_cartesian (self ):
15741630 self ._run_cartesian_test (ops .generalized_box_iou )
15751631
1632+ def test_iou_batch (self ):
1633+ self ._run_batch_test (ops .generalized_box_iou )
1634+
15761635
15771636class TestDistanceBoxIoU (TestIouBase ):
15781637 int_expected = [
@@ -1600,6 +1659,9 @@ def test_iou_jit(self):
16001659 def test_iou_cartesian (self ):
16011660 self ._run_cartesian_test (ops .distance_box_iou )
16021661
1662+ def test_iou_batch (self ):
1663+ self ._run_batch_test (ops .distance_box_iou )
1664+
16031665
16041666class TestCompleteBoxIou (TestIouBase ):
16051667 int_expected = [
@@ -1627,6 +1689,9 @@ def test_iou_jit(self):
16271689 def test_iou_cartesian (self ):
16281690 self ._run_cartesian_test (ops .complete_box_iou )
16291691
1692+ def test_iou_batch (self ):
1693+ self ._run_batch_test (ops .complete_box_iou )
1694+
16301695
16311696def get_boxes (dtype , device ):
16321697 box1 = torch .tensor ([- 1 , - 1 , 1 , 1 ], dtype = dtype , device = device )
0 commit comments