@@ -134,7 +134,54 @@ def is_cache_dynamic_registered(fast: bool = False) -> bool:
134134 return len (cache2 .key_cache ) == len (cache .value_cache )
135135
136136
137- if pv .Version (transformers .__version__ ) > pv .Version ("4.49.99999" ):
137+ if (
138+ pv .Version (transformers .__version__ ) > pv .Version ("4.99.99999" )
139+ or transformers .__version__ == "4.57.0.dev0"
140+ ):
141+
142+ def make_dynamic_cache (
143+ key_value_pairs : List [Tuple [torch .Tensor , torch .Tensor ]],
144+ ) -> transformers .cache_utils .DynamicCache :
145+ """
146+ Creates an instance of :class:`transformers.cache_utils.DynamicCache`.
147+ This version is valid for ``transformers >= 4.50``.
148+
149+ :param key_value_pairs: list of pairs of (key, values)
150+ :return: :class:`transformers.cache_utils.DynamicCache`
151+
152+ Example:
153+
154+ .. runpython::
155+ :showcode:
156+
157+ import torch
158+ from onnx_diagnostic.helpers import string_type
159+ from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
160+
161+ n_layers = 2
162+ bsize, nheads, slen, dim = 2, 4, 3, 7
163+
164+ past_key_values = make_dynamic_cache(
165+ [
166+ (
167+ torch.randn(bsize, nheads, slen, dim),
168+ torch.randn(bsize, nheads, slen, dim),
169+ )
170+ for i in range(n_layers)
171+ ]
172+ )
173+ print(string_type(past_key_values, with_shape=True))
174+ """
175+ cache = transformers .cache_utils .DynamicCache (
176+ [(None , k , v ) for k , v in key_value_pairs ]
177+ )
178+ assert not hasattr (cache , "layers" ) or len (key_value_pairs ) == len (cache .layers ), (
179+ f"Unexpected number of layers in the cache ({ len (cache .layers )} ), "
180+ f"{ len (key_value_pairs )} expected."
181+ )
182+ return finalize_cache (cache )
183+
184+ elif pv .Version (transformers .__version__ ) > pv .Version ("4.49.99999" ):
138185
139186 def make_dynamic_cache (
140187 key_value_pairs : List [Tuple [torch .Tensor , torch .Tensor ]],
0 commit comments