Skip to content

Commit 6c8ab75

Browse files
committed
fix cache
1 parent bd49ebc commit 6c8ab75

File tree

3 files changed

+11
-18
lines changed

3 files changed

+11
-18
lines changed

_unittests/ut_torch_models/test_tiny_llms.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def test_tiny_llm_export_dynamic(self):
3737
dynamic_shapes=use_dyn_not_str(data["dynamic_shapes"]),
3838
)
3939
got = ep.module()(**inputs)
40+
print(ep)
4041
self.assertEqualArrayAny(expected, got)
4142

4243
@requires_transformers("4.52")

onnx_diagnostic/helpers/cache_helper.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,9 @@ class CacheKeyValue:
2626

2727
def __init__(self, cache):
2828
if hasattr(cache, "layers"):
29-
self.key_cache = [layer.keys for layer in cache.layers if layer.keys is not None]
30-
self.value_cache = [
31-
layer.values for layer in cache.layers if layer.values is not None
32-
]
29+
layers = [layer for layer in cache.layers if layer is not None]
30+
self.key_cache = [layer.keys for layer in layers]
31+
self.value_cache = [layer.values for layer in layers]
3332
else:
3433
self.key_cache = cache.key_cache
3534
self.value_cache = cache.value_cache

onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,12 @@
1414
except ImportError:
1515
from transformers.cache_utils import MambaCache
1616
from transformers.modeling_outputs import BaseModelOutput
17-
from ...helpers.cache_helper import make_hybrid_cache, make_static_cache, CacheKeyValue
17+
from ...helpers.cache_helper import (
18+
make_dynamic_cache,
19+
make_hybrid_cache,
20+
make_static_cache,
21+
CacheKeyValue,
22+
)
1823
from . import make_serialization_function_for_dataclass
1924

2025

@@ -96,8 +101,6 @@ def flatten_dynamic_cache(
96101
dynamic_cache: DynamicCache,
97102
) -> Tuple[List[Any], torch.utils._pytree.Context]:
98103
"""Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects."""
99-
if hasattr(transformers.cache_utils, "_flatten_dynamic_cache"):
100-
return transformers.cache_utils._flatten_dynamic_cache(dynamic_cache)
101104
ca = CacheKeyValue(dynamic_cache)
102105
flat = [("key_cache", ca.key_cache), ("value_cache", ca.value_cache)]
103106
return [f[1] for f in flat], [f[0] for f in flat]
@@ -107,8 +110,6 @@ def flatten_with_keys_dynamic_cache(
107110
dynamic_cache: DynamicCache,
108111
) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
109112
"""Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects."""
110-
if hasattr(transformers.cache_utils, "_flatten_with_keys_dynamic_cache"):
111-
return transformers.cache_utils._flatten_with_keys_dynamic_cache(dynamic_cache)
112113
values, context = flatten_dynamic_cache(dynamic_cache)
113114
return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
114115

@@ -117,15 +118,7 @@ def unflatten_dynamic_cache(
117118
values: List[Any], context: torch.utils._pytree.Context, output_type=None
118119
) -> DynamicCache:
119120
"""Restores a :class:`transformers.cache_utils.DynamicCache` from python objects."""
120-
if hasattr(transformers.cache_utils, "_unflatten_dynamic_cache"):
121-
assert output_type is None, f"output_type={output_type} not supported"
122-
return transformers.cache_utils._unflatten_dynamic_cache(values, context)
123-
124-
cache = transformers.cache_utils.DynamicCache()
125-
values = dict(zip(context, values))
126-
for k, v in values.items():
127-
setattr(cache, k, v)
128-
return cache
121+
return make_dynamic_cache(list(zip(values[0], values[1])))
129122

130123

131124
#############

0 commit comments

Comments
 (0)