Skip to content

Commit 29012cd

Browse files
authored
Experiment around jit (#57)
* jit * jit * fix ut * ut
1 parent edec507 commit 29012cd

File tree

1 file changed

+147
-0
lines changed

1 file changed

+147
-0
lines changed

_unittests/ut_export/test_jit.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
import inspect
2+
import unittest
3+
import torch
4+
from onnx_diagnostic.ext_test_case import (
5+
ExtTestCase,
6+
hide_stdout,
7+
ignore_warnings,
8+
requires_onnxscript,
9+
)
10+
from onnx_diagnostic.reference import ExtendedReferenceEvaluator
11+
from onnx_diagnostic.helpers.torch_test_helper import is_torchdynamo_exporting
12+
13+
try:
14+
from experimental_experiment.torch_interpreter import to_onnx
15+
except ImportError:
16+
to_onnx = None
17+
18+
19+
has_scan_reverse = "reverse" in set(inspect.signature(torch.ops.higher_order.scan).parameters)
20+
21+
22+
@torch.jit.script_if_tracing
23+
def dummy_loop(padded: torch.Tensor, pos: torch.Tensor):
24+
copy = torch.zeros(padded.shape)
25+
for i in range(pos.shape[0]):
26+
p = pos[i]
27+
copy[i, :p] = padded[i, :p]
28+
return copy
29+
30+
31+
def dummy_loop_with_scan(padded: torch.Tensor, pos: torch.Tensor):
32+
def pad_row(padded, p):
33+
row = torch.zeros((padded.shape[0],))
34+
torch._check(p.item() > 0)
35+
torch._check(p.item() < padded.shape[0])
36+
# this check is not always true, we add it anyway to make this dimension >= 2
37+
# and avoid raising an exception about dynamic dimension in {0, 1}
38+
if is_torchdynamo_exporting():
39+
torch._check(p.item() > 1)
40+
row[: p.item()] = padded[: p.item()]
41+
return (row,)
42+
43+
if has_scan_reverse:
44+
# torch==2.6
45+
return torch.ops.higher_order.scan(
46+
pad_row, [], [padded, pos], additional_inputs=[], reverse=False, dim=0
47+
)
48+
return torch.ops.higher_order.scan(pad_row, [], [padded, pos], [])
49+
50+
51+
def select_when_exporting(f, f_scan):
52+
return f_scan if is_torchdynamo_exporting() else f
53+
54+
55+
class TestJit(ExtTestCase):
56+
def test_dummy_loop(self):
57+
x = torch.randn((5, 6))
58+
y = torch.arange(5, dtype=torch.int64) + 1
59+
res = dummy_loop(x, y)
60+
res_scan = dummy_loop_with_scan(x, y)
61+
self.assertEqualArray(res, res_scan[0])
62+
63+
@hide_stdout()
64+
@ignore_warnings(UserWarning)
65+
@requires_onnxscript("0.4")
66+
def test_export_loop_onnxscript(self):
67+
class Model(torch.nn.Module):
68+
def forward(self, images, position):
69+
return select_when_exporting(dummy_loop, dummy_loop_with_scan)(
70+
images, position
71+
)
72+
73+
model = Model()
74+
x = torch.randn((5, 6))
75+
y = torch.arange(5, dtype=torch.int64) + 1
76+
expected = model(x, y)
77+
78+
name = self.get_dump_file("test_export_loop_onnxscript.onnx")
79+
torch.onnx.export(
80+
model,
81+
(x, y),
82+
name,
83+
dynamic_axes={"images": {0: "batch", 1: "maxdim"}, "position": {0: "batch"}},
84+
dynamo=False,
85+
)
86+
ref = ExtendedReferenceEvaluator(name)
87+
feeds = dict(images=x.numpy(), position=y.numpy())
88+
got = ref.run(None, feeds)[0]
89+
self.assertEqualArray(expected, got)
90+
91+
DYN = torch.export.Dim.DYNAMIC
92+
ep = torch.export.export(
93+
model,
94+
(x, y),
95+
dynamic_shapes={"images": {0: DYN, 1: DYN}, "position": {0: DYN}},
96+
strict=False,
97+
)
98+
self.assertNotEmpty(ep)
99+
100+
name2 = self.get_dump_file("test_export_loop_onnxscript.dynamo.onnx")
101+
torch.onnx.export(
102+
model,
103+
(x, y),
104+
name2,
105+
dynamic_shapes={"images": {0: "batch", 1: "maxdim"}, "position": {0: "batch"}},
106+
dynamo=True,
107+
fallback=False,
108+
)
109+
import onnxruntime
110+
111+
ref = onnxruntime.InferenceSession(name2, providers=["CPUExecutionProvider"])
112+
feeds = dict(images=x.numpy(), position=y.numpy())
113+
got = ref.run(None, feeds)[0]
114+
self.assertEqualArray(expected, got)
115+
116+
@hide_stdout()
117+
@ignore_warnings(UserWarning)
118+
@unittest.skipIf(to_onnx is None, "missing to_onnx")
119+
def test_export_loop_custom(self):
120+
class Model(torch.nn.Module):
121+
def forward(self, images, position):
122+
return select_when_exporting(dummy_loop, dummy_loop_with_scan)(
123+
images, position
124+
)
125+
126+
model = Model()
127+
x = torch.randn((5, 6))
128+
y = torch.arange(5, dtype=torch.int64) + 1
129+
expected = model(x, y)
130+
131+
name2 = self.get_dump_file("test_export_loop.custom.onnx")
132+
to_onnx(
133+
model,
134+
(x, y),
135+
filename=name2,
136+
dynamic_shapes={"images": {0: "batch", 1: "maxdim"}, "position": {0: "batch"}},
137+
)
138+
import onnxruntime
139+
140+
ref = onnxruntime.InferenceSession(name2, providers=["CPUExecutionProvider"])
141+
feeds = dict(images=x.numpy(), position=y.numpy())
142+
got = ref.run(None, feeds)[0]
143+
self.assertEqualArray(expected, got)
144+
145+
146+
if __name__ == "__main__":
147+
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)