Skip to content

Commit ef5b95c

Browse files
committed
fix remaining issues
1 parent 5de8afa commit ef5b95c

File tree

6 files changed

+37
-22
lines changed

6 files changed

+37
-22
lines changed

_doc/examples/plot_export_with_dynamic_cache.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,19 @@
2323
import pprint
2424
import torch
2525
from onnx_diagnostic import doc
26-
from onnx_diagnostic.ext_test_case import has_transformers
2726
from onnx_diagnostic.helpers import string_type
2827
from onnx_diagnostic.helpers.cache_helper import (
2928
flatten_unflatten_for_dynamic_shapes,
3029
make_dynamic_cache,
30+
CacheKeyValue,
3131
)
3232
from onnx_diagnostic.export import ModelInputs
3333
from onnx_diagnostic.torch_export_patches import torch_export_patches
3434

3535

3636
class Model(torch.nn.Module):
3737
def forward(self, cache, z):
38+
cache = CacheKeyValue(cache)
3839
return (
3940
z
4041
+ cache.key_cache[0]
@@ -105,13 +106,8 @@ def forward(self, cache, z):
105106
# registers functions to serialize ``DynamicCache``. This one is modified to make
106107
# the shape inference implemented in :epkg:`torch` happy.
107108

108-
if has_transformers("4.50"):
109+
with torch_export_patches(patch_transformers=True):
109110
ep = torch.export.export(model, inputs[0], dynamic_shapes=ds[0], strict=False)
110-
else:
111-
with torch_export_patches(patch_transformers=True) as modificator:
112-
ep = torch.export.export(
113-
model, modificator(inputs[0]), dynamic_shapes=ds[0], strict=False
114-
)
115111
print(ep)
116112

117113
# %%

_unittests/ut_torch_models/test_validate_whole_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ def test_validate_model_vit_model(self):
258258
@requires_torch("2.7")
259259
@hide_stdout()
260260
@ignore_warnings(FutureWarning)
261-
@requires_transformers("4.53")
261+
@requires_transformers("4.55")
262262
def test_validate_phi35_mini_instruct(self):
263263
mid = "microsoft/Phi-3.5-mini-instruct"
264264
summary, data = validate_model(
@@ -282,7 +282,7 @@ def test_validate_phi35_mini_instruct(self):
282282
@requires_torch("2.7")
283283
@hide_stdout()
284284
@ignore_warnings(FutureWarning)
285-
@requires_transformers("4.53")
285+
@requires_transformers("4.55")
286286
def test_validate_phi35_4k_mini_instruct(self):
287287
mid = "microsoft/Phi-3-mini-4k-instruct"
288288
summary, data = validate_model(

onnx_diagnostic/export/dynamic_shapes.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -887,27 +887,41 @@ def guess_dynamic_shape_object(
887887

888888
# In case DynamicCache is not registered.
889889
if obj.__class__.__name__ == "DynamicCache":
890-
kc = set(len(o.key_cache) for o in objs)
891-
assert (
892-
len(kc) == 1
893-
), f"All attribute 'key_cache' should have the same length but found {kc}"
894-
vc = set(len(o.value_cache) for o in objs)
895-
assert (
896-
len(vc) == 1
897-
), f"All attribute 'value_cache' should have the same length but found {vc}"
890+
if hasattr(obj, "layers"):
891+
kc = set(len(o.layers) for o in objs)
892+
assert (
893+
len(kc) == 1
894+
), f"All attribute 'key_cache' should have the same length but found {kc}"
895+
vc = kc.copy()
896+
else:
897+
kc = set(len(o.key_cache) for o in objs)
898+
assert (
899+
len(kc) == 1
900+
), f"All attribute 'key_cache' should have the same length but found {kc}"
901+
vc = set(len(o.value_cache) for o in objs)
902+
assert (
903+
len(vc) == 1
904+
), f"All attribute 'value_cache' should have the same length but found {vc}"
905+
898906
key_cache = []
899907
for i in range(kc.pop()):
900908
key_cache.append(
901909
self.guess_dynamic_dimensions(
902-
*[o.key_cache[i] for o in objs],
910+
*[
911+
o.layers[i].keys if hasattr(o, "layers") else o.key_cache[i]
912+
for o in objs
913+
],
903914
auto=auto if isinstance(auto, bool) else f"{auto}_{i}kdc",
904915
)
905916
)
906917
value_cache = []
907918
for i in range(vc.pop()):
908919
value_cache.append(
909920
self.guess_dynamic_dimensions(
910-
*[o.value_cache[i] for o in objs],
921+
*[
922+
o.layers[i].values if hasattr(o, "layers") else o.value_cache[i]
923+
for o in objs
924+
],
911925
auto=auto if isinstance(auto, bool) else f"{auto}_{i}vdc",
912926
)
913927
)

onnx_diagnostic/export/shape_helper.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ def all_dynamic_shape_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any:
99
All dimensions are considered as dynamic.
1010
``dim_prefix`` can be a string (the function uses it as a prefix),
1111
or ``torch.export.Dim.AUTO`` or ``torch.export.Dim.DYNAMIC``.
12+
Depending on the version of transformers, serializations function
13+
of DynamicCache class is automatically serialized or not (>= 4.51, < 4.55).
1214
1315
.. runpython::
1416
:showcode:
@@ -17,6 +19,7 @@ def all_dynamic_shape_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any:
1719
import torch
1820
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
1921
from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs
22+
from onnx_diagnostic.torch_export_patches import torch_export_patches
2023
2124
bsize, nheads, slen, dim = 2, 1, 30, 96
2225
inputs = dict(
@@ -25,10 +28,11 @@ def all_dynamic_shape_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any:
2528
position_ids=torch.arange(3, dtype=torch.int64),
2629
past_key_values=make_dynamic_cache(
2730
[(torch.randn(bsize, nheads, slen, dim),
28-
torch.randn(bsize, nheads, slen, dim))]
31+
torch.randn(bsize, nheads, slen, dim))]
2932
),
3033
)
31-
ds = all_dynamic_shape_from_inputs(inputs)
34+
with torch_export_patches(patch_transformers=True):
35+
ds = all_dynamic_shape_from_inputs(inputs)
3236
pprint.pprint(ds)
3337
3438
For this function to work, patches must be enabled if :epkg:`transformers`

onnx_diagnostic/helpers/cache_helper.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,8 @@ def is_cache_dynamic_registered(fast: bool = False) -> bool:
131131
)
132132
values, spec = torch.utils._pytree.tree_flatten(cache)
133133
cache2 = torch.utils._pytree.tree_unflatten(values, spec)
134+
if hasattr(cache2, "layers") and hasattr(cache, "layers"):
135+
return len(cache2.layers) == len(cache.layers)
134136
return len(cache2.key_cache) == len(cache.value_cache)
135137

136138

onnx_diagnostic/tasks/feature_extraction.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,5 +160,4 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
160160
if hasattr(config, att):
161161
kwargs[att] = getattr(config, att)
162162
kwargs["decoder_ffn_dim"] = kwargs["encoder_ffn_dim"] = 64
163-
print(kwargs)
164163
return kwargs, get_inputs

0 commit comments

Comments
 (0)