Skip to content

Commit 3da85a8

Browse files
committed
jit
1 parent 1eab135 commit 3da85a8

File tree

1 file changed

+90
-0
lines changed

1 file changed

+90
-0
lines changed

_unittests/ut_export/test_jit.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import unittest
2+
from typing import Callable
3+
import torch
4+
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout
5+
from onnx_diagnostic.reference import ExtendedReferenceEvaluator
6+
from onnx_diagnostic.helpers.torch_test_helper import is_torchdynamo_exporting
7+
8+
9+
@torch.jit.script_if_tracing
10+
def dummy_loop(padded: torch.Tensor, pos: torch.Tensor):
11+
copy = torch.zeros(padded.shape)
12+
for i in range(pos.shape[0]):
13+
p = pos[i]
14+
copy[i, :p] = padded[i, :p]
15+
return copy
16+
17+
18+
def wrap_for_export(f: Callable) -> Callable:
19+
20+
class _wrapped(torch.nn.Module):
21+
def __init__(self):
22+
super().__init__()
23+
self.f = f
24+
25+
def forward(self, *args, **kwargs):
26+
return self.f(*args, **kwargs)
27+
28+
return _wrapped()
29+
30+
31+
def select_when_exporting(mod, f):
32+
if is_torchdynamo_exporting():
33+
return mod
34+
return f
35+
36+
37+
class TestJit(ExtTestCase):
38+
@hide_stdout()
39+
def test_export_loop(self):
40+
class Model(torch.nn.Module):
41+
def __init__(self):
42+
super().__init__()
43+
self.wrapped_f = wrap_for_export(dummy_loop)
44+
45+
def forward(self, images, position):
46+
return select_when_exporting(self.wrapped_f, dummy_loop)(images, position)
47+
48+
model = Model()
49+
x = torch.randn((5, 6))
50+
y = torch.arange(5, dtype=torch.int64) + 1
51+
expected = model(x, y)
52+
53+
name = self.get_dump_file("test_export_loop.onnx")
54+
torch.onnx.export(
55+
model,
56+
(x, y),
57+
name,
58+
dynamic_axes={"images": {0: "batch", 1: "maxdim"}, "position": {0: "batch"}},
59+
dynamo=False,
60+
)
61+
ref = ExtendedReferenceEvaluator(name)
62+
feeds = dict(images=x.numpy(), position=y.numpy())
63+
got = ref.run(None, feeds)[0]
64+
self.assertEqualArray(expected, got)
65+
66+
DYN = torch.export.Dim.DYNAMIC
67+
ep = torch.export.export(
68+
model,
69+
(x, y),
70+
dynamic_shapes={"images": {0: DYN, 1: DYN}, "position": {0: DYN}},
71+
)
72+
print(ep)
73+
74+
name2 = self.get_dump_file("test_export_loop.dynamo.onnx")
75+
torch.onnx.export(
76+
model,
77+
(x, y),
78+
name2,
79+
dynamic_axes={"images": {0: "batch", 1: "maxdim"}, "position": {0: "batch"}},
80+
dynamo=True,
81+
fallback=False,
82+
)
83+
ref = ExtendedReferenceEvaluator(name2)
84+
feeds = dict(images=x.numpy(), position=y.numpy())
85+
got = ref.run(None, feeds)[0]
86+
self.assertEqualArray(expected, got)
87+
88+
89+
if __name__ == "__main__":
90+
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)