|
11 | 11 | SlidingWindowCache, |
12 | 12 | StaticCache, |
13 | 13 | ) |
14 | | -from transformers.modeling_outputs import BaseModelOutput |
15 | | - |
16 | | -try: |
17 | | - from diffusers.models.autoencoders.vae import DecoderOutput, EncoderOutput |
18 | | - from diffusers.models.unets.unet_1d import UNet1DOutput |
19 | | - from diffusers.models.unets.unet_2d import UNet2DOutput |
20 | | - from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput |
21 | | - from diffusers.models.unets.unet_3d_condition import UNet3DConditionOutput |
22 | | -except ImportError as e: |
23 | | - try: |
24 | | - import diffusers |
25 | | - except ImportError: |
26 | | - diffusers = None |
27 | | - DecoderOutput, EncoderOutput = None, None |
28 | | - UNet1DOutput, UNet2DOutput = None, None |
29 | | - UNet2DConditionOutput, UNet3DConditionOutput = None, None |
30 | | - if diffusers: |
31 | | - raise e |
32 | 14 |
|
33 | 15 | from ..helpers import string_type |
34 | 16 |
|
35 | 17 |
|
36 | 18 | PATCH_OF_PATCHES: Set[Any] = set() |
37 | | -WRONG_REGISTRATIONS: Dict[str, Optional[str]] = { |
38 | | - DynamicCache: "4.50", |
39 | | - BaseModelOutput: None, |
40 | | - UNet2DConditionOutput: None, |
41 | | -} |
42 | 19 |
|
43 | 20 |
|
44 | 21 | def register_class_serialization( |
@@ -101,6 +78,8 @@ def register_cache_serialization(verbose: int = 0) -> Dict[str, bool]: |
101 | 78 | Registers many classes with :func:`register_class_serialization`. |
102 | 79 | Returns information needed to undo the registration. |
103 | 80 | """ |
| 81 | + from .onnx_export_serialization_impl import WRONG_REGISTRATIONS |
| 82 | + |
104 | 83 | registration_functions = serialization_functions(verbose=verbose) |
105 | 84 |
|
106 | 85 | # DynamicCache serialization is different in transformers and does not |
@@ -212,7 +191,7 @@ def serialization_functions(verbose: int = 0) -> Dict[type, Callable[[int], bool |
212 | 191 | f"flatten_{lname}" in all_functions |
213 | 192 | ), f"Unable to find function 'flatten_{lname}' in {sorted(all_functions)}" |
214 | 193 | transformers_classes[cls] = ( |
215 | | - lambda verbose=verbose, _ln=lname, cls=cls, _al=all_functions: register_class_serialization( |
| 194 | + lambda verbose=verbose, _ln=lname, cls=cls, _al=all_functions: register_class_serialization( # noqa: E501 |
216 | 195 | cls, |
217 | 196 | _al[f"flatten_{_ln}"], |
218 | 197 | _al[f"unflatten_{_ln}"], |
@@ -253,7 +232,7 @@ def unregister_class_serialization(cls: type, verbose: int = 0): |
253 | 232 |
|
254 | 233 | def unregister_cache_serialization(undo: Dict[str, bool], verbose: int = 0): |
255 | 234 | """Undo all registrations.""" |
256 | | - cls_ensemble = {MambaCache, DynamicCache, EncoderDecoderCache, BaseModelOutput} | set(undo) |
| 235 | + cls_ensemble = {MambaCache, DynamicCache, EncoderDecoderCache} | set(undo) |
257 | 236 | for cls in cls_ensemble: |
258 | 237 | if undo.get(cls.__name__, False): |
259 | 238 | unregister_class_serialization(cls, verbose) |
0 commit comments