Skip to content

Commit 6d85d19

Browse files
committed
Add task image-to-video
1 parent cc64994 commit 6d85d19

File tree

6 files changed

+294
-62
lines changed

6 files changed

+294
-62
lines changed

onnx_diagnostic/tasks/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
fill_mask,
66
image_classification,
77
image_text_to_text,
8+
image_to_video,
9+
mask_generation,
810
mixture_of_expert,
911
object_detection,
1012
sentence_similarity,
@@ -14,7 +16,6 @@
1416
text_to_image,
1517
text2text_generation,
1618
zero_shot_image_classification,
17-
mask_generation,
1819
)
1920

2021
__TASKS__ = [
@@ -23,6 +24,8 @@
2324
fill_mask,
2425
image_classification,
2526
image_text_to_text,
27+
image_to_video,
28+
mask_generation,
2629
mixture_of_expert,
2730
object_detection,
2831
sentence_similarity,
@@ -32,7 +35,6 @@
3235
text_to_image,
3336
text2text_generation,
3437
zero_shot_image_classification,
35-
mask_generation,
3638
]
3739

3840

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
from typing import Any, Callable, Dict, Optional, Tuple
2+
import torch
3+
from ..helpers.config_helper import (
4+
update_config,
5+
check_hasattr,
6+
default_num_hidden_layers as nhl,
7+
)
8+
9+
__TASK__ = "image-to-video"
10+
11+
12+
def reduce_model_config(config: Any) -> Dict[str, Any]:
13+
"""Reduces a model size."""
14+
if not hasattr(config, "num_hidden_layers") and not hasattr(config, "num_layers"):
15+
# We cannot reduce.
16+
return {}
17+
check_hasattr(config, ("num_hidden_layers", "num_layers"))
18+
kwargs = {}
19+
if hasattr(config, "num_layers"):
20+
kwargs["num_layers"] = min(config.num_layers, nhl())
21+
if hasattr(config, "num_hidden_layers"):
22+
kwargs["num_hidden_layers"] = min(config.num_hidden_layers, nhl())
23+
24+
update_config(config, kwargs)
25+
return kwargs
26+
27+
28+
def get_inputs(
29+
model: torch.nn.Module,
30+
config: Optional[Any],
31+
text_embed_dim: int,
32+
latent_channels: int,
33+
batch_size: int = 2,
34+
image_height: int = 704,
35+
image_width: int = 1280,
36+
latent_frames: int = 1,
37+
text_maxlen: int = 512,
38+
add_second_input: int = 1,
39+
**kwargs, # unused
40+
):
41+
"""
42+
Generates inputs for task ``image-to-video``.
43+
"""
44+
assert (
45+
"cls_cache" not in kwargs
46+
), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
47+
latent_height = image_height // 8
48+
latent_width = image_width // 8
49+
dtype = torch.float32
50+
51+
inputs = dict(
52+
hidden_states=torch.randn(
53+
batch_size,
54+
latent_channels,
55+
latent_frames,
56+
latent_height,
57+
latent_width,
58+
dtype=dtype,
59+
),
60+
timestep=torch.tensor([1.0] * batch_size, dtype=dtype),
61+
encoder_hidden_states=torch.randn(
62+
batch_size, text_maxlen, text_embed_dim, dtype=dtype
63+
),
64+
padding_mask=torch.ones(1, 1, image_height, image_width, dtype=dtype),
65+
fps=torch.tensor([16] * batch_size, dtype=dtype),
66+
condition_mask=torch.randn(
67+
batch_size, 1, latent_frames, latent_height, latent_width, dtype=dtype
68+
),
69+
)
70+
shapes = dict(
71+
hidden_states={
72+
0: "batch_size",
73+
2: "latent_frames",
74+
3: "latent_height",
75+
4: "latent_width",
76+
},
77+
timestep={0: "batch_size"},
78+
encoder_hidden_states={0: "batch_size"},
79+
padding_mask={0: "batch_size", 2: "height", 3: "width"},
80+
fps={0: "batch_size"},
81+
condition_mask={
82+
0: "batch_size",
83+
2: "latent_frames",
84+
3: "latent_height",
85+
4: "latent_width",
86+
},
87+
)
88+
res = dict(inputs=inputs, dynamic_shapes=shapes)
89+
90+
if add_second_input:
91+
assert (
92+
add_second_input > 0
93+
), f"Not implemented for add_second_input={add_second_input}."
94+
res["inputs2"] = get_inputs(
95+
model=model,
96+
config=config,
97+
text_embed_dim=text_embed_dim,
98+
latent_channels=latent_channels,
99+
batch_size=batch_size,
100+
image_height=image_height,
101+
image_width=image_width,
102+
latent_frames=latent_frames,
103+
text_maxlen=text_maxlen,
104+
add_second_input=0,
105+
**kwargs,
106+
)["inputs"]
107+
return res
108+
109+
110+
def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
111+
"""
112+
Inputs kwargs.
113+
114+
If the configuration is None, the function selects typical dimensions.
115+
"""
116+
if config is not None:
117+
check_hasattr(config, "in_channels", "text_embed_dim"),
118+
kwargs = dict(
119+
text_embed_dim=1024 if config is None else config.text_embed_dim,
120+
latent_channels=16 if config is None else config.in_channels - 1,
121+
batch_size=1,
122+
image_height=8 * 50,
123+
image_width=8 * 80,
124+
latent_frames=1,
125+
text_maxlen=512,
126+
)
127+
return kwargs, get_inputs

onnx_diagnostic/torch_models/hghub/hub_api.py

Lines changed: 69 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,51 @@ def task_from_arch(
177177
return data[arch]
178178

179179

180+
def _trygetattr(config, attname):
181+
try:
182+
return getattr(config, attname)
183+
except AttributeError:
184+
return None
185+
186+
187+
def architecture_from_config(config) -> Optional[str]:
188+
"""Guesses the architecture (class) of the model described by this config."""
189+
if isinstance(config, dict):
190+
if "_class_name" in config:
191+
return config["_class_name"]
192+
if "architecture" in config:
193+
return config["architecture"]
194+
if config.get("architectures", []):
195+
return config["architectures"][0]
196+
if hasattr(config, "_class_name"):
197+
return config._class_name
198+
if hasattr(config, "architecture"):
199+
return config.architecture
200+
if hasattr(config, "architectures") and config.architectures:
201+
return config.architectures[0]
202+
if hasattr(config, "__dict__"):
203+
if "_class_name" in config.__dict__:
204+
return config.__dict__["_class_name"]
205+
if "architecture" in config.__dict__:
206+
return config.__dict__["architecture"]
207+
if config.__dict__.get("architectures", []):
208+
return config.__dict__["architectures"][0]
209+
return None
210+
211+
212+
def find_package_source(config) -> Optional[str]:
213+
"""Guesses the package the class models from."""
214+
if isinstance(config, dict):
215+
if "_diffusers_version" in config:
216+
return "diffusers"
217+
if hasattr(config, "_diffusers_version"):
218+
return "diffusers"
219+
if hasattr(config, "__dict__"):
220+
if "_diffusers_version" in config.__dict__:
221+
return "diffusers"
222+
return "transformers"
223+
224+
180225
def task_from_id(
181226
model_id: str,
182227
default_value: Optional[str] = None,
@@ -202,28 +247,30 @@ def task_from_id(
202247
if not fall_back_to_pretrained:
203248
raise
204249
config = get_pretrained_config(model_id, subfolder=subfolder)
205-
try:
206-
return config.pipeline_tag
207-
except AttributeError:
208-
guess = _guess_task_from_config(config)
209-
if guess is not None:
210-
return guess
211-
data = load_architecture_task()
212-
if model_id in data:
213-
return data[model_id]
214-
if type(config) is dict and "_class_name" in config:
215-
return task_from_arch(config["_class_name"], default_value=default_value)
216-
if not config.architectures or not config.architectures:
217-
# Some hardcoded values until a better solution is found.
218-
if model_id.startswith("google/bert_"):
219-
return "fill-mask"
220-
assert config.architectures is not None and len(config.architectures) == 1, (
221-
f"Cannot return the task of {model_id!r}, pipeline_tag is not setup, "
222-
f"architectures={config.architectures} in config={config}. "
223-
f"The task can be added in "
224-
f"``onnx_diagnostic.torch_models.hghub.hub_data.__data_arch__``."
225-
)
226-
return task_from_arch(config.architectures[0], default_value=default_value)
250+
tag = _trygetattr(config, "pipeline_tag")
251+
if tag is not None:
252+
return tag
253+
254+
guess = _guess_task_from_config(config)
255+
if guess is not None:
256+
return guess
257+
data = load_architecture_task()
258+
if subfolder:
259+
full_id = f"{model_id}//{subfolder}"
260+
if full_id in data:
261+
return data[full_id]
262+
if model_id in data:
263+
return data[model_id]
264+
arch = architecture_from_config(config)
265+
if arch is None:
266+
if model_id.startswith("google/bert_"):
267+
return "fill-mask"
268+
assert arch is not None, (
269+
f"Cannot return the task of {model_id!r}, pipeline_tag is not setup, "
270+
f"config={config}. The task can be added in "
271+
f"``onnx_diagnostic.torch_models.hghub.hub_data.__data_arch__``."
272+
)
273+
return task_from_arch(arch, default_value=default_value)
227274

228275

229276
def task_from_tags(tags: Union[str, List[str]]) -> str:

onnx_diagnostic/torch_models/hghub/hub_data.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,8 @@
156156
YolosForObjectDetection,object-detection
157157
YolosModel,image-feature-extraction
158158
Alibaba-NLP/gte-large-en-v1.5,sentence-similarity
159-
emilyalsentzer/Bio_ClinicalBERT,fill-mask"""
159+
emilyalsentzer/Bio_ClinicalBERT,fill-mask
160+
nvidia/Cosmos-Predict2-2B-Video2World//transformer,image-to-video"""
160161
)
161162

162163
__data_tasks__ = [

0 commit comments

Comments
 (0)