|
27 | 27 | from onnx_diagnostic.helpers import string_type |
28 | 28 | from onnx_diagnostic.export import ModelInputs |
29 | 29 |
|
| 30 | +# %% |
| 31 | +# We need addition import in case ``transformers<4.50``. |
| 32 | +# Exporting DynamicCache is not supported before that. |
| 33 | +from onnx_diagnostic.ext_test_case import has_transformers |
| 34 | +from onnx_diagnostic.torch_export_patches import bypass_export_some_errors |
| 35 | + |
30 | 36 |
|
31 | 37 | class Model(torch.nn.Module): |
32 | 38 | def forward(self, x, y): |
@@ -201,6 +207,20 @@ def forward(self, cache, z): |
201 | 207 |
|
202 | 208 | # %% |
203 | 209 | # And finally the export. |
204 | | - |
205 | | -ep = torch.export.export(model, inputs[0], dynamic_shapes=ds[0], strict=False) |
| 210 | +# The export is simple if ``transformers>=4.50``, otherwise, |
| 211 | +# transformers needs to be patched. |
| 212 | +# :func:`onnx_diagnostic.torch_export_patches.bypass_export_some_errors` |
| 213 | +# registers functions to serialize ``DynamicCache`` and another class |
| 214 | +# called ``patched_DynamicCache``. This one is modified to make |
| 215 | +# the shape inference implemented in :epkg:`torch` happy. |
| 216 | + |
| 217 | +if has_transformers("4.50"): |
| 218 | + ep = torch.export.export(model, inputs[0], dynamic_shapes=ds[0], strict=False) |
| 219 | +else: |
| 220 | + with bypass_export_some_errors( |
| 221 | + patch_transformers=True, replace_dynamic_cache=True |
| 222 | + ) as modificator: |
| 223 | + ep = torch.export.export( |
| 224 | + model, modificator(inputs[0]), dynamic_shapes=ds[0], strict=False |
| 225 | + ) |
206 | 226 | print(ep) |
0 commit comments