Skip to content

Commit 1a0a46d

Browse files
committed
fix patches
1 parent 4bd2185 commit 1a0a46d

File tree

1 file changed

+20
-7
lines changed

1 file changed

+20
-7
lines changed

onnx_diagnostic/torch_export_patches/onnx_export_errors.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,13 @@ def unpatch_module(mod, info: Dict[type, Dict[type, Callable]], verbose: int = 0
6565
setattr(original, n, v)
6666

6767

68+
PATCH_OF_PATCHES = set()
69+
70+
6871
def _register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
6972
# Cache serialization: to be moved into appropriate packages
7073
import torch
74+
import transformers
7175
import packaging.version as pv
7276

7377
try:
@@ -109,11 +113,17 @@ def _register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
109113
# torch.fx._pytree.register_pytree_flatten_spec(
110114
# DynamicCache, _flatten_dynamic_cache_for_fx)
111115
# so we remove it anyway
112-
if DynamicCache in torch.fx._pytree.SUPPORTED_NODES and pv.Version(
113-
torch.__version__
114-
) >= pv.Version("2.7"):
116+
if (
117+
DynamicCache in torch.fx._pytree.SUPPORTED_NODES
118+
and not PATCH_OF_PATCHES
119+
and pv.Version(torch.__version__) >= pv.Version("2.7")
120+
and pv.Version(transformers.__version__) >= pv.Version("4.50")
121+
):
115122
if verbose:
116-
print("[_register_cache_serialization] DynamicCache is unregistered first.")
123+
print(
124+
"[_register_cache_serialization] DynamicCache "
125+
"is unregistered and registered first."
126+
)
117127
_unregister(DynamicCache)
118128
torch.utils._pytree.register_pytree_node(
119129
DynamicCache,
@@ -122,6 +132,8 @@ def _register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
122132
serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}",
123133
flatten_with_keys_fn=flatten_with_keys_dynamic_cache,
124134
)
135+
# To avoid doing it multiple times.
136+
PATCH_OF_PATCHES.add(DynamicCache)
125137

126138
unregistered_dynamic_cache = True
127139
if DynamicCache is not None and DynamicCache in torch.utils._pytree.SUPPORTED_NODES:
@@ -138,9 +150,10 @@ def _register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
138150
serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}",
139151
flatten_with_keys_fn=flatten_with_keys_dynamic_cache,
140152
)
141-
torch.fx._pytree.register_pytree_flatten_spec(
142-
DynamicCache, lambda x, _: [x.key_cache, x.value_cache]
143-
)
153+
if pv.Version(torch.__version__) < pv.Version("2.7"):
154+
torch.fx._pytree.register_pytree_flatten_spec(
155+
DynamicCache, lambda x, _: [x.key_cache, x.value_cache]
156+
)
144157

145158
# check
146159
from ..helpers.cache_helper import make_dynamic_cache

0 commit comments

Comments
 (0)