Skip to content

Commit aaa5372

Browse files
committed
Refactors serialization
1 parent 57c920b commit aaa5372

File tree

9 files changed

+208
-141
lines changed

9 files changed

+208
-141
lines changed

_doc/api/torch_export_patches/onnx_export_serialization_impl.rst

Lines changed: 0 additions & 7 deletions
This file was deleted.
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
2+
onnx_diagnostic.torch_export_patches.serialization.diffusers_impl
3+
=================================================================
4+
5+
.. automodule:: onnx_diagnostic.torch_export_patches.serialization.diffusers_impl
6+
:members:
7+
:no-undoc-members:
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
onnx_diagnostic.torch_export_patches.serialization
2+
==================================================
3+
4+
.. toctree::
5+
:maxdepth: 1
6+
:caption: submodules
7+
8+
diffusers_impl
9+
transformers_impl
10+
11+
.. automodule:: onnx_diagnostic.torch_export_patches.serialization
12+
:members:
13+
:no-undoc-members:
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
2+
onnx_diagnostic.torch_export_patches.serialization.transformers_impl
3+
====================================================================
4+
5+
.. automodule:: onnx_diagnostic.torch_export_patches.serialization.transformers_impl
6+
:members:
7+
:no-undoc-members:

onnx_diagnostic/torch_export_patches/onnx_export_errors.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,11 +134,17 @@ def unpatch_module_or_classes(mod, info: Dict[type, Dict[type, Callable]], verbo
134134

135135
@contextlib.contextmanager
136136
def register_additional_serialization_functions(
137-
patch_transformers: bool = False, verbose: int = 0
137+
patch_transformers: bool = False, patch_diffusers: bool = False, verbose: int = 0
138138
) -> Callable:
139139
"""The necessary modifications to run the fx Graph."""
140-
fct_callable = replacement_before_exporting if patch_transformers else (lambda x: x)
141-
done = register_cache_serialization(verbose=verbose)
140+
fct_callable = (
141+
replacement_before_exporting
142+
if patch_transformers or patch_diffusers
143+
else (lambda x: x)
144+
)
145+
done = register_cache_serialization(
146+
patch_transformers=patch_transformers, patch_diffusers=patch_diffusers, verbose=verbose
147+
)
142148
try:
143149
yield fct_callable
144150
finally:
@@ -150,6 +156,7 @@ def torch_export_patches(
150156
patch_sympy: bool = True,
151157
patch_torch: bool = True,
152158
patch_transformers: bool = False,
159+
patch_diffusers: bool = False,
153160
catch_constraints: bool = True,
154161
stop_if_static: int = 0,
155162
verbose: int = 0,
@@ -165,6 +172,7 @@ def torch_export_patches(
165172
:param patch_sympy: fix missing method ``name`` for IntegerConstant
166173
:param patch_torch: patches :epkg:`torch` with supported implementation
167174
:param patch_transformers: patches :epkg:`transformers` with supported implementation
175+
:param patch_diffusers: patches :epkg:`diffusers` with supported implementation
168176
:param catch_constraints: catch constraints related to dynamic shapes,
169177
as a result, some dynamic dimension may turn into static ones,
170178
the environment variable ``SKIP_SOLVE_CONSTRAINTS=0``
@@ -249,6 +257,7 @@ def torch_export_patches(
249257
patch_sympy=patch_sympy,
250258
patch_torch=patch_torch,
251259
patch_transformers=patch_transformers,
260+
patch_diffusers=patch_diffusers,
252261
catch_constraints=catch_constraints,
253262
stop_if_static=stop_if_static,
254263
verbose=verbose,
@@ -281,7 +290,11 @@ def torch_export_patches(
281290
# caches
282291
########
283292

284-
cache_done = register_cache_serialization(verbose=verbose)
293+
cache_done = register_cache_serialization(
294+
patch_transformers=patch_transformers,
295+
patch_diffusers=patch_diffusers,
296+
verbose=verbose,
297+
)
285298

286299
#############
287300
# patch sympy

onnx_diagnostic/torch_export_patches/onnx_export_serialization.py

Lines changed: 77 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
)
1414

1515
from ..helpers import string_type
16-
16+
from .serialization import _lower_name_with_
1717

1818
PATCH_OF_PATCHES: Set[Any] = set()
1919

@@ -73,14 +73,25 @@ def register_class_serialization(
7373
return True
7474

7575

76-
def register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
76+
def register_cache_serialization(
77+
patch_transformers: bool = False, patch_diffusers: bool = True, verbose: int = 0
78+
) -> Dict[str, bool]:
7779
"""
7880
Registers many classes with :func:`register_class_serialization`.
7981
Returns information needed to undo the registration.
82+
83+
:param patch_transformers: add serialization function for
84+
:epkg:`transformers` package
85+
:param patch_diffusers: add serialization function for
86+
:epkg:`diffusers` package
87+
:param verbosity: verbosity level
88+
:return: information to unpatch
8089
"""
8190
from .onnx_export_serialization_impl import WRONG_REGISTRATIONS
8291

83-
registration_functions = serialization_functions(verbose=verbose)
92+
registration_functions = serialization_functions(
93+
patch_transformers=patch_transformers, patch_diffusers=patch_diffusers, verbose=verbose
94+
)
8495

8596
# DynamicCache serialization is different in transformers and does not
8697
# play way with torch.export.export.
@@ -124,68 +135,86 @@ def register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
124135
return done
125136

126137

127-
def serialization_functions(verbose: int = 0) -> Dict[type, Callable[[int], bool]]:
138+
def serialization_functions(
139+
patch_transformers: bool = False, patch_diffusers: bool = False, verbose: int = 0
140+
) -> Dict[type, Callable[[int], bool]]:
128141
"""Returns the list of serialization functions."""
129-
from .onnx_export_serialization_impl import (
130-
SUPPORTED_DATACLASSES,
131-
_lower_name_with_,
132-
__dict__ as all_functions,
133-
flatten_dynamic_cache,
134-
unflatten_dynamic_cache,
135-
flatten_with_keys_dynamic_cache,
136-
flatten_mamba_cache,
137-
unflatten_mamba_cache,
138-
flatten_with_keys_mamba_cache,
139-
flatten_encoder_decoder_cache,
140-
unflatten_encoder_decoder_cache,
141-
flatten_with_keys_encoder_decoder_cache,
142-
flatten_sliding_window_cache,
143-
unflatten_sliding_window_cache,
144-
flatten_with_keys_sliding_window_cache,
145-
flatten_static_cache,
146-
unflatten_static_cache,
147-
flatten_with_keys_static_cache,
148-
)
149142

150-
transformers_classes = {
151-
DynamicCache: lambda verbose=verbose: register_class_serialization(
152-
DynamicCache,
143+
supported_classes = set()
144+
classes = {}
145+
all_functions = {}
146+
147+
if patch_transformers:
148+
from .serialization.transformers_impl import (
149+
__dict__ as dtr,
150+
SUPPORTED_DATACLASSES,
153151
flatten_dynamic_cache,
154152
unflatten_dynamic_cache,
155153
flatten_with_keys_dynamic_cache,
156-
# f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
157-
verbose=verbose,
158-
),
159-
MambaCache: lambda verbose=verbose: register_class_serialization(
160-
MambaCache,
161154
flatten_mamba_cache,
162155
unflatten_mamba_cache,
163156
flatten_with_keys_mamba_cache,
164-
verbose=verbose,
165-
),
166-
EncoderDecoderCache: lambda verbose=verbose: register_class_serialization(
167-
EncoderDecoderCache,
168157
flatten_encoder_decoder_cache,
169158
unflatten_encoder_decoder_cache,
170159
flatten_with_keys_encoder_decoder_cache,
171-
verbose=verbose,
172-
),
173-
SlidingWindowCache: lambda verbose=verbose: register_class_serialization(
174-
SlidingWindowCache,
175160
flatten_sliding_window_cache,
176161
unflatten_sliding_window_cache,
177162
flatten_with_keys_sliding_window_cache,
178-
verbose=verbose,
179-
),
180-
StaticCache: lambda verbose=verbose: register_class_serialization(
181-
StaticCache,
182163
flatten_static_cache,
183164
unflatten_static_cache,
184165
flatten_with_keys_static_cache,
185-
verbose=verbose,
186-
),
187-
}
188-
for cls in SUPPORTED_DATACLASSES:
166+
)
167+
168+
all_functions.update(dtr)
169+
supported_classes |= SUPPORTED_DATACLASSES
170+
171+
transformers_classes = {
172+
DynamicCache: lambda verbose=verbose: register_class_serialization(
173+
DynamicCache,
174+
flatten_dynamic_cache,
175+
unflatten_dynamic_cache,
176+
flatten_with_keys_dynamic_cache,
177+
# f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
178+
verbose=verbose,
179+
),
180+
MambaCache: lambda verbose=verbose: register_class_serialization(
181+
MambaCache,
182+
flatten_mamba_cache,
183+
unflatten_mamba_cache,
184+
flatten_with_keys_mamba_cache,
185+
verbose=verbose,
186+
),
187+
EncoderDecoderCache: lambda verbose=verbose: register_class_serialization(
188+
EncoderDecoderCache,
189+
flatten_encoder_decoder_cache,
190+
unflatten_encoder_decoder_cache,
191+
flatten_with_keys_encoder_decoder_cache,
192+
verbose=verbose,
193+
),
194+
SlidingWindowCache: lambda verbose=verbose: register_class_serialization(
195+
SlidingWindowCache,
196+
flatten_sliding_window_cache,
197+
unflatten_sliding_window_cache,
198+
flatten_with_keys_sliding_window_cache,
199+
verbose=verbose,
200+
),
201+
StaticCache: lambda verbose=verbose: register_class_serialization(
202+
StaticCache,
203+
flatten_static_cache,
204+
unflatten_static_cache,
205+
flatten_with_keys_static_cache,
206+
verbose=verbose,
207+
),
208+
}
209+
classes.update(patch_transformers)
210+
211+
if patch_diffusers:
212+
from .serialization.diffusers_impl import SUPPORTED_DATACLASSES, __dict__ as dfu
213+
214+
all_functions.update(dfu)
215+
supported_classes |= SUPPORTED_DATACLASSES
216+
217+
for cls in supported_classes:
189218
lname = _lower_name_with_(cls.__name__)
190219
assert (
191220
f"flatten_{lname}" in all_functions
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import re
2+
from typing import Any, Callable, List, Set, Tuple
3+
import torch
4+
5+
6+
def _lower_name_with_(name):
7+
s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
8+
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
9+
10+
11+
def make_serialization_function_for_dataclass(
12+
cls: type, supported_classes: Set[type]
13+
) -> Tuple[Callable, Callable, Callable]:
14+
"""
15+
Automatically creates serialization function for a class decorated with
16+
``dataclasses.dataclass``.
17+
"""
18+
19+
def flatten_cls(obj: cls) -> Tuple[List[Any], torch.utils._pytree.Context]:
20+
"""Serializes a ``%s`` with python objects."""
21+
return list(obj.values()), list(obj.keys())
22+
23+
def flatten_with_keys_cls(
24+
obj: cls,
25+
) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
26+
"""Serializes a ``%s`` with python objects with keys."""
27+
values, context = list(obj.values()), list(obj.keys())
28+
return [
29+
(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)
30+
], context
31+
32+
def unflatten_cls(
33+
values: List[Any], context: torch.utils._pytree.Context, output_type=None
34+
) -> cls:
35+
"""Restores an instance of ``%s`` from python objects."""
36+
return cls(**dict(zip(context, values)))
37+
38+
name = _lower_name_with_(cls.__name__)
39+
flatten_cls.__name__ = f"flatten_{name}"
40+
flatten_with_keys_cls.__name__ = f"flatten_with_keys_{name}"
41+
unflatten_cls.__name__ = f"unflatten_{name}"
42+
flatten_cls.__doc__ = flatten_cls.__doc__ % cls.__name__
43+
flatten_with_keys_cls.__doc__ = flatten_with_keys_cls.__doc__ % cls.__name__
44+
unflatten_cls.__doc__ = unflatten_cls.__doc__ % cls.__name__
45+
supported_classes.add(cls)
46+
return flatten_cls, flatten_with_keys_cls, unflatten_cls
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from typing import Dict, Optional
2+
3+
try:
4+
from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
5+
except ImportError as e:
6+
try:
7+
import diffusers
8+
except ImportError:
9+
diffusers = None
10+
UNet2DConditionOutput = None
11+
if diffusers:
12+
raise e
13+
14+
from . import make_serialization_function_for_dataclass
15+
16+
17+
def _make_wrong_registrations() -> Dict[str, Optional[str]]:
18+
res = {}
19+
for c in [UNet2DConditionOutput]:
20+
if c is not None:
21+
res[c] = None
22+
return res
23+
24+
25+
SUPPORTED_DATACLASSES = set()
26+
WRONG_REGISTRATIONS = _make_wrong_registrations()
27+
28+
29+
if UNet2DConditionOutput is not None:
30+
(
31+
flatten_u_net2_d_condition_output,
32+
flatten_with_keys_u_net2_d_condition_output,
33+
unflatten_u_net2_d_condition_output,
34+
) = make_serialization_function_for_dataclass(UNet2DConditionOutput, SUPPORTED_DATACLASSES)

0 commit comments

Comments
 (0)