@@ -160,7 +160,19 @@ def make_dynamic_cache(
160160 )
161161 print(string_type(past_key_values, with_shape=True))
162162 """
163- return transformers .cache_utils .DynamicCache (key_value_pairs )
163+ cache = transformers .cache_utils .DynamicCache (key_value_pairs )
164+ if hasattr (cache , "layers" ) and len (key_value_pairs ) < len (cache .layers ):
165+ # The cache constructor contains the two following lines
166+ # (in cache_utils.py) which append empty layers when the cache is
167+ # initialized. We need to remove them.
168+ # self.num_hidden_layers = getattr(config, "num_hidden_layers", 1)
169+ # self.append_new_layers(self.num_hidden_layers - 1)
170+ cache .layers [:] = cache .layers [- len (key_value_pairs ) :]
171+ assert not hasattr (cache , "layers" ) or len (key_value_pairs ) == len (cache .layers ), (
172+ f"Unexpected number of layers in the cache ({ len (cache .layers )} ), "
173+ f"{ len (key_value_pairs )} expected."
174+ )
175+ return cache
164176
165177else :
166178
@@ -271,6 +283,17 @@ def __init__(self):
271283 d = key_value_pairs [i ][1 ].shape [2 ]
272284 ca .key_cache [i ][:, :, :d , :] = key_value_pairs [i ][0 ]
273285 ca .value_cache [i ][:, :, :d , :] = key_value_pairs [i ][1 ]
286+ if hasattr (cache , "layers" ) and len (key_value_pairs ) < len (cache .layers ):
287+ # The cache constructor contains the two following lines
288+ # (in cache_utils.py) which append empty layers when the cache is
289+ # initialized. We need to remove them.
290+ # self.num_hidden_layers = getattr(config, "num_hidden_layers", 1)
291+ # self.append_new_layers(self.num_hidden_layers - 1)
292+ cache .layers [:] = cache .layers [- len (key_value_pairs ) :]
293+ assert not hasattr (cache , "layers" ) or len (key_value_pairs ) == len (cache .layers ), (
294+ f"Unexpected number of layers in the cache ({ len (cache .layers )} ), "
295+ f"{ len (key_value_pairs )} expected."
296+ )
274297 return cache
275298
276299
@@ -355,6 +378,17 @@ def __init__(self):
355378 f"got { key_value_pairs [i ][1 ].shape } "
356379 )
357380 ca .value_cache [i ][:, :, :, :] = key_value_pairs [i ][1 ]
381+ if hasattr (cache , "layers" ) and len (key_value_pairs ) < len (cache .layers ):
382+ # The cache constructor contains the two following lines
383+ # (in cache_utils.py) which append empty layers when the cache is
384+ # initialized. We need to remove them.
385+ # self.num_hidden_layers = getattr(config, "num_hidden_layers", 1)
386+ # self.append_new_layers(self.num_hidden_layers - 1)
387+ cache .layers [:] = cache .layers [- len (key_value_pairs ) :]
388+ assert not hasattr (cache , "layers" ) or len (key_value_pairs ) == len (cache .layers ), (
389+ f"Unexpected number of layers in the cache ({ len (cache .layers )} ), "
390+ f"{ len (key_value_pairs )} expected."
391+ )
358392 return cache
359393
360394
@@ -500,4 +534,15 @@ class _config:
500534 )
501535 },
502536 )
537+ if hasattr (cache , "layers" ) and len (key_value_pairs ) < len (cache .layers ):
538+ # The cache constructor contains the two following lines
539+ # (in cache_utils.py) which append empty layers when the cache is
540+ # initialized. We need to remove them.
541+ # self.num_hidden_layers = getattr(config, "num_hidden_layers", 1)
542+ # self.append_new_layers(self.num_hidden_layers - 1)
543+ cache .layers [:] = cache .layers [- len (key_value_pairs ) :]
544+ assert not hasattr (cache , "layers" ) or len (key_value_pairs ) == len (cache .layers ), (
545+ f"Unexpected number of layers in the cache ({ len (cache .layers )} ), "
546+ f"{ len (key_value_pairs )} expected."
547+ )
503548 return cache
0 commit comments