Skip to content

Commit 245be95

Browse files
committed
fix patches
1 parent 02eaf02 commit 245be95

File tree

1 file changed

+2
-20
lines changed

1 file changed

+2
-20
lines changed

onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from ...helpers.cache_helper import (
1717
make_dynamic_cache,
1818
make_hybrid_cache,
19+
make_sliding_window_cache,
1920
make_static_cache,
2021
CacheKeyValue,
2122
)
@@ -218,26 +219,7 @@ def unflatten_sliding_window_cache(
218219
) -> SlidingWindowCache:
219220
"""Restores a :class:`transformers.cache_utils.SlidingWindowCache` from python objects."""
220221
key_cache, value_cache = values
221-
222-
class _config:
223-
def __init__(self):
224-
self.head_dim = key_cache[0].shape[-1]
225-
self.num_attention_heads = key_cache[0].shape[1]
226-
self.num_hidden_layers = len(key_cache)
227-
self.sliding_window = key_cache[0].shape[2]
228-
229-
cache = SlidingWindowCache(
230-
_config(),
231-
max_batch_size=key_cache[0].shape[0],
232-
max_cache_len=key_cache[0].shape[2], # sligding window
233-
device=key_cache[0].device,
234-
dtype=key_cache[0].dtype,
235-
)
236-
237-
values = dict(zip(context, values))
238-
for k, v in values.items():
239-
setattr(cache, k, v)
240-
return cache
222+
return make_sliding_window_cache(list(zip(values[0], values[1])))
241223

242224

243225
#####################

0 commit comments

Comments
 (0)