Skip to content

Commit 18d1257

Browse files
committed
simple patch
1 parent 6e38709 commit 18d1257

File tree

2 files changed

+2
-45
lines changed

2 files changed

+2
-45
lines changed

onnx_diagnostic/torch_export_patches/onnx_export_errors.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -426,14 +426,6 @@ def torch_export_patches(
426426
patch_transformers_list, verbose=verbose
427427
)
428428

429-
if patch_transformers_list.patch_is_initialized:
430-
if verbose:
431-
print(
432-
"[torch_export_patches] patches "
433-
"transformers.cache_utils.CacheLayerMixin.is_initialized"
434-
)
435-
patch_transformers_list.apply_patch_for_is_initialized()
436-
437429
if (
438430
masking_utils
439431
and patch_transformers_list.patch_masking_utils
@@ -697,14 +689,6 @@ def torch_export_patches(
697689
"in ALL_MASK_ATTENTION_FUNCTIONS"
698690
)
699691

700-
if patch_transformers_list.patch_is_initialized:
701-
if verbose:
702-
print(
703-
"[torch_export_patches] restores "
704-
"transformers.cache_utils.CacheLayerMixin.is_initialized"
705-
)
706-
patch_transformers_list.disable_patch_for_is_initialized()
707-
708692
########
709693
# caches
710694
########

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 2 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -38,35 +38,6 @@
3838
patch_is_initialized = pv.Version(transformers.__version__) > pv.Version("4.56.99")
3939

4040

41-
def _get_is_initialized(self):
42-
return self.keys is not None
43-
44-
45-
def _set_is_initialized(self, value):
46-
assert (value and self.keys is not None) or (not value and self.keys is None), (
47-
f"The patch does not set is_initialized but checks the it is consistent with "
48-
f"``self.keys is not None``, value={value}, "
49-
f"self.keys is not None={self.keys is not None}"
50-
)
51-
52-
53-
def apply_patch_for_is_initialized():
54-
"""
55-
Fixes export issues introduced by PR `40791 <https://github.com/huggingface/transformers/pull/40791>`_.
56-
The attribute is_initialized does not seem to be captured by :func:`torch.export.export`.
57-
"""
58-
if patch_is_initialized:
59-
transformers.cache_utils.CacheLayerMixin.is_initialized = property(
60-
_get_is_initialized, _set_is_initialized
61-
)
62-
63-
64-
def disable_patch_for_is_initialized():
65-
"""Disables the patch applied by function :func:`applies_patch_for_is_initialized`."""
66-
if patch_is_initialized:
67-
delattr(transformers.cache_utils.CacheLayerMixin, "is_initialized")
68-
69-
7041
if patch_masking_utils:
7142
# Introduced in 4.52
7243
from transformers.masking_utils import (
@@ -245,6 +216,8 @@ def lazy_initialization(self, key_states: torch.Tensor):
245216
new_shape[-2] = 0
246217
self.keys = torch.empty(new_shape, dtype=self.dtype, device=self.device)
247218
self.values = torch.empty(new_shape, dtype=self.dtype, device=self.device)
219+
if patch_is_initialized:
220+
self.is_initialized = True
248221

249222

250223
def _patch_make_causal_mask(

0 commit comments

Comments
 (0)