@@ -1486,9 +1486,21 @@ def test_box_area_jit(self, fmt):
14861486 torch .tensor ([[0 , 0 , 100 , 100 ], [0 , 0 , 0 , 0 ]], dtype = torch .float ), in_fmt = "xyxy" , out_fmt = fmt
14871487 )
14881488 expected = ops .box_area (box_tensor , fmt )
1489- scripted_fn = torch .jit .script (ops .box_area )
1489+ class BoxArea (torch .nn .Module ):
1490+ # We are using this intermediate class
1491+ # since torchscript does not support
1492+ # neither partial nor lambda functions for this test.
1493+ def __init__ (self , fmt ):
1494+ super ().__init__ ()
1495+ self .area = ops .box_area
1496+ self .fmt = fmt
1497+
1498+ def forward (self , boxes ):
1499+ return self .area (boxes , self .fmt )
1500+
1501+ scripted_fn = torch .jit .script (BoxArea (fmt ))
14901502 scripted_area = scripted_fn (box_tensor )
1491- torch .testing .assert_close (scripted_area , expected , fmt )
1503+ torch .testing .assert_close (scripted_area , expected )
14921504
14931505
14941506INT_BOXES = [[0 , 0 , 100 , 100 ], [0 , 0 , 50 , 50 ], [200 , 200 , 300 , 300 ], [0 , 0 , 25 , 25 ]]
@@ -1582,7 +1594,7 @@ def __init__(self, fmt):
15821594 self .fmt = fmt
15831595
15841596 def forward (self , boxes1 , boxes2 ):
1585- return self .iou (boxes1 , boxes2 )
1597+ return self .iou (boxes1 , boxes2 , fmt = self . fmt )
15861598
15871599 self ._run_jit_test (IoUJit (fmt = fmt ), INT_BOXES , fmt )
15881600
0 commit comments