diff --git a/_doc/api/torch_export_patches/index.rst b/_doc/api/torch_export_patches/index.rst index 23f136a5..d014ed2e 100644 --- a/_doc/api/torch_export_patches/index.rst +++ b/_doc/api/torch_export_patches/index.rst @@ -8,12 +8,12 @@ onnx_diagnostic.torch_export_patches eval/index onnx_export_errors onnx_export_serialization - onnx_export_serialization_impl patches/index patch_expressions patch_inputs patch_module patch_module_helper + serialization/index .. automodule:: onnx_diagnostic.torch_export_patches :members: diff --git a/_doc/api/torch_export_patches/onnx_export_serialization_impl.rst b/_doc/api/torch_export_patches/onnx_export_serialization_impl.rst deleted file mode 100644 index 22a94d27..00000000 --- a/_doc/api/torch_export_patches/onnx_export_serialization_impl.rst +++ /dev/null @@ -1,7 +0,0 @@ - -onnx_diagnostic.torch_export_patches.onnx_export_serialization_impl -=================================================================== - -.. automodule:: onnx_diagnostic.torch_export_patches.onnx_export_serialization_impl - :members: - :no-undoc-members: diff --git a/_doc/api/torch_export_patches/serialization/diffusers_impl.rst b/_doc/api/torch_export_patches/serialization/diffusers_impl.rst new file mode 100644 index 00000000..d9bb3d09 --- /dev/null +++ b/_doc/api/torch_export_patches/serialization/diffusers_impl.rst @@ -0,0 +1,7 @@ + +onnx_diagnostic.torch_export_patches.serialization.diffusers_impl +================================================================= + +.. automodule:: onnx_diagnostic.torch_export_patches.serialization.diffusers_impl + :members: + :no-undoc-members: diff --git a/_doc/api/torch_export_patches/serialization/index.rst b/_doc/api/torch_export_patches/serialization/index.rst new file mode 100644 index 00000000..654f0ec4 --- /dev/null +++ b/_doc/api/torch_export_patches/serialization/index.rst @@ -0,0 +1,13 @@ +onnx_diagnostic.torch_export_patches.serialization +================================================== + +.. toctree:: + :maxdepth: 1 + :caption: submodules + + diffusers_impl + transformers_impl + +.. automodule:: onnx_diagnostic.torch_export_patches.serialization + :members: + :no-undoc-members: diff --git a/_doc/api/torch_export_patches/serialization/transformers_impl.rst b/_doc/api/torch_export_patches/serialization/transformers_impl.rst new file mode 100644 index 00000000..7929dc0c --- /dev/null +++ b/_doc/api/torch_export_patches/serialization/transformers_impl.rst @@ -0,0 +1,7 @@ + +onnx_diagnostic.torch_export_patches.serialization.transformers_impl +==================================================================== + +.. automodule:: onnx_diagnostic.torch_export_patches.serialization.transformers_impl + :members: + :no-undoc-members: diff --git a/_unittests/ut_torch_export_patches/test_patch_serialization_diffusers.py b/_unittests/ut_torch_export_patches/test_patch_serialization_diffusers.py new file mode 100644 index 00000000..7a2475da --- /dev/null +++ b/_unittests/ut_torch_export_patches/test_patch_serialization_diffusers.py @@ -0,0 +1,71 @@ +import unittest +import torch +from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings, requires_diffusers +from onnx_diagnostic.helpers.cache_helper import flatten_unflatten_for_dynamic_shapes +from onnx_diagnostic.torch_export_patches.onnx_export_errors import ( + torch_export_patches, +) +from onnx_diagnostic.helpers.torch_helper import torch_deepcopy + + +class TestPatchSerializationDiffusers(ExtTestCase): + @ignore_warnings(UserWarning) + @requires_diffusers("0.30") + def test_unet_2d_condition_output(self): + from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput + + bo = UNet2DConditionOutput(sample=torch.rand((4, 4, 4))) + self.assertEqual(bo.__class__.__name__, "UNet2DConditionOutput") + bo2 = torch_deepcopy([bo]) + self.assertIsInstance(bo2, list) + self.assertEqual( + "UNet2DConditionOutput(sample:T1s4x4x4)", + self.string_type(bo, with_shape=True), + ) + + with torch_export_patches(patch_diffusers=True): + # internal function + bo2 = torch_deepcopy([bo]) + self.assertIsInstance(bo2, list) + self.assertEqual(bo2[0].__class__.__name__, "UNet2DConditionOutput") + self.assertEqualAny([bo], bo2) + self.assertEqual( + "UNet2DConditionOutput(sample:T1s4x4x4)", + self.string_type(bo, with_shape=True), + ) + + # serialization + flat, _spec = torch.utils._pytree.tree_flatten(bo) + self.assertEqual( + "#1[T1s4x4x4]", + self.string_type(flat, with_shape=True), + ) + bo2 = torch.utils._pytree.tree_unflatten(flat, _spec) + self.assertEqual( + self.string_type(bo, with_shape=True, with_min_max=True), + self.string_type(bo2, with_shape=True, with_min_max=True), + ) + + # flatten_unflatten + flat, _spec = torch.utils._pytree.tree_flatten(bo) + unflat = flatten_unflatten_for_dynamic_shapes(bo, use_dict=True) + self.assertIsInstance(unflat, dict) + self.assertEqual(list(unflat), ["sample"]) + + # export + class Model(torch.nn.Module): + def forward(self, cache): + return cache.sample[0] + + bo = UNet2DConditionOutput(sample=torch.rand((4, 4, 4))) + model = Model() + model(bo) + DYN = torch.export.Dim.DYNAMIC + ds = [{0: DYN}] + + with torch_export_patches(patch_diffusers=True): + torch.export.export(model, (bo,), dynamic_shapes=(ds,)) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/_unittests/ut_torch_export_patches/test_patch_serialization.py b/_unittests/ut_torch_export_patches/test_patch_serialization_transformers.py similarity index 82% rename from _unittests/ut_torch_export_patches/test_patch_serialization.py rename to _unittests/ut_torch_export_patches/test_patch_serialization_transformers.py index 9469cefd..f2432c99 100644 --- a/_unittests/ut_torch_export_patches/test_patch_serialization.py +++ b/_unittests/ut_torch_export_patches/test_patch_serialization_transformers.py @@ -1,12 +1,7 @@ import unittest import torch from transformers.modeling_outputs import BaseModelOutput -from onnx_diagnostic.ext_test_case import ( - ExtTestCase, - ignore_warnings, - requires_torch, - requires_diffusers, -) +from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings, requires_torch from onnx_diagnostic.helpers.cache_helper import ( make_encoder_decoder_cache, make_dynamic_cache, @@ -159,7 +154,7 @@ def forward(self, cache): DYN = torch.export.Dim.DYNAMIC ds = [{0: DYN}] - with torch_export_patches(): + with torch_export_patches(patch_transformers=True): torch.export.export(model, (bo,), dynamic_shapes=(ds,)) @ignore_warnings(UserWarning) @@ -218,63 +213,6 @@ def test_sliding_window_cache_flatten(self): self.string_type(cache2, with_shape=True, with_min_max=True), ) - @ignore_warnings(UserWarning) - @requires_diffusers("0.30") - def test_unet_2d_condition_output(self): - from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput - - bo = UNet2DConditionOutput(sample=torch.rand((4, 4, 4))) - self.assertEqual(bo.__class__.__name__, "UNet2DConditionOutput") - bo2 = torch_deepcopy([bo]) - self.assertIsInstance(bo2, list) - self.assertEqual( - "UNet2DConditionOutput(sample:T1s4x4x4)", - self.string_type(bo, with_shape=True), - ) - - with torch_export_patches(): - # internal function - bo2 = torch_deepcopy([bo]) - self.assertIsInstance(bo2, list) - self.assertEqual(bo2[0].__class__.__name__, "UNet2DConditionOutput") - self.assertEqualAny([bo], bo2) - self.assertEqual( - "UNet2DConditionOutput(sample:T1s4x4x4)", - self.string_type(bo, with_shape=True), - ) - - # serialization - flat, _spec = torch.utils._pytree.tree_flatten(bo) - self.assertEqual( - "#1[T1s4x4x4]", - self.string_type(flat, with_shape=True), - ) - bo2 = torch.utils._pytree.tree_unflatten(flat, _spec) - self.assertEqual( - self.string_type(bo, with_shape=True, with_min_max=True), - self.string_type(bo2, with_shape=True, with_min_max=True), - ) - - # flatten_unflatten - flat, _spec = torch.utils._pytree.tree_flatten(bo) - unflat = flatten_unflatten_for_dynamic_shapes(bo, use_dict=True) - self.assertIsInstance(unflat, dict) - self.assertEqual(list(unflat), ["sample"]) - - # export - class Model(torch.nn.Module): - def forward(self, cache): - return cache.sample[0] - - bo = UNet2DConditionOutput(sample=torch.rand((4, 4, 4))) - model = Model() - model(bo) - DYN = torch.export.Dim.DYNAMIC - ds = [{0: DYN}] - - with torch_export_patches(): - torch.export.export(model, (bo,), dynamic_shapes=(ds,)) - @ignore_warnings(UserWarning) @requires_torch("2.7.99") def test_static_cache(self): diff --git a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py index ad1309e6..07617be2 100644 --- a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py +++ b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py @@ -134,11 +134,17 @@ def unpatch_module_or_classes(mod, info: Dict[type, Dict[type, Callable]], verbo @contextlib.contextmanager def register_additional_serialization_functions( - patch_transformers: bool = False, verbose: int = 0 + patch_transformers: bool = False, patch_diffusers: bool = False, verbose: int = 0 ) -> Callable: """The necessary modifications to run the fx Graph.""" - fct_callable = replacement_before_exporting if patch_transformers else (lambda x: x) - done = register_cache_serialization(verbose=verbose) + fct_callable = ( + replacement_before_exporting + if patch_transformers or patch_diffusers + else (lambda x: x) + ) + done = register_cache_serialization( + patch_transformers=patch_transformers, patch_diffusers=patch_diffusers, verbose=verbose + ) try: yield fct_callable finally: @@ -150,6 +156,7 @@ def torch_export_patches( patch_sympy: bool = True, patch_torch: bool = True, patch_transformers: bool = False, + patch_diffusers: bool = False, catch_constraints: bool = True, stop_if_static: int = 0, verbose: int = 0, @@ -165,6 +172,7 @@ def torch_export_patches( :param patch_sympy: fix missing method ``name`` for IntegerConstant :param patch_torch: patches :epkg:`torch` with supported implementation :param patch_transformers: patches :epkg:`transformers` with supported implementation + :param patch_diffusers: patches :epkg:`diffusers` with supported implementation :param catch_constraints: catch constraints related to dynamic shapes, as a result, some dynamic dimension may turn into static ones, the environment variable ``SKIP_SOLVE_CONSTRAINTS=0`` @@ -249,6 +257,7 @@ def torch_export_patches( patch_sympy=patch_sympy, patch_torch=patch_torch, patch_transformers=patch_transformers, + patch_diffusers=patch_diffusers, catch_constraints=catch_constraints, stop_if_static=stop_if_static, verbose=verbose, @@ -281,7 +290,11 @@ def torch_export_patches( # caches ######## - cache_done = register_cache_serialization(verbose=verbose) + cache_done = register_cache_serialization( + patch_transformers=patch_transformers, + patch_diffusers=patch_diffusers, + verbose=verbose, + ) ############# # patch sympy diff --git a/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py b/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py index 4f216367..e669a129 100644 --- a/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +++ b/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py @@ -13,7 +13,7 @@ ) from ..helpers import string_type - +from .serialization import _lower_name_with_ PATCH_OF_PATCHES: Set[Any] = set() @@ -73,14 +73,33 @@ def register_class_serialization( return True -def register_cache_serialization(verbose: int = 0) -> Dict[str, bool]: +def register_cache_serialization( + patch_transformers: bool = False, patch_diffusers: bool = True, verbose: int = 0 +) -> Dict[str, bool]: """ Registers many classes with :func:`register_class_serialization`. Returns information needed to undo the registration. + + :param patch_transformers: add serialization function for + :epkg:`transformers` package + :param patch_diffusers: add serialization function for + :epkg:`diffusers` package + :param verbosity: verbosity level + :return: information to unpatch """ - from .onnx_export_serialization_impl import WRONG_REGISTRATIONS + wrong: Dict[type, Optional[str]] = {} + if patch_transformers: + from .serialization.transformers_impl import WRONG_REGISTRATIONS + + wrong |= WRONG_REGISTRATIONS + if patch_diffusers: + from .serialization.diffusers_impl import WRONG_REGISTRATIONS - registration_functions = serialization_functions(verbose=verbose) + wrong |= WRONG_REGISTRATIONS + + registration_functions = serialization_functions( + patch_transformers=patch_transformers, patch_diffusers=patch_diffusers, verbose=verbose + ) # DynamicCache serialization is different in transformers and does not # play way with torch.export.export. @@ -92,7 +111,7 @@ def register_cache_serialization(verbose: int = 0) -> Dict[str, bool]: # so we remove it anyway # BaseModelOutput serialization is incomplete. # It does not include dynamic shapes mapping. - for cls, version in WRONG_REGISTRATIONS.items(): + for cls, version in wrong.items(): if ( cls in torch.utils._pytree.SUPPORTED_NODES and cls not in PATCH_OF_PATCHES @@ -124,73 +143,91 @@ def register_cache_serialization(verbose: int = 0) -> Dict[str, bool]: return done -def serialization_functions(verbose: int = 0) -> Dict[type, Callable[[int], bool]]: +def serialization_functions( + patch_transformers: bool = False, patch_diffusers: bool = False, verbose: int = 0 +) -> Dict[type, Callable[[int], bool]]: """Returns the list of serialization functions.""" - from .onnx_export_serialization_impl import ( - SUPPORTED_DATACLASSES, - _lower_name_with_, - __dict__ as all_functions, - flatten_dynamic_cache, - unflatten_dynamic_cache, - flatten_with_keys_dynamic_cache, - flatten_mamba_cache, - unflatten_mamba_cache, - flatten_with_keys_mamba_cache, - flatten_encoder_decoder_cache, - unflatten_encoder_decoder_cache, - flatten_with_keys_encoder_decoder_cache, - flatten_sliding_window_cache, - unflatten_sliding_window_cache, - flatten_with_keys_sliding_window_cache, - flatten_static_cache, - unflatten_static_cache, - flatten_with_keys_static_cache, - ) - transformers_classes = { - DynamicCache: lambda verbose=verbose: register_class_serialization( - DynamicCache, + supported_classes: Set[type] = set() + classes: Dict[type, Callable[[int], bool]] = {} + all_functions: Dict[type, Optional[str]] = {} + + if patch_transformers: + from .serialization.transformers_impl import ( + __dict__ as dtr, + SUPPORTED_DATACLASSES, flatten_dynamic_cache, unflatten_dynamic_cache, flatten_with_keys_dynamic_cache, - # f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]), - verbose=verbose, - ), - MambaCache: lambda verbose=verbose: register_class_serialization( - MambaCache, flatten_mamba_cache, unflatten_mamba_cache, flatten_with_keys_mamba_cache, - verbose=verbose, - ), - EncoderDecoderCache: lambda verbose=verbose: register_class_serialization( - EncoderDecoderCache, flatten_encoder_decoder_cache, unflatten_encoder_decoder_cache, flatten_with_keys_encoder_decoder_cache, - verbose=verbose, - ), - SlidingWindowCache: lambda verbose=verbose: register_class_serialization( - SlidingWindowCache, flatten_sliding_window_cache, unflatten_sliding_window_cache, flatten_with_keys_sliding_window_cache, - verbose=verbose, - ), - StaticCache: lambda verbose=verbose: register_class_serialization( - StaticCache, flatten_static_cache, unflatten_static_cache, flatten_with_keys_static_cache, - verbose=verbose, - ), - } - for cls in SUPPORTED_DATACLASSES: + ) + + all_functions.update(dtr) + supported_classes |= SUPPORTED_DATACLASSES + + transformers_classes = { + DynamicCache: lambda verbose=verbose: register_class_serialization( + DynamicCache, + flatten_dynamic_cache, + unflatten_dynamic_cache, + flatten_with_keys_dynamic_cache, + # f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]), + verbose=verbose, + ), + MambaCache: lambda verbose=verbose: register_class_serialization( + MambaCache, + flatten_mamba_cache, + unflatten_mamba_cache, + flatten_with_keys_mamba_cache, + verbose=verbose, + ), + EncoderDecoderCache: lambda verbose=verbose: register_class_serialization( + EncoderDecoderCache, + flatten_encoder_decoder_cache, + unflatten_encoder_decoder_cache, + flatten_with_keys_encoder_decoder_cache, + verbose=verbose, + ), + SlidingWindowCache: lambda verbose=verbose: register_class_serialization( + SlidingWindowCache, + flatten_sliding_window_cache, + unflatten_sliding_window_cache, + flatten_with_keys_sliding_window_cache, + verbose=verbose, + ), + StaticCache: lambda verbose=verbose: register_class_serialization( + StaticCache, + flatten_static_cache, + unflatten_static_cache, + flatten_with_keys_static_cache, + verbose=verbose, + ), + } + classes.update(transformers_classes) + + if patch_diffusers: + from .serialization.diffusers_impl import SUPPORTED_DATACLASSES, __dict__ as dfu + + all_functions.update(dfu) + supported_classes |= SUPPORTED_DATACLASSES + + for cls in supported_classes: lname = _lower_name_with_(cls.__name__) assert ( f"flatten_{lname}" in all_functions - ), f"Unable to find function 'flatten_{lname}' in {sorted(all_functions)}" - transformers_classes[cls] = ( + ), f"Unable to find function 'flatten_{lname}' in {list(all_functions)}" + classes[cls] = ( lambda verbose=verbose, _ln=lname, cls=cls, _al=all_functions: register_class_serialization( # noqa: E501 cls, _al[f"flatten_{_ln}"], @@ -199,7 +236,7 @@ def serialization_functions(verbose: int = 0) -> Dict[type, Callable[[int], bool verbose=verbose, ) ) - return transformers_classes + return classes def unregister_class_serialization(cls: type, verbose: int = 0): diff --git a/onnx_diagnostic/torch_export_patches/serialization/__init__.py b/onnx_diagnostic/torch_export_patches/serialization/__init__.py new file mode 100644 index 00000000..965e7cb3 --- /dev/null +++ b/onnx_diagnostic/torch_export_patches/serialization/__init__.py @@ -0,0 +1,46 @@ +import re +from typing import Any, Callable, List, Set, Tuple +import torch + + +def _lower_name_with_(name): + s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) + return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() + + +def make_serialization_function_for_dataclass( + cls: type, supported_classes: Set[type] +) -> Tuple[Callable, Callable, Callable]: + """ + Automatically creates serialization function for a class decorated with + ``dataclasses.dataclass``. + """ + + def flatten_cls(obj: cls) -> Tuple[List[Any], torch.utils._pytree.Context]: # type: ignore[valid-type] + """Serializes a ``%s`` with python objects.""" + return list(obj.values()), list(obj.keys()) + + def flatten_with_keys_cls( + obj: cls, # type: ignore[valid-type] + ) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]: + """Serializes a ``%s`` with python objects with keys.""" + values, context = list(obj.values()), list(obj.keys()) + return [ + (torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values) + ], context + + def unflatten_cls( + values: List[Any], context: torch.utils._pytree.Context, output_type=None + ) -> cls: # type: ignore[valid-type] + """Restores an instance of ``%s`` from python objects.""" + return cls(**dict(zip(context, values))) + + name = _lower_name_with_(cls.__name__) + flatten_cls.__name__ = f"flatten_{name}" + flatten_with_keys_cls.__name__ = f"flatten_with_keys_{name}" + unflatten_cls.__name__ = f"unflatten_{name}" + flatten_cls.__doc__ = flatten_cls.__doc__ % cls.__name__ + flatten_with_keys_cls.__doc__ = flatten_with_keys_cls.__doc__ % cls.__name__ + unflatten_cls.__doc__ = unflatten_cls.__doc__ % cls.__name__ + supported_classes.add(cls) + return flatten_cls, flatten_with_keys_cls, unflatten_cls diff --git a/onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py b/onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py new file mode 100644 index 00000000..a450ffcf --- /dev/null +++ b/onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py @@ -0,0 +1,34 @@ +from typing import Dict, Optional, Set + +try: + from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput +except ImportError as e: + try: + import diffusers + except ImportError: + diffusers = None + UNet2DConditionOutput = None + if diffusers: + raise e + +from . import make_serialization_function_for_dataclass + + +def _make_wrong_registrations() -> Dict[type, Optional[str]]: + res: Dict[type, Optional[str]] = {} + for c in [UNet2DConditionOutput]: + if c is not None: + res[c] = None + return res + + +SUPPORTED_DATACLASSES: Set[type] = set() +WRONG_REGISTRATIONS = _make_wrong_registrations() + + +if UNet2DConditionOutput is not None: + ( + flatten_u_net2_d_condition_output, + flatten_with_keys_u_net2_d_condition_output, + unflatten_u_net2_d_condition_output, + ) = make_serialization_function_for_dataclass(UNet2DConditionOutput, SUPPORTED_DATACLASSES) diff --git a/onnx_diagnostic/torch_export_patches/onnx_export_serialization_impl.py b/onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py similarity index 73% rename from onnx_diagnostic/torch_export_patches/onnx_export_serialization_impl.py rename to onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py index 9557c15f..3b2dc899 100644 --- a/onnx_diagnostic/torch_export_patches/onnx_export_serialization_impl.py +++ b/onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py @@ -1,5 +1,4 @@ -import re -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, List, Set, Tuple import torch import transformers from transformers.cache_utils import ( @@ -10,40 +9,15 @@ StaticCache, ) from transformers.modeling_outputs import BaseModelOutput - -try: - from diffusers.models.autoencoders.vae import DecoderOutput, EncoderOutput - from diffusers.models.unets.unet_1d import UNet1DOutput - from diffusers.models.unets.unet_2d import UNet2DOutput - from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput - from diffusers.models.unets.unet_3d_condition import UNet3DConditionOutput -except ImportError as e: - try: - import diffusers - except ImportError: - diffusers = None - DecoderOutput, EncoderOutput = None, None - UNet1DOutput, UNet2DOutput = None, None - UNet2DConditionOutput, UNet3DConditionOutput = None, None - if diffusers: - raise e - -from ..helpers.cache_helper import make_static_cache - - -def _make_wrong_registrations() -> Dict[str, Optional[str]]: - res = { - DynamicCache: "4.50", - BaseModelOutput: None, - } - for c in [UNet2DConditionOutput]: - if c is not None: - res[c] = None - return res +from ...helpers.cache_helper import make_static_cache +from . import make_serialization_function_for_dataclass -SUPPORTED_DATACLASSES = set() -WRONG_REGISTRATIONS = _make_wrong_registrations() +SUPPORTED_DATACLASSES: Set[type] = set() +WRONG_REGISTRATIONS = { + DynamicCache: "4.50", + BaseModelOutput: None, +} ############ @@ -278,57 +252,8 @@ def unflatten_encoder_decoder_cache( ############# -def _lower_name_with_(name): - s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) - return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() - - -def make_serialization_function_for_dataclass(cls) -> Tuple[Callable, Callable, Callable]: - """ - Automatically creates serialization function for a class decorated with - ``dataclasses.dataclass``. - """ - - def flatten_cls(obj: cls) -> Tuple[List[Any], torch.utils._pytree.Context]: - """Serializes a ``%s`` with python objects.""" - return list(obj.values()), list(obj.keys()) - - def flatten_with_keys_cls( - obj: cls, - ) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]: - """Serializes a ``%s`` with python objects with keys.""" - values, context = list(obj.values()), list(obj.keys()) - return [ - (torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values) - ], context - - def unflatten_cls( - values: List[Any], context: torch.utils._pytree.Context, output_type=None - ) -> cls: - """Restores an instance of ``%s`` from python objects.""" - return cls(**dict(zip(context, values))) - - name = _lower_name_with_(cls.__name__) - flatten_cls.__name__ = f"flatten_{name}" - flatten_with_keys_cls.__name__ = f"flatten_with_keys_{name}" - unflatten_cls.__name__ = f"unflatten_{name}" - flatten_cls.__doc__ = flatten_cls.__doc__ % cls.__name__ - flatten_with_keys_cls.__doc__ = flatten_with_keys_cls.__doc__ % cls.__name__ - unflatten_cls.__doc__ = unflatten_cls.__doc__ % cls.__name__ - SUPPORTED_DATACLASSES.add(cls) - return flatten_cls, flatten_with_keys_cls, unflatten_cls - - ( flatten_base_model_output, flatten_with_keys_base_model_output, unflatten_base_model_output, -) = make_serialization_function_for_dataclass(BaseModelOutput) - - -if UNet2DConditionOutput is not None: - ( - flatten_u_net2_d_condition_output, - flatten_with_keys_u_net2_d_condition_output, - unflatten_u_net2_d_condition_output, - ) = make_serialization_function_for_dataclass(UNet2DConditionOutput) +) = make_serialization_function_for_dataclass(BaseModelOutput, SUPPORTED_DATACLASSES)