Skip to content

Commit 1c682f5

Browse files
committed
fix wrogn import
1 parent 51f3f24 commit 1c682f5

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

onnx_diagnostic/torch_export_patches/onnx_export_serialization.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,15 @@ def register_cache_serialization(
8787
:param verbosity: verbosity level
8888
:return: information to unpatch
8989
"""
90-
from .onnx_export_serialization_impl import WRONG_REGISTRATIONS
90+
wrong: Dict[type, Optional[str]] = {}
91+
if patch_transformers:
92+
from .serialization.transformers_impl import WRONG_REGISTRATIONS
93+
94+
wrong |= WRONG_REGISTRATIONS
95+
if patch_diffusers:
96+
from .serialization.diffusers_impl import WRONG_REGISTRATIONS
97+
98+
wrong |= WRONG_REGISTRATIONS
9199

92100
registration_functions = serialization_functions(
93101
patch_transformers=patch_transformers, patch_diffusers=patch_diffusers, verbose=verbose
@@ -103,7 +111,7 @@ def register_cache_serialization(
103111
# so we remove it anyway
104112
# BaseModelOutput serialization is incomplete.
105113
# It does not include dynamic shapes mapping.
106-
for cls, version in WRONG_REGISTRATIONS.items():
114+
for cls, version in wrong.items():
107115
if (
108116
cls in torch.utils._pytree.SUPPORTED_NODES
109117
and cls not in PATCH_OF_PATCHES

onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
StaticCache,
1010
)
1111
from transformers.modeling_outputs import BaseModelOutput
12-
from ..helpers.cache_helper import make_static_cache
12+
from ...helpers.cache_helper import make_static_cache
1313
from . import make_serialization_function_for_dataclass
1414

1515

0 commit comments

Comments
 (0)