Skip to content

Commit 7cc8de9

Browse files
authored
fix patched lazy_initialization for transformers>=5 (#376)
* fix patched lazy_initialization for transformers>=5 * doc
1 parent 598d5ea commit 7cc8de9

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.8.8
55
+++++
66

7+
* :pr:`376`: fix patched lazy_initialization for transformers>=5
78
* :pr:`372`: fix patch on rotary embedding
89
* :pr:`371`: fix make_fake_with_dynamic_dimensions
910

onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@ class patched_DynamicLayer:
2222
_PATCHES_ = ["lazy_initialization"]
2323
_PATCHED_CLASS_ = DynamicLayer
2424

25-
def lazy_initialization(self, key_states: torch.Tensor):
25+
def lazy_initialization(
26+
self, key_states: torch.Tensor, value_states: torch.Tensor = None
27+
):
2628
self.dtype, self.device = key_states.dtype, key_states.device
2729
assert (
2830
hasattr(key_states, "shape") and key_states is not None

0 commit comments

Comments
 (0)