Skip to content

Commit 962993e

Browse files
committed
fix many issues
1 parent 7174ff4 commit 962993e

File tree

5 files changed

+33
-11
lines changed

5 files changed

+33
-11
lines changed

_unittests/ut_export/test_dynamic_shapes.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import transformers
44
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_transformers
55
from onnx_diagnostic.helpers import string_type
6-
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
6+
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache, CacheKeyValue
77
from onnx_diagnostic.export import ModelInputs, CoupleInputsDynamicShapes
88
from onnx_diagnostic.torch_export_patches import torch_export_patches
99
from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs
@@ -408,6 +408,7 @@ def forward(self, xy, z):
408408
def test_guess_dynamic_shapes_cache(self):
409409
class Model(torch.nn.Module):
410410
def forward(self, cache, z):
411+
cache = CacheKeyValue(cache)
411412
return (
412413
z
413414
+ cache.key_cache[0]
@@ -475,6 +476,7 @@ def forward(self, cache, z):
475476
def test_guess_dynamic_shapes_cache_str(self):
476477
class Model(torch.nn.Module):
477478
def forward(self, cache, z):
479+
cache = CacheKeyValue(cache)
478480
return (
479481
z
480482
+ cache.key_cache[0]
@@ -812,8 +814,9 @@ def test_couple_input_ds_change_dynamic_dimensions_dynamic_cache(self):
812814
with torch_export_patches(patch_transformers=True):
813815
new_inputs = inst.change_dynamic_dimensions()
814816
self.assertIsInstance(new_inputs["A"], transformers.cache_utils.DynamicCache)
815-
self.assertEqual((3, 2, 3, 2), new_inputs["A"].key_cache[0].shape)
816-
self.assertEqual((3, 2, 3, 2), new_inputs["A"].value_cache[0].shape)
817+
new_inputs_A = CacheKeyValue(new_inputs["A"])
818+
self.assertEqual((3, 2, 3, 2), new_inputs_A.key_cache[0].shape)
819+
self.assertEqual((3, 2, 3, 2), new_inputs_A.value_cache[0].shape)
817820

818821
@requires_transformers("4.51")
819822
def test_dynamic_cache_replace_by_string(self):

_unittests/ut_export/test_serialization.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from onnx_diagnostic.helpers.cache_helper import (
66
make_dynamic_cache,
77
flatten_unflatten_for_dynamic_shapes,
8+
CacheKeyValue,
89
)
910
from onnx_diagnostic.export import ModelInputs
1011
from onnx_diagnostic.torch_export_patches import torch_export_patches
@@ -23,13 +24,15 @@ def _get_cache(self, n_layers=2, bsize=2, nheads=4, slen=1, dim=7):
2324
def test_dynamic_cache(self):
2425
class Model(torch.nn.Module):
2526
def forward(self, cache):
27+
cache = CacheKeyValue(cache)
2628
return cache.key_cache[0]
2729

2830
cache = self._get_cache()
2931
DYN = torch.export.Dim.DYNAMIC
3032
ds = {0: DYN, 1: DYN, 3: DYN}
3133
dynamic_shapes = ([[ds, ds], [ds, ds]],)
32-
exp = torch.export.export(Model(), (cache,), dynamic_shapes=dynamic_shapes)
34+
with torch_export_patches(patch_transformers=True):
35+
exp = torch.export.export(Model(), (cache,), dynamic_shapes=dynamic_shapes)
3336
self.assertNotEmpty(exp)
3437

3538
@requires_transformers("4.50")

_unittests/ut_helpers/test_mini_onnx_builder.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
create_input_tensors_from_onnx_model,
99
MiniOnnxBuilder,
1010
)
11-
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
11+
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache, CacheKeyValue
1212
from onnx_diagnostic.helpers import string_type
1313

1414

@@ -127,8 +127,9 @@ def test_mini_onnx_builder(self):
127127

128128
def test_mini_onnx_builder_transformers(self):
129129
cache = make_dynamic_cache([(torch.ones((3, 3)), torch.ones((3, 3)) * 2)])
130-
self.assertEqual(len(cache.key_cache), 1)
131-
self.assertEqual(len(cache.value_cache), 1)
130+
dc = CacheKeyValue(cache)
131+
self.assertEqual(len(dc.key_cache), 1)
132+
self.assertEqual(len(dc.value_cache), 1)
132133

133134
data = [(cache,), cache]
134135

@@ -140,8 +141,9 @@ def test_mini_onnx_builder_transformers(self):
140141

141142
def test_mini_onnx_builder_transformers_sep(self):
142143
cache = make_dynamic_cache([(torch.ones((3, 3)), torch.ones((3, 3)) * 2)])
143-
self.assertEqual(len(cache.key_cache), 1)
144-
self.assertEqual(len(cache.value_cache), 1)
144+
dc = CacheKeyValue(cache)
145+
self.assertEqual(len(dc.key_cache), 1)
146+
self.assertEqual(len(dc.value_cache), 1)
145147

146148
data = [(cache,), cache]
147149

onnx_diagnostic/helpers/cache_helper.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,14 @@ def __init__(self, cache=None):
4141
f"or value_cache={string_type(self.value_cache)}, "
4242
f"cache.layers={string_type(cache.layers)}"
4343
)
44-
elif cache is not None:
44+
elif cache is not None and hasattr(cache, "key_cache"):
4545
self.key_cache = cache.key_cache
4646
self.value_cache = cache.value_cache
47+
elif cache is None:
48+
self.key_cache = None
49+
self.value_cache = None
50+
else:
51+
raise NotImplementedError(f"type(cache)={type(cache)}")
4752

4853
def make_dynamic_cache(self):
4954
"""Do the reverse operation."""
@@ -401,6 +406,13 @@ def __init__(self):
401406
dtype=key_value_pairs[0][0].dtype,
402407
)
403408
ca = CacheKeyValue(cache)
409+
if hasattr(cache, "layers") and len(ca.key_cache) == 0:
410+
# transformers>= 4.55.2, layers are empty
411+
cache_position = torch.arange(key_value_pairs[0][0].shape[2], dtype=torch.int64)
412+
for i, (key, value) in enumerate(key_value_pairs):
413+
cache.update(key, value, i, cache_kwargs={"cache_position": cache_position})
414+
return cache
415+
404416
for i in range(len(key_value_pairs)):
405417
assert ca.key_cache[i].shape == key_value_pairs[i][0].shape, (
406418
f"Shape mismatch, expected {cache.key_cache[i].shape}, "

onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,9 @@ def unflatten_encoder_decoder_cache(
261261
) -> EncoderDecoderCache:
262262
"""Restores a :class:`transformers.cache_utils.EncoderDecoderCache` from python objects."""
263263
dictionary = torch.utils._pytree._dict_unflatten(values, context)
264-
return EncoderDecoderCache(**dictionary)
264+
return EncoderDecoderCache(
265+
dictionary["self_attention_cache"], dictionary["cross_attention_cache"]
266+
)
265267

266268

267269
#############

0 commit comments

Comments
 (0)