Skip to content

Commit 5138b69

Browse files
committed
jit
1 parent 2e47ee1 commit 5138b69

File tree

1 file changed

+69
-26
lines changed

1 file changed

+69
-26
lines changed

_unittests/ut_export/test_jit.py

Lines changed: 69 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
import unittest
2-
from typing import Callable
32
import torch
4-
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout
3+
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout, ignore_warnings
54
from onnx_diagnostic.reference import ExtendedReferenceEvaluator
65
from onnx_diagnostic.helpers.torch_test_helper import is_torchdynamo_exporting
76

7+
try:
8+
from experimental_experiment.torch_interpreter import to_onnx
9+
except ImportError:
10+
to_onnx = None
11+
812

913
@torch.jit.script_if_tracing
1014
def dummy_loop(padded: torch.Tensor, pos: torch.Tensor):
@@ -15,42 +19,53 @@ def dummy_loop(padded: torch.Tensor, pos: torch.Tensor):
1519
return copy
1620

1721

18-
def wrap_for_export(f: Callable) -> Callable:
19-
20-
class _wrapped(torch.nn.Module):
21-
def __init__(self):
22-
super().__init__()
23-
self.f = f
22+
def dummy_loop_with_scan(padded: torch.Tensor, pos: torch.Tensor):
23+
def pad_row(padded, p):
24+
row = torch.zeros((padded.shape[0],))
25+
torch._check(p.item() > 0)
26+
torch._check(p.item() < padded.shape[0])
27+
# this check is not always true, we add it anyway to make this dimension >= 2
28+
# and avoid raising an exception about dynamic dimension in {0, 1}
29+
if is_torchdynamo_exporting():
30+
torch._check(p.item() > 1)
31+
row[: p.item()] = padded[: p.item()]
32+
return (row,)
2433

25-
def forward(self, *args, **kwargs):
26-
return self.f(*args, **kwargs)
34+
return torch.ops.higher_order.scan(
35+
pad_row,
36+
[],
37+
[padded, pos],
38+
[],
39+
)
2740

28-
return _wrapped()
2941

30-
31-
def select_when_exporting(mod, f):
32-
if is_torchdynamo_exporting():
33-
return mod
34-
return f
42+
def select_when_exporting(f, f_scan):
43+
return f_scan if is_torchdynamo_exporting() else f
3544

3645

3746
class TestJit(ExtTestCase):
47+
def test_dummy_loop(self):
48+
x = torch.randn((5, 6))
49+
y = torch.arange(5, dtype=torch.int64) + 1
50+
res = dummy_loop(x, y)
51+
res_scan = dummy_loop_with_scan(x, y)
52+
self.assertEqualArray(res, res_scan[0])
53+
3854
@hide_stdout()
39-
def test_export_loop(self):
55+
@ignore_warnings(UserWarning)
56+
def test_export_loop_onnxscript(self):
4057
class Model(torch.nn.Module):
41-
def __init__(self):
42-
super().__init__()
43-
self.wrapped_f = wrap_for_export(dummy_loop)
44-
4558
def forward(self, images, position):
46-
return select_when_exporting(self.wrapped_f, dummy_loop)(images, position)
59+
return select_when_exporting(dummy_loop, dummy_loop_with_scan)(
60+
images, position
61+
)
4762

4863
model = Model()
4964
x = torch.randn((5, 6))
5065
y = torch.arange(5, dtype=torch.int64) + 1
5166
expected = model(x, y)
5267

53-
name = self.get_dump_file("test_export_loop.onnx")
68+
name = self.get_dump_file("test_export_loop_onnxscript.onnx")
5469
torch.onnx.export(
5570
model,
5671
(x, y),
@@ -68,15 +83,16 @@ def forward(self, images, position):
6883
model,
6984
(x, y),
7085
dynamic_shapes={"images": {0: DYN, 1: DYN}, "position": {0: DYN}},
86+
strict=False,
7187
)
72-
print(ep)
88+
self.assertNotEmpty(ep)
7389

74-
name2 = self.get_dump_file("test_export_loop.dynamo.onnx")
90+
name2 = self.get_dump_file("test_export_loop_onnxscript.dynamo.onnx")
7591
torch.onnx.export(
7692
model,
7793
(x, y),
7894
name2,
79-
dynamic_axes={"images": {0: "batch", 1: "maxdim"}, "position": {0: "batch"}},
95+
dynamic_shapes={"images": {0: "batch", 1: "maxdim"}, "position": {0: "batch"}},
8096
dynamo=True,
8197
fallback=False,
8298
)
@@ -85,6 +101,33 @@ def forward(self, images, position):
85101
got = ref.run(None, feeds)[0]
86102
self.assertEqualArray(expected, got)
87103

104+
@hide_stdout()
105+
@ignore_warnings(UserWarning)
106+
@unittest.skipIf(to_onnx is None, "missing to_onnx")
107+
def test_export_loop_custom(self):
108+
class Model(torch.nn.Module):
109+
def forward(self, images, position):
110+
return select_when_exporting(dummy_loop, dummy_loop_with_scan)(
111+
images, position
112+
)
113+
114+
model = Model()
115+
x = torch.randn((5, 6))
116+
y = torch.arange(5, dtype=torch.int64) + 1
117+
expected = model(x, y)
118+
119+
name2 = self.get_dump_file("test_export_loop.custom.onnx")
120+
to_onnx(
121+
model,
122+
(x, y),
123+
filename=name2,
124+
dynamic_shapes={"images": {0: "batch", 1: "maxdim"}, "position": {0: "batch"}},
125+
)
126+
ref = ExtendedReferenceEvaluator(name2)
127+
feeds = dict(images=x.numpy(), position=y.numpy())
128+
got = ref.run(None, feeds)[0]
129+
self.assertEqualArray(expected, got)
130+
88131

89132
if __name__ == "__main__":
90133
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)