Skip to content

Commit 2d3e317

Browse files
committed
fix cache
1 parent 3eb75b0 commit 2d3e317

File tree

1 file changed

+46
-1
lines changed

1 file changed

+46
-1
lines changed

onnx_diagnostic/helpers/cache_helper.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,19 @@ def make_dynamic_cache(
160160
)
161161
print(string_type(past_key_values, with_shape=True))
162162
"""
163-
return transformers.cache_utils.DynamicCache(key_value_pairs)
163+
cache = transformers.cache_utils.DynamicCache(key_value_pairs)
164+
if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers):
165+
# The cache constructor contains the two following lines
166+
# (in cache_utils.py) which append empty layers when the cache is
167+
# initialized. We need to remove them.
168+
# self.num_hidden_layers = getattr(config, "num_hidden_layers", 1)
169+
# self.append_new_layers(self.num_hidden_layers - 1)
170+
cache.layers[:] = cache.layers[-len(key_value_pairs) :]
171+
assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), (
172+
f"Unexpected number of layers in the cache ({len(cache.layers)}), "
173+
f"{len(key_value_pairs)} expected."
174+
)
175+
return cache
164176

165177
else:
166178

@@ -271,6 +283,17 @@ def __init__(self):
271283
d = key_value_pairs[i][1].shape[2]
272284
ca.key_cache[i][:, :, :d, :] = key_value_pairs[i][0]
273285
ca.value_cache[i][:, :, :d, :] = key_value_pairs[i][1]
286+
if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers):
287+
# The cache constructor contains the two following lines
288+
# (in cache_utils.py) which append empty layers when the cache is
289+
# initialized. We need to remove them.
290+
# self.num_hidden_layers = getattr(config, "num_hidden_layers", 1)
291+
# self.append_new_layers(self.num_hidden_layers - 1)
292+
cache.layers[:] = cache.layers[-len(key_value_pairs) :]
293+
assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), (
294+
f"Unexpected number of layers in the cache ({len(cache.layers)}), "
295+
f"{len(key_value_pairs)} expected."
296+
)
274297
return cache
275298

276299

@@ -355,6 +378,17 @@ def __init__(self):
355378
f"got {key_value_pairs[i][1].shape}"
356379
)
357380
ca.value_cache[i][:, :, :, :] = key_value_pairs[i][1]
381+
if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers):
382+
# The cache constructor contains the two following lines
383+
# (in cache_utils.py) which append empty layers when the cache is
384+
# initialized. We need to remove them.
385+
# self.num_hidden_layers = getattr(config, "num_hidden_layers", 1)
386+
# self.append_new_layers(self.num_hidden_layers - 1)
387+
cache.layers[:] = cache.layers[-len(key_value_pairs) :]
388+
assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), (
389+
f"Unexpected number of layers in the cache ({len(cache.layers)}), "
390+
f"{len(key_value_pairs)} expected."
391+
)
358392
return cache
359393

360394

@@ -500,4 +534,15 @@ class _config:
500534
)
501535
},
502536
)
537+
if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers):
538+
# The cache constructor contains the two following lines
539+
# (in cache_utils.py) which append empty layers when the cache is
540+
# initialized. We need to remove them.
541+
# self.num_hidden_layers = getattr(config, "num_hidden_layers", 1)
542+
# self.append_new_layers(self.num_hidden_layers - 1)
543+
cache.layers[:] = cache.layers[-len(key_value_pairs) :]
544+
assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), (
545+
f"Unexpected number of layers in the cache ({len(cache.layers)}), "
546+
f"{len(key_value_pairs)} expected."
547+
)
503548
return cache

0 commit comments

Comments
 (0)