Skip to content

Commit 2821a35

Browse files
committed
refactor
1 parent b37ec29 commit 2821a35

File tree

2 files changed

+18
-26
lines changed

2 files changed

+18
-26
lines changed

onnx_diagnostic/torch_export_patches/onnx_export_serialization.py

Lines changed: 4 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -11,34 +11,11 @@
1111
SlidingWindowCache,
1212
StaticCache,
1313
)
14-
from transformers.modeling_outputs import BaseModelOutput
15-
16-
try:
17-
from diffusers.models.autoencoders.vae import DecoderOutput, EncoderOutput
18-
from diffusers.models.unets.unet_1d import UNet1DOutput
19-
from diffusers.models.unets.unet_2d import UNet2DOutput
20-
from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
21-
from diffusers.models.unets.unet_3d_condition import UNet3DConditionOutput
22-
except ImportError as e:
23-
try:
24-
import diffusers
25-
except ImportError:
26-
diffusers = None
27-
DecoderOutput, EncoderOutput = None, None
28-
UNet1DOutput, UNet2DOutput = None, None
29-
UNet2DConditionOutput, UNet3DConditionOutput = None, None
30-
if diffusers:
31-
raise e
3214

3315
from ..helpers import string_type
3416

3517

3618
PATCH_OF_PATCHES: Set[Any] = set()
37-
WRONG_REGISTRATIONS: Dict[str, Optional[str]] = {
38-
DynamicCache: "4.50",
39-
BaseModelOutput: None,
40-
UNet2DConditionOutput: None,
41-
}
4219

4320

4421
def register_class_serialization(
@@ -101,6 +78,8 @@ def register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
10178
Registers many classes with :func:`register_class_serialization`.
10279
Returns information needed to undo the registration.
10380
"""
81+
from .onnx_export_serialization_impl import WRONG_REGISTRATIONS
82+
10483
registration_functions = serialization_functions(verbose=verbose)
10584

10685
# DynamicCache serialization is different in transformers and does not
@@ -212,7 +191,7 @@ def serialization_functions(verbose: int = 0) -> Dict[type, Callable[[int], bool
212191
f"flatten_{lname}" in all_functions
213192
), f"Unable to find function 'flatten_{lname}' in {sorted(all_functions)}"
214193
transformers_classes[cls] = (
215-
lambda verbose=verbose, _ln=lname, cls=cls, _al=all_functions: register_class_serialization(
194+
lambda verbose=verbose, _ln=lname, cls=cls, _al=all_functions: register_class_serialization( # noqa: E501
216195
cls,
217196
_al[f"flatten_{_ln}"],
218197
_al[f"unflatten_{_ln}"],
@@ -253,7 +232,7 @@ def unregister_class_serialization(cls: type, verbose: int = 0):
253232

254233
def unregister_cache_serialization(undo: Dict[str, bool], verbose: int = 0):
255234
"""Undo all registrations."""
256-
cls_ensemble = {MambaCache, DynamicCache, EncoderDecoderCache, BaseModelOutput} | set(undo)
235+
cls_ensemble = {MambaCache, DynamicCache, EncoderDecoderCache} | set(undo)
257236
for cls in cls_ensemble:
258237
if undo.get(cls.__name__, False):
259238
unregister_class_serialization(cls, verbose)

onnx_diagnostic/torch_export_patches/onnx_export_serialization_impl.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import re
2-
from typing import Any, Callable, List, Tuple
2+
from typing import Any, Callable, Dict, List, Optional, Tuple
33
import torch
44
import transformers
55
from transformers.cache_utils import (
@@ -31,7 +31,20 @@
3131
from ..helpers.cache_helper import make_static_cache
3232

3333

34+
def _make_wrong_registrations() -> Dict[str, Optional[str]]:
35+
res = {
36+
DynamicCache: "4.50",
37+
BaseModelOutput: None,
38+
}
39+
for c in [UNet2DConditionOutput]:
40+
if c is not None:
41+
res[c] = None
42+
return res
43+
44+
3445
SUPPORTED_DATACLASSES = set()
46+
WRONG_REGISTRATIONS = _make_wrong_registrations()
47+
3548

3649
############
3750
# MambaCache

0 commit comments

Comments
 (0)