Skip to content

Commit ea8da0e

Browse files
committed
add MiniOnnxBuilder
2 parents 89c6ee7 + a5f0c6e commit ea8da0e

File tree

9 files changed

+739
-21
lines changed

9 files changed

+739
-21
lines changed

.github/workflows/check-urls.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ jobs:
3030
print_all: false
3131
timeout: 2
3232
retry_count# : 2
33-
exclude_urls: https://github.com/pytorch/pytorch/pull/117009,https://github.com/huggingface/transformers/pull/29285,https://github.com/pytorch/pytorch/blob/a44f8894fa6d973693aab44a3dda079a168b05c1/torch/_decomp/decompositions.py#L1475
34-
exclude_patterns: https://dumps.wikimedia.org/,https://github.com/pytorch/pytorch/pull/,https://github.com/pytorch/pytorch/blob/a44f8894fa6d973693aab44a3dda079a168b05c1/torch/_decomp/decompositions.py#L1475,https://huggingface.co/,https://huggingface.co/
33+
exclude_urls: https://github.com/pytorch/pytorch/pull/117009,https://github.com/huggingface/transformers/pull/29285,https://github.com/pytorch/pytorch/blob/a44f8894fa6d973693aab44a3dda079a168b05c1/torch/_decomp/decompositions.py#L1475,https://github.com/huggingface/transformers/pull/36652
34+
exclude_patterns: https://dumps.wikimedia.org/,https://github.com/pytorch/pytorch/pull/,https://github.com/pytorch/pytorch/blob/a44f8894fa6d973693aab44a3dda079a168b05c1/torch/_decomp/decompositions.py#L1475,https://huggingface.co/,https://huggingface.co/,https://github.com/huggingface/transformers/
3535
# force_pass : true
3636

3737
- name: urls-checker-docs

_doc/api/helpers/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ onnx_diagnostic.helpers
1212
config_helper
1313
helper
1414
memory_peak
15+
mini_onnx_builder
1516
onnx_helper
1617
ort_session
1718
rt_helper
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
2+
onnx_diagnostic.helpers.mini_onnx_builder
3+
=========================================
4+
5+
.. automodule:: onnx_diagnostic.helpers.mini_onnx_builder
6+
:members:
7+
:no-undoc-members:

_doc/recipes/plot_dynamic_shapes_nonzero.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,21 +35,23 @@ def adaptive_enc_mask(self, x_len, chunk_start_idx, left_window=0, right_window=
3535
mask_right = seq_range_expand < boundary_right.unsqueeze(-1)
3636
return mask_left & mask_right
3737

38-
def forward(self, x):
39-
return self.adaptive_enc_mask(x.shape[1], [])
38+
def forward(self, x, y):
39+
return self.adaptive_enc_mask(
40+
x.shape[1], torch.tensor([], dtype=torch.int64), left_window=y.shape[0]
41+
)
4042

4143

4244
model = Model()
43-
x = torch.rand((5, 8))
44-
y = model(x)
45-
print(f"x.shape={x.shape}, y.shape={y.shape}")
45+
x, y = torch.rand((2, 546)), torch.rand((18,))
46+
z = model(x, y)
47+
print(f"y.shape={x.shape}, y.shape={y.shape}, z.shape={z.shape}")
4648

4749
# %%
4850
# Export
4951
# ++++++
5052

5153
DYN = torch.export.Dim.DYNAMIC
52-
ep = torch.export.export(model, (x,), dynamic_shapes=(({0: DYN, 1: DYN}),))
54+
ep = torch.export.export(model, (x, y), dynamic_shapes=({0: DYN, 1: DYN}, {0: DYN}))
5355
print(ep)
5456

5557

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
import unittest
2+
import numpy as np
3+
import torch
4+
from onnx_diagnostic.ext_test_case import ExtTestCase
5+
from onnx_diagnostic.reference import ExtendedReferenceEvaluator
6+
from onnx_diagnostic.helpers.mini_onnx_builder import (
7+
create_onnx_model_from_input_tensors,
8+
create_input_tensors_from_onnx_model,
9+
MiniOnnxBuilder,
10+
)
11+
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
12+
from onnx_diagnostic.helpers import string_type
13+
14+
15+
class TestMiniOnnxBuilder(ExtTestCase):
16+
def test_mini_onnx_builder_sequence_onnx(self):
17+
builder = MiniOnnxBuilder()
18+
builder.append_output_sequence("name", [np.array([6, 7])])
19+
onx = builder.to_onnx()
20+
ref = ExtendedReferenceEvaluator(onx)
21+
got = ref.run(None, {})
22+
self.assertEqualAny([np.array([6, 7])], got[0])
23+
24+
def test_mini_onnx_builder_sequence_ort(self):
25+
from onnxruntime import InferenceSession
26+
27+
builder = MiniOnnxBuilder()
28+
builder.append_output_sequence("name", [np.array([6, 7])])
29+
onx = builder.to_onnx()
30+
ref = InferenceSession(onx.SerializeToString(), providers=["CPUExecutionProvider"])
31+
got = ref.run(None, {})
32+
self.assertEqualAny([np.array([6, 7])], got[0])
33+
34+
def test_mini_onnx_builder(self):
35+
data = [
36+
(
37+
np.array([1, 2], dtype=np.int64),
38+
torch.tensor([4, 5], dtype=torch.float32),
39+
{
40+
"tt1": np.array([-1, -2], dtype=np.int64),
41+
"tt2": torch.tensor([-4, -5], dtype=torch.float32),
42+
},
43+
{},
44+
),
45+
{
46+
"t1": np.array([1, 2], dtype=np.int64),
47+
"t2": torch.tensor([4, 5], dtype=torch.float32),
48+
"d1": {
49+
"tt1": np.array([-1, -2], dtype=np.int64),
50+
"tt2": torch.tensor([-4, -5], dtype=torch.float32),
51+
},
52+
"d2": {},
53+
},
54+
(
55+
np.array([1, 2], dtype=np.int64),
56+
torch.tensor([4, 5], dtype=torch.float32),
57+
(
58+
np.array([-1, -2], dtype=np.int64),
59+
torch.tensor([-4, -5], dtype=torch.float32),
60+
),
61+
tuple(),
62+
),
63+
{
64+
"t1": np.array([1, 2], dtype=np.int64),
65+
"t2": torch.tensor([4, 5], dtype=torch.float32),
66+
"l1": (
67+
np.array([-1, -2], dtype=np.int64),
68+
torch.tensor([-4, -5], dtype=torch.float32),
69+
),
70+
"l2": tuple(),
71+
},
72+
# nested
73+
(
74+
{
75+
"t1": np.array([1, 2], dtype=np.int64),
76+
"t2": torch.tensor([4, 5], dtype=torch.float32),
77+
"l1": (
78+
np.array([-1, -2], dtype=np.int64),
79+
torch.tensor([-4, -5], dtype=torch.float32),
80+
),
81+
"l2": tuple(),
82+
},
83+
(
84+
np.array([1, 2], dtype=np.int64),
85+
torch.tensor([4, 5], dtype=torch.float32),
86+
(
87+
np.array([-1, -2], dtype=np.int64),
88+
torch.tensor([-4, -5], dtype=torch.float32),
89+
),
90+
tuple(),
91+
),
92+
),
93+
# simple
94+
np.array([1, 2], dtype=np.int64),
95+
torch.tensor([4, 5], dtype=torch.float32),
96+
(np.array([1, 2], dtype=np.int64), torch.tensor([4, 5], dtype=torch.float32)),
97+
[np.array([1, 2], dtype=np.int64), torch.tensor([4, 5], dtype=torch.float32)],
98+
{
99+
"t1": np.array([1, 2], dtype=np.int64),
100+
"t2": torch.tensor([4, 5], dtype=torch.float32),
101+
},
102+
(
103+
np.array([1, 2], dtype=np.int64),
104+
torch.tensor([4, 5], dtype=torch.float32),
105+
[
106+
np.array([-1, -2], dtype=np.int64),
107+
torch.tensor([-4, -5], dtype=torch.float32),
108+
],
109+
[],
110+
),
111+
{
112+
"t1": np.array([1, 2], dtype=np.int64),
113+
"t2": torch.tensor([4, 5], dtype=torch.float32),
114+
"l1": [
115+
np.array([-1, -2], dtype=np.int64),
116+
torch.tensor([-4, -5], dtype=torch.float32),
117+
],
118+
"l2": [],
119+
},
120+
]
121+
122+
for inputs in data:
123+
with self.subTest(types=string_type(inputs)):
124+
model = create_onnx_model_from_input_tensors(inputs)
125+
restored = create_input_tensors_from_onnx_model(model)
126+
self.assertEqualAny(inputs, restored)
127+
128+
def test_mini_onnx_builder_transformers(self):
129+
cache = make_dynamic_cache([(torch.ones((3, 3)), torch.ones((3, 3)) * 2)])
130+
self.assertEqual(len(cache.key_cache), 1)
131+
self.assertEqual(len(cache.value_cache), 1)
132+
133+
data = [(cache,), cache]
134+
135+
for inputs in data:
136+
with self.subTest(types=string_type(inputs)):
137+
model = create_onnx_model_from_input_tensors(inputs)
138+
restored = create_input_tensors_from_onnx_model(model)
139+
self.assertEqualAny(inputs, restored)
140+
141+
def test_mini_onnx_builder_transformers_sep(self):
142+
cache = make_dynamic_cache([(torch.ones((3, 3)), torch.ones((3, 3)) * 2)])
143+
self.assertEqual(len(cache.key_cache), 1)
144+
self.assertEqual(len(cache.value_cache), 1)
145+
146+
data = [(cache,), cache]
147+
148+
for inputs in data:
149+
with self.subTest(types=string_type(inputs)):
150+
model = create_onnx_model_from_input_tensors(inputs, sep="#")
151+
restored = create_input_tensors_from_onnx_model(model, sep="#")
152+
self.assertEqualAny(inputs, restored)
153+
154+
155+
if __name__ == "__main__":
156+
unittest.main(verbosity=2)

_unittests/ut_torch_export_patches/test_dynamic_class.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -340,8 +340,6 @@ def test_phi2_export_interpreter(self):
340340
)
341341
# ep = ep.run_decompositions()
342342

343-
# from experimental_experiment.torch_interpreter.tracing import CustomTracer
344-
# CustomTracer.remove_unnecessary_slices(ep.graph)
345343
memorize = []
346344

347345
class MyInterpreter(torch.fx.Interpreter):

onnx_diagnostic/helpers/memory_peak.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def start_spying_on(
223223
224224
.. code-block:: python
225225
226-
from experimental_experiment.memory_peak import start_spying_on, flatten
226+
from onnx_diagnostic.helpers.memory_peak import start_spying_on, flatten
227227
228228
p = start_spying_on()
229229
# ...

0 commit comments

Comments
 (0)