Skip to content

Commit e724b5e

Browse files
committed
fix issues
1 parent 1a0a46d commit e724b5e

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

onnx_diagnostic/torch_export_patches/onnx_export_errors.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import contextlib
22
import pprint
3-
from typing import Any, Callable, Dict
3+
from typing import Any, Callable, Dict, Set
44
from .onnx_export_serialization import (
55
flatten_with_keys_dynamic_cache,
66
flatten_dynamic_cache,
@@ -65,7 +65,7 @@ def unpatch_module(mod, info: Dict[type, Dict[type, Callable]], verbose: int = 0
6565
setattr(original, n, v)
6666

6767

68-
PATCH_OF_PATCHES = set()
68+
PATCH_OF_PATCHES: Set[Any] = set()
6969

7070

7171
def _register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
@@ -116,7 +116,7 @@ def _register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
116116
if (
117117
DynamicCache in torch.fx._pytree.SUPPORTED_NODES
118118
and not PATCH_OF_PATCHES
119-
and pv.Version(torch.__version__) >= pv.Version("2.7")
119+
# and pv.Version(torch.__version__) < pv.Version("2.7")
120120
and pv.Version(transformers.__version__) >= pv.Version("4.50")
121121
):
122122
if verbose:
@@ -132,6 +132,10 @@ def _register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
132132
serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}",
133133
flatten_with_keys_fn=flatten_with_keys_dynamic_cache,
134134
)
135+
if pv.Version(torch.__version__) < pv.Version("2.7"):
136+
torch.fx._pytree.register_pytree_flatten_spec(
137+
DynamicCache, lambda x, _: [x.key_cache, x.value_cache]
138+
)
135139
# To avoid doing it multiple times.
136140
PATCH_OF_PATCHES.add(DynamicCache)
137141

0 commit comments

Comments
 (0)