File tree Expand file tree Collapse file tree 1 file changed +2
-20
lines changed
onnx_diagnostic/torch_export_patches/serialization Expand file tree Collapse file tree 1 file changed +2
-20
lines changed Original file line number Diff line number Diff line change 1616from ...helpers .cache_helper import (
1717 make_dynamic_cache ,
1818 make_hybrid_cache ,
19+ make_sliding_window_cache ,
1920 make_static_cache ,
2021 CacheKeyValue ,
2122)
@@ -218,26 +219,7 @@ def unflatten_sliding_window_cache(
218219) -> SlidingWindowCache :
219220 """Restores a :class:`transformers.cache_utils.SlidingWindowCache` from python objects."""
220221 key_cache , value_cache = values
221-
222- class _config :
223- def __init__ (self ):
224- self .head_dim = key_cache [0 ].shape [- 1 ]
225- self .num_attention_heads = key_cache [0 ].shape [1 ]
226- self .num_hidden_layers = len (key_cache )
227- self .sliding_window = key_cache [0 ].shape [2 ]
228-
229- cache = SlidingWindowCache (
230- _config (),
231- max_batch_size = key_cache [0 ].shape [0 ],
232- max_cache_len = key_cache [0 ].shape [2 ], # sligding window
233- device = key_cache [0 ].device ,
234- dtype = key_cache [0 ].dtype ,
235- )
236-
237- values = dict (zip (context , values ))
238- for k , v in values .items ():
239- setattr (cache , k , v )
240- return cache
222+ return make_sliding_window_cache (list (zip (values [0 ], values [1 ])))
241223
242224
243225#####################
You can’t perform that action at this time.
0 commit comments