Skip to content

Commit 7be71c1

Browse files
committed
fix issues
1 parent 87fa333 commit 7be71c1

File tree

6 files changed

+67
-9
lines changed

6 files changed

+67
-9
lines changed

_doc/conf.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,8 @@ def linkcode_resolve(domain, info):
9090
"https://sdpython.github.io/doc/experimental-experiment/dev/",
9191
None,
9292
),
93-
"diffusers": ("https://huggingface.co/docs/diffusers/index", None),
93+
# Not a sphinx documentation
94+
# "diffusers": ("https://huggingface.co/docs/diffusers/index", None),
9495
"matplotlib": ("https://matplotlib.org/stable/", None),
9596
"numpy": ("https://numpy.org/doc/stable", None),
9697
"onnx": ("https://onnx.ai/onnx/", None),
@@ -105,7 +106,8 @@ def linkcode_resolve(domain, info):
105106
"sklearn": ("https://scikit-learn.org/stable/", None),
106107
"skl2onnx": ("https://onnx.ai/sklearn-onnx/", None),
107108
"torch": ("https://pytorch.org/docs/main/", None),
108-
"transformers": ("https://huggingface.co/docs/transformers/index", None),
109+
# Not a sphinx documentation
110+
# "transformers": ("https://huggingface.co/docs/transformers/index", None),
109111
}
110112

111113
# Check intersphinx reference targets exist
@@ -118,6 +120,7 @@ def linkcode_resolve(domain, info):
118120
("py:class", "True"),
119121
("py:class", "Argument"),
120122
("py:class", "default=sklearn.utils.metadata_routing.UNCHANGED"),
123+
("py:class", "diffusers.models.unets.unet_2d_condition.UNet2DConditionOutput"),
121124
("py:class", "ModelProto"),
122125
("py:class", "Model"),
123126
("py:class", "Module"),

_doc/patches.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ Here is the list of supported caches:
113113

114114
import onnx_diagnostic.torch_export_patches.onnx_export_serialization as p
115115

116-
print("\n".join(sorted(p.serialization_functions())))
116+
print("\n".join(sorted(t.__name__ for t in p.serialization_functions())))
117117

118118
.. _l-control-flow-rewriting:
119119

_doc/status/patches_coverage.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ The following code shows the list of serialized classes in transformers.
1414

1515
import onnx_diagnostic.torch_export_patches.onnx_export_serialization as p
1616

17-
print('\n'.join(sorted(p.serialization_functions())))
17+
print('\n'.join(sorted(t.__name__ for t in p.serialization_functions())))
1818

1919
Patched Classes
2020
===============

_unittests/ut_torch_export_patches/test_patch_serialization.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
import unittest
22
import torch
33
from transformers.modeling_outputs import BaseModelOutput
4-
from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings, requires_torch
4+
from onnx_diagnostic.ext_test_case import (
5+
ExtTestCase,
6+
ignore_warnings,
7+
requires_torch,
8+
requires_diffusers,
9+
)
510
from onnx_diagnostic.helpers.cache_helper import (
611
make_encoder_decoder_cache,
712
make_dynamic_cache,
@@ -212,6 +217,56 @@ def test_sliding_window_cache_flatten(self):
212217
self.string_type(cache2, with_shape=True, with_min_max=True),
213218
)
214219

220+
@ignore_warnings(UserWarning)
221+
@requires_diffusers("0.30")
222+
def test_unet_2d_condition_output(self):
223+
from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
224+
225+
bo = UNet2DConditionOutput(sample=torch.rand((4, 4, 4)))
226+
self.assertEqual(bo.__class__.__name__, "UNet2DConditionOutput")
227+
with torch_export_patches():
228+
# internal function
229+
bo2 = torch_deepcopy([bo])
230+
self.assertIsInstance(bo2, list)
231+
self.assertEqual(bo2[0].__class__.__name__, "UNet2DConditionOutput")
232+
self.assertEqualAny([bo], bo2)
233+
self.assertEqual(
234+
"UNet2DConditionOutput(sample:T1s4x4x4)",
235+
self.string_type(bo, with_shape=True),
236+
)
237+
238+
# serialization
239+
flat, _spec = torch.utils._pytree.tree_flatten(bo)
240+
self.assertEqual(
241+
"#1[T1s4x4x4]",
242+
self.string_type(flat, with_shape=True),
243+
)
244+
bo2 = torch.utils._pytree.tree_unflatten(flat, _spec)
245+
self.assertEqual(
246+
self.string_type(bo, with_shape=True, with_min_max=True),
247+
self.string_type(bo2, with_shape=True, with_min_max=True),
248+
)
249+
250+
# flatten_unflatten
251+
flat, _spec = torch.utils._pytree.tree_flatten(bo)
252+
unflat = flatten_unflatten_for_dynamic_shapes(bo, use_dict=True)
253+
self.assertIsInstance(unflat, dict)
254+
self.assertEqual(list(unflat), ["sample"])
255+
256+
# export
257+
class Model(torch.nn.Module):
258+
def forward(self, cache):
259+
return cache.sample[0]
260+
261+
bo = UNet2DConditionOutput(sample=torch.rand((4, 4, 4)))
262+
model = Model()
263+
model(bo)
264+
DYN = torch.export.Dim.DYNAMIC
265+
ds = [{0: DYN}]
266+
267+
with torch_export_patches():
268+
torch.export.export(model, (bo,), dynamic_shapes=(ds,))
269+
215270

216271
if __name__ == "__main__":
217272
unittest.main(verbosity=2)

onnx_diagnostic/helpers/config_helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def _pick(config, *atts, exceptions: Optional[Dict[str, Callable]] = None):
7171

7272
def pick(config, name: str, default_value: Any) -> Any:
7373
"""
74-
Returns the vlaue of a attribute if config has it
74+
Returns the value of a attribute if config has it
7575
otherwise the default value.
7676
"""
7777
if not config:

onnx_diagnostic/torch_export_patches/onnx_export_serialization.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535

3636

3737
PATCH_OF_PATCHES: Set[Any] = set()
38-
WRONG_REGISTRATIONS: Dict[str, str] = {
38+
WRONG_REGISTRATIONS: Dict[str, Optional[str]] = {
3939
DynamicCache: "4.50",
4040
BaseModelOutput: None,
4141
UNet2DConditionOutput: None,
@@ -125,7 +125,7 @@ def register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
125125
):
126126
assert cls in registration_functions, (
127127
f"{cls} has no registration functions mapped to it, "
128-
f"available {sorted(registration_functions)}"
128+
f"available options are {list(registration_functions)}"
129129
)
130130
if verbose:
131131
print(
@@ -146,7 +146,7 @@ def register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
146146
return done
147147

148148

149-
def serialization_functions(verbose: int = 0) -> Dict[type, Union[Callable[[], bool], int]]:
149+
def serialization_functions(verbose: int = 0) -> Dict[type, Union[Callable[[int], bool], int]]:
150150
"""Returns the list of serialization functions."""
151151
transformers_classes = {
152152
DynamicCache: lambda verbose=verbose: register_class_serialization(

0 commit comments

Comments
 (0)