|
29 | 29 | specified at `dynamic_shapes['past_key_values']` |
30 | 30 | to non-tensor type <class 'transformers.cache_utils.DynamicCache'> |
31 | 31 | at `inputs['past_key_values']` (expected None) |
32 | | - For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#dynamic-shapes-validation |
| 32 | + For more information about this error, |
| 33 | + see: https://pytorch.org/docs/main/generated/exportdb/index.html#dynamic-shapes-validation |
33 | 34 |
|
34 | 35 | With ``transformers==4.50``, it shows the following: |
35 | 36 |
|
|
67 | 68 | import torch |
68 | 69 | import transformers |
69 | 70 | from onnx_diagnostic import doc |
| 71 | +from onnx_diagnostic.cache_helpers import is_cache_dynamic_registered |
70 | 72 | from onnx_diagnostic.helpers import string_type |
71 | | -from onnx_diagnostic.torch_export_patches.onnx_export_errors import bypass_export_some_errors |
| 73 | +from onnx_diagnostic.torch_export_patches import bypass_export_some_errors |
72 | 74 | from onnx_diagnostic.torch_models.llms import get_tiny_llm |
73 | 75 |
|
74 | 76 |
|
|
92 | 94 | pprint.pprint(dynamic_shapes) |
93 | 95 |
|
94 | 96 | # %% |
95 | | -# We are ready to export. |
| 97 | +# Before exporting, we check :class:`transformers.cache_utils.DynamicCache` |
| 98 | +# can serialized and deserialized otherwise :func:`torch.export.export` |
| 99 | +# fails. |
96 | 100 |
|
97 | | -with bypass_export_some_errors(patch_transformers=True) as modificator: |
| 101 | +print("-- DynamicCache registered: ", is_cache_dynamic_registered()) |
| 102 | + |
| 103 | +# %% |
| 104 | +# If they are not registered, function |
| 105 | +# func:`onnx_diagnostic.torch_export_patches.bypass_export_some_errors` |
| 106 | +# should take care of it. Then we export. |
| 107 | + |
| 108 | +with bypass_export_some_errors(patch_transformers=True, verbose=10) as modificator: |
| 109 | + assert is_cache_dynamic_registered() # it must be true here |
98 | 110 | ep = torch.export.export( |
99 | 111 | untrained_model, |
100 | 112 | (), |
101 | 113 | kwargs=modificator(cloned_inputs), |
102 | 114 | dynamic_shapes=dynamic_shapes, |
| 115 | + strict=False, # mandatory for torch==2.6 |
103 | 116 | ) |
104 | 117 | print("It worked:") |
105 | 118 | print(ep) |
|
114 | 127 |
|
115 | 128 | cloned_inputs = copy.deepcopy(inputs) |
116 | 129 |
|
117 | | -with bypass_export_some_errors(patch_transformers=True) as modificator: |
| 130 | +with bypass_export_some_errors(patch_transformers=True, verbose=10) as modificator: |
118 | 131 | ep = torch.export.export( |
119 | 132 | model, |
120 | 133 | (), |
121 | 134 | kwargs=modificator(cloned_inputs), |
122 | 135 | dynamic_shapes=dynamic_shapes, |
| 136 | + strict=False, # mandatory for torch==2.6 |
123 | 137 | ) |
124 | 138 | print("It worked:") |
125 | 139 | print(ep) |
|
0 commit comments