Skip to content

Commit d03c565

Browse files
committed
fix ci
1 parent 23a6901 commit d03c565

File tree

6 files changed

+56
-9
lines changed

6 files changed

+56
-9
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ jobs:
1717
os: [ubuntu-latest]
1818
python: ['3.11', '3.12']
1919
transformers: ['4.48', '4.50', 'main']
20-
torch: ['main']
20+
torch: ['2.6', 'main']
2121

2222
steps:
2323
- uses: actions/checkout@v3

CHANGELOGS.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@ Change Logs
44
0.2.1
55
+++++
66

7-
* :pr:`16`: refactors patches
7+
* :pr:`16`: refactors patches, add model Phi2, implements
8+
a tweak to raise an exception with a dynamic dimension
9+
becomes static when exporting a model
810

911
0.2.0
1012
+++++

_doc/examples/plot_export_tiny_llm.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def _forward_(*args, _f=None, **kwargs):
125125

126126
try:
127127
ep = torch.export.export(
128-
untrained_model, (), kwargs=cloned_inputs, dynamic_shapes=dynamic_shapes
128+
untrained_model, (), kwargs=cloned_inputs, dynamic_shapes=dynamic_shapes, strict=False
129129
)
130130
print("It worked:")
131131
print(ep)
@@ -159,7 +159,9 @@ def _forward_(*args, _f=None, **kwargs):
159159
# And Let's finally export.
160160

161161
try:
162-
ep = torch.export.export(model, (), kwargs=cloned_inputs, dynamic_shapes=dynamic_shapes)
162+
ep = torch.export.export(
163+
model, (), kwargs=cloned_inputs, dynamic_shapes=dynamic_shapes, strict=False
164+
)
163165
print("It worked:")
164166
print(ep)
165167
except Exception as e:

_doc/examples/plot_export_tiny_llm_patched.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@
2929
specified at `dynamic_shapes['past_key_values']`
3030
to non-tensor type <class 'transformers.cache_utils.DynamicCache'>
3131
at `inputs['past_key_values']` (expected None)
32-
For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#dynamic-shapes-validation
32+
For more information about this error,
33+
see: https://pytorch.org/docs/main/generated/exportdb/index.html#dynamic-shapes-validation
3334
3435
With ``transformers==4.50``, it shows the following:
3536
@@ -67,8 +68,9 @@
6768
import torch
6869
import transformers
6970
from onnx_diagnostic import doc
71+
from onnx_diagnostic.cache_helpers import is_cache_dynamic_registered
7072
from onnx_diagnostic.helpers import string_type
71-
from onnx_diagnostic.torch_export_patches.onnx_export_errors import bypass_export_some_errors
73+
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
7274
from onnx_diagnostic.torch_models.llms import get_tiny_llm
7375

7476

@@ -92,14 +94,25 @@
9294
pprint.pprint(dynamic_shapes)
9395

9496
# %%
95-
# We are ready to export.
97+
# Before exporting, we check :class:`transformers.cache_utils.DynamicCache`
98+
# can serialized and deserialized otherwise :func:`torch.export.export`
99+
# fails.
96100

97-
with bypass_export_some_errors(patch_transformers=True) as modificator:
101+
print("-- DynamicCache registered: ", is_cache_dynamic_registered())
102+
103+
# %%
104+
# If they are not registered, function
105+
# func:`onnx_diagnostic.torch_export_patches.bypass_export_some_errors`
106+
# should take care of it. Then we export.
107+
108+
with bypass_export_some_errors(patch_transformers=True, verbose=10) as modificator:
109+
assert is_cache_dynamic_registered() # it must be true here
98110
ep = torch.export.export(
99111
untrained_model,
100112
(),
101113
kwargs=modificator(cloned_inputs),
102114
dynamic_shapes=dynamic_shapes,
115+
strict=False, # mandatory for torch==2.6
103116
)
104117
print("It worked:")
105118
print(ep)
@@ -114,12 +127,13 @@
114127

115128
cloned_inputs = copy.deepcopy(inputs)
116129

117-
with bypass_export_some_errors(patch_transformers=True) as modificator:
130+
with bypass_export_some_errors(patch_transformers=True, verbose=10) as modificator:
118131
ep = torch.export.export(
119132
model,
120133
(),
121134
kwargs=modificator(cloned_inputs),
122135
dynamic_shapes=dynamic_shapes,
136+
strict=False, # mandatory for torch==2.6
123137
)
124138
print("It worked:")
125139
print(ep)

_unittests/ut_torch_models/test_tiny_llms_onnx.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
ExtTestCase,
77
ignore_warnings,
88
hide_stdout,
9+
has_torch,
910
requires_transformers,
1011
)
1112
from onnx_diagnostic.torch_models.llms import get_tiny_llm
@@ -35,6 +36,9 @@ def test_onnx_export_tiny_llm_official(self):
3536
dynamo=True,
3637
optimize=True,
3738
)
39+
# There are some discrepancies with torch==2.6
40+
if not has_torch("2.7"):
41+
raise unittest.SkipTest("discrepancies observed with torch<2.7")
3842
self.assert_onnx_disc(
3943
inspect.currentframe().f_code.co_name, ep.model_proto, model, inputs, verbose=1
4044
)
@@ -96,6 +100,9 @@ def test_bypass_onnx_export_tiny_llm_official_full(self):
96100
dynamo=True,
97101
optimize=True,
98102
)
103+
# There are some discrepancies with torch==2.6
104+
if not has_torch("2.7"):
105+
raise unittest.SkipTest("discrepancies observed with torch<2.7")
99106
self.assert_onnx_disc(
100107
inspect.currentframe().f_code.co_name, ep.model_proto, model, inputs, verbose=1
101108
)

onnx_diagnostic/cache_helpers.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,28 @@
44
import transformers
55
import transformers.cache_utils
66

7+
8+
def is_cache_dynamic_registered() -> bool:
9+
"""
10+
Tells class :class:`transformers.cache_utils.DynamicCache` can be
11+
serialized and deserialized. Only then, :func:`torch.export.export`
12+
can export a model.
13+
"""
14+
bsize, nheads, slen, dim = 2, 4, 3, 7
15+
cache = make_dynamic_cache(
16+
[
17+
(
18+
torch.randn(bsize, nheads, slen, dim),
19+
torch.randn(bsize, nheads, slen, dim),
20+
)
21+
for i in range(2)
22+
]
23+
)
24+
values, spec = torch.utils._pytree.tree_flatten(cache)
25+
cache2 = torch.utils._pytree.tree_unflatten(values, spec)
26+
return len(cache2.key_cache) == len(cache.value_cache)
27+
28+
729
if pv.Version(transformers.__version__) > pv.Version("4.49.99999"):
830

931
def make_dynamic_cache(

0 commit comments

Comments
 (0)