Skip to content

Commit 2444381

Browse files
authored
Extends steal_forward to dump input, outputs in onnx models (#88)
* update * changelogs * dump stolen * fix type
1 parent a5f0c6e commit 2444381

File tree

9 files changed

+885
-22
lines changed

9 files changed

+885
-22
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.5.0
55
+++++
66

7+
* :pr:`88`: extends ``steal_forward`` to dump input, outputs in onnx models
78
* :pr:`83`, :pr:`85`: improves the automated rewriting of control flow (test)
89

910
0.4.4

_doc/api/helpers/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ onnx_diagnostic.helpers
1212
config_helper
1313
helper
1414
memory_peak
15+
mini_onnx_builder
1516
onnx_helper
1617
ort_session
1718
rt_helper
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
2+
onnx_diagnostic.helpers.mini_onnx_builder
3+
=========================================
4+
5+
.. automodule:: onnx_diagnostic.helpers.mini_onnx_builder
6+
:members:
7+
:no-undoc-members:
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
import unittest
2+
import numpy as np
3+
import torch
4+
from onnx_diagnostic.ext_test_case import ExtTestCase
5+
from onnx_diagnostic.reference import ExtendedReferenceEvaluator
6+
from onnx_diagnostic.helpers.mini_onnx_builder import (
7+
create_onnx_model_from_input_tensors,
8+
create_input_tensors_from_onnx_model,
9+
MiniOnnxBuilder,
10+
)
11+
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
12+
from onnx_diagnostic.helpers import string_type
13+
14+
15+
class TestMiniOnnxBuilder(ExtTestCase):
16+
def test_mini_onnx_builder_sequence_onnx(self):
17+
builder = MiniOnnxBuilder()
18+
builder.append_output_sequence("name", [np.array([6, 7])])
19+
onx = builder.to_onnx()
20+
ref = ExtendedReferenceEvaluator(onx)
21+
got = ref.run(None, {})
22+
self.assertEqualAny([np.array([6, 7])], got[0])
23+
24+
def test_mini_onnx_builder_sequence_ort(self):
25+
from onnxruntime import InferenceSession
26+
27+
builder = MiniOnnxBuilder()
28+
builder.append_output_sequence("name", [np.array([6, 7])])
29+
onx = builder.to_onnx()
30+
ref = InferenceSession(onx.SerializeToString(), providers=["CPUExecutionProvider"])
31+
got = ref.run(None, {})
32+
self.assertEqualAny([np.array([6, 7])], got[0])
33+
34+
def test_mini_onnx_builder(self):
35+
data = [
36+
(
37+
np.array([1, 2], dtype=np.int64),
38+
torch.tensor([4, 5], dtype=torch.float32),
39+
{
40+
"tt1": np.array([-1, -2], dtype=np.int64),
41+
"tt2": torch.tensor([-4, -5], dtype=torch.float32),
42+
},
43+
{},
44+
),
45+
{
46+
"t1": np.array([1, 2], dtype=np.int64),
47+
"t2": torch.tensor([4, 5], dtype=torch.float32),
48+
"d1": {
49+
"tt1": np.array([-1, -2], dtype=np.int64),
50+
"tt2": torch.tensor([-4, -5], dtype=torch.float32),
51+
},
52+
"d2": {},
53+
},
54+
(
55+
np.array([1, 2], dtype=np.int64),
56+
torch.tensor([4, 5], dtype=torch.float32),
57+
(
58+
np.array([-1, -2], dtype=np.int64),
59+
torch.tensor([-4, -5], dtype=torch.float32),
60+
),
61+
tuple(),
62+
),
63+
{
64+
"t1": np.array([1, 2], dtype=np.int64),
65+
"t2": torch.tensor([4, 5], dtype=torch.float32),
66+
"l1": (
67+
np.array([-1, -2], dtype=np.int64),
68+
torch.tensor([-4, -5], dtype=torch.float32),
69+
),
70+
"l2": tuple(),
71+
},
72+
# nested
73+
(
74+
{
75+
"t1": np.array([1, 2], dtype=np.int64),
76+
"t2": torch.tensor([4, 5], dtype=torch.float32),
77+
"l1": (
78+
np.array([-1, -2], dtype=np.int64),
79+
torch.tensor([-4, -5], dtype=torch.float32),
80+
),
81+
"l2": tuple(),
82+
},
83+
(
84+
np.array([1, 2], dtype=np.int64),
85+
torch.tensor([4, 5], dtype=torch.float32),
86+
(
87+
np.array([-1, -2], dtype=np.int64),
88+
torch.tensor([-4, -5], dtype=torch.float32),
89+
),
90+
tuple(),
91+
),
92+
),
93+
# simple
94+
np.array([1, 2], dtype=np.int64),
95+
torch.tensor([4, 5], dtype=torch.float32),
96+
(np.array([1, 2], dtype=np.int64), torch.tensor([4, 5], dtype=torch.float32)),
97+
[np.array([1, 2], dtype=np.int64), torch.tensor([4, 5], dtype=torch.float32)],
98+
{
99+
"t1": np.array([1, 2], dtype=np.int64),
100+
"t2": torch.tensor([4, 5], dtype=torch.float32),
101+
},
102+
(
103+
np.array([1, 2], dtype=np.int64),
104+
torch.tensor([4, 5], dtype=torch.float32),
105+
[
106+
np.array([-1, -2], dtype=np.int64),
107+
torch.tensor([-4, -5], dtype=torch.float32),
108+
],
109+
[],
110+
),
111+
{
112+
"t1": np.array([1, 2], dtype=np.int64),
113+
"t2": torch.tensor([4, 5], dtype=torch.float32),
114+
"l1": [
115+
np.array([-1, -2], dtype=np.int64),
116+
torch.tensor([-4, -5], dtype=torch.float32),
117+
],
118+
"l2": [],
119+
},
120+
]
121+
122+
for inputs in data:
123+
with self.subTest(types=string_type(inputs)):
124+
model = create_onnx_model_from_input_tensors(inputs)
125+
restored = create_input_tensors_from_onnx_model(model)
126+
self.assertEqualAny(inputs, restored)
127+
128+
def test_mini_onnx_builder_transformers(self):
129+
cache = make_dynamic_cache([(torch.ones((3, 3)), torch.ones((3, 3)) * 2)])
130+
self.assertEqual(len(cache.key_cache), 1)
131+
self.assertEqual(len(cache.value_cache), 1)
132+
133+
data = [(cache,), cache]
134+
135+
for inputs in data:
136+
with self.subTest(types=string_type(inputs)):
137+
model = create_onnx_model_from_input_tensors(inputs)
138+
restored = create_input_tensors_from_onnx_model(model)
139+
self.assertEqualAny(inputs, restored)
140+
141+
def test_mini_onnx_builder_transformers_sep(self):
142+
cache = make_dynamic_cache([(torch.ones((3, 3)), torch.ones((3, 3)) * 2)])
143+
self.assertEqual(len(cache.key_cache), 1)
144+
self.assertEqual(len(cache.value_cache), 1)
145+
146+
data = [(cache,), cache]
147+
148+
for inputs in data:
149+
with self.subTest(types=string_type(inputs)):
150+
model = create_onnx_model_from_input_tensors(inputs, sep="#")
151+
restored = create_input_tensors_from_onnx_model(model, sep="#")
152+
self.assertEqualAny(inputs, restored)
153+
154+
155+
if __name__ == "__main__":
156+
unittest.main(verbosity=2)

_unittests/ut_helpers/test_torch_test_helper.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
make_mamba_cache,
2121
make_sliding_window_cache,
2222
)
23+
from onnx_diagnostic.helpers.mini_onnx_builder import create_input_tensors_from_onnx_model
2324

2425
TFLOAT = onnx.TensorProto.FLOAT
2526

@@ -59,6 +60,90 @@ def forward(self, x, y):
5960
with steal_forward(model):
6061
model(*inputs)
6162

63+
@hide_stdout()
64+
def test_steal_forward_multi(self):
65+
class SubModel(torch.nn.Module):
66+
def forward(self, x):
67+
return x * x
68+
69+
class Model(torch.nn.Module):
70+
def __init__(self):
71+
super().__init__()
72+
self.s1 = SubModel()
73+
self.s2 = SubModel()
74+
75+
def forward(self, x, y):
76+
return self.s1(x) + self.s2(y)
77+
78+
inputs = torch.rand(3, 4), torch.rand(3, 4)
79+
model = Model()
80+
with steal_forward(
81+
[
82+
(
83+
"main",
84+
model,
85+
),
86+
(" s1", model.s1),
87+
(" s2", model.s2),
88+
]
89+
):
90+
model(*inputs)
91+
92+
@hide_stdout()
93+
def test_steal_forward_dump_file(self):
94+
class SubModel(torch.nn.Module):
95+
def forward(self, x):
96+
return x * x
97+
98+
class Model(torch.nn.Module):
99+
def __init__(self):
100+
super().__init__()
101+
self.s1 = SubModel()
102+
self.s2 = SubModel()
103+
104+
def forward(self, x, y):
105+
return self.s1(x) + self.s2(y)
106+
107+
inputs = torch.rand(3, 4), torch.rand(3, 4)
108+
model = Model()
109+
dump_file = self.get_dump_file("test_steal_forward_dump_file.onnx")
110+
with steal_forward(
111+
[
112+
(
113+
"main",
114+
model,
115+
),
116+
(" s1", model.s1),
117+
(" s2", model.s2),
118+
],
119+
dump_file=dump_file,
120+
):
121+
res1 = model(*inputs)
122+
res2 = model(*inputs)
123+
self.assertExists(dump_file)
124+
restored = create_input_tensors_from_onnx_model(dump_file)
125+
self.assertEqual(
126+
[
127+
("main", 0, "I"),
128+
("main", 0, "O"),
129+
("main", 1, "I"),
130+
("main", 1, "O"),
131+
("s1", 0, "I"),
132+
("s1", 0, "O"),
133+
("s1", 1, "I"),
134+
("s1", 1, "O"),
135+
("s2", 0, "I"),
136+
("s2", 0, "O"),
137+
("s2", 1, "I"),
138+
("s2", 1, "O"),
139+
],
140+
sorted(restored),
141+
)
142+
self.assertEqualAny(restored["main", 0, "I"], (inputs, {}))
143+
self.assertEqualAny(restored["main", 1, "I"], (inputs, {}))
144+
self.assertEqualAny(restored["main", 0, "O"], res1)
145+
self.assertEqualAny(restored["main", 0, "O"], res2)
146+
62147
def test_replace_string_by_dynamic(self):
63148
example = {
64149
"input_ids": {0: "batch_size", 1: "sequence_length"},

_unittests/ut_torch_export_patches/test_dynamic_class.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -340,8 +340,6 @@ def test_phi2_export_interpreter(self):
340340
)
341341
# ep = ep.run_decompositions()
342342

343-
# from experimental_experiment.torch_interpreter.tracing import CustomTracer
344-
# CustomTracer.remove_unnecessary_slices(ep.graph)
345343
memorize = []
346344

347345
class MyInterpreter(torch.fx.Interpreter):

onnx_diagnostic/helpers/memory_peak.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def start_spying_on(
223223
224224
.. code-block:: python
225225
226-
from experimental_experiment.memory_peak import start_spying_on, flatten
226+
from onnx_diagnostic.helpers.memory_peak import start_spying_on, flatten
227227
228228
p = start_spying_on()
229229
# ...

0 commit comments

Comments
 (0)