Skip to content

Commit ef48a6c

Browse files
committed
update index
1 parent a3fdab0 commit ef48a6c

File tree

4 files changed

+123
-1
lines changed

4 files changed

+123
-1
lines changed
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
"""
2+
.. _l-plot-intermediate-results:
3+
4+
Dumps intermediate results of a torch model
5+
===========================================
6+
7+
8+
codellama/CodeLlama-7b-Python-hf
9+
++++++++++++++++++++++++++++++++
10+
11+
"""
12+
13+
import onnx
14+
import torch
15+
from onnx_array_api.plotting.graphviz_helper import plot_dot
16+
from onnx_diagnostic import doc
17+
from onnx_diagnostic.helpers import string_type
18+
from onnx_diagnostic.helpers.torch_helper import dummy_llm
19+
from onnx_diagnostic.helpers.mini_onnx_builder import create_input_tensors_from_onnx_model
20+
from onnx_diagnostic.helpers.torch_helper import steal_forward
21+
22+
23+
model, inputs, ds = dummy_llm(dynamic_shapes=True)
24+
25+
print(f"type(model)={type(model)}")
26+
print(f"inputs={string_type(inputs, with_shape=True)}")
27+
print(f"ds={string_type(ds, with_shape=True)}")
28+
29+
# %%
30+
# It contains the following submodules.
31+
32+
for name, mod in model.named_modules():
33+
print(f"- {name}: {type(mod)}")
34+
35+
# %%
36+
# Steal and dump the output of submodules
37+
# +++++++++++++++++++++++++++++++++++++++
38+
#
39+
# The following context spies on the intermediate results
40+
# for the following module and submodules. It stores
41+
# in one onnx file all the input/output for those.
42+
43+
with steal_forward(
44+
[
45+
("model", model),
46+
("model.decoder", model.decoder),
47+
("model.decoder.attention", model.decoder.attention),
48+
("model.decoder.feed_forward", model.decoder.feed_forward),
49+
("model.decoder.norm_1", model.decoder.norm_1),
50+
("model.decoder.norm_2", model.decoder.norm_2),
51+
],
52+
dump_file="plot_dump_intermediate_results.inputs.onnx",
53+
verbose=1,
54+
storage_limit=2**28,
55+
):
56+
model(*inputs)
57+
58+
59+
# %%
60+
# Restores saved inputs/outputs
61+
# +++++++++++++++++++++++++++++
62+
#
63+
# All the intermediate tensors were saved in one unique onnx model,
64+
# every tensor is stored in a constant node.
65+
# The model can be run with any runtime to restore the inputs
66+
# and function :func:`onnx_diagnostic.mini_onnx_builder.create_input_tensors_from_onnx_model`
67+
# can restore their names.
68+
69+
saved_tensors = create_input_tensors_from_onnx_model(
70+
"plot_dump_intermediate_results.inputs.onnx"
71+
)
72+
for k, v in saved_tensors.items():
73+
print(f"{k} -- {string_type(v, with_shape=True)}")
74+
75+
# %%
76+
# Let's explained the naming convention.
77+
#
78+
# ::
79+
# ('model.decoder.norm_2', 0, 'I') -- ((T1s2x30x16,),{})
80+
# | | |
81+
# | | +--> input, the format is args, kwargs
82+
# | |
83+
# | +--> iteration, 0 means the first time the execution
84+
# | went through that module
85+
# | it is possible to call multiple times,
86+
# | the model to store more
87+
# |
88+
# +--> the name given to steal forward
89+
#
90+
# The same goes for output except ``'I'`` is replaced by ``'O'``.
91+
#
92+
# ::
93+
#
94+
# ('model.decoder.norm_2', 0, 'O') -- T1s2x30x16
95+
#
96+
# This trick can be used to compare intermediate results coming
97+
# from pytorch to any other implementation of the same model
98+
# as long as it is possible to map the stored inputs/outputs.
99+
100+
# %%
101+
# Conversion to ONNX
102+
# ++++++++++++++++++
103+
#
104+
# The difficult point is to be able to map the saved intermediate
105+
# results to intermediate results in ONNX.
106+
# Let's create the ONNX model.
107+
108+
epo = torch.onnx.export(model, inputs, dynamic_shapes=ds, dynamo=True)
109+
epo.optimize()
110+
epo.save("plot_dump_intermediate_results.onnx")
111+
112+
# %%
113+
# It looks like the following.
114+
onx = onnx.load("plot_dump_intermediate_results.onnx")
115+
plot_dot(onx)
116+
117+
# %%
118+
doc.plot_legend("steal and dump\nintermediate\nresults", "steal_forward", "blue")

_doc/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ Enlightening Examples
9292
* :ref:`l-plot-failing-reference-evaluator`
9393
* :ref:`l-plot-failing-onnxruntime-evaluator`
9494
* :ref:`l-plot-failing-model-extract`
95+
* :ref:`l-plot-intermediate-results`
9596

9697
Some Usefuls Tools
9798
==================

onnx_diagnostic/helpers/mini_onnx_builder.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,8 @@ def create_input_tensors_from_onnx_model(
532532
:param engine: runtime to use, onnx, the default value, onnxruntime
533533
:param sep: separator
534534
:return: restored data
535+
536+
See example :ref:`l-plot-intermediate-results` for an example.
535537
"""
536538
if engine == "ExtendedReferenceEvaluator":
537539
from ..reference import ExtendedReferenceEvaluator

onnx_diagnostic/helpers/torch_helper.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,8 @@ def steal_forward(
288288
"""
289289
The necessary modification to steem forward method and prints out inputs
290290
and outputs using :func:`onnx_diagnostic.helpers.string_type`.
291-
See example :ref:`l-plot-tiny-llm-export`.
291+
See example :ref:`l-plot-tiny-llm-export` or
292+
:ref:`l-plot-intermediate-results`.
292293
293294
:param model: a model or a list of models to monitor,
294295
every model can also be a tuple(name, model), name is displayed well.

0 commit comments

Comments
 (0)