Skip to content

Commit c30f912

Browse files
committed
Support for text-to-image
1 parent 5c3f2a8 commit c30f912

12 files changed

+210
-12
lines changed

_unittests/ut_tasks/test_tasks_image_classification.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
77

88

9-
class TestTasks(ExtTestCase):
9+
class TestTasksImageClassification(ExtTestCase):
1010
@hide_stdout()
1111
def test_image_classification(self):
1212
mid = "hf-internal-testing/tiny-random-BeitForImageClassification"

_unittests/ut_tasks/test_tasks_image_text_to_text.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
1212

1313

14-
class TestTasks(ExtTestCase):
14+
class TestTasksImageTextToText(ExtTestCase):
1515
@hide_stdout()
1616
@requires_transformers("4.52")
1717
@requires_torch("2.7.99")

_unittests/ut_tasks/test_tasks_object_detection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
77

88

9-
class TestTasks(ExtTestCase):
9+
class TestTasksObjectDetection(ExtTestCase):
1010
@hide_stdout()
1111
def test_object_detection(self):
1212
mid = "hustvl/yolos-tiny"
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import unittest
2+
import torch
3+
from onnx_diagnostic.ext_test_case import (
4+
ExtTestCase,
5+
hide_stdout,
6+
requires_transformers,
7+
requires_torch,
8+
)
9+
from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs
10+
from onnx_diagnostic.torch_export_patches import torch_export_patches
11+
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
12+
13+
14+
class TestTasksTextToTimage(ExtTestCase):
15+
@hide_stdout()
16+
@requires_transformers("4.52")
17+
@requires_torch("2.7.99")
18+
def test_text_to_image(self):
19+
mid = "diffusers/tiny-torch-full-checker"
20+
data = get_untrained_model_with_inputs(
21+
mid, verbose=1, add_second_input=True, subfolder="unet"
22+
)
23+
self.assertEqual(data["task"], "text-to-image")
24+
self.assertIn((data["size"], data["n_weights"]), [(5708048, 1427012)])
25+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
26+
model(**inputs)
27+
model(**data["inputs2"])
28+
with torch_export_patches(patch_transformers=True, verbose=10, stop_if_static=1):
29+
torch.export.export(
30+
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
31+
)
32+
33+
34+
if __name__ == "__main__":
35+
unittest.main(verbosity=2)

_unittests/ut_tasks/test_tasks_zero_shot_image_classification.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
77

88

9-
class TestTasks(ExtTestCase):
9+
class TestTasksZeroShotImageClassification(ExtTestCase):
1010
@requires_torch("2.7.99")
1111
@hide_stdout()
1212
def test_zero_shot_image_classification(self):

_unittests/ut_tasks/try_tasks.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,27 @@ def test_object_detection(self):
569569
f"{round(score.item(), 3)} at location {box}"
570570
)
571571

572+
@never_test()
573+
def test_text_to_image(self):
574+
# clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k test_text_to_image
575+
import torch
576+
from diffusers import StableDiffusionPipeline
577+
578+
model_id = "diffusers/tiny-torch-full-checker" # "stabilityai/stable-diffusion-2"
579+
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(
580+
"cuda"
581+
)
582+
583+
prompt = "a photo of an astronaut riding a horse on mars and on jupyter"
584+
print()
585+
with steal_forward(pipe.unet, with_min_max=True):
586+
image = pipe(prompt).images[0]
587+
print("-- output", self.string_type(image, with_shape=True, with_min_max=True))
588+
# stolen forward for class UNet2DConditionModel -- iteration 44
589+
# sample=T10s2x4x96x96[-3.7734375,4.359375:A-0.043463995395642184]
590+
# time_step=T7s=101
591+
# encoder_hidden_states:T10s2x77x1024[-6.58203125,13.0234375:A-0.16780663634440257]
592+
572593

573594
if __name__ == "__main__":
574595
unittest.main(verbosity=2)

onnx_diagnostic/helpers/config_helper.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,10 @@ def update_config(config: Any, mkwargs: Dict[str, Any]):
4343
else:
4444
update_config(getattr(config, k), v)
4545
continue
46-
setattr(config, k, v)
46+
if type(config) is dict:
47+
config[k] = v
48+
else:
49+
setattr(config, k, v)
4750

4851

4952
def _pick(config, *atts, exceptions: Optional[Dict[str, Callable]] = None):

onnx_diagnostic/tasks/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
summarization,
1212
text_classification,
1313
text_generation,
14+
text_to_image,
1415
text2text_generation,
1516
zero_shot_image_classification,
1617
)
@@ -27,6 +28,7 @@
2728
summarization,
2829
text_classification,
2930
text_generation,
31+
text_to_image,
3032
text2text_generation,
3133
zero_shot_image_classification,
3234
]
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
from typing import Any, Callable, Dict, Optional, Tuple
2+
import torch
3+
from ..helpers.config_helper import update_config, check_hasattr
4+
5+
__TASK__ = "text-to-image"
6+
7+
8+
def reduce_model_config(config: Any) -> Dict[str, Any]:
9+
"""Reduces a model size."""
10+
check_hasattr(config, "sample_size", "cross_attention_dim")
11+
kwargs = dict(
12+
sample_size=min(config["sample_size"], 32),
13+
cross_attention_dim=min(config["cross_attention_dim"], 64),
14+
)
15+
update_config(config, kwargs)
16+
return kwargs
17+
18+
19+
def get_inputs(
20+
model: torch.nn.Module,
21+
config: Optional[Any],
22+
batch_size: int,
23+
sequence_length: int,
24+
cache_length: int,
25+
in_channels: int,
26+
sample_size: int,
27+
cross_attention_dim: int,
28+
add_second_input: bool = False,
29+
**kwargs, # unused
30+
):
31+
"""
32+
Generates inputs for task ``text-to-image``.
33+
Example:
34+
35+
::
36+
37+
sample:T10s2x4x96x96[-3.7734375,4.359375:A-0.043463995395642184]
38+
timestep:T7s=101
39+
encoder_hidden_states:T10s2x77x1024[-6.58203125,13.0234375:A-0.16780663634440257]
40+
"""
41+
assert (
42+
"cls_cache" not in kwargs
43+
), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
44+
batch = torch.export.Dim("batch", min=1, max=1024)
45+
shapes = {
46+
"sample": {0: batch},
47+
"timestep": {},
48+
"encoder_hidden_states": {0: batch, 1: "encoder_length"},
49+
}
50+
inputs = dict(
51+
sample=torch.randn((batch_size, sequence_length, sample_size, sample_size)).to(
52+
torch.float32
53+
),
54+
timestep=torch.tensor([101], dtype=torch.int64),
55+
encoder_hidden_states=torch.randn(
56+
(batch_size, sequence_length, cross_attention_dim)
57+
).to(torch.float32),
58+
)
59+
res = dict(inputs=inputs, dynamic_shapes=shapes)
60+
if add_second_input:
61+
res["inputs2"] = get_inputs(
62+
model=model,
63+
config=config,
64+
batch_size=batch_size + 1,
65+
sequence_length=sequence_length,
66+
cache_length=cache_length + 1,
67+
in_channels=in_channels,
68+
sample_size=sample_size,
69+
cross_attention_dim=cross_attention_dim,
70+
**kwargs,
71+
)["inputs"]
72+
return res
73+
74+
75+
def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
76+
"""
77+
Inputs kwargs.
78+
79+
If the configuration is None, the function selects typical dimensions.
80+
"""
81+
if config is not None:
82+
check_hasattr(config, "sample_size", "cross_attention_dim", "in_channels")
83+
kwargs = dict(
84+
batch_size=2,
85+
sequence_length=config["in_channels"],
86+
cache_length=77,
87+
in_channels=config["in_channels"],
88+
sample_size=config["sample_size"],
89+
cross_attention_dim=config["cross_attention_dim"],
90+
)
91+
return kwargs, get_inputs

onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4302,3 +4302,31 @@ def _ccached_microsoft_phi_35_mini_instruct():
43024302
"vocab_size": 32064,
43034303
}
43044304
)
4305+
4306+
4307+
def _ccached_diffusers_tiny_torch_full_checker_unet():
4308+
"diffusers/tiny-torch-full-checker/unet"
4309+
return {
4310+
"_class_name": "UNet2DConditionModel",
4311+
"_diffusers_version": "0.8.0",
4312+
"_name_or_path": "https://huggingface.co/diffusers/tiny-torch-full-checker/blob/main/unet/config.json",
4313+
"act_fn": "silu",
4314+
"attention_head_dim": 8,
4315+
"block_out_channels": [32, 64],
4316+
"center_input_sample": false,
4317+
"cross_attention_dim": 32,
4318+
"down_block_types": ["DownBlock2D", "CrossAttnDownBlock2D"],
4319+
"downsample_padding": 1,
4320+
"dual_cross_attention": false,
4321+
"flip_sin_to_cos": true,
4322+
"freq_shift": 0,
4323+
"in_channels": 4,
4324+
"layers_per_block": 2,
4325+
"mid_block_scale_factor": 1,
4326+
"norm_eps": 1e-05,
4327+
"norm_num_groups": 32,
4328+
"out_channels": 4,
4329+
"sample_size": 32,
4330+
"up_block_types": ["CrossAttnUpBlock2D", "UpBlock2D"],
4331+
"use_linear_projection": false,
4332+
}

0 commit comments

Comments
 (0)