Skip to content

Commit 2419114

Browse files
authored
Refactors sbs to save memory (#316)
* Refactors sbs to save memory * add * a few changes * fix a few things * fix ut * remove doc * fix import issues * disable * mypy
1 parent 2db1ba1 commit 2419114

File tree

10 files changed

+619
-287
lines changed

10 files changed

+619
-287
lines changed

CHANGELOGS.rst

Lines changed: 15 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ Change Logs
88
* :pr:`311`: use custom and local function to use PackedMultiHeadAttention from onnxruntime
99
* :pr:`310`: splits patches into multiple files
1010
* :pr:`308`: add option --save_ep to dump the exported program as well as torch input
11-
* :pr:`304`, :pr:`306`: improves side-by-side comparison, creates command line sbs
11+
* :pr:`304`, :pr:`306`, :pr:`316`: improves side-by-side comparison, creates command line sbs
1212

1313
0.8.2
1414
+++++
@@ -112,8 +112,7 @@ Change Logs
112112
* :pr:`203`: Add option to disable patches for torch in command line validate
113113
* :pr:`202`: add models DeepseekV3ForCausalLM, Gemma3ForCausalLM, Glm4vMoeForConditionalGeneration
114114
* :pr:`201`: switch CI to 4.55.4
115-
* :pr:`200`: fixes patches for 4.55.1+, DynamicCache is no longer registered by default,
116-
this code moved to executorch.py in transformers
115+
* :pr:`200`: fixes patches for 4.55.1+, DynamicCache is no longer registered by default, this code moved to executorch.py in transformers
117116
* :pr:`199`: delete hidden_size and num_attention_heads modification in a config
118117
* :pr:`198`: support gpt-oss
119118
* :pr:`197`: updates CI for torch 2.8
@@ -124,15 +123,13 @@ Change Logs
124123

125124
* :pr:`193`: validates with 4.53.3
126125
* :pr:`189`: support for task mask-generation
127-
* :pr:`192`: add support for Gemma-3, add serialization for HybridCache,
128-
changes to support ``transformers>=4.54``
126+
* :pr:`192`: add support for Gemma-3, add serialization for HybridCache, changes to support ``transformers>=4.54``
129127

130128
0.7.5
131129
+++++
132130

133131
* :pr:`186`: add parameter --output_names to command line validate to change the output names of the onnx exported model
134-
* :pr:`185`: remove the use of _seen_tokens in DynamicCache (removed in transformers>4.53),
135-
updates dummpy inputs for feature-extraction
132+
* :pr:`185`: remove the use of _seen_tokens in DynamicCache (removed in ``transformers>4.53``), updates dummpy inputs for feature-extraction
136133
* :pr:`184`: implements side-by-side
137134

138135
0.7.4
@@ -172,12 +169,8 @@ Change Logs
172169
* :pr:`147`: simplified log processing
173170
* :pr:`146`: patch for IdeficsAttention, IdeficsEmbedding
174171
* :pr:`145`: patch for _compute_dynamic_ntk_parameters (Phi3RotaryEmbedding)
175-
* :pr:`144`: support for second inputs with different dimension,
176-
rename test_helper into validate,
177-
support ``interpolate_pos_encoding`` for ``VitModel``,
178-
update model builder helpers for this PR
179-
`Use ONNX IR for model builder
180-
<https://github.com/microsoft/onnxruntime-genai/pull/1416>`_
172+
* :pr:`144`: support for second inputs with different dimension, rename test_helper into validate, support ``interpolate_pos_encoding`` for ``VitModel``, update model builder helpers for this PR
173+
`Use ONNX IR for model builder <https://github.com/microsoft/onnxruntime-genai/pull/1416>`_
181174
* :pr:`143`: compares intermediate results,
182175

183176
0.6.3
@@ -199,8 +192,7 @@ Change Logs
199192
* :pr:`123`: add subgraphs to TorchOnnxEvaluator
200193
* :pr:`122`: add local functions to TorchOnnxEvaluator
201194
* :pr:`120`: enables TorchOnnxEvaluator in command line ``python -m onnx_diagnostic validate ...``
202-
* :pr:`115`, :pr:`116`, :pr:`117`, :pr:`118`, :pr:`119`, :pr:`127`:
203-
first steps for TorchOnnxEvaluator
195+
* :pr:`115`, :pr:`116`, :pr:`117`, :pr:`118`, :pr:`119`, :pr:`127`: first steps for TorchOnnxEvaluator
204196
* :pr:`114`: extends the list of known rewritings
205197
* :pr:`113`: fixes a couple of issues with ModelBuilder
206198

@@ -257,10 +249,7 @@ Change Logs
257249
* :pr:`65`: support SlidingWindowCache
258250
* :pr:`63`: support option ``--trained``
259251
* :pr:`61`: improves dynamic shapes for EncoderDecoderCache
260-
* :pr:`58`: add function use_dyn_not_str to replace string by ``torch.export.Dim.DYNAMIC``,
261-
use string instead of ``torch.export.Dim.DYNAMIC`` when returning the dynamic shapes
262-
for a specific models, it is a valid definition for ``torch.onnx.export``
263-
which can reuse the names
252+
* :pr:`58`: add function use_dyn_not_str to replace string by ``torch.export.Dim.DYNAMIC``, use string instead of ``torch.export.Dim.DYNAMIC`` when returning the dynamic shapes for a specific models, it is a valid definition for ``torch.onnx.export`` which can reuse the names
264253
* :pr:`55`: add support for text-classification
265254
* :pr:`54`: add support for fill-mask, refactoring
266255
* :pr:`52`: add support for zero-shot-image-classification
@@ -274,28 +263,18 @@ Change Logs
274263
* :pr:`43`: uses custom patches
275264
* :pr:`38`: uses the registered serialization functions when it is available
276265
* :pr:`30`, :pr:`31`: adds command to test a model id, validate the export
277-
* :pr:`29`: adds helpers to measure the memory peak and run benchmark
278-
on different processes
279-
* :pr:`28`: adds command line to print out the configuration for a model id,
280-
support image-text-to-text
281-
* :pr:`26`: creates a folder ``helpers`` to gather all the functions
282-
used in many places
283-
* :pr:`25`: improve patches for DynamicCache
284-
(issue with register_pytree_flatten_spec being deprecated)
285-
* :pr:`24`: dummy inputs for ``text2text-generation``, add new function
286-
``convert_dynamic_axes_into_dynamic_shapes`` to convert dynamic axes
287-
into dynamic shapes, add support for ``T5ForConditionalGeneration``
266+
* :pr:`29`: adds helpers to measure the memory peak and run benchmark on different processes
267+
* :pr:`28`: adds command line to print out the configuration for a model id, support image-text-to-text
268+
* :pr:`26`: creates a folder ``helpers`` to gather all the functions used in many places
269+
* :pr:`25`: improve patches for DynamicCache (issue with register_pytree_flatten_spec being deprecated)
270+
* :pr:`24`: dummy inputs for ``text2text-generation``, add new function ``convert_dynamic_axes_into_dynamic_shapes`` to convert dynamic axes into dynamic shapes, add support for ``T5ForConditionalGeneration``
288271
* :pr:`23`: dummy inputs for ``image-classification``
289-
* :pr:`22`, :pr:`27`: api to create untrained model copying the architecture
290-
of the trained models and dummy inputs for them,
291-
support for ``text-generation``
272+
* :pr:`22`, :pr:`27`: api to create untrained model copying the architecture of the trained models and dummy inputs for them, support for ``text-generation``
292273

293274
0.2.1
294275
+++++
295276

296-
* :pr:`16`: refactors patches, add model Phi2, implements
297-
a tweak to raise an exception with a dynamic dimension
298-
becomes static when exporting a model
277+
* :pr:`16`: refactors patches, add model Phi2, implements a tweak to raise an exception with a dynamic dimension becomes static when exporting a model
299278

300279
0.2.0
301280
+++++

_doc/cmds/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@ Command Lines
99
:maxdepth: 1
1010

1111
config
12+
sbs
1213
validate

_doc/cmds/sbs.rst

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
-m onnx_diagnostic sbs ... runs a side-by-side torch/onnx
2+
=========================================================
3+
4+
Description
5+
+++++++++++
6+
7+
It compares the intermediate results between an exported program saved with
8+
:func:`torch.export.save` and an exported model on saved inputs
9+
with :func:`torch.save`. It assumes intermediate results share the same
10+
names.
11+
12+
.. runpython::
13+
14+
from onnx_diagnostic._command_lines_parser import get_parser_sbs
15+
16+
get_parser_sbs().print_help()
17+
18+
CPU, CUDA
19+
+++++++++
20+
21+
Inputs are saved :func:`torch.save`. The execution will run on CUDA
22+
if the device of the inputs is CUDA, same goes on CPU.

_unittests/ut_torch_onnx/test_sbs.py

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,15 @@ def test_run_aligned_record(self):
2929
onnx_name="B",
3030
ep_target="C",
3131
onnx_op_type="D",
32-
shape_type="E",
32+
ep_shape_type="E",
3333
err_abs=0.1,
3434
err_rel=0.2,
3535
err_dev=0.3,
3636
err_nan=0.4,
3737
)
3838
sr = str(r)
3939
self.assertIn("RunAlignedRecord(", sr)
40-
self.assertIn("shape_type='E'", sr)
40+
self.assertIn("ep_shape_type='E'", sr)
4141

4242
@hide_stdout()
4343
@unittest.skipIf(to_onnx is None, "to_onnx not installed")
@@ -303,8 +303,8 @@ def forward(self, x):
303303
)
304304
self.assertEqual(len(results), 14)
305305
self.assertEqual(
306-
[r.err_dev for r in results],
307306
[None, None, None, None, None, None, None, None, 0, 0, 0, 0, 0, 0],
307+
[r.err_dev for r in results],
308308
)
309309

310310
@hide_stdout()
@@ -349,29 +349,27 @@ def forward(self, x):
349349
[
350350
"ep_id_node",
351351
"ep_name",
352+
"ep_shape_type",
352353
"ep_target",
353354
"ep_time_run",
354355
"err_abs",
355356
"err_dev",
357+
"err_h01",
356358
"err_nan",
357359
"err_rel",
358360
"onnx_id_node",
359361
"onnx_id_output",
360362
"onnx_name",
361363
"onnx_op_type",
364+
"onnx_shape_type",
362365
"onnx_time_run",
363-
"shape_type",
364366
],
365367
sorted(df.columns),
366368
)
367-
self.assertEqual(len(results), 12)
368-
self.assertEqual(
369-
[r.err_dev for r in results],
370-
[None, None, None, None, None, None, None, None, None, 0, 0, 0],
371-
)
369+
self.assertEqual(len(results), 8)
370+
self.assertEqual([0, 0, 0, 0, None, 0, 0, 0], [r.err_dev for r in results])
372371
self.assertEqual(
373-
[-1.0, -1.0, -1.0, -1.0, -10.0, -10.0, -10.0, -10.0, -1.0, 0.0, 1.0, 2.0],
374-
df["onnx_id_node"].fillna(-10).tolist(),
372+
[-1, -1, -1, -1, -1, 0, 1, 2], df["onnx_id_node"].fillna(-10).tolist()
375373
)
376374
self.clean_dump()
377375

@@ -417,29 +415,27 @@ def forward(self, x):
417415
[
418416
"ep_id_node",
419417
"ep_name",
418+
"ep_shape_type",
420419
"ep_target",
421420
"ep_time_run",
422421
"err_abs",
423422
"err_dev",
423+
"err_h01",
424424
"err_nan",
425425
"err_rel",
426426
"onnx_id_node",
427427
"onnx_id_output",
428428
"onnx_name",
429429
"onnx_op_type",
430+
"onnx_shape_type",
430431
"onnx_time_run",
431-
"shape_type",
432432
],
433433
sorted(df.columns),
434434
)
435-
self.assertEqual(len(results), 12)
436-
self.assertEqual(
437-
[r.err_dev for r in results],
438-
[None, None, None, None, None, None, None, None, None, 0, 0, 0],
439-
)
435+
self.assertEqual(len(results), 8)
436+
self.assertEqual([0, 0, 0, 0, None, 0, 0, 0], [r.err_dev for r in results])
440437
self.assertEqual(
441-
[-1.0, -1.0, -1.0, -1.0, -10.0, -10.0, -10.0, -10.0, -1.0, 0.0, 1.0, 2.0],
442-
df["onnx_id_node"].fillna(-10).tolist(),
438+
[-1, -1, -1, -1, -1, 0, 1, 2], df["onnx_id_node"].fillna(-10).tolist()
443439
)
444440
self.clean_dump()
445441

_unittests/ut_xrun_doc/test_command_lines_exe.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
import unittest
33
from contextlib import redirect_stdout
44
from io import StringIO
5+
import pandas
56
import torch
6-
from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings
7+
from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings, requires_transformers
78
from onnx_diagnostic._command_lines_parser import main
89
from onnx_diagnostic.helpers.log_helper import enumerate_csv_files
10+
from onnx_diagnostic.export.api import to_onnx
911

1012

1113
class TestCommandLines(ExtTestCase):
@@ -88,6 +90,73 @@ def test_g_parser_agg(self):
8890
self.assertIn("[CubeLogs.to_excel] plots 1 plots", text)
8991
self.assertExists(output)
9092

93+
@ignore_warnings(UserWarning)
94+
@requires_transformers("4.53")
95+
def test_h_parser_sbs(self):
96+
import torch
97+
98+
class Model(torch.nn.Module):
99+
def __init__(self):
100+
super(Model, self).__init__()
101+
self.fc1 = torch.nn.Linear(10, 32) # input size 10 → hidden size 32
102+
self.relu = torch.nn.ReLU()
103+
self.fc2 = torch.nn.Linear(32, 1) # hidden → output
104+
105+
def forward(self, x):
106+
x = self.relu(self.fc1(x))
107+
x = self.fc2(x)
108+
return x
109+
110+
inputs = dict(x=torch.randn((5, 10)))
111+
ds = dict(x={0: "batch"})
112+
input_file = self.get_dump_file("test_h_parser_sbs.inputs.pt")
113+
ep_file = self.get_dump_file("test_h_parser_sbs.ep")
114+
onnx_file = self.get_dump_file("test_h_parser_sbs.model.onnx")
115+
torch.save(inputs, input_file)
116+
to_onnx(
117+
Model(),
118+
kwargs=inputs,
119+
dynamic_shapes=ds,
120+
exporter="custom",
121+
save_ep=(ep_file, 2**30),
122+
filename=onnx_file,
123+
)
124+
125+
output = self.get_dump_file("test_h_parser_sbs.xlsx")
126+
st = StringIO()
127+
with redirect_stdout(st):
128+
main(
129+
[
130+
"sbs",
131+
"-v",
132+
"2",
133+
"--first",
134+
"-i",
135+
input_file,
136+
"-e",
137+
f"{ep_file}.ep.pt2",
138+
"-o",
139+
output,
140+
"-m",
141+
onnx_file,
142+
]
143+
)
144+
text = st.getvalue()
145+
self.assertIn("[run_aligned", text)
146+
self.assertExists(output)
147+
df = pandas.read_excel(output).apply(
148+
lambda col: col.fillna("") if col.dtype == "object" else col
149+
)
150+
self.assertLess(df["err_abs"].max(), 1e-5)
151+
self.assertEqual(df["err_h01"].max(), 0)
152+
self.assertIn("p_fc1_weight", set(df["ep_name"]))
153+
self.assertIn("fc1.bias", set(df["onnx_name"]))
154+
self.assertNotIn("NaN", set(df["ep_name"]))
155+
# print(f"{df}\n{st.getvalue()}")
156+
self.assertIn("[run_aligned] done", st.getvalue())
157+
sdf = df[(df.ep_target == "placeholder") & (df.onnx_op_type == "initializer")]
158+
self.assertEqual(sdf.shape[0], 4)
159+
91160

92161
if __name__ == "__main__":
93162
unittest.main(verbosity=2)

clean_onnx.sh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ rm _plot_torch_sklearn_201_knnpy.py
3030

3131
rm _doc/sg_execution_times.rst
3232

33-
rm _doc/examples/plot*.onnx
33+
rm _doc/examples/_debug*
34+
rm _doc/examples/plot*.onnx*
3435
rm _doc/examples/plot*.txt
3536
rm _doc/examples/ort*.onnx
3637
rm _doc/examples/*.sarif
@@ -83,6 +84,7 @@ rm _doc/technical/*.dynamo.onnx
8384
rm _doc/technical/*.script.onnx
8485
rm _doc/technical/dump_models -rf
8586
rm _doc/technical/dump_onx_*
87+
rm _doc/technical/model_*.onnx* -rf
8688

8789
rm _tools/bin -rf
8890
rm _tools/mambaroot -rf

0 commit comments

Comments
 (0)