Skip to content

Commit ad6e61c

Browse files
committed
Merge branch 'main' of https://github.com/sdpython/onnx-diagnostic into cmp
2 parents 29cfaf6 + 2318738 commit ad6e61c

File tree

6 files changed

+150
-3
lines changed

6 files changed

+150
-3
lines changed
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
"""
2+
.. _l-plot-intermediate-results:
3+
4+
Dumps intermediate results of a torch model
5+
===========================================
6+
7+
Looking for discrepancies is quickly annoying. Discrepancies
8+
come from two results obtained with the same models
9+
implemented in two different ways, :epkg:`pytorch` and :epkg:`onnx`.
10+
Models are big so where do they come from? That's the
11+
unavoidable question. Unless there is an obvious reason,
12+
the only way is to compare intermediate outputs alon the computation.
13+
The first step into that direction is to dump the intermediate results
14+
coming from :epkg:`pytorch`.
15+
We use :func:`onnx_diagnostic.helpers.torch_helper.steal_forward` for that.
16+
17+
A simple LLM Model
18+
++++++++++++++++++
19+
20+
See :func:`onnx_diagnostic.helpers.torch_helper.dummy_llm`
21+
for its definition. It is mostly used for unit test or example.
22+
23+
"""
24+
25+
import onnx
26+
import torch
27+
from onnx_array_api.plotting.graphviz_helper import plot_dot
28+
from onnx_diagnostic import doc
29+
from onnx_diagnostic.helpers import string_type
30+
from onnx_diagnostic.helpers.torch_helper import dummy_llm
31+
from onnx_diagnostic.helpers.mini_onnx_builder import create_input_tensors_from_onnx_model
32+
from onnx_diagnostic.helpers.torch_helper import steal_forward
33+
34+
35+
model, inputs, ds = dummy_llm(dynamic_shapes=True)
36+
37+
print(f"type(model)={type(model)}")
38+
print(f"inputs={string_type(inputs, with_shape=True)}")
39+
print(f"ds={string_type(ds, with_shape=True)}")
40+
41+
# %%
42+
# It contains the following submodules.
43+
44+
for name, mod in model.named_modules():
45+
print(f"- {name}: {type(mod)}")
46+
47+
# %%
48+
# Steal and dump the output of submodules
49+
# +++++++++++++++++++++++++++++++++++++++
50+
#
51+
# The following context spies on the intermediate results
52+
# for the following module and submodules. It stores
53+
# in one onnx file all the input/output for those.
54+
55+
with steal_forward(
56+
[
57+
("model", model),
58+
("model.decoder", model.decoder),
59+
("model.decoder.attention", model.decoder.attention),
60+
("model.decoder.feed_forward", model.decoder.feed_forward),
61+
("model.decoder.norm_1", model.decoder.norm_1),
62+
("model.decoder.norm_2", model.decoder.norm_2),
63+
],
64+
dump_file="plot_dump_intermediate_results.inputs.onnx",
65+
verbose=1,
66+
storage_limit=2**28,
67+
):
68+
model(*inputs)
69+
70+
71+
# %%
72+
# Restores saved inputs/outputs
73+
# +++++++++++++++++++++++++++++
74+
#
75+
# All the intermediate tensors were saved in one unique onnx model,
76+
# every tensor is stored in a constant node.
77+
# The model can be run with any runtime to restore the inputs
78+
# and function :func:`create_input_tensors_from_onnx_model
79+
# <onnx_diagnostic.helpers.mini_onnx_builder.create_input_tensors_from_onnx_model>`
80+
# can restore their names.
81+
82+
saved_tensors = create_input_tensors_from_onnx_model(
83+
"plot_dump_intermediate_results.inputs.onnx"
84+
)
85+
for k, v in saved_tensors.items():
86+
print(f"{k} -- {string_type(v, with_shape=True)}")
87+
88+
# %%
89+
# Let's explained the naming convention.
90+
#
91+
# ::
92+
#
93+
# ('model.decoder.norm_2', 0, 'I') -- ((T1s2x30x16,),{})
94+
# | | |
95+
# | | +--> input, the format is args, kwargs
96+
# | |
97+
# | +--> iteration, 0 means the first time the execution
98+
# | went through that module
99+
# | it is possible to call multiple times,
100+
# | the model to store more
101+
# |
102+
# +--> the name given to function steal_forward
103+
#
104+
# The same goes for output except ``'I'`` is replaced by ``'O'``.
105+
#
106+
# ::
107+
#
108+
# ('model.decoder.norm_2', 0, 'O') -- T1s2x30x16
109+
#
110+
# This trick can be used to compare intermediate results coming
111+
# from pytorch to any other implementation of the same model
112+
# as long as it is possible to map the stored inputs/outputs.
113+
114+
# %%
115+
# Conversion to ONNX
116+
# ++++++++++++++++++
117+
#
118+
# The difficult point is to be able to map the saved intermediate
119+
# results to intermediate results in ONNX.
120+
# Let's create the ONNX model.
121+
122+
epo = torch.onnx.export(model, inputs, dynamic_shapes=ds, dynamo=True)
123+
epo.optimize()
124+
epo.save("plot_dump_intermediate_results.onnx")
125+
126+
# %%
127+
# It looks like the following.
128+
onx = onnx.load("plot_dump_intermediate_results.onnx")
129+
plot_dot(onx)
130+
131+
# %%
132+
doc.plot_legend("steal and dump\nintermediate\nresults", "steal_forward", "blue")

_doc/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ Enlightening Examples
9292
* :ref:`l-plot-failing-reference-evaluator`
9393
* :ref:`l-plot-failing-onnxruntime-evaluator`
9494
* :ref:`l-plot-failing-model-extract`
95+
* :ref:`l-plot-intermediate-results`
9596

9697
Some Usefuls Tools
9798
==================

_unittests/ut_xrun_doc/test_documentation_examples.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,15 @@ def add_test_methods(cls):
7474
this = os.path.abspath(os.path.dirname(__file__))
7575
fold = os.path.normpath(os.path.join(this, "..", "..", "_doc", "examples"))
7676
found = os.listdir(fold)
77+
has_dot = int(os.environ.get("UNITTEST_DOT", "0"))
7778
for name in found:
7879
if not name.endswith(".py") or not name.startswith("plot_"):
7980
continue
8081
reason = None
8182

83+
if not reason and not has_dot and name in {"plot_dump_intermediate_results.py"}:
84+
reason = "dot not installed"
85+
8286
if (
8387
not reason
8488
and name in {"plot_export_tiny_llm.py"}

_unittests/ut_xrun_doc/test_documentation_recipes.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,13 +73,16 @@ def add_test_methods(cls):
7373
this = os.path.abspath(os.path.dirname(__file__))
7474
fold = os.path.normpath(os.path.join(this, "..", "..", "_doc", "recipes"))
7575
found = os.listdir(fold)
76+
has_dot = int(os.environ.get("UNITTEST_DOT", "0"))
7677
for name in found:
7778
if not name.endswith(".py") or not name.startswith("plot_"):
7879
continue
7980
reason = None
8081

8182
if not reason and not has_torch("4.7"):
8283
reason = "torch<2.7"
84+
if not reason and not has_dot and name in {"plot_dump_intermediate_results.py"}:
85+
reason = "dot not installed"
8386

8487
if reason:
8588

onnx_diagnostic/helpers/mini_onnx_builder.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,8 @@ def create_onnx_model_from_input_tensors(
393393
Creates a model proto including all the value as initializers.
394394
They can be restored by executing the model.
395395
We assume these inputs are not bigger than 2Gb,
396-
the limit of protobuf.
396+
the limit of protobuf. Nothing is implemented yet to get around
397+
that limit.
397398
398399
:param inputs: anything
399400
:param switch_low_high: if None, it is equal to ``switch_low_high=sys.byteorder != "big"``
@@ -532,6 +533,8 @@ def create_input_tensors_from_onnx_model(
532533
:param engine: runtime to use, onnx, the default value, onnxruntime
533534
:param sep: separator
534535
:return: restored data
536+
537+
See example :ref:`l-plot-intermediate-results` for an example.
535538
"""
536539
if engine == "ExtendedReferenceEvaluator":
537540
from ..reference import ExtendedReferenceEvaluator

onnx_diagnostic/helpers/torch_helper.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,8 @@ def steal_forward(
288288
"""
289289
The necessary modification to steem forward method and prints out inputs
290290
and outputs using :func:`onnx_diagnostic.helpers.string_type`.
291-
See example :ref:`l-plot-tiny-llm-export`.
291+
See example :ref:`l-plot-tiny-llm-export` or
292+
:ref:`l-plot-intermediate-results`.
292293
293294
:param model: a model or a list of models to monitor,
294295
every model can also be a tuple(name, model), name is displayed well.
@@ -410,12 +411,15 @@ def forward(self, x, y):
410411
proto = create_onnx_model_from_input_tensors(storage)
411412
if verbose:
412413
print("-- dumps stored objects")
414+
location = f"{os.path.split(dump_file)[-1]}.data"
415+
if os.path.exists(location):
416+
os.remove(location)
413417
onnx.save(
414418
proto,
415419
dump_file,
416420
save_as_external_data=True,
417421
all_tensors_to_one_file=True,
418-
location=f"{os.path.split(dump_file)[-1]}.data",
422+
location=location,
419423
)
420424
if verbose:
421425
print("-- done dump stored objects")

0 commit comments

Comments
 (0)