|
| 1 | +""" |
| 2 | +.. _l-plot-intermediate-results: |
| 3 | +
|
| 4 | +Dumps intermediate results of a torch model |
| 5 | +=========================================== |
| 6 | +
|
| 7 | +
|
| 8 | +codellama/CodeLlama-7b-Python-hf |
| 9 | +++++++++++++++++++++++++++++++++ |
| 10 | +
|
| 11 | +""" |
| 12 | + |
| 13 | +import onnx |
| 14 | +import torch |
| 15 | +from onnx_array_api.plotting.graphviz_helper import plot_dot |
| 16 | +from onnx_diagnostic import doc |
| 17 | +from onnx_diagnostic.helpers import string_type |
| 18 | +from onnx_diagnostic.helpers.torch_helper import dummy_llm |
| 19 | +from onnx_diagnostic.helpers.mini_onnx_builder import create_input_tensors_from_onnx_model |
| 20 | +from onnx_diagnostic.helpers.torch_helper import steal_forward |
| 21 | + |
| 22 | + |
| 23 | +model, inputs, ds = dummy_llm(dynamic_shapes=True) |
| 24 | + |
| 25 | +print(f"type(model)={type(model)}") |
| 26 | +print(f"inputs={string_type(inputs, with_shape=True)}") |
| 27 | +print(f"ds={string_type(ds, with_shape=True)}") |
| 28 | + |
| 29 | +# %% |
| 30 | +# It contains the following submodules. |
| 31 | + |
| 32 | +for name, mod in model.named_modules(): |
| 33 | + print(f"- {name}: {type(mod)}") |
| 34 | + |
| 35 | +# %% |
| 36 | +# Steal and dump the output of submodules |
| 37 | +# +++++++++++++++++++++++++++++++++++++++ |
| 38 | +# |
| 39 | +# The following context spies on the intermediate results |
| 40 | +# for the following module and submodules. It stores |
| 41 | +# in one onnx file all the input/output for those. |
| 42 | + |
| 43 | +with steal_forward( |
| 44 | + [ |
| 45 | + ("model", model), |
| 46 | + ("model.decoder", model.decoder), |
| 47 | + ("model.decoder.attention", model.decoder.attention), |
| 48 | + ("model.decoder.feed_forward", model.decoder.feed_forward), |
| 49 | + ("model.decoder.norm_1", model.decoder.norm_1), |
| 50 | + ("model.decoder.norm_2", model.decoder.norm_2), |
| 51 | + ], |
| 52 | + dump_file="plot_dump_intermediate_results.inputs.onnx", |
| 53 | + verbose=1, |
| 54 | + storage_limit=2**28, |
| 55 | +): |
| 56 | + model(*inputs) |
| 57 | + |
| 58 | + |
| 59 | +# %% |
| 60 | +# Restores saved inputs/outputs |
| 61 | +# +++++++++++++++++++++++++++++ |
| 62 | +# |
| 63 | +# All the intermediate tensors were saved in one unique onnx model, |
| 64 | +# every tensor is stored in a constant node. |
| 65 | +# The model can be run with any runtime to restore the inputs |
| 66 | +# and function :func:`onnx_diagnostic.mini_onnx_builder.create_input_tensors_from_onnx_model` |
| 67 | +# can restore their names. |
| 68 | + |
| 69 | +saved_tensors = create_input_tensors_from_onnx_model( |
| 70 | + "plot_dump_intermediate_results.inputs.onnx" |
| 71 | +) |
| 72 | +for k, v in saved_tensors.items(): |
| 73 | + print(f"{k} -- {string_type(v, with_shape=True)}") |
| 74 | + |
| 75 | +# %% |
| 76 | +# Let's explained the naming convention. |
| 77 | +# |
| 78 | +# :: |
| 79 | +# ('model.decoder.norm_2', 0, 'I') -- ((T1s2x30x16,),{}) |
| 80 | +# | | | |
| 81 | +# | | +--> input, the format is args, kwargs |
| 82 | +# | | |
| 83 | +# | +--> iteration, 0 means the first time the execution |
| 84 | +# | went through that module |
| 85 | +# | it is possible to call multiple times, |
| 86 | +# | the model to store more |
| 87 | +# | |
| 88 | +# +--> the name given to steal forward |
| 89 | +# |
| 90 | +# The same goes for output except ``'I'`` is replaced by ``'O'``. |
| 91 | +# |
| 92 | +# :: |
| 93 | +# |
| 94 | +# ('model.decoder.norm_2', 0, 'O') -- T1s2x30x16 |
| 95 | +# |
| 96 | +# This trick can be used to compare intermediate results coming |
| 97 | +# from pytorch to any other implementation of the same model |
| 98 | +# as long as it is possible to map the stored inputs/outputs. |
| 99 | + |
| 100 | +# %% |
| 101 | +# Conversion to ONNX |
| 102 | +# ++++++++++++++++++ |
| 103 | +# |
| 104 | +# The difficult point is to be able to map the saved intermediate |
| 105 | +# results to intermediate results in ONNX. |
| 106 | +# Let's create the ONNX model. |
| 107 | + |
| 108 | +epo = torch.onnx.export(model, inputs, dynamic_shapes=ds, dynamo=True) |
| 109 | +epo.optimize() |
| 110 | +epo.save("plot_dump_intermediate_results.onnx") |
| 111 | + |
| 112 | +# %% |
| 113 | +# It looks like the following. |
| 114 | +onx = onnx.load("plot_dump_intermediate_results.onnx") |
| 115 | +plot_dot(onx) |
| 116 | + |
| 117 | +# %% |
| 118 | +doc.plot_legend("steal and dump\nintermediate\nresults", "steal_forward", "blue") |
0 commit comments