Skip to content

Commit 8fd602e

Browse files
committed
rename
1 parent 1c682f5 commit 8fd602e

File tree

3 files changed

+74
-65
lines changed

3 files changed

+74
-65
lines changed

_doc/api/torch_export_patches/index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@ onnx_diagnostic.torch_export_patches
88
eval/index
99
onnx_export_errors
1010
onnx_export_serialization
11-
onnx_export_serialization_impl
1211
patches/index
1312
patch_expressions
1413
patch_inputs
1514
patch_module
1615
patch_module_helper
16+
serialization/index
1717

1818
.. automodule:: onnx_diagnostic.torch_export_patches
1919
:members:
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import unittest
2+
import torch
3+
from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings, requires_diffusers
4+
from onnx_diagnostic.helpers.cache_helper import flatten_unflatten_for_dynamic_shapes
5+
from onnx_diagnostic.torch_export_patches.onnx_export_errors import (
6+
torch_export_patches,
7+
)
8+
from onnx_diagnostic.helpers.torch_helper import torch_deepcopy
9+
10+
11+
class TestPatchSerializationDiffusers(ExtTestCase):
12+
@ignore_warnings(UserWarning)
13+
@requires_diffusers("0.30")
14+
def test_unet_2d_condition_output(self):
15+
from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
16+
17+
bo = UNet2DConditionOutput(sample=torch.rand((4, 4, 4)))
18+
self.assertEqual(bo.__class__.__name__, "UNet2DConditionOutput")
19+
bo2 = torch_deepcopy([bo])
20+
self.assertIsInstance(bo2, list)
21+
self.assertEqual(
22+
"UNet2DConditionOutput(sample:T1s4x4x4)",
23+
self.string_type(bo, with_shape=True),
24+
)
25+
26+
with torch_export_patches(patch_diffusers=True):
27+
# internal function
28+
bo2 = torch_deepcopy([bo])
29+
self.assertIsInstance(bo2, list)
30+
self.assertEqual(bo2[0].__class__.__name__, "UNet2DConditionOutput")
31+
self.assertEqualAny([bo], bo2)
32+
self.assertEqual(
33+
"UNet2DConditionOutput(sample:T1s4x4x4)",
34+
self.string_type(bo, with_shape=True),
35+
)
36+
37+
# serialization
38+
flat, _spec = torch.utils._pytree.tree_flatten(bo)
39+
self.assertEqual(
40+
"#1[T1s4x4x4]",
41+
self.string_type(flat, with_shape=True),
42+
)
43+
bo2 = torch.utils._pytree.tree_unflatten(flat, _spec)
44+
self.assertEqual(
45+
self.string_type(bo, with_shape=True, with_min_max=True),
46+
self.string_type(bo2, with_shape=True, with_min_max=True),
47+
)
48+
49+
# flatten_unflatten
50+
flat, _spec = torch.utils._pytree.tree_flatten(bo)
51+
unflat = flatten_unflatten_for_dynamic_shapes(bo, use_dict=True)
52+
self.assertIsInstance(unflat, dict)
53+
self.assertEqual(list(unflat), ["sample"])
54+
55+
# export
56+
class Model(torch.nn.Module):
57+
def forward(self, cache):
58+
return cache.sample[0]
59+
60+
bo = UNet2DConditionOutput(sample=torch.rand((4, 4, 4)))
61+
model = Model()
62+
model(bo)
63+
DYN = torch.export.Dim.DYNAMIC
64+
ds = [{0: DYN}]
65+
66+
with torch_export_patches(patch_diffusers=True):
67+
torch.export.export(model, (bo,), dynamic_shapes=(ds,))
68+
69+
70+
if __name__ == "__main__":
71+
unittest.main(verbosity=2)

_unittests/ut_torch_export_patches/test_patch_serialization.py renamed to _unittests/ut_torch_export_patches/test_patch_serialization_transformers.py

Lines changed: 2 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,7 @@
11
import unittest
22
import torch
33
from transformers.modeling_outputs import BaseModelOutput
4-
from onnx_diagnostic.ext_test_case import (
5-
ExtTestCase,
6-
ignore_warnings,
7-
requires_torch,
8-
requires_diffusers,
9-
)
4+
from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings, requires_torch
105
from onnx_diagnostic.helpers.cache_helper import (
116
make_encoder_decoder_cache,
127
make_dynamic_cache,
@@ -159,7 +154,7 @@ def forward(self, cache):
159154
DYN = torch.export.Dim.DYNAMIC
160155
ds = [{0: DYN}]
161156

162-
with torch_export_patches():
157+
with torch_export_patches(patch_transformers=True):
163158
torch.export.export(model, (bo,), dynamic_shapes=(ds,))
164159

165160
@ignore_warnings(UserWarning)
@@ -218,63 +213,6 @@ def test_sliding_window_cache_flatten(self):
218213
self.string_type(cache2, with_shape=True, with_min_max=True),
219214
)
220215

221-
@ignore_warnings(UserWarning)
222-
@requires_diffusers("0.30")
223-
def test_unet_2d_condition_output(self):
224-
from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
225-
226-
bo = UNet2DConditionOutput(sample=torch.rand((4, 4, 4)))
227-
self.assertEqual(bo.__class__.__name__, "UNet2DConditionOutput")
228-
bo2 = torch_deepcopy([bo])
229-
self.assertIsInstance(bo2, list)
230-
self.assertEqual(
231-
"UNet2DConditionOutput(sample:T1s4x4x4)",
232-
self.string_type(bo, with_shape=True),
233-
)
234-
235-
with torch_export_patches():
236-
# internal function
237-
bo2 = torch_deepcopy([bo])
238-
self.assertIsInstance(bo2, list)
239-
self.assertEqual(bo2[0].__class__.__name__, "UNet2DConditionOutput")
240-
self.assertEqualAny([bo], bo2)
241-
self.assertEqual(
242-
"UNet2DConditionOutput(sample:T1s4x4x4)",
243-
self.string_type(bo, with_shape=True),
244-
)
245-
246-
# serialization
247-
flat, _spec = torch.utils._pytree.tree_flatten(bo)
248-
self.assertEqual(
249-
"#1[T1s4x4x4]",
250-
self.string_type(flat, with_shape=True),
251-
)
252-
bo2 = torch.utils._pytree.tree_unflatten(flat, _spec)
253-
self.assertEqual(
254-
self.string_type(bo, with_shape=True, with_min_max=True),
255-
self.string_type(bo2, with_shape=True, with_min_max=True),
256-
)
257-
258-
# flatten_unflatten
259-
flat, _spec = torch.utils._pytree.tree_flatten(bo)
260-
unflat = flatten_unflatten_for_dynamic_shapes(bo, use_dict=True)
261-
self.assertIsInstance(unflat, dict)
262-
self.assertEqual(list(unflat), ["sample"])
263-
264-
# export
265-
class Model(torch.nn.Module):
266-
def forward(self, cache):
267-
return cache.sample[0]
268-
269-
bo = UNet2DConditionOutput(sample=torch.rand((4, 4, 4)))
270-
model = Model()
271-
model(bo)
272-
DYN = torch.export.Dim.DYNAMIC
273-
ds = [{0: DYN}]
274-
275-
with torch_export_patches():
276-
torch.export.export(model, (bo,), dynamic_shapes=(ds,))
277-
278216
@ignore_warnings(UserWarning)
279217
@requires_torch("2.7.99")
280218
def test_static_cache(self):

0 commit comments

Comments
 (0)