Skip to content

Commit 28fe237

Browse files
authored
Support for text-to-image (#165)
* Support for text-to-image * doc * add pick * fix issues * fix issues * type * mypy * refactor * refactor
1 parent 5c3f2a8 commit 28fe237

23 files changed

+731
-347
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.7.2
55
+++++
66

7+
* :pr:`165`: support for task text-to-image
78
* :pr:`162`: improves graphs rendering for historical data
89

910
0.7.1

_doc/api/tasks/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ Or:
4646
summarization
4747
text_classification
4848
text_generation
49+
text_to_image
4950
text2text_generation
5051
zero_shot_image_classification
5152

_doc/api/tasks/text_to_image.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
2+
onnx_diagnostic.tasks.text_to_image
3+
===================================
4+
5+
.. automodule:: onnx_diagnostic.tasks.text_to_image
6+
:members:
7+
:no-undoc-members:

_doc/api/torch_export_patches/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ onnx_diagnostic.torch_export_patches
88
eval/index
99
onnx_export_errors
1010
onnx_export_serialization
11+
onnx_export_serialization_impl
1112
patches/index
1213
patch_expressions
1314
patch_inputs
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
2+
onnx_diagnostic.torch_export_patches.onnx_export_serialization_impl
3+
===================================================================
4+
5+
.. automodule:: onnx_diagnostic.torch_export_patches.onnx_export_serialization_impl
6+
:members:
7+
:no-undoc-members:

_doc/conf.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ def linkcode_resolve(domain, info):
9090
"https://sdpython.github.io/doc/experimental-experiment/dev/",
9191
None,
9292
),
93+
# Not a sphinx documentation
94+
# "diffusers": ("https://huggingface.co/docs/diffusers/index", None),
9395
"matplotlib": ("https://matplotlib.org/stable/", None),
9496
"numpy": ("https://numpy.org/doc/stable", None),
9597
"onnx": ("https://onnx.ai/onnx/", None),
@@ -104,6 +106,8 @@ def linkcode_resolve(domain, info):
104106
"sklearn": ("https://scikit-learn.org/stable/", None),
105107
"skl2onnx": ("https://onnx.ai/sklearn-onnx/", None),
106108
"torch": ("https://pytorch.org/docs/main/", None),
109+
# Not a sphinx documentation
110+
# "transformers": ("https://huggingface.co/docs/transformers/index", None),
107111
}
108112

109113
# Check intersphinx reference targets exist
@@ -116,6 +120,7 @@ def linkcode_resolve(domain, info):
116120
("py:class", "True"),
117121
("py:class", "Argument"),
118122
("py:class", "default=sklearn.utils.metadata_routing.UNCHANGED"),
123+
("py:class", "diffusers.models.unets.unet_2d_condition.UNet2DConditionOutput"),
119124
("py:class", "ModelProto"),
120125
("py:class", "Model"),
121126
("py:class", "Module"),

_doc/patches.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ Here is the list of supported caches:
113113

114114
import onnx_diagnostic.torch_export_patches.onnx_export_serialization as p
115115

116-
print("\n".join(sorted(p.serialization_functions())))
116+
print("\n".join(sorted(t.__name__ for t in p.serialization_functions())))
117117

118118
.. _l-control-flow-rewriting:
119119

_doc/status/patches_coverage.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ The following code shows the list of serialized classes in transformers.
1414

1515
import onnx_diagnostic.torch_export_patches.onnx_export_serialization as p
1616

17-
print('\n'.join(sorted(p.serialization_functions())))
17+
print('\n'.join(sorted(t.__name__ for t in p.serialization_functions())))
1818

1919
Patched Classes
2020
===============

_unittests/ut_tasks/test_tasks_image_classification.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
77

88

9-
class TestTasks(ExtTestCase):
9+
class TestTasksImageClassification(ExtTestCase):
1010
@hide_stdout()
1111
def test_image_classification(self):
1212
mid = "hf-internal-testing/tiny-random-BeitForImageClassification"

_unittests/ut_tasks/test_tasks_image_text_to_text.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
1212

1313

14-
class TestTasks(ExtTestCase):
14+
class TestTasksImageTextToText(ExtTestCase):
1515
@hide_stdout()
1616
@requires_transformers("4.52")
1717
@requires_torch("2.7.99")

0 commit comments

Comments
 (0)