Skip to content

Commit 2783fe9

Browse files
authored
Add task image-to-video (#223)
* Handle more models * Add task image-to-video * add unit test * mypy * remove comma * too old
1 parent 4e21b81 commit 2783fe9

File tree

8 files changed

+366
-63
lines changed

8 files changed

+366
-63
lines changed

CHANGELOGS.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ Change Logs
44
0.7.11
55
++++++
66

7+
* :pr:`223`: adds task image-to-video
8+
* :pr:`220`: adds option --ort-logs to display onnxruntime logs when creating the session
79
* :pr:`220`: adds a patch for PR `#40791 <https://github.com/huggingface/transformers/pull/40791>`_ in transformers
810

911
0.7.10
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import unittest
2+
import torch
3+
import transformers
4+
from onnx_diagnostic.ext_test_case import (
5+
ExtTestCase,
6+
hide_stdout,
7+
requires_diffusers,
8+
requires_torch,
9+
requires_transformers,
10+
)
11+
from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs
12+
from onnx_diagnostic.torch_export_patches import torch_export_patches
13+
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
14+
15+
16+
class TestTasksImageToVideo(ExtTestCase):
17+
@hide_stdout()
18+
@requires_diffusers("0.35")
19+
@requires_transformers("4.55")
20+
@requires_torch("2.8.99")
21+
def test_image_to_video(self):
22+
kwargs = {
23+
"_diffusers_version": "0.34.0.dev0",
24+
"_class_name": "CosmosTransformer3DModel",
25+
"max_size": [128, 240, 240],
26+
"text_embed_dim": 128,
27+
"use_cache": True,
28+
"in_channels": 3,
29+
"out_channels": 16,
30+
"num_layers": 2,
31+
"model_type": "dia",
32+
"patch_size": [1, 2, 2],
33+
"rope_scale": [1.0, 3.0, 3.0],
34+
"attention_head_dim": 16,
35+
"mlp_ratio": 0.4,
36+
"initializer_range": 0.02,
37+
"num_attention_heads": 16,
38+
"is_encoder_decoder": True,
39+
"adaln_lora_dim": 16,
40+
"concat_padding_mask": True,
41+
"extra_pos_embed_type": None,
42+
}
43+
config = transformers.DiaConfig(**kwargs)
44+
mid = "nvidia/Cosmos-Predict2-2B-Video2World"
45+
data = get_untrained_model_with_inputs(
46+
mid,
47+
verbose=1,
48+
add_second_input=True,
49+
subfolder="transformer",
50+
config=config,
51+
inputs_kwargs=dict(image_height=8 * 50, image_width=8 * 80),
52+
)
53+
self.assertEqual(data["task"], "image-to-video")
54+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
55+
model(**inputs)
56+
model(**data["inputs2"])
57+
with torch.fx.experimental._config.patch(
58+
backed_size_oblivious=True
59+
), torch_export_patches(
60+
patch_transformers=True, patch_diffusers=True, verbose=10, stop_if_static=1
61+
):
62+
torch.export.export(
63+
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
64+
)
65+
66+
67+
if __name__ == "__main__":
68+
unittest.main(verbosity=2)

onnx_diagnostic/tasks/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
fill_mask,
66
image_classification,
77
image_text_to_text,
8+
image_to_video,
9+
mask_generation,
810
mixture_of_expert,
911
object_detection,
1012
sentence_similarity,
@@ -14,7 +16,6 @@
1416
text_to_image,
1517
text2text_generation,
1618
zero_shot_image_classification,
17-
mask_generation,
1819
)
1920

2021
__TASKS__ = [
@@ -23,6 +24,8 @@
2324
fill_mask,
2425
image_classification,
2526
image_text_to_text,
27+
image_to_video,
28+
mask_generation,
2629
mixture_of_expert,
2730
object_detection,
2831
sentence_similarity,
@@ -32,7 +35,6 @@
3235
text_to_image,
3336
text2text_generation,
3437
zero_shot_image_classification,
35-
mask_generation,
3638
]
3739

3840

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
from typing import Any, Callable, Dict, Optional, Tuple
2+
import torch
3+
from ..helpers.config_helper import (
4+
update_config,
5+
check_hasattr,
6+
default_num_hidden_layers as nhl,
7+
)
8+
9+
__TASK__ = "image-to-video"
10+
11+
12+
def reduce_model_config(config: Any) -> Dict[str, Any]:
13+
"""Reduces a model size."""
14+
if not hasattr(config, "num_hidden_layers") and not hasattr(config, "num_layers"):
15+
# We cannot reduce.
16+
return {}
17+
check_hasattr(config, ("num_hidden_layers", "num_layers"))
18+
kwargs = {}
19+
if hasattr(config, "num_layers"):
20+
kwargs["num_layers"] = min(config.num_layers, nhl())
21+
if hasattr(config, "num_hidden_layers"):
22+
kwargs["num_hidden_layers"] = min(config.num_hidden_layers, nhl())
23+
24+
update_config(config, kwargs)
25+
return kwargs
26+
27+
28+
def get_inputs(
29+
model: torch.nn.Module,
30+
config: Optional[Any],
31+
text_embed_dim: int,
32+
latent_channels: int,
33+
batch_size: int = 2,
34+
image_height: int = 704,
35+
image_width: int = 1280,
36+
latent_frames: int = 1,
37+
text_maxlen: int = 512,
38+
add_second_input: int = 1,
39+
**kwargs, # unused
40+
):
41+
"""
42+
Generates inputs for task ``image-to-video``.
43+
"""
44+
assert (
45+
"cls_cache" not in kwargs
46+
), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
47+
latent_height = image_height // 8
48+
latent_width = image_width // 8
49+
dtype = torch.float32
50+
51+
inputs = dict(
52+
hidden_states=torch.randn(
53+
batch_size,
54+
latent_channels,
55+
latent_frames,
56+
latent_height,
57+
latent_width,
58+
dtype=dtype,
59+
),
60+
timestep=torch.tensor([1.0] * batch_size, dtype=dtype),
61+
encoder_hidden_states=torch.randn(
62+
batch_size, text_maxlen, text_embed_dim, dtype=dtype
63+
),
64+
padding_mask=torch.ones(1, 1, image_height, image_width, dtype=dtype),
65+
fps=torch.tensor([16] * batch_size, dtype=dtype),
66+
condition_mask=torch.randn(
67+
batch_size, 1, latent_frames, latent_height, latent_width, dtype=dtype
68+
),
69+
)
70+
shapes = dict(
71+
hidden_states={
72+
0: "batch_size",
73+
2: "latent_frames",
74+
3: "latent_height",
75+
4: "latent_width",
76+
},
77+
timestep={0: "batch_size"},
78+
encoder_hidden_states={0: "batch_size"},
79+
padding_mask={0: "batch_size", 2: "height", 3: "width"},
80+
fps={0: "batch_size"},
81+
condition_mask={
82+
0: "batch_size",
83+
2: "latent_frames",
84+
3: "latent_height",
85+
4: "latent_width",
86+
},
87+
)
88+
res = dict(inputs=inputs, dynamic_shapes=shapes)
89+
90+
if add_second_input:
91+
assert (
92+
add_second_input > 0
93+
), f"Not implemented for add_second_input={add_second_input}."
94+
res["inputs2"] = get_inputs(
95+
model=model,
96+
config=config,
97+
text_embed_dim=text_embed_dim,
98+
latent_channels=latent_channels,
99+
batch_size=batch_size,
100+
image_height=image_height,
101+
image_width=image_width,
102+
latent_frames=latent_frames,
103+
text_maxlen=text_maxlen,
104+
add_second_input=0,
105+
**kwargs,
106+
)["inputs"]
107+
return res
108+
109+
110+
def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
111+
"""
112+
Inputs kwargs.
113+
114+
If the configuration is None, the function selects typical dimensions.
115+
"""
116+
if config is not None:
117+
check_hasattr(config, "in_channels", "text_embed_dim"),
118+
kwargs = dict(
119+
text_embed_dim=1024 if config is None else config.text_embed_dim,
120+
latent_channels=16 if config is None else config.in_channels - 1,
121+
batch_size=1,
122+
image_height=8 * 50,
123+
image_width=8 * 80,
124+
latent_frames=1,
125+
text_maxlen=512,
126+
)
127+
return kwargs, get_inputs

onnx_diagnostic/torch_models/hghub/hub_api.py

Lines changed: 69 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,51 @@ def task_from_arch(
177177
return data[arch]
178178

179179

180+
def _trygetattr(config, attname):
181+
try:
182+
return getattr(config, attname)
183+
except AttributeError:
184+
return None
185+
186+
187+
def architecture_from_config(config) -> Optional[str]:
188+
"""Guesses the architecture (class) of the model described by this config."""
189+
if isinstance(config, dict):
190+
if "_class_name" in config:
191+
return config["_class_name"]
192+
if "architecture" in config:
193+
return config["architecture"]
194+
if config.get("architectures", []):
195+
return config["architectures"][0]
196+
if hasattr(config, "_class_name"):
197+
return config._class_name
198+
if hasattr(config, "architecture"):
199+
return config.architecture
200+
if hasattr(config, "architectures") and config.architectures:
201+
return config.architectures[0]
202+
if hasattr(config, "__dict__"):
203+
if "_class_name" in config.__dict__:
204+
return config.__dict__["_class_name"]
205+
if "architecture" in config.__dict__:
206+
return config.__dict__["architecture"]
207+
if config.__dict__.get("architectures", []):
208+
return config.__dict__["architectures"][0]
209+
return None
210+
211+
212+
def find_package_source(config) -> Optional[str]:
213+
"""Guesses the package the class models from."""
214+
if isinstance(config, dict):
215+
if "_diffusers_version" in config:
216+
return "diffusers"
217+
if hasattr(config, "_diffusers_version"):
218+
return "diffusers"
219+
if hasattr(config, "__dict__"):
220+
if "_diffusers_version" in config.__dict__:
221+
return "diffusers"
222+
return "transformers"
223+
224+
180225
def task_from_id(
181226
model_id: str,
182227
default_value: Optional[str] = None,
@@ -202,28 +247,30 @@ def task_from_id(
202247
if not fall_back_to_pretrained:
203248
raise
204249
config = get_pretrained_config(model_id, subfolder=subfolder)
205-
try:
206-
return config.pipeline_tag
207-
except AttributeError:
208-
guess = _guess_task_from_config(config)
209-
if guess is not None:
210-
return guess
211-
data = load_architecture_task()
212-
if model_id in data:
213-
return data[model_id]
214-
if type(config) is dict and "_class_name" in config:
215-
return task_from_arch(config["_class_name"], default_value=default_value)
216-
if not config.architectures or not config.architectures:
217-
# Some hardcoded values until a better solution is found.
218-
if model_id.startswith("google/bert_"):
219-
return "fill-mask"
220-
assert config.architectures is not None and len(config.architectures) == 1, (
221-
f"Cannot return the task of {model_id!r}, pipeline_tag is not setup, "
222-
f"architectures={config.architectures} in config={config}. "
223-
f"The task can be added in "
224-
f"``onnx_diagnostic.torch_models.hghub.hub_data.__data_arch__``."
225-
)
226-
return task_from_arch(config.architectures[0], default_value=default_value)
250+
tag = _trygetattr(config, "pipeline_tag")
251+
if tag is not None:
252+
return tag
253+
254+
guess = _guess_task_from_config(config)
255+
if guess is not None:
256+
return guess
257+
data = load_architecture_task()
258+
if subfolder:
259+
full_id = f"{model_id}//{subfolder}"
260+
if full_id in data:
261+
return data[full_id]
262+
if model_id in data:
263+
return data[model_id]
264+
arch = architecture_from_config(config)
265+
if arch is None:
266+
if model_id.startswith("google/bert_"):
267+
return "fill-mask"
268+
assert arch is not None, (
269+
f"Cannot return the task of {model_id!r}, pipeline_tag is not setup, "
270+
f"config={config}. The task can be added in "
271+
f"``onnx_diagnostic.torch_models.hghub.hub_data.__data_arch__``."
272+
)
273+
return task_from_arch(arch, default_value=default_value)
227274

228275

229276
def task_from_tags(tags: Union[str, List[str]]) -> str:

onnx_diagnostic/torch_models/hghub/hub_data.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
ConvBertModel,feature-extraction
3131
ConvNextForImageClassification,image-classification
3232
ConvNextV2Model,image-feature-extraction
33+
CosmosTransformer3DModel,image-to-video
3334
CvtModel,feature-extraction
3435
DPTModel,image-feature-extraction
3536
Data2VecAudioModel,feature-extraction
@@ -156,7 +157,8 @@
156157
YolosForObjectDetection,object-detection
157158
YolosModel,image-feature-extraction
158159
Alibaba-NLP/gte-large-en-v1.5,sentence-similarity
159-
emilyalsentzer/Bio_ClinicalBERT,fill-mask"""
160+
emilyalsentzer/Bio_ClinicalBERT,fill-mask
161+
nvidia/Cosmos-Predict2-2B-Video2World//transformer,image-to-video"""
160162
)
161163

162164
__data_tasks__ = [

0 commit comments

Comments
 (0)