Skip to content

Commit b687a15

Browse files
committed
fix cachekeyvalue
1 parent 3ebf494 commit b687a15

File tree

3 files changed

+45
-3
lines changed

3 files changed

+45
-3
lines changed

_unittests/ut_helpers/test_cache_helper.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ def test_replace_by(self):
5050
past_key_values = make_dynamic_cache(
5151
[(torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim))]
5252
)
53+
self.assertEqual(
54+
"DynamicCache(key_cache=#1[T1s2x4x3x7], value_cache=#1[T1s2x4x3x7])",
55+
self.string_type(past_key_values, with_shape=True),
56+
)
5357
kwargs = dict(
5458
input_ids=torch.zeros(2, 3),
5559
attention_mask=torch.zeros(2, 3),

onnx_diagnostic/helpers/cache_helper.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,27 @@ class CacheKeyValue:
2020
.. code-block:: python
2121
2222
capi = CacheKeyValue(cache)
23-
capi.cache.key_cache
24-
capi.cache.value_cache
23+
capi.key_cache
24+
capi.value_cache
2525
"""
2626

2727
def __init__(self, cache):
2828
if hasattr(cache, "layers"):
29-
layers = [layer for layer in cache.layers if layer is not None]
29+
layers = [
30+
layer
31+
for layer in cache.layers
32+
if layer is not None and layer.keys is not None and layer.values is not None
33+
]
3034
self.key_cache = [layer.keys for layer in layers]
3135
self.value_cache = [layer.values for layer in layers]
36+
if None in self.key_cache or None in self.value_cache:
37+
from .helper import string_type
38+
39+
raise AssertionError(
40+
f"issue with key_cache={string_type(self.key_cache)}, "
41+
f"or value_cache={string_type(self.value_cache)}, "
42+
f"cache.layers={string_type(cache.layers)}"
43+
)
3244
else:
3345
self.key_cache = cache.key_cache
3446
self.value_cache = cache.value_cache

onnx_diagnostic/helpers/helper.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -710,6 +710,32 @@ def string_type(
710710
)
711711
return f"{obj.__class__.__name__}[{obj.cache_type}]{s}"
712712

713+
if obj.__class__.__name__ == "DynamicLayer":
714+
import transformers
715+
716+
assert isinstance(
717+
obj, transformers.cache_utils.DynamicLayer
718+
), f"Unexpected type {type(obj)}"
719+
if verbose:
720+
print(f"[string_type] LY0:{type(obj)}")
721+
s1 = string_type(
722+
obj.keys,
723+
with_shape=with_shape,
724+
with_min_max=with_min_max,
725+
with_device=with_device,
726+
limit=limit,
727+
verbose=verbose,
728+
)
729+
s2 = string_type(
730+
obj.values,
731+
with_shape=with_shape,
732+
with_min_max=with_min_max,
733+
with_device=with_device,
734+
limit=limit,
735+
verbose=verbose,
736+
)
737+
return f"{obj.__class__.__name__}(keys={s1}, values={s2})"
738+
713739
if isinstance(obj, torch.nn.Module):
714740
if verbose:
715741
print(f"[string_type] MM:{type(obj)}")

0 commit comments

Comments
 (0)