Skip to content

Commit 36aa79e

Browse files
committed
rewrite patch
1 parent b94c731 commit 36aa79e

File tree

2 files changed

+5
-18
lines changed

2 files changed

+5
-18
lines changed

onnx_diagnostic/helpers/cache_helper.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,6 @@ def __init__(self, cache=None):
2828
]
2929
self.key_cache = [layer.keys for layer in layers]
3030
self.value_cache = [layer.values for layer in layers]
31-
if None in self.key_cache or None in self.value_cache:
32-
from .helper import string_type
33-
34-
raise AssertionError(
35-
f"issue with key_cache={string_type(self.key_cache)}, "
36-
f"or value_cache={string_type(self.value_cache)}, "
37-
f"cache.layers={string_type(cache.layers)}"
38-
)
3931
elif cache is not None and hasattr(cache, "key_cache"):
4032
self.key_cache = cache.key_cache
4133
self.value_cache = cache.value_cache

onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,20 +27,15 @@ def lazy_initialization(self, key_states: torch.Tensor):
2727
assert (
2828
hasattr(key_states, "shape") and key_states is not None
2929
), f"Attribute 'shape' is wrong for type {type(key_states)}"
30-
assert isinstance(key_states.shape, tuple), (
31-
f"Unxpected type {type(key_states.shape)} for key_states.shape, "
32-
f"__dict__={key_states.shape.__dict__}"
33-
)
34-
new_shape = list(key_states.shape)
35-
new_shape[-2] = 0
30+
like = key_states[:, :0]
3631
# PATCHED: used a tensor with an empty shape and not en empty list to initialize
3732
if isinstance(key_states, torch._subclasses.fake_tensor.FakeTensor):
3833
with key_states.fake_mode:
39-
self.keys = torch.empty(new_shape, dtype=self.dtype, device=self.device)
40-
self.values = torch.empty(new_shape, dtype=self.dtype, device=self.device)
34+
self.keys = torch.empty_like(like, dtype=self.dtype, device=self.device)
35+
self.values = torch.empty_like(like, dtype=self.dtype, device=self.device)
4136
else:
42-
self.keys = torch.empty(new_shape, dtype=self.dtype, device=self.device)
43-
self.values = torch.empty(new_shape, dtype=self.dtype, device=self.device)
37+
self.keys = torch.empty_like(like, dtype=self.dtype, device=self.device)
38+
self.values = torch.empty_like(like, dtype=self.dtype, device=self.device)
4439
if patch_is_initialized:
4540
self.is_initialized = True
4641

0 commit comments

Comments
 (0)