Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion _doc/api/torch_export_patches/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ onnx_diagnostic.torch_export_patches
eval/index
onnx_export_errors
onnx_export_serialization
onnx_export_serialization_impl
patches/index
patch_expressions
patch_inputs
patch_module
patch_module_helper
serialization/index

.. automodule:: onnx_diagnostic.torch_export_patches
:members:
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@

onnx_diagnostic.torch_export_patches.serialization.diffusers_impl
=================================================================

.. automodule:: onnx_diagnostic.torch_export_patches.serialization.diffusers_impl
:members:
:no-undoc-members:
13 changes: 13 additions & 0 deletions _doc/api/torch_export_patches/serialization/index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
onnx_diagnostic.torch_export_patches.serialization
==================================================

.. toctree::
:maxdepth: 1
:caption: submodules

diffusers_impl
transformers_impl

.. automodule:: onnx_diagnostic.torch_export_patches.serialization
:members:
:no-undoc-members:
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@

onnx_diagnostic.torch_export_patches.serialization.transformers_impl
====================================================================

.. automodule:: onnx_diagnostic.torch_export_patches.serialization.transformers_impl
:members:
:no-undoc-members:
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import unittest
import torch
from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings, requires_diffusers
from onnx_diagnostic.helpers.cache_helper import flatten_unflatten_for_dynamic_shapes
from onnx_diagnostic.torch_export_patches.onnx_export_errors import (
torch_export_patches,
)
from onnx_diagnostic.helpers.torch_helper import torch_deepcopy


class TestPatchSerializationDiffusers(ExtTestCase):
@ignore_warnings(UserWarning)
@requires_diffusers("0.30")
def test_unet_2d_condition_output(self):
from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput

bo = UNet2DConditionOutput(sample=torch.rand((4, 4, 4)))
self.assertEqual(bo.__class__.__name__, "UNet2DConditionOutput")
bo2 = torch_deepcopy([bo])
self.assertIsInstance(bo2, list)
self.assertEqual(
"UNet2DConditionOutput(sample:T1s4x4x4)",
self.string_type(bo, with_shape=True),
)

with torch_export_patches(patch_diffusers=True):
# internal function
bo2 = torch_deepcopy([bo])
self.assertIsInstance(bo2, list)
self.assertEqual(bo2[0].__class__.__name__, "UNet2DConditionOutput")
self.assertEqualAny([bo], bo2)
self.assertEqual(
"UNet2DConditionOutput(sample:T1s4x4x4)",
self.string_type(bo, with_shape=True),
)

# serialization
flat, _spec = torch.utils._pytree.tree_flatten(bo)
self.assertEqual(
"#1[T1s4x4x4]",
self.string_type(flat, with_shape=True),
)
bo2 = torch.utils._pytree.tree_unflatten(flat, _spec)
self.assertEqual(
self.string_type(bo, with_shape=True, with_min_max=True),
self.string_type(bo2, with_shape=True, with_min_max=True),
)

# flatten_unflatten
flat, _spec = torch.utils._pytree.tree_flatten(bo)
unflat = flatten_unflatten_for_dynamic_shapes(bo, use_dict=True)
self.assertIsInstance(unflat, dict)
self.assertEqual(list(unflat), ["sample"])

# export
class Model(torch.nn.Module):
def forward(self, cache):
return cache.sample[0]

bo = UNet2DConditionOutput(sample=torch.rand((4, 4, 4)))
model = Model()
model(bo)
DYN = torch.export.Dim.DYNAMIC
ds = [{0: DYN}]

with torch_export_patches(patch_diffusers=True):
torch.export.export(model, (bo,), dynamic_shapes=(ds,))


if __name__ == "__main__":
unittest.main(verbosity=2)
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
import unittest
import torch
from transformers.modeling_outputs import BaseModelOutput
from onnx_diagnostic.ext_test_case import (
ExtTestCase,
ignore_warnings,
requires_torch,
requires_diffusers,
)
from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings, requires_torch
from onnx_diagnostic.helpers.cache_helper import (
make_encoder_decoder_cache,
make_dynamic_cache,
Expand Down Expand Up @@ -159,7 +154,7 @@ def forward(self, cache):
DYN = torch.export.Dim.DYNAMIC
ds = [{0: DYN}]

with torch_export_patches():
with torch_export_patches(patch_transformers=True):
torch.export.export(model, (bo,), dynamic_shapes=(ds,))

@ignore_warnings(UserWarning)
Expand Down Expand Up @@ -218,63 +213,6 @@ def test_sliding_window_cache_flatten(self):
self.string_type(cache2, with_shape=True, with_min_max=True),
)

@ignore_warnings(UserWarning)
@requires_diffusers("0.30")
def test_unet_2d_condition_output(self):
from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput

bo = UNet2DConditionOutput(sample=torch.rand((4, 4, 4)))
self.assertEqual(bo.__class__.__name__, "UNet2DConditionOutput")
bo2 = torch_deepcopy([bo])
self.assertIsInstance(bo2, list)
self.assertEqual(
"UNet2DConditionOutput(sample:T1s4x4x4)",
self.string_type(bo, with_shape=True),
)

with torch_export_patches():
# internal function
bo2 = torch_deepcopy([bo])
self.assertIsInstance(bo2, list)
self.assertEqual(bo2[0].__class__.__name__, "UNet2DConditionOutput")
self.assertEqualAny([bo], bo2)
self.assertEqual(
"UNet2DConditionOutput(sample:T1s4x4x4)",
self.string_type(bo, with_shape=True),
)

# serialization
flat, _spec = torch.utils._pytree.tree_flatten(bo)
self.assertEqual(
"#1[T1s4x4x4]",
self.string_type(flat, with_shape=True),
)
bo2 = torch.utils._pytree.tree_unflatten(flat, _spec)
self.assertEqual(
self.string_type(bo, with_shape=True, with_min_max=True),
self.string_type(bo2, with_shape=True, with_min_max=True),
)

# flatten_unflatten
flat, _spec = torch.utils._pytree.tree_flatten(bo)
unflat = flatten_unflatten_for_dynamic_shapes(bo, use_dict=True)
self.assertIsInstance(unflat, dict)
self.assertEqual(list(unflat), ["sample"])

# export
class Model(torch.nn.Module):
def forward(self, cache):
return cache.sample[0]

bo = UNet2DConditionOutput(sample=torch.rand((4, 4, 4)))
model = Model()
model(bo)
DYN = torch.export.Dim.DYNAMIC
ds = [{0: DYN}]

with torch_export_patches():
torch.export.export(model, (bo,), dynamic_shapes=(ds,))

@ignore_warnings(UserWarning)
@requires_torch("2.7.99")
def test_static_cache(self):
Expand Down
21 changes: 17 additions & 4 deletions onnx_diagnostic/torch_export_patches/onnx_export_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,17 @@ def unpatch_module_or_classes(mod, info: Dict[type, Dict[type, Callable]], verbo

@contextlib.contextmanager
def register_additional_serialization_functions(
patch_transformers: bool = False, verbose: int = 0
patch_transformers: bool = False, patch_diffusers: bool = False, verbose: int = 0
) -> Callable:
"""The necessary modifications to run the fx Graph."""
fct_callable = replacement_before_exporting if patch_transformers else (lambda x: x)
done = register_cache_serialization(verbose=verbose)
fct_callable = (
replacement_before_exporting
if patch_transformers or patch_diffusers
else (lambda x: x)
)
done = register_cache_serialization(
patch_transformers=patch_transformers, patch_diffusers=patch_diffusers, verbose=verbose
)
try:
yield fct_callable
finally:
Expand All @@ -150,6 +156,7 @@ def torch_export_patches(
patch_sympy: bool = True,
patch_torch: bool = True,
patch_transformers: bool = False,
patch_diffusers: bool = False,
catch_constraints: bool = True,
stop_if_static: int = 0,
verbose: int = 0,
Expand All @@ -165,6 +172,7 @@ def torch_export_patches(
:param patch_sympy: fix missing method ``name`` for IntegerConstant
:param patch_torch: patches :epkg:`torch` with supported implementation
:param patch_transformers: patches :epkg:`transformers` with supported implementation
:param patch_diffusers: patches :epkg:`diffusers` with supported implementation
:param catch_constraints: catch constraints related to dynamic shapes,
as a result, some dynamic dimension may turn into static ones,
the environment variable ``SKIP_SOLVE_CONSTRAINTS=0``
Expand Down Expand Up @@ -249,6 +257,7 @@ def torch_export_patches(
patch_sympy=patch_sympy,
patch_torch=patch_torch,
patch_transformers=patch_transformers,
patch_diffusers=patch_diffusers,
catch_constraints=catch_constraints,
stop_if_static=stop_if_static,
verbose=verbose,
Expand Down Expand Up @@ -281,7 +290,11 @@ def torch_export_patches(
# caches
########

cache_done = register_cache_serialization(verbose=verbose)
cache_done = register_cache_serialization(
patch_transformers=patch_transformers,
patch_diffusers=patch_diffusers,
verbose=verbose,
)

#############
# patch sympy
Expand Down
Loading
Loading