Skip to content

Commit 85980d6

Browse files
committed
fix patched lazy_initialization for transformers>=5
1 parent 598d5ea commit 85980d6

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@ class patched_DynamicLayer:
2222
_PATCHES_ = ["lazy_initialization"]
2323
_PATCHED_CLASS_ = DynamicLayer
2424

25-
def lazy_initialization(self, key_states: torch.Tensor):
25+
def lazy_initialization(
26+
self, key_states: torch.Tensor, value_states: torch.Tensor = None
27+
):
2628
self.dtype, self.device = key_states.dtype, key_states.device
2729
assert (
2830
hasattr(key_states, "shape") and key_states is not None

0 commit comments

Comments
 (0)