|
| 1 | +.. _l-patches-explained: |
| 2 | + |
| 3 | +================= |
| 4 | +Patches Explained |
| 5 | +================= |
| 6 | + |
| 7 | +Function :func:`onnx_diagnostic.torch_export_patches.torch_export_patches` |
| 8 | +implements four kinds of patches to make it easier to export a model, usually |
| 9 | +coming from :epkg:`transformers`. |
| 10 | +All patches takes place in :mod:`onnx_diagnostic.torch_export_patches`. |
| 11 | +.. code-block:: python |
| 12 | +
|
| 13 | + with torch_export_patches(...) as f: |
| 14 | + ep = torch.export.export(model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes) |
| 15 | +
|
| 16 | +1. **torch fixes**: |
| 17 | + it disables some exceptions or improves some functions related to dynamic shapes |
| 18 | + until :epkg:`torch` addresses the issues |
| 19 | + (see `mostly exporter issues |
| 20 | + <https://github.com/pytorch/pytorch/issues?q=is%3Aissue%20state%3Aopen%20author%3Axadupre>`_) |
| 21 | +2. **transformers rewriting**: |
| 22 | + some methods are replaced with a version :func:`torch.export.export` can understand, |
| 23 | + some rewriting may migrate to :epkg:`transformers`, others are applied only |
| 24 | + at export time because it would make the implementation less efficient |
| 25 | +3. **cache serialization**: :func:`torch.export.export` needs to know how to |
| 26 | + serialize custom classes such as :class:`transformers.cache_utils.DynamicCache` |
| 27 | +4. **control flow rewriting**: control flow (if, for) cannot be exported as is, |
| 28 | + there is still some work to be done to automatically process them, |
| 29 | + this package offers some automated rewriting, but it is far from being perfect. |
| 30 | + |
| 31 | +All of them are triggered by :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`. |
| 32 | + |
| 33 | +.. code-block:: bash |
| 34 | +
|
| 35 | + python -m onnx_diagnostic validate \ |
| 36 | + -m hf-tiny-model-private/tiny-random-PLBartForConditionalGeneration \ |
| 37 | + --run -v 1 --export onnx-dynamo -o dump_test --dtype float16 --device cuda |
| 38 | +
|
| 39 | +
|
| 40 | +All patches can be disabled with ``with torch_export_patches(patch=False)``. |
| 41 | + |
| 42 | +torch fixes |
| 43 | +=========== |
| 44 | + |
| 45 | +Implemented in :mod:`onnx_diagnostic.torch_export_patches.patches.patch_torch` and triggered with |
| 46 | +``with torch_export_patches(patch_sympy=True, patch_torch=True, catch_constraints=True, stop_if_static=1...)``. |
| 47 | + |
| 48 | +It fixes some issues found while exporting model. Some of them might not be needed anymore. |
| 49 | +It improves shape broadcasting or inserts an exception every time a dynamic dimension |
| 50 | +becomes static (``stop_if_static=1``). |
| 51 | + |
| 52 | +transformers rewriting |
| 53 | +====================== |
| 54 | + |
| 55 | +Implemented in :mod:`onnx_diagnostic.torch_export_patches.patches.patch_transformers` and triggered with |
| 56 | +``with torch_export_patches(patch_transformers=True)``. |
| 57 | + |
| 58 | +Every patched class is prefixed with ``patched_``. It contains two class attributes. |
| 59 | +``_PATCHES_`` contains the list of methods to replace. |
| 60 | +``_PATCHED_CLASS_`` is the class patched by this one. |
| 61 | + |
| 62 | +.. code-block:: python |
| 63 | +
|
| 64 | + class patched_AttentionMaskConverter: |
| 65 | + """ |
| 66 | + Patches |
| 67 | + ``transformers.modeling_attn_mask_utils.AttentionMaskConverter._make_causal_mask``. |
| 68 | + """ |
| 69 | +
|
| 70 | + # This method was fixed in 4.51 at least. |
| 71 | + _PATCHES_ = ["_make_causal_mask"] if not has_transformers("4.48.3") else [] |
| 72 | + _PATCHED_CLASS_ = AttentionMaskConverter |
| 73 | +
|
| 74 | +The packages automatically parses this file to extract the patched methods. |
| 75 | +More can be added by populating the argument ``custom_patches``: |
| 76 | +``with torch_export_patches(custom_patches=[...])``. |
| 77 | + |
| 78 | +Cache serialization |
| 79 | +=================== |
| 80 | + |
| 81 | +Implemented in :mod:`onnx_diagnostic.torch_export_patches.onnx_export_serialization`. |
| 82 | +Any custom classes manipulated by a model needs to be registered through |
| 83 | +``torch.utils._pytree.register_pytree_node`` or with |
| 84 | +:func:`onnx_diagnostic.torch_export_patches.onnx_export_serialization.register_class_serialization` |
| 85 | +and triggered by ``with torch_export_patches(patch_transformers=True)``. |
| 86 | +This function does one class, |
| 87 | +:func:`onnx_diagnostic.torch_export_patches.onnx_export_serialization.register_cache_serialization` |
| 88 | +does all known classes. |
| 89 | +It can be undone with :func:`onnx_diagnostic.torch_export_patches.onnx_export_serialization.unregister` |
| 90 | +or :func:`onnx_diagnostic.torch_export_patches.onnx_export_serialization.unregister_cache_serialization`. |
| 91 | + |
| 92 | +.. _l-control-flow-rewriting: |
| 93 | + |
| 94 | +Control flow rewriting |
| 95 | +====================== |
| 96 | + |
| 97 | +This is an attempt to automatically rewrite control flow using :mod:`ast`. |
| 98 | +It is implemented in :mod:`onnx_diagnostic.torch_export_patches.patch_module` and |
| 99 | +triggered ``with torch_export_patches(rewrite=<instance of torch.nn.Module>)``. |
| 100 | +Option ``dump_rewriting=<folder>`` tells the function to dump all applied |
| 101 | +rewritings. |
| 102 | + |
| 103 | +The following example contains the rewriting of method |
| 104 | +:meth:`transformers.models.bart.modeling_bart.BartEncoderLayer.forward`. |
| 105 | +The list of known rewriting to apply are returned by function |
| 106 | +:func:`onnx_diagnostic.torch_export_patches.patch_module_helper.code_needing_rewriting` |
| 107 | +and applied by function :func:`onnx_diagnostic.torch_export_patches.patch_module.transform_method`. |
| 108 | + |
| 109 | +While parsing the code, it is missing type information but this is known by |
| 110 | +:func:`torch.export.export`. Due to that, the automation usually needs manual tuning |
| 111 | +to filter out some tests (argument ``filter_node``) or pre/post processing |
| 112 | +(arguments ``pre_rewriter``, ``post_rewriter``) of function |
| 113 | +:func:`onnx_diagnostic.torch_export_patches.patch_module.transform_method`. |
| 114 | + |
| 115 | +The main entry point is the context |
| 116 | +:func:`onnx_diagnostic.torch_export_patches.torch_export_rewrite` |
| 117 | +which rewrites and undoes the rewriting. |
| 118 | +For example, the model :class:`transformers.BartForConditionalGeneration` |
| 119 | +requires the following value for parameter ``rewrite``: |
| 120 | + |
| 121 | +.. runpython:: |
| 122 | + :showcode: |
| 123 | + |
| 124 | + from onnx_diagnostic.torch_export_patches.patch_module_helper import ( |
| 125 | + code_needing_rewriting, |
| 126 | + ) |
| 127 | + |
| 128 | + print(code_needing_rewriting("BartForConditionalGeneration")) |
| 129 | + |
| 130 | +And that produces: |
| 131 | + |
| 132 | +.. code-block:: diff |
| 133 | +
|
| 134 | + --- original |
| 135 | + +++ rewritten |
| 136 | + @@ -26,7 +26,6 @@ |
| 137 | + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) |
| 138 | + hidden_states = residual + hidden_states |
| 139 | + hidden_states = self.self_attn_layer_norm(hidden_states) |
| 140 | + - |
| 141 | + residual = hidden_states |
| 142 | + hidden_states = self.activation_fn(self.fc1(hidden_states)) |
| 143 | + hidden_states = nn.functional.dropout( |
| 144 | + @@ -37,15 +36,22 @@ |
| 145 | + hidden_states = residual + hidden_states |
| 146 | + hidden_states = self.final_layer_norm(hidden_states) |
| 147 | + |
| 148 | + - if hidden_states.dtype == torch.float16 and ( |
| 149 | + - torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() |
| 150 | + - ): |
| 151 | + + def branch_cond_then_1(hidden_states): |
| 152 | + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 |
| 153 | + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) |
| 154 | + + return hidden_states.clone() |
| 155 | + |
| 156 | + + def branch_cond_else_1(hidden_states): |
| 157 | + + return hidden_states.clone() |
| 158 | + + |
| 159 | + + hidden_states = torch.cond( |
| 160 | + + hidden_states.dtype == torch.float16 |
| 161 | + + and torch.isinf(hidden_states).any() | torch.isnan(hidden_states).any(), |
| 162 | + + branch_cond_then_1, |
| 163 | + + branch_cond_else_1, |
| 164 | + + [hidden_states], |
| 165 | + + ) |
| 166 | + outputs = (hidden_states,) |
| 167 | + - |
| 168 | + if output_attentions: |
| 169 | + - outputs += (attn_weights,) |
| 170 | + - |
| 171 | + + outputs = outputs + (attn_weights,) |
| 172 | + return outputs |
0 commit comments