@@ -1510,10 +1510,13 @@ class TestIouBase:
15101510 @staticmethod
15111511 def _run_test (target_fn : Callable , actual_box1 , actual_box2 , dtypes , atol , expected , fmt = "xyxy" ):
15121512 for dtype in dtypes :
1513- actual_box1 = ops .box_convert (torch .tensor (actual_box1 , dtype = dtype ), in_fmt = "xyxy" , out_fmt = fmt )
1514- actual_box2 = ops .box_convert (torch .tensor (actual_box2 , dtype = dtype ), in_fmt = "xyxy" , out_fmt = fmt )
1513+ _actual_box1 = ops .box_convert (torch .tensor (actual_box1 , dtype = dtype ), in_fmt = "xyxy" , out_fmt = fmt )
1514+ _actual_box2 = ops .box_convert (torch .tensor (actual_box2 , dtype = dtype ), in_fmt = "xyxy" , out_fmt = fmt )
15151515 expected_box = torch .tensor (expected )
1516- out = target_fn (actual_box1 , actual_box2 )
1516+ out = target_fn (
1517+ _actual_box1 ,
1518+ _actual_box2 ,
1519+ )
15171520 torch .testing .assert_close (out , expected_box , rtol = 0.0 , check_dtype = False , atol = atol )
15181521
15191522 @staticmethod
@@ -1569,7 +1572,16 @@ def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected, fmt):
15691572
15701573 @pytest .mark .parametrize ("fmt" , ["xyxy" , "xywh" , "cxcywh" ])
15711574 def test_iou_jit (self , fmt ):
1572- self ._run_jit_test (partial (ops .box_iou , fmt = fmt ), INT_BOXES , fmt )
1575+ class IoUJit (torch .nn .Module ):
1576+ def __init__ (self , fmt ):
1577+ super ().__init__ ()
1578+ self .iou = ops .box_iou
1579+ self .fmt = fmt
1580+
1581+ def forward (self , boxes1 , boxes2 ):
1582+ return self .iou (boxes1 , boxes2 )
1583+
1584+ self ._run_jit_test (IoUJit (fmt = fmt ), INT_BOXES , fmt )
15731585
15741586 @pytest .mark .parametrize ("fmt" , ["xyxy" , "xywh" , "cxcywh" ])
15751587 def test_iou_cartesian (self , fmt ):
0 commit comments