|
| 1 | +""" |
| 2 | +.. _l-plot-intermediate-results: |
| 3 | +
|
| 4 | +Dumps intermediate results of a torch model |
| 5 | +=========================================== |
| 6 | +
|
| 7 | +Looking for discrepancies is quickly annoying. Discrepancies |
| 8 | +come from two results obtained with the same models |
| 9 | +implemented in two different ways, :epkg:`pytorch` and :epkg:`onnx`. |
| 10 | +Models are big so where do they come from? That's the |
| 11 | +unavoidable question. Unless there is an obvious reason, |
| 12 | +the only way is to compare intermediate outputs alon the computation. |
| 13 | +The first step into that direction is to dump the intermediate results |
| 14 | +coming from :epkg:`pytorch`. |
| 15 | +We use :func:`onnx_diagnostic.helpers.torch_helper.steal_forward` for that. |
| 16 | +
|
| 17 | +A simple LLM Model |
| 18 | +++++++++++++++++++ |
| 19 | +
|
| 20 | +See :func:`onnx_diagnostic.helpers.torch_helper.dummy_llm` |
| 21 | +for its definition. It is mostly used for unit test or example. |
| 22 | +
|
| 23 | +""" |
| 24 | + |
| 25 | +import onnx |
| 26 | +import torch |
| 27 | +from onnx_array_api.plotting.graphviz_helper import plot_dot |
| 28 | +from onnx_diagnostic import doc |
| 29 | +from onnx_diagnostic.helpers import string_type |
| 30 | +from onnx_diagnostic.helpers.torch_helper import dummy_llm |
| 31 | +from onnx_diagnostic.helpers.mini_onnx_builder import create_input_tensors_from_onnx_model |
| 32 | +from onnx_diagnostic.helpers.torch_helper import steal_forward |
| 33 | + |
| 34 | + |
| 35 | +model, inputs, ds = dummy_llm(dynamic_shapes=True) |
| 36 | + |
| 37 | +print(f"type(model)={type(model)}") |
| 38 | +print(f"inputs={string_type(inputs, with_shape=True)}") |
| 39 | +print(f"ds={string_type(ds, with_shape=True)}") |
| 40 | + |
| 41 | +# %% |
| 42 | +# It contains the following submodules. |
| 43 | + |
| 44 | +for name, mod in model.named_modules(): |
| 45 | + print(f"- {name}: {type(mod)}") |
| 46 | + |
| 47 | +# %% |
| 48 | +# Steal and dump the output of submodules |
| 49 | +# +++++++++++++++++++++++++++++++++++++++ |
| 50 | +# |
| 51 | +# The following context spies on the intermediate results |
| 52 | +# for the following module and submodules. It stores |
| 53 | +# in one onnx file all the input/output for those. |
| 54 | + |
| 55 | +with steal_forward( |
| 56 | + [ |
| 57 | + ("model", model), |
| 58 | + ("model.decoder", model.decoder), |
| 59 | + ("model.decoder.attention", model.decoder.attention), |
| 60 | + ("model.decoder.feed_forward", model.decoder.feed_forward), |
| 61 | + ("model.decoder.norm_1", model.decoder.norm_1), |
| 62 | + ("model.decoder.norm_2", model.decoder.norm_2), |
| 63 | + ], |
| 64 | + dump_file="plot_dump_intermediate_results.inputs.onnx", |
| 65 | + verbose=1, |
| 66 | + storage_limit=2**28, |
| 67 | +): |
| 68 | + model(*inputs) |
| 69 | + |
| 70 | + |
| 71 | +# %% |
| 72 | +# Restores saved inputs/outputs |
| 73 | +# +++++++++++++++++++++++++++++ |
| 74 | +# |
| 75 | +# All the intermediate tensors were saved in one unique onnx model, |
| 76 | +# every tensor is stored in a constant node. |
| 77 | +# The model can be run with any runtime to restore the inputs |
| 78 | +# and function :func:`create_input_tensors_from_onnx_model |
| 79 | +# <onnx_diagnostic.helpers.mini_onnx_builder.create_input_tensors_from_onnx_model>` |
| 80 | +# can restore their names. |
| 81 | + |
| 82 | +saved_tensors = create_input_tensors_from_onnx_model( |
| 83 | + "plot_dump_intermediate_results.inputs.onnx" |
| 84 | +) |
| 85 | +for k, v in saved_tensors.items(): |
| 86 | + print(f"{k} -- {string_type(v, with_shape=True)}") |
| 87 | + |
| 88 | +# %% |
| 89 | +# Let's explained the naming convention. |
| 90 | +# |
| 91 | +# :: |
| 92 | +# |
| 93 | +# ('model.decoder.norm_2', 0, 'I') -- ((T1s2x30x16,),{}) |
| 94 | +# | | | |
| 95 | +# | | +--> input, the format is args, kwargs |
| 96 | +# | | |
| 97 | +# | +--> iteration, 0 means the first time the execution |
| 98 | +# | went through that module |
| 99 | +# | it is possible to call multiple times, |
| 100 | +# | the model to store more |
| 101 | +# | |
| 102 | +# +--> the name given to function steal_forward |
| 103 | +# |
| 104 | +# The same goes for output except ``'I'`` is replaced by ``'O'``. |
| 105 | +# |
| 106 | +# :: |
| 107 | +# |
| 108 | +# ('model.decoder.norm_2', 0, 'O') -- T1s2x30x16 |
| 109 | +# |
| 110 | +# This trick can be used to compare intermediate results coming |
| 111 | +# from pytorch to any other implementation of the same model |
| 112 | +# as long as it is possible to map the stored inputs/outputs. |
| 113 | + |
| 114 | +# %% |
| 115 | +# Conversion to ONNX |
| 116 | +# ++++++++++++++++++ |
| 117 | +# |
| 118 | +# The difficult point is to be able to map the saved intermediate |
| 119 | +# results to intermediate results in ONNX. |
| 120 | +# Let's create the ONNX model. |
| 121 | + |
| 122 | +epo = torch.onnx.export(model, inputs, dynamic_shapes=ds, dynamo=True) |
| 123 | +epo.optimize() |
| 124 | +epo.save("plot_dump_intermediate_results.onnx") |
| 125 | + |
| 126 | +# %% |
| 127 | +# It looks like the following. |
| 128 | +onx = onnx.load("plot_dump_intermediate_results.onnx") |
| 129 | +plot_dot(onx) |
| 130 | + |
| 131 | +# %% |
| 132 | +doc.plot_legend("steal and dump\nintermediate\nresults", "steal_forward", "blue") |
0 commit comments