Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 132 additions & 0 deletions _doc/examples/plot_dump_intermediate_results.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
"""
.. _l-plot-intermediate-results:

Dumps intermediate results of a torch model
===========================================

Looking for discrepancies is quickly annoying. Discrepancies
come from two results obtained with the same models
implemented in two different ways, :epkg:`pytorch` and :epkg:`onnx`.
Models are big so where do they come from? That's the
unavoidable question. Unless there is an obvious reason,
the only way is to compare intermediate outputs alon the computation.
The first step into that direction is to dump the intermediate results
coming from :epkg:`pytorch`.
We use :func:`onnx_diagnostic.helpers.torch_helper.steal_forward` for that.

A simple LLM Model
++++++++++++++++++

See :func:`onnx_diagnostic.helpers.torch_helper.dummy_llm`
for its definition. It is mostly used for unit test or example.

"""

import onnx
import torch
from onnx_array_api.plotting.graphviz_helper import plot_dot
from onnx_diagnostic import doc
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.helpers.torch_helper import dummy_llm
from onnx_diagnostic.helpers.mini_onnx_builder import create_input_tensors_from_onnx_model
from onnx_diagnostic.helpers.torch_helper import steal_forward


model, inputs, ds = dummy_llm(dynamic_shapes=True)

print(f"type(model)={type(model)}")
print(f"inputs={string_type(inputs, with_shape=True)}")
print(f"ds={string_type(ds, with_shape=True)}")

# %%
# It contains the following submodules.

for name, mod in model.named_modules():
print(f"- {name}: {type(mod)}")

# %%
# Steal and dump the output of submodules
# +++++++++++++++++++++++++++++++++++++++
#
# The following context spies on the intermediate results
# for the following module and submodules. It stores
# in one onnx file all the input/output for those.

with steal_forward(
[
("model", model),
("model.decoder", model.decoder),
("model.decoder.attention", model.decoder.attention),
("model.decoder.feed_forward", model.decoder.feed_forward),
("model.decoder.norm_1", model.decoder.norm_1),
("model.decoder.norm_2", model.decoder.norm_2),
],
dump_file="plot_dump_intermediate_results.inputs.onnx",
verbose=1,
storage_limit=2**28,
):
model(*inputs)


# %%
# Restores saved inputs/outputs
# +++++++++++++++++++++++++++++
#
# All the intermediate tensors were saved in one unique onnx model,
# every tensor is stored in a constant node.
# The model can be run with any runtime to restore the inputs
# and function :func:`create_input_tensors_from_onnx_model
# <onnx_diagnostic.helpers.mini_onnx_builder.create_input_tensors_from_onnx_model>`
# can restore their names.

saved_tensors = create_input_tensors_from_onnx_model(
"plot_dump_intermediate_results.inputs.onnx"
)
for k, v in saved_tensors.items():
print(f"{k} -- {string_type(v, with_shape=True)}")

# %%
# Let's explained the naming convention.
#
# ::
#
# ('model.decoder.norm_2', 0, 'I') -- ((T1s2x30x16,),{})
# | | |
# | | +--> input, the format is args, kwargs
# | |
# | +--> iteration, 0 means the first time the execution
# | went through that module
# | it is possible to call multiple times,
# | the model to store more
# |
# +--> the name given to function steal_forward
#
# The same goes for output except ``'I'`` is replaced by ``'O'``.
#
# ::
#
# ('model.decoder.norm_2', 0, 'O') -- T1s2x30x16
#
# This trick can be used to compare intermediate results coming
# from pytorch to any other implementation of the same model
# as long as it is possible to map the stored inputs/outputs.

# %%
# Conversion to ONNX
# ++++++++++++++++++
#
# The difficult point is to be able to map the saved intermediate
# results to intermediate results in ONNX.
# Let's create the ONNX model.

epo = torch.onnx.export(model, inputs, dynamic_shapes=ds, dynamo=True)
epo.optimize()
epo.save("plot_dump_intermediate_results.onnx")

# %%
# It looks like the following.
onx = onnx.load("plot_dump_intermediate_results.onnx")
plot_dot(onx)

# %%
doc.plot_legend("steal and dump\nintermediate\nresults", "steal_forward", "blue")
1 change: 1 addition & 0 deletions _doc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ Enlightening Examples
* :ref:`l-plot-failing-reference-evaluator`
* :ref:`l-plot-failing-onnxruntime-evaluator`
* :ref:`l-plot-failing-model-extract`
* :ref:`l-plot-intermediate-results`

Some Usefuls Tools
==================
Expand Down
4 changes: 4 additions & 0 deletions _unittests/ut_xrun_doc/test_documentation_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,15 @@ def add_test_methods(cls):
this = os.path.abspath(os.path.dirname(__file__))
fold = os.path.normpath(os.path.join(this, "..", "..", "_doc", "examples"))
found = os.listdir(fold)
has_dot = int(os.environ.get("UNITTEST_DOT", "0"))
for name in found:
if not name.endswith(".py") or not name.startswith("plot_"):
continue
reason = None

if not reason and not has_dot and name in {"plot_dump_intermediate_results.py"}:
reason = "dot not installed"

if (
not reason
and name in {"plot_export_tiny_llm.py"}
Expand Down
3 changes: 3 additions & 0 deletions _unittests/ut_xrun_doc/test_documentation_recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,16 @@ def add_test_methods(cls):
this = os.path.abspath(os.path.dirname(__file__))
fold = os.path.normpath(os.path.join(this, "..", "..", "_doc", "recipes"))
found = os.listdir(fold)
has_dot = int(os.environ.get("UNITTEST_DOT", "0"))
for name in found:
if not name.endswith(".py") or not name.startswith("plot_"):
continue
reason = None

if not reason and not has_torch("4.7"):
reason = "torch<2.7"
if not reason and not has_dot and name in {"plot_dump_intermediate_results.py"}:
reason = "dot not installed"

if reason:

Expand Down
5 changes: 4 additions & 1 deletion onnx_diagnostic/helpers/mini_onnx_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,8 @@ def create_onnx_model_from_input_tensors(
Creates a model proto including all the value as initializers.
They can be restored by executing the model.
We assume these inputs are not bigger than 2Gb,
the limit of protobuf.
the limit of protobuf. Nothing is implemented yet to get around
that limit.

:param inputs: anything
:param switch_low_high: if None, it is equal to ``switch_low_high=sys.byteorder != "big"``
Expand Down Expand Up @@ -532,6 +533,8 @@ def create_input_tensors_from_onnx_model(
:param engine: runtime to use, onnx, the default value, onnxruntime
:param sep: separator
:return: restored data

See example :ref:`l-plot-intermediate-results` for an example.
"""
if engine == "ExtendedReferenceEvaluator":
from ..reference import ExtendedReferenceEvaluator
Expand Down
8 changes: 6 additions & 2 deletions onnx_diagnostic/helpers/torch_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,8 @@ def steal_forward(
"""
The necessary modification to steem forward method and prints out inputs
and outputs using :func:`onnx_diagnostic.helpers.string_type`.
See example :ref:`l-plot-tiny-llm-export`.
See example :ref:`l-plot-tiny-llm-export` or
:ref:`l-plot-intermediate-results`.

:param model: a model or a list of models to monitor,
every model can also be a tuple(name, model), name is displayed well.
Expand Down Expand Up @@ -410,12 +411,15 @@ def forward(self, x, y):
proto = create_onnx_model_from_input_tensors(storage)
if verbose:
print("-- dumps stored objects")
location = f"{os.path.split(dump_file)[-1]}.data"
if os.path.exists(location):
os.remove(location)
onnx.save(
proto,
dump_file,
save_as_external_data=True,
all_tensors_to_one_file=True,
location=f"{os.path.split(dump_file)[-1]}.data",
location=location,
)
if verbose:
print("-- done dump stored objects")
Expand Down
Loading