Skip to content

Commit 87fa333

Browse files
committed
fix issues
1 parent b23031b commit 87fa333

File tree

2 files changed

+126
-65
lines changed

2 files changed

+126
-65
lines changed

_doc/conf.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def linkcode_resolve(domain, info):
9090
"https://sdpython.github.io/doc/experimental-experiment/dev/",
9191
None,
9292
),
93+
"diffusers": ("https://huggingface.co/docs/diffusers/index", None),
9394
"matplotlib": ("https://matplotlib.org/stable/", None),
9495
"numpy": ("https://numpy.org/doc/stable", None),
9596
"onnx": ("https://onnx.ai/onnx/", None),
@@ -104,6 +105,7 @@ def linkcode_resolve(domain, info):
104105
"sklearn": ("https://scikit-learn.org/stable/", None),
105106
"skl2onnx": ("https://onnx.ai/sklearn-onnx/", None),
106107
"torch": ("https://pytorch.org/docs/main/", None),
108+
"transformers": ("https://huggingface.co/docs/transformers/index", None),
107109
}
108110

109111
# Check intersphinx reference targets exist

onnx_diagnostic/torch_export_patches/onnx_export_serialization.py

Lines changed: 124 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,34 @@
1212
StaticCache,
1313
)
1414
from transformers.modeling_outputs import BaseModelOutput
15+
16+
try:
17+
from diffusers.models.autoencoders.vae import DecoderOutput, EncoderOutput
18+
from diffusers.models.unets.unet_1d import UNet1DOutput
19+
from diffusers.models.unets.unet_2d import UNet2DOutput
20+
from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
21+
from diffusers.models.unets.unet_3d_condition import UNet3DConditionOutput
22+
except ImportError as e:
23+
try:
24+
import diffusers
25+
except ImportError:
26+
diffusers = None
27+
DecoderOutput, EncoderOutput = None, None
28+
UNet1DOutput, UNet2DOutput = None, None
29+
UNet2DConditionOutput, UNet3DConditionOutput = None, None
30+
if diffusers:
31+
raise e
32+
1533
from ..helpers import string_type
1634
from ..helpers.cache_helper import make_static_cache
1735

1836

1937
PATCH_OF_PATCHES: Set[Any] = set()
38+
WRONG_REGISTRATIONS: Dict[str, str] = {
39+
DynamicCache: "4.50",
40+
BaseModelOutput: None,
41+
UNet2DConditionOutput: None,
42+
}
2043

2144

2245
def register_class_serialization(
@@ -40,10 +63,12 @@ def register_class_serialization(
4063
:return: registered or not
4164
"""
4265
if cls is not None and cls in torch.utils._pytree.SUPPORTED_NODES:
66+
if verbose and cls is not None:
67+
print(f"[register_class_serialization] already registered {cls.__name__}")
4368
return False
4469

4570
if verbose:
46-
print(f"[register_cache_serialization] register {cls}")
71+
print(f"[register_class_serialization] ---------- register {cls.__name__}")
4772
torch.utils._pytree.register_pytree_node(
4873
cls,
4974
f_flatten,
@@ -54,8 +79,8 @@ def register_class_serialization(
5479
if pv.Version(torch.__version__) < pv.Version("2.7"):
5580
if verbose:
5681
print(
57-
f"[register_cache_serialization] "
58-
f"register {cls} for torch=={torch.__version__}"
82+
f"[register_class_serialization] "
83+
f"---------- register {cls.__name__} for torch=={torch.__version__}"
5984
)
6085
torch.fx._pytree.register_pytree_flatten_spec(cls, lambda x, _: f_flatten(x)[0])
6186

@@ -77,6 +102,8 @@ def register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
77102
Registers many classes with :func:`register_class_serialization`.
78103
Returns information needed to undo the registration.
79104
"""
105+
registration_functions = serialization_functions(verbose=verbose)
106+
80107
# DynamicCache serialization is different in transformers and does not
81108
# play way with torch.export.export.
82109
# see test test_export_dynamic_cache_cat with NOBYPASS=1
@@ -85,109 +112,102 @@ def register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
85112
# torch.fx._pytree.register_pytree_flatten_spec(
86113
# DynamicCache, _flatten_dynamic_cache_for_fx)
87114
# so we remove it anyway
88-
if (
89-
DynamicCache in torch.utils._pytree.SUPPORTED_NODES
90-
and DynamicCache not in PATCH_OF_PATCHES
91-
# and pv.Version(torch.__version__) < pv.Version("2.7")
92-
and pv.Version(transformers.__version__) >= pv.Version("4.50")
93-
):
94-
if verbose:
95-
print(
96-
f"[_fix_registration] DynamicCache is unregistered and "
97-
f"registered first for transformers=={transformers.__version__}"
98-
)
99-
unregister(DynamicCache, verbose=verbose)
100-
register_class_serialization(
101-
DynamicCache,
102-
flatten_dynamic_cache,
103-
unflatten_dynamic_cache,
104-
flatten_with_keys_dynamic_cache,
105-
# f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
106-
verbose=verbose,
107-
)
108-
if verbose:
109-
print("[_fix_registration] DynamicCache done.")
110-
# To avoid doing it multiple times.
111-
PATCH_OF_PATCHES.add(DynamicCache)
112-
113115
# BaseModelOutput serialization is incomplete.
114116
# It does not include dynamic shapes mapping.
115-
if (
116-
BaseModelOutput in torch.utils._pytree.SUPPORTED_NODES
117-
and BaseModelOutput not in PATCH_OF_PATCHES
118-
):
119-
if verbose:
120-
print(
121-
f"[_fix_registration] BaseModelOutput is unregistered and "
122-
f"registered first for transformers=={transformers.__version__}"
117+
for cls, version in WRONG_REGISTRATIONS.items():
118+
if (
119+
cls in torch.utils._pytree.SUPPORTED_NODES
120+
and cls not in PATCH_OF_PATCHES
121+
# and pv.Version(torch.__version__) < pv.Version("2.7")
122+
and (
123+
version is None or pv.Version(transformers.__version__) >= pv.Version(version)
123124
)
124-
unregister(BaseModelOutput, verbose=verbose)
125-
register_class_serialization(
126-
BaseModelOutput,
127-
flatten_base_model_output,
128-
unflatten_base_model_output,
129-
flatten_with_keys_base_model_output,
130-
verbose=verbose,
131-
)
132-
if verbose:
133-
print("[_fix_registration] BaseModelOutput done.")
134-
135-
# To avoid doing it multiple times.
136-
PATCH_OF_PATCHES.add(BaseModelOutput)
137-
138-
return serialization_functions(verbose=verbose)
139-
140-
141-
def serialization_functions(verbose: int = 0) -> Dict[str, Union[Callable, int]]:
125+
):
126+
assert cls in registration_functions, (
127+
f"{cls} has no registration functions mapped to it, "
128+
f"available {sorted(registration_functions)}"
129+
)
130+
if verbose:
131+
print(
132+
f"[_fix_registration] {cls.__name__} is unregistered and "
133+
f"registered first"
134+
)
135+
unregister_class_serialization(cls, verbose=verbose)
136+
registration_functions[cls](verbose=verbose)
137+
if verbose:
138+
print(f"[_fix_registration] {cls.__name__} done.")
139+
# To avoid doing it multiple times.
140+
PATCH_OF_PATCHES.add(cls)
141+
142+
# classes with no registration at all.
143+
done = {}
144+
for k, v in registration_functions.items():
145+
done[k] = v(verbose=verbose)
146+
return done
147+
148+
149+
def serialization_functions(verbose: int = 0) -> Dict[type, Union[Callable[[], bool], int]]:
142150
"""Returns the list of serialization functions."""
143-
return dict(
144-
DynamicCache=register_class_serialization(
151+
transformers_classes = {
152+
DynamicCache: lambda verbose=verbose: register_class_serialization(
145153
DynamicCache,
146154
flatten_dynamic_cache,
147155
unflatten_dynamic_cache,
148156
flatten_with_keys_dynamic_cache,
149157
# f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
150158
verbose=verbose,
151159
),
152-
MambaCache=register_class_serialization(
160+
MambaCache: lambda verbose=verbose: register_class_serialization(
153161
MambaCache,
154162
flatten_mamba_cache,
155163
unflatten_mamba_cache,
156164
flatten_with_keys_mamba_cache,
157165
verbose=verbose,
158166
),
159-
EncoderDecoderCache=register_class_serialization(
167+
EncoderDecoderCache: lambda verbose=verbose: register_class_serialization(
160168
EncoderDecoderCache,
161169
flatten_encoder_decoder_cache,
162170
unflatten_encoder_decoder_cache,
163171
flatten_with_keys_encoder_decoder_cache,
164172
verbose=verbose,
165173
),
166-
BaseModelOutput=register_class_serialization(
174+
BaseModelOutput: lambda verbose=verbose: register_class_serialization(
167175
BaseModelOutput,
168176
flatten_base_model_output,
169177
unflatten_base_model_output,
170178
flatten_with_keys_base_model_output,
171179
verbose=verbose,
172180
),
173-
SlidingWindowCache=register_class_serialization(
181+
SlidingWindowCache: lambda verbose=verbose: register_class_serialization(
174182
SlidingWindowCache,
175183
flatten_sliding_window_cache,
176184
unflatten_sliding_window_cache,
177185
flatten_with_keys_sliding_window_cache,
178186
verbose=verbose,
179187
),
180-
StaticCache=register_class_serialization(
188+
StaticCache: lambda verbose=verbose: register_class_serialization(
181189
StaticCache,
182190
flatten_static_cache,
183191
unflatten_static_cache,
184192
flatten_with_keys_static_cache,
185193
verbose=verbose,
186194
),
187-
)
195+
}
196+
if UNet2DConditionOutput:
197+
diffusers_classes = {
198+
UNet2DConditionOutput: lambda verbose=verbose: register_class_serialization(
199+
UNet2DConditionOutput,
200+
flatten_unet_2d_condition_output,
201+
unflatten_unet_2d_condition_output,
202+
flatten_with_keys_unet_2d_condition_output,
203+
verbose=verbose,
204+
)
205+
}
206+
transformers_classes.update(diffusers_classes)
207+
return transformers_classes
188208

189209

190-
def unregister(cls: type, verbose: int = 0):
210+
def unregister_class_serialization(cls: type, verbose: int = 0):
191211
"""Undo the registration."""
192212
# torch.utils._pytree._deregister_pytree_flatten_spec(cls)
193213
if cls in torch.fx._pytree.SUPPORTED_NODES:
@@ -217,9 +237,10 @@ def unregister(cls: type, verbose: int = 0):
217237

218238
def unregister_cache_serialization(undo: Dict[str, bool], verbose: int = 0):
219239
"""Undo all registrations."""
220-
for cls in [MambaCache, DynamicCache, EncoderDecoderCache, BaseModelOutput]:
240+
cls_ensemble = {MambaCache, DynamicCache, EncoderDecoderCache, BaseModelOutput} | set(undo)
241+
for cls in cls_ensemble:
221242
if undo.get(cls.__name__, False):
222-
unregister(cls, verbose)
243+
unregister_class_serialization(cls, verbose)
223244

224245

225246
############
@@ -478,3 +499,41 @@ def unflatten_base_model_output(
478499
from python objects.
479500
"""
480501
return BaseModelOutput(**dict(zip(context, values)))
502+
503+
504+
#######################
505+
# UNet2DConditionOutput
506+
#######################
507+
508+
509+
def flatten_unet_2d_condition_output(
510+
obj: UNet2DConditionOutput,
511+
) -> Tuple[List[Any], torch.utils._pytree.Context]:
512+
"""
513+
Serializes a :class:`diffusers.models.unets.unet_2d_condition.UNet2DConditionOutput`
514+
with python objects.
515+
"""
516+
return list(obj.values()), list(obj.keys())
517+
518+
519+
def flatten_with_keys_unet_2d_condition_output(
520+
obj: UNet2DConditionOutput,
521+
) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
522+
"""
523+
Serializes a :class:`diffusers.models.unets.unet_2d_condition.UNet2DConditionOutput`
524+
with python objects.
525+
"""
526+
values, context = flatten_unet_2d_condition_output(obj)
527+
return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
528+
529+
530+
def unflatten_unet_2d_condition_output(
531+
values: List[Any],
532+
context: torch.utils._pytree.Context,
533+
output_type=None,
534+
) -> UNet2DConditionOutput:
535+
"""
536+
Restores a :class:`diffusers.models.unets.unet_2d_condition.UNet2DConditionOutput`
537+
from python objects.
538+
"""
539+
return UNet2DConditionOutput(**dict(zip(context, values)))

0 commit comments

Comments
 (0)