Skip to content

Commit bc36d1e

Browse files
fix area test
1 parent 298cfa9 commit bc36d1e

File tree

1 file changed

+15
-3
lines changed

1 file changed

+15
-3
lines changed

test/test_ops.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1486,9 +1486,21 @@ def test_box_area_jit(self, fmt):
14861486
torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=torch.float), in_fmt="xyxy", out_fmt=fmt
14871487
)
14881488
expected = ops.box_area(box_tensor, fmt)
1489-
scripted_fn = torch.jit.script(ops.box_area)
1489+
class BoxArea(torch.nn.Module):
1490+
# We are using this intermediate class
1491+
# since torchscript does not support
1492+
# neither partial nor lambda functions for this test.
1493+
def __init__(self, fmt):
1494+
super().__init__()
1495+
self.area = ops.box_area
1496+
self.fmt = fmt
1497+
1498+
def forward(self, boxes):
1499+
return self.area(boxes, self.fmt)
1500+
1501+
scripted_fn = torch.jit.script(BoxArea(fmt))
14901502
scripted_area = scripted_fn(box_tensor)
1491-
torch.testing.assert_close(scripted_area, expected, fmt)
1503+
torch.testing.assert_close(scripted_area, expected)
14921504

14931505

14941506
INT_BOXES = [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300], [0, 0, 25, 25]]
@@ -1582,7 +1594,7 @@ def __init__(self, fmt):
15821594
self.fmt = fmt
15831595

15841596
def forward(self, boxes1, boxes2):
1585-
return self.iou(boxes1, boxes2)
1597+
return self.iou(boxes1, boxes2, fmt=self.fmt)
15861598

15871599
self._run_jit_test(IoUJit(fmt=fmt), INT_BOXES, fmt)
15881600

0 commit comments

Comments
 (0)