1414except ImportError :
1515 from transformers .cache_utils import MambaCache
1616from transformers .modeling_outputs import BaseModelOutput
17- from ...helpers .cache_helper import make_hybrid_cache , make_static_cache , CacheKeyValue
17+ from ...helpers .cache_helper import (
18+ make_dynamic_cache ,
19+ make_hybrid_cache ,
20+ make_static_cache ,
21+ CacheKeyValue ,
22+ )
1823from . import make_serialization_function_for_dataclass
1924
2025
@@ -96,8 +101,6 @@ def flatten_dynamic_cache(
96101 dynamic_cache : DynamicCache ,
97102) -> Tuple [List [Any ], torch .utils ._pytree .Context ]:
98103 """Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects."""
99- if hasattr (transformers .cache_utils , "_flatten_dynamic_cache" ):
100- return transformers .cache_utils ._flatten_dynamic_cache (dynamic_cache )
101104 ca = CacheKeyValue (dynamic_cache )
102105 flat = [("key_cache" , ca .key_cache ), ("value_cache" , ca .value_cache )]
103106 return [f [1 ] for f in flat ], [f [0 ] for f in flat ]
@@ -107,8 +110,6 @@ def flatten_with_keys_dynamic_cache(
107110 dynamic_cache : DynamicCache ,
108111) -> Tuple [List [Tuple [torch .utils ._pytree .KeyEntry , Any ]], torch .utils ._pytree .Context ]:
109112 """Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects."""
110- if hasattr (transformers .cache_utils , "_flatten_with_keys_dynamic_cache" ):
111- return transformers .cache_utils ._flatten_with_keys_dynamic_cache (dynamic_cache )
112113 values , context = flatten_dynamic_cache (dynamic_cache )
113114 return [(torch .utils ._pytree .MappingKey (k ), v ) for k , v in zip (context , values )], context
114115
@@ -117,15 +118,7 @@ def unflatten_dynamic_cache(
117118 values : List [Any ], context : torch .utils ._pytree .Context , output_type = None
118119) -> DynamicCache :
119120 """Restores a :class:`transformers.cache_utils.DynamicCache` from python objects."""
120- if hasattr (transformers .cache_utils , "_unflatten_dynamic_cache" ):
121- assert output_type is None , f"output_type={ output_type } not supported"
122- return transformers .cache_utils ._unflatten_dynamic_cache (values , context )
123-
124- cache = transformers .cache_utils .DynamicCache ()
125- values = dict (zip (context , values ))
126- for k , v in values .items ():
127- setattr (cache , k , v )
128- return cache
121+ return make_dynamic_cache (list (zip (values [0 ], values [1 ])))
129122
130123
131124#############
0 commit comments