Skip to content

Commit 9e9b996

Browse files
authored
Improve patches for transformers (#16)
* Improve patches for transformers * myp * urls * fix issues * fix * patches
1 parent f6ad410 commit 9e9b996

20 files changed

+628
-256
lines changed

.github/workflows/ci.yml

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ jobs:
1515
strategy:
1616
matrix:
1717
os: [ubuntu-latest]
18-
python: ['3.12']
19-
transformers: ['4.48', 'main']
18+
python: ['3.11', '3.12']
19+
transformers: ['4.48', '4.50', 'main']
2020

2121
steps:
2222
- uses: actions/checkout@v3
@@ -69,6 +69,16 @@ jobs:
6969
export PYTHONPATH=.
7070
python _unittests/ut_torch_models/test_tiny_llms_onnx.py
7171
72+
- name: tiny-llm example
73+
run: |
74+
export PYTHONPATH=.
75+
python _doc/examples/plot_export_tiny_llm.py
76+
77+
- name: tiny-llm bypass
78+
run: |
79+
export PYTHONPATH=.
80+
python _doc/examples/plot_export_tiny_llm_patched.py
81+
7282
- name: run tests
7383
run: |
7484
pip install pytest

.github/workflows/documentation.yml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,16 @@ jobs:
6363
export PYTHONPATH=.
6464
python _unittests/ut_torch_models/test_tiny_llms_onnx.py
6565
66+
- name: tiny-llm example
67+
run: |
68+
export PYTHONPATH=.
69+
python _doc/examples/plot_export_tiny_llm.py
70+
71+
- name: tiny-llm bypass
72+
run: |
73+
export PYTHONPATH=.
74+
python _doc/examples/plot_export_tiny_llm_patched.py
75+
6676
- name: Generate coverage report
6777
run: |
6878
pip install pytest

CHANGELOGS.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
Change Logs
22
===========
33

4+
0.2.1
5+
+++++
6+
7+
* :pr:`16`: refactors patches
8+
49
0.2.0
510
+++++
611

README.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ Enlightening Examples
5454
<https://sdpython.github.io/doc/onnx-diagnostic/dev/auto_examples/plot_export_with_dynamic_cache.html>`_
5555
* `Steel method forward to guess the dynamic shapes (with Tiny-LLM)
5656
<https://sdpython.github.io/doc/onnx-diagnostic/dev/auto_examples/plot_export_tiny_llm.html>`_
57+
* `Export Tiny-LLM with patches
58+
<https://sdpython.github.io/doc/onnx-diagnostic/dev/auto_examples/plot_export_tiny_llm_patched.html>`_
5759

5860
**Investigate ONNX models**
5961

_doc/conf.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,8 @@
113113
("py:class", "pipeline.Pipeline"),
114114
("py:class", "torch.fx.passes.operator_support.OperatorSupport"),
115115
("py:class", "torch.fx.proxy.TracerBase"),
116+
("py:class", "torch.FloatTensor"),
117+
("py:class", "torch.LongTensor"),
116118
("py:class", "torch.utils._pytree.Context"),
117119
("py:class", "torch.utils._pytree.KeyEntry"),
118120
("py:class", "torch.utils._pytree.TreeSpec"),
@@ -196,8 +198,8 @@
196198
"onnx-extended": "https://sdpython.github.io/doc/onnx-extended/dev/",
197199
"onnx-script": "https://github.com/microsoft/onnxscript",
198200
"onnxscript": "https://github.com/microsoft/onnxscript",
199-
"onnxscript Tutorial": "https://onnxscript.ai/tutorial/index.html",
200-
"Pattern-based Rewrite Using Rules With onnxscript": "https://onnxscript.ai/tutorial/rewriter/rewrite_patterns.html",
201+
"onnxscript Tutorial": "https://microsoft.github.io/onnxscript/tutorial/index.html",
202+
"Pattern-based Rewrite Using Rules With onnxscript": "https://microsoft.github.io/onnxscript/tutorial/rewriter/rewrite_patterns.html",
201203
"opsets": "https://onnx.ai/onnx/intro/concepts.html#what-is-an-opset-version",
202204
"pyinstrument": "https://pyinstrument.readthedocs.io/en/latest/",
203205
"psutil": "https://psutil.readthedocs.io/en/latest/",

_doc/examples/plot_export_tiny_llm.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def _forward_(*args, _f=None, **kwargs):
8484
#
8585
# Let's create an untrained model using the config file provided
8686
# `config.json <https://huggingface.co/arnir0/Tiny-LLM/blob/main/config.json>`_
87-
# to create an untrained model: :func:`onnx_diagnostic.torch_models.llms.get_tiny_llm`.
87+
# to create an untrained model: :func:`....get_tiny_llm`.
8888
# Then let's use it.
8989

9090
experiment = get_tiny_llm()
@@ -138,7 +138,7 @@ def _forward_(*args, _f=None, **kwargs):
138138
#
139139
# Let's use the same dummy inputs but we use the downloaded model.
140140
# Dummy inputs and dynamic shapes are created by function
141-
# :func:`onnx_diagnostic.torch_models.llms.get_tiny_llm`.
141+
# :func:`....get_tiny_llm`.
142142

143143
data = get_tiny_llm()
144144
inputs, dynamic_shapes = data["inputs"], data["dynamic_shapes"]
@@ -163,3 +163,7 @@ def _forward_(*args, _f=None, **kwargs):
163163
# * https://github.com/huggingface/transformers/pull/36311
164164
# * https://github.com/huggingface/transformers/pull/36652
165165
print("It failed:", e)
166+
167+
# %%
168+
# If you have any error, then look at example
169+
# :ref:`l-plot-tiny-llm-export-patched`.
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
"""
2+
.. _l-plot-tiny-llm-export-patched:
3+
4+
Export Tiny-LLM with patches
5+
============================
6+
7+
Many models from :epkg:`transformers` cannot be converted because
8+
the implementation uses cache classes. Let's see how to get around that.
9+
We focus on the model
10+
`Tiny-LLM <https://huggingface.co/arnir0/Tiny-LLM>`_.
11+
To avoid downloading any weights, we write a function creating a
12+
random model based on the same architecture.
13+
This continues example :ref:`l-plot-tiny-llm-export`.
14+
15+
Errors
16+
++++++
17+
18+
They depend on transformers version.
19+
20+
``transformers>=4.40,<4.50`` cannot serialize DynamicCache and cannot
21+
map dynamic shapes to instances of DynamicCache. The following errors
22+
would appear:
23+
24+
::
25+
26+
torch._dynamo.exc.UserError: Cannot associate shape
27+
[[{0: <class '....batch'>, 2: <class '....cache_length'>}],
28+
[{0: <class '....batch'>, 2: <class '....cache_length'>}]]
29+
specified at `dynamic_shapes['past_key_values']`
30+
to non-tensor type <class 'transformers.cache_utils.DynamicCache'>
31+
at `inputs['past_key_values']` (expected None)
32+
For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#dynamic-shapes-validation
33+
34+
With ``transformers==4.50``, it shows the following:
35+
36+
::
37+
38+
torch._dynamo.exc.UserError: Constraints violated (batch)!
39+
For more information, run with TORCH_LOGS="+dynamic".
40+
- Not all values of batch = L['args'][1]['input_ids'].size()[0]
41+
in the specified range batch <= 1024 are valid
42+
because batch was inferred to be a constant (2).
43+
- Not all values of batch = L['args'][1]['attention_mask'].size()[0]
44+
in the specified range batch <= 1024 are valid
45+
because batch was inferred to be a constant (2).
46+
- Not all values of batch = L['args'][1]['past_key_values']['key_cache'][0].size()[0]
47+
in the specified range batch <= 1024 are valid
48+
because batch was inferred to be a constant (2).
49+
- Not all values of batch = L['args'][1]['past_key_values']['value_cache'][0].size()[0]
50+
in the specified range batch <= 1024 are valid
51+
because batch was inferred to be a constant (2).
52+
Suggested fixes:
53+
batch = 2
54+
55+
However, this package implements a patch mechanism
56+
with replaces the part causing these issues.
57+
58+
.. note:: restart after an export failure
59+
60+
If the export fails, it is better to start executing again,
61+
or restart the kernel if you are in the notebook.
62+
The export may leave :epkg:`torch` in one unstable state.
63+
"""
64+
65+
import copy
66+
import torch
67+
import transformers
68+
from onnx_diagnostic.torch_export_patches.onnx_export_errors import bypass_export_some_errors
69+
from onnx_diagnostic.torch_models.llms import get_tiny_llm
70+
71+
72+
experiment = get_tiny_llm()
73+
untrained_model, inputs, dynamic_shapes = (
74+
experiment["model"],
75+
experiment["inputs"],
76+
experiment["dynamic_shapes"],
77+
)
78+
79+
cloned_inputs = copy.deepcopy(inputs)
80+
81+
82+
with bypass_export_some_errors(patch_transformers=True) as modificator:
83+
ep = torch.export.export(
84+
untrained_model,
85+
(),
86+
kwargs=modificator(cloned_inputs),
87+
dynamic_shapes=dynamic_shapes,
88+
)
89+
print("It worked:")
90+
print(ep)
91+
92+
# %%
93+
# With the original model
94+
# +++++++++++++++++++++++
95+
96+
MODEL_NAME = "arnir0/Tiny-LLM"
97+
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
98+
model = transformers.AutoModelForCausalLM.from_pretrained(MODEL_NAME)
99+
100+
cloned_inputs = copy.deepcopy(inputs)
101+
102+
with bypass_export_some_errors(patch_transformers=True) as modificator:
103+
ep = torch.export.export(
104+
model,
105+
(),
106+
kwargs=modificator(cloned_inputs),
107+
dynamic_shapes=dynamic_shapes,
108+
)
109+
print("It worked:")
110+
print(ep)

_doc/examples/plot_export_with_dynamic_cache.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -210,16 +210,13 @@ def forward(self, cache, z):
210210
# The export is simple if ``transformers>=4.50``, otherwise,
211211
# transformers needs to be patched.
212212
# :func:`onnx_diagnostic.torch_export_patches.bypass_export_some_errors`
213-
# registers functions to serialize ``DynamicCache`` and another class
214-
# called ``patched_DynamicCache``. This one is modified to make
213+
# registers functions to serialize ``DynamicCache``. This one is modified to make
215214
# the shape inference implemented in :epkg:`torch` happy.
216215

217216
if has_transformers("4.50"):
218217
ep = torch.export.export(model, inputs[0], dynamic_shapes=ds[0], strict=False)
219218
else:
220-
with bypass_export_some_errors(
221-
patch_transformers=True, replace_dynamic_cache=True
222-
) as modificator:
219+
with bypass_export_some_errors(patch_transformers=True) as modificator:
223220
ep = torch.export.export(
224221
model, modificator(inputs[0]), dynamic_shapes=ds[0], strict=False
225222
)

_doc/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ Enlightening Examples
6262
* :ref:`l-plot-sxport-with-dynamio-shapes-auto`
6363
* :ref:`l-plot-export-with-dynamic-shape`
6464
* :ref:`l-plot-tiny-llm-export`
65+
* :ref:`l-plot-tiny-llm-export-patched`
6566

6667
**Investigate ONNX models**
6768

_unittests/ut_torch_export_patches/test_onnx_export_errors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def forward(self, x: torch.Tensor, cache: MambaCache):
7979
model = Model()
8080
model(x, cache)
8181

82-
with bypass_export_some_errors(replace_dynamic_cache=True, verbose=1):
82+
with bypass_export_some_errors(verbose=1):
8383
cache = MambaCache(_config(), max_batch_size=1, device="cpu")
8484
torch.export.export(Model(), (x, cache))
8585

0 commit comments

Comments
 (0)