Skip to content

Commit 6e38709

Browse files
committed
patch
1 parent e94564f commit 6e38709

File tree

4 files changed

+52
-0
lines changed

4 files changed

+52
-0
lines changed

CHANGELOGS.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ Change Logs
44
0.7.11
55
++++++
66

7+
* :pr:`220`: adds a patch for PR `#40791 <https://github.com/huggingface/transformers/pull/40791>`_ in transformers
8+
79
0.7.10
810
++++++
911

onnx_diagnostic/torch_export_patches/onnx_export_errors.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,14 @@ def torch_export_patches(
426426
patch_transformers_list, verbose=verbose
427427
)
428428

429+
if patch_transformers_list.patch_is_initialized:
430+
if verbose:
431+
print(
432+
"[torch_export_patches] patches "
433+
"transformers.cache_utils.CacheLayerMixin.is_initialized"
434+
)
435+
patch_transformers_list.apply_patch_for_is_initialized()
436+
429437
if (
430438
masking_utils
431439
and patch_transformers_list.patch_masking_utils
@@ -689,6 +697,14 @@ def torch_export_patches(
689697
"in ALL_MASK_ATTENTION_FUNCTIONS"
690698
)
691699

700+
if patch_transformers_list.patch_is_initialized:
701+
if verbose:
702+
print(
703+
"[torch_export_patches] restores "
704+
"transformers.cache_utils.CacheLayerMixin.is_initialized"
705+
)
706+
patch_transformers_list.disable_patch_for_is_initialized()
707+
692708
########
693709
# caches
694710
########

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,38 @@
3535
from ...ext_test_case import has_transformers
3636
from ...helpers.torch_helper import is_torchdynamo_exporting
3737

38+
patch_is_initialized = pv.Version(transformers.__version__) > pv.Version("4.56.99")
39+
40+
41+
def _get_is_initialized(self):
42+
return self.keys is not None
43+
44+
45+
def _set_is_initialized(self, value):
46+
assert (value and self.keys is not None) or (not value and self.keys is None), (
47+
f"The patch does not set is_initialized but checks the it is consistent with "
48+
f"``self.keys is not None``, value={value}, "
49+
f"self.keys is not None={self.keys is not None}"
50+
)
51+
52+
53+
def apply_patch_for_is_initialized():
54+
"""
55+
Fixes export issues introduced by PR `40791 <https://github.com/huggingface/transformers/pull/40791>`_.
56+
The attribute is_initialized does not seem to be captured by :func:`torch.export.export`.
57+
"""
58+
if patch_is_initialized:
59+
transformers.cache_utils.CacheLayerMixin.is_initialized = property(
60+
_get_is_initialized, _set_is_initialized
61+
)
62+
63+
64+
def disable_patch_for_is_initialized():
65+
"""Disables the patch applied by function :func:`applies_patch_for_is_initialized`."""
66+
if patch_is_initialized:
67+
delattr(transformers.cache_utils.CacheLayerMixin, "is_initialized")
68+
69+
3870
if patch_masking_utils:
3971
# Introduced in 4.52
4072
from transformers.masking_utils import (

onnx_diagnostic/torch_models/validate.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1518,6 +1518,8 @@ def call_torch_export_custom(
15181518
"default+onnxruntime+os_ort",
15191519
None,
15201520
}
1521+
if optimization == "none":
1522+
optimization = ""
15211523
assert (
15221524
optimization in available
15231525
), f"unexpected value for optimization={optimization}, available={available}"

0 commit comments

Comments
 (0)