Skip to content

Commit b94c731

Browse files
committed
fix
1 parent 31b272f commit b94c731

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,13 @@ class patched_DynamicLayer:
2424

2525
def lazy_initialization(self, key_states: torch.Tensor):
2626
self.dtype, self.device = key_states.dtype, key_states.device
27+
assert (
28+
hasattr(key_states, "shape") and key_states is not None
29+
), f"Attribute 'shape' is wrong for type {type(key_states)}"
30+
assert isinstance(key_states.shape, tuple), (
31+
f"Unxpected type {type(key_states.shape)} for key_states.shape, "
32+
f"__dict__={key_states.shape.__dict__}"
33+
)
2734
new_shape = list(key_states.shape)
2835
new_shape[-2] = 0
2936
# PATCHED: used a tensor with an empty shape and not en empty list to initialize

0 commit comments

Comments
 (0)